Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 115 additions & 6 deletions src/memos/graph_dbs/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,7 @@ def get_by_metadata(
user_name: str | None = None,
filter: dict | None = None,
knowledgebase_ids: list[str] | None = None,
user_name_flag: bool = True,
) -> list[str]:
"""
TODO:
Expand Down Expand Up @@ -876,11 +877,19 @@ def get_by_metadata(
raise ValueError(f"Unsupported operator: {op}")

# Build user_name filter with knowledgebase_ids support (OR relationship) using common method
user_name_conditions, user_name_params = self._build_user_name_and_kb_ids_conditions_cypher(
user_name=user_name,
knowledgebase_ids=knowledgebase_ids,
default_user_name=self.config.user_name,
node_alias="n",
user_name_conditions = []
user_name_params = {}
if user_name_flag:
user_name_conditions, user_name_params = (
self._build_user_name_and_kb_ids_conditions_cypher(
user_name=user_name,
knowledgebase_ids=knowledgebase_ids,
default_user_name=self.config.user_name,
node_alias="n",
)
)
print(
f"[get_by_metadata] user_name_conditions: {user_name_conditions},user_name_params: {user_name_params}"
)

# Add user_name WHERE clause
Expand Down Expand Up @@ -1425,7 +1434,7 @@ def build_filter_condition(condition_dict: dict, param_counter: list) -> tuple[s
# Use datetime() function for date comparisons
if key in ("created_at", "updated_at") or key.endswith("_at"):
condition_parts.append(
f"{node_alias}.{key} {cypher_op} datetime(${param_name})"
f"datetime({node_alias}.{key}) {cypher_op} datetime(${param_name})"
)
else:
condition_parts.append(
Expand Down Expand Up @@ -1482,6 +1491,12 @@ def build_filter_condition(condition_dict: dict, param_counter: list) -> tuple[s
if condition_str:
filter_conditions.append(f"({condition_str})")
filter_params.update(params)
else:
# Handle simple dict without "and" or "or" (e.g., {"id": "xxx"})
condition_str, params = build_filter_condition(filter, param_counter)
if condition_str:
filter_conditions.append(condition_str)
filter_params.update(params)

return filter_conditions, filter_params

Expand All @@ -1505,3 +1520,97 @@ def _parse_node(self, node_data: dict[str, Any]) -> dict[str, Any]:
break
node["sources"][idx] = json.loads(node["sources"][idx])
return {"id": node.pop("id"), "memory": node.pop("memory", ""), "metadata": node}

def delete_node_by_prams(
self,
memory_ids: list[str] | None = None,
file_ids: list[str] | None = None,
filter: dict | None = None,
) -> int:
"""
Delete nodes by memory_ids, file_ids, or filter.

Args:
memory_ids (list[str], optional): List of memory node IDs to delete.
file_ids (list[str], optional): List of file node IDs to delete.
filter (dict, optional): Filter dictionary to query matching nodes for deletion.

Returns:
int: Number of nodes deleted.
"""
# Collect all node IDs to delete
ids_to_delete = set()

# Add memory_ids if provided
if memory_ids and len(memory_ids) > 0:
ids_to_delete.update(memory_ids)

# Add file_ids if provided (treating them as node IDs)
if file_ids and len(file_ids) > 0:
ids_to_delete.update(file_ids)

# Query nodes by filter if provided
if filter:
# Use get_by_metadata with empty filters list and filter
filter_ids = self.get_by_metadata(
filters=[],
user_name=None,
filter=filter,
knowledgebase_ids=None,
user_name_flag=False,
)
ids_to_delete.update(filter_ids)

# If no IDs to delete, return 0
if not ids_to_delete:
logger.warning("[delete_node_by_prams] No nodes to delete")
return 0

# Convert to list for easier handling
ids_list = list(ids_to_delete)
logger.info(f"[delete_node_by_prams] Deleting {len(ids_list)} nodes: {ids_list}")

# Build WHERE condition for collected IDs (query n.id)
ids_where = "n.id IN $ids_to_delete"
params = {"ids_to_delete": ids_list}

# Calculate total count for logging
total_count = len(ids_list)
logger.info(
f"[delete_node_by_prams] Deleting nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
)
print(
f"[delete_node_by_prams] Deleting {total_count} nodes - memory_ids: {memory_ids}, file_ids: {file_ids}, filter: {filter}"
)

# First count matching nodes to get accurate count
count_query = f"MATCH (n:Memory) WHERE {ids_where} RETURN count(n) AS node_count"
logger.info(f"[delete_node_by_prams] count_query: {count_query}")
print(f"[delete_node_by_prams] count_query: {count_query}")

# Then delete nodes
delete_query = f"MATCH (n:Memory) WHERE {ids_where} DETACH DELETE n"
logger.info(f"[delete_node_by_prams] delete_query: {delete_query}")
print(f"[delete_node_by_prams] delete_query: {delete_query}")

deleted_count = 0
try:
with self.driver.session(database=self.db_name) as session:
# Count nodes before deletion
count_result = session.run(count_query, **params)
count_record = count_result.single()
expected_count = total_count
if count_record:
expected_count = count_record["node_count"] or total_count

# Delete nodes
session.run(delete_query, **params)
# Use the count from before deletion as the actual deleted count
deleted_count = expected_count

except Exception as e:
logger.error(f"[delete_node_by_prams] Failed to delete nodes: {e}", exc_info=True)
raise

logger.info(f"[delete_node_by_prams] Successfully deleted {deleted_count} nodes")
return deleted_count
Loading
Loading