diff --git a/evaluation/scripts/utils/client.py b/evaluation/scripts/utils/client.py index 9b686a131..4e7cfdbca 100644 --- a/evaluation/scripts/utils/client.py +++ b/evaluation/scripts/utils/client.py @@ -181,7 +181,7 @@ def search(self, query, user_id, top_k): "mem_cube_id": user_id, "conversation_id": "", "top_k": top_k, - "mode": "fast", + "mode": os.getenv("SEARCH_MODE", "fast"), "handle_pref_mem": False, }, ensure_ascii=False, @@ -232,7 +232,7 @@ def search(self, query, user_id, top_k): "query": query, "user_id": user_id, "memory_limit_number": top_k, - "mode": "mixture", + "mode": os.getenv("SEARCH_MODE", "fast"), } ) diff --git a/src/memos/graph_dbs/neo4j.py b/src/memos/graph_dbs/neo4j.py index f3a36a887..367b486cd 100644 --- a/src/memos/graph_dbs/neo4j.py +++ b/src/memos/graph_dbs/neo4j.py @@ -1072,7 +1072,7 @@ def drop_database(self) -> None: with self.driver.session(database=self.system_db_name) as session: session.run(f"DROP DATABASE {self.db_name} IF EXISTS") - print(f"Database '{self.db_name}' has been dropped.") + logger.info(f"Database '{self.db_name}' has been dropped.") else: raise ValueError( f"Refusing to drop protected database: {self.db_name} in " diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 971a56e04..5d50cf68f 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -72,7 +72,7 @@ def detect_embedding_field(embedding_list): if dim == 1024: return "embedding" else: - print(f"⚠️ Unknown embedding dimension {dim}, skipping this vector") + logger.warning(f"Unknown embedding dimension {dim}, skipping this vector") return None @@ -274,8 +274,6 @@ def get_memory_count(self, memory_type: str, user_name: str | None = None) -> in query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params = [f'"{memory_type}"', f'"{user_name}"'] - print(f"[get_memory_count] Query: {query}, Params: {params}") - try: with self.connection.cursor() as cursor: cursor.execute(query, params) @@ -298,13 +296,10 @@ def node_not_exist(self, scope: str, user_name: str | None = None) -> int: query += "\nLIMIT 1" params = [f'"{scope}"', f'"{user_name}"'] - print(f"[node_not_exist] Query: {query}, Params: {params}") - try: with self.connection.cursor() as cursor: cursor.execute(query, params) result = cursor.fetchone() - print(f"[node_not_exist] Query result: {result}") return 1 if result else 0 except Exception as e: logger.error(f"[node_not_exist] Query failed: {e}", exc_info=True) @@ -419,7 +414,6 @@ def update_node(self, id: str, fields: dict[str, Any], user_name: str | None = N query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(f'"{user_name}"') - print(f"[update_node] query: {query}, params: {params}") try: with self.connection.cursor() as cursor: cursor.execute(query, params) @@ -446,7 +440,6 @@ def delete_node(self, id: str, user_name: str | None = None) -> None: query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(f'"{user_name}"') - print(f"[delete_node] query: {query}, params: {params}") try: with self.connection.cursor() as cursor: cursor.execute(query, params) @@ -462,22 +455,24 @@ def create_extension(self): # Ensure in the correct database context cursor.execute("SELECT current_database();") current_db = cursor.fetchone()[0] - print(f"Current database context: {current_db}") + logger.info(f"Current database context: {current_db}") for ext_name, ext_desc in extensions: try: cursor.execute(f"create extension if not exists {ext_name};") - print(f"✅ Extension '{ext_name}' ({ext_desc}) ensured.") + logger.info(f"Extension '{ext_name}' ({ext_desc}) ensured.") except Exception as e: if "already exists" in str(e): - print(f"ℹ️ Extension '{ext_name}' ({ext_desc}) already exists.") + logger.info(f"Extension '{ext_name}' ({ext_desc}) already exists.") else: - print(f"⚠️ Failed to create extension '{ext_name}' ({ext_desc}): {e}") + logger.warning( + f"Failed to create extension '{ext_name}' ({ext_desc}): {e}" + ) logger.error( f"Failed to create extension '{ext_name}': {e}", exc_info=True ) except Exception as e: - print(f"⚠️ Failed to access database context: {e}") + logger.warning(f"Failed to access database context: {e}") logger.error(f"Failed to access database context: {e}", exc_info=True) @timed @@ -491,12 +486,12 @@ def create_graph(self): graph_exists = cursor.fetchone()[0] > 0 if graph_exists: - print(f"ℹ️ Graph '{self.db_name}_graph' already exists.") + logger.info(f"Graph '{self.db_name}_graph' already exists.") else: cursor.execute(f"select create_graph('{self.db_name}_graph');") - print(f"✅ Graph database '{self.db_name}_graph' created.") + logger.info(f"Graph database '{self.db_name}_graph' created.") except Exception as e: - print(f"⚠️ Failed to create graph '{self.db_name}_graph': {e}") + logger.warning(f"Failed to create graph '{self.db_name}_graph': {e}") logger.error(f"Failed to create graph '{self.db_name}_graph': {e}", exc_info=True) @timed @@ -506,16 +501,16 @@ def create_edge(self): valid_rel_types = {"AGGREGATE_TO", "FOLLOWS", "INFERS", "MERGED_TO", "RELATE_TO", "PARENT"} for label_name in valid_rel_types: - print(f"🪶 Creating elabel: {label_name}") + logger.info(f"Creating elabel: {label_name}") try: with self.connection.cursor() as cursor: cursor.execute(f"select create_elabel('{self.db_name}_graph', '{label_name}');") - print(f"✅ Successfully created elabel: {label_name}") + logger.info(f"Successfully created elabel: {label_name}") except Exception as e: if "already exists" in str(e): - print(f"ℹ️ Label '{label_name}' already exists, skipping.") + logger.info(f"Label '{label_name}' already exists, skipping.") else: - print(f"⚠️ Failed to create label {label_name}: {e}") + logger.warning(f"Failed to create label {label_name}: {e}") logger.error(f"Failed to create elabel '{label_name}': {e}", exc_info=True) @timed @@ -547,7 +542,6 @@ def add_edge( AND end_id = ag_catalog._make_graph_id('{self.db_name}_graph'::name, 'Memory'::name, '{target_id}'::text::cstring) ); """ - print(f"Executing add_edge: {query}") try: with self.connection.cursor() as cursor: @@ -658,7 +652,6 @@ def edge_exists( # Prepare the relationship pattern user_name = user_name if user_name else self.config.user_name - print(f"edge_exists direction: {direction}") # Prepare the match pattern with direction if direction == "OUTGOING": @@ -681,7 +674,6 @@ def edge_exists( query += "\nRETURN r" query += "\n$$) AS (r agtype)" - print(f"edge_exists query: {query}") with self.connection.cursor() as cursor: cursor.execute(query) result = cursor.fetchone() @@ -728,7 +720,6 @@ def format_param_value(value: str) -> str: query += "\nAND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(format_param_value(user_name)) - print(f"[get_node] query: {query}, params: {params}") try: with self.connection.cursor() as cursor: cursor.execute(query, params) @@ -812,7 +803,6 @@ def get_nodes( query += " AND ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = %s::agtype" params.append(f'"{user_name}"') - print(f"[get_nodes] query: {query}, params: {params}") with self.connection.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1067,8 +1057,6 @@ def get_children_with_embeddings( WHERE t.cid::graphid = m.id; """ - print("[get_children_with_embeddings] query:", query) - try: with self.connection.cursor() as cursor: cursor.execute(query) @@ -1191,7 +1179,6 @@ def get_subgraph( with self.connection.cursor() as cursor: cursor.execute(query) result = cursor.fetchone() - print("[get_subgraph] result:", result) if not result or not result[0]: return {"core_node": None, "neighbors": [], "edges": []} @@ -1346,9 +1333,6 @@ def search_by_embedding( """ params = [vector] - print( - f"[search_by_embedding] query: {query}, params: {params}, where_clause: {where_clause}" - ) with self.connection.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1417,7 +1401,6 @@ def get_by_metadata( escaped_value = f"[{', '.join(list_items)}]" else: escaped_value = f"'{value}'" if isinstance(value, str) else str(value) - print("op=============:", op) # Build WHERE conditions if op == "=": where_conditions.append(f"n.{field} = {escaped_value}") @@ -1455,16 +1438,13 @@ def get_by_metadata( $$) AS (id agtype) """ - print(f"[get_by_metadata] query: {cypher_query}, where_str: {where_str}") ids = [] try: with self.connection.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() - print("[get_by_metadata] result:", results) ids = [str(item[0]).strip('"') for item in results] except Exception as e: - print("Failed to get metadata:", {e}) logger.error(f"Failed to get metadata: {e}, query is {cypher_query}") return ids @@ -1494,7 +1474,6 @@ def get_grouped_counts1( raise ValueError("group_fields cannot be empty") final_params = params.copy() if params else {} - print("username:" + user_name) if not self.config.use_multi_db and (self.config.user_name or user_name): user_clause = "n.user_name = $user_name" final_params["user_name"] = user_name @@ -1506,14 +1485,12 @@ def get_grouped_counts1( where_clause = f"WHERE {where_clause} AND {user_clause}" else: where_clause = f"WHERE {user_clause}" - print("where_clause:" + where_clause) # Force RETURN field AS field to guarantee key match group_fields_cypher = ", ".join([f"n.{field} AS {field}" for field in group_fields]) """ # group_fields_cypher_polardb = "agtype, ".join([f"{field}" for field in group_fields]) """ group_fields_cypher_polardb = ", ".join([f"{field} agtype" for field in group_fields]) - print("group_fields_cypher_polardb:" + group_fields_cypher_polardb) query = f""" SELECT * FROM cypher('{self.db_name}_graph', $$ MATCH (n:Memory) @@ -1521,7 +1498,6 @@ def get_grouped_counts1( RETURN {group_fields_cypher}, COUNT(n) AS count1 $$ ) as ({group_fields_cypher_polardb}, count1 agtype); """ - print("get_grouped_counts:" + query) try: with self.connection.cursor() as cursor: # Handle parameterized query @@ -1620,8 +1596,6 @@ def get_grouped_counts( GROUP BY {", ".join(group_by_fields)} """ - print("[get_grouped_counts] query:", query) - try: with self.connection.cursor() as cursor: # Handle parameterized query @@ -1889,7 +1863,6 @@ def get_all_memory_items( """ nodes = [] node_ids = set() - print("[get_all_memory_items embedding true ] cypher_query:", cypher_query) try: with self.connection.cursor() as cursor: cursor.execute(cypher_query) @@ -1924,7 +1897,6 @@ def get_all_memory_items( LIMIT 100 $$) AS (nprops agtype) """ - print("[get_all_memory_items embedding false ] cypher_query:", cypher_query) nodes = [] try: @@ -1993,14 +1965,12 @@ def get_all_memory_items_old( LIMIT 100 $$) AS (nprops agtype) """ - print("[get_all_memory_items] cypher_query:", cypher_query) nodes = [] try: with self.connection.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() - print("[get_all_memory_items] results:", results) for row in results: node_agtype = row[0] @@ -2025,16 +1995,14 @@ def get_all_memory_items_old( parsed_node_data["embedding"] = properties["embedding"] nodes.append(self._parse_node(parsed_node_data)) - print( - f"[get_all_memory_items] ✅ Parsed node successfully: {properties.get('id', '')}" + logger.debug( + f"[get_all_memory_items] Parsed node successfully: {properties.get('id', '')}" ) else: - print( - f"[get_all_memory_items] ❌ Invalid node data format: {node_data}" - ) + logger.warning(f"Invalid node data format: {node_data}") except (json.JSONDecodeError, TypeError) as e: - print(f"[get_all_memory_items] ❌ JSON parsing failed: {e}") + logger.error(f"JSON parsing failed: {e}") elif node_agtype and hasattr(node_agtype, "value"): # Handle agtype object node_props = node_agtype.value @@ -2050,13 +2018,8 @@ def get_all_memory_items_old( node_data["embedding"] = node_props["embedding"] nodes.append(self._parse_node(node_data)) - print( - f"[get_all_memory_items] ✅ Parsed agtype node successfully: {node_props.get('id', '')}" - ) else: - print( - f"[get_all_memory_items] ❌ Unknown data format: {type(node_agtype)}" - ) + logger.warning(f"Unknown data format: {type(node_agtype)}") except Exception as e: logger.error(f"Failed to get memories: {e}", exc_info=True) @@ -2152,7 +2115,7 @@ def get_structure_optimization_candidates( {self.db_name}_graph."Memory" m WHERE t.id1 = m.id """ - print("[get_structure_optimization_candidates] query:", cypher_query) + logger.info(f"[get_structure_optimization_candidates] query: {cypher_query}") candidates = [] node_ids = set() @@ -2160,7 +2123,7 @@ def get_structure_optimization_candidates( with self.connection.cursor() as cursor: cursor.execute(cypher_query) results = cursor.fetchall() - print("result------", len(results)) + logger.info(f"Found {len(results)} structure optimization candidates") for row in results: if include_embedding: # When include_embedding=True, return full node object @@ -2228,9 +2191,9 @@ def get_structure_optimization_candidates( if node_id not in node_ids: candidates.append(node) node_ids.add(node_id) - print(f"✅ Parsed node successfully: {node_id}") + logger.debug(f"Parsed node successfully: {node_id}") except Exception as e: - print(f"❌ Failed to parse node: {e}") + logger.error(f"Failed to parse node: {e}") except Exception as e: logger.error(f"Failed to get structure optimization candidates: {e}", exc_info=True) @@ -2243,7 +2206,7 @@ def drop_database(self) -> None: if self._get_config_value("use_multi_db", True): with self.connection.cursor() as cursor: cursor.execute(f"SELECT drop_graph('{self.db_name}_graph', true)") - print(f"Graph '{self.db_name}_graph' has been dropped.") + logger.info(f"Graph '{self.db_name}_graph' has been dropped.") else: raise ValueError( f"Refusing to drop graph '{self.db_name}_graph' in " @@ -2498,7 +2461,7 @@ def get_neighbors_by_tag( WHERE {where_clause} """ - print(f"[get_neighbors_by_tag] query: {query}, params: {params}") + logger.debug(f"[get_neighbors_by_tag] query: {query}, params: {params}") try: with self.connection.cursor() as cursor: @@ -2646,7 +2609,7 @@ def get_neighbors_by_tag_ccl( ORDER BY (overlap_count::integer) DESC LIMIT {top_k} """ - print("get_neighbors_by_tag:", query) + logger.debug(f"get_neighbors_by_tag: {query}") try: with self.connection.cursor() as cursor: cursor.execute(query) diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 0360396af..d679eba9c 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -6,6 +6,7 @@ from collections.abc import Callable from datetime import datetime from pathlib import Path +from typing import TYPE_CHECKING from sqlalchemy.engine import Engine @@ -50,6 +51,10 @@ from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE +if TYPE_CHECKING: + from memos.mem_cube.base import BaseMemCube + + logger = get_logger(__name__) @@ -124,7 +129,7 @@ def __init__(self, config: BaseSchedulerConfig): self._context_lock = threading.Lock() self.current_user_id: UserID | str | None = None self.current_mem_cube_id: MemCubeID | str | None = None - self.current_mem_cube: GeneralMemCube | None = None + self.current_mem_cube: BaseMemCube | None = None self.auth_config_path: str | Path | None = self.config.get("auth_config_path", None) self.auth_config = None self.rabbitmq_config = None diff --git a/src/memos/mem_scheduler/general_modules/api_misc.py b/src/memos/mem_scheduler/general_modules/api_misc.py index bb993de38..1b10804fc 100644 --- a/src/memos/mem_scheduler/general_modules/api_misc.py +++ b/src/memos/mem_scheduler/general_modules/api_misc.py @@ -8,7 +8,6 @@ APISearchHistoryManager, TaskRunningStatus, ) -from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.memories.textual.item import TextualMemoryItem @@ -16,16 +15,20 @@ class SchedulerAPIModule(BaseSchedulerModule): - def __init__(self, window_size=5): + def __init__(self, window_size: int | None = None, history_memory_turns: int | None = None): super().__init__() self.window_size = window_size + self.history_memory_turns = history_memory_turns self.search_history_managers: dict[str, APIRedisDBManager] = {} - self.pre_memory_turns = 5 def get_search_history_manager(self, user_id: str, mem_cube_id: str) -> APIRedisDBManager: """Get or create a Redis manager for search history.""" + logger.info( + f"Getting search history manager for user_id: {user_id}, mem_cube_id: {mem_cube_id}" + ) key = f"search_history:{user_id}:{mem_cube_id}" if key not in self.search_history_managers: + logger.info(f"Creating new search history manager for key: {key}") self.search_history_managers[key] = APIRedisDBManager( user_id=user_id, mem_cube_id=mem_cube_id, @@ -41,8 +44,12 @@ def sync_search_data( query: str, memories: list[TextualMemoryItem], formatted_memories: Any, - conversation_id: str | None = None, + session_id: str | None = None, + conversation_turn: int = 0, ) -> Any: + logger.info( + f"Syncing search data for item_id: {item_id}, user_id: {user_id}, mem_cube_id: {mem_cube_id}" + ) # Get the search history manager manager = self.get_search_history_manager(user_id, mem_cube_id) manager.sync_with_redis(size_limit=self.window_size) @@ -59,7 +66,7 @@ def sync_search_data( query=query, formatted_memories=formatted_memories, task_status=TaskRunningStatus.COMPLETED, # Use the provided running_status - conversation_id=conversation_id, + session_id=session_id, memories=memories, ) @@ -69,18 +76,18 @@ def sync_search_data( logger.warning(f"Failed to update entry with item_id: {item_id}") else: # Add new entry based on running_status - search_entry = APIMemoryHistoryEntryItem( + entry_item = APIMemoryHistoryEntryItem( item_id=item_id, query=query, formatted_memories=formatted_memories, memories=memories, task_status=TaskRunningStatus.COMPLETED, - conversation_id=conversation_id, - created_time=get_utc_now(), + session_id=session_id, + conversation_turn=conversation_turn, ) # Add directly to completed list as APIMemoryHistoryEntryItem instance - search_history.completed_entries.append(search_entry) + search_history.completed_entries.append(entry_item) # Maintain window size if len(search_history.completed_entries) > search_history.window_size: @@ -101,37 +108,22 @@ def sync_search_data( manager.sync_with_redis(size_limit=self.window_size) return manager - def get_pre_memories(self, user_id: str, mem_cube_id: str) -> list: - """ - Get pre-computed memories from the most recent completed search entry. - - Args: - user_id: User identifier - mem_cube_id: Memory cube identifier - - Returns: - List of TextualMemoryItem objects from the most recent completed search - """ - manager = self.get_search_history_manager(user_id, mem_cube_id) - - existing_data = manager.load_from_db() - if existing_data is None: - return [] - - search_history: APISearchHistoryManager = existing_data - - # Get memories from the most recent completed entry - history_memories = search_history.get_history_memories(turns=self.pre_memory_turns) - return history_memories - - def get_history_memories(self, user_id: str, mem_cube_id: str, n: int) -> list: + def get_history_memories( + self, user_id: str, mem_cube_id: str, turns: int | None = None + ) -> list: """Get history memories for backward compatibility with tests.""" + logger.info( + f"Getting history memories for user_id: {user_id}, mem_cube_id: {mem_cube_id}, turns: {turns}" + ) manager = self.get_search_history_manager(user_id, mem_cube_id) existing_data = manager.load_from_db() if existing_data is None: return [] + if turns is None: + turns = self.history_memory_turns + # Handle different data formats if isinstance(existing_data, APISearchHistoryManager): search_history = existing_data @@ -142,4 +134,4 @@ def get_history_memories(self, user_id: str, mem_cube_id: str, n: int) -> list: except Exception: return [] - return search_history.get_history_memories(turns=n) + return search_history.get_history_memories(turns=turns) diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py index 22fb78445..a789d581e 100644 --- a/src/memos/mem_scheduler/monitors/general_monitor.py +++ b/src/memos/mem_scheduler/monitors/general_monitor.py @@ -76,8 +76,8 @@ def __init__( ] = {} # Lifecycle monitor - self.last_activation_mem_update_time = datetime.min - self.last_query_consume_time = datetime.min + self.last_activation_mem_update_time = get_utc_now() + self.last_query_consume_time = get_utc_now() self._register_lock = Lock() self._process_llm = process_llm diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index c8e2eb59e..a087ab2df 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -1,11 +1,14 @@ import json +import os +from collections import OrderedDict from typing import TYPE_CHECKING from memos.api.product_models import APISearchRequest from memos.configs.mem_scheduler import GeneralSchedulerConfig from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube +from memos.mem_cube.navie import NaiveMemCube from memos.mem_scheduler.general_modules.api_misc import SchedulerAPIModule from memos.mem_scheduler.general_scheduler import GeneralScheduler from memos.mem_scheduler.schemas.general_schemas import ( @@ -23,6 +26,7 @@ if TYPE_CHECKING: from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem + from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.http_bge import HTTPBGEReranker @@ -34,54 +38,33 @@ class OptimizedScheduler(GeneralScheduler): def __init__(self, config: GeneralSchedulerConfig): super().__init__(config) - self.api_module = SchedulerAPIModule() + self.window_size = int(os.getenv("API_SEARCH_WINDOW_SIZE", 5)) + self.history_memory_turns = int(os.getenv("API_SEARCH_HISTORY_TURNS", 5)) + self.session_counter = OrderedDict() + self.max_session_history = 5 + + self.api_module = SchedulerAPIModule( + window_size=self.window_size, + history_memory_turns=self.history_memory_turns, + ) self.register_handlers( { API_MIX_SEARCH_LABEL: self._api_mix_search_message_consumer, } ) - def search_memories( - self, - search_req: APISearchRequest, - user_context: UserContext, - mem_cube: GeneralMemCube, - mode: SearchMode, - ): - """Fine search memories function copied from server_router to avoid circular import""" - target_session_id = search_req.session_id - if not target_session_id: - target_session_id = "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - # Create MemCube and perform search - search_results = mem_cube.text_mem.search( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=mode, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, - ) - return search_results - def submit_memory_history_async_task( self, search_req: APISearchRequest, user_context: UserContext, + session_id: str | None = None, ): # Create message for async fine search message_content = { "search_req": { "query": search_req.query, "user_id": search_req.user_id, - "session_id": search_req.session_id, + "session_id": session_id, "top_k": search_req.top_k, "internet_search": search_req.internet_search, "moscube": search_req.moscube, @@ -110,6 +93,36 @@ def submit_memory_history_async_task( logger.info(f"Submitted async fine search task for user {search_req.user_id}") return async_task_id + def search_memories( + self, + search_req: APISearchRequest, + user_context: UserContext, + mem_cube: NaiveMemCube, + mode: SearchMode, + ): + """Fine search memories function copied from server_router to avoid circular import""" + target_session_id = search_req.session_id + if not target_session_id: + target_session_id = "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + # Create MemCube and perform search + search_results = mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=mode, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + return search_results + def mix_search_memories( self, search_req: APISearchRequest, @@ -122,82 +135,115 @@ def mix_search_memories( # Get mem_cube for fast search mem_cube = self.current_mem_cube - # Perform fast search - fast_memories = self.search_memories( - search_req=search_req, - user_context=user_context, - mem_cube=mem_cube, + target_session_id = search_req.session_id + if not target_session_id: + target_session_id = "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + text_mem: TreeTextMemory = mem_cube.text_mem + searcher: Searcher = text_mem.get_searcher( + manual_close_internet=not search_req.internet_search, + moscube=False, + ) + # Rerank Memories - reranker expects TextualMemoryItem objects + reranker: HTTPBGEReranker = text_mem.reranker + info = { + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + } + + fast_retrieved_memories = searcher.retrieve( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, mode=SearchMode.FAST, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info=info, ) self.submit_memory_history_async_task( search_req=search_req, user_context=user_context, + session_id=search_req.session_id, ) # Try to get pre-computed fine memories if available - pre_fine_memories = self.api_module.get_pre_memories( - user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id + history_memories = self.api_module.get_history_memories( + user_id=search_req.user_id, + mem_cube_id=user_context.mem_cube_id, + turns=self.history_memory_turns, ) - if not pre_fine_memories: + + if not history_memories: + fast_memories = searcher.post_retrieve( + retrieved_results=fast_retrieved_memories, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) # Format fast memories for return formatted_memories = [format_textual_memory_item(data) for data in fast_memories] return formatted_memories - # Merge fast and pre-computed fine memories (both are TextualMemoryItem objects) - combined_memories = fast_memories + pre_fine_memories - # Remove duplicates based on memory content - seen_contents = set() - unique_memories = [] - for memory in combined_memories: - # Both fast_memories and pre_fine_memories are TextualMemoryItem objects - content_key = memory.memory # Use .memory attribute instead of .get("content", "") - if content_key not in seen_contents: - seen_contents.add(content_key) - unique_memories.append(memory) - - # Rerank Memories - reranker expects TextualMemoryItem objects - reranker: HTTPBGEReranker = mem_cube.text_mem.reranker - - # Use search_req parameters for reranking - target_session_id = search_req.session_id - if not target_session_id: - target_session_id = "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - sorted_results = reranker.rerank( + sorted_history_memories = reranker.rerank( query=search_req.query, # Use search_req.query instead of undefined query - graph_results=unique_memories, # Pass TextualMemoryItem objects directly + graph_results=history_memories, # Pass TextualMemoryItem objects directly top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k search_filter=search_filter, ) + sorted_results = fast_retrieved_memories + sorted_history_memories + final_results = searcher.post_retrieve( + retrieved_results=sorted_results, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) + formatted_memories = [ - format_textual_memory_item(item) for item, score in sorted_results[: search_req.top_k] + format_textual_memory_item(item) for item in final_results[: search_req.top_k] ] return formatted_memories def update_search_memories_to_redis( self, - user_id: str, - mem_cube_id: str, messages: list[ScheduleMessageItem], ): - mem_cube = messages[0].mem_cube + mem_cube: NaiveMemCube = self.current_mem_cube for msg in messages: content_dict = json.loads(msg.content) search_req = content_dict["search_req"] user_context = content_dict["user_context"] - fine_memories: list[TextualMemoryItem] = self.search_memories( + session_id = search_req.get("session_id") + if session_id: + if session_id not in self.session_counter: + self.session_counter[session_id] = 0 + else: + self.session_counter[session_id] += 1 + session_turn = self.session_counter[session_id] + + # Move the current session to the end to mark it as recently used + self.session_counter.move_to_end(session_id) + + # If the counter exceeds the max size, remove the oldest item + if len(self.session_counter) > self.max_session_history: + self.session_counter.popitem(last=False) + else: + session_turn = 0 + + memories: list[TextualMemoryItem] = self.search_memories( search_req=APISearchRequest(**content_dict["search_req"]), user_context=UserContext(**content_dict["user_context"]), mem_cube=mem_cube, - mode=SearchMode.FINE, + mode=SearchMode.FAST, ) - formatted_memories = [format_textual_memory_item(data) for data in fine_memories] + formatted_memories = [format_textual_memory_item(data) for data in memories] # Sync search data to Redis self.api_module.sync_search_data( @@ -205,8 +251,10 @@ def update_search_memories_to_redis( user_id=search_req["user_id"], mem_cube_id=user_context["mem_cube_id"], query=search_req["query"], - memories=fine_memories, + memories=memories, formatted_memories=formatted_memories, + session_id=session_id, + conversation_turn=session_turn, ) def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: @@ -228,9 +276,7 @@ def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem]) messages = grouped_messages[user_id][mem_cube_id] if len(messages) == 0: return - self.update_search_memories_to_redis( - user_id=user_id, mem_cube_id=mem_cube_id, messages=messages - ) + self.update_search_memories_to_redis(messages=messages) def replace_working_memory( self, diff --git a/src/memos/mem_scheduler/schemas/api_schemas.py b/src/memos/mem_scheduler/schemas/api_schemas.py index 23eb5a848..6d0de49c4 100644 --- a/src/memos/mem_scheduler/schemas/api_schemas.py +++ b/src/memos/mem_scheduler/schemas/api_schemas.py @@ -35,11 +35,10 @@ class APIMemoryHistoryEntryItem(BaseModel, DictConversionMixin): task_status: str = Field( default="running", description="Task status: running, completed, failed" ) - conversation_id: str | None = Field( - default=None, description="Optional conversation identifier" - ) + session_id: str | None = Field(default=None, description="Optional conversation identifier") created_time: datetime = Field(description="Entry creation time", default_factory=get_utc_now) timestamp: datetime | None = Field(default=None, description="Timestamp for the entry") + conversation_turn: int = Field(default=0, description="Turn count for the same session_id") model_config = ConfigDict( arbitrary_types_allowed=True, @@ -107,11 +106,13 @@ def get_running_item_ids(self) -> list[str]: """Get all running task IDs""" return self.running_item_ids.copy() - def get_completed_entries(self) -> list[dict[str, Any]]: + def get_completed_entries(self) -> list[APIMemoryHistoryEntryItem]: """Get all completed entries""" return self.completed_entries.copy() - def get_history_memory_entries(self, turns: int | None = None) -> list[dict[str, Any]]: + def get_history_memory_entries( + self, turns: int | None = None + ) -> list[APIMemoryHistoryEntryItem]: """ Get the most recent n completed search entries, sorted by created_time. @@ -179,7 +180,7 @@ def update_entry_by_item_id( query: str, formatted_memories: Any, task_status: TaskRunningStatus, - conversation_id: str | None = None, + session_id: str | None = None, memories: list[TextualMemoryItem] | None = None, ) -> bool: """ @@ -191,7 +192,7 @@ def update_entry_by_item_id( query: New query string formatted_memories: New formatted memories task_status: New task status - conversation_id: New conversation ID + session_id: New conversation ID memories: List of TextualMemoryItem objects Returns: @@ -204,8 +205,8 @@ def update_entry_by_item_id( entry.query = query entry.formatted_memories = formatted_memories entry.task_status = task_status - if conversation_id is not None: - entry.conversation_id = conversation_id + if session_id is not None: + entry.session_id = session_id if memories is not None: entry.memories = memories diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 8d07522cd..8ce81a8bd 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -104,6 +104,34 @@ def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int """ return self.memory_manager.get_current_memory_size(user_name=user_name) + def get_searcher( + self, + manual_close_internet: bool = False, + moscube: bool = False, + ): + if (self.internet_retriever is not None) and manual_close_internet: + logger.warning( + "Internet retriever is init by config , but this search set manual_close_internet is True and will close it" + ) + searcher = Searcher( + self.dispatcher_llm, + self.graph_store, + self.embedder, + self.reranker, + internet_retriever=None, + moscube=moscube, + ) + else: + searcher = Searcher( + self.dispatcher_llm, + self.graph_store, + self.embedder, + self.reranker, + internet_retriever=self.internet_retriever, + moscube=moscube, + ) + return searcher + def search( self, query: str, diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 472bed219..56c8117e9 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -115,6 +115,34 @@ def get_current_memory_size(self) -> dict[str, int]: """ return self.memory_manager.get_current_memory_size() + def get_searcher( + self, + manual_close_internet: bool = False, + moscube: bool = False, + ): + if (self.internet_retriever is not None) and manual_close_internet: + logger.warning( + "Internet retriever is init by config , but this search set manual_close_internet is True and will close it" + ) + searcher = Searcher( + self.dispatcher_llm, + self.graph_store, + self.embedder, + self.reranker, + internet_retriever=None, + moscube=moscube, + ) + else: + searcher = Searcher( + self.dispatcher_llm, + self.graph_store, + self.embedder, + self.reranker, + internet_retriever=self.internet_retriever, + moscube=moscube, + ) + return searcher + def search( self, query: str, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 96c6c97f1..9d540b311 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -44,6 +44,49 @@ def __init__( self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") + @timed + def retrieve( + self, + query: str, + top_k: int, + info=None, + mode="fast", + memory_type="All", + search_filter: dict | None = None, + user_name: str | None = None, + **kwargs, + ) -> list[TextualMemoryItem]: + logger.info( + f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" + ) + parsed_goal, query_embedding, context, query = self._parse_task( + query, info, mode, search_filter=search_filter, user_name=user_name + ) + results = self._retrieve_paths( + query, + parsed_goal, + query_embedding, + info, + top_k, + mode, + memory_type, + search_filter, + user_name, + ) + return results + + def post_retrieve( + self, + retrieved_results: list[TextualMemoryItem], + top_k: int, + user_name: str | None = None, + info=None, + ): + deduped = self._deduplicate_results(retrieved_results) + final_results = self._sort_and_trim(deduped, top_k) + self._update_usage_history(final_results, info, user_name) + return final_results + @timed def search( self, @@ -72,9 +115,6 @@ def search( Returns: list[TextualMemoryItem]: List of matching memories. """ - logger.info( - f"[SEARCH] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" - ) if not info: logger.warning( "Please input 'info' when use tree.search so that " @@ -84,23 +124,22 @@ def search( else: logger.debug(f"[SEARCH] Received info dict: {info}") - parsed_goal, query_embedding, context, query = self._parse_task( - query, info, mode, search_filter=search_filter, user_name=user_name + retrieved_results = self.retrieve( + query=query, + top_k=top_k, + info=info, + mode=mode, + memory_type=memory_type, + search_filter=search_filter, + user_name=user_name, ) - results = self._retrieve_paths( - query, - parsed_goal, - query_embedding, - info, - top_k, - mode, - memory_type, - search_filter, - user_name, + + final_results = self.post_retrieve( + retrieved_results=retrieved_results, + top_k=top_k, + user_name=user_name, + info=None, ) - deduped = self._deduplicate_results(results) - final_results = self._sort_and_trim(deduped, top_k) - self._update_usage_history(final_results, info, user_name) logger.info(f"[SEARCH] Done. Total {len(final_results)} results.") res_results = ""