From 33bb45e07b2e36a69477a03d7c1048e29b335ef0 Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 26 Nov 2025 16:34:01 +0800 Subject: [PATCH 1/5] feat: add full text memory --- src/memos/api/handlers/component_init.py | 5 +- src/memos/api/handlers/search_handler.py | 3 +- src/memos/api/product_models.py | 1 + src/memos/graph_dbs/polardb.py | 109 ++++++++++++++++++ src/memos/memories/textual/simple_tree.py | 3 + src/memos/memories/textual/tree.py | 13 ++- .../tree_text_memory/retrieve/recall.py | 22 +++- .../retrieve/retrieve_utils.py | 27 +++++ .../tree_text_memory/retrieve/searcher.py | 82 +++++++++++-- 9 files changed, 250 insertions(+), 15 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 89e61e79d..a01d8fc2a 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -40,6 +40,7 @@ from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer if TYPE_CHECKING: @@ -142,7 +143,7 @@ def init_server() -> dict[str, Any]: ) logger.debug("Memory manager initialized") - + tokenizer = FastTokenizer() # Initialize text memory text_mem = SimpleTreeTextMemory( llm=llm, @@ -153,6 +154,7 @@ def init_server() -> dict[str, Any]: memory_manager=memory_manager, config=default_cube_config.text_mem.config, internet_retriever=internet_retriever, + tokenizer=tokenizer, ) logger.debug("Text memory initialized") @@ -270,7 +272,6 @@ def init_server() -> dict[str, Any]: online_bot = get_online_bot_function() if dingding_enabled else None logger.info("DingDing bot is enabled") - # Return all components as a dictionary for easy access and extension return { "graph_db": graph_db, diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 7d7d52dc4..6e2f6f712 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -191,7 +191,7 @@ def _fast_search( """ target_session_id = search_req.session_id or "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - + plugin = bool(search_req.info is not None and search_req.info.get("origin_model")) search_results = self.naive_mem_cube.text_mem.search( query=search_req.query, user_name=user_context.mem_cube_id, @@ -205,6 +205,7 @@ def _fast_search( "session_id": target_session_id, "chat_history": search_req.chat_history, }, + plugin=plugin, ) formatted_memories = [format_memory_item(data) for data in search_results] diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index f7f0304c7..85c4d21cd 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -185,6 +185,7 @@ class APISearchRequest(BaseRequest): ) include_preference: bool = Field(True, description="Whether to handle preference memory") pref_top_k: int = Field(6, description="Number of preference results to return") + info: dict | None = Field(None, description="Info for search") class APIADDRequest(BaseRequest): diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index da1635296..eaa0a6881 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1450,6 +1450,115 @@ def get_context_chain(self, id: str, type: str = "FOLLOWS") -> list[str]: """Get the ordered context chain starting from a node.""" raise NotImplementedError + @timed + def search_by_fulltext( + self, + query_words: list[str], + top_k: int = 10, + scope: str | None = None, + status: str | None = None, + threshold: float | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + tsvector_field: str = "properties_tsvector_zh", + tsquery_config: str = "jiebaqry", + **kwargs, + ) -> list[dict]: + """ + Full-text search functionality using PostgreSQL's full-text search capabilities. + + Args: + query_text: query text + top_k: maximum number of results to return + scope: memory type filter (memory_type) + status: status filter, defaults to "activated" + threshold: similarity threshold filter + search_filter: additional property filter conditions + user_name: username filter + tsvector_field: full-text index field name, defaults to properties_tsvector_zh_1 + tsquery_config: full-text search configuration, defaults to jiebaqry (Chinese word segmentation) + **kwargs: other parameters (e.g. cube_name) + + Returns: + list[dict]: result list containing id and score + """ + # Build WHERE clause dynamically, same as search_by_embedding + where_clauses = [] + + if scope: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"memory_type\"'::agtype) = '\"{scope}\"'::agtype" + ) + if status: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"{status}\"'::agtype" + ) + else: + where_clauses.append( + "ag_catalog.agtype_access_operator(properties, '\"status\"'::agtype) = '\"activated\"'::agtype" + ) + + # Add user_name filter + user_name = user_name if user_name else self.config.user_name + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"user_name\"'::agtype) = '\"{user_name}\"'::agtype" + ) + + # Add search_filter conditions + if search_filter: + for key, value in search_filter.items(): + if isinstance(value, str): + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = '\"{value}\"'::agtype" + ) + else: + where_clauses.append( + f"ag_catalog.agtype_access_operator(properties, '\"{key}\"'::agtype) = {value}::agtype" + ) + + # Add fulltext search condition + # Convert query_text to OR query format: "word1 | word2 | word3" + tsquery_string = " | ".join(query_words) + + where_clauses.append(f"{tsvector_field} @@ to_tsquery('{tsquery_config}', %s)") + + where_clause = f"WHERE {' AND '.join(where_clauses)}" if where_clauses else "" + + # Build fulltext search query + query = f""" + SELECT + ag_catalog.agtype_access_operator(properties, '"id"'::agtype) AS old_id, + agtype_object_field_text(properties, 'memory') as memory_text, + ts_rank({tsvector_field}, to_tsquery('{tsquery_config}', %s)) as rank + FROM "{self.db_name}_graph"."Memory" + {where_clause} + ORDER BY rank DESC + LIMIT {top_k}; + """ + + params = [tsquery_string, tsquery_string] + + conn = self._get_connection() + try: + with conn.cursor() as cursor: + cursor.execute(query, params) + results = cursor.fetchall() + output = [] + for row in results: + oldid = row[0] # old_id + rank = row[2] # rank score + + id_val = str(oldid) + score_val = float(rank) + + # Apply threshold filter if specified + if threshold is None or score_val >= threshold: + output.append({"id": id_val, "score": score_val}) + + return output[:top_k] + finally: + self._return_connection(conn) + @timed def search_by_embedding( self, diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 05e62e3ee..c67271f76 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -9,6 +9,7 @@ from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import FastTokenizer from memos.reranker.base import BaseReranker @@ -35,6 +36,7 @@ def __init__( config: TreeTextMemoryConfig, internet_retriever: None = None, is_reorganize: bool = False, + tokenizer: FastTokenizer | None = None, ): """Initialize memory with the given configuration.""" self.config: TreeTextMemoryConfig = config @@ -51,6 +53,7 @@ def __init__( if self.search_strategy and self.search_strategy.get("bm25", False) else None ) + self.tokenizer = tokenizer self.reranker = reranker self.memory_manager: MemoryManager = memory_manager # Create internet retriever if configured diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 1b2355bc8..60cc25263 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -89,6 +89,7 @@ def __init__(self, config: TreeTextMemoryConfig): ) else: logger.info("No internet retriever configured") + self.tokenizer = None def add( self, @@ -165,6 +166,7 @@ def search( moscube: bool = False, search_filter: dict | None = None, user_name: str | None = None, + **kwargs, ) -> list[TextualMemoryItem]: """Search for memories based on a query. User query -> TaskGoalParser -> MemoryPathResolver -> @@ -199,6 +201,7 @@ def search( moscube=moscube, search_strategy=self.search_strategy, manual_close_internet=manual_close_internet, + tokenizer=self.tokenizer, ) else: searcher = Searcher( @@ -211,9 +214,17 @@ def search( moscube=moscube, search_strategy=self.search_strategy, manual_close_internet=manual_close_internet, + tokenizer=self.tokenizer, ) return searcher.search( - query, top_k, info, mode, memory_type, search_filter, user_name=user_name + query, + top_k, + info, + mode, + memory_type, + search_filter, + user_name=user_name, + plugin=kwargs.get("plugin", False), ) def get_relevant_subgraph( diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 375048900..c5bf6cade 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -143,6 +143,26 @@ def retrieve_from_cube( return list(combined.values()) + def retrieve_from_mixed( + self, + top_k: int, + memory_scope: str | None = None, + query_embedding: list[list[float]] | None = None, + search_filter: dict | None = None, + user_name: str | None = None, + use_fast_graph: bool = False, + ) -> list[TextualMemoryItem]: + """Retrieve from mixed and memory""" + vector_results = self._vector_recall( + query_embedding or [], + memory_scope, + top_k, + search_filter=search_filter, + user_name=user_name, + ) # Merge and deduplicate by ID + combined = {item.id: item for item in vector_results} + return list(combined.values()) + def _graph_recall( self, parsed_goal: ParsedTaskGoal, memory_scope: str, user_name: str | None = None, **kwargs ) -> list[TextualMemoryItem]: @@ -270,7 +290,7 @@ def _vector_recall( query_embedding: list[list[float]], memory_scope: str, top_k: int = 20, - max_num: int = 5, + max_num: int = 20, status: str = "activated", cube_name: str | None = None, search_filter: dict | None = None, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py index 3f2b41a47..824f93b26 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/retrieve_utils.py @@ -3,6 +3,8 @@ from pathlib import Path +import numpy as np + from memos.dependency import require_python_package from memos.log import get_logger @@ -376,3 +378,28 @@ def detect_lang(text): return "en" except Exception: return "en" + + +def find_best_unrelated_subgroup(sentences: list, similarity_matrix: list, bar: float = 0.8): + assert len(sentences) == len(similarity_matrix) + + num_sentence = len(sentences) + selected_sentences = [] + selected_indices = [] + for i in range(num_sentence): + can_add = True + for j in selected_indices: + if similarity_matrix[i][j] > bar: + can_add = False + break + if can_add: + selected_sentences.append(i) + selected_indices.append(i) + return selected_sentences, selected_indices + + +def cosine_similarity_matrix(embeddings: list[list[float]]) -> list[list[float]]: + norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + x_normalized = embeddings / norms + similarity_matrix = np.dot(x_normalized, x_normalized.T) + return similarity_matrix diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 933ef5af1..275cea575 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -8,7 +8,10 @@ from memos.memories.textual.item import SearchedTreeNodeTextualMemoryMetadata, TextualMemoryItem from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import ( + FastTokenizer, + cosine_similarity_matrix, detect_lang, + find_best_unrelated_subgroup, parse_json_result, ) from memos.reranker.base import BaseReranker @@ -44,6 +47,7 @@ def __init__( moscube: bool = False, search_strategy: dict | None = None, manual_close_internet: bool = True, + tokenizer: FastTokenizer | None = None, ): self.graph_store = graph_store self.embedder = embedder @@ -60,6 +64,7 @@ def __init__( self.vec_cot = search_strategy.get("cot", False) if search_strategy else False self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False self.manual_close_internet = manual_close_internet + self.tokenizer = tokenizer self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @timed @@ -90,6 +95,7 @@ def retrieve( memory_type, search_filter, user_name, + **kwargs, ) return results @@ -115,6 +121,7 @@ def search( memory_type="All", search_filter: dict | None = None, user_name: str | None = None, + **kwargs, ) -> list[TextualMemoryItem]: """ Search for memories based on a query. @@ -142,15 +149,21 @@ def search( else: logger.debug(f"[SEARCH] Received info dict: {info}") - retrieved_results = self.retrieve( - query=query, - top_k=top_k, - info=info, - mode=mode, - memory_type=memory_type, - search_filter=search_filter, - user_name=user_name, - ) + if kwargs.get("plugin"): + logger.info(f"[SEARCH] Retrieve from plugin: {query}") + retrieved_results = self._retrieve_simple( + query=query, top_k=top_k, search_filter=search_filter, user_name=user_name + ) + else: + retrieved_results = self.retrieve( + query=query, + top_k=top_k, + info=info, + mode=mode, + memory_type=memory_type, + search_filter=search_filter, + user_name=user_name, + ) final_results = self.post_retrieve( retrieved_results=retrieved_results, @@ -235,6 +248,45 @@ def _parse_task( return parsed_goal, query_embedding, context, query + @timed + def _retrieve_simple( + self, + query: str, + top_k: int, + search_filter: dict | None = None, + user_name: str | None = None, + **kwargs, + ): + """Retrieve from by keywords and embedding""" + query_words = [] + if self.tokenizer: + query_words = self.tokenizer.tokenize_mixed(query) + else: + query_words = query.strip().split() + query_words = [query, *query_words] + logger.info(f"[SIMPLESEARCH] Query words: {query_words}") + query_embeddings = self.embedder.embed(query_words) + + items = self.graph_retriever.retrieve_from_mixed( + top_k=top_k * 2, + memory_scope=None, + query_embedding=query_embeddings, + search_filter=search_filter, + user_name=user_name, + use_fast_graph=self.use_fast_graph, + ) + documents = [getattr(item, "memory", "") for item in items] + documents_embeddings = self.embedder.embed(documents) + similarity_matrix = cosine_similarity_matrix(documents_embeddings) + selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) + selected_items = [items[i] for i in selected_indices] + return self.reranker.rerank( + query=query, + query_embedding=query_embeddings[0], + graph_results=selected_items, + top_k=top_k, + ) + @timed def _retrieve_paths( self, @@ -247,6 +299,7 @@ def _retrieve_paths( memory_type, search_filter: dict | None = None, user_name: str | None = None, + **kwargs, ): """Run A/B/C retrieval paths in parallel""" tasks = [] @@ -308,7 +361,16 @@ def _retrieve_paths( "memos_cube01", ) ) - + if kwargs.get("keywords"): + tasks.append( + executor.submit( + self._retrieve_from_fulltext, + query, + parsed_goal, + query_embedding, + top_k, + ) + ) results = [] for t in tasks: results.extend(t.result()) From 9291b8b3501dd0038ef0f2b9c67fe5f4d539a089 Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 26 Nov 2025 16:38:11 +0800 Subject: [PATCH 2/5] =?UTF-8?q?fix=EF=BC=9A=20remove=20search?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../textual/tree_text_memory/retrieve/searcher.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 275cea575..02a951a09 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -361,16 +361,6 @@ def _retrieve_paths( "memos_cube01", ) ) - if kwargs.get("keywords"): - tasks.append( - executor.submit( - self._retrieve_from_fulltext, - query, - parsed_goal, - query_embedding, - top_k, - ) - ) results = [] for t in tasks: results.extend(t.result()) From 7e5c113315900a131f3ffa32f96da69cadae4349 Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 26 Nov 2025 17:43:09 +0800 Subject: [PATCH 3/5] feat: update pulgin --- src/memos/api/handlers/search_handler.py | 2 +- src/memos/api/product_models.py | 2 +- .../memories/textual/tree_text_memory/retrieve/searcher.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 6e2f6f712..c8b92e225 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -191,7 +191,7 @@ def _fast_search( """ target_session_id = search_req.session_id or "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - plugin = bool(search_req.info is not None and search_req.info.get("origin_model")) + plugin = bool(search_req.source is not None and search_req.source == "plugin") search_results = self.naive_mem_cube.text_mem.search( query=search_req.query, user_name=user_context.mem_cube_id, diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 85c4d21cd..c238e7d09 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -185,7 +185,7 @@ class APISearchRequest(BaseRequest): ) include_preference: bool = Field(True, description="Whether to handle preference memory") pref_top_k: int = Field(6, description="Number of preference results to return") - info: dict | None = Field(None, description="Info for search") + source: str | None = Field(None, description="Source of the search") class APIADDRequest(BaseRequest): diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 02a951a09..56aa92e6e 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -275,6 +275,7 @@ def _retrieve_simple( user_name=user_name, use_fast_graph=self.use_fast_graph, ) + logger.info(f"[SIMPLESEARCH] Items count: {len(items)}") documents = [getattr(item, "memory", "") for item in items] documents_embeddings = self.embedder.embed(documents) similarity_matrix = cosine_similarity_matrix(documents_embeddings) From 85cadef1e6ec39c2d5aecdfe5e6a3703feaee6f4 Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 26 Nov 2025 18:51:17 +0800 Subject: [PATCH 4/5] feat: add logger --- src/memos/memories/textual/tree_text_memory/retrieve/searcher.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 56aa92e6e..6451084cc 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -281,6 +281,7 @@ def _retrieve_simple( similarity_matrix = cosine_similarity_matrix(documents_embeddings) selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) selected_items = [items[i] for i in selected_indices] + logger.info(f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}") return self.reranker.rerank( query=query, query_embedding=query_embeddings[0], From cd2397498133341a28c33884101864ecbf415612 Mon Sep 17 00:00:00 2001 From: fridayL Date: Wed, 26 Nov 2025 20:27:34 +0800 Subject: [PATCH 5/5] feat: update score --- .../textual/tree_text_memory/retrieve/searcher.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 6451084cc..4f5feb9d9 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -105,9 +105,10 @@ def post_retrieve( top_k: int, user_name: str | None = None, info=None, + plugin=False, ): deduped = self._deduplicate_results(retrieved_results) - final_results = self._sort_and_trim(deduped, top_k) + final_results = self._sort_and_trim(deduped, top_k, plugin) self._update_usage_history(final_results, info, user_name) return final_results @@ -170,6 +171,7 @@ def search( top_k=top_k, user_name=user_name, info=None, + plugin=kwargs.get("plugin", False), ) logger.info(f"[SEARCH] Done. Total {len(final_results)} results.") @@ -281,7 +283,9 @@ def _retrieve_simple( similarity_matrix = cosine_similarity_matrix(documents_embeddings) selected_indices, _ = find_best_unrelated_subgroup(documents, similarity_matrix) selected_items = [items[i] for i in selected_indices] - logger.info(f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}") + logger.info( + f"[SIMPLESEARCH] after unrelated subgroup selection items count: {len(selected_items)}" + ) return self.reranker.rerank( query=query, query_embedding=query_embeddings[0], @@ -541,12 +545,14 @@ def _deduplicate_results(self, results): return list(deduped.values()) @timed - def _sort_and_trim(self, results, top_k): + def _sort_and_trim(self, results, top_k, plugin=False): """Sort results by score and trim to top_k""" sorted_results = sorted(results, key=lambda pair: pair[1], reverse=True)[:top_k] final_items = [] for item, score in sorted_results: + if plugin and round(score, 2) == 0.00: + continue meta_data = item.metadata.model_dump() meta_data["relativity"] = score final_items.append(