From f34a2f03ad5b3df303b9e04d3189c345144ef7fc Mon Sep 17 00:00:00 2001 From: fridayL Date: Tue, 11 Nov 2025 21:03:45 +0800 Subject: [PATCH 1/7] feat: re org code --- src/memos/api/handlers/__init__.py | 55 ++ src/memos/api/handlers/add_handlers.py | 253 +++++ src/memos/api/handlers/chat_handlers.py | 130 +++ src/memos/api/handlers/component_init.py | 270 ++++++ src/memos/api/handlers/config_builders.py | 153 +++ src/memos/api/handlers/formatters_handlers.py | 92 ++ src/memos/api/handlers/scheduler_handlers.py | 220 +++++ src/memos/api/handlers/search_handlers.py | 261 ++++++ src/memos/api/routers/server_router.py | 873 ++---------------- 9 files changed, 1509 insertions(+), 798 deletions(-) create mode 100644 src/memos/api/handlers/__init__.py create mode 100644 src/memos/api/handlers/add_handlers.py create mode 100644 src/memos/api/handlers/chat_handlers.py create mode 100644 src/memos/api/handlers/component_init.py create mode 100644 src/memos/api/handlers/config_builders.py create mode 100644 src/memos/api/handlers/formatters_handlers.py create mode 100644 src/memos/api/handlers/scheduler_handlers.py create mode 100644 src/memos/api/handlers/search_handlers.py diff --git a/src/memos/api/handlers/__init__.py b/src/memos/api/handlers/__init__.py new file mode 100644 index 00000000..c57a33d3 --- /dev/null +++ b/src/memos/api/handlers/__init__.py @@ -0,0 +1,55 @@ +""" +Server handlers for MemOS API routers. + +This package contains modular handlers for the server_router, responsible for: +- Building component configurations (config_builders) +- Initializing server components (component_init) +- Formatting data for API responses (formatters) +- Handling search, add, scheduler, and chat operations +""" + +# Lazy imports to avoid circular dependencies +from memos.api import handlers +from memos.api.handlers import add_handlers, chat_handlers, scheduler_handlers, search_handlers +from memos.api.handlers.component_init import init_server +from memos.api.handlers.config_builders import ( + build_embedder_config, + build_graph_db_config, + build_internet_retriever_config, + build_llm_config, + build_mem_reader_config, + build_pref_adder_config, + build_pref_extractor_config, + build_pref_retriever_config, + build_reranker_config, + build_vec_db_config, +) +from memos.api.handlers.formatters_handlers import ( + format_memory_item, + post_process_pref_mem, + to_iter, +) + + +__all__ = [ + "add_handlers", + "build_embedder_config", + "build_graph_db_config", + "build_internet_retriever_config", + "build_llm_config", + "build_mem_reader_config", + "build_pref_adder_config", + "build_pref_extractor_config", + "build_pref_retriever_config", + "build_reranker_config", + "build_vec_db_config", + "chat_handlers", + "format_memory_item", + "formatters_handlers", + "handlers", + "init_server", + "post_process_pref_mem", + "scheduler_handlers", + "search_handlers", + "to_iter", +] diff --git a/src/memos/api/handlers/add_handlers.py b/src/memos/api/handlers/add_handlers.py new file mode 100644 index 00000000..2392d85d --- /dev/null +++ b/src/memos/api/handlers/add_handlers.py @@ -0,0 +1,253 @@ +""" +Add handler for memory addition functionality. + +This module handles adding new memories to the system, supporting both +text and preference memory additions with optional async processing. +""" + +import json +import os + +from datetime import datetime +from typing import Any + +from memos.api.product_models import APIADDRequest, MemoryResponse +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_scheduler.schemas.general_schemas import ( + ADD_LABEL, + MEM_READ_LABEL, + PREF_ADD_LABEL, +) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.types import UserContext + + +logger = get_logger(__name__) + + +def _process_text_mem( + add_req: APIADDRequest, + user_context: UserContext, + naive_mem_cube: Any, + mem_reader: Any, + mem_scheduler: Any, +) -> list[dict[str, str]]: + """ + Process and add text memories. + + Extracts memories from messages and adds them to the text memory system. + Handles both sync and async modes. + + Args: + add_req: Add memory request + user_context: User context with IDs + naive_mem_cube: Memory cube instance + mem_reader: Memory reader for extraction + mem_scheduler: Scheduler for async tasks + + Returns: + List of formatted memory responses + """ + target_session_id = add_req.session_id or "default_session" + + # Determine sync mode + try: + sync_mode = getattr(naive_mem_cube.text_mem, "mode", "sync") + except Exception: + sync_mode = "sync" + + logger.info(f"Processing text memory with mode: {sync_mode}") + + memories_local = mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + mode="fast" if sync_mode == "async" else "fine", + ) + flattened_local = [mm for m in memories_local for mm in m] + logger.info(f"Memory extraction completed for user {add_req.user_id}") + + mem_ids_local: list[str] = naive_mem_cube.text_mem.add( + flattened_local, + user_name=user_context.mem_cube_id, + ) + logger.info( + f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " + f"in session {add_req.session_id}: {mem_ids_local}" + ) + + # Handle async/sync scheduling + if sync_mode == "async": + try: + message_item_read = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=naive_mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids_local), + timestamp=datetime.utcnow(), + user_name=add_req.mem_cube_id, + ) + mem_scheduler.submit_messages(messages=[message_item_read]) + logger.info(f"Submitted async memory read task: {json.dumps(mem_ids_local)}") + except Exception as e: + logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True) + else: + message_item_add = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=naive_mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids_local), + timestamp=datetime.utcnow(), + user_name=add_req.mem_cube_id, + ) + mem_scheduler.submit_messages(messages=[message_item_add]) + + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.memory_type, + } + for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) + ] + + +def _process_pref_mem( + add_req: APIADDRequest, + user_context: UserContext, + naive_mem_cube: Any, + mem_scheduler: Any, +) -> list[dict[str, str]]: + """ + Process and add preference memories. + + Extracts preferences from messages and adds them to the preference memory system. + Handles both sync and async modes. + + Args: + add_req: Add memory request + user_context: User context with IDs + naive_mem_cube: Memory cube instance + mem_scheduler: Scheduler for async tasks + + Returns: + List of formatted preference responses + """ + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + + # Determine sync mode + try: + sync_mode = getattr(naive_mem_cube.text_mem, "mode", "sync") + except Exception: + sync_mode = "sync" + + target_session_id = add_req.session_id or "default_session" + + # Follow async behavior: enqueue when async + if sync_mode == "async": + try: + messages_list = [add_req.messages] + message_item_pref = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=naive_mem_cube, + label=PREF_ADD_LABEL, + content=json.dumps(messages_list), + timestamp=datetime.utcnow(), + ) + mem_scheduler.submit_messages(messages=[message_item_pref]) + logger.info("Submitted preference add to scheduler (async mode)") + except Exception as e: + logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True) + return [] + else: + pref_memories_local = naive_mem_cube.pref_mem.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + "mem_cube_id": add_req.mem_cube_id, + }, + ) + pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local) + logger.info( + f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " + f"in session {add_req.session_id}: {pref_ids_local}" + ) + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.preference_type, + } + for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) + ] + + +def handle_add_memories( + add_req: APIADDRequest, + naive_mem_cube: Any, + mem_reader: Any, + mem_scheduler: Any, +) -> MemoryResponse: + """ + Main handler for add memories endpoint. + + Orchestrates the addition of both text and preference memories, + supporting concurrent processing. + + Args: + add_req: Add memory request + naive_mem_cube: Memory cube instance + mem_reader: Memory reader for extraction + mem_scheduler: Scheduler for async tasks + + Returns: + MemoryResponse with added memory information + """ + # Create UserContext object + user_context = UserContext( + user_id=add_req.user_id, + mem_cube_id=add_req.mem_cube_id, + session_id=add_req.session_id or "default_session", + ) + + logger.info(f"Add Req is: {add_req}") + + with ContextThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit( + _process_text_mem, + add_req, + user_context, + naive_mem_cube, + mem_reader, + mem_scheduler, + ) + pref_future = executor.submit( + _process_pref_mem, + add_req, + user_context, + naive_mem_cube, + mem_scheduler, + ) + text_response_data = text_future.result() + pref_response_data = pref_future.result() + + logger.info(f"add_memories Text response data: {text_response_data}") + logger.info(f"add_memories Pref response data: {pref_response_data}") + + return MemoryResponse( + message="Memory added successfully", + data=text_response_data + pref_response_data, + ) diff --git a/src/memos/api/handlers/chat_handlers.py b/src/memos/api/handlers/chat_handlers.py new file mode 100644 index 00000000..6cd7a135 --- /dev/null +++ b/src/memos/api/handlers/chat_handlers.py @@ -0,0 +1,130 @@ +""" +Chat handler for chat functionality. + +This module handles both streaming and complete chat responses with memory integration. +""" + +import json +import traceback + +from typing import Any + +from fastapi import HTTPException +from fastapi.responses import StreamingResponse + +from memos.api.product_models import APIChatCompleteRequest, ChatRequest +from memos.log import get_logger +from memos.mem_os.product_server import MOSServer + + +logger = get_logger(__name__) + + +def handle_chat_complete( + chat_req: APIChatCompleteRequest, + mos_server: Any, + naive_mem_cube: Any, +) -> dict[str, Any]: + """ + Chat with MemOS for complete response (non-streaming). + + Processes a chat request and returns the complete response with references. + + Args: + chat_req: Chat complete request + mos_server: MOS server instance + naive_mem_cube: Memory cube instance + + Returns: + Dictionary with response and references + + Raises: + HTTPException: If chat fails + """ + try: + # Collect all responses from the server + content, references = mos_server.chat( + query=chat_req.query, + user_id=chat_req.user_id, + cube_id=chat_req.mem_cube_id, + mem_cube=naive_mem_cube, + history=chat_req.history, + internet_search=chat_req.internet_search, + moscube=chat_req.moscube, + base_prompt=chat_req.base_prompt, + top_k=chat_req.top_k, + threshold=chat_req.threshold, + session_id=chat_req.session_id, + ) + + # Return the complete response + return { + "message": "Chat completed successfully", + "data": {"response": content, "references": references}, + } + + except ValueError as err: + raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err + except Exception as err: + logger.error(f"Failed to complete chat: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + + +def handle_chat_stream( + chat_req: ChatRequest, + mos_server: MOSServer, +) -> StreamingResponse: + """ + Chat with MemOS via Server-Sent Events (SSE) stream. + + Processes a chat request and streams the response back via SSE. + + Args: + chat_req: Chat stream request + mos_server: MOS server instance + + Returns: + StreamingResponse with SSE formatted chat stream + + Raises: + HTTPException: If stream initialization fails + """ + try: + + def generate_chat_response(): + """Generate chat response as SSE stream.""" + try: + # Directly yield from the generator without async wrapper + yield from mos_server.chat_with_references( + query=chat_req.query, + user_id=chat_req.user_id, + cube_id=chat_req.mem_cube_id, + history=chat_req.history, + internet_search=chat_req.internet_search, + moscube=chat_req.moscube, + session_id=chat_req.session_id, + ) + + except Exception as e: + logger.error(f"Error in chat stream: {e}") + error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n" + yield error_data + + return StreamingResponse( + generate_chat_response(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "*", + "Access-Control-Allow-Methods": "*", + }, + ) + + except ValueError as err: + raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err + except Exception as err: + logger.error(f"Failed to start chat stream: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py new file mode 100644 index 00000000..b89e879b --- /dev/null +++ b/src/memos/api/handlers/component_init.py @@ -0,0 +1,270 @@ +""" +Server component initialization module. + +This module handles the initialization of all MemOS server components +including databases, LLMs, memory systems, and schedulers. +""" + +from typing import TYPE_CHECKING, Any + +from memos.api.config import APIConfig +from memos.api.handlers.config_builders import ( + build_embedder_config, + build_graph_db_config, + build_internet_retriever_config, + build_llm_config, + build_mem_reader_config, + build_pref_adder_config, + build_pref_extractor_config, + build_pref_retriever_config, + build_reranker_config, + build_vec_db_config, +) +from memos.configs.mem_scheduler import SchedulerConfigFactory +from memos.embedders.factory import EmbedderFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.llms.factory import LLMFactory +from memos.log import get_logger +from memos.mem_cube.navie import NaiveMemCube +from memos.mem_os.product_server import MOSServer +from memos.mem_reader.factory import MemReaderFactory +from memos.mem_scheduler.orm_modules.base_model import BaseDBManager +from memos.mem_scheduler.scheduler_factory import SchedulerFactory +from memos.memories.textual.prefer_text_memory.factory import ( + AdderFactory, + ExtractorFactory, + RetrieverFactory, +) +from memos.memories.textual.simple_preference import SimplePreferenceTextMemory +from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( + InternetRetrieverFactory, +) +from memos.reranker.factory import RerankerFactory +from memos.vec_dbs.factory import VecDBFactory + + +if TYPE_CHECKING: + from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler + +logger = get_logger(__name__) + + +def _get_default_memory_size(cube_config: Any) -> dict[str, int]: + """ + Get default memory size configuration. + + Attempts to retrieve memory size from cube config, falls back to defaults + if not found. + + Args: + cube_config: The cube configuration object + + Returns: + Dictionary with memory sizes for different memory types + """ + return getattr(cube_config.text_mem.config, "memory_size", None) or { + "WorkingMemory": 20, + "LongTermMemory": 1500, + "UserMemory": 480, + } + + +def init_server() -> tuple[Any, ...]: + """ + Initialize all server components and configurations. + + This function orchestrates the creation and initialization of all components + required by the MemOS server, including: + - Database connections (graph DB, vector DB) + - Language models and embedders + - Memory systems (text, preference) + - Scheduler and related modules + + Returns: + A tuple containing all initialized components in this order: + ( + graph_db, + mem_reader, + llm, + embedder, + reranker, + internet_retriever, + memory_manager, + default_cube_config, + mos_server, + mem_scheduler, + naive_mem_cube, + api_module, + vector_db, + pref_extractor, + pref_adder, + pref_retriever, + text_mem, + pref_mem, + ) + """ + logger.info("Initializing MemOS server components...") + + # Get default cube configuration + default_cube_config = APIConfig.get_default_cube_config() + + # Build component configurations + graph_db_config = build_graph_db_config() + llm_config = build_llm_config() + embedder_config = build_embedder_config() + mem_reader_config = build_mem_reader_config() + reranker_config = build_reranker_config() + internet_retriever_config = build_internet_retriever_config() + vector_db_config = build_vec_db_config() + pref_extractor_config = build_pref_extractor_config() + pref_adder_config = build_pref_adder_config() + pref_retriever_config = build_pref_retriever_config() + + logger.debug("Component configurations built successfully") + + # Create component instances + graph_db = GraphStoreFactory.from_config(graph_db_config) + vector_db = VecDBFactory.from_config(vector_db_config) + llm = LLMFactory.from_config(llm_config) + embedder = EmbedderFactory.from_config(embedder_config) + mem_reader = MemReaderFactory.from_config(mem_reader_config) + reranker = RerankerFactory.from_config(reranker_config) + internet_retriever = InternetRetrieverFactory.from_config( + internet_retriever_config, embedder=embedder + ) + + logger.debug("Core components instantiated") + + # Initialize memory manager + memory_manager = MemoryManager( + graph_db, + embedder, + llm, + memory_size=_get_default_memory_size(default_cube_config), + is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False), + ) + + logger.debug("Memory manager initialized") + + # Initialize text memory + text_mem = SimpleTreeTextMemory( + llm=llm, + embedder=embedder, + mem_reader=mem_reader, + graph_db=graph_db, + reranker=reranker, + memory_manager=memory_manager, + config=default_cube_config.text_mem.config, + internet_retriever=internet_retriever, + ) + + logger.debug("Text memory initialized") + + # Initialize preference memory components + pref_extractor = ExtractorFactory.from_config( + config_factory=pref_extractor_config, + llm_provider=llm, + embedder=embedder, + vector_db=vector_db, + ) + + pref_adder = AdderFactory.from_config( + config_factory=pref_adder_config, + llm_provider=llm, + embedder=embedder, + vector_db=vector_db, + text_mem=text_mem, + ) + + pref_retriever = RetrieverFactory.from_config( + config_factory=pref_retriever_config, + llm_provider=llm, + embedder=embedder, + reranker=reranker, + vector_db=vector_db, + ) + + logger.debug("Preference memory components initialized") + + # Initialize preference memory + pref_mem = SimplePreferenceTextMemory( + extractor_llm=llm, + vector_db=vector_db, + embedder=embedder, + reranker=reranker, + extractor=pref_extractor, + adder=pref_adder, + retriever=pref_retriever, + ) + + logger.debug("Preference memory initialized") + + # Initialize MOS Server + mos_server = MOSServer( + mem_reader=mem_reader, + llm=llm, + online_bot=False, + ) + + logger.debug("MOS server initialized") + + # Create MemCube with pre-initialized memory instances + naive_mem_cube = NaiveMemCube( + text_mem=text_mem, + pref_mem=pref_mem, + act_mem=None, + para_mem=None, + ) + + logger.debug("MemCube created") + + # Initialize Scheduler + scheduler_config_dict = APIConfig.get_scheduler_config() + scheduler_config = SchedulerConfigFactory( + backend="optimized_scheduler", config=scheduler_config_dict + ) + mem_scheduler: OptimizedScheduler = SchedulerFactory.from_config(scheduler_config) + mem_scheduler.initialize_modules( + chat_llm=llm, + process_llm=mem_reader.llm, + db_engine=BaseDBManager.create_default_sqlite_engine(), + mem_reader=mem_reader, + ) + mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube) + + logger.debug("Scheduler initialized") + + # Initialize SchedulerAPIModule + api_module = mem_scheduler.api_module + + # Start scheduler if enabled + import os + + if os.getenv("API_SCHEDULER_ON", True): + mem_scheduler.start() + logger.info("Scheduler started") + + logger.info("MemOS server components initialized successfully") + + return ( + graph_db, + mem_reader, + llm, + embedder, + reranker, + internet_retriever, + memory_manager, + default_cube_config, + mos_server, + mem_scheduler, + naive_mem_cube, + api_module, + vector_db, + pref_extractor, + pref_adder, + pref_retriever, + text_mem, + pref_mem, + ) diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py new file mode 100644 index 00000000..9f510add --- /dev/null +++ b/src/memos/api/handlers/config_builders.py @@ -0,0 +1,153 @@ +""" +Configuration builders for server handlers. + +This module contains factory functions that build configurations for various +components used by the MemOS server. Each function constructs and validates +a configuration dictionary using the appropriate ConfigFactory. +""" + +import os + +from typing import Any + +from memos.api.config import APIConfig +from memos.configs.embedder import EmbedderConfigFactory +from memos.configs.graph_db import GraphDBConfigFactory +from memos.configs.internet_retriever import InternetRetrieverConfigFactory +from memos.configs.llm import LLMConfigFactory +from memos.configs.mem_reader import MemReaderConfigFactory +from memos.configs.reranker import RerankerConfigFactory +from memos.configs.vec_db import VectorDBConfigFactory +from memos.memories.textual.prefer_text_memory.config import ( + AdderConfigFactory, + ExtractorConfigFactory, + RetrieverConfigFactory, +) + + +def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: + """ + Build graph database configuration. + + Args: + user_id: User ID for configuration context (default: "default") + + Returns: + Validated graph database configuration dictionary + """ + graph_db_backend_map = { + "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), + "neo4j": APIConfig.get_neo4j_config(user_id=user_id), + "nebular": APIConfig.get_nebular_config(user_id=user_id), + "polardb": APIConfig.get_polardb_config(user_id=user_id), + } + + graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() + return GraphDBConfigFactory.model_validate( + { + "backend": graph_db_backend, + "config": graph_db_backend_map[graph_db_backend], + } + ) + + +def build_vec_db_config() -> dict[str, Any]: + """ + Build vector database configuration. + + Returns: + Validated vector database configuration dictionary + """ + return VectorDBConfigFactory.model_validate( + { + "backend": "milvus", + "config": APIConfig.get_milvus_config(), + } + ) + + +def build_llm_config() -> dict[str, Any]: + """ + Build LLM configuration. + + Returns: + Validated LLM configuration dictionary + """ + return LLMConfigFactory.model_validate( + { + "backend": "openai", + "config": APIConfig.get_openai_config(), + } + ) + + +def build_embedder_config() -> dict[str, Any]: + """ + Build embedder configuration. + + Returns: + Validated embedder configuration dictionary + """ + return EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config()) + + +def build_mem_reader_config() -> dict[str, Any]: + """ + Build memory reader configuration. + + Returns: + Validated memory reader configuration dictionary + """ + return MemReaderConfigFactory.model_validate( + APIConfig.get_product_default_config()["mem_reader"] + ) + + +def build_reranker_config() -> dict[str, Any]: + """ + Build reranker configuration. + + Returns: + Validated reranker configuration dictionary + """ + return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) + + +def build_internet_retriever_config() -> dict[str, Any]: + """ + Build internet retriever configuration. + + Returns: + Validated internet retriever configuration dictionary + """ + return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) + + +def build_pref_extractor_config() -> dict[str, Any]: + """ + Build preference memory extractor configuration. + + Returns: + Validated extractor configuration dictionary + """ + return ExtractorConfigFactory.model_validate({"backend": "naive", "config": {}}) + + +def build_pref_adder_config() -> dict[str, Any]: + """ + Build preference memory adder configuration. + + Returns: + Validated adder configuration dictionary + """ + return AdderConfigFactory.model_validate({"backend": "naive", "config": {}}) + + +def build_pref_retriever_config() -> dict[str, Any]: + """ + Build preference memory retriever configuration. + + Returns: + Validated retriever configuration dictionary + """ + return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}}) diff --git a/src/memos/api/handlers/formatters_handlers.py b/src/memos/api/handlers/formatters_handlers.py new file mode 100644 index 00000000..976be87b --- /dev/null +++ b/src/memos/api/handlers/formatters_handlers.py @@ -0,0 +1,92 @@ +""" +Data formatting utilities for server handlers. + +This module provides utility functions for formatting and transforming data +structures for API responses, including memory items and preferences. +""" + +from typing import Any + +from memos.templates.instruction_completion import instruct_completion + + +def to_iter(running: Any) -> list[Any]: + """ + Normalize running tasks to a list of task objects. + + Handles different input types and converts them to a consistent list format. + + Args: + running: Running tasks, can be None, dict, or iterable + + Returns: + List of task objects + """ + if running is None: + return [] + if isinstance(running, dict): + return list(running.values()) + return list(running) if running else [] + + +def format_memory_item(memory_data: Any) -> dict[str, Any]: + """ + Format a single memory item for API response. + + Transforms a memory object into a dictionary with metadata properly + structured for API consumption. + + Args: + memory_data: Memory object to format + + Returns: + Formatted memory dictionary with ref_id and metadata + """ + memory = memory_data.model_dump() + memory_id = memory["id"] + ref_id = f"[{memory_id.split('-')[0]}]" + + memory["ref_id"] = ref_id + memory["metadata"]["embedding"] = [] + memory["metadata"]["sources"] = [] + memory["metadata"]["usage"] = [] + memory["metadata"]["ref_id"] = ref_id + memory["metadata"]["id"] = memory_id + memory["metadata"]["memory"] = memory["memory"] + + return memory + + +def post_process_pref_mem( + memories_result: dict[str, Any], + pref_formatted_mem: list[dict[str, Any]], + mem_cube_id: str, + include_preference: bool, +) -> dict[str, Any]: + """ + Post-process preference memory results. + + Adds formatted preference memories to the result dictionary and generates + instruction completion strings if preferences are included. + + Args: + memories_result: Result dictionary to update + pref_formatted_mem: List of formatted preference memories + mem_cube_id: Memory cube ID + include_preference: Whether to include preferences in result + + Returns: + Updated memories_result dictionary + """ + if include_preference: + memories_result["pref_mem"].append( + { + "cube_id": mem_cube_id, + "memories": pref_formatted_mem, + } + ) + pref_instruction, pref_note = instruct_completion(pref_formatted_mem) + memories_result["pref_string"] = pref_instruction + memories_result["pref_note"] = pref_note + + return memories_result diff --git a/src/memos/api/handlers/scheduler_handlers.py b/src/memos/api/handlers/scheduler_handlers.py new file mode 100644 index 00000000..f621e891 --- /dev/null +++ b/src/memos/api/handlers/scheduler_handlers.py @@ -0,0 +1,220 @@ +""" +Scheduler handler for scheduler management functionality. + +This module handles all scheduler-related operations including status checking, +waiting for idle state, and streaming progress updates. +""" + +import json +import time +import traceback + +from typing import Any + +from fastapi import HTTPException +from fastapi.responses import StreamingResponse + +from memos.api.handlers.formatters_handlers import to_iter +from memos.log import get_logger + + +logger = get_logger(__name__) + + +def handle_scheduler_status( + user_name: str | None = None, + mem_scheduler: Any | None = None, + instance_id: str = "", +) -> dict[str, Any]: + """ + Get scheduler running status. + + Retrieves the number of running tasks for a specific user or globally. + + Args: + user_name: Optional specific user name to filter tasks + mem_scheduler: Scheduler instance + instance_id: Instance ID for response + + Returns: + Dictionary with status information + + Raises: + HTTPException: If status retrieval fails + """ + try: + if user_name: + running = mem_scheduler.dispatcher.get_running_tasks( + lambda task: getattr(task, "mem_cube_id", None) == user_name + ) + tasks_iter = to_iter(running) + running_count = len(tasks_iter) + return { + "message": "ok", + "data": { + "scope": "user", + "user_name": user_name, + "running_tasks": running_count, + "timestamp": time.time(), + "instance_id": instance_id, + }, + } + else: + running_all = mem_scheduler.dispatcher.get_running_tasks(lambda _t: True) + tasks_iter = to_iter(running_all) + running_count = len(tasks_iter) + + task_count_per_user: dict[str, int] = {} + for task in tasks_iter: + cube = getattr(task, "mem_cube_id", "unknown") + task_count_per_user[cube] = task_count_per_user.get(cube, 0) + 1 + + try: + metrics_snapshot = mem_scheduler.dispatcher.metrics.snapshot() + except Exception: + metrics_snapshot = {} + + return { + "message": "ok", + "data": { + "scope": "global", + "running_tasks": running_count, + "task_count_per_user": task_count_per_user, + "timestamp": time.time(), + "instance_id": instance_id, + "metrics": metrics_snapshot, + }, + } + except Exception as err: + logger.error("Failed to get scheduler status: %s", traceback.format_exc()) + raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err + + +def handle_scheduler_wait( + user_name: str, + timeout_seconds: float = 120.0, + poll_interval: float = 0.2, + mem_scheduler: Any | None = None, +) -> dict[str, Any]: + """ + Wait until scheduler is idle for a specific user. + + Blocks until scheduler has no running tasks for the given user, or timeout. + + Args: + user_name: User name to wait for + timeout_seconds: Maximum wait time in seconds + poll_interval: Polling interval in seconds + mem_scheduler: Scheduler instance + + Returns: + Dictionary with wait result and statistics + + Raises: + HTTPException: If wait operation fails + """ + start = time.time() + try: + while True: + running = mem_scheduler.dispatcher.get_running_tasks( + lambda task: task.mem_cube_id == user_name + ) + running_count = len(running) + elapsed = time.time() - start + + # success -> scheduler is idle + if running_count == 0: + return { + "message": "idle", + "data": { + "running_tasks": 0, + "waited_seconds": round(elapsed, 3), + "timed_out": False, + "user_name": user_name, + }, + } + + # timeout check + if elapsed > timeout_seconds: + return { + "message": "timeout", + "data": { + "running_tasks": running_count, + "waited_seconds": round(elapsed, 3), + "timed_out": True, + "user_name": user_name, + }, + } + + time.sleep(poll_interval) + + except Exception as err: + logger.error("Failed while waiting for scheduler: %s", traceback.format_exc()) + raise HTTPException(status_code=500, detail="Failed while waiting for scheduler") from err + + +def handle_scheduler_wait_stream( + user_name: str, + timeout_seconds: float = 120.0, + poll_interval: float = 0.2, + mem_scheduler: Any | None = None, + instance_id: str = "", +) -> StreamingResponse: + """ + Stream scheduler progress via Server-Sent Events (SSE). + + Emits periodic heartbeat frames while tasks are running, then final + status frame indicating idle or timeout. + + Args: + user_name: User name to monitor + timeout_seconds: Maximum stream duration in seconds + poll_interval: Polling interval between updates + mem_scheduler: Scheduler instance + instance_id: Instance ID for response + + Returns: + StreamingResponse with SSE formatted progress updates + + Example: + curl -N "http://localhost:8000/product/scheduler/wait/stream?timeout_seconds=10" + """ + + def event_generator(): + start = time.time() + try: + while True: + running = mem_scheduler.dispatcher.get_running_tasks( + lambda task: task.mem_cube_id == user_name + ) + running_count = len(running) + elapsed = time.time() - start + + payload = { + "user_name": user_name, + "running_tasks": running_count, + "elapsed_seconds": round(elapsed, 3), + "status": "running" if running_count > 0 else "idle", + "instance_id": instance_id, + } + yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n" + + if running_count == 0 or elapsed > timeout_seconds: + payload["status"] = "idle" if running_count == 0 else "timeout" + payload["timed_out"] = running_count > 0 + yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n" + break + + time.sleep(poll_interval) + + except Exception as e: + err_payload = { + "status": "error", + "detail": "stream_failed", + "exception": str(e), + "user_name": user_name, + } + logger.error(f"Scheduler stream error for {user_name}: {traceback.format_exc()}") + yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n" + + return StreamingResponse(event_generator(), media_type="text/event-stream") diff --git a/src/memos/api/handlers/search_handlers.py b/src/memos/api/handlers/search_handlers.py new file mode 100644 index 00000000..a1a03d34 --- /dev/null +++ b/src/memos/api/handlers/search_handlers.py @@ -0,0 +1,261 @@ +""" +Search handler for memory search functionality. + +This module handles all memory search operations including fast, fine-grained, +and mixture-based search modes. +""" + +import os +import traceback + +from typing import Any + +from memos.api.handlers.formatters_handlers import ( + format_memory_item, + post_process_pref_mem, +) +from memos.api.product_models import APISearchRequest, SearchResponse +from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger +from memos.mem_scheduler.schemas.general_schemas import SearchMode +from memos.types import MOSSearchResult, UserContext + + +logger = get_logger(__name__) + + +def fast_search_memories( + search_req: APISearchRequest, + user_context: UserContext, + naive_mem_cube: Any, +) -> list[dict[str, Any]]: + """ + Fast search memories using vector database. + + Performs a quick vector-based search for memories. + + Args: + search_req: Search request containing query and parameters + user_context: User context with IDs + naive_mem_cube: Memory cube instance + + Returns: + List of formatted search results + """ + target_session_id = search_req.session_id or "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + # Create MemCube and perform search + search_results = naive_mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FAST, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + formatted_memories = [format_memory_item(data) for data in search_results] + + return formatted_memories + + +def fine_search_memories( + search_req: APISearchRequest, + user_context: UserContext, + mem_scheduler: Any, +) -> list[dict[str, Any]]: + """ + Fine-grained search memories using scheduler and retriever. + + Performs a more comprehensive search with query enhancement. + + Args: + search_req: Search request containing query and parameters + user_context: User context with IDs + mem_scheduler: Scheduler instance for advanced retrieval + + Returns: + List of formatted search results + """ + target_session_id = search_req.session_id or "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + searcher = mem_scheduler.searcher + + info = { + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + } + + fast_retrieved_memories = searcher.retrieve( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FAST, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info=info, + ) + + fast_memories = searcher.post_retrieve( + retrieved_results=fast_retrieved_memories, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) + + enhanced_results, _ = mem_scheduler.retriever.enhance_memories_with_query( + query_history=[search_req.query], + memories=fast_memories, + ) + + formatted_memories = [format_memory_item(data) for data in enhanced_results] + + return formatted_memories + + +def mix_search_memories( + search_req: APISearchRequest, + user_context: UserContext, + mem_scheduler: Any, +) -> list[dict[str, Any]]: + """ + Mix search memories: fast search + async fine search. + + Combines fast initial search with asynchronous fine-grained search. + + Args: + search_req: Search request containing query and parameters + user_context: User context with IDs + mem_scheduler: Scheduler instance + + Returns: + List of formatted search results + """ + formatted_memories = mem_scheduler.mix_search_memories( + search_req=search_req, + user_context=user_context, + ) + return formatted_memories + + +def handle_search_memories( + search_req: APISearchRequest, + naive_mem_cube: Any, + mem_scheduler: Any, +) -> SearchResponse: + """ + Main handler for search memories endpoint. + + Orchestrates the search process based on the requested search mode, + supporting both text and preference memory searches. + + Args: + search_req: Search request + naive_mem_cube: Memory cube instance + mem_scheduler: Scheduler instance + + Returns: + SearchResponse with formatted results + """ + # Create UserContext object + user_context = UserContext( + user_id=search_req.user_id, + mem_cube_id=search_req.mem_cube_id, + session_id=search_req.session_id or "default_session", + ) + logger.info(f"Search Req is: {search_req}") + + memories_result: MOSSearchResult = { + "text_mem": [], + "act_mem": [], + "para_mem": [], + "pref_mem": [], + "pref_note": "", + } + + if search_req.mode == SearchMode.NOT_INITIALIZED: + search_mode = os.getenv("SEARCH_MODE", SearchMode.FAST) + else: + search_mode = search_req.mode + + def _search_text(): + try: + if search_mode == SearchMode.FAST: + formatted_memories = fast_search_memories( + search_req=search_req, + user_context=user_context, + naive_mem_cube=naive_mem_cube, + ) + elif search_mode == SearchMode.FINE: + formatted_memories = fine_search_memories( + search_req=search_req, + user_context=user_context, + mem_scheduler=mem_scheduler, + ) + elif search_mode == SearchMode.MIXTURE: + formatted_memories = mix_search_memories( + search_req=search_req, + user_context=user_context, + mem_scheduler=mem_scheduler, + ) + else: + logger.error(f"Unsupported search mode: {search_mode}") + return [] + return formatted_memories + except Exception as e: + logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc()) + return [] + + def _search_pref(): + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + try: + results = naive_mem_cube.pref_mem.search( + query=search_req.query, + top_k=search_req.pref_top_k, + info={ + "user_id": search_req.user_id, + "session_id": search_req.session_id, + "chat_history": search_req.chat_history, + }, + ) + return [format_memory_item(data) for data in results] + except Exception as e: + logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) + return [] + + with ContextThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(_search_text) + pref_future = executor.submit(_search_pref) + text_formatted_memories = text_future.result() + pref_formatted_memories = pref_future.result() + + memories_result["text_mem"].append( + { + "cube_id": search_req.mem_cube_id, + "memories": text_formatted_memories, + } + ) + + memories_result = post_process_pref_mem( + memories_result, + pref_formatted_memories, + search_req.mem_cube_id, + search_req.include_preference, + ) + + logger.info(f"Search memories result: {memories_result}") + + return SearchResponse( + message="Search completed successfully", + data=memories_result, + ) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 7d9f141d..fbe6f7aa 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,315 +1,39 @@ -import json +""" +Server API Router for MemOS. + +This router provides low-level API endpoints for direct server instance operations, +including search, add, scheduler management, and chat functionalities. + +The actual implementation logic is delegated to specialized handler modules +in server_handlers package for better modularity and maintainability. +""" + import os import random as _random import socket -import time -import traceback - -from collections.abc import Iterable -from datetime import datetime -from typing import TYPE_CHECKING, Any -from fastapi import APIRouter, HTTPException -from fastapi.responses import StreamingResponse +from fastapi import APIRouter -from memos.api.config import APIConfig +from memos.api import handlers from memos.api.product_models import ( APIADDRequest, APIChatCompleteRequest, APISearchRequest, + ChatRequest, MemoryResponse, SearchResponse, ) -from memos.configs.embedder import EmbedderConfigFactory -from memos.configs.graph_db import GraphDBConfigFactory -from memos.configs.internet_retriever import InternetRetrieverConfigFactory -from memos.configs.llm import LLMConfigFactory -from memos.configs.mem_reader import MemReaderConfigFactory -from memos.configs.mem_scheduler import SchedulerConfigFactory -from memos.configs.reranker import RerankerConfigFactory -from memos.configs.vec_db import VectorDBConfigFactory -from memos.context.context import ContextThreadPoolExecutor -from memos.embedders.factory import EmbedderFactory -from memos.graph_dbs.factory import GraphStoreFactory -from memos.llms.factory import LLMFactory from memos.log import get_logger -from memos.mem_cube.navie import NaiveMemCube -from memos.mem_os.product_server import MOSServer -from memos.mem_reader.factory import MemReaderFactory -from memos.mem_scheduler.orm_modules.base_model import BaseDBManager -from memos.mem_scheduler.scheduler_factory import SchedulerFactory -from memos.mem_scheduler.schemas.general_schemas import ( - ADD_LABEL, - MEM_READ_LABEL, - PREF_ADD_LABEL, - SearchMode, -) -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.memories.textual.prefer_text_memory.config import ( - AdderConfigFactory, - ExtractorConfigFactory, - RetrieverConfigFactory, -) -from memos.memories.textual.prefer_text_memory.factory import ( - AdderFactory, - ExtractorFactory, - RetrieverFactory, -) -from memos.memories.textual.simple_preference import SimplePreferenceTextMemory -from memos.memories.textual.simple_tree import SimpleTreeTextMemory -from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager -from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( - InternetRetrieverFactory, -) -from memos.reranker.factory import RerankerFactory -from memos.templates.instruction_completion import instruct_completion - - -if TYPE_CHECKING: - from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler -from memos.types import MOSSearchResult, UserContext -from memos.vec_dbs.factory import VecDBFactory logger = get_logger(__name__) router = APIRouter(prefix="/product", tags=["Server API"]) -INSTANCE_ID = f"{socket.gethostname()}:{os.getpid()}:{_random.randint(1000, 9999)}" - - -def _to_iter(running: Any) -> Iterable: - """Normalize running tasks to an iterable of task objects.""" - if running is None: - return [] - if isinstance(running, dict): - return running.values() - return running # assume it's already an iterable (e.g., list) - - -def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]: - """Build graph database configuration.""" - graph_db_backend_map = { - "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), - "neo4j": APIConfig.get_neo4j_config(user_id=user_id), - "nebular": APIConfig.get_nebular_config(user_id=user_id), - "polardb": APIConfig.get_polardb_config(user_id=user_id), - } - - graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() - return GraphDBConfigFactory.model_validate( - { - "backend": graph_db_backend, - "config": graph_db_backend_map[graph_db_backend], - } - ) - - -def _build_vec_db_config() -> dict[str, Any]: - """Build vector database configuration.""" - return VectorDBConfigFactory.model_validate( - { - "backend": "milvus", - "config": APIConfig.get_milvus_config(), - } - ) - - -def _build_llm_config() -> dict[str, Any]: - """Build LLM configuration.""" - return LLMConfigFactory.model_validate( - { - "backend": "openai", - "config": APIConfig.get_openai_config(), - } - ) - - -def _build_embedder_config() -> dict[str, Any]: - """Build embedder configuration.""" - return EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config()) - - -def _build_mem_reader_config() -> dict[str, Any]: - """Build memory reader configuration.""" - return MemReaderConfigFactory.model_validate( - APIConfig.get_product_default_config()["mem_reader"] - ) - - -def _build_reranker_config() -> dict[str, Any]: - """Build reranker configuration.""" - return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) - - -def _build_internet_retriever_config() -> dict[str, Any]: - """Build internet retriever configuration.""" - return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) - - -def _build_pref_extractor_config() -> dict[str, Any]: - """Build extractor configuration.""" - return ExtractorConfigFactory.model_validate({"backend": "naive", "config": {}}) - - -def _build_pref_adder_config() -> dict[str, Any]: - """Build adder configuration.""" - return AdderConfigFactory.model_validate({"backend": "naive", "config": {}}) - - -def _build_pref_retriever_config() -> dict[str, Any]: - """Build retriever configuration.""" - return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}}) - - -def _get_default_memory_size(cube_config) -> dict[str, int]: - """Get default memory size configuration.""" - return getattr(cube_config.text_mem.config, "memory_size", None) or { - "WorkingMemory": 20, - "LongTermMemory": 1500, - "UserMemory": 480, - } - - -def init_server(): - """Initialize server components and configurations.""" - # Get default cube configuration - default_cube_config = APIConfig.get_default_cube_config() - - # Build component configurations - graph_db_config = _build_graph_db_config() - llm_config = _build_llm_config() - embedder_config = _build_embedder_config() - mem_reader_config = _build_mem_reader_config() - reranker_config = _build_reranker_config() - internet_retriever_config = _build_internet_retriever_config() - vector_db_config = _build_vec_db_config() - pref_extractor_config = _build_pref_extractor_config() - pref_adder_config = _build_pref_adder_config() - pref_retriever_config = _build_pref_retriever_config() - - # Create component instances - graph_db = GraphStoreFactory.from_config(graph_db_config) - vector_db = VecDBFactory.from_config(vector_db_config) - llm = LLMFactory.from_config(llm_config) - embedder = EmbedderFactory.from_config(embedder_config) - mem_reader = MemReaderFactory.from_config(mem_reader_config) - reranker = RerankerFactory.from_config(reranker_config) - internet_retriever = InternetRetrieverFactory.from_config( - internet_retriever_config, embedder=embedder - ) - - # Initialize memory manager - memory_manager = MemoryManager( - graph_db, - embedder, - llm, - memory_size=_get_default_memory_size(default_cube_config), - is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False), - ) - - # Initialize text memory - text_mem = SimpleTreeTextMemory( - llm=llm, - embedder=embedder, - mem_reader=mem_reader, - graph_db=graph_db, - reranker=reranker, - memory_manager=memory_manager, - config=default_cube_config.text_mem.config, - internet_retriever=internet_retriever, - ) - - pref_extractor = ExtractorFactory.from_config( - config_factory=pref_extractor_config, - llm_provider=llm, - embedder=embedder, - vector_db=vector_db, - ) - - pref_adder = AdderFactory.from_config( - config_factory=pref_adder_config, - llm_provider=llm, - embedder=embedder, - vector_db=vector_db, - text_mem=text_mem, - ) - - pref_retriever = RetrieverFactory.from_config( - config_factory=pref_retriever_config, - llm_provider=llm, - embedder=embedder, - reranker=reranker, - vector_db=vector_db, - ) - - # Initialize preference memory - pref_mem = SimplePreferenceTextMemory( - extractor_llm=llm, - vector_db=vector_db, - embedder=embedder, - reranker=reranker, - extractor=pref_extractor, - adder=pref_adder, - retriever=pref_retriever, - ) - - mos_server = MOSServer( - mem_reader=mem_reader, - llm=llm, - online_bot=False, - ) - - # Create MemCube with pre-initialized memory instances - naive_mem_cube = NaiveMemCube( - text_mem=text_mem, - pref_mem=pref_mem, - act_mem=None, - para_mem=None, - ) - - # Initialize Scheduler - scheduler_config_dict = APIConfig.get_scheduler_config() - scheduler_config = SchedulerConfigFactory( - backend="optimized_scheduler", config=scheduler_config_dict - ) - mem_scheduler: OptimizedScheduler = SchedulerFactory.from_config(scheduler_config) - mem_scheduler.initialize_modules( - chat_llm=llm, - process_llm=mem_reader.llm, - db_engine=BaseDBManager.create_default_sqlite_engine(), - mem_reader=mem_reader, - ) - mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube) - - # Initialize SchedulerAPIModule - api_module = mem_scheduler.api_module - - if os.getenv("API_SCHEDULER_ON", True): - mem_scheduler.start() - - return ( - graph_db, - mem_reader, - llm, - embedder, - reranker, - internet_retriever, - memory_manager, - default_cube_config, - mos_server, - mem_scheduler, - naive_mem_cube, - api_module, - vector_db, - pref_extractor, - pref_adder, - pref_retriever, - text_mem, - pref_mem, - ) +# Instance ID for identifying this server instance in logs and responses +INSTANCE_ID = f"{socket.gethostname()}:{os.getpid()}:{_random.randint(1000, 9999)}" -# Initialize global components +# Initialize all server components ( graph_db, mem_reader, @@ -329,415 +53,53 @@ def init_server(): pref_retriever, text_mem, pref_mem, -) = init_server() - - -def _format_memory_item(memory_data: Any) -> dict[str, Any]: - """Format a single memory item for API response.""" - memory = memory_data.model_dump() - memory_id = memory["id"] - ref_id = f"[{memory_id.split('-')[0]}]" +) = handlers.init_server() - memory["ref_id"] = ref_id - memory["metadata"]["embedding"] = [] - memory["metadata"]["sources"] = [] - memory["metadata"]["usage"] = [] - memory["metadata"]["ref_id"] = ref_id - memory["metadata"]["id"] = memory_id - memory["metadata"]["memory"] = memory["memory"] - return memory - - -def _post_process_pref_mem( - memories_result: list[dict[str, Any]], - pref_formatted_mem: list[dict[str, Any]], - mem_cube_id: str, - include_preference: bool, -): - if include_preference: - memories_result["pref_mem"].append( - { - "cube_id": mem_cube_id, - "memories": pref_formatted_mem, - } - ) - pref_instruction, pref_note = instruct_completion(pref_formatted_mem) - memories_result["pref_string"] = pref_instruction - memories_result["pref_note"] = pref_note - - return memories_result +# ============================================================================= +# Search API Endpoints +# ============================================================================= @router.post("/search", summary="Search memories", response_model=SearchResponse) def search_memories(search_req: APISearchRequest): """Search memories for a specific user.""" - # Create UserContext object - how to assign values - user_context = UserContext( - user_id=search_req.user_id, - mem_cube_id=search_req.mem_cube_id, - session_id=search_req.session_id or "default_session", - ) - logger.info(f"Search Req is: {search_req}") - memories_result: MOSSearchResult = { - "text_mem": [], - "act_mem": [], - "para_mem": [], - "pref_mem": [], - "pref_note": "", - } - if search_req.mode == SearchMode.NOT_INITIALIZED: - search_mode = os.getenv("SEARCH_MODE", SearchMode.FAST) - else: - search_mode = search_req.mode - - def _search_text(): - try: - if search_mode == SearchMode.FAST: - formatted_memories = fast_search_memories( - search_req=search_req, user_context=user_context - ) - elif search_mode == SearchMode.FINE: - formatted_memories = fine_search_memories( - search_req=search_req, user_context=user_context - ) - elif search_mode == SearchMode.MIXTURE: - formatted_memories = mix_search_memories( - search_req=search_req, user_context=user_context - ) - else: - logger.error(f"Unsupported search mode: {search_mode}") - raise HTTPException( - status_code=400, detail=f"Unsupported search mode: {search_mode}" - ) - return formatted_memories - except Exception as e: - logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc()) - return [] - - def _search_pref(): - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - try: - results = naive_mem_cube.pref_mem.search( - query=search_req.query, - top_k=search_req.pref_top_k, - info={ - "user_id": search_req.user_id, - "session_id": search_req.session_id, - "chat_history": search_req.chat_history, - }, - ) - return [_format_memory_item(data) for data in results] - except Exception as e: - logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) - return [] - - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(_search_text) - pref_future = executor.submit(_search_pref) - text_formatted_memories = text_future.result() - pref_formatted_memories = pref_future.result() - - memories_result["text_mem"].append( - { - "cube_id": search_req.mem_cube_id, - "memories": text_formatted_memories, - } - ) - - memories_result = _post_process_pref_mem( - memories_result, - pref_formatted_memories, - search_req.mem_cube_id, - search_req.include_preference, - ) - - logger.info(f"Search memories result: {memories_result}") - - return SearchResponse( - message="Search completed successfully", - data=memories_result, - ) - - -def mix_search_memories( - search_req: APISearchRequest, - user_context: UserContext, -): - """ - Mix search memories: fast search + async fine search - """ - - formatted_memories = mem_scheduler.mix_search_memories( + return handlers.search_handlers.handle_search_memories( search_req=search_req, - user_context=user_context, - ) - return formatted_memories - - -def fine_search_memories( - search_req: APISearchRequest, - user_context: UserContext, -): - target_session_id = search_req.session_id - if not target_session_id: - target_session_id = "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - searcher = mem_scheduler.searcher - - info = { - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - } - - fast_retrieved_memories = searcher.retrieve( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=SearchMode.FAST, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, - info=info, - ) - - fast_memories = searcher.post_retrieve( - retrieved_results=fast_retrieved_memories, - top_k=search_req.top_k, - user_name=user_context.mem_cube_id, - info=info, - ) - - enhanced_results, _ = mem_scheduler.retriever.enhance_memories_with_query( - query_history=[search_req.query], - memories=fast_memories, + naive_mem_cube=naive_mem_cube, + mem_scheduler=mem_scheduler, ) - formatted_memories = [_format_memory_item(data) for data in enhanced_results] - - return formatted_memories - -def fast_search_memories( - search_req: APISearchRequest, - user_context: UserContext, -): - target_session_id = search_req.session_id - if not target_session_id: - target_session_id = "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - # Create MemCube and perform search - search_results = naive_mem_cube.text_mem.search( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=SearchMode.FAST, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, - ) - formatted_memories = [_format_memory_item(data) for data in search_results] - - return formatted_memories +# ============================================================================= +# Add API Endpoints +# ============================================================================= @router.post("/add", summary="Add memories", response_model=MemoryResponse) def add_memories(add_req: APIADDRequest): """Add memories for a specific user.""" - # Create UserContext object - how to assign values - user_context = UserContext( - user_id=add_req.user_id, - mem_cube_id=add_req.mem_cube_id, - session_id=add_req.session_id or "default_session", + return handlers.add_handlers.handle_add_memories( + add_req=add_req, + naive_mem_cube=naive_mem_cube, + mem_reader=mem_reader, + mem_scheduler=mem_scheduler, ) - logger.info(f"Add Req is: {add_req}") - - target_session_id = add_req.session_id - if not target_session_id: - target_session_id = "default_session" - - # If text memory backend works in async mode, submit tasks to scheduler - try: - sync_mode = getattr(naive_mem_cube.text_mem, "mode", "sync") - except Exception: - sync_mode = "sync" - logger.info(f"Add sync_mode mode is: {sync_mode}") - - def _process_text_mem() -> list[dict[str, str]]: - memories_local = mem_reader.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - mode="fast" if sync_mode == "async" else "fine", - ) - flattened_local = [mm for m in memories_local for mm in m] - logger.info(f"Memory extraction completed for user {add_req.user_id}") - mem_ids_local: list[str] = naive_mem_cube.text_mem.add( - flattened_local, - user_name=user_context.mem_cube_id, - ) - logger.info( - f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " - f"in session {add_req.session_id}: {mem_ids_local}" - ) - if sync_mode == "async": - try: - message_item_read = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=naive_mem_cube, - label=MEM_READ_LABEL, - content=json.dumps(mem_ids_local), - timestamp=datetime.utcnow(), - user_name=add_req.mem_cube_id, - ) - mem_scheduler.submit_messages(messages=[message_item_read]) - logger.info(f"2105Submit messages!!!!!: {json.dumps(mem_ids_local)}") - except Exception as e: - logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True) - else: - message_item_add = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=naive_mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids_local), - timestamp=datetime.utcnow(), - user_name=add_req.mem_cube_id, - ) - mem_scheduler.submit_messages(messages=[message_item_add]) - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.memory_type, - } - for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) - ] - - def _process_pref_mem() -> list[dict[str, str]]: - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - # Follow async behavior similar to core.py: enqueue when async - if sync_mode == "async": - try: - messages_list = [add_req.messages] - message_item_pref = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=naive_mem_cube, - label=PREF_ADD_LABEL, - content=json.dumps(messages_list), - timestamp=datetime.utcnow(), - ) - mem_scheduler.submit_messages(messages=[message_item_pref]) - logger.info("Submitted preference add to scheduler (async mode)") - except Exception as e: - logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True) - return [] - else: - pref_memories_local = naive_mem_cube.pref_mem.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - "mem_cube_id": add_req.mem_cube_id, - }, - ) - pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local) - logger.info( - f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " - f"in session {add_req.session_id}: {pref_ids_local}" - ) - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.preference_type, - } - for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) - ] - - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(_process_text_mem) - pref_future = executor.submit(_process_pref_mem) - text_response_data = text_future.result() - pref_response_data = pref_future.result() - - logger.info(f"add_memories Text response data: {text_response_data}") - logger.info(f"add_memories Pref response data: {pref_response_data}") - - return MemoryResponse( - message="Memory added successfully", - data=text_response_data + pref_response_data, - ) + +# ============================================================================= +# Scheduler API Endpoints +# ============================================================================= @router.get("/scheduler/status", summary="Get scheduler running status") def scheduler_status(user_name: str | None = None): - try: - if user_name: - running = mem_scheduler.dispatcher.get_running_tasks( - lambda task: getattr(task, "mem_cube_id", None) == user_name - ) - tasks_iter = list(_to_iter(running)) - running_count = len(tasks_iter) - return { - "message": "ok", - "data": { - "scope": "user", - "user_name": user_name, - "running_tasks": running_count, - "timestamp": time.time(), - "instance_id": INSTANCE_ID, - }, - } - else: - running_all = mem_scheduler.dispatcher.get_running_tasks(lambda _t: True) - tasks_iter = list(_to_iter(running_all)) - running_count = len(tasks_iter) - - task_count_per_user: dict[str, int] = {} - for task in tasks_iter: - cube = getattr(task, "mem_cube_id", "unknown") - task_count_per_user[cube] = task_count_per_user.get(cube, 0) + 1 - - try: - metrics_snapshot = mem_scheduler.dispatcher.metrics.snapshot() - except Exception: - metrics_snapshot = {} - - return { - "message": "ok", - "data": { - "scope": "global", - "running_tasks": running_count, - "task_count_per_user": task_count_per_user, - "timestamp": time.time(), - "instance_id": INSTANCE_ID, - "metrics": metrics_snapshot, - }, - } - except Exception as err: - logger.error("Failed to get scheduler status: %s", traceback.format_exc()) - raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err + """Get scheduler running status.""" + return handlers.scheduler_handlers.handle_scheduler_status( + user_name=user_name, + mem_scheduler=mem_scheduler, + instance_id=INSTANCE_ID, + ) @router.post("/scheduler/wait", summary="Wait until scheduler is idle for a specific user") @@ -746,47 +108,13 @@ def scheduler_wait( timeout_seconds: float = 120.0, poll_interval: float = 0.2, ): - """ - Block until scheduler has no running tasks for the given user_name, or timeout. - """ - start = time.time() - try: - while True: - running = mem_scheduler.dispatcher.get_running_tasks( - lambda task: task.mem_cube_id == user_name - ) - running_count = len(running) - elapsed = time.time() - start - - # success -> scheduler is idle - if running_count == 0: - return { - "message": "idle", - "data": { - "running_tasks": 0, - "waited_seconds": round(elapsed, 3), - "timed_out": False, - "user_name": user_name, - }, - } - - # timeout check - if elapsed > timeout_seconds: - return { - "message": "timeout", - "data": { - "running_tasks": running_count, - "waited_seconds": round(elapsed, 3), - "timed_out": True, - "user_name": user_name, - }, - } - - time.sleep(poll_interval) - - except Exception as err: - logger.error("Failed while waiting for scheduler: %s", traceback.format_exc()) - raise HTTPException(status_code=500, detail="Failed while waiting for scheduler") from err + """Wait until scheduler is idle for a specific user.""" + return handlers.scheduler_handlers.handle_scheduler_wait( + user_name=user_name, + timeout_seconds=timeout_seconds, + poll_interval=poll_interval, + mem_scheduler=mem_scheduler, + ) @router.get("/scheduler/wait/stream", summary="Stream scheduler progress for a user") @@ -795,86 +123,35 @@ def scheduler_wait_stream( timeout_seconds: float = 120.0, poll_interval: float = 0.2, ): - """ - Stream scheduler progress via Server-Sent Events (SSE). - - Contract: - - We emit periodic heartbeat frames while tasks are still running. - - Each heartbeat frame is JSON, prefixed with "data: ". - - On final frame, we include status = "idle" or "timeout" and timed_out flag, - with the same semantics as /scheduler/wait. - - Example curl: - curl -N "${API_HOST}/product/scheduler/wait/stream?timeout_seconds=10&poll_interval=0.5" - """ - - def event_generator(): - start = time.time() - try: - while True: - running = mem_scheduler.dispatcher.get_running_tasks( - lambda task: task.mem_cube_id == user_name - ) - running_count = len(running) - elapsed = time.time() - start - - payload = { - "user_name": user_name, - "running_tasks": running_count, - "elapsed_seconds": round(elapsed, 3), - "status": "running" if running_count > 0 else "idle", - "instance_id": INSTANCE_ID, - } - yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n" - - if running_count == 0 or elapsed > timeout_seconds: - payload["status"] = "idle" if running_count == 0 else "timeout" - payload["timed_out"] = running_count > 0 - yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n" - break - - time.sleep(poll_interval) - - except Exception as e: - err_payload = { - "status": "error", - "detail": "stream_failed", - "exception": str(e), - "user_name": user_name, - } - logger.error(f"Scheduler stream error for {user_name}: {traceback.format_exc()}") - yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n" - - return StreamingResponse(event_generator(), media_type="text/event-stream") + """Stream scheduler progress via Server-Sent Events (SSE).""" + return handlers.scheduler_handlers.handle_scheduler_wait_stream( + user_name=user_name, + timeout_seconds=timeout_seconds, + poll_interval=poll_interval, + mem_scheduler=mem_scheduler, + instance_id=INSTANCE_ID, + ) + + +# ============================================================================= +# Chat API Endpoints +# ============================================================================= @router.post("/chat/complete", summary="Chat with MemOS (Complete Response)") def chat_complete(chat_req: APIChatCompleteRequest): """Chat with MemOS for a specific user. Returns complete response (non-streaming).""" - try: - # Collect all responses from the generator - content, references = mos_server.chat( - query=chat_req.query, - user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, - mem_cube=naive_mem_cube, - history=chat_req.history, - internet_search=chat_req.internet_search, - moscube=chat_req.moscube, - base_prompt=chat_req.base_prompt, - top_k=chat_req.top_k, - threshold=chat_req.threshold, - session_id=chat_req.session_id, - ) - - # Return the complete response - return { - "message": "Chat completed successfully", - "data": {"response": content, "references": references}, - } - - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to start chat: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + return handlers.chat_handlers.handle_chat_complete( + chat_req=chat_req, + mos_server=mos_server, + naive_mem_cube=naive_mem_cube, + ) + + +@router.post("/chat", summary="Chat with MemOS") +def chat(chat_req: ChatRequest): + """Chat with MemOS for a specific user. Returns SSE stream.""" + return handlers.chat_handlers.handle_chat_stream( + chat_req=chat_req, + mos_server=mos_server, + ) From 30f9e88995ba42914dedc15cb32398f04547059e Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 13 Nov 2025 15:13:54 +0800 Subject: [PATCH 2/7] feat: code reorg and merge API and playground --- src/memos/api/handlers/__init__.py | 25 +- src/memos/api/handlers/add_handler.py | 272 ++++++ src/memos/api/handlers/add_handlers.py | 253 ------ src/memos/api/handlers/base_handler.py | 207 +++++ src/memos/api/handlers/chat_handler.py | 830 ++++++++++++++++++ src/memos/api/handlers/chat_handlers.py | 130 --- src/memos/api/handlers/component_init.py | 1 - ...ters_handlers.py => formatters_handler.py} | 0 src/memos/api/handlers/memory_handler.py | 201 +++++ ...duler_handlers.py => scheduler_handler.py} | 2 +- src/memos/api/handlers/search_handler.py | 289 ++++++ src/memos/api/handlers/search_handlers.py | 261 ------ src/memos/api/handlers/suggestion_handler.py | 117 +++ src/memos/api/product_models.py | 3 + src/memos/api/routers/server_router.py | 141 ++- src/memos/mem_os/utils/reference_utils.py | 23 +- .../mem_scheduler/general_modules/base.py | 2 +- src/memos/mem_scheduler/general_scheduler.py | 9 - 18 files changed, 2062 insertions(+), 704 deletions(-) create mode 100644 src/memos/api/handlers/add_handler.py delete mode 100644 src/memos/api/handlers/add_handlers.py create mode 100644 src/memos/api/handlers/base_handler.py create mode 100644 src/memos/api/handlers/chat_handler.py delete mode 100644 src/memos/api/handlers/chat_handlers.py rename src/memos/api/handlers/{formatters_handlers.py => formatters_handler.py} (100%) create mode 100644 src/memos/api/handlers/memory_handler.py rename src/memos/api/handlers/{scheduler_handlers.py => scheduler_handler.py} (99%) create mode 100644 src/memos/api/handlers/search_handler.py delete mode 100644 src/memos/api/handlers/search_handlers.py create mode 100644 src/memos/api/handlers/suggestion_handler.py diff --git a/src/memos/api/handlers/__init__.py b/src/memos/api/handlers/__init__.py index c57a33d3..90347768 100644 --- a/src/memos/api/handlers/__init__.py +++ b/src/memos/api/handlers/__init__.py @@ -9,8 +9,14 @@ """ # Lazy imports to avoid circular dependencies -from memos.api import handlers -from memos.api.handlers import add_handlers, chat_handlers, scheduler_handlers, search_handlers +from memos.api.handlers import ( + add_handler, + chat_handler, + memory_handler, + scheduler_handler, + search_handler, + suggestion_handler, +) from memos.api.handlers.component_init import init_server from memos.api.handlers.config_builders import ( build_embedder_config, @@ -24,7 +30,7 @@ build_reranker_config, build_vec_db_config, ) -from memos.api.handlers.formatters_handlers import ( +from memos.api.handlers.formatters_handler import ( format_memory_item, post_process_pref_mem, to_iter, @@ -32,7 +38,7 @@ __all__ = [ - "add_handlers", + "add_handler", "build_embedder_config", "build_graph_db_config", "build_internet_retriever_config", @@ -43,13 +49,14 @@ "build_pref_retriever_config", "build_reranker_config", "build_vec_db_config", - "chat_handlers", + "chat_handler", "format_memory_item", - "formatters_handlers", - "handlers", + "formatters_handler", "init_server", + "memory_handler", "post_process_pref_mem", - "scheduler_handlers", - "search_handlers", + "scheduler_handler", + "search_handler", + "suggestion_handler", "to_iter", ] diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py new file mode 100644 index 00000000..257d911b --- /dev/null +++ b/src/memos/api/handlers/add_handler.py @@ -0,0 +1,272 @@ +""" +Add handler for memory addition functionality (Class-based version). + +This module provides a class-based implementation of add handlers, +using dependency injection for better modularity and testability. +""" + +import json +import os + +from datetime import datetime + +from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies +from memos.api.product_models import APIADDRequest, MemoryResponse +from memos.context.context import ContextThreadPoolExecutor +from memos.mem_scheduler.schemas.general_schemas import ( + ADD_LABEL, + MEM_READ_LABEL, + PREF_ADD_LABEL, +) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.types import UserContext + + +class AddHandler(BaseHandler): + """ + Handler for memory addition operations. + + Handles both text and preference memory additions with sync/async support. + """ + + def __init__(self, dependencies: HandlerDependencies): + """ + Initialize add handler. + + Args: + dependencies: HandlerDependencies instance + """ + super().__init__(dependencies) + self._validate_dependencies("naive_mem_cube", "mem_reader", "mem_scheduler") + + def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: + """ + Main handler for add memories endpoint. + + Orchestrates the addition of both text and preference memories, + supporting concurrent processing. + + Args: + add_req: Add memory request + + Returns: + MemoryResponse with added memory information + """ + # Create UserContext object + user_context = UserContext( + user_id=add_req.user_id, + mem_cube_id=add_req.mem_cube_id, + session_id=add_req.session_id or "default_session", + ) + + self.logger.info(f"Add Req is: {add_req}") + + # Process text and preference memories in parallel + with ContextThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(self._process_text_mem, add_req, user_context) + pref_future = executor.submit(self._process_pref_mem, add_req, user_context) + + text_response_data = text_future.result() + pref_response_data = pref_future.result() + + self.logger.info(f"add_memories Text response data: {text_response_data}") + self.logger.info(f"add_memories Pref response data: {pref_response_data}") + + return MemoryResponse( + message="Memory added successfully", + data=text_response_data + pref_response_data, + ) + + def _process_text_mem( + self, + add_req: APIADDRequest, + user_context: UserContext, + ) -> list[dict[str, str]]: + """ + Process and add text memories. + + Extracts memories from messages and adds them to the text memory system. + Handles both sync and async modes. + + Args: + add_req: Add memory request + user_context: User context with IDs + + Returns: + List of formatted memory responses + """ + target_session_id = add_req.session_id or "default_session" + + # Determine sync mode + sync_mode = add_req.async_mode or self._get_sync_mode() + + self.logger.info(f"Processing text memory with mode: {sync_mode}") + + # Extract memories + memories_local = self.mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + mode="fast" if sync_mode == "async" else "fine", + ) + flattened_local = [mm for m in memories_local for mm in m] + self.logger.info(f"Memory extraction completed for user {add_req.user_id}") + + # Add memories to text_mem + mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add( + flattened_local, + user_name=user_context.mem_cube_id, + ) + self.logger.info( + f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " + f"in session {add_req.session_id}: {mem_ids_local}" + ) + + # Schedule async/sync tasks + self._schedule_memory_tasks( + add_req=add_req, + user_context=user_context, + mem_ids=mem_ids_local, + sync_mode=sync_mode, + ) + + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.memory_type, + } + for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) + ] + + def _process_pref_mem( + self, + add_req: APIADDRequest, + user_context: UserContext, + ) -> list[dict[str, str]]: + """ + Process and add preference memories. + + Extracts preferences from messages and adds them to the preference memory system. + Handles both sync and async modes. + + Args: + add_req: Add memory request + user_context: User context with IDs + + Returns: + List of formatted preference responses + """ + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + + # Determine sync mode + sync_mode = add_req.async_mode or self._get_sync_mode() + target_session_id = add_req.session_id or "default_session" + + # Follow async behavior: enqueue when async + if sync_mode == "async": + try: + messages_list = [add_req.messages] + message_item_pref = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=self.naive_mem_cube, + label=PREF_ADD_LABEL, + content=json.dumps(messages_list), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item_pref]) + self.logger.info("Submitted preference add to scheduler (async mode)") + except Exception as e: + self.logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True) + return [] + else: + # Sync mode: process immediately + pref_memories_local = self.naive_mem_cube.pref_mem.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + "mem_cube_id": add_req.mem_cube_id, + }, + ) + pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local) + self.logger.info( + f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " + f"in session {add_req.session_id}: {pref_ids_local}" + ) + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.preference_type, + } + for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) + ] + + def _get_sync_mode(self) -> str: + """ + Get synchronization mode from memory cube. + + Returns: + Sync mode string ("sync" or "async") + """ + try: + return getattr(self.naive_mem_cube.text_mem, "mode", "sync") + except Exception: + return "sync" + + def _schedule_memory_tasks( + self, + add_req: APIADDRequest, + user_context: UserContext, + mem_ids: list[str], + sync_mode: str, + ) -> None: + """ + Schedule memory processing tasks based on sync mode. + + Args: + add_req: Add memory request + user_context: User context + mem_ids: List of memory IDs + sync_mode: Synchronization mode + """ + target_session_id = add_req.session_id or "default_session" + + if sync_mode == "async": + # Async mode: submit MEM_READ_LABEL task + try: + message_item_read = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=self.naive_mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + user_name=add_req.mem_cube_id, + ) + self.mem_scheduler.submit_messages(messages=[message_item_read]) + self.logger.info(f"Submitted async memory read task: {json.dumps(mem_ids)}") + except Exception as e: + self.logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True) + else: + # Sync mode: submit ADD_LABEL task + message_item_add = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=self.naive_mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + user_name=add_req.mem_cube_id, + ) + self.mem_scheduler.submit_messages(messages=[message_item_add]) diff --git a/src/memos/api/handlers/add_handlers.py b/src/memos/api/handlers/add_handlers.py deleted file mode 100644 index 2392d85d..00000000 --- a/src/memos/api/handlers/add_handlers.py +++ /dev/null @@ -1,253 +0,0 @@ -""" -Add handler for memory addition functionality. - -This module handles adding new memories to the system, supporting both -text and preference memory additions with optional async processing. -""" - -import json -import os - -from datetime import datetime -from typing import Any - -from memos.api.product_models import APIADDRequest, MemoryResponse -from memos.context.context import ContextThreadPoolExecutor -from memos.log import get_logger -from memos.mem_scheduler.schemas.general_schemas import ( - ADD_LABEL, - MEM_READ_LABEL, - PREF_ADD_LABEL, -) -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.types import UserContext - - -logger = get_logger(__name__) - - -def _process_text_mem( - add_req: APIADDRequest, - user_context: UserContext, - naive_mem_cube: Any, - mem_reader: Any, - mem_scheduler: Any, -) -> list[dict[str, str]]: - """ - Process and add text memories. - - Extracts memories from messages and adds them to the text memory system. - Handles both sync and async modes. - - Args: - add_req: Add memory request - user_context: User context with IDs - naive_mem_cube: Memory cube instance - mem_reader: Memory reader for extraction - mem_scheduler: Scheduler for async tasks - - Returns: - List of formatted memory responses - """ - target_session_id = add_req.session_id or "default_session" - - # Determine sync mode - try: - sync_mode = getattr(naive_mem_cube.text_mem, "mode", "sync") - except Exception: - sync_mode = "sync" - - logger.info(f"Processing text memory with mode: {sync_mode}") - - memories_local = mem_reader.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - mode="fast" if sync_mode == "async" else "fine", - ) - flattened_local = [mm for m in memories_local for mm in m] - logger.info(f"Memory extraction completed for user {add_req.user_id}") - - mem_ids_local: list[str] = naive_mem_cube.text_mem.add( - flattened_local, - user_name=user_context.mem_cube_id, - ) - logger.info( - f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " - f"in session {add_req.session_id}: {mem_ids_local}" - ) - - # Handle async/sync scheduling - if sync_mode == "async": - try: - message_item_read = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=naive_mem_cube, - label=MEM_READ_LABEL, - content=json.dumps(mem_ids_local), - timestamp=datetime.utcnow(), - user_name=add_req.mem_cube_id, - ) - mem_scheduler.submit_messages(messages=[message_item_read]) - logger.info(f"Submitted async memory read task: {json.dumps(mem_ids_local)}") - except Exception as e: - logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True) - else: - message_item_add = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=naive_mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids_local), - timestamp=datetime.utcnow(), - user_name=add_req.mem_cube_id, - ) - mem_scheduler.submit_messages(messages=[message_item_add]) - - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.memory_type, - } - for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) - ] - - -def _process_pref_mem( - add_req: APIADDRequest, - user_context: UserContext, - naive_mem_cube: Any, - mem_scheduler: Any, -) -> list[dict[str, str]]: - """ - Process and add preference memories. - - Extracts preferences from messages and adds them to the preference memory system. - Handles both sync and async modes. - - Args: - add_req: Add memory request - user_context: User context with IDs - naive_mem_cube: Memory cube instance - mem_scheduler: Scheduler for async tasks - - Returns: - List of formatted preference responses - """ - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - - # Determine sync mode - try: - sync_mode = getattr(naive_mem_cube.text_mem, "mode", "sync") - except Exception: - sync_mode = "sync" - - target_session_id = add_req.session_id or "default_session" - - # Follow async behavior: enqueue when async - if sync_mode == "async": - try: - messages_list = [add_req.messages] - message_item_pref = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=naive_mem_cube, - label=PREF_ADD_LABEL, - content=json.dumps(messages_list), - timestamp=datetime.utcnow(), - ) - mem_scheduler.submit_messages(messages=[message_item_pref]) - logger.info("Submitted preference add to scheduler (async mode)") - except Exception as e: - logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True) - return [] - else: - pref_memories_local = naive_mem_cube.pref_mem.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - "mem_cube_id": add_req.mem_cube_id, - }, - ) - pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local) - logger.info( - f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " - f"in session {add_req.session_id}: {pref_ids_local}" - ) - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.preference_type, - } - for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) - ] - - -def handle_add_memories( - add_req: APIADDRequest, - naive_mem_cube: Any, - mem_reader: Any, - mem_scheduler: Any, -) -> MemoryResponse: - """ - Main handler for add memories endpoint. - - Orchestrates the addition of both text and preference memories, - supporting concurrent processing. - - Args: - add_req: Add memory request - naive_mem_cube: Memory cube instance - mem_reader: Memory reader for extraction - mem_scheduler: Scheduler for async tasks - - Returns: - MemoryResponse with added memory information - """ - # Create UserContext object - user_context = UserContext( - user_id=add_req.user_id, - mem_cube_id=add_req.mem_cube_id, - session_id=add_req.session_id or "default_session", - ) - - logger.info(f"Add Req is: {add_req}") - - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit( - _process_text_mem, - add_req, - user_context, - naive_mem_cube, - mem_reader, - mem_scheduler, - ) - pref_future = executor.submit( - _process_pref_mem, - add_req, - user_context, - naive_mem_cube, - mem_scheduler, - ) - text_response_data = text_future.result() - pref_response_data = pref_future.result() - - logger.info(f"add_memories Text response data: {text_response_data}") - logger.info(f"add_memories Pref response data: {pref_response_data}") - - return MemoryResponse( - message="Memory added successfully", - data=text_response_data + pref_response_data, - ) diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py new file mode 100644 index 00000000..3bb8ae57 --- /dev/null +++ b/src/memos/api/handlers/base_handler.py @@ -0,0 +1,207 @@ +""" +Base handler for MemOS API handlers. + +This module provides the base class for all API handlers, implementing +dependency injection and common functionality. +""" + +from typing import Any + +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class HandlerDependencies: + """ + Container for handler dependencies. + + This class acts as a dependency injection container, holding all + shared resources needed by handlers. + """ + + def __init__( + self, + llm: Any | None = None, + naive_mem_cube: Any | None = None, + mem_reader: Any | None = None, + mem_scheduler: Any | None = None, + embedder: Any | None = None, + reranker: Any | None = None, + graph_db: Any | None = None, + vector_db: Any | None = None, + internet_retriever: Any | None = None, + memory_manager: Any | None = None, + mos_server: Any | None = None, + **kwargs, + ): + """ + Initialize handler dependencies. + + Args: + llm: Language model instance + naive_mem_cube: Memory cube instance + mem_reader: Memory reader instance + mem_scheduler: Scheduler instance + embedder: Embedder instance + reranker: Reranker instance + graph_db: Graph database instance + vector_db: Vector database instance + internet_retriever: Internet retriever instance + memory_manager: Memory manager instance + mos_server: MOS server instance + **kwargs: Additional dependencies + """ + self.llm = llm + self.naive_mem_cube = naive_mem_cube + self.mem_reader = mem_reader + self.mem_scheduler = mem_scheduler + self.embedder = embedder + self.reranker = reranker + self.graph_db = graph_db + self.vector_db = vector_db + self.internet_retriever = internet_retriever + self.memory_manager = memory_manager + self.mos_server = mos_server + + # Store any additional dependencies + for key, value in kwargs.items(): + setattr(self, key, value) + + @classmethod + def from_init_server(cls, *components): + """ + Create dependencies from init_server() return values. + + Args: + *components: Tuple of components returned by init_server() + + Returns: + HandlerDependencies instance + """ + ( + graph_db, + mem_reader, + llm, + embedder, + reranker, + internet_retriever, + memory_manager, + default_cube_config, + mos_server, + mem_scheduler, + naive_mem_cube, + api_module, + vector_db, + pref_extractor, + pref_adder, + pref_retriever, + text_mem, + pref_mem, + ) = components + + return cls( + llm=llm, + naive_mem_cube=naive_mem_cube, + mem_reader=mem_reader, + mem_scheduler=mem_scheduler, + embedder=embedder, + reranker=reranker, + graph_db=graph_db, + vector_db=vector_db, + internet_retriever=internet_retriever, + memory_manager=memory_manager, + mos_server=mos_server, + default_cube_config=default_cube_config, + api_module=api_module, + pref_extractor=pref_extractor, + pref_adder=pref_adder, + pref_retriever=pref_retriever, + text_mem=text_mem, + pref_mem=pref_mem, + ) + + +class BaseHandler: + """ + Base class for all API handlers. + + Provides common functionality and dependency injection for handlers. + All specific handlers should inherit from this class. + """ + + def __init__(self, dependencies: HandlerDependencies): + """ + Initialize base handler. + + Args: + dependencies: HandlerDependencies instance containing all shared resources + """ + self.deps = dependencies + self.logger = get_logger(self.__class__.__name__) + + @property + def llm(self): + """Get LLM instance.""" + return self.deps.llm + + @property + def naive_mem_cube(self): + """Get memory cube instance.""" + return self.deps.naive_mem_cube + + @property + def mem_reader(self): + """Get memory reader instance.""" + return self.deps.mem_reader + + @property + def mem_scheduler(self): + """Get scheduler instance.""" + return self.deps.mem_scheduler + + @property + def embedder(self): + """Get embedder instance.""" + return self.deps.embedder + + @property + def reranker(self): + """Get reranker instance.""" + return self.deps.reranker + + @property + def graph_db(self): + """Get graph database instance.""" + return self.deps.graph_db + + @property + def vector_db(self): + """Get vector database instance.""" + return self.deps.vector_db + + @property + def mos_server(self): + """Get MOS server instance.""" + return self.deps.mos_server + + def _validate_dependencies(self, *required_deps: str) -> None: + """ + Validate that required dependencies are available. + + Args: + *required_deps: Names of required dependency attributes + + Raises: + ValueError: If any required dependency is None + """ + missing = [] + for dep_name in required_deps: + if not hasattr(self.deps, dep_name) or getattr(self.deps, dep_name) is None: + missing.append(dep_name) + + if missing: + raise ValueError( + f"{self.__class__.__name__} requires the following dependencies: {', '.join(missing)}" + ) diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py new file mode 100644 index 00000000..01a44ca7 --- /dev/null +++ b/src/memos/api/handlers/chat_handler.py @@ -0,0 +1,830 @@ +""" +Chat handler for chat functionality (Class-based version). + +This module provides a complete implementation of chat handlers, +consolidating all chat-related logic without depending on mos_server. +""" + +import asyncio +import json +import traceback + +from collections.abc import Generator +from datetime import datetime +from typing import Any, Literal + +from fastapi import HTTPException +from fastapi.responses import StreamingResponse + +from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies +from memos.api.product_models import ( + APIADDRequest, + APIChatCompleteRequest, + APISearchRequest, + ChatRequest, +) +from memos.context.context import ContextThread +from memos.mem_os.utils.format_utils import clean_json_response +from memos.mem_os.utils.reference_utils import ( + prepare_reference_data, + process_streaming_references_complete, +) +from memos.mem_scheduler.schemas.general_schemas import ( + ANSWER_LABEL, + QUERY_LABEL, + SearchMode, +) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.templates.mos_prompts import ( + FURTHER_SUGGESTION_PROMPT, + get_memos_prompt, +) +from memos.types import MessageList + + +class ChatHandler(BaseHandler): + """ + Handler for chat operations. + + Composes SearchHandler and AddHandler to provide complete chat functionality + without depending on mos_server. All chat logic is centralized here. + """ + + def __init__( + self, + dependencies: HandlerDependencies, + search_handler=None, + add_handler=None, + online_bot=None, + ): + """ + Initialize chat handler. + + Args: + dependencies: HandlerDependencies instance + search_handler: Optional SearchHandler instance (created if not provided) + add_handler: Optional AddHandler instance (created if not provided) + online_bot: Optional DingDing bot function for notifications + """ + super().__init__(dependencies) + self._validate_dependencies("llm", "naive_mem_cube", "mem_reader", "mem_scheduler") + + # Lazy import to avoid circular dependencies + if search_handler is None: + from memos.api.handlers.search_handler import SearchHandler + + search_handler = SearchHandler(dependencies) + + if add_handler is None: + from memos.api.handlers.add_handler import AddHandler + + add_handler = AddHandler(dependencies) + + self.search_handler = search_handler + self.add_handler = add_handler + self.online_bot = online_bot + + # Check if scheduler is enabled + self.enable_mem_scheduler = ( + hasattr(dependencies, "enable_mem_scheduler") and dependencies.enable_mem_scheduler + ) + + def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, Any]: + """ + Chat with MemOS for complete response (non-streaming). + + This implementation directly uses search/add handlers instead of mos_server. + + Args: + chat_req: Chat complete request + + Returns: + Dictionary with response and references + + Raises: + HTTPException: If chat fails + """ + try: + import time + + time_start = time.time() + + # Step 1: Search for relevant memories + search_req = APISearchRequest( + user_id=chat_req.user_id, + mem_cube_id=chat_req.mem_cube_id, + query=chat_req.query, + top_k=chat_req.top_k or 10, + session_id=chat_req.session_id, + mode=SearchMode.FINE, + internet_search=chat_req.internet_search, + moscube=chat_req.moscube, + chat_history=chat_req.history, + ) + + search_response = self.search_handler.handle_search_memories(search_req) + + # Extract memories from search results + memories_list = [] + if search_response.data and search_response.data.get("text_mem"): + text_mem_results = search_response.data["text_mem"] + if text_mem_results and text_mem_results[0].get("memories"): + memories_list = text_mem_results[0]["memories"] + + # Filter memories by threshold + filtered_memories = self._filter_memories_by_threshold( + memories_list, chat_req.threshold or 0.5 + ) + + # Step 2: Build system prompt + system_prompt = self._build_system_prompt(filtered_memories, chat_req.base_prompt) + + # Prepare message history + history_info = chat_req.history[-20:] if chat_req.history else [] + current_messages = [ + {"role": "system", "content": system_prompt}, + *history_info, + {"role": "user", "content": chat_req.query}, + ] + + self.logger.info("Starting to generate complete response...") + + # Step 3: Generate complete response from LLM + response = self.llm.generate(current_messages) + + time_end = time.time() + + # Step 4: Start post-chat processing asynchronously + self._start_post_chat_processing( + user_id=chat_req.user_id, + cube_id=chat_req.mem_cube_id, + session_id=chat_req.session_id or "default_session", + query=chat_req.query, + full_response=response, + system_prompt=system_prompt, + time_start=time_start, + time_end=time_end, + speed_improvement=0.0, + current_messages=current_messages, + ) + + # Return the complete response + return { + "message": "Chat completed successfully", + "data": {"response": response, "references": filtered_memories}, + } + + except ValueError as err: + raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err + except Exception as err: + self.logger.error(f"Failed to complete chat: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + + def handle_chat_stream(self, chat_req: ChatRequest) -> StreamingResponse: + """ + Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers. + + This implementation directly uses search_handler and add_handler. + + Args: + chat_req: Chat stream request + + Returns: + StreamingResponse with SSE formatted chat stream + + Raises: + HTTPException: If stream initialization fails + """ + try: + + def generate_chat_response() -> Generator[str, None, None]: + """Generate chat response as SSE stream.""" + try: + import time + + time_start = time.time() + + # Step 1: Search for memories using search handler + yield f"data: {json.dumps({'type': 'status', 'data': '0'})}\n\n" + + search_req = APISearchRequest( + user_id=chat_req.user_id, + mem_cube_id=chat_req.mem_cube_id, + query=chat_req.query, + top_k=20, + session_id=chat_req.session_id, + mode=SearchMode.FINE, + internet_search=chat_req.internet_search, + moscube=chat_req.moscube, + chat_history=chat_req.history, + ) + + search_response = self.search_handler.handle_search_memories(search_req) + + yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n" + self._send_message_to_scheduler( + user_id=chat_req.user_id, + mem_cube_id=chat_req.mem_cube_id, + query=chat_req.query, + label=QUERY_LABEL, + ) + # Extract memories from search results + memories_list = [] + if search_response.data and search_response.data.get("text_mem"): + text_mem_results = search_response.data["text_mem"] + if text_mem_results and text_mem_results[0].get("memories"): + memories_list = text_mem_results[0]["memories"] + + # Filter memories by threshold + filtered_memories = self._filter_memories_by_threshold(memories_list) + + # Prepare reference data + reference = prepare_reference_data(filtered_memories) + yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" + + # Step 2: Build system prompt with memories + system_prompt = self._build_enhance_system_prompt(filtered_memories) + + # Prepare messages + history_info = chat_req.history[-20:] if chat_req.history else [] + current_messages = [ + {"role": "system", "content": system_prompt}, + *history_info, + {"role": "user", "content": chat_req.query}, + ] + + self.logger.info( + f"user_id: {chat_req.user_id}, cube_id: {chat_req.mem_cube_id}, " + f"current_system_prompt: {system_prompt}" + ) + + yield f"data: {json.dumps({'type': 'status', 'data': '2'})}\n\n" + + # Step 3: Generate streaming response from LLM + response_stream = self.llm.generate_stream(current_messages) + + # Stream the response + buffer = "" + full_response = "" + + for chunk in response_stream: + if chunk in ["", ""]: + continue + + buffer += chunk + full_response += chunk + + # Process buffer to ensure complete reference tags + processed_chunk, remaining_buffer = process_streaming_references_complete( + buffer + ) + + if processed_chunk: + chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n" + yield chunk_data + buffer = remaining_buffer + + # Process any remaining buffer + if buffer: + processed_chunk, _ = process_streaming_references_complete(buffer) + if processed_chunk: + chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n" + yield chunk_data + + # Calculate timing + time_end = time.time() + speed_improvement = round(float((len(system_prompt) / 2) * 0.0048 + 44.5), 1) + total_time = round(float(time_end - time_start), 1) + + yield f"data: {json.dumps({'type': 'time', 'data': {'total_time': total_time, 'speed_improvement': f'{speed_improvement}%'}})}\n\n" + + # Get further suggestion + current_messages.append({"role": "assistant", "content": full_response}) + further_suggestion = self._get_further_suggestion(current_messages) + self.logger.info(f"further_suggestion: {further_suggestion}") + yield f"data: {json.dumps({'type': 'suggestion', 'data': further_suggestion})}\n\n" + + yield f"data: {json.dumps({'type': 'end'})}\n\n" + + # Step 4: Add conversation to memory asynchronously + self._start_post_chat_processing( + user_id=chat_req.user_id, + cube_id=chat_req.mem_cube_id, + session_id=chat_req.session_id or "default_session", + query=chat_req.query, + full_response=full_response, + system_prompt=system_prompt, + time_start=time_start, + time_end=time_end, + speed_improvement=speed_improvement, + current_messages=current_messages, + ) + + except Exception as e: + self.logger.error(f"Error in chat stream: {e}", exc_info=True) + error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n" + yield error_data + + return StreamingResponse( + generate_chat_response(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "*", + "Access-Control-Allow-Methods": "*", + }, + ) + + except ValueError as err: + raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err + except Exception as err: + self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + + def _build_system_prompt( + self, + memories_all: list, + base_prompt: str | None = None, + tone: str = "friendly", + verbosity: str = "mid", + ) -> str: + """ + Build system prompt with memory references (for complete response). + + Args: + memories_all: List of memory items + base_prompt: Optional base prompt + tone: Tone of the prompt + verbosity: Verbosity level + + Returns: + System prompt string + """ + now = datetime.now() + formatted_date = now.strftime("%Y-%m-%d (%A)") + sys_body = get_memos_prompt( + date=formatted_date, tone=tone, verbosity=verbosity, mode="base" + ) + + # Format memories + mem_block_o, mem_block_p = self._format_mem_block(memories_all) + mem_block = mem_block_o + "\n" + mem_block_p + + prefix = (base_prompt.strip() + "\n\n") if base_prompt else "" + return ( + prefix + + sys_body + + "\n\n# Memories\n## PersonalMemory & OuterMemory (ordered)\n" + + mem_block + ) + + def _build_enhance_system_prompt( + self, + memories_list: list, + tone: str = "friendly", + verbosity: str = "mid", + ) -> str: + """ + Build enhanced system prompt with memories (for streaming response). + + Args: + memories_list: List of memory items + tone: Tone of the prompt + verbosity: Verbosity level + + Returns: + System prompt string + """ + now = datetime.now() + formatted_date = now.strftime("%Y-%m-%d (%A)") + sys_body = get_memos_prompt( + date=formatted_date, tone=tone, verbosity=verbosity, mode="enhance" + ) + + # Format memories + mem_block_o, mem_block_p = self._format_mem_block(memories_list) + + return ( + sys_body + + "\n\n# Memories\n## PersonalMemory (ordered)\n" + + mem_block_p + + "\n## OuterMemory (ordered)\n" + + mem_block_o + ) + + def _format_mem_block( + self, memories_all: list, max_items: int = 20, max_chars_each: int = 320 + ) -> tuple[str, str]: + """ + Format memory block for prompt. + + Args: + memories_all: List of memory items + max_items: Maximum number of items to format + max_chars_each: Maximum characters per item + + Returns: + Tuple of (outer_memory_block, personal_memory_block) + """ + if not memories_all: + return "(none)", "(none)" + + lines_o = [] + lines_p = [] + + for idx, m in enumerate(memories_all[:max_items], 1): + mid = m.get("id", "").split("-")[0] if m.get("id") else f"mem_{idx}" + memory_content = m.get("memory", "") + metadata = m.get("metadata", {}) + memory_type = metadata.get("memory_type", "") + + tag = "O" if "Outer" in str(memory_type) else "P" + txt = memory_content.replace("\n", " ").strip() + if len(txt) > max_chars_each: + txt = txt[: max_chars_each - 1] + "…" + + mid = mid or f"mem_{idx}" + if tag == "O": + lines_o.append(f"[{idx}:{mid}] :: [{tag}] {txt}\n") + elif tag == "P": + lines_p.append(f"[{idx}:{mid}] :: [{tag}] {txt}") + + return "\n".join(lines_o), "\n".join(lines_p) + + def _filter_memories_by_threshold( + self, + memories: list, + threshold: float = 0.30, + min_num: int = 3, + memory_type: Literal["OuterMemory"] = "OuterMemory", + ) -> list: + """ + Filter memories by threshold and type. + + Args: + memories: List of memory items + threshold: Relevance threshold + min_num: Minimum number of memories to keep + memory_type: Memory type to filter + + Returns: + Filtered list of memories + """ + if not memories: + return [] + + # Handle dict format (from search results) + def get_relativity(m): + if isinstance(m, dict): + return m.get("metadata", {}).get("relativity", 0.0) + return getattr(getattr(m, "metadata", None), "relativity", 0.0) + + def get_memory_type(m): + if isinstance(m, dict): + return m.get("metadata", {}).get("memory_type", "") + return getattr(getattr(m, "metadata", None), "memory_type", "") + + sorted_memories = sorted(memories, key=get_relativity, reverse=True) + filtered_person = [m for m in memories if get_memory_type(m) != memory_type] + filtered_outer = [m for m in memories if get_memory_type(m) == memory_type] + + filtered = [] + per_memory_count = 0 + + for m in sorted_memories: + if get_relativity(m) >= threshold: + if get_memory_type(m) != memory_type: + per_memory_count += 1 + filtered.append(m) + + if len(filtered) < min_num: + filtered = filtered_person[:min_num] + filtered_outer[:min_num] + else: + if per_memory_count < min_num: + filtered += filtered_person[per_memory_count:min_num] + + filtered_memory = sorted(filtered, key=get_relativity, reverse=True) + return filtered_memory + + def _get_further_suggestion( + self, + current_messages: MessageList, + ) -> list[str]: + """Get further suggestion based on current messages.""" + try: + dialogue_info = "\n".join( + [f"{msg['role']}: {msg['content']}" for msg in current_messages[-2:]] + ) + further_suggestion_prompt = FURTHER_SUGGESTION_PROMPT.format(dialogue=dialogue_info) + message_list = [{"role": "system", "content": further_suggestion_prompt}] + response = self.llm.generate(message_list) + clean_response = clean_json_response(response) + response_json = json.loads(clean_response) + return response_json["query"] + except Exception as e: + self.logger.error(f"Error getting further suggestion: {e}", exc_info=True) + return [] + + def _extract_references_from_response(self, response: str) -> tuple[str, list[dict]]: + """Extract reference information from the response and return clean text.""" + import re + + try: + references = [] + # Pattern to match [refid:memoriesID] + pattern = r"\[(\d+):([^\]]+)\]" + + matches = re.findall(pattern, response) + for ref_number, memory_id in matches: + references.append({"memory_id": memory_id, "reference_number": int(ref_number)}) + + # Remove all reference markers from the text to get clean text + clean_text = re.sub(pattern, "", response) + + # Clean up any extra whitespace that might be left after removing markers + clean_text = re.sub(r"\s+", " ", clean_text).strip() + + return clean_text, references + except Exception as e: + self.logger.error(f"Error extracting references from response: {e}", exc_info=True) + return response, [] + + def _extract_struct_data_from_history(self, chat_data: list[dict]) -> dict: + """ + Extract structured message data from chat history. + + Args: + chat_data: List of chat messages + + Returns: + Dictionary with system, memory, and chat_history + """ + system_content = "" + memory_content = "" + chat_history = [] + + for item in chat_data: + role = item.get("role") + content = item.get("content", "") + if role == "system": + parts = content.split("# Memories", 1) + system_content = parts[0].strip() + if len(parts) > 1: + memory_content = "# Memories" + parts[1].strip() + elif role in ("user", "assistant"): + chat_history.append({"role": role, "content": content}) + + if chat_history and chat_history[-1]["role"] == "assistant": + if len(chat_history) >= 2 and chat_history[-2]["role"] == "user": + chat_history = chat_history[:-2] + else: + chat_history = chat_history[:-1] + + return {"system": system_content, "memory": memory_content, "chat_history": chat_history} + + def _send_message_to_scheduler( + self, + user_id: str, + mem_cube_id: str, + query: str, + label: str, + ) -> None: + """ + Send message to scheduler. + + Args: + user_id: User ID + mem_cube_id: Memory cube ID + query: Query content + label: Message label + """ + try: + message_item = ScheduleMessageItem( + user_id=user_id, + mem_cube_id=mem_cube_id, + label=label, + content=query, + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) + self.logger.info(f"Sent message to scheduler with label: {label}") + except Exception as e: + self.logger.error(f"Failed to send message to scheduler: {e}", exc_info=True) + + async def _post_chat_processing( + self, + user_id: str, + cube_id: str, + session_id: str, + query: str, + full_response: str, + system_prompt: str, + time_start: float, + time_end: float, + speed_improvement: float, + current_messages: list, + ) -> None: + """ + Asynchronous post-chat processing with complete functionality. + + Includes: + - Reference extraction + - DingDing notification + - Scheduler messaging + - Memory addition + + Args: + user_id: User ID + cube_id: Memory cube ID + session_id: Session ID + query: User query + full_response: Full LLM response + system_prompt: System prompt used + time_start: Start timestamp + time_end: End timestamp + speed_improvement: Speed improvement metric + current_messages: Current message history + """ + try: + self.logger.info( + f"user_id: {user_id}, cube_id: {cube_id}, current_messages: {current_messages}" + ) + self.logger.info( + f"user_id: {user_id}, cube_id: {cube_id}, full_response: {full_response}" + ) + + # Extract references and clean response + clean_response, extracted_references = self._extract_references_from_response( + full_response + ) + struct_message = self._extract_struct_data_from_history(current_messages) + self.logger.info(f"Extracted {len(extracted_references)} references from response") + + # Send DingDing notification if enabled + if self.online_bot: + self.logger.info("Online Bot Open!") + try: + from memos.memos_tools.notification_utils import ( + send_online_bot_notification_async, + ) + + # Prepare notification data + chat_data = {"query": query, "user_id": user_id, "cube_id": cube_id} + chat_data.update( + { + "memory": struct_message["memory"], + "chat_history": struct_message["chat_history"], + "full_response": full_response, + } + ) + + system_data = { + "references": extracted_references, + "time_start": time_start, + "time_end": time_end, + "speed_improvement": speed_improvement, + } + + emoji_config = {"chat": "💬", "system_info": "📊"} + + await send_online_bot_notification_async( + online_bot=self.online_bot, + header_name="MemOS Chat Report", + sub_title_name="chat_with_references", + title_color="#00956D", + other_data1=chat_data, + other_data2=system_data, + emoji=emoji_config, + ) + except Exception as e: + self.logger.warning(f"Failed to send chat notification (async): {e}") + + # Send answer to scheduler + self._send_message_to_scheduler( + user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_LABEL + ) + + # Add conversation to memory using add handler + add_req = APIADDRequest( + user_id=user_id, + mem_cube_id=cube_id, + session_id=session_id, + messages=[ + { + "role": "user", + "content": query, + "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + }, + { + "role": "assistant", + "content": clean_response, # Store clean text without reference markers + "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + }, + ], + async_mode="sync", # set suync for playground + ) + + self.add_handler.handle_add_memories(add_req) + + self.logger.info(f"Post-chat processing completed for user {user_id}") + + except Exception as e: + self.logger.error( + f"Error in post-chat processing for user {user_id}: {e}", exc_info=True + ) + + def _start_post_chat_processing( + self, + user_id: str, + cube_id: str, + session_id: str, + query: str, + full_response: str, + system_prompt: str, + time_start: float, + time_end: float, + speed_improvement: float, + current_messages: list, + ) -> None: + """ + Start asynchronous post-chat processing in a background thread. + + Args: + user_id: User ID + cube_id: Memory cube ID + session_id: Session ID + query: User query + full_response: Full LLM response + system_prompt: System prompt used + time_start: Start timestamp + time_end: End timestamp + speed_improvement: Speed improvement metric + current_messages: Current message history + """ + + def run_async_in_thread(): + """Running asynchronous tasks in a new thread""" + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + self._post_chat_processing( + user_id=user_id, + cube_id=cube_id, + session_id=session_id, + query=query, + full_response=full_response, + system_prompt=system_prompt, + time_start=time_start, + time_end=time_end, + speed_improvement=speed_improvement, + current_messages=current_messages, + ) + ) + finally: + loop.close() + except Exception as e: + self.logger.error( + f"Error in thread-based post-chat processing for user {user_id}: {e}", + exc_info=True, + ) + + try: + # Try to get the current event loop + asyncio.get_running_loop() + # Create task and store reference to prevent garbage collection + task = asyncio.create_task( + self._post_chat_processing( + user_id=user_id, + cube_id=cube_id, + session_id=session_id, + query=query, + full_response=full_response, + system_prompt=system_prompt, + time_start=time_start, + time_end=time_end, + speed_improvement=speed_improvement, + current_messages=current_messages, + ) + ) + # Add exception handling for the background task + task.add_done_callback( + lambda t: self.logger.error( + f"Error in background post-chat processing for user {user_id}: {t.exception()}", + exc_info=True, + ) + if t.exception() + else None + ) + except RuntimeError: + # No event loop, run in a new thread with context propagation + thread = ContextThread( + target=run_async_in_thread, + name=f"PostChatProcessing-{user_id}", + daemon=True, + ) + thread.start() diff --git a/src/memos/api/handlers/chat_handlers.py b/src/memos/api/handlers/chat_handlers.py deleted file mode 100644 index 6cd7a135..00000000 --- a/src/memos/api/handlers/chat_handlers.py +++ /dev/null @@ -1,130 +0,0 @@ -""" -Chat handler for chat functionality. - -This module handles both streaming and complete chat responses with memory integration. -""" - -import json -import traceback - -from typing import Any - -from fastapi import HTTPException -from fastapi.responses import StreamingResponse - -from memos.api.product_models import APIChatCompleteRequest, ChatRequest -from memos.log import get_logger -from memos.mem_os.product_server import MOSServer - - -logger = get_logger(__name__) - - -def handle_chat_complete( - chat_req: APIChatCompleteRequest, - mos_server: Any, - naive_mem_cube: Any, -) -> dict[str, Any]: - """ - Chat with MemOS for complete response (non-streaming). - - Processes a chat request and returns the complete response with references. - - Args: - chat_req: Chat complete request - mos_server: MOS server instance - naive_mem_cube: Memory cube instance - - Returns: - Dictionary with response and references - - Raises: - HTTPException: If chat fails - """ - try: - # Collect all responses from the server - content, references = mos_server.chat( - query=chat_req.query, - user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, - mem_cube=naive_mem_cube, - history=chat_req.history, - internet_search=chat_req.internet_search, - moscube=chat_req.moscube, - base_prompt=chat_req.base_prompt, - top_k=chat_req.top_k, - threshold=chat_req.threshold, - session_id=chat_req.session_id, - ) - - # Return the complete response - return { - "message": "Chat completed successfully", - "data": {"response": content, "references": references}, - } - - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to complete chat: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err - - -def handle_chat_stream( - chat_req: ChatRequest, - mos_server: MOSServer, -) -> StreamingResponse: - """ - Chat with MemOS via Server-Sent Events (SSE) stream. - - Processes a chat request and streams the response back via SSE. - - Args: - chat_req: Chat stream request - mos_server: MOS server instance - - Returns: - StreamingResponse with SSE formatted chat stream - - Raises: - HTTPException: If stream initialization fails - """ - try: - - def generate_chat_response(): - """Generate chat response as SSE stream.""" - try: - # Directly yield from the generator without async wrapper - yield from mos_server.chat_with_references( - query=chat_req.query, - user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, - history=chat_req.history, - internet_search=chat_req.internet_search, - moscube=chat_req.moscube, - session_id=chat_req.session_id, - ) - - except Exception as e: - logger.error(f"Error in chat stream: {e}") - error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n" - yield error_data - - return StreamingResponse( - generate_chat_response(), - media_type="text/event-stream", - headers={ - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Content-Type": "text/event-stream", - "Access-Control-Allow-Origin": "*", - "Access-Control-Allow-Headers": "*", - "Access-Control-Allow-Methods": "*", - }, - ) - - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to start chat stream: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index b89e879b..7e3fccc0 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -233,7 +233,6 @@ def init_server() -> tuple[Any, ...]: mem_reader=mem_reader, ) mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube) - logger.debug("Scheduler initialized") # Initialize SchedulerAPIModule diff --git a/src/memos/api/handlers/formatters_handlers.py b/src/memos/api/handlers/formatters_handler.py similarity index 100% rename from src/memos/api/handlers/formatters_handlers.py rename to src/memos/api/handlers/formatters_handler.py diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py new file mode 100644 index 00000000..2242e6e3 --- /dev/null +++ b/src/memos/api/handlers/memory_handler.py @@ -0,0 +1,201 @@ +""" +Memory handler for retrieving and managing memories. + +This module handles retrieving all memories or specific subgraphs based on queries. +""" + +import random + +from typing import Any, Literal + +from memos.api.product_models import MemoryResponse +from memos.log import get_logger +from memos.mem_os.utils.format_utils import ( + convert_graph_to_tree_forworkmem, + ensure_unique_tree_ids, + filter_nodes_by_tree_ids, + remove_embedding_recursive, + sort_children_by_memory_type, +) + + +logger = get_logger(__name__) + + +def handle_get_all_memories( + user_id: str, + mem_cube_id: str, + memory_type: Literal["text_mem", "act_mem", "param_mem", "para_mem"], + naive_mem_cube: Any, +) -> MemoryResponse: + """ + Main handler for getting all memories. + + Retrieves all memories of specified type for a user and formats them appropriately. + + Args: + user_id: User ID + mem_cube_id: Memory cube ID + memory_type: Type of memory to retrieve + naive_mem_cube: Memory cube instance + + Returns: + MemoryResponse with formatted memory data + """ + try: + reformat_memory_list = [] + + if memory_type == "text_mem": + # Get all text memories from the graph database + memories = naive_mem_cube.text_mem.get_all(user_name=mem_cube_id) + + # Format and convert to tree structure + memories_cleaned = remove_embedding_recursive(memories) + custom_type_ratios = { + "WorkingMemory": 0.20, + "LongTermMemory": 0.40, + "UserMemory": 0.40, + } + tree_result, node_type_count = convert_graph_to_tree_forworkmem( + memories_cleaned, target_node_count=200, type_ratios=custom_type_ratios + ) + # Ensure all node IDs are unique in the tree structure + tree_result = ensure_unique_tree_ids(tree_result) + memories_filtered = filter_nodes_by_tree_ids(tree_result, memories_cleaned) + children = tree_result["children"] + children_sort = sort_children_by_memory_type(children) + tree_result["children"] = children_sort + memories_filtered["tree_structure"] = tree_result + + reformat_memory_list.append( + { + "cube_id": mem_cube_id, + "memories": [memories_filtered], + "memory_statistics": node_type_count, + } + ) + + elif memory_type == "act_mem": + # Get activation memory + memories_list = [] + act_mem = getattr(naive_mem_cube, "act_mem", None) + if act_mem: + act_mem_params = act_mem.get_all() + if act_mem_params: + memories_data = act_mem_params[0].model_dump() + records = memories_data.get("records", []) + for record in records.get("text_memories", []): + memories_list.append( + { + "id": memories_data["id"], + "text": record, + "create_time": records.get("timestamp"), + "size": random.randint(1, 20), + "modify_times": 1, + } + ) + + reformat_memory_list.append( + { + "cube_id": mem_cube_id, + "memories": memories_list, + } + ) + + elif memory_type == "para_mem": + # Get parameter memory + act_mem = getattr(naive_mem_cube, "act_mem", None) + if act_mem: + act_mem_params = act_mem.get_all() + if act_mem_params: + reformat_memory_list.append( + { + "cube_id": mem_cube_id, + "memories": act_mem_params[0].model_dump(), + } + ) + else: + reformat_memory_list.append( + { + "cube_id": mem_cube_id, + "memories": {}, + } + ) + else: + reformat_memory_list.append( + { + "cube_id": mem_cube_id, + "memories": {}, + } + ) + + return MemoryResponse( + message="Memories retrieved successfully", + data=reformat_memory_list, + ) + + except Exception as e: + logger.error(f"Failed to get all memories: {e}", exc_info=True) + raise + + +def handle_get_subgraph( + user_id: str, + mem_cube_id: str, + query: str, + top_k: int, + naive_mem_cube: Any, +) -> MemoryResponse: + """ + Main handler for getting memory subgraph based on query. + + Retrieves relevant memory subgraph and formats it as a tree structure. + + Args: + user_id: User ID + mem_cube_id: Memory cube ID + query: Search query + top_k: Number of top results to return + naive_mem_cube: Memory cube instance + + Returns: + MemoryResponse with formatted subgraph data + """ + try: + # Get relevant subgraph from text memory + memories = naive_mem_cube.text_mem.get_relevant_subgraph(query, top_k=top_k) + + # Format and convert to tree structure + memories_cleaned = remove_embedding_recursive(memories) + custom_type_ratios = { + "WorkingMemory": 0.20, + "LongTermMemory": 0.40, + "UserMemory": 0.40, + } + tree_result, node_type_count = convert_graph_to_tree_forworkmem( + memories_cleaned, target_node_count=150, type_ratios=custom_type_ratios + ) + # Ensure all node IDs are unique in the tree structure + tree_result = ensure_unique_tree_ids(tree_result) + memories_filtered = filter_nodes_by_tree_ids(tree_result, memories_cleaned) + children = tree_result["children"] + children_sort = sort_children_by_memory_type(children) + tree_result["children"] = children_sort + memories_filtered["tree_structure"] = tree_result + + reformat_memory_list = [ + { + "cube_id": mem_cube_id, + "memories": [memories_filtered], + "memory_statistics": node_type_count, + } + ] + + return MemoryResponse( + message="Memories retrieved successfully", + data=reformat_memory_list, + ) + + except Exception as e: + logger.error(f"Failed to get subgraph: {e}", exc_info=True) + raise diff --git a/src/memos/api/handlers/scheduler_handlers.py b/src/memos/api/handlers/scheduler_handler.py similarity index 99% rename from src/memos/api/handlers/scheduler_handlers.py rename to src/memos/api/handlers/scheduler_handler.py index f621e891..8d3c6dc7 100644 --- a/src/memos/api/handlers/scheduler_handlers.py +++ b/src/memos/api/handlers/scheduler_handler.py @@ -14,7 +14,7 @@ from fastapi import HTTPException from fastapi.responses import StreamingResponse -from memos.api.handlers.formatters_handlers import to_iter +from memos.api.handlers.formatters_handler import to_iter from memos.log import get_logger diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py new file mode 100644 index 00000000..9fc8a5b2 --- /dev/null +++ b/src/memos/api/handlers/search_handler.py @@ -0,0 +1,289 @@ +""" +Search handler for memory search functionality (Class-based version). + +This module provides a class-based implementation of search handlers, +using dependency injection for better modularity and testability. +""" + +import os +import traceback + +from typing import Any + +from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies +from memos.api.handlers.formatters_handler import ( + format_memory_item, + post_process_pref_mem, +) +from memos.api.product_models import APISearchRequest, SearchResponse +from memos.context.context import ContextThreadPoolExecutor +from memos.mem_scheduler.schemas.general_schemas import SearchMode +from memos.types import MOSSearchResult, UserContext + + +class SearchHandler(BaseHandler): + """ + Handler for memory search operations. + + Provides fast, fine-grained, and mixture-based search modes. + """ + + def __init__(self, dependencies: HandlerDependencies): + """ + Initialize search handler. + + Args: + dependencies: HandlerDependencies instance + """ + super().__init__(dependencies) + self._validate_dependencies("naive_mem_cube", "mem_scheduler") + + def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: + """ + Main handler for search memories endpoint. + + Orchestrates the search process based on the requested search mode, + supporting both text and preference memory searches. + + Args: + search_req: Search request containing query and parameters + + Returns: + SearchResponse with formatted results + """ + # Create UserContext object + user_context = UserContext( + user_id=search_req.user_id, + mem_cube_id=search_req.mem_cube_id, + session_id=search_req.session_id or "default_session", + ) + self.logger.info(f"Search Req is: {search_req}") + + memories_result: MOSSearchResult = { + "text_mem": [], + "act_mem": [], + "para_mem": [], + "pref_mem": [], + "pref_note": "", + } + + # Determine search mode + search_mode = self._get_search_mode(search_req.mode) + + # Execute search in parallel for text and preference memories + with ContextThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(self._search_text, search_req, user_context, search_mode) + pref_future = executor.submit(self._search_pref, search_req, user_context) + + text_formatted_memories = text_future.result() + pref_formatted_memories = pref_future.result() + + # Build result + memories_result["text_mem"].append( + { + "cube_id": search_req.mem_cube_id, + "memories": text_formatted_memories, + } + ) + + memories_result = post_process_pref_mem( + memories_result, + pref_formatted_memories, + search_req.mem_cube_id, + search_req.include_preference, + ) + + self.logger.info(f"Search memories result: {memories_result}") + + return SearchResponse( + message="Search completed successfully", + data=memories_result, + ) + + def _get_search_mode(self, mode: str) -> str: + """ + Get search mode with environment variable fallback. + + Args: + mode: Requested search mode + + Returns: + Search mode string + """ + if mode == SearchMode.NOT_INITIALIZED: + return os.getenv("SEARCH_MODE", SearchMode.FAST) + return mode + + def _search_text( + self, + search_req: APISearchRequest, + user_context: UserContext, + search_mode: str, + ) -> list[dict[str, Any]]: + """ + Search text memories based on mode. + + Args: + search_req: Search request + user_context: User context + search_mode: Search mode (FAST, FINE, or MIXTURE) + + Returns: + List of formatted memory items + """ + try: + if search_mode == SearchMode.FAST: + memories = self._fast_search(search_req, user_context) + elif search_mode == SearchMode.FINE: + memories = self._fine_search(search_req, user_context) + elif search_mode == SearchMode.MIXTURE: + memories = self._mix_search(search_req, user_context) + else: + self.logger.error(f"Unsupported search mode: {search_mode}") + return [] + + return [format_memory_item(data) for data in memories] + + except Exception as e: + self.logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc()) + return [] + + def _search_pref( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list[dict[str, Any]]: + """ + Search preference memories. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of formatted preference memory items + """ + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + + try: + results = self.naive_mem_cube.pref_mem.search( + query=search_req.query, + top_k=search_req.pref_top_k, + info={ + "user_id": search_req.user_id, + "session_id": search_req.session_id, + "chat_history": search_req.chat_history, + }, + ) + return [format_memory_item(data) for data in results] + except Exception as e: + self.logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) + return [] + + def _fast_search( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list: + """ + Fast search using vector database. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of search results + """ + target_session_id = search_req.session_id or "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + return self.naive_mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FAST, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + + def _fine_search( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list: + """ + Fine-grained search with query enhancement. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of enhanced search results + """ + target_session_id = search_req.session_id or "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + searcher = self.mem_scheduler.searcher + + info = { + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + } + + # Fast retrieve + fast_retrieved_memories = searcher.retrieve( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FAST, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info=info, + ) + + # Post retrieve + fast_memories = searcher.post_retrieve( + retrieved_results=fast_retrieved_memories, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) + + # Enhance with query + enhanced_results, _ = self.mem_scheduler.retriever.enhance_memories_with_query( + query_history=[search_req.query], + memories=fast_memories, + ) + + return enhanced_results + + def _mix_search( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list: + """ + Mix search combining fast and fine-grained approaches. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of formatted search results + """ + return self.mem_scheduler.mix_search_memories( + search_req=search_req, + user_context=user_context, + ) diff --git a/src/memos/api/handlers/search_handlers.py b/src/memos/api/handlers/search_handlers.py deleted file mode 100644 index a1a03d34..00000000 --- a/src/memos/api/handlers/search_handlers.py +++ /dev/null @@ -1,261 +0,0 @@ -""" -Search handler for memory search functionality. - -This module handles all memory search operations including fast, fine-grained, -and mixture-based search modes. -""" - -import os -import traceback - -from typing import Any - -from memos.api.handlers.formatters_handlers import ( - format_memory_item, - post_process_pref_mem, -) -from memos.api.product_models import APISearchRequest, SearchResponse -from memos.context.context import ContextThreadPoolExecutor -from memos.log import get_logger -from memos.mem_scheduler.schemas.general_schemas import SearchMode -from memos.types import MOSSearchResult, UserContext - - -logger = get_logger(__name__) - - -def fast_search_memories( - search_req: APISearchRequest, - user_context: UserContext, - naive_mem_cube: Any, -) -> list[dict[str, Any]]: - """ - Fast search memories using vector database. - - Performs a quick vector-based search for memories. - - Args: - search_req: Search request containing query and parameters - user_context: User context with IDs - naive_mem_cube: Memory cube instance - - Returns: - List of formatted search results - """ - target_session_id = search_req.session_id or "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - # Create MemCube and perform search - search_results = naive_mem_cube.text_mem.search( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=SearchMode.FAST, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, - ) - formatted_memories = [format_memory_item(data) for data in search_results] - - return formatted_memories - - -def fine_search_memories( - search_req: APISearchRequest, - user_context: UserContext, - mem_scheduler: Any, -) -> list[dict[str, Any]]: - """ - Fine-grained search memories using scheduler and retriever. - - Performs a more comprehensive search with query enhancement. - - Args: - search_req: Search request containing query and parameters - user_context: User context with IDs - mem_scheduler: Scheduler instance for advanced retrieval - - Returns: - List of formatted search results - """ - target_session_id = search_req.session_id or "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - searcher = mem_scheduler.searcher - - info = { - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - } - - fast_retrieved_memories = searcher.retrieve( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=SearchMode.FAST, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, - info=info, - ) - - fast_memories = searcher.post_retrieve( - retrieved_results=fast_retrieved_memories, - top_k=search_req.top_k, - user_name=user_context.mem_cube_id, - info=info, - ) - - enhanced_results, _ = mem_scheduler.retriever.enhance_memories_with_query( - query_history=[search_req.query], - memories=fast_memories, - ) - - formatted_memories = [format_memory_item(data) for data in enhanced_results] - - return formatted_memories - - -def mix_search_memories( - search_req: APISearchRequest, - user_context: UserContext, - mem_scheduler: Any, -) -> list[dict[str, Any]]: - """ - Mix search memories: fast search + async fine search. - - Combines fast initial search with asynchronous fine-grained search. - - Args: - search_req: Search request containing query and parameters - user_context: User context with IDs - mem_scheduler: Scheduler instance - - Returns: - List of formatted search results - """ - formatted_memories = mem_scheduler.mix_search_memories( - search_req=search_req, - user_context=user_context, - ) - return formatted_memories - - -def handle_search_memories( - search_req: APISearchRequest, - naive_mem_cube: Any, - mem_scheduler: Any, -) -> SearchResponse: - """ - Main handler for search memories endpoint. - - Orchestrates the search process based on the requested search mode, - supporting both text and preference memory searches. - - Args: - search_req: Search request - naive_mem_cube: Memory cube instance - mem_scheduler: Scheduler instance - - Returns: - SearchResponse with formatted results - """ - # Create UserContext object - user_context = UserContext( - user_id=search_req.user_id, - mem_cube_id=search_req.mem_cube_id, - session_id=search_req.session_id or "default_session", - ) - logger.info(f"Search Req is: {search_req}") - - memories_result: MOSSearchResult = { - "text_mem": [], - "act_mem": [], - "para_mem": [], - "pref_mem": [], - "pref_note": "", - } - - if search_req.mode == SearchMode.NOT_INITIALIZED: - search_mode = os.getenv("SEARCH_MODE", SearchMode.FAST) - else: - search_mode = search_req.mode - - def _search_text(): - try: - if search_mode == SearchMode.FAST: - formatted_memories = fast_search_memories( - search_req=search_req, - user_context=user_context, - naive_mem_cube=naive_mem_cube, - ) - elif search_mode == SearchMode.FINE: - formatted_memories = fine_search_memories( - search_req=search_req, - user_context=user_context, - mem_scheduler=mem_scheduler, - ) - elif search_mode == SearchMode.MIXTURE: - formatted_memories = mix_search_memories( - search_req=search_req, - user_context=user_context, - mem_scheduler=mem_scheduler, - ) - else: - logger.error(f"Unsupported search mode: {search_mode}") - return [] - return formatted_memories - except Exception as e: - logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc()) - return [] - - def _search_pref(): - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - try: - results = naive_mem_cube.pref_mem.search( - query=search_req.query, - top_k=search_req.pref_top_k, - info={ - "user_id": search_req.user_id, - "session_id": search_req.session_id, - "chat_history": search_req.chat_history, - }, - ) - return [format_memory_item(data) for data in results] - except Exception as e: - logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) - return [] - - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(_search_text) - pref_future = executor.submit(_search_pref) - text_formatted_memories = text_future.result() - pref_formatted_memories = pref_future.result() - - memories_result["text_mem"].append( - { - "cube_id": search_req.mem_cube_id, - "memories": text_formatted_memories, - } - ) - - memories_result = post_process_pref_mem( - memories_result, - pref_formatted_memories, - search_req.mem_cube_id, - search_req.include_preference, - ) - - logger.info(f"Search memories result: {memories_result}") - - return SearchResponse( - message="Search completed successfully", - data=memories_result, - ) diff --git a/src/memos/api/handlers/suggestion_handler.py b/src/memos/api/handlers/suggestion_handler.py new file mode 100644 index 00000000..dce89400 --- /dev/null +++ b/src/memos/api/handlers/suggestion_handler.py @@ -0,0 +1,117 @@ +""" +Suggestion handler for generating suggestion queries. + +This module handles suggestion query generation based on user's recent memories +or further suggestions from chat history. +""" + +import json + +from typing import Any + +from memos.api.product_models import SuggestionResponse +from memos.log import get_logger +from memos.mem_os.utils.format_utils import clean_json_response +from memos.templates.mos_prompts import ( + FURTHER_SUGGESTION_PROMPT, + SUGGESTION_QUERY_PROMPT_EN, + SUGGESTION_QUERY_PROMPT_ZH, +) +from memos.types import MessageList + + +logger = get_logger(__name__) + + +def _get_further_suggestion( + llm: Any, + message: MessageList, +) -> list[str]: + """ + Get further suggestion based on recent dialogue. + + Args: + llm: LLM instance for generating suggestions + message: Recent chat messages + + Returns: + List of suggestion queries + """ + try: + dialogue_info = "\n".join([f"{msg['role']}: {msg['content']}" for msg in message[-2:]]) + further_suggestion_prompt = FURTHER_SUGGESTION_PROMPT.format(dialogue=dialogue_info) + message_list = [{"role": "system", "content": further_suggestion_prompt}] + response = llm.generate(message_list) + clean_response = clean_json_response(response) + response_json = json.loads(clean_response) + return response_json["query"] + except Exception as e: + logger.error(f"Error getting further suggestion: {e}", exc_info=True) + return [] + + +def handle_get_suggestion_queries( + user_id: str, + language: str, + message: MessageList | None, + llm: Any, + naive_mem_cube: Any, +) -> SuggestionResponse: + """ + Main handler for suggestion queries endpoint. + + Generates suggestion queries based on user's recent memories or chat history. + + Args: + user_id: User ID + language: Language preference ("zh" or "en") + message: Optional chat message list for further suggestions + llm: LLM instance + naive_mem_cube: Memory cube instance + + Returns: + SuggestionResponse with generated queries + """ + try: + # If message is provided, get further suggestions based on dialogue + if message: + suggestions = _get_further_suggestion(llm, message) + return SuggestionResponse( + message="Suggestions retrieved successfully", + data={"query": suggestions}, + ) + + # Otherwise, generate suggestions based on recent memories + if language == "zh": + suggestion_prompt = SUGGESTION_QUERY_PROMPT_ZH + else: # English + suggestion_prompt = SUGGESTION_QUERY_PROMPT_EN + + # Search for recent memories + text_mem_results = naive_mem_cube.text_mem.search( + query="my recently memories", + user_name=user_id, + top_k=3, + mode="fast", + info={"user_id": user_id}, + ) + + # Extract memory content + memories = "" + if text_mem_results: + memories = "\n".join([m.memory[:200] for m in text_mem_results]) + + # Generate suggestions using LLM + message_list = [{"role": "system", "content": suggestion_prompt.format(memories=memories)}] + response = llm.generate(message_list) + clean_response = clean_json_response(response) + response_json = json.loads(clean_response) + + return SuggestionResponse( + message="Suggestions retrieved successfully", + data={"query": response_json["query"]}, + ) + + except Exception as e: + logger.error(f"Failed to get suggestions: {e}", exc_info=True) + raise diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 3b1ce2fc..39935f34 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -200,6 +200,9 @@ class APIADDRequest(BaseRequest): operation: list[PermissionDict] | None = Field( None, description="operation ids for multi cubes" ) + async_mode: Literal["async", "sync"] = Field( + "async", description="Whether to add memory in async mode" + ) class APIChatCompleteRequest(BaseRequest): diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index fbe6f7aa..140ffaaa 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,11 +1,14 @@ """ -Server API Router for MemOS. +Server API Router for MemOS (Class-based handlers version). -This router provides low-level API endpoints for direct server instance operations, -including search, add, scheduler management, and chat functionalities. +This router demonstrates the improved architecture using class-based handlers +with dependency injection, providing better modularity and maintainability. -The actual implementation logic is delegated to specialized handler modules -in server_handlers package for better modularity and maintainability. +Comparison with function-based approach: +- Cleaner code: No need to pass dependencies in every endpoint +- Better testability: Easy to mock handler dependencies +- Improved extensibility: Add new handlers or modify existing ones easily +- Clear separation of concerns: Router focuses on routing, handlers handle business logic """ import os @@ -15,13 +18,20 @@ from fastapi import APIRouter from memos.api import handlers +from memos.api.handlers.add_handler import AddHandler +from memos.api.handlers.base_handler import HandlerDependencies +from memos.api.handlers.chat_handler import ChatHandler +from memos.api.handlers.search_handler import SearchHandler from memos.api.product_models import ( APIADDRequest, APIChatCompleteRequest, APISearchRequest, ChatRequest, + GetMemoryRequest, MemoryResponse, SearchResponse, + SuggestionRequest, + SuggestionResponse, ) from memos.log import get_logger @@ -34,6 +44,17 @@ INSTANCE_ID = f"{socket.gethostname()}:{os.getpid()}:{_random.randint(1000, 9999)}" # Initialize all server components +components = handlers.init_server() + +# Create dependency container +dependencies = HandlerDependencies.from_init_server(*components) + +# Initialize all handlers with dependency injection +search_handler = SearchHandler(dependencies) +add_handler = AddHandler(dependencies) +chat_handler = ChatHandler(dependencies, search_handler, add_handler) + +# For backward compatibility, also provide component access ( graph_db, mem_reader, @@ -53,7 +74,7 @@ pref_retriever, text_mem, pref_mem, -) = handlers.init_server() +) = components # ============================================================================= @@ -63,12 +84,12 @@ @router.post("/search", summary="Search memories", response_model=SearchResponse) def search_memories(search_req: APISearchRequest): - """Search memories for a specific user.""" - return handlers.search_handlers.handle_search_memories( - search_req=search_req, - naive_mem_cube=naive_mem_cube, - mem_scheduler=mem_scheduler, - ) + """ + Search memories for a specific user. + + This endpoint uses the class-based SearchHandler for better code organization. + """ + return search_handler.handle_search_memories(search_req) # ============================================================================= @@ -78,13 +99,12 @@ def search_memories(search_req: APISearchRequest): @router.post("/add", summary="Add memories", response_model=MemoryResponse) def add_memories(add_req: APIADDRequest): - """Add memories for a specific user.""" - return handlers.add_handlers.handle_add_memories( - add_req=add_req, - naive_mem_cube=naive_mem_cube, - mem_reader=mem_reader, - mem_scheduler=mem_scheduler, - ) + """ + Add memories for a specific user. + + This endpoint uses the class-based AddHandler for better code organization. + """ + return add_handler.handle_add_memories(add_req) # ============================================================================= @@ -95,7 +115,7 @@ def add_memories(add_req: APIADDRequest): @router.get("/scheduler/status", summary="Get scheduler running status") def scheduler_status(user_name: str | None = None): """Get scheduler running status.""" - return handlers.scheduler_handlers.handle_scheduler_status( + return handlers.scheduler_handler.handle_scheduler_status( user_name=user_name, mem_scheduler=mem_scheduler, instance_id=INSTANCE_ID, @@ -109,7 +129,7 @@ def scheduler_wait( poll_interval: float = 0.2, ): """Wait until scheduler is idle for a specific user.""" - return handlers.scheduler_handlers.handle_scheduler_wait( + return handlers.scheduler_handler.handle_scheduler_wait( user_name=user_name, timeout_seconds=timeout_seconds, poll_interval=poll_interval, @@ -124,7 +144,7 @@ def scheduler_wait_stream( poll_interval: float = 0.2, ): """Stream scheduler progress via Server-Sent Events (SSE).""" - return handlers.scheduler_handlers.handle_scheduler_wait_stream( + return handlers.scheduler_handler.handle_scheduler_wait_stream( user_name=user_name, timeout_seconds=timeout_seconds, poll_interval=poll_interval, @@ -140,18 +160,75 @@ def scheduler_wait_stream( @router.post("/chat/complete", summary="Chat with MemOS (Complete Response)") def chat_complete(chat_req: APIChatCompleteRequest): - """Chat with MemOS for a specific user. Returns complete response (non-streaming).""" - return handlers.chat_handlers.handle_chat_complete( - chat_req=chat_req, - mos_server=mos_server, - naive_mem_cube=naive_mem_cube, - ) + """ + Chat with MemOS for a specific user. Returns complete response (non-streaming). + + This endpoint uses the class-based ChatHandler. + """ + return chat_handler.handle_chat_complete(chat_req) @router.post("/chat", summary="Chat with MemOS") def chat(chat_req: ChatRequest): - """Chat with MemOS for a specific user. Returns SSE stream.""" - return handlers.chat_handlers.handle_chat_stream( - chat_req=chat_req, - mos_server=mos_server, + """ + Chat with MemOS for a specific user. Returns SSE stream. + + This endpoint uses the class-based ChatHandler which internally + composes SearchHandler and AddHandler for a clean architecture. + """ + return chat_handler.handle_chat_stream(chat_req) + + +# ============================================================================= +# Suggestion API Endpoints +# ============================================================================= + + +@router.post( + "/suggestions", + summary="Get suggestion queries", + response_model=SuggestionResponse, +) +def get_suggestion_queries(suggestion_req: SuggestionRequest): + """Get suggestion queries for a specific user with language preference.""" + return handlers.suggestion_handler.handle_get_suggestion_queries( + user_id=suggestion_req.user_id, + language=suggestion_req.language, + message=suggestion_req.message, + llm=llm, + naive_mem_cube=naive_mem_cube, ) + + +# ============================================================================= +# Memory Retrieval API Endpoints +# ============================================================================= + + +@router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse) +def get_all_memories(memory_req: GetMemoryRequest): + """ + Get all memories or subgraph for a specific user. + + If search_query is provided, returns a subgraph based on the query. + Otherwise, returns all memories of the specified type. + """ + if memory_req.search_query: + return handlers.memory_handler.handle_get_subgraph( + user_id=memory_req.user_id, + mem_cube_id=( + memory_req.mem_cube_ids[0] if memory_req.mem_cube_ids else memory_req.user_id + ), + query=memory_req.search_query, + top_k=20, + naive_mem_cube=naive_mem_cube, + ) + else: + return handlers.memory_handler.handle_get_all_memories( + user_id=memory_req.user_id, + mem_cube_id=( + memory_req.mem_cube_ids[0] if memory_req.mem_cube_ids else memory_req.user_id + ), + memory_type=memory_req.memory_type or "text_mem", + naive_mem_cube=naive_mem_cube, + ) diff --git a/src/memos/mem_os/utils/reference_utils.py b/src/memos/mem_os/utils/reference_utils.py index c2f4431c..09b81220 100644 --- a/src/memos/mem_os/utils/reference_utils.py +++ b/src/memos/mem_os/utils/reference_utils.py @@ -142,12 +142,21 @@ def prepare_reference_data(memories_list: list[TextualMemoryItem]) -> list[dict] # Prepare reference data reference = [] for memories in memories_list: - memories_json = memories.model_dump() - memories_json["metadata"]["ref_id"] = f"{memories.id.split('-')[0]}" - memories_json["metadata"]["embedding"] = [] - memories_json["metadata"]["sources"] = [] - memories_json["metadata"]["memory"] = memories.memory - memories_json["metadata"]["id"] = memories.id - reference.append({"metadata": memories_json["metadata"]}) + if isinstance(memories, TextualMemoryItem): + memories_json = memories.model_dump() + memories_json["metadata"]["ref_id"] = f"{memories.id.split('-')[0]}" + memories_json["metadata"]["embedding"] = [] + memories_json["metadata"]["sources"] = [] + memories_json["metadata"]["memory"] = memories.memory + memories_json["metadata"]["id"] = memories.id + reference.append({"metadata": memories_json["metadata"]}) + else: + memories_json = memories + memories_json["metadata"]["ref_id"] = f"{memories_json['id'].split('-')[0]}" + memories_json["metadata"]["embedding"] = [] + memories_json["metadata"]["sources"] = [] + memories_json["metadata"]["memory"] = memories_json["memory"] + memories_json["metadata"]["id"] = memories_json["id"] + reference.append({"metadata": memories_json["metadata"]}) return reference diff --git a/src/memos/mem_scheduler/general_modules/base.py b/src/memos/mem_scheduler/general_modules/base.py index 392f2bde..0b80b9e7 100644 --- a/src/memos/mem_scheduler/general_modules/base.py +++ b/src/memos/mem_scheduler/general_modules/base.py @@ -51,7 +51,7 @@ def _build_system_prompt(self, memories: list | None = None) -> str: def get_mem_cube(self, mem_cube_id: str) -> GeneralMemCube: logger.error(f"mem_cube {mem_cube_id} does not exists.") - return self.mem_cubes.get(mem_cube_id, None) + return self.current_mem_cube @property def chat_llm(self) -> BaseLLM: diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 32fefce6..2b14887d 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -53,9 +53,6 @@ def long_memory_update_process( ): mem_cube = self.current_mem_cube - # for status update - self._set_current_context_from_message(msg=messages[0]) - # update query monitors for msg in messages: self.monitor.register_query_monitor_if_not_exists( @@ -185,9 +182,6 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: if len(messages) == 0: return - # for status update - self._set_current_context_from_message(msg=messages[0]) - def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.info(f"Messages {messages} assigned to {ADD_LABEL} handler.") # Process the query in a session turn @@ -201,9 +195,6 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: if len(messages) == 0: return - # for status update - self._set_current_context_from_message(msg=messages[0]) - # submit logs for msg in messages: try: From 553a0e39611b98f0267f55ab91c4e7fafcd696f1 Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 13 Nov 2025 15:41:24 +0800 Subject: [PATCH 3/7] feat: update memcube info --- src/memos/api/product_models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 39935f34..892d2d43 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -226,6 +226,7 @@ class SuggestionRequest(BaseRequest): """Request model for getting suggestion queries.""" user_id: str = Field(..., description="User ID") + mem_cube_id: str = Field(..., description="Cube ID") language: Literal["zh", "en"] = Field("zh", description="Language for suggestions") message: list[MessageDict] | None = Field(None, description="List of messages to store.") From b84bb12764805e1ac1053e7c2ba43885830c9358 Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 13 Nov 2025 15:48:31 +0800 Subject: [PATCH 4/7] feat: remove act mem and params mem --- src/memos/api/handlers/memory_handler.py | 54 +----------------------- 1 file changed, 2 insertions(+), 52 deletions(-) diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index 2242e6e3..f816ef65 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -76,59 +76,9 @@ def handle_get_all_memories( ) elif memory_type == "act_mem": - # Get activation memory - memories_list = [] - act_mem = getattr(naive_mem_cube, "act_mem", None) - if act_mem: - act_mem_params = act_mem.get_all() - if act_mem_params: - memories_data = act_mem_params[0].model_dump() - records = memories_data.get("records", []) - for record in records.get("text_memories", []): - memories_list.append( - { - "id": memories_data["id"], - "text": record, - "create_time": records.get("timestamp"), - "size": random.randint(1, 20), - "modify_times": 1, - } - ) - - reformat_memory_list.append( - { - "cube_id": mem_cube_id, - "memories": memories_list, - } - ) - + logger.warning("Activity memory retrieval not implemented yet.") elif memory_type == "para_mem": - # Get parameter memory - act_mem = getattr(naive_mem_cube, "act_mem", None) - if act_mem: - act_mem_params = act_mem.get_all() - if act_mem_params: - reformat_memory_list.append( - { - "cube_id": mem_cube_id, - "memories": act_mem_params[0].model_dump(), - } - ) - else: - reformat_memory_list.append( - { - "cube_id": mem_cube_id, - "memories": {}, - } - ) - else: - reformat_memory_list.append( - { - "cube_id": mem_cube_id, - "memories": {}, - } - ) - + logger.warning("Parameter memory retrieval not implemented yet.") return MemoryResponse( message="Memories retrieved successfully", data=reformat_memory_list, From 672426579f64f78fb8b5c883ebc868be568b0234 Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 13 Nov 2025 18:54:54 +0800 Subject: [PATCH 5/7] feat: upadte init --- src/memos/api/handlers/add_handler.py | 18 +++++- src/memos/api/handlers/base_handler.py | 52 +++------------- src/memos/api/handlers/chat_handler.py | 56 ++++++++--------- src/memos/api/handlers/component_init.py | 78 +++++++++++------------- src/memos/api/handlers/memory_handler.py | 2 +- src/memos/api/routers/server_router.py | 35 +++-------- src/memos/memories/textual/tree.py | 6 +- 7 files changed, 101 insertions(+), 146 deletions(-) diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 257d911b..4cbf1108 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -60,7 +60,9 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: ) self.logger.info(f"Add Req is: {add_req}") - + if (not add_req.messages) and add_req.memory_content: + add_req.messages = self._convert_content_messsage(add_req.memory_content) + self.logger.info(f"Converted Add Req content to messages: {add_req.messages}") # Process text and preference memories in parallel with ContextThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(self._process_text_mem, add_req, user_context) @@ -77,6 +79,20 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: data=text_response_data + pref_response_data, ) + def _convert_content_messsage(seflf, memory_content: str) -> list[dict[str, str]]: + """ + Convert content string to list of message dictionaries. + + Args: + content: add content string + + Returns: + List of message dictionaries + """ + messages_list = [{"role": "user", "content": memory_content, "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}] + # for only user-str input and convert message + return messages_list + def _process_text_mem( self, add_req: APIADDRequest, diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py index 3bb8ae57..86a00dc3 100644 --- a/src/memos/api/handlers/base_handler.py +++ b/src/memos/api/handlers/base_handler.py @@ -70,57 +70,23 @@ def __init__( setattr(self, key, value) @classmethod - def from_init_server(cls, *components): + def from_init_server(cls, components: dict[str, Any]): """ Create dependencies from init_server() return values. Args: - *components: Tuple of components returned by init_server() + components: Dictionary of components returned by init_server(). + All components will be automatically unpacked as dependencies. Returns: HandlerDependencies instance + + Note: + This method uses **kwargs unpacking, so any new components added to + init_server() will automatically become available as dependencies + without modifying this code. """ - ( - graph_db, - mem_reader, - llm, - embedder, - reranker, - internet_retriever, - memory_manager, - default_cube_config, - mos_server, - mem_scheduler, - naive_mem_cube, - api_module, - vector_db, - pref_extractor, - pref_adder, - pref_retriever, - text_mem, - pref_mem, - ) = components - - return cls( - llm=llm, - naive_mem_cube=naive_mem_cube, - mem_reader=mem_reader, - mem_scheduler=mem_scheduler, - embedder=embedder, - reranker=reranker, - graph_db=graph_db, - vector_db=vector_db, - internet_retriever=internet_retriever, - memory_manager=memory_manager, - mos_server=mos_server, - default_cube_config=default_cube_config, - api_module=api_module, - pref_extractor=pref_extractor, - pref_adder=pref_adder, - pref_retriever=pref_retriever, - text_mem=text_mem, - pref_mem=pref_mem, - ) + return cls(**components) class BaseHandler: diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 01a44ca7..a36991a1 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -346,40 +346,34 @@ def generate_chat_response() -> Generator[str, None, None]: def _build_system_prompt( self, - memories_all: list, + memories: list | None = None, base_prompt: str | None = None, - tone: str = "friendly", - verbosity: str = "mid", + **kwargs, ) -> str: - """ - Build system prompt with memory references (for complete response). - - Args: - memories_all: List of memory items - base_prompt: Optional base prompt - tone: Tone of the prompt - verbosity: Verbosity level - - Returns: - System prompt string - """ - now = datetime.now() - formatted_date = now.strftime("%Y-%m-%d (%A)") - sys_body = get_memos_prompt( - date=formatted_date, tone=tone, verbosity=verbosity, mode="base" - ) - - # Format memories - mem_block_o, mem_block_p = self._format_mem_block(memories_all) - mem_block = mem_block_o + "\n" + mem_block_p + """Build system prompt with optional memories context.""" + if base_prompt is None: + base_prompt = ( + "You are a knowledgeable and helpful AI assistant. " + "You have access to conversation memories that help you provide more personalized responses. " + "Use the memories to understand the user's context, preferences, and past interactions. " + "If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories." + ) - prefix = (base_prompt.strip() + "\n\n") if base_prompt else "" - return ( - prefix - + sys_body - + "\n\n# Memories\n## PersonalMemory & OuterMemory (ordered)\n" - + mem_block - ) + memory_context = "" + if memories: + memory_list = [] + for i, memory in enumerate(memories, 1): + text_memory = memory.get("memory", "") + memory_list.append(f"{i}. {text_memory}") + memory_context = "\n".join(memory_list) + + if "{memories}" in base_prompt: + return base_prompt.format(memories=memory_context) + elif base_prompt and memories: + # For backward compatibility, append memories if no placeholder is found + memory_context_with_header = "\n\n## Memories:\n" + memory_context + return base_prompt + memory_context_with_header + return base_prompt def _build_enhance_system_prompt( self, diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 7e3fccc0..e0d4586d 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -71,7 +71,7 @@ def _get_default_memory_size(cube_config: Any) -> dict[str, int]: } -def init_server() -> tuple[Any, ...]: +def init_server() -> dict[str, Any]: """ Initialize all server components and configurations. @@ -83,33 +83,18 @@ def init_server() -> tuple[Any, ...]: - Scheduler and related modules Returns: - A tuple containing all initialized components in this order: - ( - graph_db, - mem_reader, - llm, - embedder, - reranker, - internet_retriever, - memory_manager, - default_cube_config, - mos_server, - mem_scheduler, - naive_mem_cube, - api_module, - vector_db, - pref_extractor, - pref_adder, - pref_retriever, - text_mem, - pref_mem, - ) + A dictionary containing all initialized components with descriptive keys. + This approach allows easy addition of new components without breaking + existing code that uses the components. """ logger.info("Initializing MemOS server components...") # Get default cube configuration default_cube_config = APIConfig.get_default_cube_config() + #Get online bot setting + dingding_enabled = APIConfig.is_dingding_bot_enabled() + # Build component configurations graph_db_config = build_graph_db_config() llm_config = build_llm_config() @@ -247,23 +232,32 @@ def init_server() -> tuple[Any, ...]: logger.info("MemOS server components initialized successfully") - return ( - graph_db, - mem_reader, - llm, - embedder, - reranker, - internet_retriever, - memory_manager, - default_cube_config, - mos_server, - mem_scheduler, - naive_mem_cube, - api_module, - vector_db, - pref_extractor, - pref_adder, - pref_retriever, - text_mem, - pref_mem, - ) + # Initialize online bot if enabled + online_bot = None + if dingding_enabled: + from memos.memos_tools.notification_service import get_online_bot_function + online_bot = get_online_bot_function() if dingding_enabled else None + logger.info("DingDing bot is enabled") + + # Return all components as a dictionary for easy access and extension + return { + "graph_db": graph_db, + "mem_reader": mem_reader, + "llm": llm, + "embedder": embedder, + "reranker": reranker, + "internet_retriever": internet_retriever, + "memory_manager": memory_manager, + "default_cube_config": default_cube_config, + "mos_server": mos_server, + "mem_scheduler": mem_scheduler, + "naive_mem_cube": naive_mem_cube, + "api_module": api_module, + "vector_db": vector_db, + "pref_extractor": pref_extractor, + "pref_adder": pref_adder, + "pref_retriever": pref_retriever, + "text_mem": text_mem, + "pref_mem": pref_mem, + "online_bot": online_bot, + } diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index f816ef65..eacf615c 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -113,7 +113,7 @@ def handle_get_subgraph( """ try: # Get relevant subgraph from text memory - memories = naive_mem_cube.text_mem.get_relevant_subgraph(query, top_k=top_k) + memories = naive_mem_cube.text_mem.get_relevant_subgraph(query, top_k=top_k, user_name=mem_cube_id) # Format and convert to tree structure memories_cleaned = remove_embedding_recursive(memories) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 140ffaaa..0c382f40 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -47,34 +47,19 @@ components = handlers.init_server() # Create dependency container -dependencies = HandlerDependencies.from_init_server(*components) +dependencies = HandlerDependencies.from_init_server(components) # Initialize all handlers with dependency injection search_handler = SearchHandler(dependencies) add_handler = AddHandler(dependencies) -chat_handler = ChatHandler(dependencies, search_handler, add_handler) - -# For backward compatibility, also provide component access -( - graph_db, - mem_reader, - llm, - embedder, - reranker, - internet_retriever, - memory_manager, - default_cube_config, - mos_server, - mem_scheduler, - naive_mem_cube, - api_module, - vector_db, - pref_extractor, - pref_adder, - pref_retriever, - text_mem, - pref_mem, -) = components +chat_handler = ChatHandler(dependencies, search_handler, add_handler, online_bot=components.get("online_bot")) + +# Extract commonly used components for function-based handlers +# (These can be accessed from the components dict without unpacking all of them) +mem_scheduler = components["mem_scheduler"] +llm = components["llm"] +naive_mem_cube = components["naive_mem_cube"] + # ============================================================================= @@ -192,7 +177,7 @@ def chat(chat_req: ChatRequest): def get_suggestion_queries(suggestion_req: SuggestionRequest): """Get suggestion queries for a specific user with language preference.""" return handlers.suggestion_handler.handle_get_suggestion_queries( - user_id=suggestion_req.user_id, + user_id=suggestion_req.mem_cube_id, language=suggestion_req.language, message=suggestion_req.message, llm=llm, diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index e2e0be69..28f4cd77 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -218,7 +218,7 @@ def search( ) def get_relevant_subgraph( - self, query: str, top_k: int = 5, depth: int = 2, center_status: str = "activated" + self, query: str, top_k: int = 5, depth: int = 2, center_status: str = "activated", user_name: str | None = None ) -> dict[str, Any]: """ Find and merge the local neighborhood sub-graphs of the top-k @@ -249,7 +249,7 @@ def get_relevant_subgraph( query_embedding = self.embedder.embed([query])[0] # Step 2: Get top-1 similar node - similar_nodes = self.graph_store.search_by_embedding(query_embedding, top_k=top_k) + similar_nodes = self.graph_store.search_by_embedding(query_embedding, top_k=top_k, user_name=user_name) if not similar_nodes: logger.info("No similar nodes found for query embedding.") return {"core_id": None, "nodes": [], "edges": []} @@ -264,7 +264,7 @@ def get_relevant_subgraph( score = node["score"] subgraph = self.graph_store.get_subgraph( - center_id=core_id, depth=depth, center_status=center_status + center_id=core_id, depth=depth, center_status=center_status, user_name=user_name ) if subgraph is None or not subgraph["core_node"]: From 2fd351e90cce69e715132be94a8855d2aa24979b Mon Sep 17 00:00:00 2001 From: fridayL Date: Thu, 13 Nov 2025 18:59:00 +0800 Subject: [PATCH 6/7] code suffix --- src/memos/api/handlers/add_handler.py | 12 +++++++++--- src/memos/api/handlers/chat_handler.py | 4 ++-- src/memos/api/handlers/component_init.py | 3 ++- src/memos/api/handlers/memory_handler.py | 6 +++--- src/memos/api/routers/server_router.py | 5 +++-- src/memos/memories/textual/tree.py | 11 +++++++++-- 6 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 4cbf1108..48db7ae6 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -79,7 +79,7 @@ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: data=text_response_data + pref_response_data, ) - def _convert_content_messsage(seflf, memory_content: str) -> list[dict[str, str]]: + def _convert_content_messsage(self, memory_content: str) -> list[dict[str, str]]: """ Convert content string to list of message dictionaries. @@ -89,10 +89,16 @@ def _convert_content_messsage(seflf, memory_content: str) -> list[dict[str, str] Returns: List of message dictionaries """ - messages_list = [{"role": "user", "content": memory_content, "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))}] + messages_list = [ + { + "role": "user", + "content": memory_content, + "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + } + ] # for only user-str input and convert message return messages_list - + def _process_text_mem( self, add_req: APIADDRequest, diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index a36991a1..9b0048ed 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -363,8 +363,8 @@ def _build_system_prompt( if memories: memory_list = [] for i, memory in enumerate(memories, 1): - text_memory = memory.get("memory", "") - memory_list.append(f"{i}. {text_memory}") + text_memory = memory.get("memory", "") + memory_list.append(f"{i}. {text_memory}") memory_context = "\n".join(memory_list) if "{memories}" in base_prompt: diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index e0d4586d..4e696a34 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -92,7 +92,7 @@ def init_server() -> dict[str, Any]: # Get default cube configuration default_cube_config = APIConfig.get_default_cube_config() - #Get online bot setting + # Get online bot setting dingding_enabled = APIConfig.is_dingding_bot_enabled() # Build component configurations @@ -236,6 +236,7 @@ def init_server() -> dict[str, Any]: online_bot = None if dingding_enabled: from memos.memos_tools.notification_service import get_online_bot_function + online_bot = get_online_bot_function() if dingding_enabled else None logger.info("DingDing bot is enabled") diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py index eacf615c..85f339f3 100644 --- a/src/memos/api/handlers/memory_handler.py +++ b/src/memos/api/handlers/memory_handler.py @@ -4,8 +4,6 @@ This module handles retrieving all memories or specific subgraphs based on queries. """ -import random - from typing import Any, Literal from memos.api.product_models import MemoryResponse @@ -113,7 +111,9 @@ def handle_get_subgraph( """ try: # Get relevant subgraph from text memory - memories = naive_mem_cube.text_mem.get_relevant_subgraph(query, top_k=top_k, user_name=mem_cube_id) + memories = naive_mem_cube.text_mem.get_relevant_subgraph( + query, top_k=top_k, user_name=mem_cube_id + ) # Format and convert to tree structure memories_cleaned = remove_embedding_recursive(memories) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 0c382f40..d43f9ccd 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -52,7 +52,9 @@ # Initialize all handlers with dependency injection search_handler = SearchHandler(dependencies) add_handler = AddHandler(dependencies) -chat_handler = ChatHandler(dependencies, search_handler, add_handler, online_bot=components.get("online_bot")) +chat_handler = ChatHandler( + dependencies, search_handler, add_handler, online_bot=components.get("online_bot") +) # Extract commonly used components for function-based handlers # (These can be accessed from the components dict without unpacking all of them) @@ -61,7 +63,6 @@ naive_mem_cube = components["naive_mem_cube"] - # ============================================================================= # Search API Endpoints # ============================================================================= diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 28f4cd77..15a6a8b4 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -218,7 +218,12 @@ def search( ) def get_relevant_subgraph( - self, query: str, top_k: int = 5, depth: int = 2, center_status: str = "activated", user_name: str | None = None + self, + query: str, + top_k: int = 5, + depth: int = 2, + center_status: str = "activated", + user_name: str | None = None, ) -> dict[str, Any]: """ Find and merge the local neighborhood sub-graphs of the top-k @@ -249,7 +254,9 @@ def get_relevant_subgraph( query_embedding = self.embedder.embed([query])[0] # Step 2: Get top-1 similar node - similar_nodes = self.graph_store.search_by_embedding(query_embedding, top_k=top_k, user_name=user_name) + similar_nodes = self.graph_store.search_by_embedding( + query_embedding, top_k=top_k, user_name=user_name + ) if not similar_nodes: logger.info("No similar nodes found for query embedding.") return {"core_id": None, "nodes": [], "edges": []} From 736035f323312e906c9d1827835d0a7c2772cead Mon Sep 17 00:00:00 2001 From: fridayL Date: Mon, 17 Nov 2025 16:11:50 +0800 Subject: [PATCH 7/7] feat: update internet search mode --- src/memos/api/handlers/chat_handler.py | 6 +++--- src/memos/api/handlers/search_handler.py | 2 +- .../memories/textual/tree_text_memory/retrieve/searcher.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 9b0048ed..f6023e5c 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -116,7 +116,7 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An query=chat_req.query, top_k=chat_req.top_k or 10, session_id=chat_req.session_id, - mode=SearchMode.FINE, + mode=SearchMode.FAST, internet_search=chat_req.internet_search, moscube=chat_req.moscube, chat_history=chat_req.history, @@ -213,8 +213,8 @@ def generate_chat_response() -> Generator[str, None, None]: query=chat_req.query, top_k=20, session_id=chat_req.session_id, - mode=SearchMode.FINE, - internet_search=chat_req.internet_search, + mode=SearchMode.FINE if chat_req.internet_search else SearchMode.FAST, + internet_search=chat_req.internet_search, # TODO this param is not worked at fine mode moscube=chat_req.moscube, chat_history=chat_req.history, ) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 9fc8a5b2..e8e4e07d 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -245,7 +245,7 @@ def _fine_search( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, - mode=SearchMode.FAST, + mode=SearchMode.FINE, manual_close_internet=not search_req.internet_search, moscube=search_req.moscube, search_filter=search_filter, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index f196c556..14ea8e2c 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -182,7 +182,7 @@ def _parse_task( query_embedding = None # fine mode will trigger initial embedding search - if mode == "fine": + if mode == "fine_old": logger.info("[SEARCH] Fine mode: embedding search") query_embedding = self.embedder.embed([query])[0]