diff --git a/docker/requirements.txt b/docker/requirements.txt index d3268edae..21f246599 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -160,3 +160,4 @@ xlrd==2.0.2 xlsxwriter==3.2.5 prometheus-client==0.23.1 pymilvus==2.5.12 +langchain-text-splitters==1.0.0 diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index fe6b600b8..e9bb2e499 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -142,7 +142,9 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An # Step 2: Build system prompt system_prompt = self._build_system_prompt( - filtered_memories, search_response.data["pref_string"], chat_req.system_prompt + filtered_memories, + search_response.data.get("pref_string", ""), + chat_req.system_prompt, ) # Prepare message history @@ -257,7 +259,7 @@ def generate_chat_response() -> Generator[str, None, None]: # Step 2: Build system prompt with memories system_prompt = self._build_system_prompt( filtered_memories, - search_response.data["pref_string"], + search_response.data.get("pref_string", ""), chat_req.system_prompt, ) @@ -449,7 +451,7 @@ def generate_chat_response() -> Generator[str, None, None]: # Step 2: Build system prompt with memories system_prompt = self._build_enhance_system_prompt( - filtered_memories, search_response.data["pref_string"] + filtered_memories, search_response.data.get("pref_string", "") ) # Prepare messages diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py index 976be87bb..88875cacc 100644 --- a/src/memos/api/handlers/formatters_handler.py +++ b/src/memos/api/handlers/formatters_handler.py @@ -90,3 +90,37 @@ def post_process_pref_mem( memories_result["pref_note"] = pref_note return memories_result + + +def post_process_textual_mem( + memories_result: dict[str, Any], + text_formatted_mem: list[dict[str, Any]], + mem_cube_id: str, +) -> dict[str, Any]: + """ + Post-process text and tool memory results. + """ + fact_mem = [ + mem + for mem in text_formatted_mem + if mem["metadata"]["memory_type"] not in ["ToolSchemaMemory", "ToolTrajectoryMemory"] + ] + tool_mem = [ + mem + for mem in text_formatted_mem + if mem["metadata"]["memory_type"] in ["ToolSchemaMemory", "ToolTrajectoryMemory"] + ] + + memories_result["text_mem"].append( + { + "cube_id": mem_cube_id, + "memories": fact_mem, + } + ) + memories_result["tool_mem"].append( + { + "cube_id": mem_cube_id, + "memories": tool_mem, + } + ) + return memories_result diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index ffe736aa3..4ad4016bc 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -6,7 +6,7 @@ # Import message types from core types module from memos.log import get_logger -from memos.types import MessageDict, PermissionDict, SearchMode +from memos.types import MessageList, MessagesType, PermissionDict, SearchMode logger = get_logger(__name__) @@ -56,7 +56,7 @@ class Message(BaseModel): class MemoryCreate(BaseRequest): user_id: str = Field(..., description="User ID") - messages: list | None = Field(None, description="List of messages to store.") + messages: MessageList | None = Field(None, description="List of messages to store.") memory_content: str | None = Field(None, description="Content to store as memory") doc_path: str | None = Field(None, description="Path to document to store") mem_cube_id: str | None = Field(None, description="ID of the memory cube") @@ -83,7 +83,7 @@ class ChatRequest(BaseRequest): writable_cube_ids: list[str] | None = Field( None, description="List of cube IDs user can write for multi-cube chat" ) - history: list | None = Field(None, description="Chat history") + history: MessageList | None = Field(None, description="Chat history") mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") system_prompt: str | None = Field(None, description="Base system prompt to use for chat") top_k: int = Field(10, description="Number of results to return") @@ -165,7 +165,7 @@ class ChatCompleteRequest(BaseRequest): user_id: str = Field(..., description="User ID") query: str = Field(..., description="Chat query message") mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") - history: list | None = Field(None, description="Chat history") + history: MessageList | None = Field(None, description="Chat history") internet_search: bool = Field(False, description="Whether to use internet search") system_prompt: str | None = Field(None, description="Base prompt to use for chat") top_k: int = Field(10, description="Number of results to return") @@ -251,7 +251,7 @@ class MemoryCreateRequest(BaseRequest): """Request model for creating memories.""" user_id: str = Field(..., description="User ID") - messages: str | list | None = Field(None, description="List of messages to store.") + messages: str | MessagesType | None = Field(None, description="List of messages to store.") memory_content: str | None = Field(None, description="Memory content to store") doc_path: str | None = Field(None, description="Path to document to store") mem_cube_id: str | None = Field(None, description="Cube ID") @@ -326,6 +326,21 @@ class APISearchRequest(BaseRequest): ), ) + search_tool_memory: bool = Field( + True, + description=( + "Whether to retrieve tool memories along with general memories. " + "If enabled, the system will automatically recall tool memories " + "relevant to the query. Default: True." + ), + ) + + tool_mem_top_k: int = Field( + 6, + ge=0, + description="Number of tool memories to retrieve (top-K). Default: 6.", + ) + # ==== Filter conditions ==== # TODO: maybe add detailed description later filter: dict[str, Any] | None = Field( @@ -360,7 +375,7 @@ class APISearchRequest(BaseRequest): ) # ==== Context ==== - chat_history: list | None = Field( + chat_history: MessageList | None = Field( None, description=( "Historical chat messages used internally by algorithms. " @@ -490,7 +505,7 @@ class APIADDRequest(BaseRequest): ) # ==== Input content ==== - messages: str | list | None = Field( + messages: MessagesType | None = Field( None, description=( "List of messages to store. Supports: " @@ -506,7 +521,7 @@ class APIADDRequest(BaseRequest): ) # ==== Chat history ==== - chat_history: list | None = Field( + chat_history: MessageList | None = Field( None, description=( "Historical chat messages used internally by algorithms. " @@ -636,7 +651,7 @@ class APIFeedbackRequest(BaseRequest): "default_session", description="Session ID for soft-filtering memories" ) task_id: str | None = Field(None, description="Task ID for monitering async tasks") - history: list[MessageDict] | None = Field(..., description="Chat history") + history: MessageList | None = Field(..., description="Chat history") retrieved_memory_ids: list[str] | None = Field( None, description="Retrieved memory ids at last turn" ) @@ -671,7 +686,7 @@ class APIChatCompleteRequest(BaseRequest): writable_cube_ids: list[str] | None = Field( None, description="List of cube IDs user can write for multi-cube chat" ) - history: list | None = Field(None, description="Chat history") + history: MessageList | None = Field(None, description="Chat history") mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") system_prompt: str | None = Field(None, description="Base system prompt to use for chat") top_k: int = Field(10, description="Number of results to return") @@ -740,7 +755,7 @@ class SuggestionRequest(BaseRequest): user_id: str = Field(..., description="User ID") mem_cube_id: str = Field(..., description="Cube ID") language: Literal["zh", "en"] = Field("zh", description="Language for suggestions") - message: list | None = Field(None, description="List of messages to store.") + message: MessagesType | None = Field(None, description="List of messages to store.") # ─── MemOS Client Response Models ────────────────────────────────────────────── diff --git a/src/memos/mem_reader/multi_modal_struct.py b/src/memos/mem_reader/multi_modal_struct.py index 57774cf3a..e0aa40913 100644 --- a/src/memos/mem_reader/multi_modal_struct.py +++ b/src/memos/mem_reader/multi_modal_struct.py @@ -1,4 +1,5 @@ import concurrent.futures +import json import traceback from typing import Any @@ -7,8 +8,9 @@ from memos.configs.mem_reader import MultiModalStructMemReaderConfig from memos.context.context import ContextThreadPoolExecutor from memos.mem_reader.read_multi_modal import MultiModalParser -from memos.mem_reader.simple_struct import SimpleStructMemReader +from memos.mem_reader.simple_struct import SimpleStructMemReader, detect_lang from memos.memories.textual.item import TextualMemoryItem +from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH from memos.types import MessagesType from memos.utils import timed @@ -297,6 +299,61 @@ def _process_string_fine( return fine_memory_items + def _get_llm_tool_trajectory_response(self, mem_str: str) -> dict: + """ + Generete tool trajectory experience item by llm. + """ + try: + lang = detect_lang(mem_str) + template = TOOL_TRAJECTORY_PROMPT_ZH if lang == "zh" else TOOL_TRAJECTORY_PROMPT_EN + prompt = template.replace("{messages}", mem_str) + rsp = self.llm.generate([{"role": "user", "content": prompt}]) + rsp = rsp.replace("```json", "").replace("```", "") + return json.loads(rsp) + except Exception as e: + logger.error(f"[MultiModalFine] Error calling LLM for tool trajectory: {e}") + return [] + + def _process_tool_trajectory_fine( + self, + fast_memory_items: list[TextualMemoryItem], + info: dict[str, Any], + ) -> list[TextualMemoryItem]: + """ + Process tool trajectory memory items through LLM to generate fine mode memories. + """ + if not fast_memory_items: + return [] + + fine_memory_items = [] + + for fast_item in fast_memory_items: + # Extract memory text (string content) + mem_str = fast_item.memory or "" + if not mem_str.strip() or "tool:" not in mem_str: + continue + try: + resp = self._get_llm_tool_trajectory_response(mem_str) + except Exception as e: + logger.error(f"[MultiModalFine] Error calling LLM for tool trajectory: {e}") + continue + for m in resp: + try: + # Normalize memory_type (same as simple_struct) + memory_type = "ToolTrajectoryMemory" + + node = self._make_memory_item( + value=m.get("trajectory", ""), + info=info, + memory_type=memory_type, + tool_used_status=m.get("tool_used_status", []), + ) + fine_memory_items.append(node) + except Exception as e: + logger.error(f"[MultiModalFine] parse error for tool trajectory: {e}") + + return fine_memory_items + @timed def _process_multi_modal_data( self, scene_data_info: MessagesType, info, mode: str = "fine", **kwargs @@ -339,6 +396,11 @@ def _process_multi_modal_data( ) fine_memory_items.extend(fine_memory_items_string_parser) + fine_memory_items_tool_trajectory_parser = self._process_tool_trajectory_fine( + fast_memory_items, info + ) + fine_memory_items.extend(fine_memory_items_tool_trajectory_parser) + # Part B: get fine multimodal items for fast_item in fast_memory_items: sources = fast_item.metadata.sources @@ -377,6 +439,12 @@ def _process_transfer_multi_modal_data( # Part A: call llm fine_memory_items_string_parser = self._process_string_fine([raw_node], info, custom_tags) fine_memory_items.extend(fine_memory_items_string_parser) + + fine_memory_items_tool_trajectory_parser = self._process_tool_trajectory_fine( + [raw_node], info + ) + fine_memory_items.extend(fine_memory_items_tool_trajectory_parser) + # Part B: get fine multimodal items for source in sources: items = self.multi_modal_parser.process_transfer( diff --git a/src/memos/mem_reader/read_multi_modal/system_parser.py b/src/memos/mem_reader/read_multi_modal/system_parser.py index d2a6611af..3f467d649 100644 --- a/src/memos/mem_reader/read_multi_modal/system_parser.py +++ b/src/memos/mem_reader/read_multi_modal/system_parser.py @@ -1,5 +1,9 @@ """Parser for system messages.""" +import json +import re +import uuid + from typing import Any from memos.embedders.base import BaseEmbedder @@ -12,7 +16,7 @@ ) from memos.types.openai_chat_completion_types import ChatCompletionSystemMessageParam -from .base import BaseMessageParser, _derive_key, _extract_text_from_content +from .base import BaseMessageParser logger = get_logger(__name__) @@ -35,63 +39,42 @@ def create_source( self, message: ChatCompletionSystemMessageParam, info: dict[str, Any], - ) -> SourceMessage | list[SourceMessage]: - """ - Create SourceMessage(s) from system message. - - For multimodal messages (content is a list of text parts), creates one SourceMessage per part. - For simple messages (content is str), creates a single SourceMessage. - """ - if not isinstance(message, dict): - return [] - - role = message.get("role", "system") - raw_content = message.get("content", "") - chat_time = message.get("chat_time") - message_id = message.get("message_id") - - sources = [] - - if isinstance(raw_content, list): - # Multimodal: create one SourceMessage per text part - for part in raw_content: - if isinstance(part, dict): - part_type = part.get("type", "") - if part_type == "text": - sources.append( - SourceMessage( - type="chat", - role=role, - chat_time=chat_time, - message_id=message_id, - content=part.get("text", ""), - ) - ) - else: - # Simple message: single SourceMessage - content = _extract_text_from_content(raw_content) - if content: - sources.append( - SourceMessage( - type="chat", - role=role, - chat_time=chat_time, - message_id=message_id, - content=content, - ) - ) - - return ( - sources - if len(sources) > 1 - else (sources[0] if sources else SourceMessage(type="chat", role=role)) + ) -> SourceMessage: + """Create SourceMessage from system message.""" + content = message["content"] + if isinstance(content, dict): + content = content["text"] + + content_wo_tool_schema = re.sub( + r"(.*?)", + r"omitted", + content, + flags=re.DOTALL, + ) + tool_schema_match = re.search(r"(.*?)", content, re.DOTALL) + tool_schema_content = tool_schema_match.group(1) if tool_schema_match else "" + + return SourceMessage( + type="chat", + role="system", + chat_time=message.get("chat_time", None), + message_id=message.get("message_id", None), + content=content_wo_tool_schema, + tool_schema=tool_schema_content, ) def rebuild_from_source( self, source: SourceMessage, ) -> ChatCompletionSystemMessageParam: - """We only need rebuild from specific multimodal source""" + """Rebuild system message from SourceMessage.""" + # only rebuild tool schema content, content will be used in full chat content by llm + return { + "role": "system", + "content": source.tool_schema or "", + "chat_time": source.chat_time, + "message_id": source.message_id, + } def parse_fast( self, @@ -99,59 +82,47 @@ def parse_fast( info: dict[str, Any], **kwargs, ) -> list[TextualMemoryItem]: - if not isinstance(message, dict): - logger.warning(f"[SystemParser] Expected dict, got {type(message)}") - return [] - - role = message.get("role", "") - raw_content = message.get("content", "") - chat_time = message.get("chat_time", None) - content = _extract_text_from_content(raw_content) - if role != "system": - logger.warning(f"[SystemParser] Expected role is `system`, got {role}") - return [] - parts = [f"{role}: "] - if chat_time: - parts.append(f"[{chat_time}]: ") - prefix = "".join(parts) - line = f"{prefix}{content}\n" - if not line: - return [] - memory_type = "LongTermMemory" + content = message["content"] + if isinstance(content, dict): + content = content["text"] + + # Replace tool_schema content with "omitted" in remaining content + content_wo_tool_schema = re.sub( + r"(.*?)", + r"omitted", + content, + flags=re.DOTALL, + ) - # Create source(s) using parser's create_source method - sources = self.create_source(message, info) - if isinstance(sources, SourceMessage): - sources = [sources] - elif not sources: - return [] + source = self.create_source(message, info) # Extract info fields info_ = info.copy() user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") - # Create memory item (equivalent to _make_memory_item) - memory_item = TextualMemoryItem( - memory=line, - metadata=TreeNodeTextualMemoryMetadata( - user_id=user_id, - session_id=session_id, - memory_type=memory_type, - status="activated", - tags=["mode:fast"], - key=_derive_key(line), - embedding=self.embedder.embed([line])[0], - usage=[], - sources=sources, - background="", - confidence=0.99, - type="fact", - info=info_, - ), - ) - - return [memory_item] + # Split parsed text into chunks + content_chunks = self._split_text(content_wo_tool_schema) + + memory_items = [] + for _chunk_idx, chunk_text in enumerate(content_chunks): + if not chunk_text.strip(): + continue + + memory_item = TextualMemoryItem( + memory=chunk_text, + metadata=TreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type="LongTermMemory", # only choce long term memory for system messages as a placeholder + status="activated", + tags=["mode:fast"], + sources=[source], + info=info_, + ), + ) + memory_items.append(memory_item) + return memory_items def parse_fine( self, @@ -159,4 +130,35 @@ def parse_fine( info: dict[str, Any], **kwargs, ) -> list[TextualMemoryItem]: - return [] + content = message["content"] + if isinstance(content, dict): + content = content["text"] + try: + tool_schema = json.loads(content) + assert isinstance(tool_schema, list), "Tool schema must be a list[dict]" + except json.JSONDecodeError: + logger.warning(f"[SystemParser] Failed to parse tool schema: {content}") + return [] + except AssertionError: + logger.warning(f"[SystemParser] Tool schema must be a list[dict]: {content}") + return [] + + info_ = info.copy() + user_id = info_.pop("user_id", "") + session_id = info_.pop("session_id", "") + + return [ + TextualMemoryItem( + id=str(uuid.uuid4()), + memory=json.dumps(schema), + metadata=TreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type="ToolSchemaMemory", + status="activated", + embedding=self.embedder.embed([json.dumps(schema)])[0], + info=info_, + ), + ) + for schema in tool_schema + ] diff --git a/src/memos/mem_reader/read_multi_modal/tool_parser.py b/src/memos/mem_reader/read_multi_modal/tool_parser.py index 7a11d931a..09bd9e9d0 100644 --- a/src/memos/mem_reader/read_multi_modal/tool_parser.py +++ b/src/memos/mem_reader/read_multi_modal/tool_parser.py @@ -1,14 +1,20 @@ """Parser for tool messages.""" +import json + from typing import Any from memos.embedders.base import BaseEmbedder from memos.llms.base import BaseLLM from memos.log import get_logger -from memos.memories.textual.item import SourceMessage, TextualMemoryItem +from memos.memories.textual.item import ( + SourceMessage, + TextualMemoryItem, + TreeNodeTextualMemoryMetadata, +) from memos.types.openai_chat_completion_types import ChatCompletionToolMessageParam -from .base import BaseMessageParser, _extract_text_from_content +from .base import BaseMessageParser logger = get_logger(__name__) @@ -29,190 +35,155 @@ def __init__(self, embedder: BaseEmbedder, llm: BaseLLM | None = None): def create_source( self, - message: ChatCompletionToolMessageParam | dict[str, Any], + message: ChatCompletionToolMessageParam, info: dict[str, Any], - ) -> SourceMessage: - """Create SourceMessage from tool message or custom tool format.""" + ) -> SourceMessage | list[SourceMessage]: + """Create SourceMessage from tool message.""" + if not isinstance(message, dict): - return SourceMessage(type="chat", role="tool") - - # Handle custom tool formats (tool_description, tool_input, tool_output) - msg_type = message.get("type", "") - if msg_type == "tool_description": - name = message.get("name", "") - description = message.get("description", "") - parameters = message.get("parameters", {}) - content = f"[tool_description] name={name}, description={description}, parameters={parameters}" - return SourceMessage( - type="tool_description", - content=content, - original_part=message, - ) - elif msg_type == "tool_input": - call_id = message.get("call_id", "") - name = message.get("name", "") - argument = message.get("argument", {}) - content = f"[tool_input] call_id={call_id}, name={name}, argument={argument}" - return SourceMessage( - type="tool_input", - content=content, - message_id=call_id, - original_part=message, - ) - elif msg_type == "tool_output": - call_id = message.get("call_id", "") - name = message.get("name", "") - output = message.get("output", {}) - content = f"[tool_output] call_id={call_id}, name={name}, output={output}" - return SourceMessage( - type="tool_output", - content=content, - message_id=call_id, - original_part=message, - ) + return [] - # Handle standard tool message - content = _extract_text_from_content(message.get("content", "")) - return SourceMessage( - type="tool", - role="tool", - chat_time=message.get("chat_time"), - message_id=message.get("message_id"), - content=content, - ) + role = message.get("role", "tool") + raw_content = message.get("content", "") + tool_call_id = message.get("tool_call_id", "") + chat_time = message.get("chat_time") + message_id = message.get("message_id") + + sources = [] + + if isinstance(raw_content, list): + # Multimodal: create one SourceMessage per part + for part in raw_content: + if isinstance(part, dict): + part_type = part.get("type", "") + if part_type == "text": + sources.append( + SourceMessage( + type="text", + role=role, + chat_time=chat_time, + message_id=message_id, + content=part.get("text", ""), + tool_call_id=tool_call_id, + ) + ) + elif part_type == "file": + file_info = part.get("file", {}) + sources.append( + SourceMessage( + type="file", + role=role, + chat_time=chat_time, + message_id=message_id, + content=file_info.get("file_data", ""), + filename=file_info.get("filename", ""), + file_id=file_info.get("file_id", ""), + tool_call_id=tool_call_id, + original_part=part, + ) + ) + elif part_type == "image_url": + file_info = part.get("image_url", {}) + sources.append( + SourceMessage( + type="image_url", + role=role, + chat_time=chat_time, + message_id=message_id, + content=file_info.get("url", ""), + detail=file_info.get("detail", "auto"), + tool_call_id=tool_call_id, + original_part=part, + ) + ) + elif part_type == "input_audio": + file_info = part.get("input_audio", {}) + sources.append( + SourceMessage( + type="input_audio", + role=role, + chat_time=chat_time, + message_id=message_id, + content=file_info.get("data", ""), + format=file_info.get("format", "wav"), + tool_call_id=tool_call_id, + original_part=part, + ) + ) + else: + logger.warning(f"[ToolParser] Unsupported part type: {part_type}") + continue + else: + # Simple string content message: single SourceMessage + if raw_content: + sources.append( + SourceMessage( + type="chat", + role=role, + chat_time=chat_time, + message_id=message_id, + content=raw_content, + tool_call_id=tool_call_id, + ) + ) + + return sources def rebuild_from_source( self, source: SourceMessage, ) -> ChatCompletionToolMessageParam: """Rebuild tool message from SourceMessage.""" - return { - "role": "tool", - "content": source.content or "", - "tool_call_id": source.message_id or "", # tool_call_id might be in message_id - "chat_time": source.chat_time, - "message_id": source.message_id, - } def parse_fast( self, - message: ChatCompletionToolMessageParam | dict[str, Any], + message: ChatCompletionToolMessageParam, info: dict[str, Any], **kwargs, ) -> list[TextualMemoryItem]: - """ - Parse tool message in fast mode. - - Handles both standard tool messages and custom tool formats: - - Standard tool message: role="tool", content, tool_call_id - - Custom formats: tool_description, tool_input, tool_output + role = message.get("role", "") + content = message.get("content", "") + chat_time = message.get("chat_time", None) - Args: - message: Tool message to parse - info: Dictionary containing user_id and session_id - **kwargs: Additional parameters - - Returns: - List of TextualMemoryItem objects - """ - from memos.memories.textual.item import TreeNodeTextualMemoryMetadata - - from .base import _derive_key - - if not isinstance(message, dict): - logger.warning(f"[ToolParser] Expected dict, got {type(message)}") - return [] - - # Handle custom tool formats (tool_description, tool_input, tool_output) - msg_type = message.get("type", "") - if msg_type in ("tool_description", "tool_input", "tool_output"): - # Create source - source = self.create_source(message, info) - content = source.content or "" - if not content: - return [] - - # Extract info fields - info_ = info.copy() - user_id = info_.pop("user_id", "") - session_id = info_.pop("session_id", "") - - # Create memory item - memory_item = TextualMemoryItem( - memory=content, - metadata=TreeNodeTextualMemoryMetadata( - user_id=user_id, - session_id=session_id, - memory_type="LongTermMemory", - status="activated", - tags=["mode:fast"], - key=_derive_key(content), - embedding=self.embedder.embed([content])[0], - usage=[], - sources=[source], - background="", - confidence=0.99, - type="fact", - info=info_, - ), - ) - return [memory_item] - - # Handle standard tool message (role="tool") - role = message.get("role", "").strip().lower() if role != "tool": - logger.warning(f"[ToolParser] Expected role='tool', got role='{role}'") + logger.warning(f"[ToolParser] Expected role is `tool`, got {role}") return [] - - # Extract content from tool message - content = _extract_text_from_content(message.get("content", "")) - if not content: - return [] - - # Build formatted line similar to assistant_parser - tool_call_id = message.get("tool_call_id", "") - chat_time = message.get("chat_time") - parts = [f"{role}: "] if chat_time: parts.append(f"[{chat_time}]: ") - if tool_call_id: - parts.append(f"[tool_call_id: {tool_call_id}]: ") prefix = "".join(parts) + content = json.dumps(content) if isinstance(content, list | dict) else content line = f"{prefix}{content}\n" + if not line: + return [] - # Create source - source = self.create_source(message, info) + sources = self.create_source(message, info) # Extract info fields info_ = info.copy() user_id = info_.pop("user_id", "") session_id = info_.pop("session_id", "") - # Tool messages are typically LongTermMemory (they're system/assistant tool results) - memory_type = "LongTermMemory" - - # Create memory item - memory_item = TextualMemoryItem( - memory=line, - metadata=TreeNodeTextualMemoryMetadata( - user_id=user_id, - session_id=session_id, - memory_type=memory_type, - status="activated", - tags=["mode:fast"], - key=_derive_key(line), - embedding=self.embedder.embed([line])[0], - usage=[], - sources=[source], - background="", - confidence=0.99, - type="fact", - info=info_, - ), - ) - - return [memory_item] + content_chunks = self._split_text(line) + memory_items = [] + for _chunk_idx, chunk_text in enumerate(content_chunks): + if not chunk_text.strip(): + continue + + memory_item = TextualMemoryItem( + memory=chunk_text, + metadata=TreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type="LongTermMemory", # only choce long term memory for tool messages as a placeholder + status="activated", + tags=["mode:fast"], + sources=sources, + info=info_, + ), + ) + memory_items.append(memory_item) + return memory_items def parse_fine( self, @@ -220,4 +191,5 @@ def parse_fine( info: dict[str, Any], **kwargs, ) -> list[TextualMemoryItem]: + # tool message no special multimodal handling is required in fine mode. return [] diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 53a7de035..7f7b16234 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -223,6 +223,7 @@ def _make_memory_item( background: str = "", type_: str = "fact", confidence: float = 0.99, + **kwargs, ) -> TextualMemoryItem: """construct memory item""" info_ = info.copy() @@ -245,6 +246,7 @@ def _make_memory_item( confidence=confidence, type=type_, info=info_, + **kwargs, ), ) diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index a85c533a0..f99360a86 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -159,6 +159,8 @@ def mix_search_memories( search_filter=search_filter, search_priority=search_priority, info=info, + search_tool_memory=search_req.search_tool_memory, + tool_mem_top_k=search_req.tool_mem_top_k, ) # Try to get pre-computed memories if available @@ -182,6 +184,8 @@ def mix_search_memories( top_k=search_req.top_k, user_name=user_context.mem_cube_id, info=info, + search_tool_memory=search_req.search_tool_memory, + tool_mem_top_k=search_req.tool_mem_top_k, ) memories = merged_memories[: search_req.top_k] diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index b7956bfec..75a16bace 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -24,7 +24,7 @@ class SourceMessage(BaseModel): - type: Source kind (e.g., "chat", "doc", "web", "file", "system", ...). If not provided, upstream logic may infer it: presence of `role` ⇒ "chat"; otherwise ⇒ "doc". - - role: Conversation role ("user" | "assistant" | "system") when the + - role: Conversation role ("user" | "assistant" | "system" | "tool") when the source is a chat turn. - content: Minimal reproducible snippet from the source. If omitted, upstream may fall back to `doc_path` / `url` / `message_id`. @@ -99,9 +99,14 @@ def __str__(self) -> str: class TreeNodeTextualMemoryMetadata(TextualMemoryMetadata): """Extended metadata for structured memory, layered retrieval, and lifecycle tracking.""" - memory_type: Literal["WorkingMemory", "LongTermMemory", "UserMemory", "OuterMemory"] = Field( - default="WorkingMemory", description="Memory lifecycle type." - ) + memory_type: Literal[ + "WorkingMemory", + "LongTermMemory", + "UserMemory", + "OuterMemory", + "ToolSchemaMemory", + "ToolTrajectoryMemory", + ] = Field(default="WorkingMemory", description="Memory lifecycle type.") sources: list[SourceMessage] | None = Field( default=None, description="Multiple origins of the memory (e.g., URLs, notes)." ) diff --git a/src/memos/memories/textual/prefer_text_memory/spliter.py b/src/memos/memories/textual/prefer_text_memory/spliter.py index 3059d611b..a54036778 100644 --- a/src/memos/memories/textual/prefer_text_memory/spliter.py +++ b/src/memos/memories/textual/prefer_text_memory/spliter.py @@ -87,7 +87,7 @@ def _split_with_overlap(self, data: MessageList) -> list[MessageList]: # overlap 1 turns (Q + A = 2) context = copy.deepcopy(chunk[-2:]) if i + 1 < len(data) else [] chunk = context - if chunk and len(chunk) % 2 == 0: + if chunk: chunks.append(chunk) return chunks diff --git a/src/memos/memories/textual/prefer_text_memory/utils.py b/src/memos/memories/textual/prefer_text_memory/utils.py index 76d4b4211..03d2ef923 100644 --- a/src/memos/memories/textual/prefer_text_memory/utils.py +++ b/src/memos/memories/textual/prefer_text_memory/utils.py @@ -1,3 +1,4 @@ +import json import re from memos.dependency import require_python_package @@ -9,12 +10,36 @@ def convert_messages_to_string(messages: MessageList) -> str: """Convert a list of messages to a string.""" message_text = "" for message in messages: + content = message.get("content", "") + content = ( + content.strip() + if isinstance(content, str) + else json.dumps(content, ensure_ascii=False).strip() + ) + if message["role"] == "system": + continue if message["role"] == "user": - message_text += f"Query: {message['content']}\n" if message["content"].strip() else "" + message_text += f"User: {content}\n" if content else "" elif message["role"] == "assistant": - message_text += f"Answer: {message['content']}\n" if message["content"].strip() else "" - message_text = message_text.strip() - return message_text + tool_calls = message.get("tool_calls", []) + tool_calls_str = ( + f"[tool_calls]: {json.dumps(tool_calls, ensure_ascii=False)}" if tool_calls else "" + ) + line_str = ( + f"Assistant: {content} {tool_calls_str}".strip() + if content or tool_calls_str + else "" + ) + message_text += f"{line_str}\n" if line_str else "" + elif message["role"] == "tool": + tool_call_id = message.get("tool_call_id", "") + line_str = ( + f"Tool: {content} [tool_call_id]: {tool_call_id}".strip() + if tool_call_id + else f"Tool: {content}".strip() + ) + message_text += f"{line_str}\n" if line_str else "" + return message_text.strip() @require_python_package( diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index ad2bcd9c4..cad850d2d 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -166,6 +166,8 @@ def search( search_priority: dict | None = None, search_filter: dict | None = None, user_name: str | None = None, + search_tool_memory: bool = False, + tool_mem_top_k: int = 6, **kwargs, ) -> list[TextualMemoryItem]: """Search for memories based on a query. @@ -223,6 +225,8 @@ def search( search_priority, user_name=user_name, plugin=kwargs.get("plugin", False), + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, ) def get_relevant_subgraph( diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index a71fee02f..3226f7ca0 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -181,12 +181,18 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non working_id = str(uuid.uuid4()) with ContextThreadPoolExecutor(max_workers=2, thread_name_prefix="mem") as ex: - f_working = ex.submit( - self._add_memory_to_db, memory, "WorkingMemory", user_name, working_id - ) - futures.append(("working", f_working)) - - if memory.metadata.memory_type in ("LongTermMemory", "UserMemory"): + if memory.metadata.memory_type not in ("ToolSchemaMemory", "ToolTrajectoryMemory"): + f_working = ex.submit( + self._add_memory_to_db, memory, "WorkingMemory", user_name, working_id + ) + futures.append(("working", f_working)) + + if memory.metadata.memory_type in ( + "LongTermMemory", + "UserMemory", + "ToolSchemaMemory", + "ToolTrajectoryMemory", + ): f_graph = ex.submit( self._add_to_graph_memory, memory=memory, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 5dfbde704..dea83887e 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -59,7 +59,13 @@ def retrieve( Returns: list: Combined memory items. """ - if memory_scope not in ["WorkingMemory", "LongTermMemory", "UserMemory"]: + if memory_scope not in [ + "WorkingMemory", + "LongTermMemory", + "UserMemory", + "ToolSchemaMemory", + "ToolTrajectoryMemory", + ]: raise ValueError(f"Unsupported memory scope: {memory_scope}") if memory_scope == "WorkingMemory": diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 830b915c1..0666f1d86 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -76,6 +76,8 @@ def retrieve( search_filter: dict | None = None, search_priority: dict | None = None, user_name: str | None = None, + search_tool_memory: bool = False, + tool_mem_top_k: int = 6, **kwargs, ) -> list[tuple[TextualMemoryItem, float]]: logger.info( @@ -100,6 +102,8 @@ def retrieve( search_filter, search_priority, user_name, + search_tool_memory, + tool_mem_top_k, ) return results @@ -109,10 +113,14 @@ def post_retrieve( top_k: int, user_name: str | None = None, info=None, + search_tool_memory: bool = False, + tool_mem_top_k: int = 6, plugin=False, ): deduped = self._deduplicate_results(retrieved_results) - final_results = self._sort_and_trim(deduped, top_k, plugin) + final_results = self._sort_and_trim( + deduped, top_k, plugin, search_tool_memory, tool_mem_top_k + ) self._update_usage_history(final_results, info, user_name) return final_results @@ -127,6 +135,8 @@ def search( search_filter: dict | None = None, search_priority: dict | None = None, user_name: str | None = None, + search_tool_memory: bool = False, + tool_mem_top_k: int = 6, **kwargs, ) -> list[TextualMemoryItem]: """ @@ -171,6 +181,8 @@ def search( search_filter=search_filter, search_priority=search_priority, user_name=user_name, + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, ) final_results = self.post_retrieve( @@ -179,6 +191,8 @@ def search( user_name=user_name, info=None, plugin=kwargs.get("plugin", False), + search_tool_memory=search_tool_memory, + tool_mem_top_k=tool_mem_top_k, ) logger.info(f"[SEARCH] Done. Total {len(final_results)} results.") @@ -272,6 +286,8 @@ def _retrieve_paths( search_filter: dict | None = None, search_priority: dict | None = None, user_name: str | None = None, + search_tool_memory: bool = False, + tool_mem_top_k: int = 6, ): """Run A/B/C retrieval paths in parallel""" tasks = [] @@ -324,6 +340,22 @@ def _retrieve_paths( user_name, ) ) + if search_tool_memory: + tasks.append( + executor.submit( + self._retrieve_from_tool_memory, + query, + parsed_goal, + query_embedding, + top_k, + memory_type, + search_filter, + search_priority, + user_name, + id_filter, + mode=mode, + ) + ) results = [] for t in tasks: @@ -498,6 +530,98 @@ def _retrieve_from_internet( parsed_goal=parsed_goal, ) + # --- Path D + @timed + def _retrieve_from_tool_memory( + self, + query, + parsed_goal, + query_embedding, + top_k, + memory_type, + search_filter: dict | None = None, + search_priority: dict | None = None, + user_name: str | None = None, + id_filter: dict | None = None, + mode: str = "fast", + ): + """Retrieve and rerank from ToolMemory""" + results = { + "ToolSchemaMemory": [], + "ToolTrajectoryMemory": [], + } + tasks = [] + + # chain of thinking + cot_embeddings = [] + if self.vec_cot: + queries = self._cot_query(query, mode=mode, context=parsed_goal.context) + if len(queries) > 1: + cot_embeddings = self.embedder.embed(queries) + cot_embeddings.extend(query_embedding) + else: + cot_embeddings = query_embedding + + with ContextThreadPoolExecutor(max_workers=2) as executor: + if memory_type in ["All", "ToolSchemaMemory"]: + tasks.append( + executor.submit( + self.graph_retriever.retrieve, + query=query, + parsed_goal=parsed_goal, + query_embedding=cot_embeddings, + top_k=top_k * 2, + memory_scope="ToolSchemaMemory", + search_filter=search_filter, + search_priority=search_priority, + user_name=user_name, + id_filter=id_filter, + use_fast_graph=self.use_fast_graph, + ) + ) + if memory_type in ["All", "ToolTrajectoryMemory"]: + tasks.append( + executor.submit( + self.graph_retriever.retrieve, + query=query, + parsed_goal=parsed_goal, + query_embedding=cot_embeddings, + top_k=top_k * 2, + memory_scope="ToolTrajectoryMemory", + search_filter=search_filter, + search_priority=search_priority, + user_name=user_name, + id_filter=id_filter, + use_fast_graph=self.use_fast_graph, + ) + ) + + # Collect results from all tasks + for task in tasks: + rsp = task.result() + if rsp and rsp[0].metadata.memory_type == "ToolSchemaMemory": + results["ToolSchemaMemory"].extend(rsp) + elif rsp and rsp[0].metadata.memory_type == "ToolTrajectoryMemory": + results["ToolTrajectoryMemory"].extend(rsp) + + schema_reranked = self.reranker.rerank( + query=query, + query_embedding=query_embedding[0], + graph_results=results["ToolSchemaMemory"], + top_k=top_k, + parsed_goal=parsed_goal, + search_filter=search_filter, + ) + trajectory_reranked = self.reranker.rerank( + query=query, + query_embedding=query_embedding[0], + graph_results=results["ToolTrajectoryMemory"], + top_k=top_k, + parsed_goal=parsed_goal, + search_filter=search_filter, + ) + return schema_reranked + trajectory_reranked + @timed def _retrieve_simple( self, @@ -554,11 +678,41 @@ def _deduplicate_results(self, results): return list(deduped.values()) @timed - def _sort_and_trim(self, results, top_k, plugin=False): + def _sort_and_trim( + self, results, top_k, plugin=False, search_tool_memory=False, tool_mem_top_k=6 + ): """Sort results by score and trim to top_k""" + final_items = [] + if search_tool_memory: + tool_results = [ + (item, score) + for item, score in results + if item.metadata.memory_type in ["ToolSchemaMemory", "ToolTrajectoryMemory"] + ] + sorted_tool_results = sorted(tool_results, key=lambda pair: pair[1], reverse=True)[ + :tool_mem_top_k + ] + for item, score in sorted_tool_results: + if plugin and round(score, 2) == 0.00: + continue + meta_data = item.metadata.model_dump() + meta_data["relativity"] = score + final_items.append( + TextualMemoryItem( + id=item.id, + memory=item.memory, + metadata=SearchedTreeNodeTextualMemoryMetadata(**meta_data), + ) + ) + # separate textual results + results = [ + (item, score) + for item, score in results + if item.metadata.memory_type not in ["ToolSchemaMemory", "ToolTrajectoryMemory"] + ] sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)[:top_k] - final_items = [] + for item, score in sorted_results: if plugin and round(score, 2) == 0.00: continue diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index b5bd34417..1ddd2b1b7 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -11,6 +11,7 @@ from memos.api.handlers.formatters_handler import ( format_memory_item, post_process_pref_mem, + post_process_textual_mem, ) from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger @@ -109,6 +110,7 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: "para_mem": [], "pref_mem": [], "pref_note": "", + "tool_mem": [], } # Determine search mode @@ -123,11 +125,10 @@ def search_memories(self, search_req: APISearchRequest) -> dict[str, Any]: pref_formatted_memories = pref_future.result() # Build result - memories_result["text_mem"].append( - { - "cube_id": self.cube_id, - "memories": text_formatted_memories, - } + memories_result = post_process_textual_mem( + memories_result, + text_formatted_memories, + self.cube_id, ) memories_result = post_process_pref_mem( @@ -278,6 +279,8 @@ def _fine_search( Returns: List of enhanced search results """ + # TODO: support tool memory search in future + logger.info(f"Fine strategy: {FINE_STRATEGY}") if FINE_STRATEGY == FineStrategy.DEEP_SEARCH: return self._deep_search(search_req=search_req, user_context=user_context) @@ -375,6 +378,9 @@ def _search_pref( """ if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": return [] + if not search_req.include_preference: + return [] + logger.info(f"search_req.filter for preference memory: {search_req.filter}") logger.info(f"type of pref_mem: {type(self.naive_mem_cube.pref_mem)}") try: @@ -427,6 +433,8 @@ def _fast_search( "chat_history": search_req.chat_history, }, plugin=plugin, + search_tool_memory=search_req.search_tool_memory, + tool_mem_top_k=search_req.tool_mem_top_k, ) formatted_memories = [format_memory_item(data) for data in search_results] @@ -543,6 +551,13 @@ def _process_pref_mem( if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": return [] + if add_req.messages is None or isinstance(add_req.messages, str): + return [] + + for message in add_req.messages: + if message.get("role", None) is None: + return [] + target_session_id = add_req.session_id or "default_session" if sync_mode == "async": diff --git a/src/memos/templates/tool_mem_prompts.py b/src/memos/templates/tool_mem_prompts.py new file mode 100644 index 000000000..7d5363956 --- /dev/null +++ b/src/memos/templates/tool_mem_prompts.py @@ -0,0 +1,84 @@ +TOOL_TRAJECTORY_PROMPT_ZH = """ +你是一个专业的工具调用轨迹提取专家。你的任务是从给定的对话消息中提取完整的工具调用轨迹经验。 + +## 提取规则: +1. 只有当对话中存在有价值的工具调用过程时才进行提取 +2. 有价值的轨迹至少包含以下元素: + - 用户的问题(user message) + - 助手的工具调用尝试(assistant message with tool_calls) + - 工具的执行结果(tool message with tool_call_id and content,无论成功或失败) + - 助手的响应(assistant message,无论是否给出最终答案) + +## 输出格式: +返回一个JSON数组,格式如下: +```json +[ + { + "trajectory": "自然语言输出包含'任务、使用的工具、工具观察、最终回答'的完整精炼的总结,体现顺序", + "tool_used_status": [ + { + "used_tool": "工具名1", + "success_rate": "0.0-1.0之间的数值,表示该工具在本次轨迹中的成功率", + "error_type": "调用失败时的错误类型和描述,成功时为空字符串", + "experience": "该工具的使用经验,比如常见的参数模式、执行特点、结果解读方式等" + } + ] + } +] +``` + +## 注意事项: +- 如果对话中没有完整的工具调用轨迹,返回空数组 +- 每个轨迹必须是独立的完整过程 +- 一个轨迹中可能涉及多个工具的使用,每个工具在tool_used_status中独立记录 +- 只提取事实内容,不要添加任何解释或额外信息 +- 确保返回的是有效的JSON格式 + +请分析以下对话消息并提取工具调用轨迹: + +{messages} + +""" + + +TOOL_TRAJECTORY_PROMPT_EN = """ +You are a professional tool call trajectory extraction expert. Your task is to extract valuable tool call trajectory experiences from given conversation messages. + +## Extraction Rules: +1. Only extract when there are valuable tool calling processes in the conversation +2. Valuable trajectories must contain at least the following elements: + - User's question (user message) + - Assistant's tool call attempt (assistant message with tool_calls) + - Tool execution results (tool message with tool_call_id and content, regardless of success or failure) + - Assistant's response (assistant message, whether or not a final answer is given) + +## Output Format: +Return a JSON array in the following format: +```json +[ + { + "trajectory": "Natural language summary containing 'task, tools used, tool observations, final answer' in a complete and refined manner, reflecting the sequence", + "tool_used_status": [ + { + "used_tool": "Tool Name 1", + "success_rate": "Numerical value between 0.0-1.0, indicating the success rate of this tool in the current trajectory", + "error_type": "Error type and description when call fails, empty string when successful", + "experience": "Usage experience of this tool, such as common parameter patterns, execution characteristics, result interpretation methods, etc." + } + ] + } +] +``` + +## Notes: +- If there are no complete tool call trajectories in the conversation, return an empty array +- Each trajectory must be an independent complete process +- Multiple tools may be used in one trajectory, each tool is recorded independently in tool_used_status +- Only extract factual content, do not add any additional explanations or information +- Ensure the returned content is valid JSON format + +Please analyze the following conversation messages and extract tool call trajectories: + +{messages} + +""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py index a742de3a9..3c5638788 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py @@ -2,7 +2,6 @@ from __future__ import annotations -from collections.abc import Iterable from typing import Literal, TypeAlias from typing_extensions import Required, TypedDict @@ -35,7 +34,7 @@ class ChatCompletionAssistantMessageParam(TypedDict, total=False): [Learn more](https://platform.openai.com/docs/guides/audio). """ - content: str | Iterable[ContentArrayOfContentPart] | None + content: str | list[ContentArrayOfContentPart] | ContentArrayOfContentPart | None """The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified. @@ -44,7 +43,9 @@ class ChatCompletionAssistantMessageParam(TypedDict, total=False): refusal: str | None """The refusal message by the assistant.""" - tool_calls: Iterable[ChatCompletionMessageToolCallUnionParam] + tool_calls: ( + list[ChatCompletionMessageToolCallUnionParam] | ChatCompletionMessageToolCallUnionParam + ) """The tool calls generated by the model, such as function calls.""" chat_time: str | None diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py index 7faa90e2e..ea2101229 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py @@ -2,7 +2,6 @@ from __future__ import annotations -from collections.abc import Iterable from typing import Literal from typing_extensions import Required, TypedDict @@ -14,7 +13,9 @@ class ChatCompletionSystemMessageParam(TypedDict, total=False): - content: Required[str | Iterable[ChatCompletionContentPartTextParam]] + content: Required[ + str | list[ChatCompletionContentPartTextParam] | ChatCompletionContentPartTextParam + ] """The contents of the system message.""" role: Required[Literal["system"]] diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py index c03220915..99c845d11 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py @@ -2,7 +2,6 @@ from __future__ import annotations -from collections.abc import Iterable from typing import Literal from typing_extensions import Required, TypedDict @@ -14,7 +13,7 @@ class ChatCompletionToolMessageParam(TypedDict, total=False): - content: Required[str | Iterable[ChatCompletionContentPartParam]] + content: Required[str | list[ChatCompletionContentPartParam] | ChatCompletionContentPartParam] """The contents of the tool message.""" role: Required[Literal["tool"]] diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py index 2c2a1f23f..8c004f340 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py @@ -2,7 +2,6 @@ from __future__ import annotations -from collections.abc import Iterable from typing import Literal from typing_extensions import Required, TypedDict @@ -14,7 +13,7 @@ class ChatCompletionUserMessageParam(TypedDict, total=False): - content: Required[str | Iterable[ChatCompletionContentPartParam]] + content: Required[str | list[ChatCompletionContentPartParam] | ChatCompletionContentPartParam] """The contents of the user message.""" role: Required[Literal["user"]]