diff --git a/src/memos/graph_dbs/nebular.py b/src/memos/graph_dbs/nebular.py index 66ad894ad..45656b770 100644 --- a/src/memos/graph_dbs/nebular.py +++ b/src/memos/graph_dbs/nebular.py @@ -188,6 +188,19 @@ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "N client = cls._CLIENT_CACHE.get(key) if client is None: # Connection setting + + tmp_client = NebulaClient( + hosts=cfg.uri, + username=cfg.user, + password=cfg.password, + session_config=SessionConfig(graph=None), + session_pool_config=SessionPoolConfig(size=1, wait_timeout=3000), + ) + try: + cls._ensure_space_exists(tmp_client, cfg) + finally: + tmp_client.close() + conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None) if conn_conf is None: conn_conf = ConnectionConfig.from_defults( @@ -318,6 +331,7 @@ def __init__(self, config: NebulaGraphDBConfig): } """ + assert config.use_multi_db is False, "Multi-DB MODE IS NOT SUPPORTED" self.config = config self.db_name = config.space self.user_name = config.user_name @@ -429,15 +443,21 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None: if not self.config.use_multi_db and self.config.user_name: optional_condition = f"AND n.user_name = '{self.config.user_name}'" - query = f""" - MATCH (n@Memory) - WHERE n.memory_type = '{memory_type}' - {optional_condition} - ORDER BY n.updated_at DESC - OFFSET {keep_latest} - DETACH DELETE n - """ - self.execute_query(query) + count = self.count_nodes(memory_type) + + if count > keep_latest: + delete_query = f""" + MATCH (n@Memory) + WHERE n.memory_type = '{memory_type}' + {optional_condition} + ORDER BY n.updated_at DESC + OFFSET {keep_latest} + DETACH DELETE n + """ + try: + self.execute_query(delete_query) + except Exception as e: + logger.warning(f"Delete old mem error: {e}") @timed def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None: @@ -597,14 +617,19 @@ def get_memory_count(self, memory_type: str) -> int: return -1 @timed - def count_nodes(self, scope: str) -> int: - query = f""" - MATCH (n@Memory) - WHERE n.memory_type = "{scope}" - """ + def count_nodes(self, scope: str | None = None) -> int: + query = "MATCH (n@Memory)" + conditions = [] + + if scope: + conditions.append(f'n.memory_type = "{scope}"') if not self.config.use_multi_db and self.config.user_name: user_name = self.config.user_name - query += f"\nAND n.user_name = '{user_name}'" + conditions.append(f"n.user_name = '{user_name}'") + + if conditions: + query += "\nWHERE " + " AND ".join(conditions) + query += "\nRETURN count(n) AS count" result = self.execute_query(query) @@ -985,8 +1010,7 @@ def search_by_embedding( dim = len(vector) vector_str = ",".join(f"{float(x)}" for x in vector) gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])" - - where_clauses = [] + where_clauses = [f"n.{self.dim_field} IS NOT NULL"] if scope: where_clauses.append(f'n.memory_type = "{scope}"') if status: @@ -1008,15 +1032,12 @@ def search_by_embedding( where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" gql = f""" - MATCH (n@Memory) + let a = {gql_vector} + MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */) {where_clause} - ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC - APPROXIMATE + ORDER BY inner_product(n.{self.dim_field}, a) DESC LIMIT {top_k} - OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }} - RETURN n.id AS id, inner_product(n.{self.dim_field}, {gql_vector}) AS score - """ - + RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score""" try: result = self.execute_query(gql) except Exception as e: @@ -1471,6 +1492,25 @@ def merge_nodes(self, id1: str, id2: str) -> str: """ raise NotImplementedError + @classmethod + def _ensure_space_exists(cls, tmp_client, cfg): + """Lightweight check to ensure target graph (space) exists.""" + db_name = getattr(cfg, "space", None) + if not db_name: + logger.warning("[NebulaGraphDBSync] No `space` specified in cfg.") + return + + try: + res = tmp_client.execute("SHOW GRAPHS;") + existing = {row.values()[0].as_string() for row in res} + if db_name not in existing: + tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type;") + logger.info(f"✅ Graph `{db_name}` created before session binding.") + else: + logger.debug(f"Graph `{db_name}` already exists.") + except Exception: + logger.exception("[NebulaGraphDBSync] Failed to ensure space exists") + @timed def _ensure_database_exists(self): graph_type_name = "MemOSBgeM3Type" diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index f324f41c9..0048f4a59 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -326,10 +326,10 @@ def load(self, dir: str) -> None: except Exception as e: logger.error(f"An error occurred while loading memories: {e}") - def dump(self, dir: str) -> None: + def dump(self, dir: str, include_embedding: bool = False) -> None: """Dump memories to os.path.join(dir, self.config.memory_filename)""" try: - json_memories = self.graph_store.export_graph() + json_memories = self.graph_store.export_graph(include_embedding=include_embedding) os.makedirs(dir, exist_ok=True) memory_file = os.path.join(dir, self.config.memory_filename) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 5cc714806..b0224655c 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -67,30 +67,33 @@ def add(self, memories: list[TextualMemoryItem]) -> list[str]: except Exception as e: logger.exception("Memory processing error: ", exc_info=e) - try: - self.graph_store.remove_oldest_memory( - memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"] - ) - except Exception: - logger.warning(f"Remove WorkingMemory error: {traceback.format_exc()}") - - try: - self.graph_store.remove_oldest_memory( - memory_type="LongTermMemory", keep_latest=self.memory_size["LongTermMemory"] - ) - except Exception: - logger.warning(f"Remove LongTermMemory error: {traceback.format_exc()}") - - try: - self.graph_store.remove_oldest_memory( - memory_type="UserMemory", keep_latest=self.memory_size["UserMemory"] - ) - except Exception: - logger.warning(f"Remove UserMemory error: {traceback.format_exc()}") + # Only clean up if we're close to or over the limit + self._cleanup_memories_if_needed() self._refresh_memory_size() return added_ids + def _cleanup_memories_if_needed(self) -> None: + """ + Only clean up memories if we're close to or over the limit. + This reduces unnecessary database operations. + """ + cleanup_threshold = 0.8 # Clean up when 80% full + + for memory_type, limit in self.memory_size.items(): + current_count = self.current_memory_size.get(memory_type, 0) + threshold = int(limit * cleanup_threshold) + + # Only clean up if we're at or above the threshold + if current_count >= threshold: + try: + self.graph_store.remove_oldest_memory( + memory_type=memory_type, keep_latest=limit + ) + logger.debug(f"Cleaned up {memory_type}: {current_count} -> {limit}") + except Exception: + logger.warning(f"Remove {memory_type} error: {traceback.format_exc()}") + def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: """ Replace WorkingMemory