diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 2758c9e3..3cdbedab 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -89,6 +89,7 @@ def _check_messages(messages: MessageList) -> None: feedback_content=feedback_content, writable_cube_ids=add_req.writable_cube_ids, async_mode=add_req.async_mode, + info=add_req.info, ) process_record = cube_view.feedback_memories(feedback_req) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 632c2ed4..670a1911 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -304,6 +304,7 @@ def init_server() -> dict[str, Any]: memory_manager=memory_manager, mem_reader=mem_reader, searcher=searcher, + reranker=reranker, ) # Initialize Scheduler diff --git a/src/memos/api/handlers/feedback_handler.py b/src/memos/api/handlers/feedback_handler.py index cf5c536e..217bca7c 100644 --- a/src/memos/api/handlers/feedback_handler.py +++ b/src/memos/api/handlers/feedback_handler.py @@ -28,7 +28,7 @@ def __init__(self, dependencies: HandlerDependencies): dependencies: HandlerDependencies instance """ super().__init__(dependencies) - self._validate_dependencies("mem_reader", "mem_scheduler", "searcher") + self._validate_dependencies("mem_reader", "mem_scheduler", "searcher", "reranker") def handle_feedback_memories(self, feedback_req: APIFeedbackRequest) -> MemoryResponse: """ diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 06cc2972..d583f3e1 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -684,6 +684,19 @@ class APIFeedbackRequest(BaseRequest): "async", description="feedback mode: sync or async" ) corrected_answer: bool = Field(False, description="Whether need return corrected answer") + info: dict[str, Any] | None = Field( + None, + description=( + "Additional metadata for the add request. " + "All keys can be used as filters in search. " + "Example: " + "{'agent_id': 'xxxxxx', " + "'app_id': 'xxxx', " + "'source_type': 'web', " + "'source_url': 'https://www.baidu.com', " + "'source_content': 'West Lake is the most famous scenic spot in Hangzhou'}." + ), + ) # ==== mem_cube_id is NOT enabled==== mem_cube_id: str | None = Field( None, diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 8dff5824..84e6bf19 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -1639,12 +1639,9 @@ def seach_by_keywords_like( """ params = (query_word,) - logger.info( - f"[seach_by_keywords_LIKE start:] user_name: {user_name}, query: {query}, params: {params}" - ) - conn = None + logger.info(f"[seach_by_keywords_LIKE start:] user_name: {user_name}, params: {params}") + conn = self._get_connection() try: - conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1654,7 +1651,7 @@ def seach_by_keywords_like( id_val = str(oldid) output.append({"id": id_val}) logger.info( - f"[seach_by_keywords_LIKE end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + f"[seach_by_keywords_LIKE end:] user_name: {user_name}, params: {params} recalled: {output}" ) return output finally: @@ -1739,9 +1736,8 @@ def seach_by_keywords_tfidf( logger.info( f"[seach_by_keywords_TFIDF start:] user_name: {user_name}, query: {query}, params: {params}" ) - conn = None + conn = self._get_connection() try: - conn = self._get_connection() with conn.cursor() as cursor: cursor.execute(query, params) results = cursor.fetchall() @@ -1751,6 +1747,9 @@ def seach_by_keywords_tfidf( id_val = str(oldid) output.append({"id": id_val}) + logger.info( + f"[seach_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" + ) logger.info( f"[seach_by_keywords_TFIDF end:] user_name: {user_name}, query: {query}, params: {params} recalled: {output}" ) diff --git a/src/memos/mem_feedback/feedback.py b/src/memos/mem_feedback/feedback.py index 831701b9..3d650c17 100644 --- a/src/memos/mem_feedback/feedback.py +++ b/src/memos/mem_feedback/feedback.py @@ -3,7 +3,7 @@ import json from datetime import datetime -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from tenacity import retry, stop_after_attempt, wait_exponential @@ -15,14 +15,15 @@ from memos.graph_dbs.factory import GraphStoreFactory, PolarDBGraphDB from memos.llms.factory import AzureLLM, LLMFactory, OllamaLLM, OpenAILLM from memos.mem_feedback.base import BaseMemFeedback -from memos.mem_feedback.utils import should_keep_update, split_into_chunks +from memos.mem_feedback.utils import make_mem_item, should_keep_update, split_into_chunks from memos.mem_reader.factory import MemReaderFactory from memos.mem_reader.read_multi_modal import detect_lang -from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata +from memos.memories.textual.item import TextualMemoryItem from memos.memories.textual.tree_text_memory.organize.manager import ( MemoryManager, extract_working_binding_ids, ) +from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import StopwordManager if TYPE_CHECKING: @@ -77,7 +78,9 @@ def __init__(self, config: MemFeedbackConfig): }, is_reorganize=self.is_reorganize, ) + self.stopword_manager = StopwordManager self.searcher: Searcher = None + self.reranker = None self.DB_IDX_READY = False def _batch_embed(self, texts: list[str], embed_bs: int = 5): @@ -259,7 +262,6 @@ def _del_working_binding(self, user_name, mem_items: list[TextualMemoryItem]) -> for mid in delete_ids: try: - print("del", mid) self.graph_store.delete_node(mid, user_name=user_name) logger.info( @@ -276,14 +278,30 @@ def semantics_feedback( user_name: str, memory_item: TextualMemoryItem, current_memories: list[TextualMemoryItem], - fact_history: str, + history_str: str, + chat_history_list: list, + info: dict, ): + """Modify memory at the semantic level""" lang = detect_lang("".join(memory_item.memory)) template = FEEDBACK_PROMPT_DICT["compare"][lang] if current_memories == []: - current_memories = self._retrieve( - memory_item.memory, info={"user_id": user_id}, user_name=user_name - ) + # retrieve feedback + feedback_retrieved = self._retrieve(memory_item.memory, info=info, user_name=user_name) + + # retrieve question + last_user_index = max(i for i, d in enumerate(chat_history_list) if d["role"] == "user") + last_qa = " ".join([item["content"] for item in chat_history_list[last_user_index:]]) + supplementary_retrieved = self._retrieve(last_qa, info=info, user_name=user_name) + ids = [] + for item in feedback_retrieved + supplementary_retrieved: + if item.id not in ids: + ids.append(item.id) + current_memories.append(item) + include_keys = ["agent_id", "app_id"] + current_memories = [ + item for item in current_memories if self._info_comparison(item, info, include_keys) + ] if not current_memories: operations = [{"operation": "ADD"}] @@ -300,7 +318,7 @@ def semantics_feedback( prompt = template.format( current_memories=current_memories_str, new_facts=memory_item.memory, - chat_history=fact_history, + chat_history=history_str, ) future = executor.submit(self._get_llm_response, prompt) @@ -319,7 +337,6 @@ def semantics_feedback( operations = self.standard_operations(all_operations, current_memories) - # TODO based on the operation, change memory_item memory info ; change source info logger.info(f"[Feedback memory operations]: {operations!s}") if not operations: @@ -378,9 +395,10 @@ def _feedback_memory( retrieved_memory_ids = kwargs.get("retrieved_memory_ids") or [] chat_history = kwargs.get("chat_history", []) feedback_content = kwargs.get("feedback_content", "") + info = kwargs.get("info", {}) chat_history_lis = [f"""{msg["role"]}: {msg["content"]}""" for msg in chat_history[-4:]] - fact_history = "\n".join(chat_history_lis) + f"\nuser feedback: \n{feedback_content}" + history_str = "\n".join(chat_history_lis) + f"\nuser feedback: \n{feedback_content}" retrieved_memories = [ self.graph_store.get_node(_id, user_name=user_name) for _id in retrieved_memory_ids @@ -402,7 +420,14 @@ def _feedback_memory( with ContextThreadPoolExecutor(max_workers=3) as ex: futures = { ex.submit( - self.semantics_feedback, user_id, user_name, mem, current_memories, fact_history + self.semantics_feedback, + user_id, + user_name, + mem, + current_memories, + history_str, + chat_history, + info, ): i for i, mem in enumerate(feedback_memories) } @@ -427,6 +452,17 @@ def _feedback_memory( } } + def _info_comparison(self, memory: TextualMemoryItem, _info: dict, include_keys: list) -> bool: + if not _info and not memory.metadata.info: + return True + + record = [] + for key in include_keys: + info_v = _info.get(key) + mem_v = memory.metadata.info.get(key, None) + record.append(info_v == mem_v) + return all(record) + def _retrieve(self, query: str, info=None, user_name=None): """Retrieve memory items""" retrieved_mems = self.searcher.search( @@ -460,8 +496,6 @@ def _vec_query(self, new_memories_embedding: list[float], user_name=None): self.graph_store.get_node(item["id"], user_name=user_name) for item in retrieved_ids ] - for item in current_memories: - print(item["id"], item["metadata"]["memory_type"], item["metadata"]["status"]) if not retrieved_ids: logger.info( f"[Feedback Core: _vec_query] No similar memories found for embedding query for user {user_name}." @@ -542,7 +576,17 @@ def correct_item(data): return None dehallu_res = [correct_item(item) for item in operations] - return [item for item in dehallu_res if item] + llm_operations = [item for item in dehallu_res if item] + + # Update takes precedence over add + has_update = any(item.get("operation").lower() == "update" for item in llm_operations) + if has_update: + filtered_items = [ + item for item in llm_operations if item.get("operation").lower() != "add" + ] + return filtered_items + else: + return llm_operations def _generate_answer( self, chat_history: list[MessageDict], feedback_content: str, corrected_answer: bool @@ -562,13 +606,49 @@ def _generate_answer( return self._get_llm_response(prompt, dsl=False) - def process_keyword_replace(self, user_id: str, user_name: str, kwp_judge: dict | None = None): + def _doc_filter(self, doc_scope: str, memories: list[TextualMemoryItem]): + """ + Filter the memory based on filename + """ + filename2_memid = {} + filename_mems = [] + + for item in memories: + for file_info in item.metadata.sources: + if file_info.type == "file": + file_dict = file_info.original_part + filename = file_dict["file"]["filename"] + if filename not in filename2_memid: + filename2_memid[filename] = [] + filename_mems.append(make_mem_item(filename)) + filename2_memid[filename].append(item.id) + + rerank_res = self.reranker.rerank(doc_scope, filename_mems, top_k=100) + inscope_docs = [item[0].memory for item in rerank_res if item[1] > 0.95] + + inscope_ids = [ + memid for inscope_file in inscope_docs for memid in filename2_memid[inscope_file] + ] + logger.info( + f"[Feedback Core: process_keyword_replace] These docs are in scope : {inscope_docs}, relared memids: {inscope_ids}" + ) + filter_memories = [mem for mem in memories if mem.id in inscope_ids] + return filter_memories + + def process_keyword_replace( + self, user_id: str, user_name: str, kwp_judge: dict | None = None, info: dict | None = None + ): """ - memory keyword replace process + Memory keyword replace process """ + info = info or {} doc_scope = kwp_judge.get("doc_scope", "NONE") original_word = kwp_judge.get("original") target_word = kwp_judge.get("target") + include_keys = ["agent_id", "app_id"] + + mem_info = {key: info[key] for key in info if key in include_keys} + filter_dict = {f"info.{key}": info[key] for key in mem_info} if self.DB_IDX_READY: # retrieve @@ -579,29 +659,29 @@ def process_keyword_replace(self, user_id: str, user_name: str, kwp_judge: dict must_part = f"{' & '.join(queries)}" if len(queries) > 1 else queries[0] retrieved_ids = self.graph_store.seach_by_keywords_tfidf( - [must_part], user_name=user_name + [must_part], user_name=user_name, filter=filter_dict ) if len(retrieved_ids) < 1: retrieved_ids = self.graph_store.search_by_fulltext( - queries, top_k=100, user_name=user_name + queries, top_k=100, user_name=user_name, filter=filter_dict ) else: retrieved_ids = self.graph_store.seach_by_keywords_like( - f"%{original_word}%", user_name=user_name + f"%{original_word}%", user_name=user_name, filter=filter_dict ) - # filter by doc scope mem_data = [ self.graph_store.get_node(item["id"], user_name=user_name) for item in retrieved_ids ] retrieved_memories = [TextualMemoryItem(**item) for item in mem_data] + retrieved_memories = [ + item + for item in retrieved_memories + if self._info_comparison(item, mem_info, include_keys) + ] if doc_scope != "NONE": - retrieved_memories = [ - item - for item in retrieved_memories - if doc_scope in item.metadata.sources # TODO - ] + retrieved_memories = self._doc_filter(doc_scope, retrieved_memories) if not retrieved_memories: return {"record": {"add": [], "update": []}} @@ -645,7 +725,7 @@ def process_keyword_replace(self, user_id: str, user_name: str, kwp_judge: dict update_results.append(result) except Exception as e: mem_id = future_to_info[future][0] - self.logger.error( + logger.error( f"[Feedback Core DB] Exception during update operation for memory {mem_id}: {e}" ) @@ -657,6 +737,7 @@ def process_feedback_core( user_name: str, chat_history: list[MessageDict], feedback_content: str, + info: dict | None = None, **kwargs, ) -> dict: """ @@ -678,7 +759,11 @@ def check_validity(item): try: feedback_time = kwargs.get("feedback_time") or datetime.now().isoformat() session_id = kwargs.get("session_id") - info = {"user_id": user_id, "user_name": user_name, "session_id": session_id} + if not info: + info = {"user_id": user_id, "user_name": user_name, "session_id": session_id} + else: + info.update({"user_id": user_id, "user_name": user_name, "session_id": session_id}) + logger.info( f"[Feedback Core: process_feedback_core] Starting memory feedback process for user {user_name}" ) @@ -690,7 +775,9 @@ def check_validity(item): and kwp_judge.get("original", "NONE") != "NONE" and kwp_judge.get("target", "NONE") != "NONE" ): - return self.process_keyword_replace(user_id, user_name, kwp_judge=kwp_judge) + return self.process_keyword_replace( + user_id, user_name, kwp_judge=kwp_judge, info=info + ) # llm update memory if not chat_history: @@ -728,29 +815,26 @@ def check_validity(item): value = item["corrected_info"] key = item["key"] tags = item["tags"] - feedback_memories.append( - TextualMemoryItem( - memory=value, - metadata=TreeNodeTextualMemoryMetadata( - user_id=info.get("user_id", ""), - session_id=info.get("session_id", ""), - memory_type="LongTermMemory", - status="activated", - tags=tags, - key=key, - embedding=embedding, - usage=[], - sources=[{"type": "chat"}], - user_name=user_name, - background="[Feedback update background]: " - + str(chat_history) - + "\nUser feedback: " - + str(feedback_content), - confidence=0.99, - type="fine", - ), - ) + background = ( + "[Feedback update background]: " + + str(chat_history) + + "\nUser feedback: " + + str(feedback_content) + ) + mem_item = make_mem_item( + value, + user_id=user_id, + user_name=user_name, + session_id=session_id, + tags=tags, + key=key, + embedding=embedding, + sources=[{"type": "chat"}], + background=background, + type="fine", + info=info, ) + feedback_memories.append(mem_item) mem_record = self._feedback_memory( user_id, @@ -758,6 +842,7 @@ def check_validity(item): feedback_memories, chat_history=chat_history, feedback_content=feedback_content, + info=info, **kwargs, ) logger.info( @@ -775,6 +860,7 @@ def process_feedback( user_name: str, chat_history: list[MessageDict], feedback_content: str, + info: dict[str, Any] | None = None, **kwargs, ): """ @@ -804,6 +890,7 @@ def process_feedback( user_name, chat_history, feedback_content, + info, **kwargs, ) done, pending = concurrent.futures.wait([answer_future, core_future], timeout=30) diff --git a/src/memos/mem_feedback/simple_feedback.py b/src/memos/mem_feedback/simple_feedback.py index 478fa104..429c2ea2 100644 --- a/src/memos/mem_feedback/simple_feedback.py +++ b/src/memos/mem_feedback/simple_feedback.py @@ -7,6 +7,7 @@ from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.retrieve_utils import StopwordManager from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher +from memos.reranker.base import BaseReranker logger = log.get_logger(__name__) @@ -21,6 +22,7 @@ def __init__( memory_manager: MemoryManager, mem_reader: SimpleStructMemReader, searcher: Searcher, + reranker: BaseReranker, ): self.llm = llm self.embedder = embedder @@ -29,4 +31,5 @@ def __init__( self.mem_reader = mem_reader self.searcher = searcher self.stopword_manager = StopwordManager + self.reranker = reranker self.DB_IDX_READY = False diff --git a/src/memos/mem_feedback/utils.py b/src/memos/mem_feedback/utils.py index b290993c..0033d85b 100644 --- a/src/memos/mem_feedback/utils.py +++ b/src/memos/mem_feedback/utils.py @@ -1,4 +1,4 @@ -from memos.memories.textual.item import TextualMemoryItem +from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata def estimate_tokens(text: str) -> int: @@ -48,13 +48,13 @@ def calculate_similarity(text1: str, text2: str) -> float: similarity = calculate_similarity(old_text, new_text) change_ratio = 1 - similarity - if old_len < 50: + if old_len < 200: return change_ratio < 0.5 else: - return change_ratio < 0.15 + return change_ratio < 0.2 -def split_into_chunks(memories: list[TextualMemoryItem], max_tokens_per_chunk=500): +def split_into_chunks(memories: list[TextualMemoryItem], max_tokens_per_chunk: int = 500): chunks = [] current_chunk = [] current_tokens = 0 @@ -84,3 +84,31 @@ def split_into_chunks(memories: list[TextualMemoryItem], max_tokens_per_chunk=50 chunks.append(current_chunk) return chunks + + +def make_mem_item(text: str, **kwargs) -> TextualMemoryItem: + """Build a minimal TextualMemoryItem.""" + info = kwargs.get("info", {}) + info_ = info.copy() + user_id = info_.pop("user_id", "") + session_id = info_.pop("session_id", "") + + return TextualMemoryItem( + memory=text, + metadata=TreeNodeTextualMemoryMetadata( + user_id=user_id, + session_id=session_id, + memory_type="LongTermMemory", + status="activated", + tags=kwargs.get("tags", []), + key=kwargs.get("key", ""), + embedding=kwargs.get("embedding", []), + usage=[], + sources=kwargs.get("sources", []), + user_name=kwargs.get("user_name", ""), + background=kwargs.get("background", ""), + confidence=0.99, + type=kwargs.get("type", ""), + info=info_, + ), + ) diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 59bd1c0a..71012d42 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -604,6 +604,7 @@ def _mem_feedback_message_consumer(self, messages: list[ScheduleMessageItem]) -> feedback_content=feedback_data.get("feedback_content"), feedback_time=feedback_data.get("feedback_time"), task_id=task_id, + info=feedback_data.get("info", None), ) logger.info( diff --git a/src/memos/multi_mem_cube/single_cube.py b/src/memos/multi_mem_cube/single_cube.py index f0157952..71a34beb 100644 --- a/src/memos/multi_mem_cube/single_cube.py +++ b/src/memos/multi_mem_cube/single_cube.py @@ -183,6 +183,7 @@ def feedback_memories(self, feedback_req: APIFeedbackRequest) -> dict[str, Any]: async_mode=feedback_req.async_mode, corrected_answer=feedback_req.corrected_answer, task_id=feedback_req.task_id, + info=feedback_req.info, ) self.logger.info(f"Feedback memories result: {feedback_result}") return feedback_result