Skip to content
Merged
88 changes: 64 additions & 24 deletions src/memos/graph_dbs/nebular.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,19 @@ def _get_or_create_shared_client(cls, cfg: NebulaGraphDBConfig) -> tuple[str, "N
client = cls._CLIENT_CACHE.get(key)
if client is None:
# Connection setting

tmp_client = NebulaClient(
hosts=cfg.uri,
username=cfg.user,
password=cfg.password,
session_config=SessionConfig(graph=None),
session_pool_config=SessionPoolConfig(size=1, wait_timeout=3000),
)
try:
cls._ensure_space_exists(tmp_client, cfg)
finally:
tmp_client.close()

conn_conf: ConnectionConfig | None = getattr(cfg, "conn_config", None)
if conn_conf is None:
conn_conf = ConnectionConfig.from_defults(
Expand Down Expand Up @@ -318,6 +331,7 @@ def __init__(self, config: NebulaGraphDBConfig):
}
"""

assert config.use_multi_db is False, "Multi-DB MODE IS NOT SUPPORTED"
self.config = config
self.db_name = config.space
self.user_name = config.user_name
Expand Down Expand Up @@ -429,15 +443,21 @@ def remove_oldest_memory(self, memory_type: str, keep_latest: int) -> None:
if not self.config.use_multi_db and self.config.user_name:
optional_condition = f"AND n.user_name = '{self.config.user_name}'"

query = f"""
MATCH (n@Memory)
WHERE n.memory_type = '{memory_type}'
{optional_condition}
ORDER BY n.updated_at DESC
OFFSET {keep_latest}
DETACH DELETE n
"""
self.execute_query(query)
count = self.count_nodes(memory_type)

if count > keep_latest:
delete_query = f"""
MATCH (n@Memory)
WHERE n.memory_type = '{memory_type}'
{optional_condition}
ORDER BY n.updated_at DESC
OFFSET {keep_latest}
DETACH DELETE n
"""
try:
self.execute_query(delete_query)
except Exception as e:
logger.warning(f"Delete old mem error: {e}")

@timed
def add_node(self, id: str, memory: str, metadata: dict[str, Any]) -> None:
Expand Down Expand Up @@ -597,14 +617,19 @@ def get_memory_count(self, memory_type: str) -> int:
return -1

@timed
def count_nodes(self, scope: str) -> int:
query = f"""
MATCH (n@Memory)
WHERE n.memory_type = "{scope}"
"""
def count_nodes(self, scope: str | None = None) -> int:
query = "MATCH (n@Memory)"
conditions = []

if scope:
conditions.append(f'n.memory_type = "{scope}"')
if not self.config.use_multi_db and self.config.user_name:
user_name = self.config.user_name
query += f"\nAND n.user_name = '{user_name}'"
conditions.append(f"n.user_name = '{user_name}'")

if conditions:
query += "\nWHERE " + " AND ".join(conditions)

query += "\nRETURN count(n) AS count"

result = self.execute_query(query)
Expand Down Expand Up @@ -985,8 +1010,7 @@ def search_by_embedding(
dim = len(vector)
vector_str = ",".join(f"{float(x)}" for x in vector)
gql_vector = f"VECTOR<{dim}, FLOAT>([{vector_str}])"

where_clauses = []
where_clauses = [f"n.{self.dim_field} IS NOT NULL"]
if scope:
where_clauses.append(f'n.memory_type = "{scope}"')
if status:
Expand All @@ -1008,15 +1032,12 @@ def search_by_embedding(
where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else ""

gql = f"""
MATCH (n@Memory)
let a = {gql_vector}
MATCH (n@Memory /*+ INDEX(idx_memory_user_name) */)
{where_clause}
ORDER BY inner_product(n.{self.dim_field}, {gql_vector}) DESC
APPROXIMATE
ORDER BY inner_product(n.{self.dim_field}, a) DESC
LIMIT {top_k}
OPTIONS {{ METRIC: IP, TYPE: IVF, NPROBE: 8 }}
RETURN n.id AS id, inner_product(n.{self.dim_field}, {gql_vector}) AS score
"""

RETURN n.id AS id, inner_product(n.{self.dim_field}, a) AS score"""
try:
result = self.execute_query(gql)
except Exception as e:
Expand Down Expand Up @@ -1471,6 +1492,25 @@ def merge_nodes(self, id1: str, id2: str) -> str:
"""
raise NotImplementedError

@classmethod
def _ensure_space_exists(cls, tmp_client, cfg):
"""Lightweight check to ensure target graph (space) exists."""
db_name = getattr(cfg, "space", None)
if not db_name:
logger.warning("[NebulaGraphDBSync] No `space` specified in cfg.")
return

try:
res = tmp_client.execute("SHOW GRAPHS;")
existing = {row.values()[0].as_string() for row in res}
if db_name not in existing:
tmp_client.execute(f"CREATE GRAPH IF NOT EXISTS `{db_name}` TYPED MemOSBgeM3Type;")
logger.info(f"✅ Graph `{db_name}` created before session binding.")
else:
logger.debug(f"Graph `{db_name}` already exists.")
except Exception:
logger.exception("[NebulaGraphDBSync] Failed to ensure space exists")

@timed
def _ensure_database_exists(self):
graph_type_name = "MemOSBgeM3Type"
Expand Down
4 changes: 2 additions & 2 deletions src/memos/memories/textual/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,10 @@ def load(self, dir: str) -> None:
except Exception as e:
logger.error(f"An error occurred while loading memories: {e}")

def dump(self, dir: str) -> None:
def dump(self, dir: str, include_embedding: bool = False) -> None:
"""Dump memories to os.path.join(dir, self.config.memory_filename)"""
try:
json_memories = self.graph_store.export_graph()
json_memories = self.graph_store.export_graph(include_embedding=include_embedding)

os.makedirs(dir, exist_ok=True)
memory_file = os.path.join(dir, self.config.memory_filename)
Expand Down
43 changes: 23 additions & 20 deletions src/memos/memories/textual/tree_text_memory/organize/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,30 +67,33 @@ def add(self, memories: list[TextualMemoryItem]) -> list[str]:
except Exception as e:
logger.exception("Memory processing error: ", exc_info=e)

try:
self.graph_store.remove_oldest_memory(
memory_type="WorkingMemory", keep_latest=self.memory_size["WorkingMemory"]
)
except Exception:
logger.warning(f"Remove WorkingMemory error: {traceback.format_exc()}")

try:
self.graph_store.remove_oldest_memory(
memory_type="LongTermMemory", keep_latest=self.memory_size["LongTermMemory"]
)
except Exception:
logger.warning(f"Remove LongTermMemory error: {traceback.format_exc()}")

try:
self.graph_store.remove_oldest_memory(
memory_type="UserMemory", keep_latest=self.memory_size["UserMemory"]
)
except Exception:
logger.warning(f"Remove UserMemory error: {traceback.format_exc()}")
# Only clean up if we're close to or over the limit
self._cleanup_memories_if_needed()

self._refresh_memory_size()
return added_ids

def _cleanup_memories_if_needed(self) -> None:
"""
Only clean up memories if we're close to or over the limit.
This reduces unnecessary database operations.
"""
cleanup_threshold = 0.8 # Clean up when 80% full

for memory_type, limit in self.memory_size.items():
current_count = self.current_memory_size.get(memory_type, 0)
threshold = int(limit * cleanup_threshold)

# Only clean up if we're at or above the threshold
if current_count >= threshold:
try:
self.graph_store.remove_oldest_memory(
memory_type=memory_type, keep_latest=limit
)
logger.debug(f"Cleaned up {memory_type}: {current_count} -> {limit}")
except Exception:
logger.warning(f"Remove {memory_type} error: {traceback.format_exc()}")

def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None:
"""
Replace WorkingMemory
Expand Down