From c7510a570415424da29437f16e6ed791aa555209 Mon Sep 17 00:00:00 2001 From: Wenqiang Wei Date: Fri, 28 Nov 2025 14:34:12 +0800 Subject: [PATCH 1/4] add filter for search_memories --- src/memos/api/product_models.py | 2 +- .../mem_scheduler/optimized_scheduler.py | 7 +- .../textual/prefer_text_memory/retrievers.py | 9 +- src/memos/memories/textual/preference.py | 5 +- .../memories/textual/simple_preference.py | 4 +- src/memos/memories/textual/tree.py | 3 +- .../tree_text_memory/retrieve/recall.py | 22 ++- .../tree_text_memory/retrieve/searcher.py | 19 +- src/memos/multi_mem_cube/single_cube.py | 15 +- src/memos/reranker/http_bge.py | 8 +- src/memos/vec_dbs/milvus.py | 171 ++++++++++++++++-- 11 files changed, 221 insertions(+), 44 deletions(-) diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 5aa617d6e..8da843683 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -469,7 +469,7 @@ class APIADDRequest(BaseRequest): ), ) - info: dict[str, str] | None = Field( + info: dict[str, Any] | None = Field( None, description=( "Additional metadata for the add request. " diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index 0e64ea9a0..e25c7cb1c 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -138,7 +138,8 @@ def mix_search_memories( target_session_id = search_req.session_id if not target_session_id: target_session_id = "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + search_priority = {"session_id": search_req.session_id} if search_req.session_id else None + search_filter = search_req.filter # Rerank Memories - reranker expects TextualMemoryItem objects @@ -155,6 +156,7 @@ def mix_search_memories( mode=SearchMode.FAST, manual_close_internet=not search_req.internet_search, search_filter=search_filter, + search_priority=search_priority, info=info, ) @@ -178,7 +180,7 @@ def mix_search_memories( query=search_req.query, # Use search_req.query instead of undefined query graph_results=history_memories, # Pass TextualMemoryItem objects directly top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k - search_filter=search_filter, + search_priority=search_priority, ) logger.info(f"Reranked {len(sorted_history_memories)} history memories.") processed_hist_mem = self.searcher.post_retrieve( @@ -234,6 +236,7 @@ def mix_search_memories( mode=SearchMode.FAST, memory_type="All", search_filter=search_filter, + search_priority=search_priority, info=info, ) else: diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 534f5d678..9e33ce587 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -17,7 +17,7 @@ def __init__(self, llm_provider=None, embedder=None, reranker=None, vector_db=No @abstractmethod def retrieve( - self, query: str, top_k: int, info: dict[str, Any] | None = None + self, query: str, top_k: int, info: dict[str, Any] | None = None, search_filter: dict[str, Any] | None = None ) -> list[TextualMemoryItem]: """Retrieve memories from the retriever.""" @@ -76,7 +76,7 @@ def _original_text_reranker( return prefs_mem def retrieve( - self, query: str, top_k: int, info: dict[str, Any] | None = None + self, query: str, top_k: int, info: dict[str, Any] | None = None, search_filter: dict[str, Any] | None = None ) -> list[TextualMemoryItem]: """Retrieve memories from the naive retriever.""" # TODO: un-support rewrite query and session filter now @@ -84,6 +84,7 @@ def retrieve( info = info.copy() # Create a copy to avoid modifying the original info.pop("chat_history", None) info.pop("session_id", None) + search_filter = {"and": [info, search_filter]} query_embeddings = self.embedder.embed([query]) # Pass as list to get list of embeddings query_embedding = query_embeddings[0] # Get the first (and only) embedding @@ -96,7 +97,7 @@ def retrieve( query, "explicit_preference", top_k * 2, - info, + search_filter, ) future_implicit = executor.submit( self.vector_db.search, @@ -104,7 +105,7 @@ def retrieve( query, "implicit_preference", top_k * 2, - info, + search_filter, ) # Wait for all results diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index 6e196e23a..fc92af063 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -76,7 +76,7 @@ def get_memory( """ return self.extractor.extract(messages, type, info) - def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: + def search(self, query: str, top_k: int, info=None, search_filter=None, **kwargs) -> list[TextualMemoryItem]: """Search for memories based on a query. Args: query (str): The query to search for. @@ -85,7 +85,8 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem Returns: list[TextualMemoryItem]: List of matching memories. """ - return self.retriever.retrieve(query, top_k, info) + print(f"search_filter for preference memory: {search_filter}") + return self.retriever.retrieve(query, top_k, info, search_filter) def load(self, dir: str) -> None: """Load memories from the specified directory. diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py index 29f30d384..496158d04 100644 --- a/src/memos/memories/textual/simple_preference.py +++ b/src/memos/memories/textual/simple_preference.py @@ -50,7 +50,7 @@ def get_memory( """ return self.extractor.extract(messages, type, info) - def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMemoryItem]: + def search(self, query: str, top_k: int, info=None, search_filter=None, **kwargs) -> list[TextualMemoryItem]: """Search for memories based on a query. Args: query (str): The query to search for. @@ -59,7 +59,7 @@ def search(self, query: str, top_k: int, info=None, **kwargs) -> list[TextualMem Returns: list[TextualMemoryItem]: List of matching memories. """ - return self.retriever.retrieve(query, top_k, info) + return self.retriever.retrieve(query, top_k, info, search_filter) def add(self, memories: list[TextualMemoryItem | dict[str, Any]]) -> list[str]: """Add memories. diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index df5e05a1f..5974a5fcf 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -162,6 +162,7 @@ def search( mode: str = "fast", memory_type: str = "All", manual_close_internet: bool = True, + search_priority: dict | None = None, search_filter: dict | None = None, user_name: str | None = None, ) -> list[TextualMemoryItem]: @@ -209,7 +210,7 @@ def search( manual_close_internet=manual_close_internet, ) return searcher.search( - query, top_k, info, mode, memory_type, search_filter, user_name=user_name + query, top_k, info, mode, memory_type, search_filter, search_priority, user_name=user_name ) def get_relevant_subgraph( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 375048900..527df1dd2 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -38,6 +38,7 @@ def retrieve( memory_scope: str, query_embedding: list[list[float]] | None = None, search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, id_filter: dict | None = None, use_fast_graph: bool = False, @@ -62,9 +63,9 @@ def retrieve( raise ValueError(f"Unsupported memory scope: {memory_scope}") if memory_scope == "WorkingMemory": - # For working memory, retrieve all entries (no filtering) + # For working memory, retrieve all entries (no session-oriented filtering) working_memories = self.graph_store.get_all_memory_items( - scope="WorkingMemory", include_embedding=False, user_name=user_name + scope="WorkingMemory", include_embedding=False, user_name=user_name, filter=search_filter ) return [TextualMemoryItem.from_dict(record) for record in working_memories[:top_k]] @@ -84,6 +85,7 @@ def retrieve( memory_scope, top_k, search_filter=search_filter, + search_priority=search_priority, user_name=user_name, ) if self.use_bm25: @@ -274,6 +276,7 @@ def _vector_recall( status: str = "activated", cube_name: str | None = None, search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, ) -> list[TextualMemoryItem]: """ @@ -283,7 +286,7 @@ def _vector_recall( if not query_embedding: return [] - def search_single(vec, filt=None): + def search_single(vec, search_priority=None, search_filter=None): return ( self.graph_store.search_by_embedding( vector=vec, @@ -291,31 +294,32 @@ def search_single(vec, filt=None): status=status, scope=memory_scope, cube_name=cube_name, - search_filter=filt, + search_filter=search_priority, + filter=search_filter, user_name=user_name, ) or [] ) def search_path_a(): - """Path A: search without filter""" + """Path A: search without priority""" path_a_hits = [] with ContextThreadPoolExecutor() as executor: futures = [ - executor.submit(search_single, vec, None) for vec in query_embedding[:max_num] + executor.submit(search_single, vec, None, search_filter) for vec in query_embedding[:max_num] ] for f in concurrent.futures.as_completed(futures): path_a_hits.extend(f.result() or []) return path_a_hits def search_path_b(): - """Path B: search with filter""" - if not search_filter: + """Path B: search with priority""" + if not search_priority: return [] path_b_hits = [] with ContextThreadPoolExecutor() as executor: futures = [ - executor.submit(search_single, vec, search_filter) + executor.submit(search_single, vec, search_priority, search_filter) for vec in query_embedding[:max_num] ] for f in concurrent.futures.as_completed(futures): diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 26ae1a723..cd04a4deb 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -69,6 +69,7 @@ def retrieve( mode="fast", memory_type="All", search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, **kwargs, ) -> list[tuple[TextualMemoryItem, float]]: @@ -76,7 +77,7 @@ def retrieve( f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" ) parsed_goal, query_embedding, context, query = self._parse_task( - query, info, mode, search_filter=search_filter, user_name=user_name + query, info, mode, search_filter=search_filter, search_priority=search_priority, user_name=user_name ) results = self._retrieve_paths( query, @@ -87,6 +88,7 @@ def retrieve( mode, memory_type, search_filter, + search_priority, user_name, ) return results @@ -112,6 +114,7 @@ def search( mode="fast", memory_type="All", search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, ) -> list[TextualMemoryItem]: """ @@ -128,6 +131,7 @@ def search( memory_type (str): Type restriction for search. ['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory'] search_filter (dict, optional): Optional metadata filters for search results. + search_priority (dict, optional): Optional metadata priority for search results. Returns: list[TextualMemoryItem]: List of matching memories. """ @@ -147,6 +151,7 @@ def search( mode=mode, memory_type=memory_type, search_filter=search_filter, + search_priority=search_priority, user_name=user_name, ) @@ -174,6 +179,7 @@ def _parse_task( mode, top_k=5, search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, ): """Parse user query, do embedding search and create context""" @@ -192,7 +198,8 @@ def _parse_task( query_embedding, top_k=top_k, status="activated", - search_filter=search_filter, + search_filter=search_priority, + filter=search_filter, user_name=user_name, ) ] @@ -244,6 +251,7 @@ def _retrieve_paths( mode, memory_type, search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, ): """Run A/B/C retrieval paths in parallel""" @@ -264,6 +272,7 @@ def _retrieve_paths( top_k, memory_type, search_filter, + search_priority, user_name, id_filter, ) @@ -277,6 +286,7 @@ def _retrieve_paths( top_k, memory_type, search_filter, + search_priority, user_name, id_filter, mode=mode, @@ -313,6 +323,7 @@ def _retrieve_from_working_memory( top_k, memory_type, search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, id_filter: dict | None = None, ): @@ -326,6 +337,7 @@ def _retrieve_from_working_memory( top_k=top_k, memory_scope="WorkingMemory", search_filter=search_filter, + search_priority=search_priority, user_name=user_name, id_filter=id_filter, use_fast_graph=self.use_fast_graph, @@ -349,6 +361,7 @@ def _retrieve_from_long_term_and_user( top_k, memory_type, search_filter: dict | None = None, + search_priority: dict | None = None, user_name: str | None = None, id_filter: dict | None = None, mode: str = "fast", @@ -378,6 +391,7 @@ def _retrieve_from_long_term_and_user( top_k=top_k * 2, memory_scope="LongTermMemory", search_filter=search_filter, + search_priority=search_priority, user_name=user_name, id_filter=id_filter, use_fast_graph=self.use_fast_graph, @@ -393,6 +407,7 @@ def _retrieve_from_long_term_and_user( top_k=top_k * 2, memory_scope="UserMemory", search_filter=search_filter, + search_priority=search_priority, user_name=user_name, id_filter=id_filter, use_fast_graph=self.use_fast_graph, diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index 92ad1a3c9..9b30d51e1 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -233,7 +233,8 @@ def _fine_search( return self._agentic_search(search_req=search_req, user_context=user_context) target_session_id = search_req.session_id or "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + search_priority = {"session_id": search_req.session_id} if search_req.session_id else None + search_filter = search_req.filter info = { "user_id": search_req.user_id, @@ -250,6 +251,7 @@ def _fine_search( manual_close_internet=not search_req.internet_search, moscube=search_req.moscube, search_filter=search_filter, + search_priority=search_priority, info=info, ) @@ -285,6 +287,7 @@ def _fine_search( top_k=retrieval_size, mode=SearchMode.FAST, memory_type="All", + search_priority=search_priority, search_filter=search_filter, info=info, ) @@ -320,7 +323,8 @@ def _search_pref( """ if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": return [] - + print(f"search_req.filter for preference memory: {search_req.filter}") + print(f"type of pref_mem: {type(self.naive_mem_cube.pref_mem)}") try: results = self.naive_mem_cube.pref_mem.search( query=search_req.query, @@ -330,6 +334,7 @@ def _search_pref( "session_id": search_req.session_id, "chat_history": search_req.chat_history, }, + search_filter=search_req.filter, ) return [format_memory_item(data) for data in results] except Exception as e: @@ -352,8 +357,9 @@ def _fast_search( List of search results """ target_session_id = search_req.session_id or "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - + search_priority = {"session_id": search_req.session_id} if search_req.session_id else None + search_filter = search_req.filter or None + print(f"type of text_mem: {type(self.naive_mem_cube.text_mem)}") search_results = self.naive_mem_cube.text_mem.search( query=search_req.query, user_name=user_context.mem_cube_id, @@ -361,6 +367,7 @@ def _fast_search( mode=SearchMode.FAST, manual_close_internet=not search_req.internet_search, search_filter=search_filter, + search_priority=search_priority, info={ "user_id": search_req.user_id, "session_id": target_session_id, diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py index db5a51fc2..764b53032 100644 --- a/src/memos/reranker/http_bge.py +++ b/src/memos/reranker/http_bge.py @@ -125,7 +125,7 @@ def rerank( query: str, graph_results: list[TextualMemoryItem], top_k: int, - search_filter: dict | None = None, + search_priority: dict | None = None, **kwargs, ) -> list[tuple[TextualMemoryItem, float]]: """ @@ -140,7 +140,7 @@ def rerank( `.memory` str field; non-strings are ignored. top_k : int Return at most this many items. - search_filter : dict | None + search_priority : dict | None, optional Currently unused. Present to keep signature compatible. Returns @@ -194,7 +194,7 @@ def rerank( raw_score = float(r.get("relevance_score", r.get("score", 0.0))) item = graph_results[idx] # generic boost - score = self._apply_boost_generic(item, raw_score, search_filter) + score = self._apply_boost_generic(item, raw_score, search_priority) scored_items.append((item, score)) scored_items.sort(key=lambda x: x[1], reverse=True) @@ -213,7 +213,7 @@ def rerank( scored_items = [] for item, raw_score in zip(graph_results, score_list, strict=False): - score = self._apply_boost_generic(item, raw_score, search_filter) + score = self._apply_boost_generic(item, raw_score, search_priority) scored_items.append((item, score)) scored_items.sort(key=lambda x: x[1], reverse=True) diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index eafee2633..4b225bf18 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -229,6 +229,7 @@ def search( List of search results with distance scores and payloads. """ # Convert filter to Milvus expression + print(f"filter for milvus: {filter}") expr = self._dict_to_expr(filter) if filter else "" search_func_map = { @@ -267,26 +268,170 @@ def search( return items def _dict_to_expr(self, filter_dict: dict[str, Any]) -> str: - """Convert a dictionary filter to a Milvus expression string.""" + """Convert a dictionary filter to a Milvus expression string. + + Supports complex query syntax with logical operators, comparison operators, + arithmetic operators, array operators, and string pattern matching. + + Args: + filter_dict: Dictionary containing filter conditions + + Returns: + Milvus expression string + """ if not filter_dict: return "" + return self._build_expression(filter_dict) + + def _build_expression(self, condition: Any) -> str: + """Build expression from condition dict or value.""" + if isinstance(condition, dict): + # Handle logical operators + if "and" in condition: + return self._handle_logical_and(condition["and"]) + elif "or" in condition: + return self._handle_logical_or(condition["or"]) + elif "not" in condition: + return self._handle_logical_not(condition["not"]) + else: + # Handle field conditions + return self._handle_field_conditions(condition) + else: + # Simple value comparison + return f"{condition}" + + def _handle_logical_and(self, conditions: list) -> str: + """Handle AND logical operator.""" + if not conditions: + return "" + expressions = [self._build_expression(cond) for cond in conditions if cond is not None] + expressions = [expr for expr in expressions if expr] + if not expressions: + return "" + return f"({' and '.join(expressions)})" + + def _handle_logical_or(self, conditions: list) -> str: + """Handle OR logical operator.""" + if not conditions: + return "" + expressions = [self._build_expression(cond) for cond in conditions if cond is not None] + expressions = [expr for expr in expressions if expr] + if not expressions: + return "" + return f"({' or '.join(expressions)})" + + def _handle_logical_not(self, condition: Any) -> str: + """Handle NOT logical operator.""" + expr = self._build_expression(condition) + if not expr: + return "" + return f"(not {expr})" + + def _handle_field_conditions(self, condition_dict: dict[str, Any]) -> str: + """Handle field-specific conditions.""" conditions = [] - for field, value in filter_dict.items(): - # Skip None values as they cause Milvus query syntax errors + + for field, value in condition_dict.items(): if value is None: continue - # For JSON fields, we need to use payload["field"] syntax - elif isinstance(value, str): - conditions.append(f"payload['{field}'] == '{value}'") - elif isinstance(value, list) and len(value) == 0: - # Skip empty lists as they cause Milvus query syntax errors - continue - elif isinstance(value, list) and len(value) > 0: - conditions.append(f"payload['{field}'] in {value}") - else: - conditions.append(f"payload['{field}'] == '{value}'") + + field_expr = self._build_field_expression(field, value) + if field_expr: + conditions.append(field_expr) + + if not conditions: + return "" return " and ".join(conditions) + + def _build_field_expression(self, field: str, value: Any) -> str: + """Build expression for a single field.""" + # Handle comparison operators + if isinstance(value, dict): + if len(value) == 1: + op, operand = next(iter(value.items())) + op_lower = op.lower() + + if op_lower == "in": + return self._handle_in_operator(field, operand) + elif op_lower == "contains": + return self._handle_contains_operator(field, operand, case_sensitive=True) + elif op_lower == "icontains": + return self._handle_contains_operator(field, operand, case_sensitive=False) + elif op_lower == "like": + return self._handle_like_operator(field, operand) + elif op_lower in ["gte", "lte", "gt", "lt", "ne"]: + return self._handle_comparison_operator(field, op_lower, operand) + else: + # Unknown operator, treat as equality + return f"payload['{field}'] == {self._format_value(operand)}" + else: + # Multiple operators, handle each one + sub_conditions = [] + for op, operand in value.items(): + op_lower = op.lower() + if op_lower in ["gte", "lte", "gt", "lt", "ne", "in", "contains", "icontains", "like"]: + sub_expr = self._build_field_expression(field, {op: operand}) + if sub_expr: + sub_conditions.append(sub_expr) + + if sub_conditions: + return f"({' and '.join(sub_conditions)})" + return "" + else: + # Simple equality + return f"payload['{field}'] == {self._format_value(value)}" + + def _handle_in_operator(self, field: str, values: list) -> str: + """Handle IN operator for arrays.""" + if not isinstance(values, list) or not values: + return "" + + formatted_values = [self._format_value(v) for v in values] + return f"payload['{field}'] in [{', '.join(formatted_values)}]" + + def _handle_contains_operator(self, field: str, value: Any, case_sensitive: bool = True) -> str: + """Handle CONTAINS/ICONTAINS operator.""" + formatted_value = self._format_value(value) + if case_sensitive: + return f"json_contains(payload['{field}'], {formatted_value})" + else: + # For case-insensitive contains, we need to use LIKE with lower case + return f"(not json_contains(payload['{field}'], {formatted_value}))" + + def _handle_like_operator(self, field: str, pattern: str) -> str: + """Handle LIKE operator for string pattern matching.""" + # Convert SQL-like pattern to Milvus-like pattern + return f"payload['{field}'] like '{pattern}'" + + def _handle_comparison_operator(self, field: str, operator: str, value: Any) -> str: + """Handle comparison operators (gte, lte, gt, lt, ne).""" + milvus_op = { + "gte": ">=", + "lte": "<=", + "gt": ">", + "lt": "<", + "ne": "!=" + }.get(operator, "==") + + formatted_value = self._format_value(value) + return f"payload['{field}'] {milvus_op} {formatted_value}" + + def _format_value(self, value: Any) -> str: + """Format value for Milvus expression.""" + if isinstance(value, str): + return f"'{value}'" + elif isinstance(value, (int, float)): + return str(value) + elif isinstance(value, bool): + return str(value).lower() + elif isinstance(value, list): + formatted_items = [self._format_value(item) for item in value] + return f"[{', '.join(formatted_items)}]" + elif value is None: + return "null" + else: + return f"'{str(value)}'" def _get_metric_type(self) -> str: """Get the metric type for search.""" From 3fd338227945bd9f67cf05a675f22124adbbd72a Mon Sep 17 00:00:00 2001 From: Wenqiang Wei Date: Fri, 28 Nov 2025 14:43:10 +0800 Subject: [PATCH 2/4] fix: data type incorrect --- src/memos/vec_dbs/milvus.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 4b225bf18..90ad4457d 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -421,7 +421,7 @@ def _format_value(self, value: Any) -> str: """Format value for Milvus expression.""" if isinstance(value, str): return f"'{value}'" - elif isinstance(value, (int, float)): + elif isinstance(value, int | float): return str(value) elif isinstance(value, bool): return str(value).lower() From d54c275815d182125c327550038517be0a5bbdb9 Mon Sep 17 00:00:00 2001 From: Wenqiang Wei Date: Fri, 28 Nov 2025 14:54:41 +0800 Subject: [PATCH 3/4] fix --- .../textual/prefer_text_memory/retrievers.py | 12 +++- src/memos/memories/textual/preference.py | 4 +- .../memories/textual/simple_preference.py | 4 +- src/memos/memories/textual/tree.py | 9 ++- .../tree_text_memory/retrieve/recall.py | 8 ++- .../tree_text_memory/retrieve/searcher.py | 7 ++- src/memos/vec_dbs/milvus.py | 58 ++++++++++--------- 7 files changed, 67 insertions(+), 35 deletions(-) diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 9e33ce587..6352d5840 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -17,7 +17,11 @@ def __init__(self, llm_provider=None, embedder=None, reranker=None, vector_db=No @abstractmethod def retrieve( - self, query: str, top_k: int, info: dict[str, Any] | None = None, search_filter: dict[str, Any] | None = None + self, + query: str, + top_k: int, + info: dict[str, Any] | None = None, + search_filter: dict[str, Any] | None = None, ) -> list[TextualMemoryItem]: """Retrieve memories from the retriever.""" @@ -76,7 +80,11 @@ def _original_text_reranker( return prefs_mem def retrieve( - self, query: str, top_k: int, info: dict[str, Any] | None = None, search_filter: dict[str, Any] | None = None + self, + query: str, + top_k: int, + info: dict[str, Any] | None = None, + search_filter: dict[str, Any] | None = None, ) -> list[TextualMemoryItem]: """Retrieve memories from the naive retriever.""" # TODO: un-support rewrite query and session filter now diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index fc92af063..c39d7d14c 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -76,7 +76,9 @@ def get_memory( """ return self.extractor.extract(messages, type, info) - def search(self, query: str, top_k: int, info=None, search_filter=None, **kwargs) -> list[TextualMemoryItem]: + def search( + self, query: str, top_k: int, info=None, search_filter=None, **kwargs + ) -> list[TextualMemoryItem]: """Search for memories based on a query. Args: query (str): The query to search for. diff --git a/src/memos/memories/textual/simple_preference.py b/src/memos/memories/textual/simple_preference.py index 496158d04..1f02132bb 100644 --- a/src/memos/memories/textual/simple_preference.py +++ b/src/memos/memories/textual/simple_preference.py @@ -50,7 +50,9 @@ def get_memory( """ return self.extractor.extract(messages, type, info) - def search(self, query: str, top_k: int, info=None, search_filter=None, **kwargs) -> list[TextualMemoryItem]: + def search( + self, query: str, top_k: int, info=None, search_filter=None, **kwargs + ) -> list[TextualMemoryItem]: """Search for memories based on a query. Args: query (str): The query to search for. diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 5974a5fcf..2a109bf71 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -210,7 +210,14 @@ def search( manual_close_internet=manual_close_internet, ) return searcher.search( - query, top_k, info, mode, memory_type, search_filter, search_priority, user_name=user_name + query, + top_k, + info, + mode, + memory_type, + search_filter, + search_priority, + user_name=user_name, ) def get_relevant_subgraph( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 527df1dd2..7fa8a87be 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -65,7 +65,10 @@ def retrieve( if memory_scope == "WorkingMemory": # For working memory, retrieve all entries (no session-oriented filtering) working_memories = self.graph_store.get_all_memory_items( - scope="WorkingMemory", include_embedding=False, user_name=user_name, filter=search_filter + scope="WorkingMemory", + include_embedding=False, + user_name=user_name, + filter=search_filter, ) return [TextualMemoryItem.from_dict(record) for record in working_memories[:top_k]] @@ -306,7 +309,8 @@ def search_path_a(): path_a_hits = [] with ContextThreadPoolExecutor() as executor: futures = [ - executor.submit(search_single, vec, None, search_filter) for vec in query_embedding[:max_num] + executor.submit(search_single, vec, None, search_filter) + for vec in query_embedding[:max_num] ] for f in concurrent.futures.as_completed(futures): path_a_hits.extend(f.result() or []) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index cd04a4deb..976be6a54 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -77,7 +77,12 @@ def retrieve( f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" ) parsed_goal, query_embedding, context, query = self._parse_task( - query, info, mode, search_filter=search_filter, search_priority=search_priority, user_name=user_name + query, + info, + mode, + search_filter=search_filter, + search_priority=search_priority, + user_name=user_name, ) results = self._retrieve_paths( query, diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index 90ad4457d..2181961d2 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -269,13 +269,13 @@ def search( def _dict_to_expr(self, filter_dict: dict[str, Any]) -> str: """Convert a dictionary filter to a Milvus expression string. - + Supports complex query syntax with logical operators, comparison operators, arithmetic operators, array operators, and string pattern matching. - + Args: filter_dict: Dictionary containing filter conditions - + Returns: Milvus expression string """ @@ -300,7 +300,7 @@ def _build_expression(self, condition: Any) -> str: else: # Simple value comparison return f"{condition}" - + def _handle_logical_and(self, conditions: list) -> str: """Handle AND logical operator.""" if not conditions: @@ -310,7 +310,7 @@ def _handle_logical_and(self, conditions: list) -> str: if not expressions: return "" return f"({' and '.join(expressions)})" - + def _handle_logical_or(self, conditions: list) -> str: """Handle OR logical operator.""" if not conditions: @@ -320,14 +320,14 @@ def _handle_logical_or(self, conditions: list) -> str: if not expressions: return "" return f"({' or '.join(expressions)})" - + def _handle_logical_not(self, condition: Any) -> str: """Handle NOT logical operator.""" expr = self._build_expression(condition) if not expr: return "" return f"(not {expr})" - + def _handle_field_conditions(self, condition_dict: dict[str, Any]) -> str: """Handle field-specific conditions.""" conditions = [] @@ -339,11 +339,11 @@ def _handle_field_conditions(self, condition_dict: dict[str, Any]) -> str: field_expr = self._build_field_expression(field, value) if field_expr: conditions.append(field_expr) - + if not conditions: return "" return " and ".join(conditions) - + def _build_field_expression(self, field: str, value: Any) -> str: """Build expression for a single field.""" # Handle comparison operators @@ -351,7 +351,7 @@ def _build_field_expression(self, field: str, value: Any) -> str: if len(value) == 1: op, operand = next(iter(value.items())) op_lower = op.lower() - + if op_lower == "in": return self._handle_in_operator(field, operand) elif op_lower == "contains": @@ -370,26 +370,36 @@ def _build_field_expression(self, field: str, value: Any) -> str: sub_conditions = [] for op, operand in value.items(): op_lower = op.lower() - if op_lower in ["gte", "lte", "gt", "lt", "ne", "in", "contains", "icontains", "like"]: + if op_lower in [ + "gte", + "lte", + "gt", + "lt", + "ne", + "in", + "contains", + "icontains", + "like", + ]: sub_expr = self._build_field_expression(field, {op: operand}) if sub_expr: sub_conditions.append(sub_expr) - + if sub_conditions: return f"({' and '.join(sub_conditions)})" return "" else: # Simple equality return f"payload['{field}'] == {self._format_value(value)}" - + def _handle_in_operator(self, field: str, values: list) -> str: """Handle IN operator for arrays.""" if not isinstance(values, list) or not values: return "" - + formatted_values = [self._format_value(v) for v in values] return f"payload['{field}'] in [{', '.join(formatted_values)}]" - + def _handle_contains_operator(self, field: str, value: Any, case_sensitive: bool = True) -> str: """Handle CONTAINS/ICONTAINS operator.""" formatted_value = self._format_value(value) @@ -398,25 +408,19 @@ def _handle_contains_operator(self, field: str, value: Any, case_sensitive: bool else: # For case-insensitive contains, we need to use LIKE with lower case return f"(not json_contains(payload['{field}'], {formatted_value}))" - + def _handle_like_operator(self, field: str, pattern: str) -> str: """Handle LIKE operator for string pattern matching.""" # Convert SQL-like pattern to Milvus-like pattern return f"payload['{field}'] like '{pattern}'" - + def _handle_comparison_operator(self, field: str, operator: str, value: Any) -> str: """Handle comparison operators (gte, lte, gt, lt, ne).""" - milvus_op = { - "gte": ">=", - "lte": "<=", - "gt": ">", - "lt": "<", - "ne": "!=" - }.get(operator, "==") - + milvus_op = {"gte": ">=", "lte": "<=", "gt": ">", "lt": "<", "ne": "!="}.get(operator, "==") + formatted_value = self._format_value(value) return f"payload['{field}'] {milvus_op} {formatted_value}" - + def _format_value(self, value: Any) -> str: """Format value for Milvus expression.""" if isinstance(value, str): @@ -431,7 +435,7 @@ def _format_value(self, value: Any) -> str: elif value is None: return "null" else: - return f"'{str(value)}'" + return f"'{value!s}'" def _get_metric_type(self) -> str: """Get the metric type for search.""" From 3b6578e821bcb023affddb3d576eb65c588cd289 Mon Sep 17 00:00:00 2001 From: Wenqiang Wei Date: Mon, 1 Dec 2025 10:23:31 +0800 Subject: [PATCH 4/4] fix textual filter bug and resolve conversation --- src/memos/graph_dbs/polardb.py | 8 ++++---- src/memos/memories/textual/preference.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index a7e60704e..f280a0673 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -3654,11 +3654,11 @@ def build_filter_condition(condition_dict: dict) -> str: if isinstance(op_value, str): escaped_value = escape_sql_string(op_value) condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype) @> '\"{escaped_value}\"'::agtype" + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> '\"{escaped_value}\"'::agtype" ) else: condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype) @> {op_value}::agtype" + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype]) @> {op_value}::agtype" ) else: # Direct property access @@ -3684,11 +3684,11 @@ def build_filter_condition(condition_dict: dict) -> str: .replace("_", "\\_") ) condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype)::text LIKE '%{escaped_value}%'" + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype])::text LIKE '%{escaped_value}%'" ) else: condition_parts.append( - f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype)::text LIKE '%{op_value}%'" + f"ag_catalog.agtype_access_operator(VARIADIC ARRAY[properties, '\"info\"'::ag_catalog.agtype, '\"{info_field}\"'::ag_catalog.agtype])::text LIKE '%{op_value}%'" ) else: # Direct property access diff --git a/src/memos/memories/textual/preference.py b/src/memos/memories/textual/preference.py index c39d7d14c..c0ed1217d 100644 --- a/src/memos/memories/textual/preference.py +++ b/src/memos/memories/textual/preference.py @@ -87,7 +87,7 @@ def search( Returns: list[TextualMemoryItem]: List of matching memories. """ - print(f"search_filter for preference memory: {search_filter}") + logger.info(f"search_filter for preference memory: {search_filter}") return self.retriever.retrieve(query, top_k, info, search_filter) def load(self, dir: str) -> None: