diff --git a/docker/requirements.txt b/docker/requirements.txt
index d3268edae..21f246599 100644
--- a/docker/requirements.txt
+++ b/docker/requirements.txt
@@ -160,3 +160,4 @@ xlrd==2.0.2
xlsxwriter==3.2.5
prometheus-client==0.23.1
pymilvus==2.5.12
+langchain-text-splitters==1.0.0
diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py
index fe6b600b8..e9bb2e499 100644
--- a/src/memos/api/handlers/chat_handler.py
+++ b/src/memos/api/handlers/chat_handler.py
@@ -142,7 +142,9 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An
# Step 2: Build system prompt
system_prompt = self._build_system_prompt(
- filtered_memories, search_response.data["pref_string"], chat_req.system_prompt
+ filtered_memories,
+ search_response.data.get("pref_string", ""),
+ chat_req.system_prompt,
)
# Prepare message history
@@ -257,7 +259,7 @@ def generate_chat_response() -> Generator[str, None, None]:
# Step 2: Build system prompt with memories
system_prompt = self._build_system_prompt(
filtered_memories,
- search_response.data["pref_string"],
+ search_response.data.get("pref_string", ""),
chat_req.system_prompt,
)
@@ -449,7 +451,7 @@ def generate_chat_response() -> Generator[str, None, None]:
# Step 2: Build system prompt with memories
system_prompt = self._build_enhance_system_prompt(
- filtered_memories, search_response.data["pref_string"]
+ filtered_memories, search_response.data.get("pref_string", "")
)
# Prepare messages
diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py
index 976be87bb..88875cacc 100644
--- a/src/memos/api/handlers/formatters_handler.py
+++ b/src/memos/api/handlers/formatters_handler.py
@@ -90,3 +90,37 @@ def post_process_pref_mem(
memories_result["pref_note"] = pref_note
return memories_result
+
+
+def post_process_textual_mem(
+ memories_result: dict[str, Any],
+ text_formatted_mem: list[dict[str, Any]],
+ mem_cube_id: str,
+) -> dict[str, Any]:
+ """
+ Post-process text and tool memory results.
+ """
+ fact_mem = [
+ mem
+ for mem in text_formatted_mem
+ if mem["metadata"]["memory_type"] not in ["ToolSchemaMemory", "ToolTrajectoryMemory"]
+ ]
+ tool_mem = [
+ mem
+ for mem in text_formatted_mem
+ if mem["metadata"]["memory_type"] in ["ToolSchemaMemory", "ToolTrajectoryMemory"]
+ ]
+
+ memories_result["text_mem"].append(
+ {
+ "cube_id": mem_cube_id,
+ "memories": fact_mem,
+ }
+ )
+ memories_result["tool_mem"].append(
+ {
+ "cube_id": mem_cube_id,
+ "memories": tool_mem,
+ }
+ )
+ return memories_result
diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py
index ffe736aa3..4ad4016bc 100644
--- a/src/memos/api/product_models.py
+++ b/src/memos/api/product_models.py
@@ -6,7 +6,7 @@
# Import message types from core types module
from memos.log import get_logger
-from memos.types import MessageDict, PermissionDict, SearchMode
+from memos.types import MessageList, MessagesType, PermissionDict, SearchMode
logger = get_logger(__name__)
@@ -56,7 +56,7 @@ class Message(BaseModel):
class MemoryCreate(BaseRequest):
user_id: str = Field(..., description="User ID")
- messages: list | None = Field(None, description="List of messages to store.")
+ messages: MessageList | None = Field(None, description="List of messages to store.")
memory_content: str | None = Field(None, description="Content to store as memory")
doc_path: str | None = Field(None, description="Path to document to store")
mem_cube_id: str | None = Field(None, description="ID of the memory cube")
@@ -83,7 +83,7 @@ class ChatRequest(BaseRequest):
writable_cube_ids: list[str] | None = Field(
None, description="List of cube IDs user can write for multi-cube chat"
)
- history: list | None = Field(None, description="Chat history")
+ history: MessageList | None = Field(None, description="Chat history")
mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
top_k: int = Field(10, description="Number of results to return")
@@ -165,7 +165,7 @@ class ChatCompleteRequest(BaseRequest):
user_id: str = Field(..., description="User ID")
query: str = Field(..., description="Chat query message")
mem_cube_id: str | None = Field(None, description="Cube ID to use for chat")
- history: list | None = Field(None, description="Chat history")
+ history: MessageList | None = Field(None, description="Chat history")
internet_search: bool = Field(False, description="Whether to use internet search")
system_prompt: str | None = Field(None, description="Base prompt to use for chat")
top_k: int = Field(10, description="Number of results to return")
@@ -251,7 +251,7 @@ class MemoryCreateRequest(BaseRequest):
"""Request model for creating memories."""
user_id: str = Field(..., description="User ID")
- messages: str | list | None = Field(None, description="List of messages to store.")
+ messages: str | MessagesType | None = Field(None, description="List of messages to store.")
memory_content: str | None = Field(None, description="Memory content to store")
doc_path: str | None = Field(None, description="Path to document to store")
mem_cube_id: str | None = Field(None, description="Cube ID")
@@ -326,6 +326,21 @@ class APISearchRequest(BaseRequest):
),
)
+ search_tool_memory: bool = Field(
+ True,
+ description=(
+ "Whether to retrieve tool memories along with general memories. "
+ "If enabled, the system will automatically recall tool memories "
+ "relevant to the query. Default: True."
+ ),
+ )
+
+ tool_mem_top_k: int = Field(
+ 6,
+ ge=0,
+ description="Number of tool memories to retrieve (top-K). Default: 6.",
+ )
+
# ==== Filter conditions ====
# TODO: maybe add detailed description later
filter: dict[str, Any] | None = Field(
@@ -360,7 +375,7 @@ class APISearchRequest(BaseRequest):
)
# ==== Context ====
- chat_history: list | None = Field(
+ chat_history: MessageList | None = Field(
None,
description=(
"Historical chat messages used internally by algorithms. "
@@ -490,7 +505,7 @@ class APIADDRequest(BaseRequest):
)
# ==== Input content ====
- messages: str | list | None = Field(
+ messages: MessagesType | None = Field(
None,
description=(
"List of messages to store. Supports: "
@@ -506,7 +521,7 @@ class APIADDRequest(BaseRequest):
)
# ==== Chat history ====
- chat_history: list | None = Field(
+ chat_history: MessageList | None = Field(
None,
description=(
"Historical chat messages used internally by algorithms. "
@@ -636,7 +651,7 @@ class APIFeedbackRequest(BaseRequest):
"default_session", description="Session ID for soft-filtering memories"
)
task_id: str | None = Field(None, description="Task ID for monitering async tasks")
- history: list[MessageDict] | None = Field(..., description="Chat history")
+ history: MessageList | None = Field(..., description="Chat history")
retrieved_memory_ids: list[str] | None = Field(
None, description="Retrieved memory ids at last turn"
)
@@ -671,7 +686,7 @@ class APIChatCompleteRequest(BaseRequest):
writable_cube_ids: list[str] | None = Field(
None, description="List of cube IDs user can write for multi-cube chat"
)
- history: list | None = Field(None, description="Chat history")
+ history: MessageList | None = Field(None, description="Chat history")
mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
top_k: int = Field(10, description="Number of results to return")
@@ -740,7 +755,7 @@ class SuggestionRequest(BaseRequest):
user_id: str = Field(..., description="User ID")
mem_cube_id: str = Field(..., description="Cube ID")
language: Literal["zh", "en"] = Field("zh", description="Language for suggestions")
- message: list | None = Field(None, description="List of messages to store.")
+ message: MessagesType | None = Field(None, description="List of messages to store.")
# ─── MemOS Client Response Models ──────────────────────────────────────────────
diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py
index 57774cf3a..e0aa40913 100644
--- a/src/memos/mem_reader/multi_modal_struct.py
+++ b/src/memos/mem_reader/multi_modal_struct.py
@@ -1,4 +1,5 @@
import concurrent.futures
+import json
import traceback
from typing import Any
@@ -7,8 +8,9 @@
from memos.configs.mem_reader import MultiModalStructMemReaderConfig
from memos.context.context import ContextThreadPoolExecutor
from memos.mem_reader.read_multi_modal import MultiModalParser
-from memos.mem_reader.simple_struct import SimpleStructMemReader
+from memos.mem_reader.simple_struct import SimpleStructMemReader, detect_lang
from memos.memories.textual.item import TextualMemoryItem
+from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH
from memos.types import MessagesType
from memos.utils import timed
@@ -297,6 +299,61 @@ def _process_string_fine(
return fine_memory_items
+ def _get_llm_tool_trajectory_response(self, mem_str: str) -> dict:
+ """
+ Generete tool trajectory experience item by llm.
+ """
+ try:
+ lang = detect_lang(mem_str)
+ template = TOOL_TRAJECTORY_PROMPT_ZH if lang == "zh" else TOOL_TRAJECTORY_PROMPT_EN
+ prompt = template.replace("{messages}", mem_str)
+ rsp = self.llm.generate([{"role": "user", "content": prompt}])
+ rsp = rsp.replace("```json", "").replace("```", "")
+ return json.loads(rsp)
+ except Exception as e:
+ logger.error(f"[MultiModalFine] Error calling LLM for tool trajectory: {e}")
+ return []
+
+ def _process_tool_trajectory_fine(
+ self,
+ fast_memory_items: list[TextualMemoryItem],
+ info: dict[str, Any],
+ ) -> list[TextualMemoryItem]:
+ """
+ Process tool trajectory memory items through LLM to generate fine mode memories.
+ """
+ if not fast_memory_items:
+ return []
+
+ fine_memory_items = []
+
+ for fast_item in fast_memory_items:
+ # Extract memory text (string content)
+ mem_str = fast_item.memory or ""
+ if not mem_str.strip() or "tool:" not in mem_str:
+ continue
+ try:
+ resp = self._get_llm_tool_trajectory_response(mem_str)
+ except Exception as e:
+ logger.error(f"[MultiModalFine] Error calling LLM for tool trajectory: {e}")
+ continue
+ for m in resp:
+ try:
+ # Normalize memory_type (same as simple_struct)
+ memory_type = "ToolTrajectoryMemory"
+
+ node = self._make_memory_item(
+ value=m.get("trajectory", ""),
+ info=info,
+ memory_type=memory_type,
+ tool_used_status=m.get("tool_used_status", []),
+ )
+ fine_memory_items.append(node)
+ except Exception as e:
+ logger.error(f"[MultiModalFine] parse error for tool trajectory: {e}")
+
+ return fine_memory_items
+
@timed
def _process_multi_modal_data(
self, scene_data_info: MessagesType, info, mode: str = "fine", **kwargs
@@ -339,6 +396,11 @@ def _process_multi_modal_data(
)
fine_memory_items.extend(fine_memory_items_string_parser)
+ fine_memory_items_tool_trajectory_parser = self._process_tool_trajectory_fine(
+ fast_memory_items, info
+ )
+ fine_memory_items.extend(fine_memory_items_tool_trajectory_parser)
+
# Part B: get fine multimodal items
for fast_item in fast_memory_items:
sources = fast_item.metadata.sources
@@ -377,6 +439,12 @@ def _process_transfer_multi_modal_data(
# Part A: call llm
fine_memory_items_string_parser = self._process_string_fine([raw_node], info, custom_tags)
fine_memory_items.extend(fine_memory_items_string_parser)
+
+ fine_memory_items_tool_trajectory_parser = self._process_tool_trajectory_fine(
+ [raw_node], info
+ )
+ fine_memory_items.extend(fine_memory_items_tool_trajectory_parser)
+
# Part B: get fine multimodal items
for source in sources:
items = self.multi_modal_parser.process_transfer(
diff --git a/src/memos/mem_reader/read_multi_modal/system_parser.py b/src/memos/mem_reader/read_multi_modal/system_parser.py
index d2a6611af..3f467d649 100644
--- a/src/memos/mem_reader/read_multi_modal/system_parser.py
+++ b/src/memos/mem_reader/read_multi_modal/system_parser.py
@@ -1,5 +1,9 @@
"""Parser for system messages."""
+import json
+import re
+import uuid
+
from typing import Any
from memos.embedders.base import BaseEmbedder
@@ -12,7 +16,7 @@
)
from memos.types.openai_chat_completion_types import ChatCompletionSystemMessageParam
-from .base import BaseMessageParser, _derive_key, _extract_text_from_content
+from .base import BaseMessageParser
logger = get_logger(__name__)
@@ -35,63 +39,42 @@ def create_source(
self,
message: ChatCompletionSystemMessageParam,
info: dict[str, Any],
- ) -> SourceMessage | list[SourceMessage]:
- """
- Create SourceMessage(s) from system message.
-
- For multimodal messages (content is a list of text parts), creates one SourceMessage per part.
- For simple messages (content is str), creates a single SourceMessage.
- """
- if not isinstance(message, dict):
- return []
-
- role = message.get("role", "system")
- raw_content = message.get("content", "")
- chat_time = message.get("chat_time")
- message_id = message.get("message_id")
-
- sources = []
-
- if isinstance(raw_content, list):
- # Multimodal: create one SourceMessage per text part
- for part in raw_content:
- if isinstance(part, dict):
- part_type = part.get("type", "")
- if part_type == "text":
- sources.append(
- SourceMessage(
- type="chat",
- role=role,
- chat_time=chat_time,
- message_id=message_id,
- content=part.get("text", ""),
- )
- )
- else:
- # Simple message: single SourceMessage
- content = _extract_text_from_content(raw_content)
- if content:
- sources.append(
- SourceMessage(
- type="chat",
- role=role,
- chat_time=chat_time,
- message_id=message_id,
- content=content,
- )
- )
-
- return (
- sources
- if len(sources) > 1
- else (sources[0] if sources else SourceMessage(type="chat", role=role))
+ ) -> SourceMessage:
+ """Create SourceMessage from system message."""
+ content = message["content"]
+ if isinstance(content, dict):
+ content = content["text"]
+
+ content_wo_tool_schema = re.sub(
+ r"(.*?)",
+ r"omitted",
+ content,
+ flags=re.DOTALL,
+ )
+ tool_schema_match = re.search(r"(.*?)", content, re.DOTALL)
+ tool_schema_content = tool_schema_match.group(1) if tool_schema_match else ""
+
+ return SourceMessage(
+ type="chat",
+ role="system",
+ chat_time=message.get("chat_time", None),
+ message_id=message.get("message_id", None),
+ content=content_wo_tool_schema,
+ tool_schema=tool_schema_content,
)
def rebuild_from_source(
self,
source: SourceMessage,
) -> ChatCompletionSystemMessageParam:
- """We only need rebuild from specific multimodal source"""
+ """Rebuild system message from SourceMessage."""
+ # only rebuild tool schema content, content will be used in full chat content by llm
+ return {
+ "role": "system",
+ "content": source.tool_schema or "",
+ "chat_time": source.chat_time,
+ "message_id": source.message_id,
+ }
def parse_fast(
self,
@@ -99,59 +82,47 @@ def parse_fast(
info: dict[str, Any],
**kwargs,
) -> list[TextualMemoryItem]:
- if not isinstance(message, dict):
- logger.warning(f"[SystemParser] Expected dict, got {type(message)}")
- return []
-
- role = message.get("role", "")
- raw_content = message.get("content", "")
- chat_time = message.get("chat_time", None)
- content = _extract_text_from_content(raw_content)
- if role != "system":
- logger.warning(f"[SystemParser] Expected role is `system`, got {role}")
- return []
- parts = [f"{role}: "]
- if chat_time:
- parts.append(f"[{chat_time}]: ")
- prefix = "".join(parts)
- line = f"{prefix}{content}\n"
- if not line:
- return []
- memory_type = "LongTermMemory"
+ content = message["content"]
+ if isinstance(content, dict):
+ content = content["text"]
+
+ # Replace tool_schema content with "omitted" in remaining content
+ content_wo_tool_schema = re.sub(
+ r"(.*?)",
+ r"omitted",
+ content,
+ flags=re.DOTALL,
+ )
- # Create source(s) using parser's create_source method
- sources = self.create_source(message, info)
- if isinstance(sources, SourceMessage):
- sources = [sources]
- elif not sources:
- return []
+ source = self.create_source(message, info)
# Extract info fields
info_ = info.copy()
user_id = info_.pop("user_id", "")
session_id = info_.pop("session_id", "")
- # Create memory item (equivalent to _make_memory_item)
- memory_item = TextualMemoryItem(
- memory=line,
- metadata=TreeNodeTextualMemoryMetadata(
- user_id=user_id,
- session_id=session_id,
- memory_type=memory_type,
- status="activated",
- tags=["mode:fast"],
- key=_derive_key(line),
- embedding=self.embedder.embed([line])[0],
- usage=[],
- sources=sources,
- background="",
- confidence=0.99,
- type="fact",
- info=info_,
- ),
- )
-
- return [memory_item]
+ # Split parsed text into chunks
+ content_chunks = self._split_text(content_wo_tool_schema)
+
+ memory_items = []
+ for _chunk_idx, chunk_text in enumerate(content_chunks):
+ if not chunk_text.strip():
+ continue
+
+ memory_item = TextualMemoryItem(
+ memory=chunk_text,
+ metadata=TreeNodeTextualMemoryMetadata(
+ user_id=user_id,
+ session_id=session_id,
+ memory_type="LongTermMemory", # only choce long term memory for system messages as a placeholder
+ status="activated",
+ tags=["mode:fast"],
+ sources=[source],
+ info=info_,
+ ),
+ )
+ memory_items.append(memory_item)
+ return memory_items
def parse_fine(
self,
@@ -159,4 +130,35 @@ def parse_fine(
info: dict[str, Any],
**kwargs,
) -> list[TextualMemoryItem]:
- return []
+ content = message["content"]
+ if isinstance(content, dict):
+ content = content["text"]
+ try:
+ tool_schema = json.loads(content)
+ assert isinstance(tool_schema, list), "Tool schema must be a list[dict]"
+ except json.JSONDecodeError:
+ logger.warning(f"[SystemParser] Failed to parse tool schema: {content}")
+ return []
+ except AssertionError:
+ logger.warning(f"[SystemParser] Tool schema must be a list[dict]: {content}")
+ return []
+
+ info_ = info.copy()
+ user_id = info_.pop("user_id", "")
+ session_id = info_.pop("session_id", "")
+
+ return [
+ TextualMemoryItem(
+ id=str(uuid.uuid4()),
+ memory=json.dumps(schema),
+ metadata=TreeNodeTextualMemoryMetadata(
+ user_id=user_id,
+ session_id=session_id,
+ memory_type="ToolSchemaMemory",
+ status="activated",
+ embedding=self.embedder.embed([json.dumps(schema)])[0],
+ info=info_,
+ ),
+ )
+ for schema in tool_schema
+ ]
diff --git a/src/memos/mem_reader/read_multi_modal/tool_parser.py b/src/memos/mem_reader/read_multi_modal/tool_parser.py
index 7a11d931a..09bd9e9d0 100644
--- a/src/memos/mem_reader/read_multi_modal/tool_parser.py
+++ b/src/memos/mem_reader/read_multi_modal/tool_parser.py
@@ -1,14 +1,20 @@
"""Parser for tool messages."""
+import json
+
from typing import Any
from memos.embedders.base import BaseEmbedder
from memos.llms.base import BaseLLM
from memos.log import get_logger
-from memos.memories.textual.item import SourceMessage, TextualMemoryItem
+from memos.memories.textual.item import (
+ SourceMessage,
+ TextualMemoryItem,
+ TreeNodeTextualMemoryMetadata,
+)
from memos.types.openai_chat_completion_types import ChatCompletionToolMessageParam
-from .base import BaseMessageParser, _extract_text_from_content
+from .base import BaseMessageParser
logger = get_logger(__name__)
@@ -29,190 +35,155 @@ def __init__(self, embedder: BaseEmbedder, llm: BaseLLM | None = None):
def create_source(
self,
- message: ChatCompletionToolMessageParam | dict[str, Any],
+ message: ChatCompletionToolMessageParam,
info: dict[str, Any],
- ) -> SourceMessage:
- """Create SourceMessage from tool message or custom tool format."""
+ ) -> SourceMessage | list[SourceMessage]:
+ """Create SourceMessage from tool message."""
+
if not isinstance(message, dict):
- return SourceMessage(type="chat", role="tool")
-
- # Handle custom tool formats (tool_description, tool_input, tool_output)
- msg_type = message.get("type", "")
- if msg_type == "tool_description":
- name = message.get("name", "")
- description = message.get("description", "")
- parameters = message.get("parameters", {})
- content = f"[tool_description] name={name}, description={description}, parameters={parameters}"
- return SourceMessage(
- type="tool_description",
- content=content,
- original_part=message,
- )
- elif msg_type == "tool_input":
- call_id = message.get("call_id", "")
- name = message.get("name", "")
- argument = message.get("argument", {})
- content = f"[tool_input] call_id={call_id}, name={name}, argument={argument}"
- return SourceMessage(
- type="tool_input",
- content=content,
- message_id=call_id,
- original_part=message,
- )
- elif msg_type == "tool_output":
- call_id = message.get("call_id", "")
- name = message.get("name", "")
- output = message.get("output", {})
- content = f"[tool_output] call_id={call_id}, name={name}, output={output}"
- return SourceMessage(
- type="tool_output",
- content=content,
- message_id=call_id,
- original_part=message,
- )
+ return []
- # Handle standard tool message
- content = _extract_text_from_content(message.get("content", ""))
- return SourceMessage(
- type="tool",
- role="tool",
- chat_time=message.get("chat_time"),
- message_id=message.get("message_id"),
- content=content,
- )
+ role = message.get("role", "tool")
+ raw_content = message.get("content", "")
+ tool_call_id = message.get("tool_call_id", "")
+ chat_time = message.get("chat_time")
+ message_id = message.get("message_id")
+
+ sources = []
+
+ if isinstance(raw_content, list):
+ # Multimodal: create one SourceMessage per part
+ for part in raw_content:
+ if isinstance(part, dict):
+ part_type = part.get("type", "")
+ if part_type == "text":
+ sources.append(
+ SourceMessage(
+ type="text",
+ role=role,
+ chat_time=chat_time,
+ message_id=message_id,
+ content=part.get("text", ""),
+ tool_call_id=tool_call_id,
+ )
+ )
+ elif part_type == "file":
+ file_info = part.get("file", {})
+ sources.append(
+ SourceMessage(
+ type="file",
+ role=role,
+ chat_time=chat_time,
+ message_id=message_id,
+ content=file_info.get("file_data", ""),
+ filename=file_info.get("filename", ""),
+ file_id=file_info.get("file_id", ""),
+ tool_call_id=tool_call_id,
+ original_part=part,
+ )
+ )
+ elif part_type == "image_url":
+ file_info = part.get("image_url", {})
+ sources.append(
+ SourceMessage(
+ type="image_url",
+ role=role,
+ chat_time=chat_time,
+ message_id=message_id,
+ content=file_info.get("url", ""),
+ detail=file_info.get("detail", "auto"),
+ tool_call_id=tool_call_id,
+ original_part=part,
+ )
+ )
+ elif part_type == "input_audio":
+ file_info = part.get("input_audio", {})
+ sources.append(
+ SourceMessage(
+ type="input_audio",
+ role=role,
+ chat_time=chat_time,
+ message_id=message_id,
+ content=file_info.get("data", ""),
+ format=file_info.get("format", "wav"),
+ tool_call_id=tool_call_id,
+ original_part=part,
+ )
+ )
+ else:
+ logger.warning(f"[ToolParser] Unsupported part type: {part_type}")
+ continue
+ else:
+ # Simple string content message: single SourceMessage
+ if raw_content:
+ sources.append(
+ SourceMessage(
+ type="chat",
+ role=role,
+ chat_time=chat_time,
+ message_id=message_id,
+ content=raw_content,
+ tool_call_id=tool_call_id,
+ )
+ )
+
+ return sources
def rebuild_from_source(
self,
source: SourceMessage,
) -> ChatCompletionToolMessageParam:
"""Rebuild tool message from SourceMessage."""
- return {
- "role": "tool",
- "content": source.content or "",
- "tool_call_id": source.message_id or "", # tool_call_id might be in message_id
- "chat_time": source.chat_time,
- "message_id": source.message_id,
- }
def parse_fast(
self,
- message: ChatCompletionToolMessageParam | dict[str, Any],
+ message: ChatCompletionToolMessageParam,
info: dict[str, Any],
**kwargs,
) -> list[TextualMemoryItem]:
- """
- Parse tool message in fast mode.
-
- Handles both standard tool messages and custom tool formats:
- - Standard tool message: role="tool", content, tool_call_id
- - Custom formats: tool_description, tool_input, tool_output
+ role = message.get("role", "")
+ content = message.get("content", "")
+ chat_time = message.get("chat_time", None)
- Args:
- message: Tool message to parse
- info: Dictionary containing user_id and session_id
- **kwargs: Additional parameters
-
- Returns:
- List of TextualMemoryItem objects
- """
- from memos.memories.textual.item import TreeNodeTextualMemoryMetadata
-
- from .base import _derive_key
-
- if not isinstance(message, dict):
- logger.warning(f"[ToolParser] Expected dict, got {type(message)}")
- return []
-
- # Handle custom tool formats (tool_description, tool_input, tool_output)
- msg_type = message.get("type", "")
- if msg_type in ("tool_description", "tool_input", "tool_output"):
- # Create source
- source = self.create_source(message, info)
- content = source.content or ""
- if not content:
- return []
-
- # Extract info fields
- info_ = info.copy()
- user_id = info_.pop("user_id", "")
- session_id = info_.pop("session_id", "")
-
- # Create memory item
- memory_item = TextualMemoryItem(
- memory=content,
- metadata=TreeNodeTextualMemoryMetadata(
- user_id=user_id,
- session_id=session_id,
- memory_type="LongTermMemory",
- status="activated",
- tags=["mode:fast"],
- key=_derive_key(content),
- embedding=self.embedder.embed([content])[0],
- usage=[],
- sources=[source],
- background="",
- confidence=0.99,
- type="fact",
- info=info_,
- ),
- )
- return [memory_item]
-
- # Handle standard tool message (role="tool")
- role = message.get("role", "").strip().lower()
if role != "tool":
- logger.warning(f"[ToolParser] Expected role='tool', got role='{role}'")
+ logger.warning(f"[ToolParser] Expected role is `tool`, got {role}")
return []
-
- # Extract content from tool message
- content = _extract_text_from_content(message.get("content", ""))
- if not content:
- return []
-
- # Build formatted line similar to assistant_parser
- tool_call_id = message.get("tool_call_id", "")
- chat_time = message.get("chat_time")
-
parts = [f"{role}: "]
if chat_time:
parts.append(f"[{chat_time}]: ")
- if tool_call_id:
- parts.append(f"[tool_call_id: {tool_call_id}]: ")
prefix = "".join(parts)
+ content = json.dumps(content) if isinstance(content, list | dict) else content
line = f"{prefix}{content}\n"
+ if not line:
+ return []
- # Create source
- source = self.create_source(message, info)
+ sources = self.create_source(message, info)
# Extract info fields
info_ = info.copy()
user_id = info_.pop("user_id", "")
session_id = info_.pop("session_id", "")
- # Tool messages are typically LongTermMemory (they're system/assistant tool results)
- memory_type = "LongTermMemory"
-
- # Create memory item
- memory_item = TextualMemoryItem(
- memory=line,
- metadata=TreeNodeTextualMemoryMetadata(
- user_id=user_id,
- session_id=session_id,
- memory_type=memory_type,
- status="activated",
- tags=["mode:fast"],
- key=_derive_key(line),
- embedding=self.embedder.embed([line])[0],
- usage=[],
- sources=[source],
- background="",
- confidence=0.99,
- type="fact",
- info=info_,
- ),
- )
-
- return [memory_item]
+ content_chunks = self._split_text(line)
+ memory_items = []
+ for _chunk_idx, chunk_text in enumerate(content_chunks):
+ if not chunk_text.strip():
+ continue
+
+ memory_item = TextualMemoryItem(
+ memory=chunk_text,
+ metadata=TreeNodeTextualMemoryMetadata(
+ user_id=user_id,
+ session_id=session_id,
+ memory_type="LongTermMemory", # only choce long term memory for tool messages as a placeholder
+ status="activated",
+ tags=["mode:fast"],
+ sources=sources,
+ info=info_,
+ ),
+ )
+ memory_items.append(memory_item)
+ return memory_items
def parse_fine(
self,
@@ -220,4 +191,5 @@ def parse_fine(
info: dict[str, Any],
**kwargs,
) -> list[TextualMemoryItem]:
+ # tool message no special multimodal handling is required in fine mode.
return []
diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py
index 53a7de035..7f7b16234 100644
--- a/src/memos/mem_reader/simple_struct.py
+++ b/src/memos/mem_reader/simple_struct.py
@@ -223,6 +223,7 @@ def _make_memory_item(
background: str = "",
type_: str = "fact",
confidence: float = 0.99,
+ **kwargs,
) -> TextualMemoryItem:
"""construct memory item"""
info_ = info.copy()
@@ -245,6 +246,7 @@ def _make_memory_item(
confidence=confidence,
type=type_,
info=info_,
+ **kwargs,
),
)
diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py
index a85c533a0..f99360a86 100644
--- a/src/memos/mem_scheduler/optimized_scheduler.py
+++ b/src/memos/mem_scheduler/optimized_scheduler.py
@@ -159,6 +159,8 @@ def mix_search_memories(
search_filter=search_filter,
search_priority=search_priority,
info=info,
+ search_tool_memory=search_req.search_tool_memory,
+ tool_mem_top_k=search_req.tool_mem_top_k,
)
# Try to get pre-computed memories if available
@@ -182,6 +184,8 @@ def mix_search_memories(
top_k=search_req.top_k,
user_name=user_context.mem_cube_id,
info=info,
+ search_tool_memory=search_req.search_tool_memory,
+ tool_mem_top_k=search_req.tool_mem_top_k,
)
memories = merged_memories[: search_req.top_k]
diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py
index b7956bfec..75a16bace 100644
--- a/src/memos/memories/textual/item.py
+++ b/src/memos/memories/textual/item.py
@@ -24,7 +24,7 @@ class SourceMessage(BaseModel):
- type: Source kind (e.g., "chat", "doc", "web", "file", "system", ...).
If not provided, upstream logic may infer it:
presence of `role` ⇒ "chat"; otherwise ⇒ "doc".
- - role: Conversation role ("user" | "assistant" | "system") when the
+ - role: Conversation role ("user" | "assistant" | "system" | "tool") when the
source is a chat turn.
- content: Minimal reproducible snippet from the source. If omitted,
upstream may fall back to `doc_path` / `url` / `message_id`.
@@ -99,9 +99,14 @@ def __str__(self) -> str:
class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata):
"""Extended metadata for structured memory, layered retrieval, and lifecycle tracking."""
- memory_type: Literal["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"] = Field(
- default="WorkingMemory", description="Memory lifecycle type."
- )
+ memory_type: Literal[
+ "WorkingMemory",
+ "LongTermMemory",
+ "UserMemory",
+ "OuterMemory",
+ "ToolSchemaMemory",
+ "ToolTrajectoryMemory",
+ ] = Field(default="WorkingMemory", description="Memory lifecycle type.")
sources: list[SourceMessage] | None = Field(
default=None, description="Multiple origins of the memory (e.g., URLs, notes)."
)
diff --git a/src/memos/memories/textual/prefer_text_memory/spliter.py b/src/memos/memories/textual/prefer_text_memory/spliter.py
index 3059d611b..a54036778 100644
--- a/src/memos/memories/textual/prefer_text_memory/spliter.py
+++ b/src/memos/memories/textual/prefer_text_memory/spliter.py
@@ -87,7 +87,7 @@ def _split_with_overlap(self, data: MessageList) -> list[MessageList]:
# overlap 1 turns (Q + A = 2)
context = copy.deepcopy(chunk[-2:]) if i + 1 < len(data) else []
chunk = context
- if chunk and len(chunk) % 2 == 0:
+ if chunk:
chunks.append(chunk)
return chunks
diff --git a/src/memos/memories/textual/prefer_text_memory/utils.py b/src/memos/memories/textual/prefer_text_memory/utils.py
index 76d4b4211..03d2ef923 100644
--- a/src/memos/memories/textual/prefer_text_memory/utils.py
+++ b/src/memos/memories/textual/prefer_text_memory/utils.py
@@ -1,3 +1,4 @@
+import json
import re
from memos.dependency import require_python_package
@@ -9,12 +10,36 @@ def convert_messages_to_string(messages: MessageList) -> str:
"""Convert a list of messages to a string."""
message_text = ""
for message in messages:
+ content = message.get("content", "")
+ content = (
+ content.strip()
+ if isinstance(content, str)
+ else json.dumps(content, ensure_ascii=False).strip()
+ )
+ if message["role"] == "system":
+ continue
if message["role"] == "user":
- message_text += f"Query: {message['content']}\n" if message["content"].strip() else ""
+ message_text += f"User: {content}\n" if content else ""
elif message["role"] == "assistant":
- message_text += f"Answer: {message['content']}\n" if message["content"].strip() else ""
- message_text = message_text.strip()
- return message_text
+ tool_calls = message.get("tool_calls", [])
+ tool_calls_str = (
+ f"[tool_calls]: {json.dumps(tool_calls, ensure_ascii=False)}" if tool_calls else ""
+ )
+ line_str = (
+ f"Assistant: {content} {tool_calls_str}".strip()
+ if content or tool_calls_str
+ else ""
+ )
+ message_text += f"{line_str}\n" if line_str else ""
+ elif message["role"] == "tool":
+ tool_call_id = message.get("tool_call_id", "")
+ line_str = (
+ f"Tool: {content} [tool_call_id]: {tool_call_id}".strip()
+ if tool_call_id
+ else f"Tool: {content}".strip()
+ )
+ message_text += f"{line_str}\n" if line_str else ""
+ return message_text.strip()
@require_python_package(
diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py
index ad2bcd9c4..cad850d2d 100644
--- a/src/memos/memories/textual/tree.py
+++ b/src/memos/memories/textual/tree.py
@@ -166,6 +166,8 @@ def search(
search_priority: dict | None = None,
search_filter: dict | None = None,
user_name: str | None = None,
+ search_tool_memory: bool = False,
+ tool_mem_top_k: int = 6,
**kwargs,
) -> list[TextualMemoryItem]:
"""Search for memories based on a query.
@@ -223,6 +225,8 @@ def search(
search_priority,
user_name=user_name,
plugin=kwargs.get("plugin", False),
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
)
def get_relevant_subgraph(
diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py
index a71fee02f..3226f7ca0 100644
--- a/src/memos/memories/textual/tree_text_memory/organize/manager.py
+++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py
@@ -181,12 +181,18 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non
working_id = str(uuid.uuid4())
with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex:
- f_working = ex.submit(
- self._add_memory_to_db, memory, "WorkingMemory", user_name, working_id
- )
- futures.append(("working", f_working))
-
- if memory.metadata.memory_type in ("LongTermMemory", "UserMemory"):
+ if memory.metadata.memory_type not in ("ToolSchemaMemory", "ToolTrajectoryMemory"):
+ f_working = ex.submit(
+ self._add_memory_to_db, memory, "WorkingMemory", user_name, working_id
+ )
+ futures.append(("working", f_working))
+
+ if memory.metadata.memory_type in (
+ "LongTermMemory",
+ "UserMemory",
+ "ToolSchemaMemory",
+ "ToolTrajectoryMemory",
+ ):
f_graph = ex.submit(
self._add_to_graph_memory,
memory=memory,
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py
index 5dfbde704..dea83887e 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py
@@ -59,7 +59,13 @@ def retrieve(
Returns:
list: Combined memory items.
"""
- if memory_scope not in ["WorkingMemory", "LongTermMemory", "UserMemory"]:
+ if memory_scope not in [
+ "WorkingMemory",
+ "LongTermMemory",
+ "UserMemory",
+ "ToolSchemaMemory",
+ "ToolTrajectoryMemory",
+ ]:
raise ValueError(f"Unsupported memory scope: {memory_scope}")
if memory_scope == "WorkingMemory":
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
index 830b915c1..0666f1d86 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
@@ -76,6 +76,8 @@ def retrieve(
search_filter: dict | None = None,
search_priority: dict | None = None,
user_name: str | None = None,
+ search_tool_memory: bool = False,
+ tool_mem_top_k: int = 6,
**kwargs,
) -> list[tuple[TextualMemoryItem, float]]:
logger.info(
@@ -100,6 +102,8 @@ def retrieve(
search_filter,
search_priority,
user_name,
+ search_tool_memory,
+ tool_mem_top_k,
)
return results
@@ -109,10 +113,14 @@ def post_retrieve(
top_k: int,
user_name: str | None = None,
info=None,
+ search_tool_memory: bool = False,
+ tool_mem_top_k: int = 6,
plugin=False,
):
deduped = self._deduplicate_results(retrieved_results)
- final_results = self._sort_and_trim(deduped, top_k, plugin)
+ final_results = self._sort_and_trim(
+ deduped, top_k, plugin, search_tool_memory, tool_mem_top_k
+ )
self._update_usage_history(final_results, info, user_name)
return final_results
@@ -127,6 +135,8 @@ def search(
search_filter: dict | None = None,
search_priority: dict | None = None,
user_name: str | None = None,
+ search_tool_memory: bool = False,
+ tool_mem_top_k: int = 6,
**kwargs,
) -> list[TextualMemoryItem]:
"""
@@ -171,6 +181,8 @@ def search(
search_filter=search_filter,
search_priority=search_priority,
user_name=user_name,
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
)
final_results = self.post_retrieve(
@@ -179,6 +191,8 @@ def search(
user_name=user_name,
info=None,
plugin=kwargs.get("plugin", False),
+ search_tool_memory=search_tool_memory,
+ tool_mem_top_k=tool_mem_top_k,
)
logger.info(f"[SEARCH] Done. Total {len(final_results)} results.")
@@ -272,6 +286,8 @@ def _retrieve_paths(
search_filter: dict | None = None,
search_priority: dict | None = None,
user_name: str | None = None,
+ search_tool_memory: bool = False,
+ tool_mem_top_k: int = 6,
):
"""Run A/B/C retrieval paths in parallel"""
tasks = []
@@ -324,6 +340,22 @@ def _retrieve_paths(
user_name,
)
)
+ if search_tool_memory:
+ tasks.append(
+ executor.submit(
+ self._retrieve_from_tool_memory,
+ query,
+ parsed_goal,
+ query_embedding,
+ top_k,
+ memory_type,
+ search_filter,
+ search_priority,
+ user_name,
+ id_filter,
+ mode=mode,
+ )
+ )
results = []
for t in tasks:
@@ -498,6 +530,98 @@ def _retrieve_from_internet(
parsed_goal=parsed_goal,
)
+ # --- Path D
+ @timed
+ def _retrieve_from_tool_memory(
+ self,
+ query,
+ parsed_goal,
+ query_embedding,
+ top_k,
+ memory_type,
+ search_filter: dict | None = None,
+ search_priority: dict | None = None,
+ user_name: str | None = None,
+ id_filter: dict | None = None,
+ mode: str = "fast",
+ ):
+ """Retrieve and rerank from ToolMemory"""
+ results = {
+ "ToolSchemaMemory": [],
+ "ToolTrajectoryMemory": [],
+ }
+ tasks = []
+
+ # chain of thinking
+ cot_embeddings = []
+ if self.vec_cot:
+ queries = self._cot_query(query, mode=mode, context=parsed_goal.context)
+ if len(queries) > 1:
+ cot_embeddings = self.embedder.embed(queries)
+ cot_embeddings.extend(query_embedding)
+ else:
+ cot_embeddings = query_embedding
+
+ with ContextThreadPoolExecutor(max_workers=2) as executor:
+ if memory_type in ["All", "ToolSchemaMemory"]:
+ tasks.append(
+ executor.submit(
+ self.graph_retriever.retrieve,
+ query=query,
+ parsed_goal=parsed_goal,
+ query_embedding=cot_embeddings,
+ top_k=top_k * 2,
+ memory_scope="ToolSchemaMemory",
+ search_filter=search_filter,
+ search_priority=search_priority,
+ user_name=user_name,
+ id_filter=id_filter,
+ use_fast_graph=self.use_fast_graph,
+ )
+ )
+ if memory_type in ["All", "ToolTrajectoryMemory"]:
+ tasks.append(
+ executor.submit(
+ self.graph_retriever.retrieve,
+ query=query,
+ parsed_goal=parsed_goal,
+ query_embedding=cot_embeddings,
+ top_k=top_k * 2,
+ memory_scope="ToolTrajectoryMemory",
+ search_filter=search_filter,
+ search_priority=search_priority,
+ user_name=user_name,
+ id_filter=id_filter,
+ use_fast_graph=self.use_fast_graph,
+ )
+ )
+
+ # Collect results from all tasks
+ for task in tasks:
+ rsp = task.result()
+ if rsp and rsp[0].metadata.memory_type == "ToolSchemaMemory":
+ results["ToolSchemaMemory"].extend(rsp)
+ elif rsp and rsp[0].metadata.memory_type == "ToolTrajectoryMemory":
+ results["ToolTrajectoryMemory"].extend(rsp)
+
+ schema_reranked = self.reranker.rerank(
+ query=query,
+ query_embedding=query_embedding[0],
+ graph_results=results["ToolSchemaMemory"],
+ top_k=top_k,
+ parsed_goal=parsed_goal,
+ search_filter=search_filter,
+ )
+ trajectory_reranked = self.reranker.rerank(
+ query=query,
+ query_embedding=query_embedding[0],
+ graph_results=results["ToolTrajectoryMemory"],
+ top_k=top_k,
+ parsed_goal=parsed_goal,
+ search_filter=search_filter,
+ )
+ return schema_reranked + trajectory_reranked
+
@timed
def _retrieve_simple(
self,
@@ -554,11 +678,41 @@ def _deduplicate_results(self, results):
return list(deduped.values())
@timed
- def _sort_and_trim(self, results, top_k, plugin=False):
+ def _sort_and_trim(
+ self, results, top_k, plugin=False, search_tool_memory=False, tool_mem_top_k=6
+ ):
"""Sort results by score and trim to top_k"""
+ final_items = []
+ if search_tool_memory:
+ tool_results = [
+ (item, score)
+ for item, score in results
+ if item.metadata.memory_type in ["ToolSchemaMemory", "ToolTrajectoryMemory"]
+ ]
+ sorted_tool_results = sorted(tool_results, key=lambda pair: pair[1], reverse=True)[
+ :tool_mem_top_k
+ ]
+ for item, score in sorted_tool_results:
+ if plugin and round(score, 2) == 0.00:
+ continue
+ meta_data = item.metadata.model_dump()
+ meta_data["relativity"] = score
+ final_items.append(
+ TextualMemoryItem(
+ id=item.id,
+ memory=item.memory,
+ metadata=SearchedTreeNodeTextualMemoryMetadata(**meta_data),
+ )
+ )
+ # separate textual results
+ results = [
+ (item, score)
+ for item, score in results
+ if item.metadata.memory_type not in ["ToolSchemaMemory", "ToolTrajectoryMemory"]
+ ]
sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)[:top_k]
- final_items = []
+
for item, score in sorted_results:
if plugin and round(score, 2) == 0.00:
continue
diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py
index b5bd34417..1ddd2b1b7 100644
--- a/src/memos/multi_mem_cube/single_cube.py
+++ b/src/memos/multi_mem_cube/single_cube.py
@@ -11,6 +11,7 @@
from memos.api.handlers.formatters_handler import (
format_memory_item,
post_process_pref_mem,
+ post_process_textual_mem,
)
from memos.context.context import ContextThreadPoolExecutor
from memos.log import get_logger
@@ -109,6 +110,7 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]:
"para_mem": [],
"pref_mem": [],
"pref_note": "",
+ "tool_mem": [],
}
# Determine search mode
@@ -123,11 +125,10 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]:
pref_formatted_memories = pref_future.result()
# Build result
- memories_result["text_mem"].append(
- {
- "cube_id": self.cube_id,
- "memories": text_formatted_memories,
- }
+ memories_result = post_process_textual_mem(
+ memories_result,
+ text_formatted_memories,
+ self.cube_id,
)
memories_result = post_process_pref_mem(
@@ -278,6 +279,8 @@ def _fine_search(
Returns:
List of enhanced search results
"""
+ # TODO: support tool memory search in future
+
logger.info(f"Fine strategy: {FINE_STRATEGY}")
if FINE_STRATEGY == FineStrategy.DEEP_SEARCH:
return self._deep_search(search_req=search_req, user_context=user_context)
@@ -375,6 +378,9 @@ def _search_pref(
"""
if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true":
return []
+ if not search_req.include_preference:
+ return []
+
logger.info(f"search_req.filter for preference memory: {search_req.filter}")
logger.info(f"type of pref_mem: {type(self.naive_mem_cube.pref_mem)}")
try:
@@ -427,6 +433,8 @@ def _fast_search(
"chat_history": search_req.chat_history,
},
plugin=plugin,
+ search_tool_memory=search_req.search_tool_memory,
+ tool_mem_top_k=search_req.tool_mem_top_k,
)
formatted_memories = [format_memory_item(data) for data in search_results]
@@ -543,6 +551,13 @@ def _process_pref_mem(
if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true":
return []
+ if add_req.messages is None or isinstance(add_req.messages, str):
+ return []
+
+ for message in add_req.messages:
+ if message.get("role", None) is None:
+ return []
+
target_session_id = add_req.session_id or "default_session"
if sync_mode == "async":
diff --git a/src/memos/templates/tool_mem_prompts.py b/src/memos/templates/tool_mem_prompts.py
new file mode 100644
index 000000000..7d5363956
--- /dev/null
+++ b/src/memos/templates/tool_mem_prompts.py
@@ -0,0 +1,84 @@
+TOOL_TRAJECTORY_PROMPT_ZH = """
+你是一个专业的工具调用轨迹提取专家。你的任务是从给定的对话消息中提取完整的工具调用轨迹经验。
+
+## 提取规则:
+1. 只有当对话中存在有价值的工具调用过程时才进行提取
+2. 有价值的轨迹至少包含以下元素:
+ - 用户的问题(user message)
+ - 助手的工具调用尝试(assistant message with tool_calls)
+ - 工具的执行结果(tool message with tool_call_id and content,无论成功或失败)
+ - 助手的响应(assistant message,无论是否给出最终答案)
+
+## 输出格式:
+返回一个JSON数组,格式如下:
+```json
+[
+ {
+ "trajectory": "自然语言输出包含'任务、使用的工具、工具观察、最终回答'的完整精炼的总结,体现顺序",
+ "tool_used_status": [
+ {
+ "used_tool": "工具名1",
+ "success_rate": "0.0-1.0之间的数值,表示该工具在本次轨迹中的成功率",
+ "error_type": "调用失败时的错误类型和描述,成功时为空字符串",
+ "experience": "该工具的使用经验,比如常见的参数模式、执行特点、结果解读方式等"
+ }
+ ]
+ }
+]
+```
+
+## 注意事项:
+- 如果对话中没有完整的工具调用轨迹,返回空数组
+- 每个轨迹必须是独立的完整过程
+- 一个轨迹中可能涉及多个工具的使用,每个工具在tool_used_status中独立记录
+- 只提取事实内容,不要添加任何解释或额外信息
+- 确保返回的是有效的JSON格式
+
+请分析以下对话消息并提取工具调用轨迹:
+
+{messages}
+
+"""
+
+
+TOOL_TRAJECTORY_PROMPT_EN = """
+You are a professional tool call trajectory extraction expert. Your task is to extract valuable tool call trajectory experiences from given conversation messages.
+
+## Extraction Rules:
+1. Only extract when there are valuable tool calling processes in the conversation
+2. Valuable trajectories must contain at least the following elements:
+ - User's question (user message)
+ - Assistant's tool call attempt (assistant message with tool_calls)
+ - Tool execution results (tool message with tool_call_id and content, regardless of success or failure)
+ - Assistant's response (assistant message, whether or not a final answer is given)
+
+## Output Format:
+Return a JSON array in the following format:
+```json
+[
+ {
+ "trajectory": "Natural language summary containing 'task, tools used, tool observations, final answer' in a complete and refined manner, reflecting the sequence",
+ "tool_used_status": [
+ {
+ "used_tool": "Tool Name 1",
+ "success_rate": "Numerical value between 0.0-1.0, indicating the success rate of this tool in the current trajectory",
+ "error_type": "Error type and description when call fails, empty string when successful",
+ "experience": "Usage experience of this tool, such as common parameter patterns, execution characteristics, result interpretation methods, etc."
+ }
+ ]
+ }
+]
+```
+
+## Notes:
+- If there are no complete tool call trajectories in the conversation, return an empty array
+- Each trajectory must be an independent complete process
+- Multiple tools may be used in one trajectory, each tool is recorded independently in tool_used_status
+- Only extract factual content, do not add any additional explanations or information
+- Ensure the returned content is valid JSON format
+
+Please analyze the following conversation messages and extract tool call trajectories:
+
+{messages}
+
+"""
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py
index a742de3a9..3c5638788 100644
--- a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py
@@ -2,7 +2,6 @@
from __future__ import annotations
-from collections.abc import Iterable
from typing import Literal, TypeAlias
from typing_extensions import Required, TypedDict
@@ -35,7 +34,7 @@ class ChatCompletionAssistantMessageParam(TypedDict, total=False):
[Learn more](https://platform.openai.com/docs/guides/audio).
"""
- content: str | Iterable[ContentArrayOfContentPart] | None
+ content: str | list[ContentArrayOfContentPart] | ContentArrayOfContentPart | None
"""The contents of the assistant message.
Required unless `tool_calls` or `function_call` is specified.
@@ -44,7 +43,9 @@ class ChatCompletionAssistantMessageParam(TypedDict, total=False):
refusal: str | None
"""The refusal message by the assistant."""
- tool_calls: Iterable[ChatCompletionMessageToolCallUnionParam]
+ tool_calls: (
+ list[ChatCompletionMessageToolCallUnionParam] | ChatCompletionMessageToolCallUnionParam
+ )
"""The tool calls generated by the model, such as function calls."""
chat_time: str | None
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py
index 7faa90e2e..ea2101229 100644
--- a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py
@@ -2,7 +2,6 @@
from __future__ import annotations
-from collections.abc import Iterable
from typing import Literal
from typing_extensions import Required, TypedDict
@@ -14,7 +13,9 @@
class ChatCompletionSystemMessageParam(TypedDict, total=False):
- content: Required[str | Iterable[ChatCompletionContentPartTextParam]]
+ content: Required[
+ str | list[ChatCompletionContentPartTextParam] | ChatCompletionContentPartTextParam
+ ]
"""The contents of the system message."""
role: Required[Literal["system"]]
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py
index c03220915..99c845d11 100644
--- a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py
@@ -2,7 +2,6 @@
from __future__ import annotations
-from collections.abc import Iterable
from typing import Literal
from typing_extensions import Required, TypedDict
@@ -14,7 +13,7 @@
class ChatCompletionToolMessageParam(TypedDict, total=False):
- content: Required[str | Iterable[ChatCompletionContentPartParam]]
+ content: Required[str | list[ChatCompletionContentPartParam] | ChatCompletionContentPartParam]
"""The contents of the tool message."""
role: Required[Literal["tool"]]
diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py
index 2c2a1f23f..8c004f340 100644
--- a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py
+++ b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py
@@ -2,7 +2,6 @@
from __future__ import annotations
-from collections.abc import Iterable
from typing import Literal
from typing_extensions import Required, TypedDict
@@ -14,7 +13,7 @@
class ChatCompletionUserMessageParam(TypedDict, total=False):
- content: Required[str | Iterable[ChatCompletionContentPartParam]]
+ content: Required[str | list[ChatCompletionContentPartParam] | ChatCompletionContentPartParam]
"""The contents of the user message."""
role: Required[Literal["user"]]