diff --git a/examples/mem_scheduler/api_w_scheduler.py b/examples/mem_scheduler/api_w_scheduler.py index a2184e9c..1b59543f 100644 --- a/examples/mem_scheduler/api_w_scheduler.py +++ b/examples/mem_scheduler/api_w_scheduler.py @@ -25,9 +25,10 @@ def my_test_handler(messages: list[ScheduleMessageItem]): print(f"My test handler received {len(messages)} messages:") for msg in messages: print(f" my_test_handler - {msg.item_id}: {msg.content}") - print( - f"{queue._redis_conn.xinfo_groups(queue.stream_key_prefix)} qsize: {queue.qsize()} messages:{messages}" + user_status_running = handle_scheduler_status( + user_name=USER_MEM_CUBE, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler" ) + print(f"[Monitor] Status for {USER_MEM_CUBE} after submit:", user_status_running) # 2. Register the handler @@ -59,10 +60,6 @@ def my_test_handler(messages: list[ScheduleMessageItem]): # 5.1 Monitor status for specific mem_cube while running USER_MEM_CUBE = "test_mem_cube" -user_status_running = handle_scheduler_status( - user_name=USER_MEM_CUBE, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler" -) -print(f"[Monitor] Status for {USER_MEM_CUBE} after submit:", user_status_running) # 6. Wait for messages to be processed (limited to 100 checks) print("Waiting for messages to be consumed (max 100 checks)...") diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py index a174defb..a686ac8f 100644 --- a/src/memos/api/handlers/base_handler.py +++ b/src/memos/api/handlers/base_handler.py @@ -9,6 +9,7 @@ from memos.log import get_logger from memos.mem_scheduler.base_scheduler import BaseScheduler +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher logger = get_logger(__name__) @@ -28,6 +29,7 @@ def __init__( naive_mem_cube: Any | None = None, mem_reader: Any | None = None, mem_scheduler: Any | None = None, + searcher: Any | None = None, embedder: Any | None = None, reranker: Any | None = None, graph_db: Any | None = None, @@ -58,6 +60,7 @@ def __init__( self.naive_mem_cube = naive_mem_cube self.mem_reader = mem_reader self.mem_scheduler = mem_scheduler + self.searcher = searcher self.embedder = embedder self.reranker = reranker self.graph_db = graph_db @@ -128,6 +131,11 @@ def mem_scheduler(self) -> BaseScheduler: """Get scheduler instance.""" return self.deps.mem_scheduler + @property + def searcher(self) -> Searcher: + """Get scheduler instance.""" + return self.deps.searcher + @property def embedder(self): """Get embedder instance.""" diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 4e696a34..78ed13e1 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -5,6 +5,8 @@ including databases, LLMs, memory systems, and schedulers. """ +import os + from typing import TYPE_CHECKING, Any from memos.api.config import APIConfig @@ -38,6 +40,10 @@ from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager + + +if TYPE_CHECKING: + from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) @@ -47,7 +53,7 @@ if TYPE_CHECKING: from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler - + from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher logger = get_logger(__name__) @@ -205,6 +211,13 @@ def init_server() -> dict[str, Any]: logger.debug("MemCube created") + tree_mem: TreeTextMemory = naive_mem_cube.text_mem + searcher: Searcher = tree_mem.get_searcher( + manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", + moscube=False, + ) + logger.debug("Searcher created") + # Initialize Scheduler scheduler_config_dict = APIConfig.get_scheduler_config() scheduler_config = SchedulerConfigFactory( @@ -217,16 +230,14 @@ def init_server() -> dict[str, Any]: db_engine=BaseDBManager.create_default_sqlite_engine(), mem_reader=mem_reader, ) - mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube) + mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube, searcher=searcher) logger.debug("Scheduler initialized") # Initialize SchedulerAPIModule api_module = mem_scheduler.api_module # Start scheduler if enabled - import os - - if os.getenv("API_SCHEDULER_ON", True): + if os.getenv("API_SCHEDULER_ON", "true").lower() == "true": mem_scheduler.start() logger.info("Scheduler started") @@ -253,6 +264,7 @@ def init_server() -> dict[str, Any]: "mos_server": mos_server, "mem_scheduler": mem_scheduler, "naive_mem_cube": naive_mem_cube, + "searcher": searcher, "api_module": api_module, "vector_db": vector_db, "pref_extractor": pref_extractor, diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index cf2ab73b..7d7d52dc 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -18,7 +18,7 @@ from memos.api.product_models import APISearchRequest, SearchResponse from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger -from memos.mem_scheduler.schemas.general_schemas import SearchMode +from memos.mem_scheduler.schemas.general_schemas import FINE_STRATEGY, FineStrategy, SearchMode from memos.types import MOSSearchResult, UserContext @@ -40,7 +40,7 @@ def __init__(self, dependencies: HandlerDependencies): dependencies: HandlerDependencies instance """ super().__init__(dependencies) - self._validate_dependencies("naive_mem_cube", "mem_scheduler") + self._validate_dependencies("naive_mem_cube", "mem_scheduler", "searcher") def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: """ @@ -211,11 +211,17 @@ def _fast_search( return formatted_memories + def _deep_search( + self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int + ) -> list: + logger.error("waiting to be implemented") + return [] + def _fine_search( self, search_req: APISearchRequest, user_context: UserContext, - ) -> list: + ) -> list[str]: """ Fine-grained search with query enhancement. @@ -226,11 +232,14 @@ def _fine_search( Returns: List of enhanced search results """ + if FINE_STRATEGY == FineStrategy.DEEP_SEARCH: + return self._deep_search( + search_req=search_req, user_context=user_context, max_thinking_depth=3 + ) + target_session_id = search_req.session_id or "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - searcher = self.mem_scheduler.searcher - info = { "user_id": search_req.user_id, "session_id": target_session_id, @@ -238,7 +247,7 @@ def _fine_search( } # Fine retrieve - fast_retrieved_memories = searcher.retrieve( + raw_retrieved_memories = self.searcher.retrieve( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, @@ -250,8 +259,8 @@ def _fine_search( ) # Post retrieve - fast_memories = searcher.post_retrieve( - retrieved_results=fast_retrieved_memories, + raw_memories = self.searcher.post_retrieve( + retrieved_results=raw_retrieved_memories, top_k=search_req.top_k, user_name=user_context.mem_cube_id, info=info, @@ -260,22 +269,22 @@ def _fine_search( # Enhance with query enhanced_memories, _ = self.mem_scheduler.retriever.enhance_memories_with_query( query_history=[search_req.query], - memories=fast_memories, + memories=raw_memories, ) - if len(enhanced_memories) < len(fast_memories): + if len(enhanced_memories) < len(raw_memories): logger.info( - f"Enhanced memories ({len(enhanced_memories)}) are less than fast memories ({len(fast_memories)}). Recalling for more." + f"Enhanced memories ({len(enhanced_memories)}) are less than raw memories ({len(raw_memories)}). Recalling for more." ) missing_info_hint, trigger = self.mem_scheduler.retriever.recall_for_missing_memories( query=search_req.query, - memories=fast_memories, + memories=raw_memories, ) - retrieval_size = len(fast_memories) - len(enhanced_memories) + retrieval_size = len(raw_memories) - len(enhanced_memories) logger.info(f"Retrieval size: {retrieval_size}") if trigger: logger.info(f"Triggering additional search with hint: {missing_info_hint}") - additional_memories = searcher.search( + additional_memories = self.searcher.search( query=missing_info_hint, user_name=user_context.mem_cube_id, top_k=retrieval_size, @@ -286,7 +295,7 @@ def _fine_search( ) else: logger.info("Not triggering additional search, using fast memories.") - additional_memories = fast_memories[:retrieval_size] + additional_memories = raw_memories[:retrieval_size] enhanced_memories += additional_memories logger.info( diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 657ceea0..6ad7f5cd 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -54,11 +54,11 @@ from memos.memories.activation.kv import KVCacheMemory from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE if TYPE_CHECKING: - from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.http_bge import HTTPBGEReranker @@ -141,14 +141,21 @@ def __init__(self, config: BaseSchedulerConfig): self.auth_config = None self.rabbitmq_config = None - def init_mem_cube(self, mem_cube): + def init_mem_cube( + self, + mem_cube: BaseMemCube, + searcher: Searcher | None = None, + ): self.mem_cube = mem_cube self.text_mem: TreeTextMemory = self.mem_cube.text_mem - self.searcher: Searcher = self.text_mem.get_searcher( - manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", - moscube=False, - ) self.reranker: HTTPBGEReranker = self.text_mem.reranker + if searcher is None: + self.searcher: Searcher = self.text_mem.get_searcher( + manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", + moscube=False, + ) + else: + self.searcher = searcher def initialize_modules( self, diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index 524eab78..8dd51c5b 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -18,6 +18,7 @@ class FineStrategy(str, Enum): REWRITE = "rewrite" RECREATE = "recreate" + DEEP_SEARCH = "deep_search" FILE_PATH = Path(__file__).absolute() diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index ac9f9a6d..b1a30475 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -150,7 +150,6 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time()) # acknowledge redis messages - if self.use_redis_queue and self.memos_message_queue is not None: for msg in messages: redis_message_id = msg.redis_message_id diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py index 93dd8113..f7e3eac1 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -79,8 +79,7 @@ def put( def get( self, - user_id: str, - mem_cube_id: str, + stream_key: str, block: bool = True, timeout: float | None = None, batch_size: int | None = None, @@ -91,8 +90,6 @@ def get( ) return [] - stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id) - # Return empty list if queue does not exist if stream_key not in self.queue_streams: logger.error(f"Stream {stream_key} does not exist when trying to get messages.") diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index fe7e3452..5e850c8c 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -5,6 +5,7 @@ the local memos_message_queue functionality in BaseScheduler. """ +import re import time from collections.abc import Callable @@ -165,8 +166,7 @@ def ack_message(self, user_id, mem_cube_id, redis_message_id) -> None: def get( self, - user_id: str, - mem_cube_id: str, + stream_key: str, block: bool = True, timeout: float | None = None, batch_size: int | None = None, @@ -175,8 +175,6 @@ def get( raise ConnectionError("Not connected to Redis. Redis connection not available.") try: - stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id) - # Calculate timeout for Redis redis_timeout = None if block and timeout is not None: @@ -295,17 +293,21 @@ def get_stream_keys(self) -> list[str]: if not self._redis_conn: return [] - try: - # Use match parameter and decode byte strings to regular strings - stream_keys = [ - key.decode("utf-8") if isinstance(key, bytes) else key - for key in self._redis_conn.scan_iter(match=f"{self.stream_key_prefix}:*") - ] - logger.debug(f"get stream_keys from redis: {stream_keys}") - return stream_keys - except Exception as e: - logger.error(f"Failed to list Redis stream keys: {e}") - return [] + # First, get all keys that might match (using Redis pattern matching) + redis_pattern = f"{self.stream_key_prefix}:*" + raw_keys = [ + key.decode("utf-8") if isinstance(key, bytes) else key + for key in self._redis_conn.scan_iter(match=redis_pattern) + ] + + # Second, filter using Python regex to ensure exact prefix match + # Escape special regex characters in the prefix, then add :.* + escaped_prefix = re.escape(self.stream_key_prefix) + regex_pattern = f"^{escaped_prefix}:" + stream_keys = [key for key in raw_keys if re.match(regex_pattern, key)] + + logger.debug(f"get stream_keys from redis: {stream_keys}") + return stream_keys def size(self) -> int: """ diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index 74f1ad1f..6d824f4b 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -5,8 +5,6 @@ the local memos_message_queue functionality in BaseScheduler. """ -from collections import defaultdict - from memos.log import get_logger from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue @@ -58,9 +56,10 @@ def debug_mode_on(self): def get_stream_keys(self) -> list[str]: if isinstance(self.memos_message_queue, SchedulerRedisQueue): - return self.memos_message_queue.get_stream_keys() + stream_keys = self.memos_message_queue.get_stream_keys() else: - return list(self.memos_message_queue.queue_streams.keys()) + stream_keys = list(self.memos_message_queue.queue_streams.keys()) + return stream_keys def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): """Submit messages to the message queue (either local queue or Redis).""" @@ -98,50 +97,25 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt ) def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: - # Discover all active streams via queue API - streams: list[tuple[str, str]] = [] - stream_keys = self.get_stream_keys() - for stream_key in stream_keys: - try: - parts = stream_key.split(":") - if len(parts) >= 3: - user_id = parts[-2] - mem_cube_id = parts[-1] - streams.append((user_id, mem_cube_id)) - except Exception as e: - logger.debug(f"Failed to parse stream key {stream_key}: {e}") - - if not streams: + + if len(stream_keys) == 0: return [] messages: list[ScheduleMessageItem] = [] - # Group by user: {user_id: [mem_cube_id, ...]} - - streams_by_user: dict[str, list[str]] = defaultdict(list) - for user_id, mem_cube_id in streams: - streams_by_user[user_id].append(mem_cube_id) - - # For each user, fairly consume up to batch_size across their streams - for user_id, mem_cube_ids in streams_by_user.items(): - if not mem_cube_ids: - continue - - # First pass: give each stream an equal share for this user - for mem_cube_id in mem_cube_ids: - fetched = self.memos_message_queue.get( - user_id=user_id, - mem_cube_id=mem_cube_id, - block=False, - batch_size=batch_size, - ) - - messages.extend(fetched) - - logger.info( - f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}" - ) + for stream_key in stream_keys: + fetched = self.memos_message_queue.get( + stream_key=stream_key, + block=False, + batch_size=batch_size, + ) + + messages.extend(fetched) + if len(messages) > 0: + logger.debug( + f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}" + ) return messages def clear(self):