From 6dcb8d377fb4ad4c2495c1056637078b5988c65b Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 14 Nov 2025 16:16:39 +0800 Subject: [PATCH 01/25] new type --- src/memos/types/__init__.py | 1 + .../openai_chat_completion_types/__init__.py | 17 +++++++ ...chat_completion_assistant_message_param.py | 48 +++++++++++++++++++ ...hat_completion_content_part_image_param.py | 24 ++++++++++ ...mpletion_content_part_input_audio_param.py | 20 ++++++++ .../chat_completion_content_part_param.py | 39 +++++++++++++++ ...t_completion_content_part_refusal_param.py | 13 +++++ ...chat_completion_content_part_text_param.py | 13 +++++ ...mpletion_message_custom_tool_call_param.py | 24 ++++++++++ ...letion_message_function_tool_call_param.py | 29 +++++++++++ .../chat_completion_message_param.py | 18 +++++++ ...ompletion_message_tool_call_union_param.py | 13 +++++ .../chat_completion_system_message_param.py | 30 ++++++++++++ .../chat_completion_tool_message_param.py | 28 +++++++++++ .../chat_completion_user_message_param.py | 30 ++++++++++++ src/memos/{ => types}/types.py | 23 ++++++++- 16 files changed, 369 insertions(+), 1 deletion(-) create mode 100644 src/memos/types/__init__.py create mode 100644 src/memos/types/openai_chat_completion_types/__init__.py create mode 100644 src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py create mode 100644 src/memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py create mode 100644 src/memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py create mode 100644 src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py create mode 100644 src/memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py create mode 100644 src/memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py create mode 100644 src/memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py create mode 100644 src/memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py create mode 100644 src/memos/types/openai_chat_completion_types/chat_completion_message_param.py create mode 100644 src/memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py create mode 100644 src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py create mode 100644 src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py create mode 100644 src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py rename src/memos/{ => types}/types.py (83%) diff --git a/src/memos/types/__init__.py b/src/memos/types/__init__.py new file mode 100644 index 000000000..4192f6a10 --- /dev/null +++ b/src/memos/types/__init__.py @@ -0,0 +1 @@ +from .types import * \ No newline at end of file diff --git a/src/memos/types/openai_chat_completion_types/__init__.py b/src/memos/types/openai_chat_completion_types/__init__.py new file mode 100644 index 000000000..3d742fe3b --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/__init__.py @@ -0,0 +1,17 @@ +from .chat_completion_message_param import * + +from .chat_completion_assistant_message_param import * +from .chat_completion_system_message_param import * +from .chat_completion_tool_message_param import * +from .chat_completion_user_message_param import * + +from .chat_completion_message_custom_tool_call_param import * +from .chat_completion_message_function_tool_call_param import * +from .chat_completion_message_tool_call_union_param import * + +from .chat_completion_content_part_input_audio_param import * +from .chat_completion_content_part_image_param import * +from .chat_completion_content_part_refusal_param import * +from .chat_completion_content_part_text_param import * +from .chat_completion_content_part_param import * + diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py new file mode 100644 index 000000000..698f2a6e0 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import Union, Iterable, Optional +from typing_extensions import Literal, Required, TypeAlias, TypedDict + +from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam +from .chat_completion_content_part_refusal_param import ChatCompletionContentPartRefusalParam +from .chat_completion_message_tool_call_union_param import ChatCompletionMessageToolCallUnionParam + +__all__ = ["ChatCompletionAssistantMessageParam", "Audio", "ContentArrayOfContentPart"] + + +class Audio(TypedDict, total=False): + id: Required[str] + """Unique identifier for a previous audio response from the model.""" + + +ContentArrayOfContentPart: TypeAlias = Union[ChatCompletionContentPartTextParam, ChatCompletionContentPartRefusalParam] + + +class ChatCompletionAssistantMessageParam(TypedDict, total=False): + role: Required[Literal["assistant"]] + """The role of the messages author, in this case `assistant`.""" + + audio: Optional[Audio] + """ + Data about a previous audio response from the model. + [Learn more](https://platform.openai.com/docs/guides/audio). + """ + + content: Union[str, Iterable[ContentArrayOfContentPart], None] + """The contents of the assistant message. + + Required unless `tool_calls` or `function_call` is specified. + """ + + refusal: Optional[str] + """The refusal message by the assistant.""" + + tool_calls: Iterable[ChatCompletionMessageToolCallUnionParam] + """The tool calls generated by the model, such as function calls.""" + + chat_time: Optional[str] + """Optional timestamp for the message, format is not + restricted, it can be any vague or precise time string.""" + + message_id: Optional[str] + """Optional unique identifier for the message""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py new file mode 100644 index 000000000..f57ab33cb --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["ChatCompletionContentPartImageParam", "ImageURL"] + + +class ImageURL(TypedDict, total=False): + url: Required[str] + """Either a URL of the image or the base64 encoded image data.""" + + detail: Literal["auto", "low", "high"] + """Specifies the detail level of the image. + + Learn more in the + [Vision guide](https://platform.openai.com/docs/guides/vision#low-or-high-fidelity-image-understanding). + """ + + +class ChatCompletionContentPartImageParam(TypedDict, total=False): + image_url: Required[ImageURL] + + type: Required[Literal["image_url"]] + """The type of the content part.""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py new file mode 100644 index 000000000..be90f84db --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["ChatCompletionContentPartInputAudioParam", "InputAudio"] + + +class InputAudio(TypedDict, total=False): + data: Required[str] + """Base64 encoded audio data.""" + + format: Required[Literal["wav", "mp3"]] + """The format of the encoded audio data. Currently supports "wav" and "mp3".""" + + +class ChatCompletionContentPartInputAudioParam(TypedDict, total=False): + input_audio: Required[InputAudio] + + type: Required[Literal["input_audio"]] + """The type of the content part. Always `input_audio`.""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py new file mode 100644 index 000000000..65ce3b2ee --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Union +from typing_extensions import Literal, Required, TypeAlias, TypedDict + +from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam +from .chat_completion_content_part_image_param import ChatCompletionContentPartImageParam +from .chat_completion_content_part_input_audio_param import ChatCompletionContentPartInputAudioParam + +__all__ = ["ChatCompletionContentPartParam", "File", "FileFile"] + + +class FileFile(TypedDict, total=False): + file_data: str + """ + The base64 encoded file data, used when passing the file to the model as a + string. + """ + + file_id: str + """The ID of an uploaded file to use as input.""" + + filename: str + """The name of the file, used when passing the file to the model as a string.""" + + +class File(TypedDict, total=False): + file: Required[FileFile] + + type: Required[Literal["file"]] + """The type of the content part. Always `file`.""" + + +ChatCompletionContentPartParam: TypeAlias = Union[ + ChatCompletionContentPartTextParam, + ChatCompletionContentPartImageParam, + ChatCompletionContentPartInputAudioParam, + File, +] diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py new file mode 100644 index 000000000..f239c48d5 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["ChatCompletionContentPartRefusalParam"] + + +class ChatCompletionContentPartRefusalParam(TypedDict, total=False): + refusal: Required[str] + """The refusal message generated by the model.""" + + type: Required[Literal["refusal"]] + """The type of the content part.""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py new file mode 100644 index 000000000..e15461ab4 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["ChatCompletionContentPartTextParam"] + + +class ChatCompletionContentPartTextParam(TypedDict, total=False): + text: Required[str] + """The text content.""" + + type: Required[Literal["text"]] + """The type of the content part.""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py new file mode 100644 index 000000000..8bcba4c59 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["ChatCompletionMessageCustomToolCallParam", "Custom"] + + +class Custom(TypedDict, total=False): + input: Required[str] + """The input for the custom tool call generated by the model.""" + + name: Required[str] + """The name of the custom tool to call.""" + + +class ChatCompletionMessageCustomToolCallParam(TypedDict, total=False): + id: Required[str] + """The ID of the tool call.""" + + custom: Required[Custom] + """The custom tool that the model called.""" + + type: Required[Literal["custom"]] + """The type of the tool. Always `custom`.""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py new file mode 100644 index 000000000..01a910787 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["ChatCompletionMessageFunctionToolCallParam", "Function"] + + +class Function(TypedDict, total=False): + arguments: Required[str] + """ + The arguments to call the function with, as generated by the model in JSON + format. Note that the model does not always generate valid JSON, and may + hallucinate parameters not defined by your function schema. Validate the + arguments in your code before calling your function. + """ + + name: Required[str] + """The name of the function to call.""" + + +class ChatCompletionMessageFunctionToolCallParam(TypedDict, total=False): + id: Required[str] + """The ID of the tool call.""" + + function: Required[Function] + """The function that the model called.""" + + type: Required[Literal["function"]] + """The type of the tool. Currently, only `function` is supported.""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_param.py new file mode 100644 index 000000000..5beee37a9 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_param.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from typing import Union +from typing_extensions import TypeAlias + +from .chat_completion_tool_message_param import ChatCompletionToolMessageParam +from .chat_completion_user_message_param import ChatCompletionUserMessageParam +from .chat_completion_system_message_param import ChatCompletionSystemMessageParam +from .chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam + +__all__ = ["ChatCompletionMessageParam"] + +ChatCompletionMessageParam: TypeAlias = Union[ + ChatCompletionSystemMessageParam, + ChatCompletionUserMessageParam, + ChatCompletionAssistantMessageParam, + ChatCompletionToolMessageParam, +] diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py new file mode 100644 index 000000000..97eccf344 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +from typing import Union +from typing_extensions import TypeAlias + +from .chat_completion_message_custom_tool_call_param import ChatCompletionMessageCustomToolCallParam +from .chat_completion_message_function_tool_call_param import ChatCompletionMessageFunctionToolCallParam + +__all__ = ["ChatCompletionMessageToolCallUnionParam"] + +ChatCompletionMessageToolCallUnionParam: TypeAlias = Union[ + ChatCompletionMessageFunctionToolCallParam, ChatCompletionMessageCustomToolCallParam +] diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py new file mode 100644 index 000000000..544f0e977 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import Union, Iterable, Optional +from typing_extensions import Literal, Required, TypedDict + +from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam + +__all__ = ["ChatCompletionSystemMessageParam"] + + +class ChatCompletionSystemMessageParam(TypedDict, total=False): + content: Required[Union[str, Iterable[ChatCompletionContentPartTextParam]]] + """The contents of the system message.""" + + role: Required[Literal["system"]] + """The role of the messages author, in this case `system`.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the same + role. + """ + + chat_time: Optional[str] + """Optional timestamp for the message, format is not + restricted, it can be any vague or precise time string.""" + + message_id: Optional[str] + """Optional unique identifier for the message""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py new file mode 100644 index 000000000..8fb75fe35 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py @@ -0,0 +1,28 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Union, Iterable, Optional +from typing_extensions import Literal, Required, TypedDict + +from .chat_completion_content_part_param import ChatCompletionContentPartParam + +__all__ = ["ChatCompletionToolMessageParam"] + + +class ChatCompletionToolMessageParam(TypedDict, total=False): + content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]] + """The contents of the tool message.""" + + role: Required[Literal["tool"]] + """The role of the messages author, in this case `tool`.""" + + tool_call_id: Required[str] + """Tool call that this message is responding to.""" + + chat_time: Optional[str] + """Optional timestamp for the message, format is not + restricted, it can be any vague or precise time string.""" + + message_id: Optional[str] + """Optional unique identifier for the message""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py new file mode 100644 index 000000000..c48240c71 --- /dev/null +++ b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from typing import Union, Iterable, Optional +from typing_extensions import Literal, Required, TypedDict + +from .chat_completion_content_part_param import ChatCompletionContentPartParam + +__all__ = ["ChatCompletionUserMessageParam"] + + +class ChatCompletionUserMessageParam(TypedDict, total=False): + content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]] + """The contents of the user message.""" + + role: Required[Literal["user"]] + """The role of the messages author, in this case `user`.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the same + role. + """ + + chat_time: Optional[str] + """Optional timestamp for the message, format is not + restricted, it can be any vague or precise time string.""" + + message_id: Optional[str] + """Optional unique identifier for the message""" diff --git a/src/memos/types.py b/src/memos/types/types.py similarity index 83% rename from src/memos/types.py rename to src/memos/types/types.py index 635fabccc..dae741afc 100644 --- a/src/memos/types.py +++ b/src/memos/types/types.py @@ -13,8 +13,20 @@ from memos.memories.activation.item import ActivationMemoryItem from memos.memories.parametric.item import ParametricMemoryItem from memos.memories.textual.item import TextualMemoryItem +from .openai_chat_completion_types import ChatCompletionMessageParam, ChatCompletionContentPartTextParam, File +__all__ = [ + "MessageRole", + "MessageDict", + "MessageList", + "ChatHistory", + "MOSSearchResult", + "Permission", + "PermissionDict", + "UserContext", +] + # ─── Message Types ────────────────────────────────────────────────────────────── # Chat message roles @@ -32,8 +44,17 @@ class MessageDict(TypedDict, total=False): message_id: str | None # Optional unique identifier for the message +RawMessageDict: TypeAlias = ChatCompletionContentPartTextParam | File + + # Message collections -MessageList: TypeAlias = list[MessageDict] +MessageList: TypeAlias = list[ChatCompletionMessageParam] +RawMessageList: TypeAlias = list[RawMessageDict] + + +# Messages Type +MessagesType: TypeAlias = str | MessageList | RawMessageList + # Chat history structure From 56dc725b514d0d56bed484e6309f9bca8d8081cc Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 18 Nov 2025 18:47:11 +0800 Subject: [PATCH 02/25] llm reconstruct and add search api modify --- src/memos/api/handlers/scheduler_handler.py | 8 +- src/memos/api/product_models.py | 41 +++- src/memos/api/routers/server_router.py | 12 +- src/memos/configs/llm.py | 42 ++++- src/memos/llms/deepseek.py | 41 ---- src/memos/llms/hf.py | 32 ++-- src/memos/llms/ollama.py | 66 ++++++- src/memos/llms/openai.py | 181 +++++++++--------- src/memos/llms/openai_new.py | 196 ++++++++++++++++++++ src/memos/llms/qwen.py | 50 ----- src/memos/llms/vllm.py | 97 +++++++--- 11 files changed, 526 insertions(+), 240 deletions(-) create mode 100644 src/memos/llms/openai_new.py diff --git a/src/memos/api/handlers/scheduler_handler.py b/src/memos/api/handlers/scheduler_handler.py index 8d3c6dc70..32b312f8a 100644 --- a/src/memos/api/handlers/scheduler_handler.py +++ b/src/memos/api/handlers/scheduler_handler.py @@ -22,7 +22,7 @@ def handle_scheduler_status( - user_name: str | None = None, + mem_cube_id: str | None = None, mem_scheduler: Any | None = None, instance_id: str = "", ) -> dict[str, Any]: @@ -43,9 +43,9 @@ def handle_scheduler_status( HTTPException: If status retrieval fails """ try: - if user_name: + if mem_cube_id: running = mem_scheduler.dispatcher.get_running_tasks( - lambda task: getattr(task, "mem_cube_id", None) == user_name + lambda task: getattr(task, "mem_cube_id", None) == mem_cube_id ) tasks_iter = to_iter(running) running_count = len(tasks_iter) @@ -53,7 +53,7 @@ def handle_scheduler_status( "message": "ok", "data": { "scope": "user", - "user_name": user_name, + "mem_cube_id": mem_cube_id, "running_tasks": running_count, "timestamp": time.time(), "instance_id": instance_id, diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 892d2d436..4b4d5acd9 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -1,6 +1,6 @@ import uuid -from typing import Generic, Literal, TypeVar +from typing import Any, Generic, Literal, TypeVar from pydantic import BaseModel, Field @@ -75,7 +75,18 @@ class ChatRequest(BaseRequest): history: list[MessageDict] | None = Field(None, description="Chat history") internet_search: bool = Field(True, description="Whether to use internet search") moscube: bool = Field(False, description="Whether to use MemOSCube") + system_prompt: str | None = Field(None, description="Base system prompt to use for chat") + top_k: int = Field(10, description="Number of results to return") + threshold: float = Field(0.5, description="Threshold for filtering references") session_id: str | None = Field(None, description="Session ID for soft-filtering memories") + include_preference: bool = Field(True, description="Whether to handle preference memory") + pref_top_k: int = Field(6, description="Number of preference results to return") + filter: dict[str, Any] | None = Field(None, description="Filter for the memory") + model_name: str | None = Field(None, description="Model name to use for chat") + max_tokens: int | None = Field(None, description="Max tokens to generate") + temperature: float | None = Field(None, description="Temperature for sampling") + top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") + add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat") class ChatCompleteRequest(BaseRequest): @@ -122,6 +133,10 @@ class SuggestionResponse(BaseResponse[list]): data: dict[str, list[str]] | None = Field(None, description="Response data") +class AddStatusResponse(BaseResponse[dict]): + """Response model for add status operations.""" + + class ConfigResponse(BaseResponse[None]): """Response model for configuration endpoint.""" @@ -184,6 +199,7 @@ class APISearchRequest(BaseRequest): ) include_preference: bool = Field(True, description="Whether to handle preference memory") pref_top_k: int = Field(6, description="Number of preference results to return") + filter: dict[str, Any] | None = Field(None, description="Filter for the memory") class APIADDRequest(BaseRequest): @@ -203,6 +219,11 @@ class APIADDRequest(BaseRequest): async_mode: Literal["async", "sync"] = Field( "async", description="Whether to add memory in async mode" ) + custom_tags: list[str] | None = Field(None, description="Custom tags for the memory") + info: dict[str, str] | None = Field(None, description="Additional information for the memory") + is_feedback: bool = Field( + False, description="Whether the user feedback in knowladge base service" + ) class APIChatCompleteRequest(BaseRequest): @@ -214,12 +235,28 @@ class APIChatCompleteRequest(BaseRequest): history: list[MessageDict] | None = Field(None, description="Chat history") internet_search: bool = Field(False, description="Whether to use internet search") moscube: bool = Field(True, description="Whether to use MemOSCube") - base_prompt: str | None = Field(None, description="Base prompt to use for chat") + system_prompt: str | None = Field(None, description="Base system prompt to use for chat") top_k: int = Field(10, description="Number of results to return") threshold: float = Field(0.5, description="Threshold for filtering references") session_id: str | None = Field( "default_session", description="Session ID for soft-filtering memories" ) + include_preference: bool = Field(True, description="Whether to handle preference memory") + pref_top_k: int = Field(6, description="Number of preference results to return") + filter: dict[str, Any] | None = Field(None, description="Filter for the memory") + model_name: str | None = Field(None, description="Model name to use for chat") + max_tokens: int | None = Field(None, description="Max tokens to generate") + temperature: float | None = Field(None, description="Temperature for sampling") + top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") + add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat") + + +class AddStatusRequest(BaseRequest): + """Request model for checking add status.""" + + mem_cube_id: str = Field(..., description="Cube ID") + user_id: str | None = Field(None, description="User ID") + session_id: str | None = Field(None, description="Session ID") class SuggestionRequest(BaseRequest): diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index d43f9ccdc..fdb97dc7d 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -23,6 +23,8 @@ from memos.api.handlers.chat_handler import ChatHandler from memos.api.handlers.search_handler import SearchHandler from memos.api.product_models import ( + AddStatusRequest, + AddStatusResponse, APIADDRequest, APIChatCompleteRequest, APISearchRequest, @@ -98,11 +100,15 @@ def add_memories(add_req: APIADDRequest): # ============================================================================= -@router.get("/scheduler/status", summary="Get scheduler running status") -def scheduler_status(user_name: str | None = None): +@router.get( + "/scheduler/status", summary="Get scheduler running status", response_model=AddStatusResponse +) +def scheduler_status(add_status_req: AddStatusRequest): """Get scheduler running status.""" return handlers.scheduler_handler.handle_scheduler_status( - user_name=user_name, + mem_cube_id=add_status_req.mem_cube_id, + user_id=add_status_req.user_id, + session_id=add_status_req.session_id, mem_scheduler=mem_scheduler, instance_id=INSTANCE_ID, ) diff --git a/src/memos/configs/llm.py b/src/memos/configs/llm.py index d69a0a0fc..916b90ad1 100644 --- a/src/memos/configs/llm.py +++ b/src/memos/configs/llm.py @@ -9,9 +9,9 @@ class BaseLLMConfig(BaseConfig): """Base configuration class for LLMs.""" model_name_or_path: str = Field(..., description="Model name or path") - temperature: float = Field(default=0.8, description="Temperature for sampling") - max_tokens: int = Field(default=1024, description="Maximum number of tokens to generate") - top_p: float = Field(default=0.9, description="Top-p sampling parameter") + temperature: float = Field(default=0.7, description="Temperature for sampling") + max_tokens: int = Field(default=8192, description="Maximum number of tokens to generate") + top_p: float = Field(default=0.95, description="Top-p sampling parameter") top_k: int = Field(default=50, description="Top-k sampling parameter") remove_think_prefix: bool = Field( default=False, @@ -27,6 +27,18 @@ class OpenAILLMConfig(BaseLLMConfig): extra_body: Any = Field(default=None, description="extra body") +class OpenAIResponsesLLMConfig(BaseLLMConfig): + api_key: str = Field(..., description="API key for OpenAI") + api_base: str = Field( + default="https://api.openai.com/v1", description="Base URL for OpenAI responses API" + ) + extra_body: Any = Field(default=None, description="extra body") + enable_thinking: bool = Field( + default=False, + description="Enable reasoning outputs from vLLM", + ) + + class QwenLLMConfig(BaseLLMConfig): api_key: str = Field(..., description="API key for DashScope (Qwen)") api_base: str = Field( @@ -34,7 +46,6 @@ class QwenLLMConfig(BaseLLMConfig): description="Base URL for Qwen OpenAI-compatible API", ) extra_body: Any = Field(default=None, description="extra body") - model_name_or_path: str = Field(..., description="Model name for Qwen, e.g., 'qwen-plus'") class DeepSeekLLMConfig(BaseLLMConfig): @@ -44,9 +55,6 @@ class DeepSeekLLMConfig(BaseLLMConfig): description="Base URL for DeepSeek OpenAI-compatible API", ) extra_body: Any = Field(default=None, description="Extra options for API") - model_name_or_path: str = Field( - ..., description="Model name: 'deepseek-chat' or 'deepseek-reasoner'" - ) class AzureLLMConfig(BaseLLMConfig): @@ -61,11 +69,27 @@ class AzureLLMConfig(BaseLLMConfig): api_key: str = Field(..., description="API key for Azure OpenAI") +class AzureResponsesLLMConfig(BaseLLMConfig): + base_url: str = Field( + default="https://api.openai.azure.com/", + description="Base URL for Azure OpenAI API", + ) + api_version: str = Field( + default="2024-03-01-preview", + description="API version for Azure OpenAI", + ) + api_key: str = Field(..., description="API key for Azure OpenAI") + + class OllamaLLMConfig(BaseLLMConfig): api_base: str = Field( default="http://localhost:11434", description="Base URL for Ollama API", ) + enable_thinking: bool = Field( + default=False, + description="Enable reasoning outputs from Ollama", + ) class HFLLMConfig(BaseLLMConfig): @@ -85,6 +109,10 @@ class VLLMLLMConfig(BaseLLMConfig): default="http://localhost:8088/v1", description="Base URL for vLLM API", ) + enable_thinking: bool = Field( + default=False, + description="Enable reasoning outputs from vLLM", + ) class LLMConfigFactory(BaseConfig): diff --git a/src/memos/llms/deepseek.py b/src/memos/llms/deepseek.py index f5ee4842b..a90f8eb31 100644 --- a/src/memos/llms/deepseek.py +++ b/src/memos/llms/deepseek.py @@ -1,10 +1,6 @@ -from collections.abc import Generator - from memos.configs.llm import DeepSeekLLMConfig from memos.llms.openai import OpenAILLM -from memos.llms.utils import remove_thinking_tags from memos.log import get_logger -from memos.types import MessageList logger = get_logger(__name__) @@ -15,40 +11,3 @@ class DeepSeekLLM(OpenAILLM): def __init__(self, config: DeepSeekLLMConfig): super().__init__(config) - - def generate(self, messages: MessageList) -> str: - """Generate a response from DeepSeek.""" - response = self.client.chat.completions.create( - model=self.config.model_name_or_path, - messages=messages, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, - extra_body=self.config.extra_body, - ) - logger.info(f"Response from DeepSeek: {response.model_dump_json()}") - response_content = response.choices[0].message.content - if self.config.remove_think_prefix: - return remove_thinking_tags(response_content) - else: - return response_content - - def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: - """Stream response from DeepSeek.""" - response = self.client.chat.completions.create( - model=self.config.model_name_or_path, - messages=messages, - stream=True, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, - extra_body=self.config.extra_body, - ) - # Streaming chunks of text - for chunk in response: - delta = chunk.choices[0].delta - if hasattr(delta, "reasoning_content") and delta.reasoning_content: - yield delta.reasoning_content - - if hasattr(delta, "content") and delta.content: - yield delta.content diff --git a/src/memos/llms/hf.py b/src/memos/llms/hf.py index be0d1d95f..d46db7c9e 100644 --- a/src/memos/llms/hf.py +++ b/src/memos/llms/hf.py @@ -54,7 +54,9 @@ def __init__(self, config: HFLLMConfig): processors.append(TopPLogitsWarper(self.config.top_p)) self.logits_processors = LogitsProcessorList(processors) - def generate(self, messages: MessageList, past_key_values: DynamicCache | None = None): + def generate( + self, messages: MessageList, past_key_values: DynamicCache | None = None, **kwargs + ): """ Generate a response from the model. If past_key_values is provided, use cache-augmented generation. Args: @@ -68,12 +70,12 @@ def generate(self, messages: MessageList, past_key_values: DynamicCache | None = ) logger.info(f"HFLLM prompt: {prompt}") if past_key_values is None: - return self._generate_full(prompt) + return self._generate_full(prompt, **kwargs) else: - return self._generate_with_cache(prompt, past_key_values) + return self._generate_with_cache(prompt, past_key_values, **kwargs) def generate_stream( - self, messages: MessageList, past_key_values: DynamicCache | None = None + self, messages: MessageList, past_key_values: DynamicCache | None = None, **kwargs ) -> Generator[str, None, None]: """ Generate a streaming response from the model. @@ -92,7 +94,7 @@ def generate_stream( else: yield from self._generate_with_cache_stream(prompt, past_key_values) - def _generate_full(self, prompt: str) -> str: + def _generate_full(self, prompt: str, **kwargs) -> str: """ Generate output from scratch using the full prompt. Args: @@ -102,13 +104,13 @@ def _generate_full(self, prompt: str) -> str: """ inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device) gen_kwargs = { - "max_new_tokens": getattr(self.config, "max_tokens", 128), + "max_new_tokens": kwargs.get("max_tokens", self.config.max_tokens), "do_sample": getattr(self.config, "do_sample", True), } if self.config.do_sample: - gen_kwargs["temperature"] = self.config.temperature - gen_kwargs["top_k"] = self.config.top_k - gen_kwargs["top_p"] = self.config.top_p + gen_kwargs["temperature"] = kwargs.get("temperature", self.config.temperature) + gen_kwargs["top_k"] = kwargs.get("top_k", self.config.top_k) + gen_kwargs["top_p"] = kwargs.get("top_p", self.config.top_p) gen_ids = self.model.generate( **inputs, **gen_kwargs, @@ -125,7 +127,7 @@ def _generate_full(self, prompt: str) -> str: else response ) - def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]: + def _generate_full_stream(self, prompt: str, **kwargs) -> Generator[str, None, None]: """ Generate output from scratch using the full prompt with streaming. Args: @@ -138,7 +140,7 @@ def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]: inputs = self.tokenizer([prompt], return_tensors="pt").to(self.model.device) # Get generation parameters - max_new_tokens = getattr(self.config, "max_tokens", 128) + max_new_tokens = kwargs.get("max_tokens", self.config.max_tokens) remove_think_prefix = getattr(self.config, "remove_think_prefix", False) # Manual streaming generation @@ -192,7 +194,7 @@ def _generate_full_stream(self, prompt: str) -> Generator[str, None, None]: else: yield new_token_text - def _generate_with_cache(self, query: str, kv: DynamicCache) -> str: + def _generate_with_cache(self, query: str, kv: DynamicCache, **kwargs) -> str: """ Generate output incrementally using an existing KV cache. Args: @@ -209,7 +211,7 @@ def _generate_with_cache(self, query: str, kv: DynamicCache) -> str: logits, kv = self._prefill(query_ids, kv) next_token = self._select_next_token(logits) generated = [next_token] - for _ in range(getattr(self.config, "max_tokens", 128) - 1): + for _ in range(kwargs.get("max_tokens", self.config.max_tokens) - 1): if self._should_stop(next_token): break logits, kv = self._prefill(next_token, kv) @@ -228,7 +230,7 @@ def _generate_with_cache(self, query: str, kv: DynamicCache) -> str: ) def _generate_with_cache_stream( - self, query: str, kv: DynamicCache + self, query: str, kv: DynamicCache, **kwargs ) -> Generator[str, None, None]: """ Generate output incrementally using an existing KV cache with streaming. @@ -242,7 +244,7 @@ def _generate_with_cache_stream( query, return_tensors="pt", add_special_tokens=False ).input_ids.to(self.model.device) - max_new_tokens = getattr(self.config, "max_tokens", 128) + max_new_tokens = kwargs.get("max_tokens", self.config.max_tokens) remove_think_prefix = getattr(self.config, "remove_think_prefix", False) # Initial forward pass diff --git a/src/memos/llms/ollama.py b/src/memos/llms/ollama.py index 050b7a253..c8643c763 100644 --- a/src/memos/llms/ollama.py +++ b/src/memos/llms/ollama.py @@ -1,7 +1,7 @@ from collections.abc import Generator from typing import Any -from ollama import Client +from ollama import Client, Message from memos.configs.llm import OllamaLLMConfig from memos.llms.base import BaseLLM @@ -54,7 +54,7 @@ def _ensure_model_exists(self): except Exception as e: logger.warning(f"Could not verify model existence: {e}") - def generate(self, messages: MessageList) -> Any: + def generate(self, messages: MessageList, **kwargs) -> Any: """ Generate a response from Ollama LLM. @@ -68,19 +68,67 @@ def generate(self, messages: MessageList) -> Any: model=self.config.model_name_or_path, messages=messages, options={ - "temperature": self.config.temperature, - "num_predict": self.config.max_tokens, - "top_p": self.config.top_p, - "top_k": self.config.top_k, + "temperature": kwargs.get("temperature", self.config.temperature), + "num_predict": kwargs.get("max_tokens", self.config.max_tokens), + "top_p": kwargs.get("top_p", self.config.top_p), + "top_k": kwargs.get("top_k", self.config.top_k), }, + think=self.config.enable_thinking, + tools=kwargs.get("tools"), ) logger.info(f"Raw response from Ollama: {response.model_dump_json()}") + if response.message.tool_calls: + return self.tool_call_parser(response.message.tool_calls) - str_response = response["message"]["content"] or "" + str_thinking = ( + f"{response.message.thinking}" + if hasattr(response.message, "thinking") + else "" + ) + str_response = response.message.content if self.config.remove_think_prefix: return remove_thinking_tags(str_response) else: - return str_response + return str_thinking + str_response def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: - raise NotImplementedError + if kwargs.get("tools"): + logger.info("stream api not support tools") + return + + response = self.client.chat( + model=kwargs.get("model_name_or_path", self.config.model_name_or_path), + messages=messages, + options={ + "temperature": kwargs.get("temperature", self.config.temperature), + "num_predict": kwargs.get("max_tokens", self.config.max_tokens), + "top_p": kwargs.get("top_p", self.config.top_p), + "top_k": kwargs.get("top_k", self.config.top_k), + }, + think=self.config.enable_thinking, + stream=True, + ) + # Streaming chunks of text + reasoning_started = False + for chunk in response: + if hasattr(chunk.message, "thinking") and chunk.message.thinking: + if not reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = True + yield chunk.message.thinking + + if hasattr(chunk.message, "content") and chunk.message.content: + if reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = False + yield chunk.message.content + + def tool_call_parser(self, tool_calls: list[Message.ToolCall]) -> list[dict]: + """Parse tool calls from OpenAI response.""" + return [ + { + "function_name": tool_call.function.name, + "arguments": tool_call.function.arguments, + } + for tool_call in tool_calls + ] diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 1a1703340..400d1fff9 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -1,11 +1,12 @@ -import hashlib import json from collections.abc import Generator -from typing import ClassVar import openai +from openai._types import NOT_GIVEN +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall + from memos.configs.llm import AzureLLMConfig, OpenAILLMConfig from memos.llms.base import BaseLLM from memos.llms.utils import remove_thinking_tags @@ -18,77 +19,55 @@ class OpenAILLM(BaseLLM): - """OpenAI LLM class with singleton pattern.""" - - _instances: ClassVar[dict] = {} # Class variable to store instances - - def __new__(cls, config: OpenAILLMConfig) -> "OpenAILLM": - config_hash = cls._get_config_hash(config) - - if config_hash not in cls._instances: - logger.info(f"Creating new OpenAI LLM instance for config hash: {config_hash}") - instance = super().__new__(cls) - cls._instances[config_hash] = instance - else: - logger.info(f"Reusing existing OpenAI LLM instance for config hash: {config_hash}") - - return cls._instances[config_hash] + """OpenAI LLM class via openai.chat.completions.create.""" def __init__(self, config: OpenAILLMConfig): - # Avoid duplicate initialization - if hasattr(self, "_initialized"): - return - self.config = config self.client = openai.Client(api_key=config.api_key, base_url=config.api_base) - self._initialized = True logger.info("OpenAI LLM instance initialized") - @classmethod - def _get_config_hash(cls, config: OpenAILLMConfig) -> str: - """Generate hash value of configuration""" - config_dict = config.model_dump() - config_str = json.dumps(config_dict, sort_keys=True) - return hashlib.md5(config_str.encode()).hexdigest() - - @classmethod - def clear_cache(cls): - """Clear all cached instances""" - cls._instances.clear() - logger.info("OpenAI LLM instance cache cleared") - @timed(log=True, log_prefix="OpenAI LLM") def generate(self, messages: MessageList, **kwargs) -> str: """Generate a response from OpenAI LLM, optionally overriding generation params.""" - temperature = kwargs.get("temperature", self.config.temperature) - max_tokens = kwargs.get("max_tokens", self.config.max_tokens) - top_p = kwargs.get("top_p", self.config.top_p) response = self.client.chat.completions.create( - model=self.config.model_name_or_path, + model=kwargs.get("model_name_or_path", self.config.model_name_or_path), messages=messages, - extra_body=self.config.extra_body, - temperature=temperature, - max_tokens=max_tokens, - top_p=top_p, + temperature=kwargs.get("temperature", self.config.temperature), + max_tokens=kwargs.get("max_tokens", self.config.max_tokens), + top_p=kwargs.get("top_p", self.config.top_p), + extra_body=kwargs.get("extra_body", self.config.extra_body), + tools=kwargs.get("tools", NOT_GIVEN), ) logger.info(f"Response from OpenAI: {response.model_dump_json()}") + if response.choices[0].message.tool_calls: + return self.tool_call_parser(response.choices[0].message.tool_calls) + reasoning_content = ( + f"{response.choices[0].message.reasoning_content}" + if hasattr(response.choices[0].message, "reasoning_content") + else "" + ) response_content = response.choices[0].message.content if self.config.remove_think_prefix: return remove_thinking_tags(response_content) else: - return response_content + return reasoning_content + response_content @timed(log=True, log_prefix="OpenAI LLM") def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: """Stream response from OpenAI LLM with optional reasoning support.""" + if kwargs.get("tools"): + logger.info("stream api not support tools") + return + response = self.client.chat.completions.create( model=self.config.model_name_or_path, messages=messages, stream=True, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, - extra_body=self.config.extra_body, + temperature=kwargs.get("temperature", self.config.temperature), + max_tokens=kwargs.get("max_tokens", self.config.max_tokens), + top_p=kwargs.get("top_p", self.config.top_p), + extra_body=kwargs.get("extra_body", self.config.extra_body), + tools=kwargs.get("tools", NOT_GIVEN), ) reasoning_started = False @@ -96,7 +75,7 @@ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, Non for chunk in response: delta = chunk.choices[0].delta - # Support for custom 'reasoning_content' (if present in OpenAI-compatible models like Qwen) + # Support for custom 'reasoning_content' (if present in OpenAI-compatible models like Qwen, DeepSeek) if hasattr(delta, "reasoning_content") and delta.reasoning_content: if not reasoning_started and not self.config.remove_think_prefix: yield "" @@ -112,63 +91,44 @@ def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, Non if reasoning_started and not self.config.remove_think_prefix: yield "" + def tool_call_parser(self, tool_calls: list[ChatCompletionMessageToolCall]) -> list[dict]: + """Parse tool calls from OpenAI response.""" + return [ + { + "tool_call_id": tool_call.id, + "function_name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + for tool_call in tool_calls + ] + class AzureLLM(BaseLLM): """Azure OpenAI LLM class with singleton pattern.""" - _instances: ClassVar[dict] = {} # Class variable to store instances - - def __new__(cls, config: AzureLLMConfig): - # Generate hash value of config as cache key - config_hash = cls._get_config_hash(config) - - if config_hash not in cls._instances: - logger.info(f"Creating new Azure LLM instance for config hash: {config_hash}") - instance = super().__new__(cls) - cls._instances[config_hash] = instance - else: - logger.info(f"Reusing existing Azure LLM instance for config hash: {config_hash}") - - return cls._instances[config_hash] - def __init__(self, config: AzureLLMConfig): - # Avoid duplicate initialization - if hasattr(self, "_initialized"): - return - self.config = config self.client = openai.AzureOpenAI( azure_endpoint=config.base_url, api_version=config.api_version, api_key=config.api_key, ) - self._initialized = True logger.info("Azure LLM instance initialized") - @classmethod - def _get_config_hash(cls, config: AzureLLMConfig) -> str: - """Generate hash value of configuration""" - # Convert config to dict and sort to ensure consistency - config_dict = config.model_dump() - config_str = json.dumps(config_dict, sort_keys=True) - return hashlib.md5(config_str.encode()).hexdigest() - - @classmethod - def clear_cache(cls): - """Clear all cached instances""" - cls._instances.clear() - logger.info("Azure LLM instance cache cleared") - - def generate(self, messages: MessageList) -> str: + def generate(self, messages: MessageList, **kwargs) -> str: """Generate a response from Azure OpenAI LLM.""" response = self.client.chat.completions.create( model=self.config.model_name_or_path, messages=messages, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, + temperature=kwargs.get("temperature", self.config.temperature), + max_tokens=kwargs.get("max_tokens", self.config.max_tokens), + top_p=kwargs.get("top_p", self.config.top_p), + tools=kwargs.get("tools", NOT_GIVEN), + extra_body=kwargs.get("extra_body", self.config.extra_body), ) logger.info(f"Response from Azure OpenAI: {response.model_dump_json()}") + if response.choices[0].message.tool_calls: + return self.tool_call_parser(response.choices[0].message.tool_calls) response_content = response.choices[0].message.content if self.config.remove_think_prefix: return remove_thinking_tags(response_content) @@ -176,4 +136,49 @@ def generate(self, messages: MessageList) -> str: return response_content def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: - raise NotImplementedError + """Stream response from Azure OpenAI LLM with optional reasoning support.""" + if kwargs.get("tools"): + logger.info("stream api not support tools") + return + + response = self.client.chat.completions.create( + model=self.config.model_name_or_path, + messages=messages, + stream=True, + temperature=kwargs.get("temperature", self.config.temperature), + max_tokens=kwargs.get("max_tokens", self.config.max_tokens), + top_p=kwargs.get("top_p", self.config.top_p), + extra_body=kwargs.get("extra_body", self.config.extra_body), + ) + + reasoning_started = False + + for chunk in response: + delta = chunk.choices[0].delta + + # Support for custom 'reasoning_content' (if present in OpenAI-compatible models like Qwen, DeepSeek) + if hasattr(delta, "reasoning_content") and delta.reasoning_content: + if not reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = True + yield delta.reasoning_content + elif hasattr(delta, "content") and delta.content: + if reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = False + yield delta.content + + # Ensure we close the block if not already done + if reasoning_started and not self.config.remove_think_prefix: + yield "" + + def tool_call_parser(self, tool_calls: list[ChatCompletionMessageToolCall]) -> list[dict]: + """Parse tool calls from OpenAI response.""" + return [ + { + "tool_call_id": tool_call.id, + "function_name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + for tool_call in tool_calls + ] diff --git a/src/memos/llms/openai_new.py b/src/memos/llms/openai_new.py new file mode 100644 index 000000000..c1e0ee8a6 --- /dev/null +++ b/src/memos/llms/openai_new.py @@ -0,0 +1,196 @@ +import json + +from collections.abc import Generator + +import openai + +from openai._types import NOT_GIVEN +from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall +from openai.types.responses.response_reasoning_item import ResponseReasoningItem + +from memos.configs.llm import AzureLLMConfig, OpenAILLMConfig +from memos.llms.base import BaseLLM +from memos.llms.utils import remove_thinking_tags +from memos.log import get_logger +from memos.types import MessageList +from memos.utils import timed + + +logger = get_logger(__name__) + + +class OpenAIResponsesLLM(BaseLLM): + def __init__(self, config: OpenAILLMConfig): + self.config = config + self.client = openai.Client(api_key=config.api_key, base_url=config.api_base) + + @timed(log=True, log_prefix="OpenAI Responses LLM") + def generate(self, messages: MessageList, **kwargs) -> str: + response = self.client.responses.create( + model=kwargs.get("model_name_or_path", self.config.model_name_or_path), + input=messages, + temperature=kwargs.get("temperature", self.config.temperature), + top_p=kwargs.get("top_p", self.config.top_p), + max_output_tokens=kwargs.get("max_tokens", self.config.max_tokens), + reasoning={"effort": "low", "summary": "auto"} + if self.config.enable_thinking + else NOT_GIVEN, + tools=kwargs.get("tools", NOT_GIVEN), + extra_body=kwargs.get("extra_body", self.config.extra_body), + ) + tool_call_outputs = [ + item for item in response.output if isinstance(item, ResponseFunctionToolCall) + ] + if tool_call_outputs: + return self.tool_call_parser(tool_call_outputs) + + output_text = getattr(response, "output_text", "") + output_reasoning = [ + item for item in response.output if isinstance(item, ResponseReasoningItem) + ] + summary = output_reasoning[0].summary + + if self.config.remove_think_prefix: + return remove_thinking_tags(output_text) + if summary: + return f"{summary[0].text}" + output_text + return output_text + + @timed(log=True, log_prefix="OpenAI Responses LLM") + def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: + if kwargs.get("tools"): + logger.info("stream api not support tools") + return + + stream = self.client.responses.create( + model=kwargs.get("model_name_or_path", self.config.model_name_or_path), + input=messages, + temperature=kwargs.get("temperature", self.config.temperature), + top_p=kwargs.get("top_p", self.config.top_p), + max_output_tokens=kwargs.get("max_tokens", self.config.max_tokens), + reasoning={"effort": "low", "summary": "auto"} + if self.config.enable_thinking + else NOT_GIVEN, + extra_body=kwargs.get("extra_body", self.config.extra_body), + stream=True, + ) + + reasoning_started = False + + for event in stream: + event_type = getattr(event, "type", "") + if event_type in ( + "response.reasoning.delta", + "response.reasoning_summary_text.delta", + ) and hasattr(event, "delta"): + if not self.config.remove_think_prefix: + if not reasoning_started: + yield "" + reasoning_started = True + yield event.delta + elif event_type == "response.output_text.delta" and hasattr(event, "delta"): + if reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = False + yield event.delta + + if reasoning_started and not self.config.remove_think_prefix: + yield "" + + def tool_call_parser(self, tool_calls: list[ResponseFunctionToolCall]) -> list[dict]: + """Parse tool calls from OpenAI response.""" + return [ + { + "tool_call_id": tool_call.call_id, + "function_name": tool_call.name, + "arguments": json.loads(tool_call.arguments), + } + for tool_call in tool_calls + ] + + +class AzureResponsesLLM(BaseLLM): + def __init__(self, config: AzureLLMConfig): + self.config = config + self.client = openai.AzureOpenAI( + azure_endpoint=config.base_url, + api_version=config.api_version, + api_key=config.api_key, + ) + + def generate(self, messages: MessageList, **kwargs) -> str: + response = self.client.responses.create( + model=self.config.model_name_or_path, + input=messages, + temperature=kwargs.get("temperature", self.config.temperature), + top_p=kwargs.get("top_p", self.config.top_p), + max_output_tokens=kwargs.get("max_tokens", self.config.max_tokens), + tools=kwargs.get("tools", NOT_GIVEN), + extra_body=kwargs.get("extra_body", self.config.extra_body), + reasoning={"effort": "low", "summary": "auto"} + if self.config.enable_thinking + else NOT_GIVEN, + ) + + output_text = getattr(response, "output_text", "") + output_reasoning = [ + item for item in response.output if isinstance(item, ResponseReasoningItem) + ] + summary = output_reasoning[0].summary + + if self.config.remove_think_prefix: + return remove_thinking_tags(output_text) + if summary: + return f"{summary[0].text}" + output_text + return output_text + + def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: + if kwargs.get("tools"): + logger.info("stream api not support tools") + return + + stream = self.client.responses.create( + model=self.config.model_name_or_path, + input=messages, + temperature=kwargs.get("temperature", self.config.temperature), + top_p=kwargs.get("top_p", self.config.top_p), + max_output_tokens=kwargs.get("max_tokens", self.config.max_tokens), + extra_body=kwargs.get("extra_body", self.config.extra_body), + stream=True, + reasoning={"effort": "low", "summary": "auto"} + if self.config.enable_thinking + else NOT_GIVEN, + ) + + reasoning_started = False + + for event in stream: + event_type = getattr(event, "type", "") + if event_type in ( + "response.reasoning.delta", + "response.reasoning_summary_text.delta", + ) and hasattr(event, "delta"): + if not self.config.remove_think_prefix: + if not reasoning_started: + yield "" + reasoning_started = True + yield event.delta + elif event_type == "response.output_text.delta" and hasattr(event, "delta"): + if reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = False + yield event.delta + + if reasoning_started and not self.config.remove_think_prefix: + yield "" + + def tool_call_parser(self, tool_calls: list[ResponseFunctionToolCall]) -> list[dict]: + """Parse tool calls from OpenAI response.""" + return [ + { + "tool_call_id": tool_call.call_id, + "function_name": tool_call.name, + "arguments": json.loads(tool_call.arguments), + } + for tool_call in tool_calls + ] diff --git a/src/memos/llms/qwen.py b/src/memos/llms/qwen.py index a47fcdf36..d54e23c7f 100644 --- a/src/memos/llms/qwen.py +++ b/src/memos/llms/qwen.py @@ -1,10 +1,6 @@ -from collections.abc import Generator - from memos.configs.llm import QwenLLMConfig from memos.llms.openai import OpenAILLM -from memos.llms.utils import remove_thinking_tags from memos.log import get_logger -from memos.types import MessageList logger = get_logger(__name__) @@ -15,49 +11,3 @@ class QwenLLM(OpenAILLM): def __init__(self, config: QwenLLMConfig): super().__init__(config) - - def generate(self, messages: MessageList) -> str: - """Generate a response from Qwen LLM.""" - response = self.client.chat.completions.create( - model=self.config.model_name_or_path, - messages=messages, - extra_body=self.config.extra_body, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, - ) - logger.info(f"Response from Qwen: {response.model_dump_json()}") - response_content = response.choices[0].message.content - if self.config.remove_think_prefix: - return remove_thinking_tags(response_content) - else: - return response_content - - def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: - """Stream response from Qwen LLM.""" - response = self.client.chat.completions.create( - model=self.config.model_name_or_path, - messages=messages, - stream=True, - temperature=self.config.temperature, - max_tokens=self.config.max_tokens, - top_p=self.config.top_p, - extra_body=self.config.extra_body, - ) - - reasoning_started = False - for chunk in response: - delta = chunk.choices[0].delta - - # Some models may have separate `reasoning_content` vs `content` - # For Qwen (DashScope), likely only `content` is used - if hasattr(delta, "reasoning_content") and delta.reasoning_content: - if not reasoning_started and not self.config.remove_think_prefix: - yield "" - reasoning_started = True - yield delta.reasoning_content - elif hasattr(delta, "content") and delta.content: - if reasoning_started and not self.config.remove_think_prefix: - yield "" - reasoning_started = False - yield delta.content diff --git a/src/memos/llms/vllm.py b/src/memos/llms/vllm.py index c3750bb4b..fdb07a1b8 100644 --- a/src/memos/llms/vllm.py +++ b/src/memos/llms/vllm.py @@ -1,5 +1,11 @@ +import json + from typing import Any, cast +import openai + +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall + from memos.configs.llm import VLLMLLMConfig from memos.llms.base import BaseLLM from memos.llms.utils import remove_thinking_tags @@ -27,8 +33,6 @@ def __init__(self, config: VLLMLLMConfig): if not api_key: api_key = "dummy" - import openai - self.client = openai.Client( api_key=api_key, base_url=getattr(self.config, "api_base", "http://localhost:8088/v1") ) @@ -85,36 +89,54 @@ def build_vllm_kv_cache(self, messages: Any) -> str: return prompt - def generate(self, messages: list[MessageDict]) -> str: + def generate(self, messages: list[MessageDict], **kwargs) -> str: """ Generate a response from the model. """ if self.client: - return self._generate_with_api_client(messages) + return self._generate_with_api_client(messages, **kwargs) else: raise RuntimeError("API client is not available") - def _generate_with_api_client(self, messages: list[MessageDict]) -> str: + def _generate_with_api_client(self, messages: list[MessageDict], **kwargs) -> str: """ - Generate response using vLLM API client. + Generate response using vLLM API client. detail view https://docs.vllm.ai/en/latest/features/reasoning_outputs/ """ if self.client: completion_kwargs = { - "model": self.config.model_name_or_path, + "model": kwargs.get("model_name_or_path", self.config.model_name_or_path), "messages": messages, - "temperature": float(getattr(self.config, "temperature", 0.8)), - "max_tokens": int(getattr(self.config, "max_tokens", 1024)), - "top_p": float(getattr(self.config, "top_p", 0.9)), - "extra_body": {"chat_template_kwargs": {"enable_thinking": False}}, + "temperature": kwargs.get("temperature", self.config.temperature), + "max_tokens": kwargs.get("max_tokens", self.config.max_tokens), + "top_p": kwargs.get("top_p", self.config.top_p), + "extra_body": { + "chat_template_kwargs": { + "enable_thinking": kwargs.get( + "enable_thinking", self.config.enable_thinking + ) + } + }, } + if kwargs.get("tools"): + completion_kwargs["tools"] = kwargs.get("tools") + completion_kwargs["tool_choice"] = kwargs.get("tool_choice", "auto") response = self.client.chat.completions.create(**completion_kwargs) + + if response.choices[0].message.tool_calls: + return self.tool_call_parser(response.choices[0].message.tool_calls) + + reasoning_content = ( + f"{response.choices[0].message.reasoning}" + if hasattr(response.choices[0].message, "reasoning") + else "" + ) response_text = response.choices[0].message.content or "" logger.info(f"VLLM API response: {response_text}") return ( remove_thinking_tags(response_text) if getattr(self.config, "remove_think_prefix", False) - else response_text + else reasoning_content + response_text ) else: raise RuntimeError("API client is not available") @@ -130,26 +152,59 @@ def _messages_to_prompt(self, messages: list[MessageDict]) -> str: prompt_parts.append(f"{role.capitalize()}: {content}") return "\n".join(prompt_parts) - def generate_stream(self, messages: list[MessageDict]): + def generate_stream(self, messages: list[MessageDict], **kwargs): """ Generate a response from the model using streaming. Yields content chunks as they are received. """ + if kwargs.get("tools"): + logger.info("stream api not support tools") + return + if self.client: completion_kwargs = { "model": self.config.model_name_or_path, "messages": messages, - "temperature": float(getattr(self.config, "temperature", 0.8)), - "max_tokens": int(getattr(self.config, "max_tokens", 1024)), - "top_p": float(getattr(self.config, "top_p", 0.9)), - "stream": True, # Enable streaming - "extra_body": {"chat_template_kwargs": {"enable_thinking": False}}, + "temperature": kwargs.get("temperature", self.config.temperature), + "max_tokens": kwargs.get("max_tokens", self.config.max_tokens), + "top_p": kwargs.get("top_p", self.config.top_p), + "stream": True, + "extra_body": { + "chat_template_kwargs": { + "enable_thinking": kwargs.get( + "enable_thinking", self.config.enable_thinking + ) + } + }, } stream = self.client.chat.completions.create(**completion_kwargs) + + reasoning_started = False for chunk in stream: - content = chunk.choices[0].delta.content - if content: - yield content + delta = chunk.choices[0].delta + if hasattr(delta, "reasoning") and delta.reasoning: + if not reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = True + yield delta.reasoning + + if hasattr(delta, "content") and delta.content: + if reasoning_started and not self.config.remove_think_prefix: + yield "" + reasoning_started = False + yield delta.content + else: raise RuntimeError("API client is not available") + + def tool_call_parser(self, tool_calls: list[ChatCompletionMessageToolCall]) -> list[dict]: + """Parse tool calls from OpenAI response.""" + return [ + { + "tool_call_id": tool_call.id, + "function_name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + for tool_call in tool_calls + ] From ddf3dd14864c440c6a7f033a9e9549d07c030ed6 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 18 Nov 2025 18:52:01 +0800 Subject: [PATCH 03/25] llm construction --- src/memos/configs/llm.py | 1 + src/memos/llms/factory.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/src/memos/configs/llm.py b/src/memos/configs/llm.py index 916b90ad1..e446e72bb 100644 --- a/src/memos/configs/llm.py +++ b/src/memos/configs/llm.py @@ -130,6 +130,7 @@ class LLMConfigFactory(BaseConfig): "huggingface_singleton": HFLLMConfig, # Add singleton support "qwen": QwenLLMConfig, "deepseek": DeepSeekLLMConfig, + "openai_new": OpenAIResponsesLLMConfig, } @field_validator("backend") diff --git a/src/memos/llms/factory.py b/src/memos/llms/factory.py index 8589d7750..8f4da662f 100644 --- a/src/memos/llms/factory.py +++ b/src/memos/llms/factory.py @@ -7,6 +7,7 @@ from memos.llms.hf_singleton import HFSingletonLLM from memos.llms.ollama import OllamaLLM from memos.llms.openai import AzureLLM, OpenAILLM +from memos.llms.openai_new import OpenAIResponsesLLM from memos.llms.qwen import QwenLLM from memos.llms.vllm import VLLMLLM from memos.memos_tools.singleton import singleton_factory @@ -24,6 +25,7 @@ class LLMFactory(BaseLLM): "vllm": VLLMLLM, "qwen": QwenLLM, "deepseek": DeepSeekLLM, + "openai_new": OpenAIResponsesLLM, } @classmethod From 3006812175c224e5173464a116b05710bcea440f Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Wed, 19 Nov 2025 15:16:11 +0800 Subject: [PATCH 04/25] add delete and get, modify chat --- src/memos/api/handlers/chat_handler.py | 83 ++++++++++++++++++++---- src/memos/api/handlers/memory_handler.py | 22 ++++++- src/memos/api/product_models.py | 44 ++++++++++--- src/memos/api/routers/product_router.py | 4 +- src/memos/api/routers/server_router.py | 4 +- 5 files changed, 129 insertions(+), 28 deletions(-) diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 9b0048ed4..df42badd7 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -32,7 +32,6 @@ from memos.mem_scheduler.schemas.general_schemas import ( ANSWER_LABEL, QUERY_LABEL, - SearchMode, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.templates.mos_prompts import ( @@ -111,15 +110,17 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An # Step 1: Search for relevant memories search_req = APISearchRequest( + query=chat_req.query, user_id=chat_req.user_id, mem_cube_id=chat_req.mem_cube_id, - query=chat_req.query, - top_k=chat_req.top_k or 10, - session_id=chat_req.session_id, - mode=SearchMode.FINE, + mode=chat_req.mode, internet_search=chat_req.internet_search, - moscube=chat_req.moscube, + top_k=chat_req.top_k, chat_history=chat_req.history, + session_id=chat_req.session_id, + include_preference=chat_req.include_preference, + pref_top_k=chat_req.pref_top_k, + filter=chat_req.filter, ) search_response = self.search_handler.handle_search_memories(search_req) @@ -137,7 +138,9 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An ) # Step 2: Build system prompt - system_prompt = self._build_system_prompt(filtered_memories, chat_req.base_prompt) + system_prompt = self._build_system_prompt( + filtered_memories, search_response.data["pref_string"], chat_req.system_prompt + ) # Prepare message history history_info = chat_req.history[-20:] if chat_req.history else [] @@ -180,7 +183,7 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An self.logger.error(f"Failed to complete chat: {traceback.format_exc()}") raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - def handle_chat_stream(self, chat_req: ChatRequest) -> StreamingResponse: + def handle_chat_stream_playground(self, chat_req: ChatRequest) -> StreamingResponse: """ Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers. @@ -208,15 +211,17 @@ def generate_chat_response() -> Generator[str, None, None]: yield f"data: {json.dumps({'type': 'status', 'data': '0'})}\n\n" search_req = APISearchRequest( + query=chat_req.query, user_id=chat_req.user_id, mem_cube_id=chat_req.mem_cube_id, - query=chat_req.query, - top_k=20, - session_id=chat_req.session_id, - mode=SearchMode.FINE, + mode=chat_req.mode, internet_search=chat_req.internet_search, - moscube=chat_req.moscube, + top_k=chat_req.top_k, chat_history=chat_req.history, + session_id=chat_req.session_id, + include_preference=chat_req.include_preference, + pref_top_k=chat_req.pref_top_k, + filter=chat_req.filter, ) search_response = self.search_handler.handle_search_memories(search_req) @@ -242,8 +247,17 @@ def generate_chat_response() -> Generator[str, None, None]: reference = prepare_reference_data(filtered_memories) yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" + # Prepare preference markdown string + if chat_req.include_preference: + pref_md_string = self._build_pref_md_string_for_playground( + search_response.data["pref_mem"][0].get("memories", []) + ) + yield f"data: {json.dumps({'type': 'pref_md_string', 'data': pref_md_string})}\n\n" + # Step 2: Build system prompt with memories - system_prompt = self._build_enhance_system_prompt(filtered_memories) + system_prompt = self._build_enhance_system_prompt( + filtered_memories, search_response.data["pref_string"] + ) # Prepare messages history_info = chat_req.history[-20:] if chat_req.history else [] @@ -344,9 +358,45 @@ def generate_chat_response() -> Generator[str, None, None]: self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}") raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + def _build_pref_md_string_for_playground(self, pref_mem_list: list[any]) -> str: + """Build preference markdown string for playground.""" + explicit = [] + implicit = [] + for pref_mem in pref_mem_list: + if pref_mem["metadata"]["preference_type"] == "explicit": + explicit.append( + { + "content": pref_mem["preference"], + "reasoning": pref_mem["metadata"]["reasoning"], + } + ) + elif pref_mem["metadata"]["preference_type"] == "implicit": + implicit.append( + { + "content": pref_mem["preference"], + "reasoning": pref_mem["metadata"]["reasoning"], + } + ) + + explicit_md = "\n\n".join( + [ + f"显性偏好 {i + 1}:\n- 抽取内容: {pref['content']}\n- 抽取理由: {pref['reasoning']}" + for i, pref in enumerate(explicit) + ] + ) + implicit_md = "\n\n".join( + [ + f"隐性偏好 {i + 1}:\n- 抽取内容: {pref['content']}\n- 抽取理由: {pref['reasoning']}" + for i, pref in enumerate(implicit) + ] + ) + + return f"{explicit_md}\n\n{implicit_md}" + def _build_system_prompt( self, memories: list | None = None, + pref_string: str | None = None, base_prompt: str | None = None, **kwargs, ) -> str: @@ -366,6 +416,8 @@ def _build_system_prompt( text_memory = memory.get("memory", "") memory_list.append(f"{i}. {text_memory}") memory_context = "\n".join(memory_list) + if pref_string: + memory_context += f"\n\n{pref_string}" if "{memories}" in base_prompt: return base_prompt.format(memories=memory_context) @@ -378,6 +430,7 @@ def _build_system_prompt( def _build_enhance_system_prompt( self, memories_list: list, + pref_string: str = "", tone: str = "friendly", verbosity: str = "mid", ) -> str: @@ -386,6 +439,7 @@ def _build_enhance_system_prompt( Args: memories_list: List of memory items + pref_string: Preference string tone: Tone of the prompt verbosity: Verbosity level @@ -407,6 +461,7 @@ def _build_enhance_system_prompt( + mem_block_p + "\n## OuterMemory (ordered)\n" + mem_block_o + + f"\n\n{pref_string}" ) def _format_mem_block( diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index 85f339f3f..994fe71a7 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -6,7 +6,13 @@ from typing import Any, Literal -from memos.api.product_models import MemoryResponse +from memos.api.product_models import ( + DeleteMemoryRequest, + DeleteMemoryResponse, + GetMemoryRequest, + GetMemoryResponse, + MemoryResponse, +) from memos.log import get_logger from memos.mem_os.utils.format_utils import ( convert_graph_to_tree_forworkmem, @@ -149,3 +155,17 @@ def handle_get_subgraph( except Exception as e: logger.error(f"Failed to get subgraph: {e}", exc_info=True) raise + + +def handle_get_memories(self, get_mem_req: GetMemoryRequest): + return GetMemoryResponse( + message="Memories retrieved successfully", + data=None, + ) + + +def handle_delete_memories(self, delete_mem_req: DeleteMemoryRequest): + return DeleteMemoryResponse( + message="Memories deleted successfully", + data=None, + ) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 4b4d5acd9..4eb28b38e 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -36,7 +36,7 @@ class UserRegisterRequest(BaseRequest): interests: str | None = Field(None, description="User interests") -class GetMemoryRequest(BaseRequest): +class GetAllMemoryRequest(BaseRequest): """Request model for getting memories.""" user_id: str = Field(..., description="User ID") @@ -74,7 +74,6 @@ class ChatRequest(BaseRequest): mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") history: list[MessageDict] | None = Field(None, description="Chat history") internet_search: bool = Field(True, description="Whether to use internet search") - moscube: bool = Field(False, description="Whether to use MemOSCube") system_prompt: str | None = Field(None, description="Base system prompt to use for chat") top_k: int = Field(10, description="Number of results to return") threshold: float = Field(0.5, description="Threshold for filtering references") @@ -82,7 +81,7 @@ class ChatRequest(BaseRequest): include_preference: bool = Field(True, description="Whether to handle preference memory") pref_top_k: int = Field(6, description="Number of preference results to return") filter: dict[str, Any] | None = Field(None, description="Filter for the memory") - model_name: str | None = Field(None, description="Model name to use for chat") + model_name_or_path: str | None = Field(None, description="Model name to use for chat") max_tokens: int | None = Field(None, description="Max tokens to generate") temperature: float | None = Field(None, description="Temperature for sampling") top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") @@ -97,11 +96,18 @@ class ChatCompleteRequest(BaseRequest): mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") history: list[MessageDict] | None = Field(None, description="Chat history") internet_search: bool = Field(False, description="Whether to use internet search") - moscube: bool = Field(False, description="Whether to use MemOSCube") - base_prompt: str | None = Field(None, description="Base prompt to use for chat") + system_prompt: str | None = Field(None, description="Base prompt to use for chat") top_k: int = Field(10, description="Number of results to return") threshold: float = Field(0.5, description="Threshold for filtering references") session_id: str | None = Field(None, description="Session ID for soft-filtering memories") + include_preference: bool = Field(True, description="Whether to handle preference memory") + pref_top_k: int = Field(6, description="Number of preference results to return") + filter: dict[str, Any] | None = Field(None, description="Filter for the memory") + model_name_or_path: str | None = Field(None, description="Model name to use for chat") + max_tokens: int | None = Field(None, description="Max tokens to generate") + temperature: float | None = Field(None, description="Temperature for sampling") + top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") + add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat") class UserCreate(BaseRequest): @@ -149,6 +155,14 @@ class ChatResponse(BaseResponse[str]): """Response model for chat operations.""" +class GetMemoryResponse(BaseResponse[dict]): + """Response model for getting memories.""" + + +class DeleteMemoryResponse(BaseResponse[dict]): + """Response model for deleting memories.""" + + class UserResponse(BaseResponse[dict]): """Response model for user operations.""" @@ -186,11 +200,8 @@ class APISearchRequest(BaseRequest): query: str = Field(..., description="Search query") user_id: str = Field(None, description="User ID") mem_cube_id: str | None = Field(None, description="Cube ID to search in") - mode: SearchMode = Field( - SearchMode.NOT_INITIALIZED, description="search mode: fast, fine, or mixture" - ) + mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") internet_search: bool = Field(False, description="Whether to use internet search") - moscube: bool = Field(False, description="Whether to use MemOSCube") top_k: int = Field(10, description="Number of results to return") chat_history: list[MessageDict] | None = Field(None, description="Chat history") session_id: str | None = Field(None, description="Session ID for soft-filtering memories") @@ -236,6 +247,7 @@ class APIChatCompleteRequest(BaseRequest): internet_search: bool = Field(False, description="Whether to use internet search") moscube: bool = Field(True, description="Whether to use MemOSCube") system_prompt: str | None = Field(None, description="Base system prompt to use for chat") + mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") top_k: int = Field(10, description="Number of results to return") threshold: float = Field(0.5, description="Threshold for filtering references") session_id: str | None = Field( @@ -259,6 +271,20 @@ class AddStatusRequest(BaseRequest): session_id: str | None = Field(None, description="Session ID") +class GetMemoryRequest(BaseRequest): + """Request model for getting memories.""" + + mem_cube_id: str = Field(..., description="Cube ID") + user_id: str | None = Field(None, description="User ID") + include_preference: bool = Field(True, description="Whether to handle preference memory") + + +class DeleteMemoryRequest(BaseRequest): + """Request model for deleting memories.""" + + memory_id: str = Field(..., description="Memory ID") + + class SuggestionRequest(BaseRequest): """Request model for getting suggestion queries.""" diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py index 75b614cf4..5d83cf359 100644 --- a/src/memos/api/routers/product_router.py +++ b/src/memos/api/routers/product_router.py @@ -10,7 +10,7 @@ BaseResponse, ChatCompleteRequest, ChatRequest, - GetMemoryRequest, + GetAllMemoryRequest, MemoryCreateRequest, MemoryResponse, SearchRequest, @@ -159,7 +159,7 @@ def get_suggestion_queries_post(suggestion_req: SuggestionRequest): @router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse) -def get_all_memories(memory_req: GetMemoryRequest): +def get_all_memories(memory_req: GetAllMemoryRequest): """Get all memories for a specific user.""" try: mos_product = get_mos_product_instance() diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index fdb97dc7d..bfe8b1a93 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -29,7 +29,7 @@ APIChatCompleteRequest, APISearchRequest, ChatRequest, - GetMemoryRequest, + GetAllMemoryRequest, MemoryResponse, SearchResponse, SuggestionRequest, @@ -198,7 +198,7 @@ def get_suggestion_queries(suggestion_req: SuggestionRequest): @router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse) -def get_all_memories(memory_req: GetMemoryRequest): +def get_all_memories(memory_req: GetAllMemoryRequest): """ Get all memories or subgraph for a specific user. From 872013bf3b2ade6fc2ba2a78e311d4169937c481 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 20 Nov 2025 15:58:04 +0800 Subject: [PATCH 05/25] modify code --- src/memos/api/handlers/chat_handler.py | 297 +++++++++++++++--- src/memos/api/handlers/memory_handler.py | 29 +- src/memos/api/handlers/search_handler.py | 2 - src/memos/api/product_models.py | 7 +- src/memos/api/routers/product_router.py | 4 +- src/memos/api/routers/server_router.py | 40 ++- .../mem_scheduler/optimized_scheduler.py | 4 - src/memos/memories/textual/preference.py | 32 +- src/memos/memories/textual/tree.py | 7 - .../tree_text_memory/retrieve/searcher.py | 13 - 10 files changed, 351 insertions(+), 84 deletions(-) diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index df42badd7..a3a187982 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -104,10 +104,6 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An HTTPException: If chat fails """ try: - import time - - time_start = time.time() - # Step 1: Search for relevant memories search_req = APISearchRequest( query=chat_req.query, @@ -155,26 +151,27 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An # Step 3: Generate complete response from LLM response = self.llm.generate(current_messages) - time_end = time.time() - - # Step 4: Start post-chat processing asynchronously - self._start_post_chat_processing( + # Step 4: start add after chat asynchronously + self._start_add_to_memory( user_id=chat_req.user_id, cube_id=chat_req.mem_cube_id, session_id=chat_req.session_id or "default_session", query=chat_req.query, full_response=response, - system_prompt=system_prompt, - time_start=time_start, - time_end=time_end, - speed_improvement=0.0, - current_messages=current_messages, + async_mode="async", + ) + + import re + + match = re.search(r"([\s\S]*?)", response) + reasoning_text = match.group(1) if match else None + final_text = ( + re.sub(r"[\s\S]*?", "", response, count=1) if match else response ) - # Return the complete response return { "message": "Chat completed successfully", - "data": {"response": response, "references": filtered_memories}, + "data": {"response": final_text, "reasoning": reasoning_text}, } except ValueError as err: @@ -183,6 +180,140 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An self.logger.error(f"Failed to complete chat: {traceback.format_exc()}") raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + def handle_chat_stream(self, chat_req: ChatRequest) -> StreamingResponse: + """ + Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers. + + This implementation directly uses search_handler and add_handler. + + Args: + chat_req: Chat stream request + + Returns: + StreamingResponse with SSE formatted chat stream + + Raises: + HTTPException: If stream initialization fails + """ + try: + + def generate_chat_response() -> Generator[str, None, None]: + """Generate chat response as SSE stream.""" + try: + search_req = APISearchRequest( + query=chat_req.query, + user_id=chat_req.user_id, + mem_cube_id=chat_req.mem_cube_id, + mode=chat_req.mode, + internet_search=chat_req.internet_search, + top_k=chat_req.top_k, + chat_history=chat_req.history, + session_id=chat_req.session_id, + include_preference=chat_req.include_preference, + pref_top_k=chat_req.pref_top_k, + filter=chat_req.filter, + ) + + search_response = self.search_handler.handle_search_memories(search_req) + + self._send_message_to_scheduler( + user_id=chat_req.user_id, + mem_cube_id=chat_req.mem_cube_id, + query=chat_req.query, + label=QUERY_LABEL, + ) + # Extract memories from search results + memories_list = [] + if search_response.data and search_response.data.get("text_mem"): + text_mem_results = search_response.data["text_mem"] + if text_mem_results and text_mem_results[0].get("memories"): + memories_list = text_mem_results[0]["memories"] + + # Filter memories by threshold + filtered_memories = self._filter_memories_by_threshold(memories_list) + + # Step 2: Build system prompt with memories + system_prompt = self._build_system_prompt( + filtered_memories, + search_response.data["pref_string"], + chat_req.system_prompt, + ) + + # Prepare messages + history_info = chat_req.history[-20:] if chat_req.history else [] + current_messages = [ + {"role": "system", "content": system_prompt}, + *history_info, + {"role": "user", "content": chat_req.query}, + ] + + self.logger.info( + f"user_id: {chat_req.user_id}, cube_id: {chat_req.mem_cube_id}, " + f"current_system_prompt: {system_prompt}" + ) + + # Step 3: Generate streaming response from LLM + response_stream = self.llm.generate_stream(current_messages) + + # Stream the response + buffer = "" + full_response = "" + in_think = False + + for chunk in response_stream: + if chunk == "": + in_think = True + continue + if chunk == "": + in_think = False + continue + + if in_think: + chunk_data = f"data: {json.dumps({'type': 'reasoning', 'data': chunk}, ensure_ascii=False)}\n\n" + yield chunk_data + continue + + buffer += chunk + full_response += chunk + + chunk_data = f"data: {json.dumps({'type': 'text', 'data': chunk}, ensure_ascii=False)}\n\n" + yield chunk_data + + current_messages.append({"role": "assistant", "content": full_response}) + + self._start_add_to_memory( + user_id=chat_req.user_id, + cube_id=chat_req.mem_cube_id, + session_id=chat_req.session_id or "default_session", + query=chat_req.query, + full_response=full_response, + async_mode="async", + ) + + except Exception as e: + self.logger.error(f"Error in chat stream: {e}", exc_info=True) + error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n" + yield error_data + + return StreamingResponse( + generate_chat_response(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "*", + "Access-Control-Allow-Methods": "*", + }, + ) + + except ValueError as err: + raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err + except Exception as err: + self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + def handle_chat_stream_playground(self, chat_req: ChatRequest) -> StreamingResponse: """ Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers. @@ -280,9 +411,19 @@ def generate_chat_response() -> Generator[str, None, None]: # Stream the response buffer = "" full_response = "" + in_think = False for chunk in response_stream: - if chunk in ["", ""]: + if chunk == "": + in_think = True + continue + if chunk == "": + in_think = False + continue + + if in_think: + chunk_data = f"data: {json.dumps({'type': 'reasoning', 'data': chunk}, ensure_ascii=False)}\n\n" + yield chunk_data continue buffer += chunk @@ -320,7 +461,6 @@ def generate_chat_response() -> Generator[str, None, None]: yield f"data: {json.dumps({'type': 'end'})}\n\n" - # Step 4: Add conversation to memory asynchronously self._start_post_chat_processing( user_id=chat_req.user_id, cube_id=chat_req.mem_cube_id, @@ -334,6 +474,15 @@ def generate_chat_response() -> Generator[str, None, None]: current_messages=current_messages, ) + self._start_add_to_memory( + user_id=chat_req.user_id, + cube_id=chat_req.mem_cube_id, + session_id=chat_req.session_id or "default_session", + query=chat_req.query, + full_response=full_response, + async_mode="sync", + ) + except Exception as e: self.logger.error(f"Error in chat stream: {e}", exc_info=True) error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n" @@ -663,6 +812,36 @@ def _send_message_to_scheduler( except Exception as e: self.logger.error(f"Failed to send message to scheduler: {e}", exc_info=True) + async def _add_conversation_to_memory( + self, + user_id: str, + cube_id: str, + session_id: str, + query: str, + clean_response: str, + async_mode: Literal["async", "sync"] = "sync", + ) -> None: + add_req = APIADDRequest( + user_id=user_id, + mem_cube_id=cube_id, + session_id=session_id, + messages=[ + { + "role": "user", + "content": query, + "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + }, + { + "role": "assistant", + "content": clean_response, + "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + }, + ], + async_mode=async_mode, + ) + + self.add_handler.handle_add_memories(add_req) + async def _post_chat_processing( self, user_id: str, @@ -756,28 +935,6 @@ async def _post_chat_processing( user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_LABEL ) - # Add conversation to memory using add handler - add_req = APIADDRequest( - user_id=user_id, - mem_cube_id=cube_id, - session_id=session_id, - messages=[ - { - "role": "user", - "content": query, - "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - }, - { - "role": "assistant", - "content": clean_response, # Store clean text without reference markers - "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), - }, - ], - async_mode="sync", # set suync for playground - ) - - self.add_handler.handle_add_memories(add_req) - self.logger.info(f"Post-chat processing completed for user {user_id}") except Exception as e: @@ -877,3 +1034,65 @@ def run_async_in_thread(): daemon=True, ) thread.start() + + def _start_add_to_memory( + self, + user_id: str, + cube_id: str, + session_id: str, + query: str, + full_response: str, + async_mode: Literal["async", "sync"] = "sync", + ) -> None: + def run_async_in_thread(): + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + clean_response, _ = self._extract_references_from_response(full_response) + loop.run_until_complete( + self._add_conversation_to_memory( + user_id=user_id, + cube_id=cube_id, + session_id=session_id, + query=query, + clean_response=clean_response, + async_mode=async_mode, + ) + ) + finally: + loop.close() + except Exception as e: + self.logger.error( + f"Error in thread-based add to memory for user {user_id}: {e}", + exc_info=True, + ) + + try: + asyncio.get_running_loop() + clean_response, _ = self._extract_references_from_response(full_response) + task = asyncio.create_task( + self._add_conversation_to_memory( + user_id=user_id, + cube_id=cube_id, + session_id=session_id, + query=query, + clean_response=clean_response, + async_mode=async_mode, + ) + ) + task.add_done_callback( + lambda t: self.logger.error( + f"Error in background add to memory for user {user_id}: {t.exception()}", + exc_info=True, + ) + if t.exception() + else None + ) + except RuntimeError: + thread = ContextThread( + target=run_async_in_thread, + name=f"AddToMemory-{user_id}", + daemon=True, + ) + thread.start() diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index 994fe71a7..bb5672740 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -6,6 +6,7 @@ from typing import Any, Literal +from memos.api.handlers.formatters_handler import format_memory_item from memos.api.product_models import ( DeleteMemoryRequest, DeleteMemoryResponse, @@ -157,15 +158,35 @@ def handle_get_subgraph( raise -def handle_get_memories(self, get_mem_req: GetMemoryRequest): +def handle_get_memories(get_mem_req: GetMemoryRequest, naive_mem_cube: Any) -> GetMemoryResponse: + # TODO: Implement get memory with filter + memories = naive_mem_cube.text_mem.get_all(user_name=get_mem_req.mem_cube_id) + filter_params: dict[str, Any] = {} + if get_mem_req.user_id is not None: + filter_params["user_id"] = get_mem_req.user_id + if get_mem_req.mem_cube_id is not None: + filter_params["mem_cube_id"] = get_mem_req.mem_cube_id + preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params) return GetMemoryResponse( message="Memories retrieved successfully", - data=None, + data={ + "text_mem": [format_memory_item(mem) for mem in memories], + "pref_mem": [format_memory_item(mem) for mem in preferences], + }, ) -def handle_delete_memories(self, delete_mem_req: DeleteMemoryRequest): +def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: Any): + try: + naive_mem_cube.text_mem.delete(delete_mem_req.memory_ids) + naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids) + except Exception as e: + logger.error(f"Failed to delete memories: {e}", exc_info=True) + return DeleteMemoryResponse( + message="Failed to delete memories", + data="failure", + ) return DeleteMemoryResponse( message="Memories deleted successfully", - data=None, + data="success", ) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 9fc8a5b28..76f087edf 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -205,7 +205,6 @@ def _fast_search( top_k=search_req.top_k, mode=SearchMode.FAST, manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, search_filter=search_filter, info={ "user_id": search_req.user_id, @@ -247,7 +246,6 @@ def _fine_search( top_k=search_req.top_k, mode=SearchMode.FAST, manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, search_filter=search_filter, info=info, ) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 4eb28b38e..26b2f9218 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -36,7 +36,7 @@ class UserRegisterRequest(BaseRequest): interests: str | None = Field(None, description="User interests") -class GetAllMemoryRequest(BaseRequest): +class GetMemoryPlaygroundRequest(BaseRequest): """Request model for getting memories.""" user_id: str = Field(..., description="User ID") @@ -245,7 +245,6 @@ class APIChatCompleteRequest(BaseRequest): mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") history: list[MessageDict] | None = Field(None, description="Chat history") internet_search: bool = Field(False, description="Whether to use internet search") - moscube: bool = Field(True, description="Whether to use MemOSCube") system_prompt: str | None = Field(None, description="Base system prompt to use for chat") mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") top_k: int = Field(10, description="Number of results to return") @@ -256,7 +255,7 @@ class APIChatCompleteRequest(BaseRequest): include_preference: bool = Field(True, description="Whether to handle preference memory") pref_top_k: int = Field(6, description="Number of preference results to return") filter: dict[str, Any] | None = Field(None, description="Filter for the memory") - model_name: str | None = Field(None, description="Model name to use for chat") + model_name_or_path: str | None = Field(None, description="Model name to use for chat") max_tokens: int | None = Field(None, description="Max tokens to generate") temperature: float | None = Field(None, description="Temperature for sampling") top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") @@ -282,7 +281,7 @@ class GetMemoryRequest(BaseRequest): class DeleteMemoryRequest(BaseRequest): """Request model for deleting memories.""" - memory_id: str = Field(..., description="Memory ID") + memory_ids: list[str] = Field(..., description="Memory IDs") class SuggestionRequest(BaseRequest): diff --git a/src/memos/api/routers/product_router.py b/src/memos/api/routers/product_router.py index 5d83cf359..2f6c5c317 100644 --- a/src/memos/api/routers/product_router.py +++ b/src/memos/api/routers/product_router.py @@ -10,7 +10,7 @@ BaseResponse, ChatCompleteRequest, ChatRequest, - GetAllMemoryRequest, + GetMemoryPlaygroundRequest, MemoryCreateRequest, MemoryResponse, SearchRequest, @@ -159,7 +159,7 @@ def get_suggestion_queries_post(suggestion_req: SuggestionRequest): @router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse) -def get_all_memories(memory_req: GetAllMemoryRequest): +def get_all_memories(memory_req: GetMemoryPlaygroundRequest): """Get all memories for a specific user.""" try: mos_product = get_mos_product_instance() diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index bfe8b1a93..592fb814a 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -29,7 +29,11 @@ APIChatCompleteRequest, APISearchRequest, ChatRequest, - GetAllMemoryRequest, + DeleteMemoryRequest, + DeleteMemoryResponse, + GetMemoryPlaygroundRequest, + GetMemoryRequest, + GetMemoryResponse, MemoryResponse, SearchResponse, SuggestionRequest, @@ -160,8 +164,8 @@ def chat_complete(chat_req: APIChatCompleteRequest): return chat_handler.handle_chat_complete(chat_req) -@router.post("/chat", summary="Chat with MemOS") -def chat(chat_req: ChatRequest): +@router.post("/chat/stream", summary="Chat with MemOS") +def chat_stream(chat_req: ChatRequest): """ Chat with MemOS for a specific user. Returns SSE stream. @@ -171,6 +175,17 @@ def chat(chat_req: ChatRequest): return chat_handler.handle_chat_stream(chat_req) +@router.post("/chat/stream/playground", summary="Chat with MemOS playground") +def chat_stream_playground(chat_req: ChatRequest): + """ + Chat with MemOS for a specific user. Returns SSE stream. + + This endpoint uses the class-based ChatHandler which internally + composes SearchHandler and AddHandler for a clean architecture. + """ + return chat_handler.handle_chat_stream_playground(chat_req) + + # ============================================================================= # Suggestion API Endpoints # ============================================================================= @@ -193,12 +208,12 @@ def get_suggestion_queries(suggestion_req: SuggestionRequest): # ============================================================================= -# Memory Retrieval API Endpoints +# Memory Retrieval Delete API Endpoints # ============================================================================= @router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse) -def get_all_memories(memory_req: GetAllMemoryRequest): +def get_all_memories(memory_req: GetMemoryPlaygroundRequest): """ Get all memories or subgraph for a specific user. @@ -224,3 +239,18 @@ def get_all_memories(memory_req: GetAllMemoryRequest): memory_type=memory_req.memory_type or "text_mem", naive_mem_cube=naive_mem_cube, ) + + +@router.post("/get_memory", summary="Get memories for user", response_model=GetMemoryResponse) +def get_memories(memory_req: GetMemoryRequest): + return handlers.memory_handler.handle_get_memories( + get_mem_req=memory_req, + naive_mem_cube=naive_mem_cube, + ) + + +@router.post( + "/delete_memory", summary="Delete memories for user", response_model=DeleteMemoryResponse +) +def delete_memories(memory_req: DeleteMemoryRequest): + return handlers.memory_handler.handle_delete_memories(memory_req) diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index b62b1e51d..e2e30f5ad 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -61,7 +61,6 @@ def init_mem_cube(self, mem_cube): self.text_mem: TreeTextMemory = self.current_mem_cube.text_mem self.searcher: Searcher = self.text_mem.get_searcher( manual_close_internet=False, - moscube=False, ) self.reranker: HTTPBGEReranker = self.text_mem.reranker @@ -80,7 +79,6 @@ def submit_memory_history_async_task( "session_id": session_id, "top_k": search_req.top_k, "internet_search": search_req.internet_search, - "moscube": search_req.moscube, "chat_history": search_req.chat_history, }, "user_context": {"mem_cube_id": user_context.mem_cube_id}, @@ -123,7 +121,6 @@ def search_memories( top_k=search_req.top_k, mode=mode, manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, search_filter=search_filter, info={ "user_id": search_req.user_id, @@ -162,7 +159,6 @@ def mix_search_memories( top_k=search_req.top_k, mode=SearchMode.FAST, manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, search_filter=search_filter, info=info, ) diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 5f85aa907..6e196e23a 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -190,7 +190,7 @@ def get_with_collection_name( return None return TextualMemoryItem( id=res.id, - memory=res.payload.get("dialog_str", ""), + memory=res.memory, metadata=PreferenceTextualMemoryMetadata(**res.payload), ) except Exception as e: @@ -225,7 +225,7 @@ def get_by_ids_with_collection_name( return [ TextualMemoryItem( id=memo.id, - memory=memo.payload.get("dialog_str", ""), + memory=memo.memory, metadata=PreferenceTextualMemoryMetadata(**memo.payload), ) for memo in res @@ -248,19 +248,43 @@ def get_all(self) -> list[TextualMemoryItem]: all_memories[collection_name] = [ TextualMemoryItem( id=memo.id, - memory=memo.payload.get("dialog_str", ""), + memory=memo.memory, metadata=PreferenceTextualMemoryMetadata(**memo.payload), ) for memo in items ] return all_memories + def get_memory_by_filter(self, filter: dict[str, Any] | None = None) -> list[TextualMemoryItem]: + """Get memories by filter. + Args: + filter (dict[str, Any]): Filter criteria. + Returns: + list[TextualMemoryItem]: List of memories that match the filter. + """ + collection_list = self.vector_db.config.collection_name + all_db_items = [] + for collection_name in collection_list: + db_items = self.vector_db.get_by_filter(collection_name=collection_name, filter=filter) + all_db_items.extend(db_items) + memories = [ + TextualMemoryItem( + id=memo.id, + memory=memo.memory, + metadata=PreferenceTextualMemoryMetadata(**memo.payload), + ) + for memo in all_db_items + ] + return memories + def delete(self, memory_ids: list[str]) -> None: """Delete memories. Args: memory_ids (list[str]): List of memory IDs to delete. """ - raise NotImplementedError + collection_list = self.vector_db.config.collection_name + for collection_name in collection_list: + self.vector_db.delete(collection_name, memory_ids) def delete_with_collection_name(self, collection_name: str, memory_ids: list[str]) -> None: """Delete memories by their IDs and collection name. diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 15a6a8b49..b55c48dcb 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -129,7 +129,6 @@ def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int def get_searcher( self, manual_close_internet: bool = False, - moscube: bool = False, ): if (self.internet_retriever is not None) and manual_close_internet: logger.warning( @@ -141,7 +140,6 @@ def get_searcher( self.embedder, self.reranker, internet_retriever=None, - moscube=moscube, ) else: searcher = Searcher( @@ -150,7 +148,6 @@ def get_searcher( self.embedder, self.reranker, internet_retriever=self.internet_retriever, - moscube=moscube, ) return searcher @@ -162,7 +159,6 @@ def search( mode: str = "fast", memory_type: str = "All", manual_close_internet: bool = False, - moscube: bool = False, search_filter: dict | None = None, user_name: str | None = None, ) -> list[TextualMemoryItem]: @@ -179,7 +175,6 @@ def search( memory_type (str): Type restriction for search. ['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory'] manual_close_internet (bool): If True, the internet retriever will be closed by this search, it high priority than config. - moscube (bool): whether you use moscube to answer questions search_filter (dict, optional): Optional metadata filters for search results. - Keys correspond to memory metadata fields (e.g., "user_id", "session_id"). - Values are exact-match conditions. @@ -199,7 +194,6 @@ def search( self.reranker, bm25_retriever=self.bm25_retriever, internet_retriever=None, - moscube=moscube, search_strategy=self.search_strategy, ) else: @@ -210,7 +204,6 @@ def search( self.reranker, bm25_retriever=self.bm25_retriever, internet_retriever=self.internet_retriever, - moscube=moscube, search_strategy=self.search_strategy, ) return searcher.search( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index f196c5569..ae4dde446 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -41,7 +41,6 @@ def __init__( reranker: BaseReranker, bm25_retriever: EnhancedBM25 | None = None, internet_retriever: None = None, - moscube: bool = False, search_strategy: dict | None = None, ): self.graph_store = graph_store @@ -55,7 +54,6 @@ def __init__( # Create internet retriever from config if provided self.internet_retriever = internet_retriever - self.moscube = moscube self.vec_cot = search_strategy.get("cot", False) if search_strategy else False self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False @@ -296,17 +294,6 @@ def _retrieve_paths( user_name, ) ) - if self.moscube: - tasks.append( - executor.submit( - self._retrieve_from_memcubes, - query, - parsed_goal, - query_embedding, - top_k, - "memos_cube01", - ) - ) results = [] for t in tasks: From 1d61e408db1e44c18f0d7acca255a1a1c2d06b94 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 20 Nov 2025 20:17:14 +0800 Subject: [PATCH 06/25] modify code --- src/memos/api/handlers/chat_handler.py | 28 +++++++++++++++----------- src/memos/api/product_models.py | 1 + 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index a3a187982..049513527 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -7,6 +7,7 @@ import asyncio import json +import re import traceback from collections.abc import Generator @@ -149,19 +150,20 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An self.logger.info("Starting to generate complete response...") # Step 3: Generate complete response from LLM - response = self.llm.generate(current_messages) - - # Step 4: start add after chat asynchronously - self._start_add_to_memory( - user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, - session_id=chat_req.session_id or "default_session", - query=chat_req.query, - full_response=response, - async_mode="async", + response = self.llm.generate( + current_messages, model_name_or_path=chat_req.model_name_or_path ) - import re + # Step 4: start add after chat asynchronously + if chat_req.add_message_on_answer: + self._start_add_to_memory( + user_id=chat_req.user_id, + cube_id=chat_req.mem_cube_id, + session_id=chat_req.session_id or "default_session", + query=chat_req.query, + full_response=response, + async_mode="async", + ) match = re.search(r"([\s\S]*?)", response) reasoning_text = match.group(1) if match else None @@ -253,7 +255,9 @@ def generate_chat_response() -> Generator[str, None, None]: ) # Step 3: Generate streaming response from LLM - response_stream = self.llm.generate_stream(current_messages) + response_stream = self.llm.generate_stream( + current_messages, chat_req.model_name_or_path + ) # Stream the response buffer = "" diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 26b2f9218..33a7f805c 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -73,6 +73,7 @@ class ChatRequest(BaseRequest): query: str = Field(..., description="Chat query message") mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") history: list[MessageDict] | None = Field(None, description="Chat history") + mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") internet_search: bool = Field(True, description="Whether to use internet search") system_prompt: str | None = Field(None, description="Base system prompt to use for chat") top_k: int = Field(10, description="Number of results to return") From ee301b523843b8eb15c68f6b078e6dbc24043dfc Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 20 Nov 2025 20:36:52 +0800 Subject: [PATCH 07/25] modify code --- src/memos/api/routers/server_router.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 592fb814a..a5cbdfeb4 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -241,7 +241,7 @@ def get_all_memories(memory_req: GetMemoryPlaygroundRequest): ) -@router.post("/get_memory", summary="Get memories for user", response_model=GetMemoryResponse) +@router.get("/get_memory", summary="Get memories for user", response_model=GetMemoryResponse) def get_memories(memory_req: GetMemoryRequest): return handlers.memory_handler.handle_get_memories( get_mem_req=memory_req, From 43a7903fa6cd866ed2406ebc668586fe2a6a3da9 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 21 Nov 2025 16:26:50 +0800 Subject: [PATCH 08/25] coding chat --- src/memos/api/handlers/chat_handler.py | 54 ++++++++++++++++------- src/memos/api/handlers/component_init.py | 38 ++++++++++++++++ src/memos/api/handlers/config_builders.py | 27 ++++++++++++ src/memos/api/routers/server_router.py | 6 ++- src/memos/configs/llm.py | 3 ++ src/memos/llms/openai.py | 4 +- src/memos/llms/openai_new.py | 4 +- src/memos/llms/vllm.py | 4 +- 8 files changed, 121 insertions(+), 19 deletions(-) diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 049513527..769ab32d3 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -53,6 +53,7 @@ class ChatHandler(BaseHandler): def __init__( self, dependencies: HandlerDependencies, + chat_llms: dict[str, Any], search_handler=None, add_handler=None, online_bot=None, @@ -62,6 +63,7 @@ def __init__( Args: dependencies: HandlerDependencies instance + chat_llms: Dictionary mapping model names to LLM instances search_handler: Optional SearchHandler instance (created if not provided) add_handler: Optional AddHandler instance (created if not provided) online_bot: Optional DingDing bot function for notifications @@ -80,6 +82,7 @@ def __init__( add_handler = AddHandler(dependencies) + self.chat_llms = chat_llms self.search_handler = search_handler self.add_handler = add_handler self.online_bot = online_bot @@ -150,9 +153,12 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An self.logger.info("Starting to generate complete response...") # Step 3: Generate complete response from LLM - response = self.llm.generate( - current_messages, model_name_or_path=chat_req.model_name_or_path - ) + if chat_req.model_name_or_path and chat_req.model_name_or_path not in self.chat_llms: + return { + "message": f"Model {chat_req.model_name_or_path} not suport, choose from {list(self.chat_llms.keys())}" + } + model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys())) + response = self.chat_llms[model].generate(current_messages, model_name_or_path=model) # Step 4: start add after chat asynchronously if chat_req.add_message_on_answer: @@ -255,8 +261,16 @@ def generate_chat_response() -> Generator[str, None, None]: ) # Step 3: Generate streaming response from LLM - response_stream = self.llm.generate_stream( - current_messages, chat_req.model_name_or_path + if ( + chat_req.model_name_or_path + and chat_req.model_name_or_path not in self.chat_llms + ): + return { + "message": f"Model {chat_req.model_name_or_path} not suport, choose from {list(self.chat_llms.keys())}" + } + model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys())) + response_stream = self.chat_llms[model].generate_stream( + current_messages, model_name_or_path=model ) # Stream the response @@ -284,15 +298,15 @@ def generate_chat_response() -> Generator[str, None, None]: yield chunk_data current_messages.append({"role": "assistant", "content": full_response}) - - self._start_add_to_memory( - user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, - session_id=chat_req.session_id or "default_session", - query=chat_req.query, - full_response=full_response, - async_mode="async", - ) + if chat_req.add_message_on_answer: + self._start_add_to_memory( + user_id=chat_req.user_id, + cube_id=chat_req.mem_cube_id, + session_id=chat_req.session_id or "default_session", + query=chat_req.query, + full_response=full_response, + async_mode="async", + ) except Exception as e: self.logger.error(f"Error in chat stream: {e}", exc_info=True) @@ -410,7 +424,17 @@ def generate_chat_response() -> Generator[str, None, None]: yield f"data: {json.dumps({'type': 'status', 'data': '2'})}\n\n" # Step 3: Generate streaming response from LLM - response_stream = self.llm.generate_stream(current_messages) + if ( + chat_req.model_name_or_path + and chat_req.model_name_or_path not in self.chat_llms + ): + return { + "message": f"Model {chat_req.model_name_or_path} not suport, choose from {list(self.chat_llms.keys())}" + } + model = chat_req.model_name_or_path or next(iter(self.chat_llms.keys())) + response_stream = self.chat_llms[model].generate_stream( + current_messages, model_name_or_path=model + ) # Stream the response buffer = "" diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 4e696a341..1b650093b 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -9,6 +9,7 @@ from memos.api.config import APIConfig from memos.api.handlers.config_builders import ( + build_chat_llm_config, build_embedder_config, build_graph_db_config, build_internet_retriever_config, @@ -71,6 +72,38 @@ def _get_default_memory_size(cube_config: Any) -> dict[str, int]: } +def _init_chat_llms(chat_llm_configs: list[dict]) -> dict[str, Any]: + """ + Initialize chat language models from configuration. + + Args: + chat_llm_configs: List of chat LLM configuration dictionaries + + Returns: + Dictionary mapping model names to initialized LLM instances + """ + + def _list_models(client): + try: + models = ( + [model.id for model in client.models.list().data] + if client.models.list().data + else client.models.list().models + ) + except Exception as e: + logger.error(f"Error listing models: {e}") + models = [] + return models + + model_name_instrance_maping = {} + for cfg in chat_llm_configs: + llm = LLMFactory.from_config(cfg["config_class"]) + if cfg["support_models"]: + for model_name in cfg["support_models"]: + model_name_instrance_maping[model_name] = llm + return model_name_instrance_maping + + def init_server() -> dict[str, Any]: """ Initialize all server components and configurations. @@ -98,6 +131,7 @@ def init_server() -> dict[str, Any]: # Build component configurations graph_db_config = build_graph_db_config() llm_config = build_llm_config() + chat_llm_config = build_chat_llm_config() embedder_config = build_embedder_config() mem_reader_config = build_mem_reader_config() reranker_config = build_reranker_config() @@ -113,6 +147,7 @@ def init_server() -> dict[str, Any]: graph_db = GraphStoreFactory.from_config(graph_db_config) vector_db = VecDBFactory.from_config(vector_db_config) llm = LLMFactory.from_config(llm_config) + chat_llms = _init_chat_llms(chat_llm_config) embedder = EmbedderFactory.from_config(embedder_config) mem_reader = MemReaderFactory.from_config(mem_reader_config) reranker = RerankerFactory.from_config(reranker_config) @@ -120,6 +155,8 @@ def init_server() -> dict[str, Any]: internet_retriever_config, embedder=embedder ) + # Initialize chat llms + logger.debug("Core components instantiated") # Initialize memory manager @@ -245,6 +282,7 @@ def init_server() -> dict[str, Any]: "graph_db": graph_db, "mem_reader": mem_reader, "llm": llm, + "chat_llms": chat_llms, "embedder": embedder, "reranker": reranker, "internet_retriever": internet_retriever, diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py index 9f510add0..4a83700d0 100644 --- a/src/memos/api/handlers/config_builders.py +++ b/src/memos/api/handlers/config_builders.py @@ -6,6 +6,7 @@ a configuration dictionary using the appropriate ConfigFactory. """ +import json import os from typing import Any @@ -81,6 +82,32 @@ def build_llm_config() -> dict[str, Any]: ) +def build_chat_llm_config() -> list[dict[str, Any]]: + """ + Build chat LLM configuration. + + Returns: + Validated chat LLM configuration dictionary + """ + configs = json.loads(os.getenv("CHAT_MODEL_LIST")) + return [ + { + "config_class": LLMConfigFactory.model_validate( + { + "backend": cfg.get("backend", "openai"), + "config": ( + {k: v for k, v in cfg.items() if k not in ["backend", "support_models"]} + ) + if cfg + else APIConfig.get_openai_config(), + } + ), + "support_models": cfg.get("support_models", None), + } + for cfg in configs + ] + + def build_embedder_config() -> dict[str, Any]: """ Build embedder configuration. diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index a5cbdfeb4..bf8284c58 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -59,7 +59,11 @@ search_handler = SearchHandler(dependencies) add_handler = AddHandler(dependencies) chat_handler = ChatHandler( - dependencies, search_handler, add_handler, online_bot=components.get("online_bot") + dependencies, + components["chat_llms"], + search_handler, + add_handler, + online_bot=components.get("online_bot"), ) # Extract commonly used components for function-based handlers diff --git a/src/memos/configs/llm.py b/src/memos/configs/llm.py index e446e72bb..70217b896 100644 --- a/src/memos/configs/llm.py +++ b/src/memos/configs/llm.py @@ -17,6 +17,9 @@ class BaseLLMConfig(BaseConfig): default=False, description="Remove content within think tags from the generated text", ) + default_headers: dict[str, Any] | None = Field( + default=None, description="Default headers for LLM requests" + ) class OpenAILLMConfig(BaseLLMConfig): diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 400d1fff9..1e9f91e5b 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -23,7 +23,9 @@ class OpenAILLM(BaseLLM): def __init__(self, config: OpenAILLMConfig): self.config = config - self.client = openai.Client(api_key=config.api_key, base_url=config.api_base) + self.client = openai.Client( + api_key=config.api_key, base_url=config.api_base, default_headers=config.default_headers + ) logger.info("OpenAI LLM instance initialized") @timed(log=True, log_prefix="OpenAI LLM") diff --git a/src/memos/llms/openai_new.py b/src/memos/llms/openai_new.py index c1e0ee8a6..766a17fda 100644 --- a/src/memos/llms/openai_new.py +++ b/src/memos/llms/openai_new.py @@ -22,7 +22,9 @@ class OpenAIResponsesLLM(BaseLLM): def __init__(self, config: OpenAILLMConfig): self.config = config - self.client = openai.Client(api_key=config.api_key, base_url=config.api_base) + self.client = openai.Client( + api_key=config.api_key, base_url=config.api_base, default_headers=config.default_headers + ) @timed(log=True, log_prefix="OpenAI Responses LLM") def generate(self, messages: MessageList, **kwargs) -> str: diff --git a/src/memos/llms/vllm.py b/src/memos/llms/vllm.py index fdb07a1b8..1cf8d4f39 100644 --- a/src/memos/llms/vllm.py +++ b/src/memos/llms/vllm.py @@ -34,7 +34,9 @@ def __init__(self, config: VLLMLLMConfig): api_key = "dummy" self.client = openai.Client( - api_key=api_key, base_url=getattr(self.config, "api_base", "http://localhost:8088/v1") + api_key=api_key, + base_url=getattr(self.config, "api_base", "http://localhost:8088/v1"), + default_headers=self.config.default_headers, ) def build_vllm_kv_cache(self, messages: Any) -> str: From b4fe866a6fb96b13956920fff8a19725e8ea74e0 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 21 Nov 2025 17:22:10 +0800 Subject: [PATCH 09/25] fix bug in get and delete --- src/memos/api/handlers/memory_handler.py | 6 +++--- src/memos/api/routers/server_router.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index bb5672740..c47a3cf83 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -160,7 +160,7 @@ def handle_get_subgraph( def handle_get_memories(get_mem_req: GetMemoryRequest, naive_mem_cube: Any) -> GetMemoryResponse: # TODO: Implement get memory with filter - memories = naive_mem_cube.text_mem.get_all(user_name=get_mem_req.mem_cube_id) + memories = naive_mem_cube.text_mem.get_all(user_name=get_mem_req.mem_cube_id)["nodes"] filter_params: dict[str, Any] = {} if get_mem_req.user_id is not None: filter_params["user_id"] = get_mem_req.user_id @@ -170,7 +170,7 @@ def handle_get_memories(get_mem_req: GetMemoryRequest, naive_mem_cube: Any) -> G return GetMemoryResponse( message="Memories retrieved successfully", data={ - "text_mem": [format_memory_item(mem) for mem in memories], + "text_mem": memories, "pref_mem": [format_memory_item(mem) for mem in preferences], }, ) @@ -188,5 +188,5 @@ def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: ) return DeleteMemoryResponse( message="Memories deleted successfully", - data="success", + data={"status": "success"}, ) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index bf8284c58..3f3b10134 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -245,7 +245,7 @@ def get_all_memories(memory_req: GetMemoryPlaygroundRequest): ) -@router.get("/get_memory", summary="Get memories for user", response_model=GetMemoryResponse) +@router.post("/get_memory", summary="Get memories for user", response_model=GetMemoryResponse) def get_memories(memory_req: GetMemoryRequest): return handlers.memory_handler.handle_get_memories( get_mem_req=memory_req, @@ -257,4 +257,6 @@ def get_memories(memory_req: GetMemoryRequest): "/delete_memory", summary="Delete memories for user", response_model=DeleteMemoryResponse ) def delete_memories(memory_req: DeleteMemoryRequest): - return handlers.memory_handler.handle_delete_memories(memory_req) + return handlers.memory_handler.handle_delete_memories( + delete_mem_req=memory_req, naive_mem_cube=naive_mem_cube + ) From b1053c4185410ab944e8bcfecfe32f2a550fe2cd Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 21 Nov 2025 23:56:12 +0800 Subject: [PATCH 10/25] add internet reference in playground chat stream --- src/memos/api/handlers/chat_handler.py | 22 +++++++++++++++++++ .../tree_text_memory/retrieve/bochasearch.py | 11 +++++++++- .../tree_text_memory/retrieve/searcher.py | 2 +- 3 files changed, 33 insertions(+), 2 deletions(-) diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 769ab32d3..5571578be 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -394,7 +394,12 @@ def generate_chat_response() -> Generator[str, None, None]: # Prepare reference data reference = prepare_reference_data(filtered_memories) + # get internet reference + internet_reference = self._get_internet_reference( + search_response.data.get("text_mem")[0]["memories"] + ) yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" + yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference})}\n\n" # Prepare preference markdown string if chat_req.include_preference: @@ -535,6 +540,23 @@ def generate_chat_response() -> Generator[str, None, None]: self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}") raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + def _get_internet_reference( + self, search_response: list[dict[str, any]] + ) -> list[dict[str, any]]: + """Get internet reference from search response.""" + unique_set = set() + result = [] + + for item in search_response: + meta = item.get("metadata", {}) + if meta.get("source") == "web" and meta.get("internet_info"): + info = meta.get("internet_info") + key = json.dumps(info, sort_keys=True) + if key not in unique_set: + unique_set.add(key) + result.append(info) + return result + def _build_pref_md_string_for_playground(self, pref_mem_list: list[any]) -> str: """Build preference markdown string for playground.""" explicit = [] diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py b/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py index 31b914776..042ed837e 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/bochasearch.py @@ -200,9 +200,11 @@ def _process_result( """Process one Bocha search result into TextualMemoryItem.""" title = result.get("name", "") content = result.get("summary", "") or result.get("snippet", "") - summary = result.get("snippet", "") + summary = result.get("summary", "") or result.get("snippet", "") url = result.get("url", "") publish_time = result.get("datePublished", "") + site_name = result.get("siteName", "") + site_icon = result.get("siteIcon") if publish_time: try: @@ -229,5 +231,12 @@ def _process_result( read_item_i.metadata.memory_type = "OuterMemory" read_item_i.metadata.sources = [SourceMessage(type="web", url=url)] if url else [] read_item_i.metadata.visibility = "public" + read_item_i.metadata.internet_info = { + "title": title, + "url": url, + "site_name": site_name, + "site_icon": site_icon, + "summary": summary, + } memory_items.append(read_item_i) return memory_items diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index ae4dde446..d2d6d5efb 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -445,7 +445,7 @@ def _retrieve_from_internet( user_id: str | None = None, ): """Retrieve and rerank from Internet source""" - if not self.internet_retriever or mode == "fast": + if not self.internet_retriever: logger.info(f"[PATH-C] '{query}' Skipped (no retriever, fast mode)") return [] if memory_type not in ["All"]: From 83702f0c62c3932042a62842110838b0583170c3 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Sat, 22 Nov 2025 00:53:02 +0800 Subject: [PATCH 11/25] remove moscube --- src/memos/api/handlers/component_init.py | 1 - src/memos/mem_scheduler/base_scheduler.py | 9 ++++----- src/memos/multi_mem_cube/single_cube.py | 2 -- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 9d329ed58..3ef1d529d 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -271,7 +271,6 @@ def init_server() -> dict[str, Any]: tree_mem: TreeTextMemory = naive_mem_cube.text_mem searcher: Searcher = tree_mem.get_searcher( manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", - moscube=False, ) logger.debug("Searcher created") diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 63b87157c..a53e19191 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -152,7 +152,6 @@ def init_mem_cube( if searcher is None: self.searcher: Searcher = self.text_mem.get_searcher( manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", - moscube=False, ) else: self.searcher = searcher @@ -577,12 +576,12 @@ def get_web_log_messages(self) -> list[dict]: def _map_label(label: str) -> str: from memos.mem_scheduler.schemas.general_schemas import ( - QUERY_LABEL, - ANSWER_LABEL, ADD_LABEL, - MEM_UPDATE_LABEL, - MEM_ORGANIZE_LABEL, + ANSWER_LABEL, MEM_ARCHIVE_LABEL, + MEM_ORGANIZE_LABEL, + MEM_UPDATE_LABEL, + QUERY_LABEL, ) mapping = { diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index f34cad1ef..2055615d2 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -232,7 +232,6 @@ def _fast_search( top_k=search_req.top_k, mode=SearchMode.FAST, manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, search_filter=search_filter, info={ "user_id": search_req.user_id, @@ -287,7 +286,6 @@ def _fine_search( top_k=search_req.top_k, mode=SearchMode.FINE, manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, search_filter=search_filter, info=info, ) From aa73811aef80d4983dc781cabe7827878959cc50 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Sat, 22 Nov 2025 09:09:45 +0800 Subject: [PATCH 12/25] modify code --- src/memos/api/handlers/chat_handler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 2d172bd46..2f40f1c91 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -399,7 +399,6 @@ def generate_chat_response() -> Generator[str, None, None]: search_response.data.get("text_mem")[0]["memories"] ) yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" - yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference})}\n\n" # Prepare preference markdown string if chat_req.include_preference: @@ -479,6 +478,9 @@ def generate_chat_response() -> Generator[str, None, None]: chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n" yield chunk_data + # Yield internet reference after text response + yield f"data: {json.dumps({'type': 'internet_reference', 'data': internet_reference})}\n\n" + # Calculate timing time_end = time.time() speed_improvement = round(float((len(system_prompt) / 2) * 0.0048 + 44.5), 1) From 42f34038e309390b12b32b58684f39e651d54147 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Sat, 22 Nov 2025 09:29:29 +0800 Subject: [PATCH 13/25] fix pre_commit --- examples/mem_scheduler/memos_w_scheduler.py | 15 +++++------ .../general_modules/scheduler_logger.py | 7 ++--- src/memos/mem_scheduler/general_scheduler.py | 6 ++--- src/memos/types/__init__.py | 4 ++- .../openai_chat_completion_types/__init__.py | 22 +++++++-------- ...chat_completion_assistant_message_param.py | 27 ++++++++++++------- ...hat_completion_content_part_image_param.py | 5 +++- ...mpletion_content_part_input_audio_param.py | 5 +++- .../chat_completion_content_part_param.py | 20 +++++++------- ...t_completion_content_part_refusal_param.py | 5 +++- ...chat_completion_content_part_text_param.py | 5 +++- ...mpletion_message_custom_tool_call_param.py | 5 +++- ...letion_message_function_tool_call_param.py | 5 +++- .../chat_completion_message_param.py | 20 +++++++------- ...ompletion_message_tool_call_union_param.py | 14 +++++----- .../chat_completion_system_message_param.py | 15 +++++++---- .../chat_completion_tool_message_param.py | 15 ++++++----- .../chat_completion_user_message_param.py | 15 +++++++---- src/memos/types/types.py | 14 ++++++---- 19 files changed, 135 insertions(+), 89 deletions(-) diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py index 17bfd3993..7d8cf2897 100644 --- a/examples/mem_scheduler/memos_w_scheduler.py +++ b/examples/mem_scheduler/memos_w_scheduler.py @@ -1,29 +1,28 @@ +import re import shutil import sys +from datetime import datetime from pathlib import Path from queue import Queue + from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig -from datetime import datetime -import re - from memos.configs.mem_scheduler import AuthConfig from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_os.main import MOS +from memos.mem_scheduler.general_scheduler import GeneralScheduler from memos.mem_scheduler.schemas.general_schemas import ( - QUERY_LABEL, - ANSWER_LABEL, ADD_LABEL, + ANSWER_LABEL, + MEM_ARCHIVE_LABEL, MEM_ORGANIZE_LABEL, MEM_UPDATE_LABEL, - MEM_ARCHIVE_LABEL, - NOT_APPLICABLE_TYPE, + QUERY_LABEL, ) from memos.mem_scheduler.schemas.message_schemas import ScheduleLogForWebItem from memos.mem_scheduler.utils.filter_utils import transform_name_to_key -from memos.mem_scheduler.general_scheduler import GeneralScheduler FILE_PATH = Path(__file__).absolute() diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py index 3859c9e6f..7da531a7f 100644 --- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py @@ -1,3 +1,5 @@ +import hashlib + from collections.abc import Callable from memos.log import get_logger @@ -6,13 +8,13 @@ from memos.mem_scheduler.schemas.general_schemas import ( ACTIVATION_MEMORY_TYPE, ADD_LABEL, + MEM_ARCHIVE_LABEL, + MEM_UPDATE_LABEL, NOT_INITIALIZED, PARAMETER_MEMORY_TYPE, TEXT_MEMORY_TYPE, USER_INPUT_TYPE, WORKING_MEMORY_TYPE, - MEM_UPDATE_LABEL, - MEM_ARCHIVE_LABEL, ) from memos.mem_scheduler.schemas.message_schemas import ( ScheduleLogForWebItem, @@ -23,7 +25,6 @@ ) from memos.mem_scheduler.utils.misc_utils import log_exceptions from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory -import hashlib logger = get_logger(__name__) diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index eeca890a9..e0d18dc72 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -12,14 +12,14 @@ ADD_LABEL, ANSWER_LABEL, DEFAULT_MAX_QUERY_KEY_WORDS, + LONG_TERM_MEMORY_TYPE, MEM_ORGANIZE_LABEL, MEM_READ_LABEL, + NOT_APPLICABLE_TYPE, PREF_ADD_LABEL, QUERY_LABEL, - WORKING_MEMORY_TYPE, USER_INPUT_TYPE, - NOT_APPLICABLE_TYPE, - LONG_TERM_MEMORY_TYPE, + WORKING_MEMORY_TYPE, MemCubeID, UserID, ) diff --git a/src/memos/types/__init__.py b/src/memos/types/__init__.py index 4192f6a10..dd1b98305 100644 --- a/src/memos/types/__init__.py +++ b/src/memos/types/__init__.py @@ -1 +1,3 @@ -from .types import * \ No newline at end of file +# ruff: noqa: F403, F401 + +from .types import * diff --git a/src/memos/types/openai_chat_completion_types/__init__.py b/src/memos/types/openai_chat_completion_types/__init__.py index 3d742fe3b..4a08a9f24 100644 --- a/src/memos/types/openai_chat_completion_types/__init__.py +++ b/src/memos/types/openai_chat_completion_types/__init__.py @@ -1,17 +1,15 @@ -from .chat_completion_message_param import * +# ruff: noqa: F403, F401 from .chat_completion_assistant_message_param import * -from .chat_completion_system_message_param import * -from .chat_completion_tool_message_param import * -from .chat_completion_user_message_param import * - -from .chat_completion_message_custom_tool_call_param import * -from .chat_completion_message_function_tool_call_param import * -from .chat_completion_message_tool_call_union_param import * - -from .chat_completion_content_part_input_audio_param import * from .chat_completion_content_part_image_param import * +from .chat_completion_content_part_input_audio_param import * +from .chat_completion_content_part_param import * from .chat_completion_content_part_refusal_param import * from .chat_completion_content_part_text_param import * -from .chat_completion_content_part_param import * - +from .chat_completion_message_custom_tool_call_param import * +from .chat_completion_message_function_tool_call_param import * +from .chat_completion_message_param import * +from .chat_completion_message_tool_call_union_param import * +from .chat_completion_system_message_param import * +from .chat_completion_tool_message_param import * +from .chat_completion_user_message_param import * diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py index 698f2a6e0..a742de3a9 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py @@ -1,13 +1,18 @@ +# ruff: noqa: TC001, TC003 + from __future__ import annotations -from typing import Union, Iterable, Optional -from typing_extensions import Literal, Required, TypeAlias, TypedDict +from collections.abc import Iterable +from typing import Literal, TypeAlias + +from typing_extensions import Required, TypedDict -from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam from .chat_completion_content_part_refusal_param import ChatCompletionContentPartRefusalParam +from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam from .chat_completion_message_tool_call_union_param import ChatCompletionMessageToolCallUnionParam -__all__ = ["ChatCompletionAssistantMessageParam", "Audio", "ContentArrayOfContentPart"] + +__all__ = ["Audio", "ChatCompletionAssistantMessageParam", "ContentArrayOfContentPart"] class Audio(TypedDict, total=False): @@ -15,34 +20,36 @@ class Audio(TypedDict, total=False): """Unique identifier for a previous audio response from the model.""" -ContentArrayOfContentPart: TypeAlias = Union[ChatCompletionContentPartTextParam, ChatCompletionContentPartRefusalParam] +ContentArrayOfContentPart: TypeAlias = ( + ChatCompletionContentPartTextParam | ChatCompletionContentPartRefusalParam +) class ChatCompletionAssistantMessageParam(TypedDict, total=False): role: Required[Literal["assistant"]] """The role of the messages author, in this case `assistant`.""" - audio: Optional[Audio] + audio: Audio | None """ Data about a previous audio response from the model. [Learn more](https://platform.openai.com/docs/guides/audio). """ - content: Union[str, Iterable[ContentArrayOfContentPart], None] + content: str | Iterable[ContentArrayOfContentPart] | None """The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified. """ - refusal: Optional[str] + refusal: str | None """The refusal message by the assistant.""" tool_calls: Iterable[ChatCompletionMessageToolCallUnionParam] """The tool calls generated by the model, such as function calls.""" - chat_time: Optional[str] + chat_time: str | None """Optional timestamp for the message, format is not restricted, it can be any vague or precise time string.""" - message_id: Optional[str] + message_id: str | None """Optional unique identifier for the message""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py index f57ab33cb..6718bd91e 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_image_param.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing_extensions import Literal, Required, TypedDict +from typing import Literal + +from typing_extensions import Required, TypedDict + __all__ = ["ChatCompletionContentPartImageParam", "ImageURL"] diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py index be90f84db..e7cfa4504 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_input_audio_param.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing_extensions import Literal, Required, TypedDict +from typing import Literal + +from typing_extensions import Required, TypedDict + __all__ = ["ChatCompletionContentPartInputAudioParam", "InputAudio"] diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py index 65ce3b2ee..a5e740791 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_param.py @@ -1,11 +1,13 @@ from __future__ import annotations -from typing import Union -from typing_extensions import Literal, Required, TypeAlias, TypedDict +from typing import Literal, TypeAlias + +from typing_extensions import Required, TypedDict -from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam from .chat_completion_content_part_image_param import ChatCompletionContentPartImageParam from .chat_completion_content_part_input_audio_param import ChatCompletionContentPartInputAudioParam +from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam + __all__ = ["ChatCompletionContentPartParam", "File", "FileFile"] @@ -31,9 +33,9 @@ class File(TypedDict, total=False): """The type of the content part. Always `file`.""" -ChatCompletionContentPartParam: TypeAlias = Union[ - ChatCompletionContentPartTextParam, - ChatCompletionContentPartImageParam, - ChatCompletionContentPartInputAudioParam, - File, -] +ChatCompletionContentPartParam: TypeAlias = ( + ChatCompletionContentPartTextParam + | ChatCompletionContentPartImageParam + | ChatCompletionContentPartInputAudioParam + | File +) diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py index f239c48d5..fc87e9e1a 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_refusal_param.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing_extensions import Literal, Required, TypedDict +from typing import Literal + +from typing_extensions import Required, TypedDict + __all__ = ["ChatCompletionContentPartRefusalParam"] diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py index e15461ab4..f43de0eff 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_content_part_text_param.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing_extensions import Literal, Required, TypedDict +from typing import Literal + +from typing_extensions import Required, TypedDict + __all__ = ["ChatCompletionContentPartTextParam"] diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py index 8bcba4c59..bc7a22edb 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_custom_tool_call_param.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing_extensions import Literal, Required, TypedDict +from typing import Literal + +from typing_extensions import Required, TypedDict + __all__ = ["ChatCompletionMessageCustomToolCallParam", "Custom"] diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py index 01a910787..56341d94a 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_function_tool_call_param.py @@ -1,6 +1,9 @@ from __future__ import annotations -from typing_extensions import Literal, Required, TypedDict +from typing import Literal + +from typing_extensions import Required, TypedDict + __all__ = ["ChatCompletionMessageFunctionToolCallParam", "Function"] diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_param.py index 5beee37a9..06a624297 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_param.py @@ -1,18 +1,18 @@ from __future__ import annotations -from typing import Union -from typing_extensions import TypeAlias +from typing import TypeAlias +from .chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam +from .chat_completion_system_message_param import ChatCompletionSystemMessageParam from .chat_completion_tool_message_param import ChatCompletionToolMessageParam from .chat_completion_user_message_param import ChatCompletionUserMessageParam -from .chat_completion_system_message_param import ChatCompletionSystemMessageParam -from .chat_completion_assistant_message_param import ChatCompletionAssistantMessageParam + __all__ = ["ChatCompletionMessageParam"] -ChatCompletionMessageParam: TypeAlias = Union[ - ChatCompletionSystemMessageParam, - ChatCompletionUserMessageParam, - ChatCompletionAssistantMessageParam, - ChatCompletionToolMessageParam, -] +ChatCompletionMessageParam: TypeAlias = ( + ChatCompletionSystemMessageParam + | ChatCompletionUserMessageParam + | ChatCompletionAssistantMessageParam + | ChatCompletionToolMessageParam +) diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py index 97eccf344..28bb880cf 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_message_tool_call_union_param.py @@ -1,13 +1,15 @@ from __future__ import annotations -from typing import Union -from typing_extensions import TypeAlias +from typing import TypeAlias from .chat_completion_message_custom_tool_call_param import ChatCompletionMessageCustomToolCallParam -from .chat_completion_message_function_tool_call_param import ChatCompletionMessageFunctionToolCallParam +from .chat_completion_message_function_tool_call_param import ( + ChatCompletionMessageFunctionToolCallParam, +) + __all__ = ["ChatCompletionMessageToolCallUnionParam"] -ChatCompletionMessageToolCallUnionParam: TypeAlias = Union[ - ChatCompletionMessageFunctionToolCallParam, ChatCompletionMessageCustomToolCallParam -] +ChatCompletionMessageToolCallUnionParam: TypeAlias = ( + ChatCompletionMessageFunctionToolCallParam | ChatCompletionMessageCustomToolCallParam +) diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py index 544f0e977..7faa90e2e 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py @@ -1,15 +1,20 @@ +# ruff: noqa: TC001, TC003 + from __future__ import annotations -from typing import Union, Iterable, Optional -from typing_extensions import Literal, Required, TypedDict +from collections.abc import Iterable +from typing import Literal + +from typing_extensions import Required, TypedDict from .chat_completion_content_part_text_param import ChatCompletionContentPartTextParam + __all__ = ["ChatCompletionSystemMessageParam"] class ChatCompletionSystemMessageParam(TypedDict, total=False): - content: Required[Union[str, Iterable[ChatCompletionContentPartTextParam]]] + content: Required[str | Iterable[ChatCompletionContentPartTextParam]] """The contents of the system message.""" role: Required[Literal["system"]] @@ -22,9 +27,9 @@ class ChatCompletionSystemMessageParam(TypedDict, total=False): role. """ - chat_time: Optional[str] + chat_time: str | None """Optional timestamp for the message, format is not restricted, it can be any vague or precise time string.""" - message_id: Optional[str] + message_id: str | None """Optional unique identifier for the message""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py index 8fb75fe35..c03220915 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py @@ -1,17 +1,20 @@ -# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. +# ruff: noqa: TC001, TC003 from __future__ import annotations -from typing import Union, Iterable, Optional -from typing_extensions import Literal, Required, TypedDict +from collections.abc import Iterable +from typing import Literal + +from typing_extensions import Required, TypedDict from .chat_completion_content_part_param import ChatCompletionContentPartParam + __all__ = ["ChatCompletionToolMessageParam"] class ChatCompletionToolMessageParam(TypedDict, total=False): - content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]] + content: Required[str | Iterable[ChatCompletionContentPartParam]] """The contents of the tool message.""" role: Required[Literal["tool"]] @@ -20,9 +23,9 @@ class ChatCompletionToolMessageParam(TypedDict, total=False): tool_call_id: Required[str] """Tool call that this message is responding to.""" - chat_time: Optional[str] + chat_time: str | None """Optional timestamp for the message, format is not restricted, it can be any vague or precise time string.""" - message_id: Optional[str] + message_id: str | None """Optional unique identifier for the message""" diff --git a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py index c48240c71..2c2a1f23f 100644 --- a/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py +++ b/src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py @@ -1,15 +1,20 @@ +# ruff: noqa: TC001, TC003 + from __future__ import annotations -from typing import Union, Iterable, Optional -from typing_extensions import Literal, Required, TypedDict +from collections.abc import Iterable +from typing import Literal + +from typing_extensions import Required, TypedDict from .chat_completion_content_part_param import ChatCompletionContentPartParam + __all__ = ["ChatCompletionUserMessageParam"] class ChatCompletionUserMessageParam(TypedDict, total=False): - content: Required[Union[str, Iterable[ChatCompletionContentPartParam]]] + content: Required[str | Iterable[ChatCompletionContentPartParam]] """The contents of the user message.""" role: Required[Literal["user"]] @@ -22,9 +27,9 @@ class ChatCompletionUserMessageParam(TypedDict, total=False): role. """ - chat_time: Optional[str] + chat_time: str | None """Optional timestamp for the message, format is not restricted, it can be any vague or precise time string.""" - message_id: Optional[str] + message_id: str | None """Optional unique identifier for the message""" diff --git a/src/memos/types/types.py b/src/memos/types/types.py index dae741afc..b8efc6208 100644 --- a/src/memos/types/types.py +++ b/src/memos/types/types.py @@ -13,15 +13,20 @@ from memos.memories.activation.item import ActivationMemoryItem from memos.memories.parametric.item import ParametricMemoryItem from memos.memories.textual.item import TextualMemoryItem -from .openai_chat_completion_types import ChatCompletionMessageParam, ChatCompletionContentPartTextParam, File + +from .openai_chat_completion_types import ( + ChatCompletionContentPartTextParam, + ChatCompletionMessageParam, + File, +) __all__ = [ - "MessageRole", - "MessageDict", - "MessageList", "ChatHistory", "MOSSearchResult", + "MessageDict", + "MessageList", + "MessageRole", "Permission", "PermissionDict", "UserContext", @@ -56,7 +61,6 @@ class MessageDict(TypedDict, total=False): MessagesType: TypeAlias = str | MessageList | RawMessageList - # Chat history structure class ChatHistory(BaseModel): """Model to represent chat history for export.""" From 86403ddbc7140d2ca3af285a4a1fddd8a52cdabf Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Sat, 22 Nov 2025 11:43:11 +0800 Subject: [PATCH 14/25] fix make test --- src/memos/llms/ollama.py | 5 ++-- src/memos/llms/openai.py | 16 ++++++------- tests/configs/test_llm.py | 13 +++++++++- tests/llms/test_deepseek.py | 10 ++++---- tests/llms/test_ollama.py | 47 +++++++++++++++++++++---------------- tests/llms/test_openai.py | 1 + tests/llms/test_qwen.py | 8 ++++--- 7 files changed, 62 insertions(+), 38 deletions(-) diff --git a/src/memos/llms/ollama.py b/src/memos/llms/ollama.py index c8643c763..bd92f9625 100644 --- a/src/memos/llms/ollama.py +++ b/src/memos/llms/ollama.py @@ -77,8 +77,9 @@ def generate(self, messages: MessageList, **kwargs) -> Any: tools=kwargs.get("tools"), ) logger.info(f"Raw response from Ollama: {response.model_dump_json()}") - if response.message.tool_calls: - return self.tool_call_parser(response.message.tool_calls) + tool_calls = getattr(response.message, "tool_calls", None) + if isinstance(tool_calls, list) and len(tool_calls) > 0: + return self.tool_call_parser(tool_calls) str_thinking = ( f"{response.message.thinking}" diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 1e9f91e5b..9b348adcf 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -41,18 +41,18 @@ def generate(self, messages: MessageList, **kwargs) -> str: tools=kwargs.get("tools", NOT_GIVEN), ) logger.info(f"Response from OpenAI: {response.model_dump_json()}") - if response.choices[0].message.tool_calls: - return self.tool_call_parser(response.choices[0].message.tool_calls) - reasoning_content = ( - f"{response.choices[0].message.reasoning_content}" - if hasattr(response.choices[0].message, "reasoning_content") - else "" - ) + tool_calls = getattr(response.choices[0].message, "tool_calls", None) + if isinstance(tool_calls, list) and len(tool_calls) > 0: + return self.tool_call_parser(tool_calls) response_content = response.choices[0].message.content + reasoning_content = getattr(response.choices[0].message, "reasoning_content", None) + if isinstance(reasoning_content, str) and reasoning_content: + reasoning_content = f"{reasoning_content}" if self.config.remove_think_prefix: return remove_thinking_tags(response_content) - else: + if reasoning_content: return reasoning_content + response_content + return response_content @timed(log=True, log_prefix="OpenAI LLM") def generate_stream(self, messages: MessageList, **kwargs) -> Generator[str, None, None]: diff --git a/tests/configs/test_llm.py b/tests/configs/test_llm.py index a977a4004..6562c9a95 100644 --- a/tests/configs/test_llm.py +++ b/tests/configs/test_llm.py @@ -19,7 +19,14 @@ def test_base_llm_config(): required_fields=[ "model_name_or_path", ], - optional_fields=["temperature", "max_tokens", "top_p", "top_k", "remove_think_prefix"], + optional_fields=[ + "temperature", + "max_tokens", + "top_p", + "top_k", + "remove_think_prefix", + "default_headers", + ], ) check_config_instantiation_valid( @@ -48,6 +55,7 @@ def test_openai_llm_config(): "api_base", "remove_think_prefix", "extra_body", + "default_headers", ], ) @@ -79,6 +87,8 @@ def test_ollama_llm_config(): "top_k", "remove_think_prefix", "api_base", + "default_headers", + "enable_thinking", ], ) @@ -111,6 +121,7 @@ def test_hf_llm_config(): "do_sample", "remove_think_prefix", "add_generation_prompt", + "default_headers", ], ) diff --git a/tests/llms/test_deepseek.py b/tests/llms/test_deepseek.py index 75c1ead5f..11be66887 100644 --- a/tests/llms/test_deepseek.py +++ b/tests/llms/test_deepseek.py @@ -12,12 +12,14 @@ def test_deepseek_llm_generate_with_and_without_think_prefix(self): """Test DeepSeekLLM generate method with and without tag removal.""" # Simulated full content including tag - full_content = "Thinking in progress...Hello from DeepSeek!" + full_content = "Hello from DeepSeek!" + reasoning_content = "Thinking in progress..." # Mock response object mock_response = MagicMock() mock_response.model_dump_json.return_value = '{"mock": "true"}' mock_response.choices[0].message.content = full_content + mock_response.choices[0].message.reasoning_content = reasoning_content # Config with think prefix preserved config_with_think = DeepSeekLLMConfig.model_validate( @@ -35,7 +37,7 @@ def test_deepseek_llm_generate_with_and_without_think_prefix(self): llm_with_think.client.chat.completions.create = MagicMock(return_value=mock_response) output_with_think = llm_with_think.generate([{"role": "user", "content": "Hello"}]) - self.assertEqual(output_with_think, full_content) + self.assertEqual(output_with_think, f"{reasoning_content}{full_content}") # Config with think tag removed config_without_think = config_with_think.model_copy(update={"remove_think_prefix": True}) @@ -43,7 +45,7 @@ def test_deepseek_llm_generate_with_and_without_think_prefix(self): llm_without_think.client.chat.completions.create = MagicMock(return_value=mock_response) output_without_think = llm_without_think.generate([{"role": "user", "content": "Hello"}]) - self.assertEqual(output_without_think, "Hello from DeepSeek!") + self.assertEqual(output_without_think, full_content) def test_deepseek_llm_generate_stream(self): """Test DeepSeekLLM generate_stream with reasoning_content and content chunks.""" @@ -84,5 +86,5 @@ def make_chunk(delta_dict): self.assertIn("Analyzing...", full_output) self.assertIn("Hello, DeepSeek!", full_output) - self.assertTrue(full_output.startswith("Analyzing...")) + self.assertTrue(full_output.startswith("")) self.assertTrue(full_output.endswith("DeepSeek!")) diff --git a/tests/llms/test_ollama.py b/tests/llms/test_ollama.py index 47002a21f..9ed252f37 100644 --- a/tests/llms/test_ollama.py +++ b/tests/llms/test_ollama.py @@ -1,5 +1,6 @@ import unittest +from types import SimpleNamespace from unittest.mock import MagicMock from memos.configs.llm import LLMConfigFactory, OllamaLLMConfig @@ -12,15 +13,15 @@ def test_llm_factory_with_mocked_ollama_backend(self): """Test LLMFactory with mocked Ollama backend.""" mock_chat = MagicMock() mock_response = MagicMock() - mock_response.model_dump_json.return_value = '{"model":"qwen3:0.6b","created_at":"2025-05-13T18:07:04.508998134Z","done":true,"done_reason":"stop","total_duration":348924420,"load_duration":14321072,"prompt_eval_count":16,"prompt_eval_duration":16770943,"eval_count":21,"eval_duration":317395459,"message":{"role":"assistant","content":"Hello! How are you? I\'m here to help and smile!","images":null,"tool_calls":null}}' - mock_response.__getitem__.side_effect = lambda key: { - "message": { - "role": "assistant", - "content": "Hello! How are you? I'm here to help and smile!", - "images": None, - "tool_calls": None, - } - }[key] + mock_response.model_dump_json.return_value = '{"model":"qwen3:0.6b","created_at":"2025-05-13T18:07:04.508998134Z","done":true,"done_reason":"stop","total_duration":348924420,"load_duration":14321072,"prompt_eval_count":16,"prompt_eval_duration":16770943,"eval_count":21,"eval_duration":317395459,"message":{"role":"assistant","content":"Hello! How are you? I\'m here to help and smile!", "thinking":"Analyzing your request...","images":null,"tool_calls":null}}' + + mock_response.message = SimpleNamespace( + role="assistant", + content="Hello! How are you? I'm here to help and smile!", + thinking="Analyzing your request...", + images=None, + tool_calls=None, + ) mock_chat.return_value = mock_response config = LLMConfigFactory.model_validate( @@ -32,6 +33,7 @@ def test_llm_factory_with_mocked_ollama_backend(self): "max_tokens": 1024, "top_p": 0.9, "top_k": 50, + "enable_thinking": True, }, } ) @@ -42,21 +44,23 @@ def test_llm_factory_with_mocked_ollama_backend(self): ] response = llm.generate(messages) - self.assertEqual(response, "Hello! How are you? I'm here to help and smile!") + self.assertEqual( + response, + "Analyzing your request...Hello! How are you? I'm here to help and smile!", + ) def test_ollama_llm_with_mocked_backend(self): """Test OllamaLLM with mocked backend.""" mock_chat = MagicMock() mock_response = MagicMock() - mock_response.model_dump_json.return_value = '{"model":"qwen3:0.6b","created_at":"2025-05-13T18:07:04.508998134Z","done":true,"done_reason":"stop","total_duration":348924420,"load_duration":14321072,"prompt_eval_count":16,"prompt_eval_duration":16770943,"eval_count":21,"eval_duration":317395459,"message":{"role":"assistant","content":"Hello! How are you? I\'m here to help and smile!","images":null,"tool_calls":null}}' - mock_response.__getitem__.side_effect = lambda key: { - "message": { - "role": "assistant", - "content": "Hello! How are you? I'm here to help and smile!", - "images": None, - "tool_calls": None, - } - }[key] + mock_response.model_dump_json.return_value = '{"model":"qwen3:0.6b","created_at":"2025-05-13T18:07:04.508998134Z","done":true,"done_reason":"stop","total_duration":348924420,"load_duration":14321072,"prompt_eval_count":16,"prompt_eval_duration":16770943,"eval_count":21,"eval_duration":317395459,"message":{"role":"assistant","content":"Hello! How are you? I\'m here to help and smile!","thinking":"Analyzing your request...","images":null,"tool_calls":null}}' + mock_response.message = SimpleNamespace( + role="assistant", + content="Hello! How are you? I'm here to help and smile!", + thinking="Analyzing your request...", + images=None, + tool_calls=None, + ) mock_chat.return_value = mock_response config = OllamaLLMConfig( @@ -73,4 +77,7 @@ def test_ollama_llm_with_mocked_backend(self): ] response = ollama.generate(messages) - self.assertEqual(response, "Hello! How are you? I'm here to help and smile!") + self.assertEqual( + response, + "Analyzing your request...Hello! How are you? I'm here to help and smile!", + ) diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index dff57c058..ba5b52df4 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -14,6 +14,7 @@ def test_llm_factory_with_mocked_openai_backend(self): mock_response = MagicMock() mock_response.model_dump_json.return_value = '{"id":"chatcmpl-BWoqIrvOeWdnFVZQUFzCcdVEpJ166","choices":[{"finish_reason":"stop","index":0,"message":{"content":"Hello! I\'m an AI language model created by OpenAI. I\'m here to help answer questions, provide information, and assist with a wide range of topics. How can I assist you today?","role":"assistant"}}],"created":1747161634,"model":"gpt-4o-2024-08-06","object":"chat.completion"}' mock_response.choices[0].message.content = "Hello! I'm an AI language model created by OpenAI. I'm here to help answer questions, provide information, and assist with a wide range of topics. How can I assist you today?" # fmt: skip + mock_response.choices[0].message.reasoning_content = None mock_chat_completions_create.return_value = mock_response config = LLMConfigFactory.model_validate( diff --git a/tests/llms/test_qwen.py b/tests/llms/test_qwen.py index 90f31e47f..71a4c75dd 100644 --- a/tests/llms/test_qwen.py +++ b/tests/llms/test_qwen.py @@ -12,12 +12,14 @@ def test_qwen_llm_generate_with_and_without_think_prefix(self): """Test QwenLLM non-streaming response generation with and without prefix removal.""" # Simulated full response content with tag - full_content = "Analyzing your request...Hello, world!" + full_content = "Hello from DeepSeek!" + reasoning_content = "Thinking in progress..." # Prepare the mock response object with expected structure mock_response = MagicMock() mock_response.model_dump_json.return_value = '{"mocked": "true"}' mock_response.choices[0].message.content = full_content + mock_response.choices[0].message.reasoning_content = reasoning_content # Create config with remove_think_prefix = False config_with_think = QwenLLMConfig.model_validate( @@ -37,7 +39,7 @@ def test_qwen_llm_generate_with_and_without_think_prefix(self): llm_with_think.client.chat.completions.create = MagicMock(return_value=mock_response) response_with_think = llm_with_think.generate([{"role": "user", "content": "Hi"}]) - self.assertEqual(response_with_think, full_content) + self.assertEqual(response_with_think, f"{reasoning_content}{full_content}") # Create config with remove_think_prefix = True config_without_think = config_with_think.model_copy(update={"remove_think_prefix": True}) @@ -47,7 +49,7 @@ def test_qwen_llm_generate_with_and_without_think_prefix(self): llm_without_think.client.chat.completions.create = MagicMock(return_value=mock_response) response_without_think = llm_without_think.generate([{"role": "user", "content": "Hi"}]) - self.assertEqual(response_without_think, "Hello, world!") + self.assertEqual(response_without_think, full_content) self.assertNotIn("", response_without_think) def test_qwen_llm_generate_stream(self): From c86dbe3e08790b96b9a6c71e0819598fd049df0f Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Mon, 24 Nov 2025 21:05:25 +0800 Subject: [PATCH 15/25] finish info transfer --- src/memos/api/handlers/add_handler.py | 10 ++++++++++ src/memos/mem_reader/simple_struct.py | 18 ++++++++++++++---- src/memos/mem_scheduler/general_scheduler.py | 8 +++++++- .../mem_scheduler/schemas/message_schemas.py | 1 + src/memos/memories/textual/item.py | 18 ++++++++++++++++++ src/memos/multi_mem_cube/single_cube.py | 2 ++ 6 files changed, 52 insertions(+), 5 deletions(-) diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 9b41477e1..8476404d3 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -9,6 +9,9 @@ from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies from memos.api.product_models import APIADDRequest, MemoryResponse +from memos.memories.textual.item import ( + list_all_fields, +) from memos.multi_mem_cube.composite_cube import CompositeCubeView from memos.multi_mem_cube.single_cube import SingleCubeView from memos.multi_mem_cube.views import MemCubeView @@ -50,6 +53,13 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: add_req.messages = self._convert_content_messsage(add_req.memory_content) self.logger.info(f"[AddHandler] Converted content to messages: {add_req.messages}") + if add_req.info: + exclude_fields = list_all_fields() + info_len = len(add_req.info) + add_req.info = {k: v for k, v in add_req.info.items() if k not in exclude_fields} + if len(add_req.info) < info_len: + self.logger.warning(f"[AddHandler] info fields can not contain {exclude_fields}.") + cube_view = self._build_cube_view(add_req) results = cube_view.add_memories(add_req) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 3845f37d0..fee4fea93 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -121,11 +121,15 @@ def _build_node(idx, message, info, scene_file, llm, parse_json_result, embedder embedding = embedder.embed([value])[0] + info_ = info.copy() + user_id = info_.pop("user_id", "") + session_id = info_.pop("session_id", "") + return TextualMemoryItem( memory=value, metadata=TreeNodeTextualMemoryMetadata( - user_id=info.get("user_id", ""), - session_id=info.get("session_id", ""), + user_id=user_id, + session_id=session_id, memory_type="LongTermMemory", status="activated", tags=tags, @@ -136,6 +140,7 @@ def _build_node(idx, message, info, scene_file, llm, parse_json_result, embedder background="", confidence=0.99, type="fact", + info=info_, ), ) except Exception as e: @@ -183,11 +188,15 @@ def _make_memory_item( confidence: float = 0.99, ) -> TextualMemoryItem: """construct memory item""" + info_ = info.copy() + user_id = info_.pop("user_id", "") + session_id = info_.pop("session_id", "") + return TextualMemoryItem( memory=value, metadata=TreeNodeTextualMemoryMetadata( - user_id=info.get("user_id", ""), - session_id=info.get("session_id", ""), + user_id=user_id, + session_id=session_id, memory_type=memory_type, status="activated", tags=tags or [], @@ -198,6 +207,7 @@ def _make_memory_item( background=background, confidence=confidence, type=type_, + info=info_, ), ) diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 2c20520ea..135e6927f 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -754,6 +754,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id = message.mem_cube_id content = message.content messages_list = json.loads(content) + info = message.info or {} logger.info(f"Processing pref_add for user_id={user_id}, mem_cube_id={mem_cube_id}") @@ -776,7 +777,12 @@ def process_message(message: ScheduleMessageItem): pref_memories = pref_mem.get_memory( messages_list, type="chat", - info={"user_id": user_id, "session_id": session_id, "mem_cube_id": mem_cube_id}, + info={ + **info, + "user_id": user_id, + "session_id": session_id, + "mem_cube_id": mem_cube_id, + }, ) # Add pref_mem to vector db pref_ids = pref_mem.add(pref_memories) diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index d7e94e0e1..5652d6b6e 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -46,6 +46,7 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): default="", description="user name / display name (optional)", ) + info: dict | None = Field(default=None, description="user custom info") # Pydantic V2 model configuration model_config = ConfigDict( diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index e7595443d..fccd75bfd 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -83,6 +83,10 @@ class TextualMemoryMetadata(BaseModel): default_factory=lambda: datetime.now().isoformat(), description="The timestamp of the last modification to the memory. Useful for tracking memory freshness or change history. Format: ISO 8601.", ) + info: dict | None = Field( + default=None, + description="Arbitrary key-value pairs for additional metadata.", + ) model_config = ConfigDict(extra="allow") @@ -267,3 +271,17 @@ def _coerce_metadata(cls, v: Any): def __str__(self) -> str: """Pretty string representation of the memory item.""" return f"" + + +def list_all_fields() -> list[str]: + """List all possible fields of the TextualMemoryItem model.""" + top = list(TextualMemoryItem.model_fields.keys()) + meta_models = [ + TextualMemoryMetadata, + TreeNodeTextualMemoryMetadata, + SearchedTreeNodeTextualMemoryMetadata, + PreferenceTextualMemoryMetadata, + ] + meta_all = sorted(set().union(*[set(m.model_fields.keys()) for m in meta_models])) + + return top + meta_all diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 2055615d2..837ccd58a 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -473,6 +473,7 @@ def _process_pref_mem( [add_req.messages], type="chat", info={ + **(add_req.info or {}), "user_id": add_req.user_id, "session_id": target_session_id, "mem_cube_id": self.cube_id, @@ -524,6 +525,7 @@ def _process_text_mem( [add_req.messages], type="chat", info={ + **(add_req.info or {}), "user_id": add_req.user_id, "session_id": target_session_id, }, From 93f8befb8fe0bf6d2ca9b12ff96c3f00ce18b373 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 25 Nov 2025 16:52:53 +0800 Subject: [PATCH 16/25] add info and custom tags --- src/memos/api/product_models.py | 29 ++++++++----- src/memos/mem_reader/simple_struct.py | 41 +++++++++++++++---- src/memos/mem_reader/strategy_struct.py | 13 +++++- src/memos/mem_scheduler/general_scheduler.py | 5 +++ src/memos/memories/textual/general.py | 4 +- src/memos/multi_mem_cube/single_cube.py | 2 + src/memos/templates/mem_reader_prompts.py | 18 ++++++++ .../templates/mem_reader_strategy_prompts.py | 2 + src/memos/types/types.py | 4 +- 9 files changed, 97 insertions(+), 21 deletions(-) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 191b219e4..5c1ad2072 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -6,7 +6,7 @@ # Import message types from core types module from memos.mem_scheduler.schemas.general_schemas import SearchMode -from memos.types import MessageDict, PermissionDict +from memos.types import MessageList, MessagesType, PermissionDict T = TypeVar("T") @@ -78,7 +78,7 @@ class ChatRequest(BaseRequest): writable_cube_ids: list[str] | None = Field( None, description="List of cube IDs user can write for multi-cube chat" ) - history: list[MessageDict] | None = Field(None, description="Chat history") + history: MessageList | None = Field(None, description="Chat history") mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") internet_search: bool = Field(True, description="Whether to use internet search") system_prompt: str | None = Field(None, description="Base system prompt to use for chat") @@ -99,12 +99,12 @@ class ChatRequest(BaseRequest): class ChatCompleteRequest(BaseRequest): - """Request model for chat operations.""" + """Request model for chat operations. will (Deprecated), instead use APIChatCompleteRequest.""" user_id: str = Field(..., description="User ID") query: str = Field(..., description="Chat query message") mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") - history: list[MessageDict] | None = Field(None, description="Chat history") + history: MessageList | None = Field(None, description="Chat history") internet_search: bool = Field(False, description="Whether to use internet search") system_prompt: str | None = Field(None, description="Base prompt to use for chat") top_k: int = Field(10, description="Number of results to return") @@ -190,7 +190,7 @@ class MemoryCreateRequest(BaseRequest): """Request model for creating memories.""" user_id: str = Field(..., description="User ID") - messages: list[MessageDict] | None = Field(None, description="List of messages to store.") + messages: MessagesType | None = Field(None, description="List of messages to store.") memory_content: str | None = Field(None, description="Memory content to store") doc_path: str | None = Field(None, description="Path to document to store") mem_cube_id: str | None = Field(None, description="Cube ID") @@ -275,7 +275,14 @@ class APISearchRequest(BaseRequest): # TODO: maybe add detailed description later filter: dict[str, Any] | None = Field( None, - description=("Filter for the memory"), + description=""" + { + "`and` or `or`": [ + {"id": "uuid-xxx"}, + {"created_at": {"gt": "2024-01-01"}}, + ] + } + """, ) # ==== Extended capabilities ==== @@ -297,7 +304,7 @@ class APISearchRequest(BaseRequest): ) # ==== Context ==== - chat_history: list[MessageDict] | None = Field( + chat_history: MessageList | None = Field( None, description=( "Historical chat messages used internally by algorithms. " @@ -374,7 +381,7 @@ class APIADDRequest(BaseRequest): ) # ==== Input content ==== - messages: list[MessageDict] | None = Field( + messages: MessagesType | None = Field( None, description=( "List of messages to store. Supports: " @@ -390,7 +397,7 @@ class APIADDRequest(BaseRequest): ) # ==== Chat history ==== - chat_history: list[MessageDict] | None = Field( + chat_history: MessageList | None = Field( None, description=( "Historical chat messages used internally by algorithms. " @@ -439,7 +446,7 @@ class APIChatCompleteRequest(BaseRequest): writable_cube_ids: list[str] | None = Field( None, description="List of cube IDs user can write for multi-cube chat" ) - history: list[MessageDict] | None = Field(None, description="Chat history") + history: MessageList | None = Field(None, description="Chat history") internet_search: bool = Field(False, description="Whether to use internet search") system_prompt: str | None = Field(None, description="Base system prompt to use for chat") mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") @@ -486,7 +493,7 @@ class SuggestionRequest(BaseRequest): user_id: str = Field(..., description="User ID") mem_cube_id: str = Field(..., description="Cube ID") language: Literal["zh", "en"] = Field("zh", description="Language for suggestions") - message: list[MessageDict] | None = Field(None, description="List of messages to store.") + message: MessagesType | None = Field(None, description="List of messages to store.") # ─── MemOS Client Response Models ────────────────────────────────────────────── diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index fee4fea93..342aee952 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -22,6 +22,8 @@ from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.parsers.factory import ParserFactory from memos.templates.mem_reader_prompts import ( + CUSTOM_TAGS_INSTRUCTION, + CUSTOM_TAGS_INSTRUCTION_ZH, SIMPLE_STRUCT_DOC_READER_PROMPT, SIMPLE_STRUCT_DOC_READER_PROMPT_ZH, SIMPLE_STRUCT_MEM_READER_EXAMPLE, @@ -41,6 +43,7 @@ "zh_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH, }, "doc": {"en": SIMPLE_STRUCT_DOC_READER_PROMPT, "zh": SIMPLE_STRUCT_DOC_READER_PROMPT_ZH}, + "custom_tags": {"en": CUSTOM_TAGS_INSTRUCTION, "zh": CUSTOM_TAGS_INSTRUCTION_ZH}, } try: @@ -211,11 +214,19 @@ def _make_memory_item( ), ) - def _get_llm_response(self, mem_str: str) -> dict: + def _get_llm_response(self, mem_str: str, custom_tags: list[str] | None) -> dict: lang = detect_lang(mem_str) template = PROMPT_DICT["chat"][lang] examples = PROMPT_DICT["chat"][f"{lang}_example"] prompt = template.replace("${conversation}", mem_str) + + custom_tags_prompt = ( + PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags)) + if custom_tags + else "" + ) + prompt = prompt.replace("${custom_tags_prompt}", custom_tags_prompt) + if self.config.remove_prompt_example: prompt = prompt.replace(examples, "") messages = [{"role": "user", "content": prompt}] @@ -313,8 +324,9 @@ def _build_fast_node(w): else: logger.debug("Using unified Fine Mode") chat_read_nodes = [] + custom_tags = info.pop("custom_tags", None) for w in windows: - resp = self._get_llm_response(w["text"]) + resp = self._get_llm_response(w["text"], custom_tags) for m in resp.get("memory list", []): try: memory_type = ( @@ -336,9 +348,11 @@ def _build_fast_node(w): logger.error(f"[ChatFine] parse error: {e}") return chat_read_nodes - def _process_transfer_chat_data(self, raw_node: TextualMemoryItem): + def _process_transfer_chat_data( + self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None + ): raw_memory = raw_node.memory - response_json = self._get_llm_response(raw_memory) + response_json = self._get_llm_response(raw_memory, custom_tags) chat_read_nodes = [] for memory_i_raw in response_json.get("memory list", []): try: @@ -352,6 +366,7 @@ def _process_transfer_chat_data(self, raw_node: TextualMemoryItem): node_i = self._make_memory_item( value=memory_i_raw.get("value", ""), info={ + **(raw_node.metadata.info or {}), "user_id": raw_node.metadata.user_id, "session_id": raw_node.metadata.session_id, }, @@ -439,7 +454,10 @@ def get_memory( return memory_list def fine_transfer_simple_mem( - self, input_memories: list[TextualMemoryItem], type: str + self, + input_memories: list[TextualMemoryItem], + type: str, + custom_tags: list[str] | None = None, ) -> list[list[TextualMemoryItem]]: if not input_memories: return [] @@ -456,7 +474,7 @@ def fine_transfer_simple_mem( # Process Q&A pairs concurrently with context propagation with ContextThreadPoolExecutor() as executor: futures = [ - executor.submit(processing_func, scene_data_info) + executor.submit(processing_func, scene_data_info, custom_tags) for scene_data_info in input_memories ] for future in concurrent.futures.as_completed(futures): @@ -549,11 +567,18 @@ def _process_doc_data(self, scene_data_info, info, **kwargs): if mode == "fast": raise NotImplementedError chunks = self.chunker.chunk(scene_data_info["text"]) + custom_tags = info.pop("custom_tags", None) messages = [] for chunk in chunks: lang = detect_lang(chunk.text) template = PROMPT_DICT["doc"][lang] prompt = template.replace("{chunk_text}", chunk.text) + custom_tags_prompt = ( + PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags)) + if custom_tags + else "" + ) + prompt = prompt.replace("{custom_tags_prompt}", custom_tags_prompt) message = [{"role": "user", "content": prompt}] messages.append(message) @@ -588,7 +613,9 @@ def _process_doc_data(self, scene_data_info, info, **kwargs): logger.error(f"[DocReader] Future task failed: {e}") return doc_nodes - def _process_transfer_doc_data(self, raw_node: TextualMemoryItem): + def _process_transfer_doc_data( + self, raw_node: TextualMemoryItem, custom_tags: list[str] | None = None + ): raise NotImplementedError def parse_json_result(self, response_text: str) -> dict: diff --git a/src/memos/mem_reader/strategy_struct.py b/src/memos/mem_reader/strategy_struct.py index 1fc21461e..21be8bc39 100644 --- a/src/memos/mem_reader/strategy_struct.py +++ b/src/memos/mem_reader/strategy_struct.py @@ -8,6 +8,8 @@ from memos.mem_reader.simple_struct import SimpleStructMemReader, detect_lang from memos.parsers.factory import ParserFactory from memos.templates.mem_reader_prompts import ( + CUSTOM_TAGS_INSTRUCTION, + CUSTOM_TAGS_INSTRUCTION_ZH, SIMPLE_STRUCT_DOC_READER_PROMPT, SIMPLE_STRUCT_DOC_READER_PROMPT_ZH, SIMPLE_STRUCT_MEM_READER_EXAMPLE, @@ -28,6 +30,7 @@ "zh_example": SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH, }, "doc": {"en": SIMPLE_STRUCT_DOC_READER_PROMPT, "zh": SIMPLE_STRUCT_DOC_READER_PROMPT_ZH}, + "custom_tags": {"en": CUSTOM_TAGS_INSTRUCTION, "zh": CUSTOM_TAGS_INSTRUCTION_ZH}, } @@ -38,11 +41,19 @@ def __init__(self, config: StrategyStructMemReaderConfig): super().__init__(config) self.chat_chunker = config.chat_chunker["config"] - def _get_llm_response(self, mem_str: str) -> dict: + def _get_llm_response(self, mem_str: str, custom_tags: list[str] | None) -> dict: lang = detect_lang(mem_str) template = STRATEGY_PROMPT_DICT["chat"][lang] examples = STRATEGY_PROMPT_DICT["chat"][f"{lang}_example"] prompt = template.replace("${conversation}", mem_str) + + custom_tags_prompt = ( + STRATEGY_PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags)) + if custom_tags + else "" + ) + prompt = prompt.replace("${custom_tags_prompt}", custom_tags_prompt) + if self.config.remove_prompt_example: # TODO unused prompt = prompt.replace(examples, "") messages = [{"role": "user", "content": prompt}] diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 135e6927f..01f44f71f 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -365,6 +365,7 @@ def process_message(message: ScheduleMessageItem): mem_cube = self.current_mem_cube content = message.content user_name = message.user_name + info = message.info or {} # Parse the memory IDs from content mem_ids = json.loads(content) if isinstance(content, str) else content @@ -388,6 +389,7 @@ def process_message(message: ScheduleMessageItem): mem_cube_id=mem_cube_id, text_mem=text_mem, user_name=user_name, + custom_tags=info.get("custom_tags", None), ) logger.info( @@ -412,6 +414,7 @@ def _process_memories_with_reader( mem_cube_id: str, text_mem: TreeTextMemory, user_name: str, + custom_tags: list[str] | None = None, ) -> None: """ Process memories using mem_reader for enhanced memory processing. @@ -421,6 +424,7 @@ def _process_memories_with_reader( user_id: User ID mem_cube_id: Memory cube ID text_mem: Text memory instance + custom_tags: Optional list of custom tags for memory processing """ try: # Get the mem_reader from the parent MOSCore @@ -464,6 +468,7 @@ def _process_memories_with_reader( processed_memories = self.mem_reader.fine_transfer_simple_mem( memory_items, type="chat", + custom_tags=custom_tags, ) except Exception as e: logger.warning(f"{e}: Fail to transfer mem: {memory_items}") diff --git a/src/memos/memories/textual/general.py b/src/memos/memories/textual/general.py index d71a86d2e..f56b2028d 100644 --- a/src/memos/memories/textual/general.py +++ b/src/memos/memories/textual/general.py @@ -56,7 +56,9 @@ def extract(self, messages: MessageList) -> list[TextualMemoryItem]: [message["role"] + ":" + message["content"] for message in messages] ) - prompt = SIMPLE_STRUCT_MEM_READER_PROMPT.replace("${conversation}", str_messages) + prompt = SIMPLE_STRUCT_MEM_READER_PROMPT.replace("${conversation}", str_messages).replace( + "${custom_tags_prompt}", "" + ) messages = [{"role": "user", "content": prompt}] response_text = self.extractor_llm.generate(messages) response_json = self.parse_json_result(response_text) diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 837ccd58a..c5cf71316 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -401,6 +401,7 @@ def _schedule_memory_tasks( content=json.dumps(mem_ids), timestamp=datetime.utcnow(), user_name=self.cube_id, + info=add_req.info, ) self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_read]) self.logger.info( @@ -526,6 +527,7 @@ def _process_text_mem( type="chat", info={ **(add_req.info or {}), + "custom_tags": add_req.custom_tags, "user_id": add_req.user_id, "session_id": target_session_id, }, diff --git a/src/memos/templates/mem_reader_prompts.py b/src/memos/templates/mem_reader_prompts.py index ec6812743..3223e4694 100644 --- a/src/memos/templates/mem_reader_prompts.py +++ b/src/memos/templates/mem_reader_prompts.py @@ -39,6 +39,8 @@ - The `key`, `value`, `tags`, `summary` fields must match the mostly used language of the input conversation. **如果输入是中文,请输出中文** - Keep `memory_type` in English. +${custom_tags_prompt} + Example: Conversation: user: [June 26, 2025 at 3:00 PM]: Hi Jerry! Yesterday at 3 PM I had a meeting with my team about the new project. @@ -132,6 +134,8 @@ - `key`、`value`、`tags`、`summary` 字段必须与输入对话的主要语言一致。**如果输入是中文,请输出中文** - `memory_type` 保持英文。 +${custom_tags_prompt} + 示例: 对话: user: [2025年6月26日下午3:00]:嗨Jerry!昨天下午3点我和团队开了个会,讨论新项目。 @@ -212,6 +216,8 @@ - The `key`, `value`, `tags`, `summary` fields must match the mostly used language of the input document summaries. **如果输入是中文,请输出中文** - Keep `memory_type` in English. +{custom_tags_prompt} + Document chunk: {chunk_text} @@ -250,6 +256,8 @@ - `key`、`value`、`tags` 字段必须与输入文档摘要的主要语言一致。**如果输入是中文,请输出中文** - `memory_type` 保持英文。 +{custom_tags_prompt} + 文档片段: {chunk_text} @@ -341,3 +349,13 @@ } """ + + +CUSTOM_TAGS_INSTRUCTION = """Output tags can refer to the following tags: +{custom_tags} +You can choose tags from the above list that are relevant to the memory. Additionally, you can freely add tags based on the content of the memory.""" + + +CUSTOM_TAGS_INSTRUCTION_ZH = """输出tags可以参考下列标签: +{custom_tags} +你可以选择与memory相关的在上述列表中可以加入tags,同时你可以根据memory的内容自由添加tags。""" diff --git a/src/memos/templates/mem_reader_strategy_prompts.py b/src/memos/templates/mem_reader_strategy_prompts.py index ba4a00d0a..21421e30b 100644 --- a/src/memos/templates/mem_reader_strategy_prompts.py +++ b/src/memos/templates/mem_reader_strategy_prompts.py @@ -61,6 +61,7 @@ Language rules: - The `key`, `value`, `tags`, `summary` and `memory_type` fields must be in English. +${custom_tags_prompt} Example: Conversations: @@ -157,6 +158,7 @@ 语言规则: - `key`、`value`、`tags`、`summary` 、`memory_type` 字段必须输出中文 +${custom_tags_prompt} 示例1: 对话: diff --git a/src/memos/types/types.py b/src/memos/types/types.py index b8efc6208..a843bff8a 100644 --- a/src/memos/types/types.py +++ b/src/memos/types/types.py @@ -27,8 +27,10 @@ "MessageDict", "MessageList", "MessageRole", + "MessagesType", "Permission", "PermissionDict", + "RawMessageList", "UserContext", ] @@ -40,7 +42,7 @@ # Message structure class MessageDict(TypedDict, total=False): - """Typed dictionary for chat message dictionaries.""" + """Typed dictionary for chat message dictionaries, will (Deprecate), use ChatCompletionMessageParam instead.""" role: MessageRole content: str From 0d68e889611317b18df2ff16a4baa64a35bd5a9a Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Tue, 25 Nov 2025 20:14:27 +0800 Subject: [PATCH 17/25] modify model product fileds --- src/memos/api/product_models.py | 60 ++++++++++++++++++++++++++------- 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 180d299ab..24d2e7c3a 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -73,7 +73,6 @@ class ChatRequest(BaseRequest): user_id: str = Field(..., description="User ID") query: str = Field(..., description="Chat query message") - mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") readable_cube_ids: list[str] | None = Field( None, description="List of cube IDs user can read for multi-cube chat" ) @@ -82,19 +81,37 @@ class ChatRequest(BaseRequest): ) history: MessageList | None = Field(None, description="Chat history") mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") - internet_search: bool = Field(True, description="Whether to use internet search") system_prompt: str | None = Field(None, description="Base system prompt to use for chat") top_k: int = Field(10, description="Number of results to return") - threshold: float = Field(0.5, description="Threshold for filtering references") session_id: str | None = Field(None, description="Session ID for soft-filtering memories") include_preference: bool = Field(True, description="Whether to handle preference memory") pref_top_k: int = Field(6, description="Number of preference results to return") - filter: dict[str, Any] | None = Field(None, description="Filter for the memory") model_name_or_path: str | None = Field(None, description="Model name to use for chat") max_tokens: int | None = Field(None, description="Max tokens to generate") temperature: float | None = Field(None, description="Temperature for sampling") top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat") + + # ==== Filter conditions ==== + filter: dict[str, Any] | None = Field( + None, + description=""" + Filter for the memory, example: + { + "`and` or `or`": [ + {"id": "uuid-xxx"}, + {"created_at": {"gt": "2024-01-01"}}, + ] + } + """, + ) + + # ==== Extended capabilities ==== + internet_search: bool = Field(True, description="Whether to use internet search") + threshold: float = Field(0.5, description="Threshold for filtering references") + + # ==== Backward compatibility ==== + mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") moscube: bool = Field( False, description="(Deprecated) Whether to use legacy MemOSCube pipeline" ) @@ -271,6 +288,7 @@ class APISearchRequest(BaseRequest): filter: dict[str, Any] | None = Field( None, description=""" + Filter for the memory, example: { "`and` or `or`": [ {"id": "uuid-xxx"}, @@ -548,7 +566,6 @@ class APIChatCompleteRequest(BaseRequest): user_id: str = Field(..., description="User ID") query: str = Field(..., description="Chat query message") - mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") readable_cube_ids: list[str] | None = Field( None, description="List of cube IDs user can read for multi-cube chat" ) @@ -556,23 +573,42 @@ class APIChatCompleteRequest(BaseRequest): None, description="List of cube IDs user can write for multi-cube chat" ) history: MessageList | None = Field(None, description="Chat history") - internet_search: bool = Field(False, description="Whether to use internet search") - system_prompt: str | None = Field(None, description="Base system prompt to use for chat") mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") + system_prompt: str | None = Field(None, description="Base system prompt to use for chat") top_k: int = Field(10, description="Number of results to return") - threshold: float = Field(0.5, description="Threshold for filtering references") - session_id: str | None = Field( - "default_session", description="Session ID for soft-filtering memories" - ) + session_id: str | None = Field(None, description="Session ID for soft-filtering memories") include_preference: bool = Field(True, description="Whether to handle preference memory") pref_top_k: int = Field(6, description="Number of preference results to return") - filter: dict[str, Any] | None = Field(None, description="Filter for the memory") model_name_or_path: str | None = Field(None, description="Model name to use for chat") max_tokens: int | None = Field(None, description="Max tokens to generate") temperature: float | None = Field(None, description="Temperature for sampling") top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter") add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat") + # ==== Filter conditions ==== + filter: dict[str, Any] | None = Field( + None, + description=""" + Filter for the memory, example: + { + "`and` or `or`": [ + {"id": "uuid-xxx"}, + {"created_at": {"gt": "2024-01-01"}}, + ] + } + """, + ) + + # ==== Extended capabilities ==== + internet_search: bool = Field(True, description="Whether to use internet search") + threshold: float = Field(0.5, description="Threshold for filtering references") + + # ==== Backward compatibility ==== + mem_cube_id: str | None = Field(None, description="Cube ID to use for chat") + moscube: bool = Field( + False, description="(Deprecated) Whether to use legacy MemOSCube pipeline" + ) + class AddStatusRequest(BaseRequest): """Request model for checking add status.""" From 4964d2bbd2dfc628571b860af17f70b9c426f22d Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Wed, 26 Nov 2025 10:13:16 +0800 Subject: [PATCH 18/25] fix get api bug --- src/memos/api/handlers/memory_handler.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index c47a3cf83..689e2b16b 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -4,7 +4,7 @@ This module handles retrieving all memories or specific subgraphs based on queries. """ -from typing import Any, Literal +from typing import TYPE_CHECKING, Any, Literal from memos.api.handlers.formatters_handler import format_memory_item from memos.api.product_models import ( @@ -24,6 +24,10 @@ ) +if TYPE_CHECKING: + from memos.memories.textual.preference import TextualMemoryItem + + logger = get_logger(__name__) @@ -161,17 +165,20 @@ def handle_get_subgraph( def handle_get_memories(get_mem_req: GetMemoryRequest, naive_mem_cube: Any) -> GetMemoryResponse: # TODO: Implement get memory with filter memories = naive_mem_cube.text_mem.get_all(user_name=get_mem_req.mem_cube_id)["nodes"] - filter_params: dict[str, Any] = {} - if get_mem_req.user_id is not None: - filter_params["user_id"] = get_mem_req.user_id - if get_mem_req.mem_cube_id is not None: - filter_params["mem_cube_id"] = get_mem_req.mem_cube_id - preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params) + preferences: list[TextualMemoryItem] = [] + if get_mem_req.include_preference: + filter_params: dict[str, Any] = {} + if get_mem_req.user_id is not None: + filter_params["user_id"] = get_mem_req.user_id + if get_mem_req.mem_cube_id is not None: + filter_params["mem_cube_id"] = get_mem_req.mem_cube_id + preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params) + preferences = [format_memory_item(mem) for mem in preferences] return GetMemoryResponse( message="Memories retrieved successfully", data={ "text_mem": memories, - "pref_mem": [format_memory_item(mem) for mem in preferences], + "pref_mem": preferences, }, ) From e4eb9db201d696fa46f756a82feb1bcf9cff1b1a Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Wed, 26 Nov 2025 14:35:25 +0800 Subject: [PATCH 19/25] fix bug --- src/memos/mem_reader/simple_struct.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 342aee952..29ce49d90 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -295,6 +295,9 @@ def _iter_chat_windows(self, scene_data_info, max_tokens=None, overlap=200): def _process_chat_data(self, scene_data_info, info, **kwargs): mode = kwargs.get("mode", "fine") windows = list(self._iter_chat_windows(scene_data_info)) + custom_tags = info.pop( + "custom_tags", None + ) # msut pop here, avoid add to info, only used in sync fine mode if mode == "fast": logger.debug("Using unified Fast Mode") @@ -324,7 +327,6 @@ def _build_fast_node(w): else: logger.debug("Using unified Fine Mode") chat_read_nodes = [] - custom_tags = info.pop("custom_tags", None) for w in windows: resp = self._get_llm_response(w["text"], custom_tags) for m in resp.get("memory list", []): @@ -353,6 +355,7 @@ def _process_transfer_chat_data( ): raw_memory = raw_node.memory response_json = self._get_llm_response(raw_memory, custom_tags) + chat_read_nodes = [] for memory_i_raw in response_json.get("memory list", []): try: From c39fda48f77a50d689b8523abac25414f7fdd091 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Thu, 27 Nov 2025 19:38:26 +0800 Subject: [PATCH 20/25] fix bug in pref add info --- src/memos/mem_reader/read_multi_model/utils.py | 2 +- .../textual/prefer_text_memory/extractor.py | 17 +++++++++++++++-- src/memos/multi_mem_cube/single_cube.py | 1 + 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/memos/mem_reader/read_multi_model/utils.py b/src/memos/mem_reader/read_multi_model/utils.py index e42a564e4..c14710650 100644 --- a/src/memos/mem_reader/read_multi_model/utils.py +++ b/src/memos/mem_reader/read_multi_model/utils.py @@ -67,7 +67,7 @@ def _is_message_list(obj): return True -def coerce_scene_data(scene_data, scene_type: str) -> list[MessagesType]: +def coerce_scene_data(scene_data: SceneDataInput, scene_type: str) -> list[MessagesType]: """ Normalize ANY allowed SceneDataInput into: list[MessagesType]. Supports: diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index d5eab2aec..72daa31cd 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -113,7 +113,11 @@ def _process_single_chunk_explicit( vector_info = { "embedding": self.embedder.embed([pref["context_summary"]])[0], } - extract_info = {**basic_info, **pref, **vector_info, **info} + + inner_keys = set(PreferenceTextualMemoryMetadata.model_fields.keys()) + inner_info = {k: v for k, v in info.items() if k in inner_keys} + user_info = {k: v for k, v in info.items() if k not in inner_keys} + extract_info = {**basic_info, **pref, **vector_info, **inner_info, "info": user_info} metadata = PreferenceTextualMemoryMetadata( type=msg_type, preference_type="explicit_preference", **extract_info @@ -140,7 +144,16 @@ def _process_single_chunk_implicit( "embedding": self.embedder.embed([implicit_pref["context_summary"]])[0], } - extract_info = {**basic_info, **implicit_pref, **vector_info, **info} + inner_keys = set(PreferenceTextualMemoryMetadata.model_fields.keys()) + inner_info = {k: v for k, v in info.items() if k in inner_keys} + user_info = {k: v for k, v in info.items() if k not in inner_keys} + extract_info = { + **basic_info, + **implicit_pref, + **vector_info, + **inner_info, + "info": user_info, + } metadata = PreferenceTextualMemoryMetadata( type=msg_type, preference_type="implicit_preference", **extract_info diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 8f4a25a0b..681950954 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -491,6 +491,7 @@ def _process_pref_mem( label=PREF_ADD_LABEL, content=json.dumps(messages_list), timestamp=datetime.utcnow(), + info=add_req.info, ) self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_pref]) self.logger.info(f"[SingleCubeView] cube={self.cube_id} Submitted PREF_ADD async") From 5c40498d3be4b15f6558bf65cc40d7a28206cfe2 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 28 Nov 2025 10:35:12 +0800 Subject: [PATCH 21/25] modify code --- .../textual/prefer_text_memory/extractor.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index 72daa31cd..f23135754 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -114,10 +114,7 @@ def _process_single_chunk_explicit( "embedding": self.embedder.embed([pref["context_summary"]])[0], } - inner_keys = set(PreferenceTextualMemoryMetadata.model_fields.keys()) - inner_info = {k: v for k, v in info.items() if k in inner_keys} - user_info = {k: v for k, v in info.items() if k not in inner_keys} - extract_info = {**basic_info, **pref, **vector_info, **inner_info, "info": user_info} + extract_info = {**basic_info, **pref, **vector_info, **info} metadata = PreferenceTextualMemoryMetadata( type=msg_type, preference_type="explicit_preference", **extract_info @@ -144,16 +141,7 @@ def _process_single_chunk_implicit( "embedding": self.embedder.embed([implicit_pref["context_summary"]])[0], } - inner_keys = set(PreferenceTextualMemoryMetadata.model_fields.keys()) - inner_info = {k: v for k, v in info.items() if k in inner_keys} - user_info = {k: v for k, v in info.items() if k not in inner_keys} - extract_info = { - **basic_info, - **implicit_pref, - **vector_info, - **inner_info, - "info": user_info, - } + extract_info = {**basic_info, **implicit_pref, **vector_info, **info} metadata = PreferenceTextualMemoryMetadata( type=msg_type, preference_type="implicit_preference", **extract_info From d8bcfbe37240a922deb8cde1aea5fb1dbdc4c0df Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 28 Nov 2025 10:50:45 +0800 Subject: [PATCH 22/25] fix bug in get and delete --- src/memos/api/handlers/memory_handler.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index 689e2b16b..f0f3f39b9 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -15,6 +15,7 @@ MemoryResponse, ) from memos.log import get_logger +from memos.mem_cube.navie import NaiveMemCube from memos.mem_os.utils.format_utils import ( convert_graph_to_tree_forworkmem, ensure_unique_tree_ids, @@ -162,11 +163,13 @@ def handle_get_subgraph( raise -def handle_get_memories(get_mem_req: GetMemoryRequest, naive_mem_cube: Any) -> GetMemoryResponse: +def handle_get_memories( + get_mem_req: GetMemoryRequest, naive_mem_cube: NaiveMemCube +) -> GetMemoryResponse: # TODO: Implement get memory with filter memories = naive_mem_cube.text_mem.get_all(user_name=get_mem_req.mem_cube_id)["nodes"] preferences: list[TextualMemoryItem] = [] - if get_mem_req.include_preference: + if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None: filter_params: dict[str, Any] = {} if get_mem_req.user_id is not None: filter_params["user_id"] = get_mem_req.user_id @@ -183,10 +186,11 @@ def handle_get_memories(get_mem_req: GetMemoryRequest, naive_mem_cube: Any) -> G ) -def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: Any): +def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: NaiveMemCube): try: naive_mem_cube.text_mem.delete(delete_mem_req.memory_ids) - naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids) + if naive_mem_cube.pref_mem is not None: + naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids) except Exception as e: logger.error(f"Failed to delete memories: {e}", exc_info=True) return DeleteMemoryResponse( From 4bb10aef26934cba791b1c8c0a99571ec7d19c24 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 28 Nov 2025 15:26:37 +0800 Subject: [PATCH 23/25] modify delete code --- src/memos/api/handlers/memory_handler.py | 37 +++++++++++++++++++++--- src/memos/api/product_models.py | 4 ++- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index f0f3f39b9..b6c4a40b5 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -187,15 +187,44 @@ def handle_get_memories( def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: NaiveMemCube): + # Validate that only one of memory_ids, file_ids, or filter is provided + provided_params = [ + delete_mem_req.memory_ids is not None, + delete_mem_req.file_ids is not None, + delete_mem_req.filter is not None, + ] + if sum(provided_params) != 1: + return DeleteMemoryResponse( + message="Exactly one of memory_ids, file_ids, or filter must be provided", + data={"status": "failure"}, + ) + try: - naive_mem_cube.text_mem.delete(delete_mem_req.memory_ids) - if naive_mem_cube.pref_mem is not None: - naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids) + if delete_mem_req.memory_ids is not None: + naive_mem_cube.text_mem.delete(delete_mem_req.memory_ids) + if naive_mem_cube.pref_mem is not None: + naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids) + elif delete_mem_req.file_ids is not None: + # TODO: Implement deletion by file_ids + # Need to find memory_ids associated with file_ids and delete them + logger.warning("Deletion by file_ids not implemented yet") + return DeleteMemoryResponse( + message="Deletion by file_ids not implemented yet", + data={"status": "failure"}, + ) + elif delete_mem_req.filter is not None: + # TODO: Implement deletion by filter + # Need to find memories matching filter and delete them + logger.warning("Deletion by filter not implemented yet") + return DeleteMemoryResponse( + message="Deletion by filter not implemented yet", + data={"status": "failure"}, + ) except Exception as e: logger.error(f"Failed to delete memories: {e}", exc_info=True) return DeleteMemoryResponse( message="Failed to delete memories", - data="failure", + data={"status": "failure"}, ) return DeleteMemoryResponse( message="Memories deleted successfully", diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 5aa617d6e..ceede3e05 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -690,7 +690,9 @@ class GetMemoryRequest(BaseRequest): class DeleteMemoryRequest(BaseRequest): """Request model for deleting memories.""" - memory_ids: list[str] = Field(..., description="Memory IDs") + memory_ids: list[str] | None = Field(None, description="Memory IDs") + file_ids: list[str] | None = Field(None, description="File IDs") + filter: dict[str, Any] | None = Field(None, description="Filter for the memory") class SuggestionRequest(BaseRequest): From b664d55a04bc93ec6d2c65e60dd0ac208d884bd2 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Fri, 28 Nov 2025 16:29:26 +0800 Subject: [PATCH 24/25] new package --- docker/requirements.txt | 1 + poetry.lock | 4 ++-- pyproject.toml | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/docker/requirements.txt b/docker/requirements.txt index 873cb4d22..d3268edae 100644 --- a/docker/requirements.txt +++ b/docker/requirements.txt @@ -159,3 +159,4 @@ websockets==15.0.1 xlrd==2.0.2 xlsxwriter==3.2.5 prometheus-client==0.23.1 +pymilvus==2.5.12 diff --git a/poetry.lock b/poetry.lock index e5e3bc1bd..40d0f6210 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.2.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "absl-py" @@ -6421,4 +6421,4 @@ tree-mem = ["neo4j", "schedule"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<4.0" -content-hash = "a98b5ddffb4c031342ef1314a93666460ce0903e207bc79d23478b80a99b7f40" +content-hash = "95e737a53fed62215bcb523c162e19ed67ffc745e27fa081bc3da5e356eba086" diff --git a/pyproject.toml b/pyproject.toml index 7efd77d80..9a8db2694 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,7 +91,7 @@ mem-reader = [ # PreferenceTextMemory pref-mem = [ - "pymilvus (>=2.6.1,<3.0.0)", # Milvus Vector DB + "pymilvus (>=2.5.12,<3.0.0)", # Milvus Vector DB "datasketch (>=1.6.5,<2.0.0)", # MinHash library ] From 9d83a8d4d7bbe56f768d60c82e8b23a25a75e3e5 Mon Sep 17 00:00:00 2001 From: "yuan.wang" Date: Sat, 29 Nov 2025 17:33:05 +0800 Subject: [PATCH 25/25] fix bug --- src/memos/api/handlers/memory_handler.py | 4 ++-- .../textual/prefer_text_memory/extractor.py | 14 +++++++++----- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index b6c4a40b5..83f51428c 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -180,8 +180,8 @@ def handle_get_memories( return GetMemoryResponse( message="Memories retrieved successfully", data={ - "text_mem": memories, - "pref_mem": preferences, + "text_mem": [{"cube_id": get_mem_req.mem_cube_id, "memories": memories}], + "pref_mem": [{"cube_id": get_mem_req.mem_cube_id, "memories": preferences}], }, ) diff --git a/src/memos/memories/textual/prefer_text_memory/extractor.py b/src/memos/memories/textual/prefer_text_memory/extractor.py index cf40f109a..e105500bd 100644 --- a/src/memos/memories/textual/prefer_text_memory/extractor.py +++ b/src/memos/memories/textual/prefer_text_memory/extractor.py @@ -9,7 +9,11 @@ from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_reader.simple_struct import detect_lang -from memos.memories.textual.item import PreferenceTextualMemoryMetadata, TextualMemoryItem +from memos.memories.textual.item import ( + PreferenceTextualMemoryMetadata, + TextualMemoryItem, + list_all_fields, +) from memos.memories.textual.prefer_text_memory.spliter import Splitter from memos.memories.textual.prefer_text_memory.utils import convert_messages_to_string from memos.templates.prefer_complete_prompt import ( @@ -114,8 +118,8 @@ def _process_single_chunk_explicit( vector_info = { "embedding": self.embedder.embed([pref["context_summary"]])[0], } - - extract_info = {**basic_info, **pref, **vector_info, **info} + user_info = {k: v for k, v in info.items() if k not in list_all_fields()} + extract_info = {**basic_info, **pref, **vector_info, **info, "info": user_info} metadata = PreferenceTextualMemoryMetadata( type=msg_type, preference_type="explicit_preference", **extract_info @@ -143,8 +147,8 @@ def _process_single_chunk_implicit( vector_info = { "embedding": self.embedder.embed([pref["context_summary"]])[0], } - - extract_info = {**basic_info, **pref, **vector_info, **info} + user_info = {k: v for k, v in info.items() if k not in list_all_fields()} + extract_info = {**basic_info, **pref, **vector_info, **info, "info": user_info} metadata = PreferenceTextualMemoryMetadata( type=msg_type, preference_type="implicit_preference", **extract_info