From 88993de1e45f1421d23c669da7b5cd252339d3d7 Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Fri, 7 Nov 2025 16:01:33 +0800 Subject: [PATCH 01/18] feat: simplify simple tree (#461) * feat: simplify simple tree * feat: add product_api examples * feat: modify online bot * feat: modify notification * feat: time * format: dingding report --- examples/api/__init__.py | 0 examples/api/product_api.py | 144 +++++++++++ src/memos/mem_os/product.py | 58 ++++- src/memos/memories/textual/simple_tree.py | 249 +------------------- src/memos/memories/textual/tree.py | 27 ++- src/memos/memos_tools/dinding_report_bot.py | 99 +++++--- src/memos/utils.py | 2 +- 7 files changed, 277 insertions(+), 302 deletions(-) create mode 100644 examples/api/__init__.py create mode 100644 examples/api/product_api.py diff --git a/examples/api/__init__.py b/examples/api/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/api/product_api.py b/examples/api/product_api.py new file mode 100644 index 000000000..b98f3b8e5 --- /dev/null +++ b/examples/api/product_api.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +""" +Simulate full MemOS Product API workflow: +1. Register user +2. Add memory +3. Search memory +4. Chat (stream) +""" + +import json + +import requests + + +BASE_URL = "http://0.0.0.0:8001/product" +HEADERS = {"Content-Type": "application/json"} + +index = "24" +USER_ID = f"memos_user_id_{index}" +USER_NAME = f"memos_user_alice_{index}" +MEM_CUBE_ID = f"memos_cube_id_{index}" +SESSION_ID = f"memos_session_id_{index}" +SESSION_ID2 = f"memos_session_id_{index}_s2" + + +def register_user(): + url = f"{BASE_URL}/users/register" + data = { + "user_id": USER_ID, + "user_name": USER_NAME, + "interests": "memory,retrieval,test", + "mem_cube_id": MEM_CUBE_ID, + } + print(f"[*] Registering user {USER_ID} ...") + resp = requests.post(url, headers=HEADERS, data=json.dumps(data), timeout=30) + print(resp.status_code, resp.text) + return resp.json() + + +def add_memory(): + url = f"{BASE_URL}/add" + data = { + "user_id": USER_ID, + "memory_content": "今天我在测试 MemOS 的记忆添加与检索流程。", + "messages": [{"role": "user", "content": "我今天在做系统测试"}], + "doc_path": None, + "mem_cube_id": MEM_CUBE_ID, + "source": "test_script", + "user_profile": False, + "session_id": SESSION_ID, + } + print("[*] Adding memory ...") + resp = requests.post(url, headers=HEADERS, data=json.dumps(data), timeout=30) + print(resp.status_code, resp.text) + return resp.json() + + +def search_memory(query="系统测试"): + url = f"{BASE_URL}/search" + data = { + "user_id": USER_ID, + "query": query, + "mem_cube_id": MEM_CUBE_ID, + "top_k": 5, + "session_id": SESSION_ID, + } + print("[*] Searching memory ...") + resp = requests.post(url, headers=HEADERS, data=json.dumps(data), timeout=30) + print(resp.status_code, resp.text) + return resp.json() + + +def chat_stream(query: str, session_id: str, history: list | None = None): + url = f"{BASE_URL}/chat" + data = { + "user_id": USER_ID, + "query": query, + "mem_cube_id": MEM_CUBE_ID, + "history": history, + "internet_search": False, + "moscube": False, + "session_id": session_id, + } + + print("[*] Starting streaming chat ...") + + with requests.post(url, headers=HEADERS, data=json.dumps(data), stream=True) as resp: + for raw_line in resp.iter_lines(): + if not raw_line: + continue + line = raw_line.decode("utf-8", errors="ignore") + + payload = line.removeprefix("data: ").strip() + if payload == "[DONE]": + print("[done]") + break + + try: + msg = json.loads(payload) + msg_type = msg.get("type") + msg_data = msg.get("data") or msg.get("content") + + if msg_type == "text": + print(msg_data, end="", flush=True) + elif msg_type == "reference": + print(f"\n[参考记忆] {msg_data}") + elif msg_type == "status": + pass + elif msg_type == "suggestion": + print(f"\n[建议] {msg_data}") + elif msg_type == "end": + print("\n[✅ Chat End]") + else: + print(f"\n[{msg_type}] {msg_data}") + except Exception: + try: + print(payload.encode("latin-1").decode("utf-8"), end="") + except Exception: + print(payload) + + +if __name__ == "__main__": + print("===== STEP 1: Register User =====") + register_user() + + print("\n===== STEP 2: Add Memory =====") + add_memory() + + print("\n===== STEP 3: Search Memory =====") + search_memory() + + print("\n===== STEP 4: Stream Chat =====") + chat_stream("我很开心,我今天吃了好吃的拉面", SESSION_ID, history=[]) + chat_stream( + "我刚和你说什么", + SESSION_ID, + history=[ + {"role": "user", "content": "我很开心,我今天吃了好吃的拉面"}, + {"role": "assistant", "content": "🉑"}, + ], + ) + + print("\n===== STEP 4: Stream Chat =====") + chat_stream("我刚和你说什么了呢", SESSION_ID2, history=[]) diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index 89e468bd7..9ddb77b52 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -563,6 +563,34 @@ def _extract_references_from_response(self, response: str) -> tuple[str, list[di logger.error(f"Error extracting references from response: {e}", exc_info=True) return response, [] + def _extract_struct_data_from_history(self, chat_data: list[dict]) -> dict: + """ + get struct message from chat-history + # TODO: @xcy make this more general + """ + system_content = "" + memory_content = "" + chat_history = [] + + for item in chat_data: + role = item.get("role") + content = item.get("content", "") + if role == "system": + parts = content.split("# Memories", 1) + system_content = parts[0].strip() + if len(parts) > 1: + memory_content = "# Memories" + parts[1].strip() + elif role in ("user", "assistant"): + chat_history.append({"role": role, "content": content}) + + if chat_history and chat_history[-1]["role"] == "assistant": + if len(chat_history) >= 2 and chat_history[-2]["role"] == "user": + chat_history = chat_history[:-2] + else: + chat_history = chat_history[:-1] + + return {"system": system_content, "memory": memory_content, "chat_history": chat_history} + def _chunk_response_with_tiktoken( self, response: str, chunk_size: int = 5 ) -> Generator[str, None, None]: @@ -640,23 +668,26 @@ async def _post_chat_processing( clean_response, extracted_references = self._extract_references_from_response( full_response ) + struct_message = self._extract_struct_data_from_history(current_messages) logger.info(f"Extracted {len(extracted_references)} references from response") # Send chat report notifications asynchronously if self.online_bot: + logger.info("Online Bot Open!") try: from memos.memos_tools.notification_utils import ( send_online_bot_notification_async, ) # Prepare notification data - chat_data = { - "query": query, - "user_id": user_id, - "cube_id": cube_id, - "system_prompt": system_prompt, - "full_response": full_response, - } + chat_data = {"query": query, "user_id": user_id, "cube_id": cube_id} + chat_data.update( + { + "memory": struct_message["memory"], + "chat_history": struct_message["chat_history"], + "full_response": full_response, + } + ) system_data = { "references": extracted_references, @@ -720,6 +751,7 @@ def _start_post_chat_processing( """ Asynchronous processing of logs, notifications and memory additions, handle synchronous and asynchronous environments """ + logger.info("Start post_chat_processing...") def run_async_in_thread(): """Running asynchronous tasks in a new thread""" @@ -1046,14 +1078,20 @@ def chat( memories_list = new_memories_list system_prompt = super()._build_system_prompt(memories_list, base_prompt) - history_info = [] - if history: + if history is not None: + # Use the provided history (even if it's empty) history_info = history[-20:] + else: + # Fall back to internal chat_history + if user_id not in self.chat_history_manager: + self._register_chat_history(user_id, session_id) + history_info = self.chat_history_manager[user_id].chat_history[-20:] current_messages = [ {"role": "system", "content": system_prompt}, *history_info, {"role": "user", "content": query}, ] + logger.info("Start to get final answer...") response = self.chat_llm.generate(current_messages) time_end = time.time() self._start_post_chat_processing( @@ -1129,7 +1167,7 @@ def chat_with_references( self._register_chat_history(user_id, session_id) chat_history = self.chat_history_manager[user_id] - if history: + if history is not None: chat_history.chat_history = history[-20:] current_messages = [ {"role": "system", "content": system_prompt}, diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py index 313989cd2..05e62e3ee 100644 --- a/src/memos/memories/textual/simple_tree.py +++ b/src/memos/memories/textual/simple_tree.py @@ -1,7 +1,4 @@ -import time - -from datetime import datetime -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from memos.configs.memory import TreeTextMemoryConfig from memos.embedders.base import BaseEmbedder @@ -9,13 +6,10 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_reader.base import BaseMemReader -from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25 -from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.base import BaseReranker -from memos.types import MessageList if TYPE_CHECKING: @@ -43,43 +37,22 @@ def __init__( is_reorganize: bool = False, ): """Initialize memory with the given configuration.""" - time_start = time.time() self.config: TreeTextMemoryConfig = config self.mode = self.config.mode logger.info(f"Tree mode is {self.mode}") self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = llm - logger.info(f"time init: extractor_llm time is: {time.time() - time_start}") - - time_start_ex = time.time() self.dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM = llm - logger.info(f"time init: dispatcher_llm time is: {time.time() - time_start_ex}") - - time_start_em = time.time() self.embedder: OllamaEmbedder = embedder - logger.info(f"time init: embedder time is: {time.time() - time_start_em}") - - time_start_gs = time.time() self.graph_store: Neo4jGraphDB = graph_db - logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}") - - time_start_bm = time.time() self.search_strategy = config.search_strategy self.bm25_retriever = ( EnhancedBM25() if self.search_strategy and self.search_strategy.get("bm25", False) else None ) - logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}") - - time_start_rr = time.time() self.reranker = reranker - logger.info(f"time init: reranker time is: {time.time() - time_start_rr}") - - time_start_mm = time.time() self.memory_manager: MemoryManager = memory_manager - logger.info(f"time init: memory_manager time is: {time.time() - time_start_mm}") - time_start_ir = time.time() # Create internet retriever if configured self.internet_retriever = None if config.internet_retriever is not None: @@ -89,223 +62,3 @@ def __init__( ) else: logger.info("No internet retriever configured") - logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}") - - def replace_working_memory( - self, memories: list[TextualMemoryItem], user_name: str | None = None - ) -> None: - self.memory_manager.replace_working_memory(memories, user_name=user_name) - - def get_working_memory(self, user_name: str | None = None) -> list[TextualMemoryItem]: - working_memories = self.graph_store.get_all_memory_items( - scope="WorkingMemory", user_name=user_name - ) - items = [TextualMemoryItem.from_dict(record) for record in (working_memories)] - # Sort by updated_at in descending order - sorted_items = sorted( - items, key=lambda x: x.metadata.updated_at or datetime.min, reverse=True - ) - return sorted_items - - def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int]: - """ - Get the current size of each memory type. - This delegates to the MemoryManager. - """ - return self.memory_manager.get_current_memory_size(user_name=user_name) - - def get_searcher( - self, - manual_close_internet: bool = False, - moscube: bool = False, - ): - if (self.internet_retriever is not None) and manual_close_internet: - logger.warning( - "Internet retriever is init by config , but this search set manual_close_internet is True and will close it" - ) - searcher = Searcher( - self.dispatcher_llm, - self.graph_store, - self.embedder, - self.reranker, - internet_retriever=None, - moscube=moscube, - ) - else: - searcher = Searcher( - self.dispatcher_llm, - self.graph_store, - self.embedder, - self.reranker, - internet_retriever=self.internet_retriever, - moscube=moscube, - ) - return searcher - - def search( - self, - query: str, - top_k: int, - info=None, - mode: str = "fast", - memory_type: str = "All", - manual_close_internet: bool = False, - moscube: bool = False, - search_filter: dict | None = None, - user_name: str | None = None, - ) -> list[TextualMemoryItem]: - """Search for memories based on a query. - User query -> TaskGoalParser -> MemoryPathResolver -> - GraphMemoryRetriever -> MemoryReranker -> MemoryReasoner -> Final output - Args: - query (str): The query to search for. - top_k (int): The number of top results to return. - info (dict): Leave a record of memory consumption. - mode (str, optional): The mode of the search. - - 'fast': Uses a faster search process, sacrificing some precision for speed. - - 'fine': Uses a more detailed search process, invoking large models for higher precision, but slower performance. - memory_type (str): Type restriction for search. - ['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory'] - manual_close_internet (bool): If True, the internet retriever will be closed by this search, it high priority than config. - moscube (bool): whether you use moscube to answer questions - search_filter (dict, optional): Optional metadata filters for search results. - - Keys correspond to memory metadata fields (e.g., "user_id", "session_id"). - - Values are exact-match conditions. - Example: {"user_id": "123", "session_id": "abc"} - If None, no additional filtering is applied. - Returns: - list[TextualMemoryItem]: List of matching memories. - """ - if (self.internet_retriever is not None) and manual_close_internet: - searcher = Searcher( - self.dispatcher_llm, - self.graph_store, - self.embedder, - self.reranker, - bm25_retriever=self.bm25_retriever, - internet_retriever=None, - moscube=moscube, - search_strategy=self.search_strategy, - ) - else: - searcher = Searcher( - self.dispatcher_llm, - self.graph_store, - self.embedder, - self.reranker, - bm25_retriever=self.bm25_retriever, - internet_retriever=self.internet_retriever, - moscube=moscube, - search_strategy=self.search_strategy, - ) - return searcher.search( - query, top_k, info, mode, memory_type, search_filter, user_name=user_name - ) - - def get_relevant_subgraph( - self, query: str, top_k: int = 5, depth: int = 2, center_status: str = "activated" - ) -> dict[str, Any]: - """ - Find and merge the local neighborhood sub-graphs of the top-k - nodes most relevant to the query. - Process: - 1. Embed the user query into a vector representation. - 2. Use vector similarity search to find the top-k similar nodes. - 3. For each similar node: - - Ensure its status matches `center_status` (e.g., 'active'). - - Retrieve its local subgraph up to `depth` hops. - - Collect the center node, its neighbors, and connecting edges. - 4. Merge all retrieved subgraphs into a single unified subgraph. - 5. Return the merged subgraph structure. - - Args: - query (str): The user input or concept to find relevant memories for. - top_k (int, optional): How many top similar nodes to retrieve. Default is 5. - depth (int, optional): The neighborhood depth (number of hops). Default is 2. - center_status (str, optional): Status condition the center node must satisfy (e.g., 'active'). - - Returns: - dict[str, Any]: A subgraph dict with: - - 'core_id': ID of the top matching core node, or None if none found. - - 'nodes': List of unique nodes (core + neighbors) in the merged subgraph. - - 'edges': List of unique edges (as dicts with 'from', 'to', 'type') in the merged subgraph. - """ - # Step 1: Embed query - query_embedding = self.embedder.embed([query])[0] - - # Step 2: Get top-1 similar node - similar_nodes = self.graph_store.search_by_embedding(query_embedding, top_k=top_k) - if not similar_nodes: - logger.info("No similar nodes found for query embedding.") - return {"core_id": None, "nodes": [], "edges": []} - - # Step 3: Fetch neighborhood - all_nodes = {} - all_edges = set() - cores = [] - - for node in similar_nodes: - core_id = node["id"] - score = node["score"] - - subgraph = self.graph_store.get_subgraph( - center_id=core_id, depth=depth, center_status=center_status - ) - - if not subgraph["core_node"]: - logger.info(f"Skipping node {core_id} (inactive or not found).") - continue - - core_node = subgraph["core_node"] - neighbors = subgraph["neighbors"] - edges = subgraph["edges"] - - # Collect nodes - all_nodes[core_node["id"]] = core_node - for n in neighbors: - all_nodes[n["id"]] = n - - # Collect edges - for e in edges: - all_edges.add((e["source"], e["target"], e["type"])) - - cores.append( - {"id": core_id, "score": score, "core_node": core_node, "neighbors": neighbors} - ) - - top_core = cores[0] - return { - "core_id": top_core["id"], - "nodes": list(all_nodes.values()), - "edges": [{"source": f, "target": t, "type": ty} for (f, t, ty) in all_edges], - } - - def extract(self, messages: MessageList) -> list[TextualMemoryItem]: - raise NotImplementedError - - def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None: - raise NotImplementedError - - def get(self, memory_id: str) -> TextualMemoryItem: - """Get a memory by its ID.""" - result = self.graph_store.get_node(memory_id) - if result is None: - raise ValueError(f"Memory with ID {memory_id} not found") - metadata_dict = result.get("metadata", {}) - return TextualMemoryItem( - id=result["id"], - memory=result["memory"], - metadata=TreeNodeTextualMemoryMetadata(**metadata_dict), - ) - - def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: - raise NotImplementedError - - def delete_all(self) -> None: - """Delete all memories and their relationships from the graph store.""" - try: - self.graph_store.clear() - logger.info("All memories and edges have been deleted from the graph.") - except Exception as e: - logger.error(f"An error occurred while deleting all memories: {e}") - raise diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index dea3cc1ab..e2e0be69c 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -103,11 +103,15 @@ def add( """ return self.memory_manager.add(memories, user_name=user_name, mode=self.mode) - def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None: - self.memory_manager.replace_working_memory(memories) - - def get_working_memory(self) -> list[TextualMemoryItem]: - working_memories = self.graph_store.get_all_memory_items(scope="WorkingMemory") + def replace_working_memory( + self, memories: list[TextualMemoryItem], user_name: str | None = None + ) -> None: + self.memory_manager.replace_working_memory(memories, user_name=user_name) + + def get_working_memory(self, user_name: str | None = None) -> list[TextualMemoryItem]: + working_memories = self.graph_store.get_all_memory_items( + scope="WorkingMemory", user_name=user_name + ) items = [TextualMemoryItem.from_dict(record) for record in (working_memories)] # Sort by updated_at in descending order sorted_items = sorted( @@ -115,12 +119,12 @@ def get_working_memory(self) -> list[TextualMemoryItem]: ) return sorted_items - def get_current_memory_size(self) -> dict[str, int]: + def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int]: """ Get the current size of each memory type. This delegates to the MemoryManager. """ - return self.memory_manager.get_current_memory_size() + return self.memory_manager.get_current_memory_size(user_name=user_name) def get_searcher( self, @@ -160,6 +164,7 @@ def search( manual_close_internet: bool = False, moscube: bool = False, search_filter: dict | None = None, + user_name: str | None = None, ) -> list[TextualMemoryItem]: """Search for memories based on a query. User query -> TaskGoalParser -> MemoryPathResolver -> @@ -208,7 +213,9 @@ def search( moscube=moscube, search_strategy=self.search_strategy, ) - return searcher.search(query, top_k, info, mode, memory_type, search_filter) + return searcher.search( + query, top_k, info, mode, memory_type, search_filter, user_name=user_name + ) def get_relevant_subgraph( self, query: str, top_k: int = 5, depth: int = 2, center_status: str = "activated" @@ -306,7 +313,9 @@ def get(self, memory_id: str) -> TextualMemoryItem: metadata=TreeNodeTextualMemoryMetadata(**metadata_dict), ) - def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]: + def get_by_ids( + self, memory_ids: list[str], user_name: str | None = None + ) -> list[TextualMemoryItem]: raise NotImplementedError def get_all(self, user_name: str | None = None) -> dict: diff --git a/src/memos/memos_tools/dinding_report_bot.py b/src/memos/memos_tools/dinding_report_bot.py index 9791cf65a..d8b762855 100644 --- a/src/memos/memos_tools/dinding_report_bot.py +++ b/src/memos/memos_tools/dinding_report_bot.py @@ -7,6 +7,7 @@ import json import os import time +import traceback import urllib.parse from datetime import datetime @@ -14,6 +15,11 @@ from dotenv import load_dotenv +from memos.log import get_logger + + +logger = get_logger(__name__) + load_dotenv() @@ -57,6 +63,20 @@ ROBOT_CODE = os.getenv("DINGDING_ROBOT_CODE") DING_APP_KEY = os.getenv("DINGDING_APP_KEY") DING_APP_SECRET = os.getenv("DINGDING_APP_SECRET") +ENV_NAME = os.getenv("ENV_NAME", "PLAYGROUND_OFFLINE") + +theme_map = { + "ONLINE": { + "color": "#2196F3", + "grad": ("#E3F2FD", "#BBDEFB"), + "emoji": "🩵", + }, + "OFFLINE": { + "color": "#FFC107", + "grad": ("#FFF8E1", "#FFECB3"), + "emoji": "🤍", + }, +} # Get access_token @@ -311,7 +331,7 @@ def error_bot( ) # ---------- Markdown ---------- - colored_title = f"{title}" + colored_title = f"{ENV_NAME}" at_suffix = "" if user_ids: at_suffix = "\n\n" + " ".join([f"@{m}" for m in user_ids]) @@ -367,41 +387,52 @@ def online_bot( other_data2: dict, emoji: dict, ): - heading_color = "#00956D" # Green for subtitle - - # 0) Banner - banner_bytes = make_header(header_name, sub_title_name) - banner_url = upload_bytes_to_oss(banner_bytes, filename="online_report.png") - - # 1) Colored main title - colored_title = f"{header_name}" - - # 3) Markdown - md = "\n\n".join( - filter( - None, - [ - f"![banner]({banner_url})", - f"### 🙄 {colored_title}\n\n", - _kv_lines( - other_data1, - next(iter(emoji.keys())), - next(iter(emoji.values())), - heading_color=heading_color, - ), - _kv_lines( - other_data2, - list(emoji.keys())[1], - list(emoji.values())[1], - heading_color=heading_color, - ), - f"Time: " - f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n", - ], + try: + logger.info("in online bot") + theme = "OFFLINE" if "OFFLINE" in ENV_NAME or "TEST" in ENV_NAME else "ONLINE" + style = theme_map.get(theme, theme_map["OFFLINE"]) + heading_color = style["color"] # Use theme color for subtitle + + # 0) Banner + banner_bytes = make_header( + header_name, + sub_title_name, + colors=style["grad"], + fg=style["color"], + ) + banner_url = upload_bytes_to_oss(banner_bytes, filename=f"{ENV_NAME}_online_report.png") + + # 1) Colored main title + colored_title = f"{ENV_NAME}" + + # 3) Markdown + md = "\n\n".join( + filter( + None, + [ + f"![banner]({banner_url})", + f"### {style['emoji']} {colored_title}\n\n", + _kv_lines( + other_data1, + next(iter(emoji.keys())), + next(iter(emoji.values())), + heading_color=heading_color, + ), + _kv_lines( + other_data2, + list(emoji.keys())[1], + list(emoji.values())[1], + heading_color=heading_color, + ), + f"Time: " + f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n", + ], + ) ) - ) - _send_md(colored_title, md, type="user") + _send_md(colored_title, md, type="user") + except Exception: + logger.error(traceback.format_exc()) if __name__ == "__main__": diff --git a/src/memos/utils.py b/src/memos/utils.py index 08934ed34..9ae27bb81 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -6,7 +6,7 @@ logger = get_logger(__name__) -def timed(func=None, *, log=False, log_prefix=""): +def timed(func=None, *, log=True, log_prefix=""): """Decorator to measure and optionally log time of retrieval steps. Can be used as @timed or @timed(log=True) From 31c4c9e67dfc476a2c7cb56cdf4ca1c800509308 Mon Sep 17 00:00:00 2001 From: CaralHsi Date: Mon, 10 Nov 2025 11:57:40 +0800 Subject: [PATCH 02/18] feat: max worker (#475) feat: set text-add max worker 200 --- src/memos/memories/textual/tree_text_memory/organize/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py index 0c41717ea..a71fee02f 100644 --- a/src/memos/memories/textual/tree_text_memory/organize/manager.py +++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py @@ -92,7 +92,7 @@ def add( """ added_ids: list[str] = [] - with ContextThreadPoolExecutor(max_workers=20) as executor: + with ContextThreadPoolExecutor(max_workers=200) as executor: futures = {executor.submit(self._process_memory, m, user_name): m for m in memories} for future in as_completed(futures, timeout=60): try: From 9c6b1ccef8226922bbc33821213812d7a17fb112 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Mon, 10 Nov 2025 15:58:30 +0800 Subject: [PATCH 03/18] Feat/dedup mem (#473) * add dedup strategy between pref and textual * make precommit * add try catch logic in server router, add dedup logic in explicit pref * fixbug in make pre_commit --------- Co-authored-by: yuan.wang --- evaluation/.env-example | 1 - src/memos/api/routers/server_router.py | 128 ++++++++++++------ src/memos/mem_cube/navie.py | 62 ++------- src/memos/memories/textual/item.py | 1 + .../textual/prefer_text_memory/adder.py | 95 ++++++++++++- .../textual/prefer_text_memory/factory.py | 11 +- .../textual/prefer_text_memory/retrievers.py | 8 ++ src/memos/templates/prefer_complete_prompt.py | 68 ++++++++++ src/memos/vec_dbs/milvus.py | 45 +++--- 9 files changed, 294 insertions(+), 125 deletions(-) diff --git a/evaluation/.env-example b/evaluation/.env-example index 5381532c2..bab6f679e 100644 --- a/evaluation/.env-example +++ b/evaluation/.env-example @@ -21,4 +21,3 @@ MEMU_API_KEY="mu_xxx" SUPERMEMORY_API_KEY="sm_xxx" MEMOBASE_API_KEY="xxx" MEMOBASE_PROJECT_URL="http://***.***.***.***:8019" - diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 8df383bfb..b426c2965 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -55,6 +55,8 @@ ExtractorFactory, RetrieverFactory, ) +from memos.memories.textual.simple_preference import SimplePreferenceTextMemory +from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, @@ -195,18 +197,43 @@ def init_server(): internet_retriever = InternetRetrieverFactory.from_config( internet_retriever_config, embedder=embedder ) + + # Initialize memory manager + memory_manager = MemoryManager( + graph_db, + embedder, + llm, + memory_size=_get_default_memory_size(default_cube_config), + is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False), + ) + + # Initialize text memory + text_mem = SimpleTreeTextMemory( + llm=llm, + embedder=embedder, + mem_reader=mem_reader, + graph_db=graph_db, + reranker=reranker, + memory_manager=memory_manager, + config=default_cube_config.text_mem.config, + internet_retriever=internet_retriever, + ) + pref_extractor = ExtractorFactory.from_config( config_factory=pref_extractor_config, llm_provider=llm, embedder=embedder, vector_db=vector_db, ) + pref_adder = AdderFactory.from_config( config_factory=pref_adder_config, llm_provider=llm, embedder=embedder, vector_db=vector_db, + text_mem=text_mem, ) + pref_retriever = RetrieverFactory.from_config( config_factory=pref_retriever_config, llm_provider=llm, @@ -215,33 +242,29 @@ def init_server(): vector_db=vector_db, ) - # Initialize memory manager - memory_manager = MemoryManager( - graph_db, - embedder, - llm, - memory_size=_get_default_memory_size(default_cube_config), - is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False), + # Initialize preference memory + pref_mem = SimplePreferenceTextMemory( + extractor_llm=llm, + vector_db=vector_db, + embedder=embedder, + reranker=reranker, + extractor=pref_extractor, + adder=pref_adder, + retriever=pref_retriever, ) + mos_server = MOSServer( mem_reader=mem_reader, llm=llm, online_bot=False, ) + # Create MemCube with pre-initialized memory instances naive_mem_cube = NaiveMemCube( - llm=llm, - embedder=embedder, - mem_reader=mem_reader, - graph_db=graph_db, - reranker=reranker, - internet_retriever=internet_retriever, - memory_manager=memory_manager, - default_cube_config=default_cube_config, - vector_db=vector_db, - pref_extractor=pref_extractor, - pref_adder=pref_adder, - pref_retriever=pref_retriever, + text_mem=text_mem, + pref_mem=pref_mem, + act_mem=None, + para_mem=None, ) # Initialize Scheduler @@ -279,6 +302,8 @@ def init_server(): pref_extractor, pref_adder, pref_retriever, + text_mem, + pref_mem, ) @@ -300,6 +325,8 @@ def init_server(): pref_extractor, pref_adder, pref_retriever, + text_mem, + pref_mem, ) = init_server() @@ -361,36 +388,46 @@ def search_memories(search_req: APISearchRequest): search_mode = search_req.mode def _search_text(): - if search_mode == SearchMode.FAST: - formatted_memories = fast_search_memories( - search_req=search_req, user_context=user_context - ) - elif search_mode == SearchMode.FINE: - formatted_memories = fine_search_memories( - search_req=search_req, user_context=user_context - ) - elif search_mode == SearchMode.MIXTURE: - formatted_memories = mix_search_memories( - search_req=search_req, user_context=user_context - ) - else: - logger.error(f"Unsupported search mode: {search_mode}") - raise HTTPException(status_code=400, detail=f"Unsupported search mode: {search_mode}") - return formatted_memories + try: + if search_mode == SearchMode.FAST: + formatted_memories = fast_search_memories( + search_req=search_req, user_context=user_context + ) + elif search_mode == SearchMode.FINE: + formatted_memories = fine_search_memories( + search_req=search_req, user_context=user_context + ) + elif search_mode == SearchMode.MIXTURE: + formatted_memories = mix_search_memories( + search_req=search_req, user_context=user_context + ) + else: + logger.error(f"Unsupported search mode: {search_mode}") + raise HTTPException( + status_code=400, detail=f"Unsupported search mode: {search_mode}" + ) + return formatted_memories + except Exception as e: + logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc()) + return [] def _search_pref(): if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": return [] - results = naive_mem_cube.pref_mem.search( - query=search_req.query, - top_k=search_req.pref_top_k, - info={ - "user_id": search_req.user_id, - "session_id": search_req.session_id, - "chat_history": search_req.chat_history, - }, - ) - return [_format_memory_item(data) for data in results] + try: + results = naive_mem_cube.pref_mem.search( + query=search_req.query, + top_k=search_req.pref_top_k, + info={ + "user_id": search_req.user_id, + "session_id": search_req.session_id, + "chat_history": search_req.chat_history, + }, + ) + return [_format_memory_item(data) for data in results] + except Exception as e: + logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) + return [] with ContextThreadPoolExecutor(max_workers=2) as executor: text_future = executor.submit(_search_text) @@ -601,6 +638,7 @@ def _process_pref_mem() -> list[dict[str, str]]: info={ "user_id": add_req.user_id, "session_id": target_session_id, + "mem_cube_id": add_req.mem_cube_id, }, ) pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local) diff --git a/src/memos/mem_cube/navie.py b/src/memos/mem_cube/navie.py index ba9f136b7..3afa78bab 100644 --- a/src/memos/mem_cube/navie.py +++ b/src/memos/mem_cube/navie.py @@ -2,26 +2,13 @@ from typing import Literal -from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.utils import get_json_file_model_schema -from memos.embedders.base import BaseEmbedder from memos.exceptions import ConfigurationError, MemCubeError -from memos.graph_dbs.base import BaseGraphDB -from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.base import BaseMemCube -from memos.mem_reader.base import BaseMemReader from memos.memories.activation.base import BaseActMemory from memos.memories.parametric.base import BaseParaMemory from memos.memories.textual.base import BaseTextMemory -from memos.memories.textual.prefer_text_memory.adder import BaseAdder -from memos.memories.textual.prefer_text_memory.extractor import BaseExtractor -from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever -from memos.memories.textual.simple_preference import SimplePreferenceTextMemory -from memos.memories.textual.simple_tree import SimpleTreeTextMemory -from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager -from memos.reranker.base import BaseReranker -from memos.vec_dbs.base import BaseVecDB logger = get_logger(__name__) @@ -32,51 +19,28 @@ class NaiveMemCube(BaseMemCube): def __init__( self, - llm: BaseLLM, - embedder: BaseEmbedder, - mem_reader: BaseMemReader, - graph_db: BaseGraphDB, - reranker: BaseReranker, - memory_manager: MemoryManager, - default_cube_config: GeneralMemCubeConfig, - vector_db: BaseVecDB, - internet_retriever: None = None, - pref_extractor: BaseExtractor | None = None, - pref_adder: BaseAdder | None = None, - pref_retriever: BaseRetriever | None = None, + text_mem: BaseTextMemory | None = None, + pref_mem: BaseTextMemory | None = None, + act_mem: BaseActMemory | None = None, + para_mem: BaseParaMemory | None = None, ): - """Initialize the MemCube with a configuration.""" - self._text_mem: BaseTextMemory | None = SimpleTreeTextMemory( - llm, - embedder, - mem_reader, - graph_db, - reranker, - memory_manager, - default_cube_config.text_mem.config, - internet_retriever, - ) - self._act_mem: BaseActMemory | None = None - self._para_mem: BaseParaMemory | None = None - self._pref_mem: BaseTextMemory | None = SimplePreferenceTextMemory( - extractor_llm=llm, - vector_db=vector_db, - embedder=embedder, - reranker=reranker, - extractor=pref_extractor, - adder=pref_adder, - retriever=pref_retriever, - ) + """Initialize the MemCube with memory instances.""" + self._text_mem: BaseTextMemory = text_mem + self._act_mem: BaseActMemory | None = act_mem + self._para_mem: BaseParaMemory | None = para_mem + self._pref_mem: BaseTextMemory | None = pref_mem def load( - self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None + self, + dir: str, + memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None, ) -> None: """Load memories. Args: dir (str): The directory containing the memory files. memory_types (list[str], optional): List of memory types to load. If None, loads all available memory types. - Options: ["text_mem", "act_mem", "para_mem"] + Options: ["text_mem", "act_mem", "para_mem", "pref_mem"] """ loaded_schema = get_json_file_model_schema(os.path.join(dir, self.config.config_filename)) if loaded_schema != self.config.model_schema: diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py index 2c23ae193..e7595443d 100644 --- a/src/memos/memories/textual/item.py +++ b/src/memos/memories/textual/item.py @@ -198,6 +198,7 @@ class PreferenceTextualMemoryMetadata(TextualMemoryMetadata): embedding: list[float] | None = Field(default=None, description="Vector of the dialog.") preference: str | None = Field(default=None, description="Preference.") created_at: str | None = Field(default=None, description="Timestamp of the dialog.") + mem_cube_id: str | None = Field(default=None, description="ID of the MemCube.") class TextualMemoryItem(BaseModel): diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py index a78601e86..5e58d23a5 100644 --- a/src/memos/memories/textual/prefer_text_memory/adder.py +++ b/src/memos/memories/textual/prefer_text_memory/adder.py @@ -10,6 +10,7 @@ from memos.log import get_logger from memos.memories.textual.item import TextualMemoryItem from memos.templates.prefer_complete_prompt import ( + NAIVE_JUDGE_DUP_WITH_TEXT_MEM_PROMPT, NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT, NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_FINE, NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE, @@ -24,7 +25,7 @@ class BaseAdder(ABC): """Abstract base class for adders.""" @abstractmethod - def __init__(self, llm_provider=None, embedder=None, vector_db=None): + def __init__(self, llm_provider=None, embedder=None, vector_db=None, text_mem=None): """Initialize the adder.""" @abstractmethod @@ -41,12 +42,13 @@ def add(self, memories: list[TextualMemoryItem | dict[str, Any]], *args, **kwarg class NaiveAdder(BaseAdder): """Naive adder.""" - def __init__(self, llm_provider=None, embedder=None, vector_db=None): + def __init__(self, llm_provider=None, embedder=None, vector_db=None, text_mem=None): """Initialize the naive adder.""" - super().__init__(llm_provider, embedder, vector_db) + super().__init__(llm_provider, embedder, vector_db, text_mem) self.llm_provider = llm_provider self.embedder = embedder self.vector_db = vector_db + self.text_mem = text_mem def _judge_update_or_add_fast(self, old_msg: str, new_msg: str) -> bool: """Judge if the new message expresses the same core content as the old message.""" @@ -81,6 +83,44 @@ def _judge_update_or_add_fine(self, new_mem: str, retrieved_mems: str) -> dict[s logger.error(f"Error in judge_update_or_add_fine: {e}") return None + def _judge_dup_with_text_mem(self, new_pref: MilvusVecDBItem) -> bool: + """Judge if the new message is the same as the text memory for a single preference.""" + if new_pref.payload["preference_type"] != "explicit_preference": + return False + text_recalls = self.text_mem.search( + query=new_pref.memory, + top_k=5, + info={ + "user_id": new_pref.payload["user_id"], + "session_id": new_pref.payload["session_id"], + }, + mode="fast", + search_filter={"session_id": new_pref.payload["session_id"]}, + user_name=new_pref.payload["mem_cube_id"], + ) + + text_mem_recalls = [ + {"id": text_recall.id, "memory": text_recall.memory} for text_recall in text_recalls + ] + + if not text_mem_recalls: + return False + + new_preference = {"id": new_pref.id, "memory": new_pref.payload["preference"]} + + prompt = NAIVE_JUDGE_DUP_WITH_TEXT_MEM_PROMPT.replace( + "{new_preference}", json.dumps(new_preference, ensure_ascii=False) + ).replace("{retrieved_memories}", json.dumps(text_mem_recalls, ensure_ascii=False)) + try: + response = self.llm_provider.generate([{"role": "user", "content": prompt}]) + response = response.strip().replace("```json", "").replace("```", "").strip() + result = json.loads(response) + exists = result.get("exists", False) + return exists + except Exception as e: + logger.error(f"Error in judge_dup_with_text_mem: {e}") + return False + def _judge_update_or_add_trace_op( self, new_mems: str, retrieved_mems: str ) -> dict[str, Any] | None: @@ -98,6 +138,32 @@ def _judge_update_or_add_trace_op( logger.error(f"Error in judge_update_or_add_trace_op: {e}") return None + def _dedup_explicit_pref_by_textual( + self, new_prefs: list[MilvusVecDBItem] + ) -> list[MilvusVecDBItem]: + """Deduplicate explicit preferences by textual memory.""" + if os.getenv("DEDUP_PREF_EXP_BY_TEXTUAL", "false").lower() != "true" or not self.text_mem: + return new_prefs + dedup_prefs = [] + with ContextThreadPoolExecutor(max_workers=max(1, min(len(new_prefs), 5))) as executor: + future_to_idx = { + executor.submit(self._judge_dup_with_text_mem, new_pref): idx + for idx, new_pref in enumerate(new_prefs) + } + is_dup_flags = [False] * len(new_prefs) + for future in as_completed(future_to_idx): + idx = future_to_idx[future] + try: + is_dup_flags[idx] = future.result() + except Exception as e: + logger.error( + f"Error in _judge_dup_with_text_mem for pref {new_prefs[idx].id}: {e}" + ) + is_dup_flags[idx] = False + + dedup_prefs = [pref for idx, pref in enumerate(new_prefs) if not is_dup_flags[idx]] + return dedup_prefs + def _update_memory_op_trace( self, new_memories: list[TextualMemoryItem], @@ -139,10 +205,17 @@ def _update_memory_op_trace( ] rsp = self._judge_update_or_add_trace_op( - new_mems=json.dumps(new_mem_inputs), - retrieved_mems=json.dumps(retrieved_mem_inputs) if retrieved_mem_inputs else "", + new_mems=json.dumps(new_mem_inputs, ensure_ascii=False), + retrieved_mems=json.dumps(retrieved_mem_inputs, ensure_ascii=False) + if retrieved_mem_inputs + else "", ) if not rsp: + dedup_rsp = self._dedup_explicit_pref_by_textual(new_vec_db_items) + if not dedup_rsp: + return [] + else: + new_vec_db_items = dedup_rsp with ContextThreadPoolExecutor(max_workers=min(len(new_vec_db_items), 5)) as executor: futures = { executor.submit(self.vector_db.add, collection_name, [db_item]): db_item @@ -222,8 +295,10 @@ def _update_memory_fine( if mem.payload.get("preference", None) ] rsp = self._judge_update_or_add_fine( - new_mem=json.dumps(new_mem_input), - retrieved_mems=json.dumps(retrieved_mem_inputs) if retrieved_mem_inputs else "", + new_mem=json.dumps(new_mem_input, ensure_ascii=False), + retrieved_mems=json.dumps(retrieved_mem_inputs, ensure_ascii=False) + if retrieved_mem_inputs + else "", ) need_update = rsp.get("need_update", False) if rsp else False need_update = ( @@ -245,6 +320,9 @@ def _update_memory_fine( self.vector_db.update(collection_name, rsp["id"], update_vec_db_item) return rsp["id"] else: + dedup_rsp = self._dedup_explicit_pref_by_textual([vec_db_item]) + if not dedup_rsp: + return "" self.vector_db.add(collection_name, [vec_db_item]) return vec_db_item.id @@ -272,6 +350,9 @@ def _update_memory_fast( old_msg_str = recall.memory new_msg_str = new_memory.memory is_same = self._judge_update_or_add_fast(old_msg=old_msg_str, new_msg=new_msg_str) + dedup_rsp = self._dedup_explicit_pref_by_textual([vec_db_item]) + if not dedup_rsp: + return "" if is_same: vec_db_item.id = recall.id self.vector_db.update(collection_name, recall.id, vec_db_item) diff --git a/src/memos/memories/textual/prefer_text_memory/factory.py b/src/memos/memories/textual/prefer_text_memory/factory.py index 22182261a..3c96b7dac 100644 --- a/src/memos/memories/textual/prefer_text_memory/factory.py +++ b/src/memos/memories/textual/prefer_text_memory/factory.py @@ -19,14 +19,21 @@ class AdderFactory(BaseAdder): @classmethod def from_config( - cls, config_factory: AdderConfigFactory, llm_provider=None, embedder=None, vector_db=None + cls, + config_factory: AdderConfigFactory, + llm_provider=None, + embedder=None, + vector_db=None, + text_mem=None, ) -> BaseAdder: """Create a Adder instance from a configuration factory.""" backend = config_factory.backend if backend not in cls.backend_to_class: raise ValueError(f"Invalid backend: {backend}") adder_class = cls.backend_to_class[backend] - return adder_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db) + return adder_class( + llm_provider=llm_provider, embedder=embedder, vector_db=vector_db, text_mem=text_mem + ) class ExtractorFactory(BaseExtractor): diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 0074c3f1c..9f0d1ab32 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -119,6 +119,9 @@ def retrieve( if pref.payload.get("preference", None) ] + # store explicit id and score, use it after reranker + explicit_id_scores = {item.id: item.score for item in explicit_prefs} + reranker_map = { "naive": self._naive_reranker, "original_text": self._original_text_reranker, @@ -131,4 +134,9 @@ def retrieve( query=query, prefs_mem=implicit_prefs_mem, prefs=implicit_prefs, top_k=top_k ) + # filter explicit mem by score bigger than threshold + explicit_prefs_mem = [ + item for item in explicit_prefs_mem if explicit_id_scores.get(item.id, 0) >= 0.2 + ] + return explicit_prefs_mem + implicit_prefs_mem diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py index 9e0274cba..3a468b943 100644 --- a/src/memos/templates/prefer_complete_prompt.py +++ b/src/memos/templates/prefer_complete_prompt.py @@ -132,6 +132,74 @@ """ +NAIVE_JUDGE_DUP_WITH_TEXT_MEM_PROMPT = """ +You are a content comparison expert. Your task is to determine whether each new preference information already exists in the retrieved text memories. + +**Task:** For each new preference, check if its content/topic/intent is already present in any of the retrieved text memories. + +**Input Structure:** +- New preferences: Array of objects, each with "id" and "memory" fields +- Retrieved memories: Array of objects, each with "id" and "memory" fields + +**Judgment Criteria:** +- If the core content, topic, or intent of a new preference is **already covered** in any retrieved memory, mark as "exists" (true). +- Consider both semantic similarity and topic overlap - even if wording differs, if the meaning is the same, it counts as existing. +- If the new preference introduces **new information, different topic, or unique content** not found in retrieved memories, mark as "exists" (false). +- Focus on the substantive content rather than minor phrasing differences. + +**Output Format (JSON):** +```json +{ + "new_preference_id": "ID of the new preference being evaluated", + "exists": true/false, + "reasoning": "Brief explanation of your judgment, citing which retrieved memory contains similar content (if exists=true) or why it's new content (if exists=false)", + "matched_memory_id": "If exists=true, indicate which retrieved memory id matches; otherwise null" +} +``` +**New Preferences (array):** +{new_preference} + +**Retrieved Text Memories (array):** +{retrieved_memories} + +Output only the JSON response, no additional text. +""" + + +NAIVE_JUDGE_DUP_WITH_TEXT_MEM_PROMPT_ZH = """ +你是一个内容比较专家。你的任务是判断每个新的偏好信息是否已经存在于召回的文本记忆中。 + +**任务:** 对每个新偏好,检查其内容/主题/意图是否已经在任何召回的文本记忆中存在。 + +**输入结构:** +- 新偏好:对象数组,每个对象包含"id"和"memory"字段 +- 召回记忆:对象数组,每个对象包含"id"和"memory"字段 + +**判断标准:** +- 如果新偏好的核心内容、主题或意图**已经被覆盖**在任何召回的记忆中,标记为"exists"(true)。 +- 考虑语义相似性和主题重叠 - 即使措辞不同,如果含义相同,也算作已存在。 +- 如果新偏好引入了**新信息、不同主题或独特内容**,且在召回记忆中未找到,标记为"exists"(false)。 +- 关注实质性内容,而非细微的表达差异。 + +**输出格式(JSON):** +```json +{ + "new_preference_id": "正在评估的新偏好ID", + "exists": true/false, + "reasoning": "简要说明你的判断理由,引用包含相似内容的召回记忆(如果exists=true)或说明为什么是新内容(如果exists=false)", + "matched_memory_id": "如果exists=true,指出匹配的召回记忆id;否则为null" +} +``` +**新偏好(数组):** +{new_preference} + +**召回的文本记忆(数组):** +{retrieved_memories} + +只输出JSON响应,不要输出其他任何文本。 +""" + + NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT = """ You are a content comparison expert. Now you are given old and new information, each containing a question, answer topic name and topic description. Please judge whether these two information express the **same question or core content**, regardless of expression differences, details or example differences. The judgment criteria are as follows: diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py index e50c8ce18..eafee2633 100644 --- a/src/memos/vec_dbs/milvus.py +++ b/src/memos/vec_dbs/milvus.py @@ -236,29 +236,32 @@ def search( "sparse": self._sparse_search, "hybrid": self._hybrid_search, } + try: + results = search_func_map[search_type]( + collection_name=collection_name, + query_vector=query_vector, + query=query, + top_k=top_k, + filter=expr, + ) - results = search_func_map[search_type]( - collection_name=collection_name, - query_vector=query_vector, - query=query, - top_k=top_k, - filter=expr, - ) - - items = [] - for hit in results[0]: - entity = hit.get("entity", {}) - - items.append( - MilvusVecDBItem( - id=str(entity.get("id")), - memory=entity.get("memory"), - original_text=entity.get("original_text"), - vector=entity.get("vector"), - payload=entity.get("payload", {}), - score=1 - float(hit["distance"]), + items = [] + for hit in results[0]: + entity = hit.get("entity", {}) + + items.append( + MilvusVecDBItem( + id=str(entity.get("id")), + memory=entity.get("memory"), + original_text=entity.get("original_text"), + vector=entity.get("vector"), + payload=entity.get("payload", {}), + score=1 - float(hit["distance"]), + ) ) - ) + except Exception as e: + logger.error("Error in _%s_search: %s", search_type, e) + return [] logger.info(f"Milvus search completed with {len(items)} results.") return items From 2019b4a8eef86f5320008b708df907c39553cf4a Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Mon, 10 Nov 2025 16:49:21 +0800 Subject: [PATCH 04/18] feat: add topk for working mem (#476) --- src/memos/memories/textual/tree_text_memory/retrieve/recall.py | 2 +- .../memories/textual/tree_text_memory/retrieve/searcher.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 8cf2f47f3..7bb2eba7e 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -66,7 +66,7 @@ def retrieve( working_memories = self.graph_store.get_all_memory_items( scope="WorkingMemory", include_embedding=False, user_name=user_name ) - return [TextualMemoryItem.from_dict(record) for record in working_memories] + return [TextualMemoryItem.from_dict(record) for record in working_memories[:top_k]] with ContextThreadPoolExecutor(max_workers=3) as executor: # Structured graph-based retrieval diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index f408755fd..f196c5569 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -108,7 +108,7 @@ def post_retrieve( def search( self, query: str, - top_k: int, + top_k: int = 10, info=None, mode="fast", memory_type="All", From 1f60c567608f2092fc5fea3d660c2e923e4cda75 Mon Sep 17 00:00:00 2001 From: Dubberman <48425266+whipser030@users.noreply.github.com> Date: Tue, 11 Nov 2025 11:29:55 +0800 Subject: [PATCH 05/18] fix: chat time issue 2023->2025 (#479) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * update reader and search strategy * set strategy reader and search config * fix install problem * fix * fix test * turn off graph recall * turn off graph recall * turn off graph recall * fix Searcher input bug * fix Searcher * fix Search * fix bug * adjust strategy reader * adjust strategy reader * adjust search config input * reformat code * re pr * format repair * fix time issue --------- Co-authored-by: 黑布林 <11641432+heiheiyouyou@user.noreply.gitee.com> Co-authored-by: CaralHsi Co-authored-by: chunyu li <78344051+fridayL@users.noreply.github.com> --- src/memos/mem_reader/simple_struct.py | 28 ++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py index 13515c038..3845f37d0 100644 --- a/src/memos/mem_reader/simple_struct.py +++ b/src/memos/mem_reader/simple_struct.py @@ -6,6 +6,7 @@ import traceback from abc import ABC +from datetime import datetime, timezone from typing import Any from tqdm import tqdm @@ -399,7 +400,7 @@ def get_memory( if not all(isinstance(info[field], str) for field in required_fields): raise ValueError("user_id and session_id must be strings") - + scene_data = self._complete_chat_time(scene_data, type) list_scene_data_info = self.get_scene_data_info(scene_data, type) memory_list = [] @@ -508,6 +509,31 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]: return results + def _complete_chat_time(self, scene_data: list[list[dict]], type: str): + if type != "chat": + return scene_data + complete_scene_data = [] + + for items in scene_data: + chat_time_value = None + + for item in items: + if "chat_time" in item: + chat_time_value = item["chat_time"] + break + + if chat_time_value is None: + session_date = datetime.now(timezone.utc) + date_format = "%I:%M %p on %d %B, %Y UTC" + chat_time_value = session_date.strftime(date_format) + + for i in range(len(items)): + if "chat_time" not in items[i]: + items[i]["chat_time"] = chat_time_value + + complete_scene_data.append(items) + return complete_scene_data + def _process_doc_data(self, scene_data_info, info, **kwargs): mode = kwargs.get("mode", "fine") if mode == "fast": From a296ba9cd1c8c7f9f8ae477e0164f479ee7ea1a6 Mon Sep 17 00:00:00 2001 From: Travis Tang Date: Tue, 11 Nov 2025 16:14:52 +0800 Subject: [PATCH 06/18] scheduler feat: implementation of redis queue and new api search functions of mixture and fine mode (#462) * debug an error function name * feat: Add DynamicCache compatibility for different transformers versions - Fix build_kv_cache method in hf.py to handle both old and new DynamicCache structures - Support new 'layers' attribute with key_cache/value_cache or keys/values - Maintain backward compatibility with direct key_cache/value_cache attributes - Add comprehensive error handling and logging for unsupported structures - Update move_dynamic_cache_htod function in kv.py for cross-version compatibility - Handle layers-based structure in newer transformers versions - Support alternative attribute names (keys/values vs key_cache/value_cache) - Preserve original functionality for older transformers versions - Add comprehensive tests for DynamicCache compatibility - Test activation memory update with mock DynamicCache layers - Verify layers attribute access across different transformers versions - Fix scheduler logger mock to include memory_manager attribute This resolves AttributeError issues when using different versions of the transformers library and ensures robust handling of DynamicCache objects. debug * feat: implement APIAnalyzerForScheduler for memory operations - Add APIAnalyzerForScheduler class with search/add operations - Support requests and http.client with connection reuse - Include comprehensive error handling and dynamic configuration - Add English test suite with real-world conversation scenarios * feat: Add search_ws API endpoint and enhance API analyzer functionality - Add search_ws endpoint in server_router.py for scheduler-enabled search - Fix missing imports: time module, SearchRequest class, and get_mos_product_instance function - Implement search_ws method in api_analyzer.py with HTTP client support - Add _search_ws_with_requests and _search_ws_with_http_client private methods - Include search_ws usage example in demonstration code - Enhance scheduler and dispatcher capabilities for improved memory management - Expand test coverage to ensure functionality stability This update primarily strengthens the memory scheduling system's search capabilities, providing users with more flexible API interface options. * fix: resolve test failures and warnings in test suite - Fix Pydantic serialization warning in test_memos_chen_tang_hello_world * Add warnings filter to suppress UserWarning from Pydantic serialization - Fix KeyError: 'past_key_values' in test_build_kv_cache_and_generation * Update mock configuration to properly return forward_output with past_key_values * Add DynamicCache version compatibility handling in test mocks * Support both old and new transformers versions with layers/key_cache attributes * Improve assertion logic to check all model calls for required parameters - Update base_scheduler.py to use centralized DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE constant * Add import for DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE from general_schemas * Replace hardcoded value 100 with configurable constant (1000) All tests now pass successfully with proper version compatibility handling. * feat: add a test_robustness execution to test thread pool execution * feat: optimize scheduler configuration and API search functionality - Add DEFAULT_TOP_K and DEFAULT_CONTEXT_WINDOW_SIZE global constants in general_schemas.py - Update base_scheduler.py to use global default values instead of hardcoded numbers - Fix SchedulerConfigFactory initialization issue by using keyword argument expansion - Resolve UnboundLocalError variable conflict in search_memories_ws function - Fix indentation and parameter issues in OptimizedScheduler search_for_api method - Improve code standardization and maintainability * feat: Add Redis auto-initialization with fallback strategies - Add auto_initialize_redis() with config/env/local fallback - Move Redis logic from dispatcher_monitor to redis_service - Update base_scheduler to use auto initialization - Add proper resource cleanup and error handling * feat: add database connection management to ORM module - Add MySQL engine loading from environment variables in BaseDBManager - Add Redis connection loading from environment variables in BaseDBManager - Enhance database configuration validation and error handling - Complete database adapter infrastructure for ORM module - Provide unified database connection management interface This update provides comprehensive database connection management capabilities for the mem_scheduler module, supporting dynamic MySQL and Redis configuration loading from environment variables, establishing reliable data persistence foundation for scheduling services and API services. * remove part of test * feat: add Redis-based ORM with multiprocess synchronization - Add RedisDBManager and RedisLockableORM classes - Implement atomic locking mechanism for concurrent access - Add merge functionality for different object types - Include comprehensive test suite and examples - Fix Redis key type conflicts in lock operations * fix: resolve scheduler module import and Redis integration issues * revise naive memcube creation in server router * remove long-time tests in test_scheduler * remove redis test which needs .env * refactor all codes about mixture search with scheduler * fix: resolve Redis API synchronization issues and implement search API with reranker - Fix running_entries to running_task_ids migration across codebase - Update sync_search_data method to properly handle TaskRunningStatus - Correct variable naming and logic in API synchronization flow - Implement search API endpoint with reranker functionality - Update test files to reflect new running_task_ids convention - Ensure proper Redis state management for concurrent tasks * remove a test for api module * revise to pass the test suite * address some bugs to make mix_search normally running * modify codes according to evaluation logs * feat: Optimize mixture search and enhance API client * feat: Add conversation_turn tracking for session-based memory search - Add conversation_turn field to APIMemoryHistoryEntryItem schema with default value 0 - Implement session counter in OptimizedScheduler to track turn count per session_id - Update sync_search_data method to accept and store conversation_turn parameter - Maintain session history with LRU eviction (max 5 sessions) - Rename conversation_id to session_id for consistency with request object - Enable direct access to session_id from search requests This feature allows tracking conversation turns within the same session, providing better context for memory retrieval and search history management. * adress time bug in monitor * revise simple tree * add mode to evaluation client; rewrite print to logger.info in db files * feat: 1. add redis queue for scheduler 2. finish the code related to mix search and fine search * debug the working memory code * addressed a range of bugs to make scheduler running correctly * remove test_dispatch_parallel test * print change to logger.info --------- Co-authored-by: CaralHsi --- examples/mem_scheduler/api_w_scheduler.py | 62 + .../memos_w_optimized_scheduler.py | 85 -- .../memos_w_optimized_scheduler_for_test.py | 87 -- examples/mem_scheduler/memos_w_scheduler.py | 73 +- .../memos_w_scheduler_for_test.py | 230 +-- examples/mem_scheduler/orm_examples.py | 374 ----- examples/mem_scheduler/redis_example.py | 8 +- .../mem_scheduler/try_schedule_modules.py | 1 + src/memos/api/config.py | 7 +- src/memos/api/product_models.py | 4 +- src/memos/api/routers/server_router.py | 46 +- src/memos/configs/mem_scheduler.py | 19 + src/memos/mem_os/core.py | 12 - src/memos/mem_os/main.py | 2 - src/memos/mem_os/product.py | 1 - .../mem_scheduler/analyzer/api_analyzer.py | 17 +- .../mem_scheduler/analyzer/eval_analyzer.py | 1322 +++++++++++++++++ .../analyzer/memory_processing.py | 246 +++ .../analyzer/mos_for_test_scheduler.py | 2 - .../analyzer/scheduler_for_eval.py | 4 +- src/memos/mem_scheduler/base_scheduler.py | 217 ++- .../general_modules/dispatcher.py | 60 +- .../mem_scheduler/general_modules/misc.py | 63 +- .../general_modules/redis_queue.py | 460 ++++++ src/memos/mem_scheduler/general_scheduler.py | 14 +- .../memory_manage_modules/memory_filter.py | 10 +- .../memory_manage_modules/retriever.py | 222 ++- .../monitors/dispatcher_monitor.py | 11 +- .../mem_scheduler/monitors/general_monitor.py | 8 +- .../mem_scheduler/optimized_scheduler.py | 137 +- .../mem_scheduler/schemas/general_schemas.py | 10 +- .../mem_scheduler/schemas/message_schemas.py | 23 +- src/memos/mem_scheduler/utils/metrics.py | 8 +- src/memos/mem_scheduler/utils/misc_utils.py | 124 +- .../webservice_modules/redis_service.py | 9 + .../retrieve/task_goal_parser.py | 37 +- src/memos/templates/mem_scheduler_prompts.py | 40 + tests/mem_scheduler/test_dispatcher.py | 43 - tests/mem_scheduler/test_scheduler.py | 249 ---- 39 files changed, 2947 insertions(+), 1400 deletions(-) create mode 100644 examples/mem_scheduler/api_w_scheduler.py delete mode 100644 examples/mem_scheduler/memos_w_optimized_scheduler.py delete mode 100644 examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py delete mode 100644 examples/mem_scheduler/orm_examples.py create mode 100644 src/memos/mem_scheduler/analyzer/eval_analyzer.py create mode 100644 src/memos/mem_scheduler/analyzer/memory_processing.py create mode 100644 src/memos/mem_scheduler/general_modules/redis_queue.py diff --git a/examples/mem_scheduler/api_w_scheduler.py b/examples/mem_scheduler/api_w_scheduler.py new file mode 100644 index 000000000..11f0ebb81 --- /dev/null +++ b/examples/mem_scheduler/api_w_scheduler.py @@ -0,0 +1,62 @@ +from memos.api.routers.server_router import mem_scheduler +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem + + +# Debug: Print scheduler configuration +print("=== Scheduler Configuration Debug ===") +print(f"Scheduler type: {type(mem_scheduler).__name__}") +print(f"Config: {mem_scheduler.config}") +print(f"use_redis_queue: {mem_scheduler.use_redis_queue}") +print(f"Queue type: {type(mem_scheduler.memos_message_queue).__name__}") +print(f"Queue maxsize: {getattr(mem_scheduler.memos_message_queue, 'maxsize', 'N/A')}") + +# Check if Redis queue is connected +if hasattr(mem_scheduler.memos_message_queue, "_is_connected"): + print(f"Redis connected: {mem_scheduler.memos_message_queue._is_connected}") +if hasattr(mem_scheduler.memos_message_queue, "_redis_conn"): + print(f"Redis connection: {mem_scheduler.memos_message_queue._redis_conn}") +print("=====================================\n") + +queue = mem_scheduler.memos_message_queue +queue.clear() + + +# 1. Define a handler function +def my_test_handler(messages: list[ScheduleMessageItem]): + print(f"My test handler received {len(messages)} messages:") + for msg in messages: + print(f" my_test_handler - {msg.item_id}: {msg.content}") + print( + f"{queue._redis_conn.xinfo_groups(queue.stream_name)} qsize: {queue.qsize()} messages:{messages}" + ) + + +# 2. Register the handler +TEST_HANDLER_LABEL = "test_handler" +mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler}) + +# 3. Create messages +messages_to_send = [ + ScheduleMessageItem( + item_id=f"test_item_{i}", + user_id="test_user", + mem_cube_id="test_mem_cube", + label=TEST_HANDLER_LABEL, + content=f"This is test message {i}", + ) + for i in range(5) +] + +# 5. Submit messages +for mes in messages_to_send: + print(f"Submitting message {mes.item_id} to the scheduler...") + mem_scheduler.submit_messages([mes]) + +# 6. Wait for messages to be processed (limited to 100 checks) +print("Waiting for messages to be consumed (max 100 checks)...") +mem_scheduler.mem_scheduler_wait() + + +# 7. Stop the scheduler +print("Stopping the scheduler...") +mem_scheduler.stop() diff --git a/examples/mem_scheduler/memos_w_optimized_scheduler.py b/examples/mem_scheduler/memos_w_optimized_scheduler.py deleted file mode 100644 index 664168f62..000000000 --- a/examples/mem_scheduler/memos_w_optimized_scheduler.py +++ /dev/null @@ -1,85 +0,0 @@ -import shutil -import sys - -from pathlib import Path - -from memos_w_scheduler import init_task, show_web_logs - -from memos.configs.mem_cube import GeneralMemCubeConfig -from memos.configs.mem_os import MOSConfig -from memos.configs.mem_scheduler import AuthConfig -from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube -from memos.mem_os.main import MOS - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -def run_with_scheduler_init(): - print("==== run_with_automatic_scheduler_init ====") - conversations, questions = init_task() - - # set configs - mos_config = MOSConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml" - ) - - mem_cube_config = GeneralMemCubeConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config_neo4j.yaml" - ) - - # default local graphdb uri - if AuthConfig.default_config_exists(): - auth_config = AuthConfig.from_local_config() - - mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key - mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url - - mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri - mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user - mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password - mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name - mem_cube_config.text_mem.config.graph_db.config.auto_create = ( - auth_config.graph_db.auto_create - ) - - # Initialization - mos = MOS(mos_config) - - user_id = "user_1" - mos.create_user(user_id) - - mem_cube_id = "mem_cube_5" - mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" - - if Path(mem_cube_name_or_path).exists(): - shutil.rmtree(mem_cube_name_or_path) - print(f"{mem_cube_name_or_path} is not empty, and has been removed.") - - mem_cube = GeneralMemCube(mem_cube_config) - mem_cube.dump(mem_cube_name_or_path) - mos.register_mem_cube( - mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id - ) - - mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) - - for item in questions: - print("===== Chat Start =====") - query = item["question"] - print(f"Query:\n {query}\n") - response = mos.chat(query=query, user_id=user_id) - print(f"Answer:\n {response}\n") - - show_web_logs(mem_scheduler=mos.mem_scheduler) - - mos.mem_scheduler.stop() - - -if __name__ == "__main__": - run_with_scheduler_init() diff --git a/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py b/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py deleted file mode 100644 index ed4f721ad..000000000 --- a/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py +++ /dev/null @@ -1,87 +0,0 @@ -import json -import shutil -import sys - -from pathlib import Path - -from memos_w_scheduler_for_test import init_task - -from memos.configs.mem_cube import GeneralMemCubeConfig -from memos.configs.mem_os import MOSConfig -from memos.configs.mem_scheduler import AuthConfig -from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) - -# Enable execution from any working directory - -logger = get_logger(__name__) - -if __name__ == "__main__": - # set up data - conversations, questions = init_task() - - # set configs - mos_config = MOSConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml" - ) - - mem_cube_config = GeneralMemCubeConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config_neo4j.yaml" - ) - - # default local graphdb uri - if AuthConfig.default_config_exists(): - auth_config = AuthConfig.from_local_config() - - mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key - mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url - - mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri - mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user - mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password - mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name - mem_cube_config.text_mem.config.graph_db.config.auto_create = ( - auth_config.graph_db.auto_create - ) - - # Initialization - mos = MOSForTestScheduler(mos_config) - - user_id = "user_1" - mos.create_user(user_id) - - mem_cube_id = "mem_cube_5" - mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}" - - if Path(mem_cube_name_or_path).exists(): - shutil.rmtree(mem_cube_name_or_path) - print(f"{mem_cube_name_or_path} is not empty, and has been removed.") - - mem_cube = GeneralMemCube(mem_cube_config) - mem_cube.dump(mem_cube_name_or_path) - mos.register_mem_cube( - mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id - ) - - mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) - - # Add interfering conversations - file_path = Path(f"{BASE_DIR}/examples/data/mem_scheduler/scene_data.json") - scene_data = json.load(file_path.open("r", encoding="utf-8")) - mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id) - mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id) - - for item in questions: - print("===== Chat Start =====") - query = item["question"] - print(f"Query:\n {query}\n") - response = mos.chat(query=query, user_id=user_id) - print(f"Answer:\n {response}\n") - - mos.mem_scheduler.stop() diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py index dc196b85a..c523a8667 100644 --- a/examples/mem_scheduler/memos_w_scheduler.py +++ b/examples/mem_scheduler/memos_w_scheduler.py @@ -70,13 +70,48 @@ def init_task(): return conversations, questions +def show_web_logs(mem_scheduler: GeneralScheduler): + """Display all web log entries from the scheduler's log queue. + + Args: + mem_scheduler: The scheduler instance containing web logs to display + """ + if mem_scheduler._web_log_message_queue.empty(): + print("Web log queue is currently empty.") + return + + print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50) + + # Create a temporary queue to preserve the original queue contents + temp_queue = Queue() + log_count = 0 + + while not mem_scheduler._web_log_message_queue.empty(): + log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get() + temp_queue.put(log_item) + log_count += 1 + + # Print log entry details + print(f"\nLog Entry #{log_count}:") + print(f'- "{log_item.label}" log: {log_item}') + + print("-" * 50) + + # Restore items back to the original queue + while not temp_queue.empty(): + mem_scheduler._web_log_message_queue.put(temp_queue.get()) + + print(f"\nTotal {log_count} web log entries displayed.") + print("=" * 110 + "\n") + + def run_with_scheduler_init(): print("==== run_with_automatic_scheduler_init ====") conversations, questions = init_task() # set configs mos_config = MOSConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml" + f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml" ) mem_cube_config = GeneralMemCubeConfig.from_yaml_file( @@ -118,6 +153,7 @@ def run_with_scheduler_init(): ) mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) + mos.mem_scheduler.current_mem_cube = mem_cube for item in questions: print("===== Chat Start =====") @@ -131,40 +167,5 @@ def run_with_scheduler_init(): mos.mem_scheduler.stop() -def show_web_logs(mem_scheduler: GeneralScheduler): - """Display all web log entries from the scheduler's log queue. - - Args: - mem_scheduler: The scheduler instance containing web logs to display - """ - if mem_scheduler._web_log_message_queue.empty(): - print("Web log queue is currently empty.") - return - - print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50) - - # Create a temporary queue to preserve the original queue contents - temp_queue = Queue() - log_count = 0 - - while not mem_scheduler._web_log_message_queue.empty(): - log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get() - temp_queue.put(log_item) - log_count += 1 - - # Print log entry details - print(f"\nLog Entry #{log_count}:") - print(f'- "{log_item.label}" log: {log_item}') - - print("-" * 50) - - # Restore items back to the original queue - while not temp_queue.empty(): - mem_scheduler._web_log_message_queue.put(temp_queue.get()) - - print(f"\nTotal {log_count} web log entries displayed.") - print("=" * 110 + "\n") - - if __name__ == "__main__": run_with_scheduler_init() diff --git a/examples/mem_scheduler/memos_w_scheduler_for_test.py b/examples/mem_scheduler/memos_w_scheduler_for_test.py index 6faac98af..2e135f127 100644 --- a/examples/mem_scheduler/memos_w_scheduler_for_test.py +++ b/examples/mem_scheduler/memos_w_scheduler_for_test.py @@ -1,10 +1,11 @@ import json import shutil import sys -import time from pathlib import Path +from memos_w_scheduler import init_task + from memos.configs.mem_cube import GeneralMemCubeConfig from memos.configs.mem_os import MOSConfig from memos.configs.mem_scheduler import AuthConfig @@ -15,155 +16,19 @@ FILE_PATH = Path(__file__).absolute() BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -def display_memory_cube_stats(mos, user_id, mem_cube_id): - """Display detailed memory cube statistics.""" - print(f"\n📊 MEMORY CUBE STATISTICS for {mem_cube_id}:") - print("-" * 60) - - mem_cube = mos.mem_cubes.get(mem_cube_id) - if not mem_cube: - print(" ❌ Memory cube not found") - return - - # Text memory stats - if mem_cube.text_mem: - text_mem = mem_cube.text_mem - working_memories = text_mem.get_working_memory() - all_memories = text_mem.get_all() - - print(" 📝 Text Memory:") - print(f" • Working Memory Items: {len(working_memories)}") - print( - f" • Total Memory Items: {len(all_memories) if isinstance(all_memories, list) else 'N/A'}" - ) - - if working_memories: - print(" • Working Memory Content Preview:") - for i, mem in enumerate(working_memories[:2]): - content = mem.memory[:60] + "..." if len(mem.memory) > 60 else mem.memory - print(f" {i + 1}. {content}") - - # Activation memory stats - if mem_cube.act_mem: - act_mem = mem_cube.act_mem - act_memories = list(act_mem.get_all()) - print(" ⚡ Activation Memory:") - print(f" • KV Cache Items: {len(act_memories)}") - if act_memories: - print( - f" • Latest Cache Size: {len(act_memories[-1].memory) if hasattr(act_memories[-1], 'memory') else 'N/A'}" - ) - - print("-" * 60) - - -def display_scheduler_status(mos): - """Display current scheduler status and configuration.""" - print("\n⚙️ SCHEDULER STATUS:") - print("-" * 60) - - if not mos.mem_scheduler: - print(" ❌ Memory scheduler not initialized") - return - - scheduler = mos.mem_scheduler - print(f" 🔄 Scheduler Running: {scheduler._running}") - print(f" 📊 Internal Queue Size: {scheduler.memos_message_queue.qsize()}") - print(f" 🧵 Parallel Dispatch: {scheduler.enable_parallel_dispatch}") - print(f" 👥 Max Workers: {scheduler.thread_pool_max_workers}") - print(f" ⏱️ Consume Interval: {scheduler._consume_interval}s") - - if scheduler.monitor: - print(" 📈 Monitor Active: ✅") - print(f" 🗄️ Database Engine: {'✅' if scheduler.db_engine else '❌'}") - - if scheduler.dispatcher: - print(" 🚀 Dispatcher Active: ✅") - print( - f" 🔧 Dispatcher Status: {scheduler.dispatcher.status if hasattr(scheduler.dispatcher, 'status') else 'Unknown'}" - ) +sys.path.insert(0, str(BASE_DIR)) - print("-" * 60) - - -def init_task(): - conversations = [ - { - "role": "user", - "content": "I have two dogs - Max (golden retriever) and Bella (pug). We live in Seattle.", - }, - {"role": "assistant", "content": "Great! Any special care for them?"}, - { - "role": "user", - "content": "Max needs joint supplements. Actually, we're moving to Chicago next month.", - }, - { - "role": "user", - "content": "Correction: Bella is 6, not 5. And she's allergic to chicken.", - }, - { - "role": "user", - "content": "My partner's cat Whiskers visits weekends. Bella chases her sometimes.", - }, - ] - - questions = [ - # 1. Basic factual recall (simple) - { - "question": "What breed is Max?", - "category": "Pet", - "expected": "golden retriever", - "difficulty": "easy", - }, - # 2. Temporal context (medium) - { - "question": "Where will I live next month?", - "category": "Location", - "expected": "Chicago", - "difficulty": "medium", - }, - # 3. Information correction (hard) - { - "question": "How old is Bella really?", - "category": "Pet", - "expected": "6", - "difficulty": "hard", - "hint": "User corrected the age later", - }, - # 4. Relationship inference (harder) - { - "question": "Why might Whiskers be nervous around my pets?", - "category": "Behavior", - "expected": "Bella chases her sometimes", - "difficulty": "harder", - }, - # 5. Combined medical info (hardest) - { - "question": "Which pets have health considerations?", - "category": "Health", - "expected": "Max needs joint supplements, Bella is allergic to chicken", - "difficulty": "hardest", - "requires": ["combining multiple facts", "ignoring outdated info"], - }, - ] - return conversations, questions +# Enable execution from any working directory +logger = get_logger(__name__) if __name__ == "__main__": - print("🚀 Starting Enhanced Memory Scheduler Test...") - print("=" * 80) - # set up data conversations, questions = init_task() # set configs mos_config = MOSConfig.from_yaml_file( - f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml" + f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml" ) mem_cube_config = GeneralMemCubeConfig.from_yaml_file( @@ -186,7 +51,6 @@ def init_task(): ) # Initialization - print("🔧 Initializing MOS with Scheduler...") mos = MOSForTestScheduler(mos_config) user_id = "user_1" @@ -197,15 +61,15 @@ def init_task(): if Path(mem_cube_name_or_path).exists(): shutil.rmtree(mem_cube_name_or_path) - print(f"🗑️ {mem_cube_name_or_path} is not empty, and has been removed.") + print(f"{mem_cube_name_or_path} is not empty, and has been removed.") mem_cube = GeneralMemCube(mem_cube_config) mem_cube.dump(mem_cube_name_or_path) mos.register_mem_cube( mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id ) + mos.mem_scheduler.current_mem_cube = mem_cube - print("📚 Adding initial conversations...") mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) # Add interfering conversations @@ -214,77 +78,11 @@ def init_task(): mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id) mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id) - # Display initial status - print("\n📊 INITIAL SYSTEM STATUS:") - display_scheduler_status(mos) - display_memory_cube_stats(mos, user_id, mem_cube_id) - - # Process questions with enhanced monitoring - print(f"\n🎯 Starting Question Processing ({len(questions)} questions)...") - question_start_time = time.time() - - for i, item in enumerate(questions, 1): - print(f"\n{'=' * 20} Question {i}/{len(questions)} {'=' * 20}") - print(f"📝 Category: {item['category']} | Difficulty: {item['difficulty']}") - print(f"🎯 Expected: {item['expected']}") - if "hint" in item: - print(f"💡 Hint: {item['hint']}") - if "requires" in item: - print(f"🔍 Requires: {', '.join(item['requires'])}") - - print(f"\n🚀 Processing Query: {item['question']}") - query_start_time = time.time() - - response = mos.chat(query=item["question"], user_id=user_id) - - query_time = time.time() - query_start_time - print(f"⏱️ Query Processing Time: {query_time:.3f}s") - print(f"🤖 Response: {response}") - - # Display intermediate status every 2 questions - if i % 2 == 0: - print(f"\n📊 INTERMEDIATE STATUS (Question {i}):") - display_scheduler_status(mos) - display_memory_cube_stats(mos, user_id, mem_cube_id) - - total_processing_time = time.time() - question_start_time - print(f"\n⏱️ Total Question Processing Time: {total_processing_time:.3f}s") - - # Display final scheduler performance summary - print("\n" + "=" * 80) - print("📊 FINAL SCHEDULER PERFORMANCE SUMMARY") - print("=" * 80) - - summary = mos.get_scheduler_summary() - print(f"🔢 Total Queries Processed: {summary['total_queries']}") - print(f"⚡ Total Scheduler Calls: {summary['total_scheduler_calls']}") - print(f"⏱️ Average Scheduler Response Time: {summary['average_scheduler_response_time']:.3f}s") - print(f"🧠 Memory Optimizations Applied: {summary['memory_optimization_count']}") - print(f"🔄 Working Memory Updates: {summary['working_memory_updates']}") - print(f"⚡ Activation Memory Updates: {summary['activation_memory_updates']}") - print(f"📈 Average Query Processing Time: {summary['average_query_processing_time']:.3f}s") - - # Performance insights - print("\n💡 PERFORMANCE INSIGHTS:") - if summary["total_scheduler_calls"] > 0: - optimization_rate = ( - summary["memory_optimization_count"] / summary["total_scheduler_calls"] - ) * 100 - print(f" • Memory Optimization Rate: {optimization_rate:.1f}%") - - if summary["average_scheduler_response_time"] < 0.1: - print(" • Scheduler Performance: 🟢 Excellent (< 100ms)") - elif summary["average_scheduler_response_time"] < 0.5: - print(" • Scheduler Performance: 🟡 Good (100-500ms)") - else: - print(" • Scheduler Performance: 🔴 Needs Improvement (> 500ms)") - - # Final system status - print("\n🔍 FINAL SYSTEM STATUS:") - display_scheduler_status(mos) - display_memory_cube_stats(mos, user_id, mem_cube_id) - - print("=" * 80) - print("🏁 Test completed successfully!") + for item in questions: + print("===== Chat Start =====") + query = item["question"] + print(f"Query:\n {query}\n") + response = mos.chat(query=query, user_id=user_id) + print(f"Answer:\n {response}\n") mos.mem_scheduler.stop() diff --git a/examples/mem_scheduler/orm_examples.py b/examples/mem_scheduler/orm_examples.py deleted file mode 100644 index bbb57b4ab..000000000 --- a/examples/mem_scheduler/orm_examples.py +++ /dev/null @@ -1,374 +0,0 @@ -#!/usr/bin/env python3 -""" -ORM Examples for MemScheduler - -This script demonstrates how to use the BaseDBManager's new environment variable loading methods -for MySQL and Redis connections. -""" - -import multiprocessing -import os -import sys - -from pathlib import Path - - -# Add the src directory to the Python path -sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src")) - -from memos.log import get_logger -from memos.mem_scheduler.orm_modules.base_model import BaseDBManager, DatabaseError -from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager, SimpleListManager - - -logger = get_logger(__name__) - - -def test_mysql_engine_from_env(): - """Test loading MySQL engine from environment variables""" - print("\n" + "=" * 60) - print("Testing MySQL Engine from Environment Variables") - print("=" * 60) - - try: - # Test loading MySQL engine from current environment variables - mysql_engine = BaseDBManager.load_mysql_engine_from_env() - if mysql_engine is None: - print("❌ Failed to create MySQL engine - check environment variables") - return - - print(f"✅ Successfully created MySQL engine: {mysql_engine}") - print(f" Engine URL: {mysql_engine.url}") - - # Test connection - with mysql_engine.connect() as conn: - from sqlalchemy import text - - result = conn.execute(text("SELECT 'MySQL connection test successful' as message")) - message = result.fetchone()[0] - print(f" Connection test: {message}") - - mysql_engine.dispose() - print(" MySQL engine disposed successfully") - - except DatabaseError as e: - print(f"❌ DatabaseError: {e}") - except Exception as e: - print(f"❌ Unexpected error: {e}") - - -def test_redis_connection_from_env(): - """Test loading Redis connection from environment variables""" - print("\n" + "=" * 60) - print("Testing Redis Connection from Environment Variables") - print("=" * 60) - - try: - # Test loading Redis connection from current environment variables - redis_client = BaseDBManager.load_redis_engine_from_env() - if redis_client is None: - print("❌ Failed to create Redis connection - check environment variables") - return - - print(f"✅ Successfully created Redis connection: {redis_client}") - - # Test basic Redis operations - redis_client.set("test_key", "Hello from ORM Examples!") - value = redis_client.get("test_key") - print(f" Redis test - Set/Get: {value}") - - # Test Redis info - info = redis_client.info("server") - redis_version = info.get("redis_version", "unknown") - print(f" Redis server version: {redis_version}") - - # Clean up test key - redis_client.delete("test_key") - print(" Test key cleaned up") - - redis_client.close() - print(" Redis connection closed successfully") - - except DatabaseError as e: - print(f"❌ DatabaseError: {e}") - except Exception as e: - print(f"❌ Unexpected error: {e}") - - -def test_environment_variables(): - """Test and display current environment variables""" - print("\n" + "=" * 60) - print("Current Environment Variables") - print("=" * 60) - - # MySQL environment variables - mysql_vars = [ - "MYSQL_HOST", - "MYSQL_PORT", - "MYSQL_USERNAME", - "MYSQL_PASSWORD", - "MYSQL_DATABASE", - "MYSQL_CHARSET", - ] - - print("\nMySQL Environment Variables:") - for var in mysql_vars: - value = os.getenv(var, "Not set") - # Mask password for security - if "PASSWORD" in var and value != "Not set": - value = "*" * len(value) - print(f" {var}: {value}") - - # Redis environment variables - redis_vars = [ - "REDIS_HOST", - "REDIS_PORT", - "REDIS_DB", - "REDIS_PASSWORD", - "MEMSCHEDULER_REDIS_HOST", - "MEMSCHEDULER_REDIS_PORT", - "MEMSCHEDULER_REDIS_DB", - "MEMSCHEDULER_REDIS_PASSWORD", - ] - - print("\nRedis Environment Variables:") - for var in redis_vars: - value = os.getenv(var, "Not set") - # Mask password for security - if "PASSWORD" in var and value != "Not set": - value = "*" * len(value) - print(f" {var}: {value}") - - -def test_manual_env_loading(): - """Test loading environment variables manually from .env file""" - print("\n" + "=" * 60) - print("Testing Manual Environment Loading") - print("=" * 60) - - env_file_path = "/Users/travistang/Documents/codes/memos/.env" - - if not os.path.exists(env_file_path): - print(f"❌ Environment file not found: {env_file_path}") - return - - try: - from dotenv import load_dotenv - - # Load environment variables - load_dotenv(env_file_path) - print(f"✅ Successfully loaded environment variables from {env_file_path}") - - # Test some key variables - test_vars = ["OPENAI_API_KEY", "MOS_CHAT_MODEL", "TZ"] - for var in test_vars: - value = os.getenv(var, "Not set") - if "KEY" in var and value != "Not set": - value = f"{value[:10]}..." if len(value) > 10 else value - print(f" {var}: {value}") - - except ImportError: - print("❌ python-dotenv not installed. Install with: pip install python-dotenv") - except Exception as e: - print(f"❌ Error loading environment file: {e}") - - -def test_redis_lockable_orm_with_list(): - """Test RedisDBManager with list[str] type synchronization""" - print("\n" + "=" * 60) - print("Testing RedisDBManager with list[str]") - print("=" * 60) - - try: - from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager - - # Create a simple list manager instance - list_manager = SimpleListManager(["apple", "banana", "cherry"]) - print(f"Original list manager: {list_manager}") - - # Create RedisDBManager instance - redis_client = BaseDBManager.load_redis_engine_from_env() - if redis_client is None: - print("❌ Failed to create Redis connection - check environment variables") - return - - db_manager = RedisDBManager( - redis_client=redis_client, - user_id="test_user", - mem_cube_id="test_list_cube", - obj=list_manager, - ) - - # Save to Redis - db_manager.save_to_db(list_manager) - print("✅ List manager saved to Redis") - - # Load from Redis - loaded_manager = db_manager.load_from_db() - if loaded_manager: - print(f"Loaded list manager: {loaded_manager}") - print(f"Items match: {list_manager.items == loaded_manager.items}") - else: - print("❌ Failed to load list manager from Redis") - - # Clean up - redis_client.delete("lockable_orm:test_user:test_list_cube:data") - redis_client.delete("lockable_orm:test_user:test_list_cube:lock") - redis_client.delete("lockable_orm:test_user:test_list_cube:version") - redis_client.close() - - except Exception as e: - print(f"❌ Error in RedisDBManager test: {e}") - - -def modify_list_process(process_id: int, items_to_add: list[str]): - """Function to be run in separate processes to modify the list using merge_items""" - try: - from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager - - # Create Redis connection - redis_client = BaseDBManager.load_redis_engine_from_env() - if redis_client is None: - print(f"Process {process_id}: Failed to create Redis connection") - return - - # Create a temporary list manager for this process with items to add - temp_manager = SimpleListManager() - - db_manager = RedisDBManager( - redis_client=redis_client, - user_id="test_user", - mem_cube_id="multiprocess_list", - obj=temp_manager, - ) - - print(f"Process {process_id}: Starting modification with items: {items_to_add}") - for item in items_to_add: - db_manager.obj.add_item(item) - # Use sync_with_orm which internally uses merge_items - db_manager.sync_with_orm(size_limit=None) - - print(f"Process {process_id}: Successfully synchronized with Redis") - - redis_client.close() - - except Exception as e: - print(f"Process {process_id}: Error - {e}") - import traceback - - traceback.print_exc() - - -def test_multiprocess_synchronization(): - """Test multiprocess synchronization with RedisDBManager""" - print("\n" + "=" * 60) - print("Testing Multiprocess Synchronization") - print("=" * 60) - - try: - # Initialize Redis with empty list - redis_client = BaseDBManager.load_redis_engine_from_env() - if redis_client is None: - print("❌ Failed to create Redis connection") - return - - # Initialize with empty list - initial_manager = SimpleListManager([]) - db_manager = RedisDBManager( - redis_client=redis_client, - user_id="test_user", - mem_cube_id="multiprocess_list", - obj=initial_manager, - ) - db_manager.save_to_db(initial_manager) - print("✅ Initialized empty list manager in Redis") - - # Define items for each process to add - process_items = [ - ["item1", "item2"], - ["item3", "item4"], - ["item5", "item6"], - ["item1", "item7"], # item1 is duplicate, should not be added twice - ] - - # Create and start processes - processes = [] - for i, items in enumerate(process_items): - p = multiprocessing.Process(target=modify_list_process, args=(i + 1, items)) - processes.append(p) - p.start() - - # Wait for all processes to complete - for p in processes: - p.join() - - print("\n" + "-" * 40) - print("All processes completed. Checking final result...") - - # Load final result - final_db_manager = RedisDBManager( - redis_client=redis_client, - user_id="test_user", - mem_cube_id="multiprocess_list", - obj=SimpleListManager([]), - ) - final_manager = final_db_manager.load_from_db() - - if final_manager: - print(f"Final synchronized list manager: {final_manager}") - print(f"Final list length: {len(final_manager)}") - print("Expected items: {'item1', 'item2', 'item3', 'item4', 'item5', 'item6', 'item7'}") - print(f"Actual items: {set(final_manager.items)}") - - # Check if all unique items are present - expected_items = {"item1", "item2", "item3", "item4", "item5", "item6", "item7"} - actual_items = set(final_manager.items) - - if expected_items == actual_items: - print("✅ All processes contributed correctly - synchronization successful!") - else: - print(f"❌ Expected items: {expected_items}") - print(f" Actual items: {actual_items}") - else: - print("❌ Failed to load final result") - - # Clean up - redis_client.delete("lockable_orm:test_user:multiprocess_list:data") - redis_client.delete("lockable_orm:test_user:multiprocess_list:lock") - redis_client.delete("lockable_orm:test_user:multiprocess_list:version") - redis_client.close() - - except Exception as e: - print(f"❌ Error in multiprocess synchronization test: {e}") - - -def main(): - """Main function to run all tests""" - print("ORM Examples - Environment Variable Loading Tests") - print("=" * 80) - - # Test environment variables display - test_environment_variables() - - # Test manual environment loading - test_manual_env_loading() - - # Test MySQL engine loading - test_mysql_engine_from_env() - - # Test Redis connection loading - test_redis_connection_from_env() - - # Test RedisLockableORM with list[str] - test_redis_lockable_orm_with_list() - - # Test multiprocess synchronization - test_multiprocess_synchronization() - - print("\n" + "=" * 80) - print("All tests completed!") - print("=" * 80) - - -if __name__ == "__main__": - main() diff --git a/examples/mem_scheduler/redis_example.py b/examples/mem_scheduler/redis_example.py index 1660d6c02..2c3801539 100644 --- a/examples/mem_scheduler/redis_example.py +++ b/examples/mem_scheduler/redis_example.py @@ -22,7 +22,7 @@ sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory -async def service_run(): +def service_run(): # Init example_scheduler_config_path = ( f"{BASE_DIR}/examples/data/config/mem_scheduler/general_scheduler_config.yaml" @@ -60,11 +60,11 @@ async def service_run(): content=query, timestamp=datetime.now(), ) - res = await mem_scheduler.redis_add_message_stream(message=message_item.to_dict()) + res = mem_scheduler.redis_add_message_stream(message=message_item.to_dict()) print( f"Added: {res}", ) - await asyncio.sleep(0.5) + asyncio.sleep(0.5) mem_scheduler.redis_stop_listening() @@ -72,4 +72,4 @@ async def service_run(): if __name__ == "__main__": - asyncio.run(service_run()) + service_run() diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py index de99f1c95..4aedac711 100644 --- a/examples/mem_scheduler/try_schedule_modules.py +++ b/examples/mem_scheduler/try_schedule_modules.py @@ -176,6 +176,7 @@ def show_web_logs(mem_scheduler: GeneralScheduler): mos.register_mem_cube( mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id ) + mos.mem_scheduler.current_mem_cube = mem_cube mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id) diff --git a/src/memos/api/config.py b/src/memos/api/config.py index f02edaad6..a276fa63d 100644 --- a/src/memos/api/config.py +++ b/src/memos/api/config.py @@ -175,7 +175,7 @@ def start_config_watch(cls): @classmethod def start_watch_if_enabled(cls) -> None: enable = os.getenv("NACOS_ENABLE_WATCH", "false").lower() == "true" - print("enable:", enable) + logger.info(f"NACOS_ENABLE_WATCH: {enable}") if not enable: return interval = int(os.getenv("NACOS_WATCH_INTERVAL", "60")) @@ -623,7 +623,10 @@ def get_scheduler_config() -> dict[str, Any]: "MOS_SCHEDULER_ENABLE_PARALLEL_DISPATCH", "true" ).lower() == "true", - "enable_activation_memory": True, + "enable_activation_memory": os.getenv( + "MOS_SCHEDULER_ENABLE_ACTIVATION_MEMORY", "false" + ).lower() + == "true", }, } diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 0412754c3..3b1ce2fc9 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -171,7 +171,9 @@ class APISearchRequest(BaseRequest): query: str = Field(..., description="Search query") user_id: str = Field(None, description="User ID") mem_cube_id: str | None = Field(None, description="Cube ID to search in") - mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture") + mode: SearchMode = Field( + SearchMode.NOT_INITIALIZED, description="search mode: fast, fine, or mixture" + ) internet_search: bool = Field(False, description="Whether to use internet search") moscube: bool = Field(False, description="Whether to use MemOSCube") top_k: int = Field(10, description="Number of results to return") diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index b426c2965..7d9f141dc 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -279,12 +279,14 @@ def init_server(): db_engine=BaseDBManager.create_default_sqlite_engine(), mem_reader=mem_reader, ) - mem_scheduler.current_mem_cube = naive_mem_cube - mem_scheduler.start() + mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube) # Initialize SchedulerAPIModule api_module = mem_scheduler.api_module + if os.getenv("API_SCHEDULER_ON", True): + mem_scheduler.start() + return ( graph_db, mem_reader, @@ -384,8 +386,10 @@ def search_memories(search_req: APISearchRequest): "pref_mem": [], "pref_note": "", } - - search_mode = search_req.mode + if search_req.mode == SearchMode.NOT_INITIALIZED: + search_mode = os.getenv("SEARCH_MODE", SearchMode.FAST) + else: + search_mode = search_req.mode def _search_text(): try: @@ -481,22 +485,38 @@ def fine_search_memories( target_session_id = "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - # Create MemCube and perform search - search_results = naive_mem_cube.text_mem.search( + searcher = mem_scheduler.searcher + + info = { + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + } + + fast_retrieved_memories = searcher.retrieve( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, - mode=SearchMode.FINE, + mode=SearchMode.FAST, manual_close_internet=not search_req.internet_search, moscube=search_req.moscube, search_filter=search_filter, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, + info=info, ) - formatted_memories = [_format_memory_item(data) for data in search_results] + + fast_memories = searcher.post_retrieve( + retrieved_results=fast_retrieved_memories, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) + + enhanced_results, _ = mem_scheduler.retriever.enhance_memories_with_query( + query_history=[search_req.query], + memories=fast_memories, + ) + + formatted_memories = [_format_memory_item(data) for data in enhanced_results] return formatted_memories diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py index e757f243b..afdaf6871 100644 --- a/src/memos/configs/mem_scheduler.py +++ b/src/memos/configs/mem_scheduler.py @@ -12,10 +12,13 @@ BASE_DIR, DEFAULT_ACT_MEM_DUMP_PATH, DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT, + DEFAULT_CONSUME_BATCH, DEFAULT_CONSUME_INTERVAL_SECONDS, DEFAULT_CONTEXT_WINDOW_SIZE, DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE, DEFAULT_MULTI_TASK_RUNNING_TIMEOUT, + DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE, + DEFAULT_SCHEDULER_RETRIEVER_RETRIES, DEFAULT_THREAD_POOL_MAX_WORKERS, DEFAULT_TOP_K, DEFAULT_USE_REDIS_QUEUE, @@ -43,6 +46,11 @@ class BaseSchedulerConfig(BaseConfig): gt=0, description=f"Interval for consuming messages from queue in seconds (default: {DEFAULT_CONSUME_INTERVAL_SECONDS})", ) + consume_batch: int = Field( + default=DEFAULT_CONSUME_BATCH, + gt=0, + description=f"Number of messages to consume in each batch (default: {DEFAULT_CONSUME_BATCH})", + ) auth_config_path: str | None = Field( default=None, description="Path to the authentication configuration file containing private credentials", @@ -91,6 +99,17 @@ class GeneralSchedulerConfig(BaseSchedulerConfig): description="Capacity of the activation memory monitor", ) + # Memory enhancement concurrency & retries configuration + enhance_batch_size: int | None = Field( + default=DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE, + description="Batch size for concurrent memory enhancement; None or <=1 disables batching", + ) + enhance_retries: int = Field( + default=DEFAULT_SCHEDULER_RETRIEVER_RETRIES, + ge=0, + description="Number of retry attempts per enhancement batch", + ) + # Database configuration for ORM persistence db_path: str | None = Field( default=None, diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 97ff9879f..1b6d4e126 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -283,7 +283,6 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=QUERY_LABEL, content=query, timestamp=datetime.utcnow(), @@ -344,7 +343,6 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=ANSWER_LABEL, content=response, timestamp=datetime.utcnow(), @@ -768,12 +766,10 @@ def process_textual_memory(): ) # submit messages for scheduler if self.enable_mem_scheduler and self.mem_scheduler is not None: - mem_cube = self.mem_cubes[mem_cube_id] if sync_mode == "async": message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=MEM_READ_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), @@ -783,7 +779,6 @@ def process_textual_memory(): message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=ADD_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), @@ -797,7 +792,6 @@ def process_preference_memory(): and self.mem_cubes[mem_cube_id].pref_mem ): messages_list = [messages] - mem_cube = self.mem_cubes[mem_cube_id] if sync_mode == "sync": pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( messages_list, @@ -816,7 +810,6 @@ def process_preference_memory(): user_id=target_user_id, session_id=target_session_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=PREF_ADD_LABEL, content=json.dumps(messages_list), timestamp=datetime.utcnow(), @@ -867,12 +860,10 @@ def process_preference_memory(): # submit messages for scheduler if self.enable_mem_scheduler and self.mem_scheduler is not None: - mem_cube = self.mem_cubes[mem_cube_id] if sync_mode == "async": message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=MEM_READ_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), @@ -882,7 +873,6 @@ def process_preference_memory(): message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=ADD_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), @@ -909,11 +899,9 @@ def process_preference_memory(): # submit messages for scheduler if self.enable_mem_scheduler and self.mem_scheduler is not None: - mem_cube = self.mem_cubes[mem_cube_id] message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=ADD_LABEL, content=json.dumps(mem_ids), timestamp=datetime.utcnow(), diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py index 6fc64c5e3..0114fc0da 100644 --- a/src/memos/mem_os/main.py +++ b/src/memos/mem_os/main.py @@ -205,7 +205,6 @@ def _chat_with_cot_enhancement( # Step 7: Submit message to scheduler (same as core method) if len(accessible_cubes) == 1: mem_cube_id = accessible_cubes[0].cube_id - mem_cube = self.mem_cubes[mem_cube_id] if self.enable_mem_scheduler and self.mem_scheduler is not None: from datetime import datetime @@ -217,7 +216,6 @@ def _chat_with_cot_enhancement( message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=ANSWER_LABEL, content=enhanced_response, timestamp=datetime.now().isoformat(), diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index 9ddb77b52..359db72ba 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -637,7 +637,6 @@ def _send_message_to_scheduler( message_item = ScheduleMessageItem( user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=self.mem_cubes[mem_cube_id], label=label, content=query, timestamp=datetime.utcnow(), diff --git a/src/memos/mem_scheduler/analyzer/api_analyzer.py b/src/memos/mem_scheduler/analyzer/api_analyzer.py index 28ca182e5..085025b7f 100644 --- a/src/memos/mem_scheduler/analyzer/api_analyzer.py +++ b/src/memos/mem_scheduler/analyzer/api_analyzer.py @@ -7,7 +7,6 @@ import http.client import json -import time from typing import Any from urllib.parse import urlparse @@ -15,6 +14,7 @@ import requests from memos.log import get_logger +from memos.mem_scheduler.schemas.general_schemas import SearchMode logger = get_logger(__name__) @@ -487,7 +487,7 @@ def search_in_conversation(self, query, mode="fast", top_k=10, include_history=T return result - def test_continuous_conversation(self): + def test_continuous_conversation(self, mode=SearchMode.MIXTURE): """Test continuous conversation functionality""" print("=" * 80) print("Testing Continuous Conversation Functionality") @@ -542,15 +542,15 @@ def test_continuous_conversation(self): # Search for trip-related information self.search_in_conversation( - query="New Year's Eve Shanghai recommendations", mode="mixture", top_k=5 + query="New Year's Eve Shanghai recommendations", mode=mode, top_k=5 ) # Search for food-related information - self.search_in_conversation(query="budget food Shanghai", mode="mixture", top_k=3) + self.search_in_conversation(query="budget food Shanghai", mode=mode, top_k=3) # Search without conversation history self.search_in_conversation( - query="Shanghai travel", mode="mixture", top_k=3, include_history=False + query="Shanghai travel", mode=mode, top_k=3, include_history=False ) print("\n✅ Continuous conversation test completed successfully!") @@ -645,7 +645,7 @@ def create_test_add_request( operation=None, ) - def run_all_tests(self): + def run_all_tests(self, mode=SearchMode.MIXTURE): """Run all available tests""" print("🚀 Starting comprehensive test suite") print("=" * 80) @@ -653,8 +653,7 @@ def run_all_tests(self): # Test continuous conversation functionality print("\n💬 Testing CONTINUOUS CONVERSATION functions:") try: - self.test_continuous_conversation() - time.sleep(5) + self.test_continuous_conversation(mode=mode) print("✅ Continuous conversation test completed successfully") except Exception as e: print(f"❌ Continuous conversation test failed: {e}") @@ -682,7 +681,7 @@ def run_all_tests(self): print("Using direct test mode") try: direct_analyzer = DirectSearchMemoriesAnalyzer() - direct_analyzer.run_all_tests() + direct_analyzer.run_all_tests(mode=SearchMode.MIXTURE) except Exception as e: print(f"Direct test mode failed: {e}") import traceback diff --git a/src/memos/mem_scheduler/analyzer/eval_analyzer.py b/src/memos/mem_scheduler/analyzer/eval_analyzer.py new file mode 100644 index 000000000..d37e17456 --- /dev/null +++ b/src/memos/mem_scheduler/analyzer/eval_analyzer.py @@ -0,0 +1,1322 @@ +""" +Evaluation Analyzer for Bad Cases + +This module provides the EvalAnalyzer class that extracts bad cases from evaluation results +and analyzes whether memories contain sufficient information to answer golden answers. +""" + +import json +import os +import sys + +from pathlib import Path +from typing import Any + +from openai import OpenAI + +from memos.api.routers.server_router import mem_scheduler +from memos.log import get_logger +from memos.memories.textual.item import TextualMemoryMetadata +from memos.memories.textual.tree import TextualMemoryItem + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent # Go up to project root +sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory + +logger = get_logger(__name__) + + +class EvalAnalyzer: + """ + Evaluation Analyzer class for extracting and analyzing bad cases. + + This class extracts bad cases from evaluation results and uses LLM to analyze + whether memories contain sufficient information to answer golden answers. + """ + + def __init__( + self, + openai_api_key: str | None = None, + openai_base_url: str | None = None, + openai_model: str = "gpt-4o-mini", + output_dir: str = "./tmp/eval_analyzer", + ): + """ + Initialize the EvalAnalyzer. + + Args: + openai_api_key: OpenAI API key + openai_base_url: OpenAI base URL + openai_model: OpenAI model to use + output_dir: Output directory for results + """ + self.output_dir = Path(output_dir) + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Initialize OpenAI client + self.openai_client = OpenAI( + api_key=openai_api_key or os.getenv("MEMSCHEDULER_OPENAI_API_KEY"), + base_url=openai_base_url or os.getenv("MEMSCHEDULER_OPENAI_BASE_URL"), + ) + self.openai_model = openai_model or os.getenv( + "MEMSCHEDULER_OPENAI_DEFAULT_MODEL", "gpt-4o-mini" + ) + + logger.info(f"EvalAnalyzer initialized with model: {self.openai_model}") + + def load_json_file(self, filepath: str) -> Any: + """Load JSON file safely.""" + try: + with open(filepath, encoding="utf-8") as f: + return json.load(f) + except FileNotFoundError: + logger.error(f"File not found: {filepath}") + return None + except json.JSONDecodeError as e: + logger.error(f"JSON decode error in {filepath}: {e}") + return None + + def extract_bad_cases(self, judged_file: str, search_results_file: str) -> list[dict[str, Any]]: + """ + Extract bad cases from judged results and corresponding search results. + + Args: + judged_file: Path to the judged results JSON file + search_results_file: Path to the search results JSON file + + Returns: + List of bad cases with their memories + """ + logger.info(f"Loading judged results from: {judged_file}") + judged_data = self.load_json_file(judged_file) + if not judged_data: + return [] + + logger.info(f"Loading search results from: {search_results_file}") + search_data = self.load_json_file(search_results_file) + if not search_data: + return [] + + bad_cases = [] + + # Process each user's data + for user_id, user_judged_results in judged_data.items(): + user_search_results = search_data.get(user_id, []) + + # Create a mapping from query to search context + search_context_map = {} + for search_result in user_search_results: + query = search_result.get("query", "") + context = search_result.get("context", "") + search_context_map[query] = context + + # Process each question for this user + for result in user_judged_results: + # Check if this is a bad case (all judgments are False) + judgments = result.get("llm_judgments", {}) + is_bad_case = all(not judgment for judgment in judgments.values()) + + if is_bad_case: + question = result.get("question", "") + answer = result.get("answer", "") + golden_answer = result.get("golden_answer", "") + + # Find corresponding memories from search results + memories = search_context_map.get(question, "") + + bad_case = { + "user_id": user_id, + "query": question, + "answer": answer, + "golden_answer": golden_answer, + "memories": memories, + "category": result.get("category", 0), + "nlp_metrics": result.get("nlp_metrics", {}), + "response_duration_ms": result.get("response_duration_ms", 0), + "search_duration_ms": result.get("search_duration_ms", 0), + "total_duration_ms": result.get("total_duration_ms", 0), + } + + bad_cases.append(bad_case) + + logger.info(f"Extracted {len(bad_cases)} bad cases") + return bad_cases + + def analyze_memory_sufficiency( + self, query: str, golden_answer: str, memories: str + ) -> dict[str, Any]: + """ + Use LLM to analyze whether memories contain sufficient information to answer the golden answer. + + Args: + query: The original query + golden_answer: The correct answer + memories: The memory context + + Returns: + Analysis result containing sufficiency judgment and relevant memory indices + """ + prompt = f""" +You are an expert analyst tasked with determining whether the provided memories contain sufficient information to answer a specific question correctly. + +**Question:** {query} + +**Golden Answer (Correct Answer):** {golden_answer} + +**Available Memories:** +{memories} + +**Task:** +1. Analyze whether the memories contain enough information to derive the golden answer +2. Identify which specific memory entries (if any) contain relevant information +3. Provide a clear judgment: True if sufficient, False if insufficient + +**Response Format (JSON):** +{{ + "sufficient": true/false, + "confidence": 0.0-1.0, + "relevant_memories": ["memory_1", "memory_2", ...], + "reasoning": "Detailed explanation of your analysis", + "missing_information": "What key information is missing (if insufficient)" +}} + +**Guidelines:** +- Be strict in your evaluation - only mark as sufficient if the memories clearly contain the information needed +- Consider both direct and indirect information that could lead to the golden answer +- Pay attention to dates, names, events, and specific details +- If information is ambiguous or requires significant inference, lean towards insufficient +""" + + try: + response = self.openai_client.chat.completions.create( + model=self.openai_model, + messages=[ + { + "role": "system", + "content": "You are a precise analyst who evaluates information sufficiency.", + }, + {"role": "user", "content": prompt}, + ], + temperature=0.1, + max_tokens=1000, + ) + + content = response.choices[0].message.content.strip() + + # Try to parse JSON response + try: + # Remove markdown code blocks if present + if content.startswith("```json"): + content = content[7:] + if content.endswith("```"): + content = content[:-3] + content = content.strip() + + analysis = json.loads(content) + return analysis + + except json.JSONDecodeError: + logger.warning(f"Failed to parse LLM response as JSON: {content}") + return { + "sufficient": False, + "confidence": 0.0, + "relevant_memories": [], + "reasoning": f"Failed to parse LLM response: {content}", + "missing_information": "Analysis failed", + } + + except Exception as e: + logger.error(f"Error in LLM analysis: {e}") + return { + "sufficient": False, + "confidence": 0.0, + "relevant_memories": [], + "reasoning": f"Error occurred: {e!s}", + "missing_information": "Analysis failed due to error", + } + + def process_memories_with_llm( + self, memories: str, query: str, processing_type: str = "summarize" + ) -> dict[str, Any]: + """ + Use LLM to process memories for better question answering. + + Args: + memories: The raw memory content + query: The query that will be answered using these memories + processing_type: Type of processing ("summarize", "restructure", "enhance") + + Returns: + Dictionary containing processed memories and processing metadata + """ + if processing_type == "summarize": + prompt = f""" +You are an expert at summarizing and organizing information to help answer specific questions. + +**Target Question:** {query} + +**Raw Memories:** +{memories} + +**Task:** +Summarize and organize the above memories in a way that would be most helpful for answering the target question. Focus on: +1. Key facts and information relevant to the question +2. Important relationships and connections +3. Chronological or logical organization where applicable +4. Remove redundant or irrelevant information + +**Processed Memories:** +""" + elif processing_type == "restructure": + prompt = f""" +You are an expert at restructuring information to optimize question answering. + +**Target Question:** {query} + +**Raw Memories:** +{memories} + +**Task:** +Restructure the above memories into a clear, logical format that directly supports answering the target question. Organize by: +1. Most relevant information first +2. Supporting details and context +3. Clear categorization of different types of information +4. Logical flow that leads to the answer + +**Restructured Memories:** +""" + elif processing_type == "enhance": + prompt = f""" +You are an expert at enhancing information by adding context and making connections. + +**Target Question:** {query} + +**Raw Memories:** +{memories} + +**Task:** +Enhance the above memories by: +1. Making implicit connections explicit +2. Adding relevant context that helps answer the question +3. Highlighting key relationships between different pieces of information +4. Organizing information in a question-focused manner + +**Enhanced Memories:** +""" + else: + raise ValueError(f"Unknown processing_type: {processing_type}") + + try: + response = self.openai_client.chat.completions.create( + model=self.openai_model, + messages=[ + { + "role": "system", + "content": "You are an expert information processor who optimizes content for question answering.", + }, + {"role": "user", "content": prompt}, + ], + temperature=0.3, + max_tokens=2000, + ) + + processed_memories = response.choices[0].message.content.strip() + + return { + "processed_memories": processed_memories, + "processing_type": processing_type, + "original_length": len(memories), + "processed_length": len(processed_memories), + "compression_ratio": len(processed_memories) / len(memories) + if len(memories) > 0 + else 0, + } + + except Exception as e: + logger.error(f"Error in memory processing: {e}") + return { + "processed_memories": memories, # Fallback to original + "processing_type": processing_type, + "original_length": len(memories), + "processed_length": len(memories), + "compression_ratio": 1.0, + "error": str(e), + } + + def generate_answer_with_memories( + self, query: str, memories: str, memory_type: str = "original" + ) -> dict[str, Any]: + """ + Generate an answer to the query using the provided memories. + + Args: + query: The question to answer + memories: The memory content to use + memory_type: Type of memories ("original", "processed") + + Returns: + Dictionary containing the generated answer and metadata + """ + prompt = f""" + You are a knowledgeable and helpful AI assistant. + + # CONTEXT: + You have access to memories from two speakers in a conversation. These memories contain + timestamped information that may be relevant to answering the question. + + # INSTRUCTIONS: + 1. Carefully analyze all provided memories. Synthesize information across different entries if needed to form a complete answer. + 2. Pay close attention to the timestamps to determine the answer. If memories contain contradictory information, the **most recent memory** is the source of truth. + 3. If the question asks about a specific event or fact, look for direct evidence in the memories. + 4. Your answer must be grounded in the memories. However, you may use general world knowledge to interpret or complete information found within a memory (e.g., identifying a landmark mentioned by description). + 5. If the question involves time references (like "last year", "two months ago", etc.), you **must** calculate the actual date based on the memory's timestamp. For example, if a memory from 4 May 2022 mentions "went to India last year," then the trip occurred in 2021. + 6. Always convert relative time references to specific dates, months, or years in your final answer. + 7. Do not confuse character names mentioned in memories with the actual users who created them. + 8. The answer must be brief (under 5-6 words) and direct, with no extra description. + + # APPROACH (Think step by step): + 1. First, examine all memories that contain information related to the question. + 2. Synthesize findings from multiple memories if a single entry is insufficient. + 3. Examine timestamps and content carefully, looking for explicit dates, times, locations, or events. + 4. If the answer requires calculation (e.g., converting relative time references), perform the calculation. + 5. Formulate a precise, concise answer based on the evidence from the memories (and allowed world knowledge). + 6. Double-check that your answer directly addresses the question asked and adheres to all instructions. + 7. Ensure your final answer is specific and avoids vague time references. + + {memories} + + Question: {query} + + Answer: +""" + + try: + response = self.openai_client.chat.completions.create( + model=self.openai_model, + messages=[ + { + "role": "system", + "content": "You are a precise assistant who answers questions based only on provided information.", + }, + {"role": "user", "content": prompt}, + ], + temperature=0.1, + max_tokens=1000, + ) + + answer = response.choices[0].message.content.strip() + + return { + "answer": answer, + "memory_type": memory_type, + "query": query, + "memory_length": len(memories), + "answer_length": len(answer), + } + + except Exception as e: + logger.error(f"Error in answer generation: {e}") + return { + "answer": f"Error generating answer: {e!s}", + "memory_type": memory_type, + "query": query, + "memory_length": len(memories), + "answer_length": 0, + "error": str(e), + } + + def compare_answer_quality( + self, query: str, golden_answer: str, original_answer: str, processed_answer: str + ) -> dict[str, Any]: + """ + Compare the quality of answers generated from original vs processed memories. + + Args: + query: The original query + golden_answer: The correct/expected answer + original_answer: Answer generated from original memories + processed_answer: Answer generated from processed memories + + Returns: + Dictionary containing comparison results + """ + prompt = f""" +You are an expert evaluator comparing the quality of two answers against a golden standard. + +**Question:** {query} + +**Golden Answer (Correct):** {golden_answer} + +**Answer A (Original Memories):** {original_answer} + +**Answer B (Processed Memories):** {processed_answer} + +**Task:** +Compare both answers against the golden answer and evaluate: +1. Accuracy: How correct is each answer? +2. Completeness: How complete is each answer? +3. Relevance: How relevant is each answer to the question? +4. Clarity: How clear and well-structured is each answer? + +**Response Format (JSON):** +{{ + "original_scores": {{ + "accuracy": 0.0-1.0, + "completeness": 0.0-1.0, + "relevance": 0.0-1.0, + "clarity": 0.0-1.0, + "overall": 0.0-1.0 + }}, + "processed_scores": {{ + "accuracy": 0.0-1.0, + "completeness": 0.0-1.0, + "relevance": 0.0-1.0, + "clarity": 0.0-1.0, + "overall": 0.0-1.0 + }}, + "winner": "original|processed|tie", + "improvement": 0.0-1.0, + "reasoning": "Detailed explanation of the comparison" +}} +""" + + try: + response = self.openai_client.chat.completions.create( + model=self.openai_model, + messages=[ + { + "role": "system", + "content": "You are an expert evaluator who compares answer quality objectively.", + }, + {"role": "user", "content": prompt}, + ], + temperature=0.1, + max_tokens=1500, + ) + + content = response.choices[0].message.content.strip() + + # Try to parse JSON response + try: + if content.startswith("```json"): + content = content[7:] + if content.endswith("```"): + content = content[:-3] + content = content.strip() + + comparison = json.loads(content) + return comparison + + except json.JSONDecodeError: + logger.warning(f"Failed to parse comparison response as JSON: {content}") + return { + "original_scores": { + "accuracy": 0.5, + "completeness": 0.5, + "relevance": 0.5, + "clarity": 0.5, + "overall": 0.5, + }, + "processed_scores": { + "accuracy": 0.5, + "completeness": 0.5, + "relevance": 0.5, + "clarity": 0.5, + "overall": 0.5, + }, + "winner": "tie", + "improvement": 0.0, + "reasoning": f"Failed to parse comparison: {content}", + } + + except Exception as e: + logger.error(f"Error in answer comparison: {e}") + return { + "original_scores": { + "accuracy": 0.0, + "completeness": 0.0, + "relevance": 0.0, + "clarity": 0.0, + "overall": 0.0, + }, + "processed_scores": { + "accuracy": 0.0, + "completeness": 0.0, + "relevance": 0.0, + "clarity": 0.0, + "overall": 0.0, + }, + "winner": "tie", + "improvement": 0.0, + "reasoning": f"Error occurred: {e!s}", + } + + def analyze_memory_processing_effectiveness( + self, + bad_cases: list[dict[str, Any]], + processing_types: list[str] | None = None, + ) -> dict[str, Any]: + """ + Analyze the effectiveness of different memory processing techniques. + + Args: + bad_cases: List of bad cases to analyze + processing_types: List of processing types to test + + Returns: + Dictionary containing comprehensive analysis results + """ + if processing_types is None: + processing_types = ["summarize", "restructure", "enhance"] + results = {"processing_results": [], "statistics": {}, "processing_types": processing_types} + + for i, case in enumerate(bad_cases): + logger.info(f"Processing case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...") + + case_result = { + "case_id": i, + "query": case["query"], + "golden_answer": case["golden_answer"], + "original_memories": case["memories"], + "processing_results": {}, + } + + # Generate answer with original memories + original_answer_result = self.generate_answer_with_memories( + case["query"], case["memories"], "original" + ) + case_result["original_answer"] = original_answer_result + + # Test each processing type + for processing_type in processing_types: + logger.info(f" Testing {processing_type} processing...") + + # Process memories + processing_result = self.process_memories_with_llm( + case["memories"], case["query"], processing_type + ) + + # Generate answer with processed memories + processed_answer_result = self.generate_answer_with_memories( + case["query"], + processing_result["processed_memories"], + f"processed_{processing_type}", + ) + + # Compare answer quality + comparison_result = self.compare_answer_quality( + case["query"], + case["golden_answer"], + original_answer_result["answer"], + processed_answer_result["answer"], + ) + + case_result["processing_results"][processing_type] = { + "processing": processing_result, + "answer": processed_answer_result, + "comparison": comparison_result, + } + + results["processing_results"].append(case_result) + + # Calculate statistics + self._calculate_processing_statistics(results) + + return results + + def _calculate_processing_statistics(self, results: dict[str, Any]) -> None: + """Calculate statistics for processing effectiveness analysis.""" + processing_types = results["processing_types"] + processing_results = results["processing_results"] + + if not processing_results: + results["statistics"] = {} + return + + stats = {"total_cases": len(processing_results), "processing_type_stats": {}} + + for processing_type in processing_types: + type_stats = { + "wins": 0, + "ties": 0, + "losses": 0, + "avg_improvement": 0.0, + "avg_compression_ratio": 0.0, + "avg_scores": { + "accuracy": 0.0, + "completeness": 0.0, + "relevance": 0.0, + "clarity": 0.0, + "overall": 0.0, + }, + } + + valid_cases = [] + for case in processing_results: + if processing_type in case["processing_results"]: + result = case["processing_results"][processing_type] + comparison = result["comparison"] + + # Count wins/ties/losses + if comparison["winner"] == "processed": + type_stats["wins"] += 1 + elif comparison["winner"] == "tie": + type_stats["ties"] += 1 + else: + type_stats["losses"] += 1 + + valid_cases.append(result) + + if valid_cases: + # Calculate averages + type_stats["avg_improvement"] = sum( + case["comparison"]["improvement"] for case in valid_cases + ) / len(valid_cases) + + type_stats["avg_compression_ratio"] = sum( + case["processing"]["compression_ratio"] for case in valid_cases + ) / len(valid_cases) + + # Calculate average scores + for score_type in type_stats["avg_scores"]: + type_stats["avg_scores"][score_type] = sum( + case["comparison"]["processed_scores"][score_type] for case in valid_cases + ) / len(valid_cases) + + # Calculate win rate + total_decisions = type_stats["wins"] + type_stats["ties"] + type_stats["losses"] + type_stats["win_rate"] = ( + type_stats["wins"] / total_decisions if total_decisions > 0 else 0.0 + ) + type_stats["success_rate"] = ( + (type_stats["wins"] + type_stats["ties"]) / total_decisions + if total_decisions > 0 + else 0.0 + ) + + stats["processing_type_stats"][processing_type] = type_stats + + results["statistics"] = stats + + def analyze_bad_cases(self, bad_cases: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Analyze all bad cases to determine memory sufficiency. + + Args: + bad_cases: List of bad cases to analyze + + Returns: + List of analyzed bad cases with sufficiency information + """ + analyzed_cases = [] + + for i, case in enumerate(bad_cases): + logger.info(f"Analyzing bad case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...") + + analysis = self.analyze_memory_sufficiency( + case["query"], case["golden_answer"], case["memories"] + ) + + # Add analysis results to the case + analyzed_case = case.copy() + analyzed_case.update( + { + "memory_analysis": analysis, + "has_sufficient_memories": analysis["sufficient"], + "analysis_confidence": analysis["confidence"], + "relevant_memory_count": len(analysis["relevant_memories"]), + } + ) + + analyzed_cases.append(analyzed_case) + + return analyzed_cases + + def collect_bad_cases(self, eval_result_dir: str | None = None) -> dict[str, Any]: + """ + Main method to collect and analyze bad cases from evaluation results. + + Args: + eval_result_dir: Directory containing evaluation results + + Returns: + Dictionary containing analysis results and statistics + """ + if eval_result_dir is None: + eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-072005-fast" + + judged_file = os.path.join(eval_result_dir, "memos-api_locomo_judged.json") + search_results_file = os.path.join(eval_result_dir, "memos-api_locomo_search_results.json") + + # Extract bad cases + bad_cases = self.extract_bad_cases(judged_file, search_results_file) + + if not bad_cases: + logger.warning("No bad cases found") + return {"bad_cases": [], "statistics": {}} + + # Analyze bad cases + analyzed_cases = self.analyze_bad_cases(bad_cases) + + # Calculate statistics + total_cases = len(analyzed_cases) + sufficient_cases = sum( + 1 for case in analyzed_cases if case.get("has_sufficient_memories", False) + ) + insufficient_cases = total_cases - sufficient_cases + + avg_confidence = ( + sum(case["analysis_confidence"] for case in analyzed_cases) / total_cases + if total_cases > 0 + else 0 + ) + avg_relevant_memories = ( + sum(case["relevant_memory_count"] for case in analyzed_cases) / total_cases + if total_cases > 0 + else 0 + ) + + statistics = { + "total_bad_cases": total_cases, + "sufficient_memory_cases": sufficient_cases, + "insufficient_memory_cases": insufficient_cases, + "sufficiency_rate": sufficient_cases / total_cases if total_cases > 0 else 0, + "average_confidence": avg_confidence, + "average_relevant_memories": avg_relevant_memories, + } + + # Save results + results = { + "bad_cases": analyzed_cases, + "statistics": statistics, + "metadata": { + "eval_result_dir": eval_result_dir, + "judged_file": judged_file, + "search_results_file": search_results_file, + "analysis_model": self.openai_model, + }, + } + + output_file = self.output_dir / "bad_cases_analysis.json" + with open(output_file, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + + logger.info(f"Analysis complete. Results saved to: {output_file}") + logger.info(f"Statistics: {statistics}") + + return results + + def _parse_json_response(self, response_text: str) -> dict: + """ + Parse JSON response from LLM, handling various formats and potential errors. + + Args: + response_text: Raw response text from LLM + + Returns: + Parsed JSON dictionary + + Raises: + ValueError: If JSON cannot be parsed + """ + import re + + # Try to extract JSON from response text + # Look for JSON blocks between ```json and ``` or just {} blocks + json_patterns = [r"```json\s*(\{.*?\})\s*```", r"```\s*(\{.*?\})\s*```", r"(\{.*\})"] + + for pattern in json_patterns: + matches = re.findall(pattern, response_text, re.DOTALL) + if matches: + json_str = matches[0].strip() + try: + return json.loads(json_str) + except json.JSONDecodeError: + continue + + # If no JSON pattern found, try parsing the entire response + try: + return json.loads(response_text.strip()) + except json.JSONDecodeError as e: + logger.error(f"Failed to parse JSON response: {response_text[:200]}...") + raise ValueError(f"Invalid JSON response: {e!s}") from e + + def filter_memories_with_llm(self, memories: list[str], query: str) -> tuple[list[str], bool]: + """ + Use LLM to filter memories based on relevance to the query. + + Args: + memories: List of memory strings + query: Query to filter memories against + + Returns: + Tuple of (filtered_memories, success_flag) + """ + if not memories: + return [], True + + # Build prompt for memory filtering + memories_text = "\n".join([f"{i + 1}. {memory}" for i, memory in enumerate(memories)]) + + prompt = f"""You are a memory filtering system. Given a query and a list of memories, identify which memories are relevant and non-redundant for answering the query. + +Query: {query} + +Memories: +{memories_text} + +Please analyze each memory and return a JSON response with the following format: +{{ + "relevant_memory_indices": [list of indices (1-based) of memories that are relevant to the query], + "reasoning": "Brief explanation of your filtering decisions" +}} + +Only include memories that are directly relevant to answering the query. Remove redundant or unrelated memories.""" + + try: + response = self.openai_client.chat.completions.create( + model=self.openai_model, + messages=[{"role": "user", "content": prompt}], + temperature=0.1, + ) + + response_text = response.choices[0].message.content + + # Extract JSON from response + result = self._parse_json_response(response_text) + + if "relevant_memory_indices" in result: + relevant_indices = result["relevant_memory_indices"] + filtered_memories = [] + + for idx in relevant_indices: + if 1 <= idx <= len(memories): + filtered_memories.append(memories[idx - 1]) + + logger.info(f"Filtered memories: {len(memories)} -> {len(filtered_memories)}") + return filtered_memories, True + else: + logger.warning("Invalid response format from memory filtering LLM") + return memories, False + + except Exception as e: + logger.error(f"Error in memory filtering: {e}") + return memories, False + + def evaluate_answer_ability_with_llm(self, query: str, memories: list[str]) -> bool: + """ + Use LLM to evaluate whether the given memories can answer the query. + + Args: + query: Query to evaluate + memories: List of memory strings + + Returns: + Boolean indicating whether memories can answer the query + """ + if not memories: + return False + + memories_text = "\n".join([f"- {memory}" for memory in memories]) + + prompt = f"""You are an answer ability evaluator. Given a query and a list of memories, determine whether the memories contain sufficient information to answer the query. + +Query: {query} + +Available Memories: +{memories_text} + +Please analyze the memories and return a JSON response with the following format: +{{ + "can_answer": true/false, + "confidence": 0.0-1.0, + "reasoning": "Brief explanation of your decision" +}} + +Consider whether the memories contain the specific information needed to provide a complete and accurate answer to the query.""" + + try: + response = self.openai_client.chat.completions.create( + model=self.openai_model, + messages=[{"role": "user", "content": prompt}], + temperature=0.1, + ) + + response_text = response.choices[0].message.content + result = self._parse_json_response(response_text) + + if "can_answer" in result: + can_answer = result["can_answer"] + confidence = result.get("confidence", 0.5) + reasoning = result.get("reasoning", "No reasoning provided") + + logger.info( + f"Answer ability evaluation: {can_answer} (confidence: {confidence:.2f}) - {reasoning}" + ) + return can_answer + else: + logger.warning("Invalid response format from answer ability evaluation") + return False + + except Exception as e: + logger.error(f"Error in answer ability evaluation: {e}") + return False + + def memory_llm_processing_analysis( + self, bad_cases: list[dict[str, Any]], use_llm_filtering: bool = True + ) -> list[dict[str, Any]]: + """ + Analyze bad cases by processing memories with LLM filtering and testing answer ability. + + This method: + 1. Parses memory strings from bad cases + 2. Uses LLM to filter unrelated and redundant memories + 3. Tests whether processed memories can help answer questions correctly + 4. Compares results before and after LLM processing + + Args: + bad_cases: List of bad cases to analyze + use_llm_filtering: Whether to use LLM filtering + + Returns: + List of analyzed bad cases with LLM processing results + """ + analyzed_cases = [] + + for i, case in enumerate(bad_cases): + logger.info(f"Processing bad case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...") + + try: + # Parse memory string + memories_text = case.get("memories", "") + if not memories_text: + logger.warning(f"No memories found for case {i + 1}") + analyzed_case = case.copy() + analyzed_case.update( + { + "llm_processing_analysis": { + "error": "No memories available", + "original_memories_count": 0, + "processed_memories_count": 0, + "can_answer_with_original": False, + "can_answer_with_processed": False, + "processing_improved_answer": False, + } + } + ) + analyzed_cases.append(analyzed_case) + continue + + # Split memories by lines + memory_lines = [line.strip() for line in memories_text.split("\n") if line.strip()] + original_memories = [line for line in memory_lines if line] + + logger.info(f"Parsed {len(original_memories)} memories from text") + + # Test answer ability with original memories + can_answer_original = self.evaluate_answer_ability_with_llm( + query=case["query"], memories=original_memories + ) + + # Process memories with LLM filtering if enabled + processed_memories = original_memories + processing_success = False + + if use_llm_filtering and len(original_memories) > 0: + processed_memories, processing_success = self.filter_memories_with_llm( + memories=original_memories, query=case["query"] + ) + logger.info( + f"LLM filtering: {len(original_memories)} -> {len(processed_memories)} memories, success: {processing_success}" + ) + + # Test answer ability with processed memories + can_answer_processed = self.evaluate_answer_ability_with_llm( + query=case["query"], memories=processed_memories + ) + + # Determine if processing improved answer ability + processing_improved = can_answer_processed and not can_answer_original + + # Create analysis result + llm_analysis = { + "processing_success": processing_success, + "original_memories_count": len(original_memories), + "processed_memories_count": len(processed_memories), + "memories_removed_count": len(original_memories) - len(processed_memories), + "can_answer_with_original": can_answer_original, + "can_answer_with_processed": can_answer_processed, + "processing_improved_answer": processing_improved, + "original_memories": original_memories, + "processed_memories": processed_memories, + } + + # Add analysis to case + analyzed_case = case.copy() + analyzed_case["llm_processing_analysis"] = llm_analysis + + logger.info( + f"Case {i + 1} analysis complete: " + f"Original: {can_answer_original}, " + f"Processed: {can_answer_processed}, " + f"Improved: {processing_improved}" + ) + + except Exception as e: + logger.error(f"Error processing case {i + 1}: {e}") + analyzed_case = case.copy() + analyzed_case["llm_processing_analysis"] = { + "error": str(e), + "processing_success": False, + "original_memories_count": 0, + "processed_memories_count": 0, + "can_answer_with_original": False, + "can_answer_with_processed": False, + "processing_improved_answer": False, + } + + analyzed_cases.append(analyzed_case) + + return analyzed_cases + + def scheduler_mem_process(self, query, memories): + from memos.mem_scheduler.utils.misc_utils import extract_list_items_in_answer + + _memories = [] + for mem in memories: + mem_item = TextualMemoryItem(memory=mem, metadata=TextualMemoryMetadata()) + _memories.append(mem_item) + prompt = mem_scheduler.retriever._build_enhancement_prompt( + query_history=[query], batch_texts=memories + ) + logger.debug( + f"[Enhance][batch={0}] Prompt (first 200 chars, len={len(prompt)}): {prompt[:200]}..." + ) + + response = mem_scheduler.retriever.process_llm.generate( + [{"role": "user", "content": prompt}] + ) + logger.debug(f"[Enhance][batch={0}] Response (first 200 chars): {response[:200]}...") + + processed_results = extract_list_items_in_answer(response) + + return { + "processed_memories": processed_results, + "processing_type": "enhance", + "original_length": len("\n".join(memories)), + "processed_length": len("\n".join(processed_results)), + "compression_ratio": len("\n".join(processed_results)) / len("\n".join(memories)) + if len(memories) > 0 + else 0, + } + + def analyze_bad_cases_with_llm_processing( + self, + bad_cases: list[dict[str, Any]], + save_results: bool = True, + output_file: str | None = None, + ) -> dict[str, Any]: + """ + Comprehensive analysis of bad cases with LLM memory processing. + + This method performs a complete analysis including: + 1. Basic bad case analysis + 2. LLM memory processing analysis + 3. Statistical summary of improvements + 4. Detailed reporting + + Args: + bad_cases: List of bad cases to analyze + save_results: Whether to save results to file + output_file: Optional output file path + + Returns: + Dictionary containing comprehensive analysis results + """ + from datetime import datetime + + logger.info( + f"Starting comprehensive analysis of {len(bad_cases)} bad cases with LLM processing" + ) + + # Perform LLM memory processing analysis + analyzed_cases = self.memory_llm_processing_analysis( + bad_cases=bad_cases, use_llm_filtering=True + ) + + # Calculate statistics + total_cases = len(analyzed_cases) + successful_processing = 0 + improved_cases = 0 + original_answerable = 0 + processed_answerable = 0 + total_memories_before = 0 + total_memories_after = 0 + + for case in analyzed_cases: + llm_analysis = case.get("llm_processing_analysis", {}) + + if llm_analysis.get("processing_success", False): + successful_processing += 1 + + if llm_analysis.get("processing_improved_answer", False): + improved_cases += 1 + + if llm_analysis.get("can_answer_with_original", False): + original_answerable += 1 + + if llm_analysis.get("can_answer_with_processed", False): + processed_answerable += 1 + + total_memories_before += llm_analysis.get("original_memories_count", 0) + total_memories_after += llm_analysis.get("processed_memories_count", 0) + + # Calculate improvement metrics + processing_success_rate = successful_processing / total_cases if total_cases > 0 else 0 + improvement_rate = improved_cases / total_cases if total_cases > 0 else 0 + original_answer_rate = original_answerable / total_cases if total_cases > 0 else 0 + processed_answer_rate = processed_answerable / total_cases if total_cases > 0 else 0 + memory_reduction_rate = ( + (total_memories_before - total_memories_after) / total_memories_before + if total_memories_before > 0 + else 0 + ) + + # Create comprehensive results + results = { + "analysis_metadata": { + "total_cases_analyzed": total_cases, + "analysis_timestamp": datetime.now().isoformat(), + "llm_model_used": self.openai_model, + }, + "processing_statistics": { + "successful_processing_count": successful_processing, + "processing_success_rate": processing_success_rate, + "cases_with_improvement": improved_cases, + "improvement_rate": improvement_rate, + "original_answerable_cases": original_answerable, + "original_answer_rate": original_answer_rate, + "processed_answerable_cases": processed_answerable, + "processed_answer_rate": processed_answer_rate, + "answer_rate_improvement": processed_answer_rate - original_answer_rate, + }, + "memory_statistics": { + "total_memories_before_processing": total_memories_before, + "total_memories_after_processing": total_memories_after, + "memories_removed": total_memories_before - total_memories_after, + "memory_reduction_rate": memory_reduction_rate, + "average_memories_per_case_before": total_memories_before / total_cases + if total_cases > 0 + else 0, + "average_memories_per_case_after": total_memories_after / total_cases + if total_cases > 0 + else 0, + }, + "analyzed_cases": analyzed_cases, + } + + # Log summary + logger.info("LLM Processing Analysis Summary:") + logger.info(f" - Total cases: {total_cases}") + logger.info(f" - Processing success rate: {processing_success_rate:.2%}") + logger.info(f" - Cases with improvement: {improved_cases} ({improvement_rate:.2%})") + logger.info(f" - Original answer rate: {original_answer_rate:.2%}") + logger.info(f" - Processed answer rate: {processed_answer_rate:.2%}") + logger.info( + f" - Answer rate improvement: {processed_answer_rate - original_answer_rate:.2%}" + ) + logger.info(f" - Memory reduction: {memory_reduction_rate:.2%}") + + # Save results if requested + if save_results: + if output_file is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_file = f"llm_processing_analysis_{timestamp}.json" + + try: + with open(output_file, "w", encoding="utf-8") as f: + json.dump(results, f, indent=2, ensure_ascii=False) + logger.info(f"Analysis results saved to: {output_file}") + except Exception as e: + logger.error(f"Failed to save results to {output_file}: {e}") + + return results + + +def main(): + """Main test function.""" + print("=== EvalAnalyzer Simple Test ===") + + # Initialize analyzer + analyzer = EvalAnalyzer(output_dir="./tmp/eval_analyzer") + + print("Analyzer initialized") + + # Test file paths + eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-xcy-1030-2114-locomo" + judged_file = os.path.join(eval_result_dir, "memos-api_locomo_judged.json") + search_results_file = os.path.join(eval_result_dir, "memos-api_locomo_search_results.json") + + print("Testing with files:") + print(f" Judged file: {judged_file}") + print(f" Search results file: {search_results_file}") + + # Check if files exist + if not os.path.exists(judged_file): + print(f"❌ Judged file not found: {judged_file}") + return + + if not os.path.exists(search_results_file): + print(f"❌ Search results file not found: {search_results_file}") + return + + print("✅ Both files exist") + + # Test bad case extraction only + try: + print("\n=== Testing Bad Case Extraction ===") + bad_cases = analyzer.extract_bad_cases(judged_file, search_results_file) + + print(f"✅ Successfully extracted {len(bad_cases)} bad cases") + + if bad_cases: + print("\n=== Sample Bad Cases ===") + for i, case in enumerate(bad_cases[:3]): # Show first 3 cases + print(f"\nBad Case {i + 1}:") + print(f" User ID: {case['user_id']}") + print(f" Query: {case['query'][:100]}...") + print(f" Golden Answer: {case['golden_answer']}...") + print(f" Answer: {case['answer']}...") + print(f" Has Memories: {len(case['memories']) > 0}") + print(f" Memory Length: {len(case['memories'])} chars") + + # Save basic results without LLM analysis + basic_results = { + "bad_cases_count": len(bad_cases), + "bad_cases": bad_cases, + "metadata": { + "eval_result_dir": eval_result_dir, + "judged_file": judged_file, + "search_results_file": search_results_file, + "extraction_only": True, + }, + } + + output_file = analyzer.output_dir / "bad_cases_extraction_only.json" + import json + + with open(output_file, "w", encoding="utf-8") as f: + json.dump(basic_results, f, indent=2, ensure_ascii=False) + + print(f"\n✅ Basic extraction results saved to: {output_file}") + + except Exception as e: + print(f"❌ Error during extraction: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/src/memos/mem_scheduler/analyzer/memory_processing.py b/src/memos/mem_scheduler/analyzer/memory_processing.py new file mode 100644 index 000000000..b692341c2 --- /dev/null +++ b/src/memos/mem_scheduler/analyzer/memory_processing.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +""" +Test script for memory processing functionality in eval_analyzer.py + +This script demonstrates how to use the new LLM memory processing features +to analyze and improve memory-based question answering. +""" + +import json +import os +import sys + +from pathlib import Path +from typing import Any + +from memos.log import get_logger +from memos.mem_scheduler.analyzer.eval_analyzer import EvalAnalyzer + + +FILE_PATH = Path(__file__).absolute() +BASE_DIR = FILE_PATH.parent # Go up to project root +sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory + + +logger = get_logger(__name__) + + +def create_sample_bad_cases() -> list[dict[str, Any]]: + """Create sample bad cases for testing memory processing.""" + return [ + { + "query": "What is the capital of France?", + "golden_answer": "Paris", + "memories": """ + Memory 1: France is a country in Western Europe. + Memory 2: The Eiffel Tower is located in Paris. + Memory 3: Paris is known for its art museums and fashion. + Memory 4: French cuisine is famous worldwide. + Memory 5: The Seine River flows through Paris. + """, + }, + { + "query": "When was the iPhone first released?", + "golden_answer": "June 29, 2007", + "memories": """ + Memory 1: Apple Inc. was founded by Steve Jobs, Steve Wozniak, and Ronald Wayne. + Memory 2: The iPhone was announced by Steve Jobs at the Macworld Conference & Expo on January 9, 2007. + Memory 3: The iPhone went on sale on June 29, 2007. + Memory 4: The original iPhone had a 3.5-inch screen. + Memory 5: Apple's stock price increased significantly after the iPhone launch. + """, + }, + { + "query": "What is photosynthesis?", + "golden_answer": "Photosynthesis is the process by which plants use sunlight, water, and carbon dioxide to produce glucose and oxygen.", + "memories": """ + Memory 1: Plants are living organisms that need sunlight to grow. + Memory 2: Chlorophyll is the green pigment in plants. + Memory 3: Plants take in carbon dioxide from the air. + Memory 4: Water is absorbed by plant roots from the soil. + Memory 5: Oxygen is released by plants during the day. + Memory 6: Glucose is a type of sugar that plants produce. + """, + }, + ] + + +def memory_processing(bad_cases): + """ + Test the memory processing functionality with cover rate and acc rate analysis. + + This function analyzes: + 1. Cover rate: Whether memories contain all information needed to answer the query + 2. Acc rate: Whether processed memories can correctly answer the query + """ + print("🧪 Testing Memory Processing Functionality with Cover Rate & Acc Rate Analysis") + print("=" * 80) + + # Initialize analyzer + analyzer = EvalAnalyzer() + + print(f"📊 Testing with {len(bad_cases)} sample cases") + print() + + # Initialize counters for real-time statistics + total_cases = 0 + cover_count = 0 # Cases where memories cover all needed information + acc_count = 0 # Cases where processed memories can correctly answer + + # Process each case + for i, case in enumerate(bad_cases): + total_cases += 1 + + # Safely handle query display + query_display = str(case.get("query", "Unknown query")) + print(f"🔍 Case {i + 1}/{len(bad_cases)}: {query_display}...") + + # Safely handle golden_answer display (convert to string if needed) + golden_answer = case.get("golden_answer", "Unknown answer") + golden_answer_str = str(golden_answer) if golden_answer is not None else "Unknown answer" + print(f"📝 Golden Answer: {golden_answer_str}") + print() + + # Step 1: Analyze if memories contain sufficient information (Cover Rate) + print(" 📋 Step 1: Analyzing memory coverage...") + coverage_analysis = analyzer.analyze_memory_sufficiency( + case["query"], + golden_answer_str, # Use the string version + case["memories"], + ) + + has_coverage = coverage_analysis.get("sufficient", False) + if has_coverage: + cover_count += 1 + + print(f" ✅ Memory Coverage: {'SUFFICIENT' if has_coverage else 'INSUFFICIENT'}") + print(f" 🎯 Confidence: {coverage_analysis.get('confidence', 0):.2f}") + print(f" 💭 Reasoning: {coverage_analysis.get('reasoning', 'N/A')}...") + if not has_coverage: + print( + f" ❌ Missing Info: {coverage_analysis.get('missing_information', 'N/A')[:100]}..." + ) + continue + print() + + # Step 2: Process memories and test answer ability (Acc Rate) + print(" 🔄 Step 2: Processing memories and testing answer ability...") + + processing_result = analyzer.scheduler_mem_process( + query=case["query"], + memories=case["memories"], + ) + print(f"Original Memories: {case['memories']}") + print(f"Processed Memories: {processing_result['processed_memories']}") + print(f" 📏 Compression ratio: {processing_result['compression_ratio']:.2f}") + print(f" 📄 Processed memories length: {processing_result['processed_length']} chars") + + # Generate answer with processed memories + answer_result = analyzer.generate_answer_with_memories( + case["query"], processing_result["processed_memories"], "processed_enhanced" + ) + + # Evaluate if the generated answer is correct + print(" 🎯 Step 3: Evaluating answer correctness...") + answer_evaluation = analyzer.compare_answer_quality( + case["query"], + golden_answer_str, # Use the string version + "No original answer available", # We don't have original answer + answer_result["answer"], + ) + + # Determine if processed memories can correctly answer (simplified logic) + processed_accuracy = answer_evaluation.get("processed_scores", {}).get("accuracy", 0) + can_answer_correctly = processed_accuracy >= 0.7 # Threshold for "correct" answer + + if can_answer_correctly: + acc_count += 1 + + print(f" 💬 Generated Answer: {answer_result['answer']}...") + print( + f" ✅ Answer Accuracy: {'CORRECT' if can_answer_correctly else 'INCORRECT'} (score: {processed_accuracy:.2f})" + ) + print() + + # Calculate and print real-time rates + current_cover_rate = cover_count / total_cases + current_acc_rate = acc_count / total_cases + + print(" 📊 REAL-TIME STATISTICS:") + print(f" 🎯 Cover Rate: {current_cover_rate:.2%} ({cover_count}/{total_cases})") + print(f" ✅ Acc Rate: {current_acc_rate:.2%} ({acc_count}/{total_cases})") + print() + + print("-" * 80) + print() + + # Final summary + print("🏁 FINAL ANALYSIS SUMMARY") + print("=" * 80) + print(f"📊 Total Cases Processed: {total_cases}") + print(f"🎯 Final Cover Rate: {cover_count / total_cases:.2%} ({cover_count}/{total_cases})") + print(f" - Cases with sufficient memory coverage: {cover_count}") + print(f" - Cases with insufficient memory coverage: {total_cases - cover_count}") + print() + print(f"✅ Final Acc Rate: {acc_count / total_cases:.2%} ({acc_count}/{total_cases})") + print(f" - Cases where processed memories can answer correctly: {acc_count}") + print(f" - Cases where processed memories cannot answer correctly: {total_cases - acc_count}") + print() + + # Additional insights + if cover_count > 0: + effective_processing_rate = acc_count / cover_count if cover_count > 0 else 0 + print(f"🔄 Processing Effectiveness: {effective_processing_rate:.2%}") + print( + f" - Among cases with sufficient coverage, {effective_processing_rate:.1%} can be answered correctly after processing" + ) + + print("=" * 80) + + +def load_real_bad_cases(file_path: str) -> list[dict[str, Any]]: + """Load real bad cases from JSON file.""" + print(f"📂 Loading bad cases from: {file_path}") + + with open(file_path, encoding="utf-8") as f: + data = json.load(f) + + bad_cases = data.get("bad_cases", []) + print(f"✅ Loaded {len(bad_cases)} bad cases") + + return bad_cases + + +def main(): + """Main test function.""" + print("🚀 Memory Processing Test Suite") + print("=" * 60) + print() + + # Check if OpenAI API key is set + if not os.getenv("OPENAI_API_KEY"): + print("⚠️ Warning: OPENAI_API_KEY not found in environment variables") + print(" Please set your OpenAI API key to run the tests") + return + + try: + bad_cases_file = f"{BASE_DIR}/tmp/eval_analyzer/bad_cases_extraction_only.json" + bad_cases = load_real_bad_cases(bad_cases_file) + + print(f"✅ Created {len(bad_cases)} sample bad cases") + print() + + # Run memory processing tests + memory_processing(bad_cases) + + print("✅ All tests completed successfully!") + + except Exception as e: + print(f"❌ Test failed with error: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + main() diff --git a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py index ace67eff6..03e1fc778 100644 --- a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +++ b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py @@ -427,7 +427,6 @@ def chat(self, query: str, user_id: str | None = None) -> str: message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=QUERY_LABEL, content=query, timestamp=datetime.now(), @@ -518,7 +517,6 @@ def chat(self, query: str, user_id: str | None = None) -> str: message_item = ScheduleMessageItem( user_id=target_user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, label=ANSWER_LABEL, content=response, timestamp=datetime.now(), diff --git a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py index 7c0fa5a4a..3d0235871 100644 --- a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py +++ b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py @@ -226,9 +226,9 @@ def evaluate_memory_answer_ability( try: # Extract JSON response - from memos.mem_scheduler.utils.misc_utils import extract_json_dict + from memos.mem_scheduler.utils.misc_utils import extract_json_obj - result = extract_json_dict(response) + result = extract_json_obj(response) # Validate response structure if "result" in result: diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 028fe8e3f..eb49d0238 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -1,6 +1,5 @@ import contextlib import multiprocessing -import queue import threading import time @@ -18,15 +17,18 @@ from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue +from memos.mem_scheduler.general_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_ACT_MEM_DUMP_PATH, + DEFAULT_CONSUME_BATCH, DEFAULT_CONSUME_INTERVAL_SECONDS, DEFAULT_CONTEXT_WINDOW_SIZE, DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE, + DEFAULT_MAX_WEB_LOG_QUEUE_SIZE, DEFAULT_STARTUP_MODE, DEFAULT_THREAD_POOL_MAX_WORKERS, DEFAULT_TOP_K, @@ -86,6 +88,22 @@ def __init__(self, config: BaseSchedulerConfig): "scheduler_startup_mode", DEFAULT_STARTUP_MODE ) + # message queue configuration + self.use_redis_queue = self.config.get("use_redis_queue", DEFAULT_USE_REDIS_QUEUE) + self.max_internal_message_queue_size = self.config.get( + "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE + ) + + # Initialize message queue based on configuration + if self.use_redis_queue: + self.memos_message_queue = SchedulerRedisQueue( + maxsize=self.max_internal_message_queue_size + ) + else: + self.memos_message_queue: Queue[ScheduleMessageItem] = Queue( + maxsize=self.max_internal_message_queue_size + ) + self.retriever: SchedulerRetriever | None = None self.db_engine: Engine | None = None self.monitor: SchedulerGeneralMonitor | None = None @@ -93,6 +111,8 @@ def __init__(self, config: BaseSchedulerConfig): self.mem_reader = None # Will be set by MOSCore self.dispatcher = SchedulerDispatcher( config=self.config, + memos_message_queue=self.memos_message_queue, + use_redis_queue=self.use_redis_queue, max_workers=self.thread_pool_max_workers, enable_parallel_dispatch=self.enable_parallel_dispatch, ) @@ -100,23 +120,9 @@ def __init__(self, config: BaseSchedulerConfig): # optional configs self.disable_handlers: list | None = self.config.get("disable_handlers", None) - # message queue configuration - self.use_redis_queue = self.config.get("use_redis_queue", DEFAULT_USE_REDIS_QUEUE) - self.max_internal_message_queue_size = self.config.get( - "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE + self.max_web_log_queue_size = self.config.get( + "max_web_log_queue_size", DEFAULT_MAX_WEB_LOG_QUEUE_SIZE ) - - # Initialize message queue based on configuration - if self.use_redis_queue: - self.memos_message_queue = None # Will use Redis instead - # Initialize Redis if using Redis queue with auto-initialization - self.auto_initialize_redis() - else: - self.memos_message_queue: Queue[ScheduleMessageItem] = Queue( - maxsize=self.max_internal_message_queue_size - ) - - self.max_web_log_queue_size = self.config.get("max_web_log_queue_size", 50) self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue( maxsize=self.max_web_log_queue_size ) @@ -126,6 +132,7 @@ def __init__(self, config: BaseSchedulerConfig): self._consume_interval = self.config.get( "consume_interval_seconds", DEFAULT_CONSUME_INTERVAL_SECONDS ) + self.consume_batch = self.config.get("consume_batch", DEFAULT_CONSUME_BATCH) # other attributes self._context_lock = threading.Lock() @@ -216,7 +223,7 @@ def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None: with self._context_lock: self.current_user_id = msg.user_id self.current_mem_cube_id = msg.mem_cube_id - self.current_mem_cube = msg.mem_cube + self.current_mem_cube = self.get_mem_cube(msg.mem_cube_id) def transform_working_memories_to_monitors( self, query_keywords, memories: list[TextualMemoryItem] @@ -533,17 +540,9 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt if self.disable_handlers and message.label in self.disable_handlers: logger.info(f"Skipping disabled handler: {message.label} - {message.content}") continue + self.memos_message_queue.put(message) + logger.info(f"Submitted message to local queue: {message.label} - {message.content}") - if self.use_redis_queue: - # Use Redis stream for message queue - self.redis_add_message_stream(message.to_dict()) - logger.info(f"Submitted message to Redis: {message.label} - {message.content}") - else: - # Use local queue - self.memos_message_queue.put(message) - logger.info( - f"Submitted message to local queue: {message.label} - {message.content}" - ) with contextlib.suppress(Exception): if messages: self.dispatcher.on_messages_enqueued(messages) @@ -590,7 +589,7 @@ def get_web_log_messages(self) -> list[dict]: try: item = self._web_log_message_queue.get_nowait() # Thread-safe get messages.append(item.to_dict()) - except queue.Empty: + except Exception: break return messages @@ -601,62 +600,28 @@ def _message_consumer(self) -> None: Runs in a dedicated thread to process messages at regular intervals. For Redis queue, this method starts the Redis listener. """ - if self.use_redis_queue: - # For Redis queue, start the Redis listener - def redis_message_handler(message_data): - """Handler for Redis messages""" - try: - # Redis message data needs to be decoded from bytes to string - decoded_data = {} - for key, value in message_data.items(): - if isinstance(key, bytes): - key = key.decode("utf-8") - if isinstance(value, bytes): - value = value.decode("utf-8") - decoded_data[key] = value - - message = ScheduleMessageItem.from_dict(decoded_data) - self.dispatcher.dispatch([message]) - except Exception as e: - logger.error(f"Error processing Redis message: {e}") - logger.error(f"Message data: {message_data}") - - self.redis_start_listening(handler=redis_message_handler) - - # Keep the thread alive while Redis listener is running - while self._running: - time.sleep(self._consume_interval) - else: - # Original local queue logic - while self._running: # Use a running flag for graceful shutdown - try: - # Get all available messages at once (thread-safe approach) - messages = [] - while True: - try: - # Use get_nowait() directly without empty() check to avoid race conditions - message = self.memos_message_queue.get_nowait() - messages.append(message) - except queue.Empty: - # No more messages available - break - if messages: - try: - self.dispatcher.dispatch(messages) - except Exception as e: - logger.error(f"Error dispatching messages: {e!s}") - finally: - # Mark all messages as processed - for _ in messages: - self.memos_message_queue.task_done() + # Original local queue logic + while self._running: # Use a running flag for graceful shutdown + try: + # Get messages in batches based on consume_batch setting + + messages = self.memos_message_queue.get(block=True, batch_size=self.consume_batch) + + if messages: + try: + self.dispatcher.dispatch(messages) + except Exception as e: + logger.error(f"Error dispatching messages: {e!s}") - # Sleep briefly to prevent busy waiting - time.sleep(self._consume_interval) # Adjust interval as needed + # Sleep briefly to prevent busy waiting + time.sleep(self._consume_interval) # Adjust interval as needed - except Exception as e: + except Exception as e: + # Don't log error for "No messages available in Redis queue" as it's expected + if "No messages available in Redis queue" not in str(e): logger.error(f"Unexpected error in message consumer: {e!s}") - time.sleep(self._consume_interval) # Prevent tight error loops + time.sleep(self._consume_interval) # Prevent tight error loops def start(self) -> None: """ @@ -666,16 +631,25 @@ def start(self) -> None: 1. Message consumer thread or process (based on startup_mode) 2. Dispatcher thread pool (if parallel dispatch enabled) """ - if self._running: - logger.warning("Memory Scheduler is already running") - return - # Initialize dispatcher resources if self.enable_parallel_dispatch: logger.info( f"Initializing dispatcher thread pool with {self.thread_pool_max_workers} workers" ) + self.start_consumer() + + def start_consumer(self) -> None: + """ + Start only the message consumer thread/process. + + This method can be used to restart the consumer after it has been stopped + with stop_consumer(), without affecting other scheduler components. + """ + if self._running: + logger.warning("Memory Scheduler consumer is already running") + return + # Start consumer based on startup mode self._running = True @@ -698,15 +672,15 @@ def start(self) -> None: self._consumer_thread.start() logger.info("Message consumer thread started") - def stop(self) -> None: - """Stop all scheduler components gracefully. + def stop_consumer(self) -> None: + """Stop only the message consumer thread/process gracefully. - 1. Stops message consumer thread/process - 2. Shuts down dispatcher thread pool - 3. Cleans up resources + This method stops the consumer without affecting other components like + dispatcher or monitors. Useful when you want to pause message processing + while keeping other scheduler components running. """ if not self._running: - logger.warning("Memory Scheduler is not running") + logger.warning("Memory Scheduler consumer is not running") return # Signal consumer thread/process to stop @@ -726,12 +700,30 @@ def stop(self) -> None: logger.info("Consumer process terminated") else: logger.info("Consumer process stopped") + self._consumer_process = None elif self._consumer_thread and self._consumer_thread.is_alive(): self._consumer_thread.join(timeout=5.0) if self._consumer_thread.is_alive(): logger.warning("Consumer thread did not stop gracefully") else: logger.info("Consumer thread stopped") + self._consumer_thread = None + + logger.info("Memory Scheduler consumer stopped") + + def stop(self) -> None: + """Stop all scheduler components gracefully. + + 1. Stops message consumer thread/process + 2. Shuts down dispatcher thread pool + 3. Cleans up resources + """ + if not self._running: + logger.warning("Memory Scheduler is not running") + return + + # Stop consumer first + self.stop_consumer() # Shutdown dispatcher if self.dispatcher: @@ -743,10 +735,6 @@ def stop(self) -> None: logger.info("Shutting down monitor...") self.dispatcher_monitor.stop() - # Clean up queues - self._cleanup_queues() - logger.info("Memory Scheduler stopped completely") - @property def handlers(self) -> dict[str, Callable]: """ @@ -819,30 +807,6 @@ def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, di return result - def _cleanup_queues(self) -> None: - """Ensure all queues are emptied and marked as closed.""" - if self.use_redis_queue: - # For Redis queue, stop the listener and close connection - try: - self.redis_stop_listening() - self.redis_close() - except Exception as e: - logger.error(f"Error cleaning up Redis connection: {e}") - else: - # Original local queue cleanup - try: - while not self.memos_message_queue.empty(): - self.memos_message_queue.get_nowait() - self.memos_message_queue.task_done() - except queue.Empty: - pass - - try: - while not self._web_log_message_queue.empty(): - self._web_log_message_queue.get_nowait() - except queue.Empty: - pass - def mem_scheduler_wait( self, timeout: float = 180.0, poll: float = 0.1, log_every: float = 0.01 ) -> bool: @@ -906,11 +870,24 @@ def _fmt_eta(seconds: float | None) -> str: st = ( stats_fn() ) # expected: {'pending':int,'running':int,'done':int?,'rate':float?} - pend = int(st.get("pending", 0)) run = int(st.get("running", 0)) + except Exception: pass + if isinstance(self.memos_message_queue, SchedulerRedisQueue): + # For Redis queue, prefer XINFO GROUPS to compute pending + groups_info = self.memos_message_queue.redis.xinfo_groups( + self.memos_message_queue.stream_name + ) + if groups_info: + for group in groups_info: + if group.get("name") == self.memos_message_queue.consumer_group: + pend = int(group.get("pending", pend)) + break + else: + pend = run + # 2) dynamic total (allows new tasks queued while waiting) total_now = max(init_unfinished, done_total + curr_unfinished) done_total = max(0, total_now - curr_unfinished) diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/general_modules/dispatcher.py index c2407b9e6..b74529c8c 100644 --- a/src/memos/mem_scheduler/general_modules/dispatcher.py +++ b/src/memos/mem_scheduler/general_modules/dispatcher.py @@ -10,7 +10,9 @@ from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule +from memos.mem_scheduler.general_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.general_modules.task_threads import ThreadManager +from memos.mem_scheduler.schemas.general_schemas import DEFAULT_STOP_WAIT from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem from memos.mem_scheduler.utils.metrics import MetricsRegistry @@ -32,13 +34,23 @@ class SchedulerDispatcher(BaseSchedulerModule): - Thread race competition for parallel task execution """ - def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None): + def __init__( + self, + max_workers: int = 30, + memos_message_queue: Any | None = None, + use_redis_queue: bool | None = None, + enable_parallel_dispatch: bool = True, + config=None, + ): super().__init__() self.config = config # Main dispatcher thread pool self.max_workers = max_workers + self.memos_message_queue = memos_message_queue + self.use_redis_queue = use_redis_queue + # Get multi-task timeout from config self.multi_task_running_timeout = ( self.config.get("multi_task_running_timeout") if self.config else None @@ -73,6 +85,11 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None): self._completed_tasks = [] self.completed_tasks_max_show_size = 10 + # Configure shutdown wait behavior from config or default + self.stop_wait = ( + self.config.get("stop_wait", DEFAULT_STOP_WAIT) if self.config else DEFAULT_STOP_WAIT + ) + self.metrics = MetricsRegistry( topk_per_label=(self.config or {}).get("metrics_topk_per_label", 50) ) @@ -131,6 +148,19 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): # --- mark done --- for m in messages: self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time()) + + # acknowledge redis messages + + if ( + self.use_redis_queue + and self.memos_message_queue is not None + and isinstance(self.memos_message_queue, SchedulerRedisQueue) + ): + for msg in messages: + redis_message_id = msg.redis_message_id + # Acknowledge message processing + self.memos_message_queue.ack_message(redis_message_id=redis_message_id) + # Mark task as completed and remove from tracking with self._task_lock: if task_item.item_id in self._running_tasks: @@ -138,7 +168,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): del self._running_tasks[task_item.item_id] self._completed_tasks.append(task_item) if len(self._completed_tasks) > self.completed_tasks_max_show_size: - self._completed_tasks[-self.completed_tasks_max_show_size :] + self._completed_tasks.pop(0) logger.info(f"Task completed: {task_item.get_execution_info()}") return result @@ -152,7 +182,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): task_item.mark_failed(str(e)) del self._running_tasks[task_item.item_id] if len(self._completed_tasks) > self.completed_tasks_max_show_size: - self._completed_tasks[-self.completed_tasks_max_show_size :] + self._completed_tasks.pop(0) logger.error(f"Task failed: {task_item.get_execution_info()}, Error: {e}") raise @@ -381,17 +411,13 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]): wrapped_handler = self._create_task_wrapper(handler, task_item) # dispatch to different handler - logger.debug( - f"Dispatch {len(msgs)} message(s) to {label} handler for user {user_id} and mem_cube {mem_cube_id}." - ) - logger.info(f"Task started: {task_item.get_execution_info()}") - + logger.debug(f"Task started: {task_item.get_execution_info()}") if self.enable_parallel_dispatch and self.dispatcher_executor is not None: # Capture variables in lambda to avoid loop variable issues - future = self.dispatcher_executor.submit(wrapped_handler, msgs) - self._futures.add(future) - future.add_done_callback(self._handle_future_result) - logger.info(f"Dispatched {len(msgs)} message(s) as future task") + _ = self.dispatcher_executor.submit(wrapped_handler, msgs) + logger.info( + f"Dispatch {len(msgs)} message(s) to {label} handler for user {user_id} and mem_cube {mem_cube_id}." + ) else: wrapped_handler(msgs) @@ -484,17 +510,9 @@ def shutdown(self) -> None: """Gracefully shutdown the dispatcher.""" self._running = False - if self.dispatcher_executor is not None: - # Cancel pending tasks - cancelled = 0 - for future in self._futures: - if future.cancel(): - cancelled += 1 - logger.info(f"Cancelled {cancelled}/{len(self._futures)} pending tasks") - # Shutdown executor try: - self.dispatcher_executor.shutdown(wait=True) + self.dispatcher_executor.shutdown(wait=self.stop_wait, cancel_futures=True) except Exception as e: logger.error(f"Executor shutdown error: {e}", exc_info=True) finally: diff --git a/src/memos/mem_scheduler/general_modules/misc.py b/src/memos/mem_scheduler/general_modules/misc.py index b6f48d043..e4e7edb89 100644 --- a/src/memos/mem_scheduler/general_modules/misc.py +++ b/src/memos/mem_scheduler/general_modules/misc.py @@ -199,6 +199,9 @@ class AutoDroppingQueue(Queue[T]): """A thread-safe queue that automatically drops the oldest item when full.""" def __init__(self, maxsize: int = 0): + # If maxsize <= 0, set to 0 (unlimited queue size) + if maxsize <= 0: + maxsize = 0 super().__init__(maxsize=maxsize) def put(self, item: T, block: bool = False, timeout: float | None = None) -> None: @@ -218,7 +221,7 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non # First try non-blocking put super().put(item, block=block, timeout=timeout) except Full: - # Remove oldest item and mark it done to avoid leaking unfinished_tasks + # Remove the oldest item and mark it done to avoid leaking unfinished_tasks with suppress(Empty): _ = self.get_nowait() # If the removed item had previously incremented unfinished_tasks, @@ -228,12 +231,70 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non # Retry putting the new item super().put(item, block=block, timeout=timeout) + def get( + self, block: bool = True, timeout: float | None = None, batch_size: int | None = None + ) -> list[T] | T: + """Get items from the queue. + + Args: + block: Whether to block if no items are available (default: True) + timeout: Timeout in seconds for blocking operations (default: None) + batch_size: Number of items to retrieve (default: 1) + + Returns: + List of items (always returns a list for consistency) + + Raises: + Empty: If no items are available and block=False or timeout expires + """ + + if batch_size is None: + return super().get(block=block, timeout=timeout) + items = [] + for _ in range(batch_size): + try: + items.append(super().get(block=block, timeout=timeout)) + except Empty: + if not items and block: + # If we haven't gotten any items and we're blocking, re-raise Empty + raise + break + return items + + def get_nowait(self, batch_size: int | None = None) -> list[T]: + """Get items from the queue without blocking. + + Args: + batch_size: Number of items to retrieve (default: 1) + + Returns: + List of items (always returns a list for consistency) + """ + if batch_size is None: + return super().get_nowait() + + items = [] + for _ in range(batch_size): + try: + items.append(super().get_nowait()) + except Empty: + break + return items + def get_queue_content_without_pop(self) -> list[T]: """Return a copy of the queue's contents without modifying it.""" # Ensure a consistent snapshot by holding the mutex with self.mutex: return list(self.queue) + def qsize(self) -> int: + """Return the approximate size of the queue. + + Returns: + Number of items currently in the queue + """ + return super().qsize() + def clear(self) -> None: """Remove all items from the queue. diff --git a/src/memos/mem_scheduler/general_modules/redis_queue.py b/src/memos/mem_scheduler/general_modules/redis_queue.py new file mode 100644 index 000000000..c10765d05 --- /dev/null +++ b/src/memos/mem_scheduler/general_modules/redis_queue.py @@ -0,0 +1,460 @@ +""" +Redis Queue implementation for SchedulerMessageItem objects. + +This module provides a Redis-based queue implementation that can replace +the local memos_message_queue functionality in BaseScheduler. +""" + +import time + +from collections.abc import Callable +from uuid import uuid4 + +from memos.log import get_logger +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule + + +logger = get_logger(__name__) + + +class SchedulerRedisQueue(RedisSchedulerModule): + """ + Redis-based queue for storing and processing SchedulerMessageItem objects. + + This class provides a Redis Stream-based implementation that can replace + the local memos_message_queue functionality, offering better scalability + and persistence for message processing. + + Inherits from RedisSchedulerModule to leverage existing Redis connection + and initialization functionality. + """ + + def __init__( + self, + stream_name: str = "scheduler:messages:stream", + consumer_group: str = "scheduler_group", + consumer_name: str | None = "scheduler_consumer", + max_len: int = 10000, + maxsize: int = 0, # For Queue compatibility + auto_delete_acked: bool = True, # Whether to automatically delete acknowledged messages + ): + """ + Initialize the Redis queue. + + Args: + stream_name: Name of the Redis stream + consumer_group: Name of the consumer group + consumer_name: Name of the consumer (auto-generated if None) + max_len: Maximum length of the stream (for memory management) + maxsize: Maximum size of the queue (for Queue compatibility, ignored) + auto_delete_acked: Whether to automatically delete acknowledged messages from stream + """ + super().__init__() + + # If maxsize <= 0, set to None (unlimited queue size) + if maxsize <= 0: + maxsize = 0 + + # Stream configuration + self.stream_name = stream_name + self.consumer_group = consumer_group + self.consumer_name = consumer_name or f"consumer_{uuid4().hex[:8]}" + self.max_len = max_len + self.maxsize = maxsize # For Queue compatibility + self.auto_delete_acked = auto_delete_acked # Whether to delete acknowledged messages + + # Consumer state + self._is_listening = False + self._message_handler: Callable[[ScheduleMessageItem], None] | None = None + + # Connection state + self._is_connected = False + + # Task tracking for mem_scheduler_wait compatibility + self._unfinished_tasks = 0 + + # Auto-initialize Redis connection + if self.auto_initialize_redis(): + self._is_connected = True + self._ensure_consumer_group() + + def _ensure_consumer_group(self) -> None: + """Ensure the consumer group exists for the stream.""" + if not self._redis_conn: + return + + try: + self._redis_conn.xgroup_create( + self.stream_name, self.consumer_group, id="0", mkstream=True + ) + logger.debug( + f"Created consumer group '{self.consumer_group}' for stream '{self.stream_name}'" + ) + except Exception as e: + # Check if it's a "consumer group already exists" error + error_msg = str(e).lower() + if "busygroup" in error_msg or "already exists" in error_msg: + logger.info( + f"Consumer group '{self.consumer_group}' already exists for stream '{self.stream_name}'" + ) + else: + logger.error(f"Error creating consumer group: {e}", exc_info=True) + + def put( + self, message: ScheduleMessageItem, block: bool = True, timeout: float | None = None + ) -> None: + """ + Add a message to the Redis queue (Queue-compatible interface). + + Args: + message: SchedulerMessageItem to add to the queue + block: Ignored for Redis implementation (always non-blocking) + timeout: Ignored for Redis implementation + + Raises: + ConnectionError: If not connected to Redis + TypeError: If message is not a ScheduleMessageItem + """ + if not self._redis_conn: + raise ConnectionError("Not connected to Redis. Redis connection not available.") + + if not isinstance(message, ScheduleMessageItem): + raise TypeError(f"Expected ScheduleMessageItem, got {type(message)}") + + try: + # Convert message to dictionary for Redis storage + message_data = message.to_dict() + + # Add to Redis stream with automatic trimming + message_id = self._redis_conn.xadd( + self.stream_name, message_data, maxlen=self.max_len, approximate=True + ) + + logger.info( + f"Added message {message_id} to Redis stream: {message.label} - {message.content[:100]}..." + ) + + except Exception as e: + logger.error(f"Failed to add message to Redis queue: {e}") + raise + + def put_nowait(self, message: ScheduleMessageItem) -> None: + """ + Add a message to the Redis queue without blocking (Queue-compatible interface). + + Args: + message: SchedulerMessageItem to add to the queue + """ + self.put(message, block=False) + + def ack_message(self, redis_message_id): + self.redis.xack(self.stream_name, self.consumer_group, redis_message_id) + + # Optionally delete the message from the stream to keep it clean + if self.auto_delete_acked: + try: + self._redis_conn.xdel(self.stream_name, redis_message_id) + logger.info(f"Successfully delete acknowledged message {redis_message_id}") + except Exception as e: + logger.warning(f"Failed to delete acknowledged message {redis_message_id}: {e}") + + def get( + self, + block: bool = True, + timeout: float | None = None, + batch_size: int | None = None, + ) -> list[ScheduleMessageItem]: + if not self._redis_conn: + raise ConnectionError("Not connected to Redis. Redis connection not available.") + + try: + # Calculate timeout for Redis + redis_timeout = None + if block and timeout is not None: + redis_timeout = int(timeout * 1000) + elif not block: + redis_timeout = None # Non-blocking + + # Read messages from the consumer group + try: + messages = self._redis_conn.xreadgroup( + self.consumer_group, + self.consumer_name, + {self.stream_name: ">"}, + count=batch_size if not batch_size else 1, + block=redis_timeout, + ) + except Exception as read_err: + # Handle missing group/stream by creating and retrying once + err_msg = str(read_err).lower() + if "nogroup" in err_msg or "no such key" in err_msg: + logger.warning( + f"Consumer group or stream missing for '{self.stream_name}/{self.consumer_group}'. Attempting to create and retry." + ) + messages = self._redis_conn.xreadgroup( + self.consumer_group, + self.consumer_name, + {self.stream_name: ">"}, + count=batch_size if not batch_size else 1, + block=redis_timeout, + ) + else: + raise + result_messages = [] + + for _stream, stream_messages in messages: + for message_id, fields in stream_messages: + try: + # Convert Redis message back to SchedulerMessageItem + message = ScheduleMessageItem.from_dict(fields) + message.redis_message_id = message_id + + result_messages.append(message) + + except Exception as e: + logger.error(f"Failed to parse message {message_id}: {e}") + + # Always return a list for consistency + if not result_messages: + if not block: + return [] # Return empty list for non-blocking calls + else: + # If no messages were found, raise Empty exception + from queue import Empty + + raise Empty("No messages available in Redis queue") + + return result_messages if batch_size is not None else result_messages[0] + + except Exception as e: + if "Empty" in str(type(e).__name__): + raise + logger.error(f"Failed to get message from Redis queue: {e}") + raise + + def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]: + """ + Get messages from the Redis queue without blocking (Queue-compatible interface). + + Returns: + List of SchedulerMessageItem objects + + Raises: + Empty: If no message is available + """ + return self.get(block=False, batch_size=batch_size) + + def qsize(self) -> int: + """ + Get the current size of the Redis queue (Queue-compatible interface). + + Returns the number of pending (unacknowledged) messages in the consumer group, + which represents the actual queue size for processing. + + Returns: + Number of pending messages in the queue + """ + if not self._redis_conn: + return 0 + + try: + # Get pending messages info for the consumer group + # XPENDING returns info about pending messages that haven't been acknowledged + pending_info = self._redis_conn.xpending(self.stream_name, self.consumer_group) + + # pending_info[0] contains the count of pending messages + if pending_info and len(pending_info) > 0 and pending_info[0] is not None: + pending_count = int(pending_info[0]) + if pending_count > 0: + return pending_count + + # If no pending messages, check if there are new messages in the stream + # that haven't been read by any consumer yet + try: + # Get the last delivered ID for the consumer group + groups_info = self._redis_conn.xinfo_groups(self.stream_name) + if not groups_info: + # No groups exist, check total stream length + return self._redis_conn.xlen(self.stream_name) or 0 + + last_delivered_id = "0-0" + + for group_info in groups_info: + if group_info and group_info.get("name") == self.consumer_group: + last_delivered_id = group_info.get("last-delivered-id", "0-0") + break + + # Count messages after the last delivered ID + new_messages = self._redis_conn.xrange( + self.stream_name, + f"({last_delivered_id}", # Exclusive start + "+", # End at the latest message + count=1000, # Limit to avoid memory issues + ) + + return len(new_messages) if new_messages else 0 + + except Exception as inner_e: + logger.debug(f"Failed to get new messages count: {inner_e}") + # Fallback: return stream length + try: + stream_len = self._redis_conn.xlen(self.stream_name) + return stream_len if stream_len is not None else 0 + except Exception: + return 0 + + except Exception as e: + logger.debug(f"Failed to get Redis queue size via XPENDING: {e}") + # Fallback to stream length if pending check fails + try: + stream_len = self._redis_conn.xlen(self.stream_name) + return stream_len if stream_len is not None else 0 + except Exception as fallback_e: + logger.error(f"Failed to get Redis queue size (all methods failed): {fallback_e}") + return 0 + + def size(self) -> int: + """ + Get the current size of the Redis queue (alias for qsize). + + Returns: + Number of messages in the queue + """ + return self.qsize() + + def empty(self) -> bool: + """ + Check if the Redis queue is empty (Queue-compatible interface). + + Returns: + True if the queue is empty, False otherwise + """ + return self.qsize() == 0 + + def full(self) -> bool: + """ + Check if the Redis queue is full (Queue-compatible interface). + + For Redis streams, we consider the queue full if it exceeds maxsize. + If maxsize is 0 or None, the queue is never considered full. + + Returns: + True if the queue is full, False otherwise + """ + if self.maxsize <= 0: + return False + return self.qsize() >= self.maxsize + + def join(self) -> None: + """ + Block until all items in the queue have been gotten and processed (Queue-compatible interface). + + For Redis streams, this would require tracking pending messages, + which is complex. For now, this is a no-op. + """ + + def clear(self) -> None: + """Clear all messages from the queue.""" + if not self._is_connected or not self._redis_conn: + return + + try: + # Delete the entire stream + self._redis_conn.delete(self.stream_name) + logger.info(f"Cleared Redis stream: {self.stream_name}") + + # Recreate the consumer group + self._ensure_consumer_group() + except Exception as e: + logger.error(f"Failed to clear Redis queue: {e}") + + def start_listening( + self, + handler: Callable[[ScheduleMessageItem], None], + batch_size: int = 10, + poll_interval: float = 0.1, + ) -> None: + """ + Start listening for messages and process them with the provided handler. + + Args: + handler: Function to call for each received message + batch_size: Number of messages to process in each batch + poll_interval: Interval between polling attempts in seconds + """ + if not self._is_connected: + raise ConnectionError("Not connected to Redis. Call connect() first.") + + self._message_handler = handler + self._is_listening = True + + logger.info(f"Started listening on Redis stream: {self.stream_name}") + + try: + while self._is_listening: + messages = self.get(timeout=poll_interval, count=batch_size) + + for message in messages: + try: + self._message_handler(message) + except Exception as e: + logger.error(f"Error processing message {message.item_id}: {e}") + + # Small sleep to prevent excessive CPU usage + if not messages: + time.sleep(poll_interval) + + except KeyboardInterrupt: + logger.info("Received interrupt signal, stopping listener") + except Exception as e: + logger.error(f"Error in message listener: {e}") + finally: + self._is_listening = False + logger.info("Stopped listening for messages") + + def stop_listening(self) -> None: + """Stop the message listener.""" + self._is_listening = False + logger.info("Requested stop for message listener") + + def connect(self) -> None: + """Establish connection to Redis and set up the queue.""" + if self._redis_conn is not None: + try: + # Test the connection + self._redis_conn.ping() + self._is_connected = True + logger.debug("Redis connection established successfully") + except Exception as e: + logger.error(f"Failed to connect to Redis: {e}") + self._is_connected = False + else: + logger.error("Redis connection not initialized") + self._is_connected = False + + def disconnect(self) -> None: + """Disconnect from Redis and clean up resources.""" + self._is_connected = False + if self._is_listening: + self.stop_listening() + logger.debug("Disconnected from Redis") + + def __enter__(self): + """Context manager entry.""" + self.connect() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.stop_listening() + self.disconnect() + + def __del__(self): + """Cleanup when object is destroyed.""" + if self._is_connected: + self.disconnect() + + @property + def unfinished_tasks(self) -> int: + return self.qsize() diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 6840adc2b..32fefce63 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -51,7 +51,7 @@ def __init__(self, config: GeneralSchedulerConfig): def long_memory_update_process( self, user_id: str, mem_cube_id: str, messages: list[ScheduleMessageItem] ): - mem_cube = messages[0].mem_cube + mem_cube = self.current_mem_cube # for status update self._set_current_context_from_message(msg=messages[0]) @@ -140,7 +140,7 @@ def long_memory_update_process( label=QUERY_LABEL, user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=messages[0].mem_cube, + mem_cube=self.current_mem_cube, ) def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: @@ -212,7 +212,7 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.error(f"Error: {e}. Content: {msg.content}", exc_info=True) userinput_memory_ids = [] - mem_cube = msg.mem_cube + mem_cube = self.current_mem_cube for memory_id in userinput_memory_ids: try: mem_item: TextualMemoryItem = mem_cube.text_mem.get( @@ -234,7 +234,7 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: memory_type=mem_type, user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, - mem_cube=msg.mem_cube, + mem_cube=self.current_mem_cube, log_func_callback=self._submit_web_logs, ) @@ -248,7 +248,7 @@ def process_message(message: ScheduleMessageItem): try: user_id = message.user_id mem_cube_id = message.mem_cube_id - mem_cube = message.mem_cube + mem_cube = self.current_mem_cube content = message.content user_name = message.user_name @@ -412,7 +412,7 @@ def process_message(message: ScheduleMessageItem): try: user_id = message.user_id mem_cube_id = message.mem_cube_id - mem_cube = message.mem_cube + mem_cube = self.current_mem_cube content = message.content user_name = message.user_name @@ -516,7 +516,7 @@ def process_message(message: ScheduleMessageItem): user_id = message.user_id session_id = message.session_id mem_cube_id = message.mem_cube_id - mem_cube = message.mem_cube + mem_cube = self.current_mem_cube content = message.content messages_list = json.loads(content) diff --git a/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py b/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py index e18c6e51a..25b9a98f3 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py +++ b/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py @@ -2,7 +2,7 @@ from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule -from memos.mem_scheduler.utils.misc_utils import extract_json_dict +from memos.mem_scheduler.utils.misc_utils import extract_json_obj from memos.memories.textual.tree import TextualMemoryItem @@ -66,7 +66,7 @@ def filter_unrelated_memories( try: # Parse JSON response - response = extract_json_dict(response) + response = extract_json_obj(response) logger.debug(f"Parsed JSON response: {response}") relevant_indices = response["relevant_memories"] filtered_count = response["filtered_count"] @@ -164,7 +164,7 @@ def filter_redundant_memories( try: # Parse JSON response - response = extract_json_dict(response) + response = extract_json_obj(response) logger.debug(f"Parsed JSON response: {response}") kept_indices = response["kept_memories"] redundant_groups = response.get("redundant_groups", []) @@ -226,8 +226,6 @@ def filter_unrelated_and_redundant_memories( Note: If LLM filtering fails, returns all memories (conservative approach) """ - success_flag = False - if not memories: logger.info("No memories to filter for unrelated and redundant - returning empty list") return [], True @@ -265,7 +263,7 @@ def filter_unrelated_and_redundant_memories( try: # Parse JSON response - response = extract_json_dict(response) + response = extract_json_obj(response) logger.debug(f"Parsed JSON response: {response}") kept_indices = response["kept_memories"] unrelated_removed_count = response.get("unrelated_removed_count", 0) diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py index b766f0010..848b1d257 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py +++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py @@ -1,9 +1,14 @@ +from concurrent.futures import as_completed + from memos.configs.mem_scheduler import BaseSchedulerConfig +from memos.context.context import ContextThreadPoolExecutor from memos.llms.base import BaseLLM from memos.log import get_logger from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.schemas.general_schemas import ( + DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE, + DEFAULT_SCHEDULER_RETRIEVER_RETRIES, TreeTextMemory_FINE_SEARCH_METHOD, TreeTextMemory_SEARCH_METHOD, ) @@ -12,11 +17,11 @@ filter_vector_based_similar_memories, transform_name_to_key, ) -from memos.mem_scheduler.utils.misc_utils import ( - extract_json_dict, -) +from memos.mem_scheduler.utils.misc_utils import extract_json_obj, extract_list_items_in_answer +from memos.memories.textual.item import TextualMemoryMetadata from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +# Extract JSON response from .memory_filter import MemoryFilter @@ -30,12 +35,213 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig): # hyper-parameters self.filter_similarity_threshold = 0.75 self.filter_min_length_threshold = 6 - - self.config: BaseSchedulerConfig = config + self.memory_filter = MemoryFilter(process_llm=process_llm, config=config) self.process_llm = process_llm + self.config = config - # Initialize memory filter - self.memory_filter = MemoryFilter(process_llm=process_llm, config=config) + # Configure enhancement batching & retries from config with safe defaults + self.batch_size: int | None = getattr( + config, "scheduler_retriever_batch_size", DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE + ) + self.retries: int = getattr( + config, "scheduler_retriever_enhance_retries", DEFAULT_SCHEDULER_RETRIEVER_RETRIES + ) + + def evaluate_memory_answer_ability( + self, query: str, memory_texts: list[str], top_k: int | None = None + ) -> bool: + limited_memories = memory_texts[:top_k] if top_k is not None else memory_texts + # Build prompt using the template + prompt = self.build_prompt( + template_name="memory_answer_ability_evaluation", + query=query, + memory_list="\n".join([f"- {memory}" for memory in limited_memories]) + if limited_memories + else "No memories available", + ) + + # Use the process LLM to generate response + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + + try: + result = extract_json_obj(response) + + # Validate response structure + if "result" in result: + logger.info( + f"Answerability: result={result['result']}; reason={result.get('reason', 'n/a')}; evaluated={len(limited_memories)}" + ) + return result["result"] + else: + logger.warning(f"Answerability: invalid LLM JSON structure; payload={result}") + return False + + except Exception as e: + logger.error(f"Answerability: parse failed; err={e}; raw={str(response)[:200]}...") + # Fallback: return False if we can't determine answer ability + return False + + # ---------------------- Enhancement helpers ---------------------- + def _build_enhancement_prompt(self, query_history: list[str], batch_texts: list[str]) -> str: + if len(query_history) == 1: + query_history = query_history[0] + else: + query_history = ( + [f"[{i}] {query}" for i, query in enumerate(query_history)] + if len(query_history) > 1 + else query_history[0] + ) + text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)]) + return self.build_prompt( + "memory_enhancement", + query_history=query_history, + memories=text_memories, + ) + + def _process_enhancement_batch( + self, + batch_index: int, + query_history: list[str], + memories: list[TextualMemoryItem], + retries: int, + ) -> tuple[list[TextualMemoryItem], bool]: + attempt = 0 + text_memories = [one.memory for one in memories] + while attempt <= max(0, retries) + 1: + try: + prompt = self._build_enhancement_prompt( + query_history=query_history, batch_texts=text_memories + ) + logger.debug( + f"[Enhance][batch={batch_index}] Prompt (first 200 chars, len={len(prompt)}): " + f"{prompt[:200]}]..." + ) + + response = self.process_llm.generate([{"role": "user", "content": prompt}]) + logger.debug( + f"[Enhance][batch={batch_index}] Response (first 200 chars): {response}..." + ) + + processed_text_memories = extract_list_items_in_answer(response) + if len(processed_text_memories) == len(memories): + # Update + for i, new_mem in enumerate(processed_text_memories): + memories[i].memory = new_mem + enhanced_memories = memories + else: + # create new + enhanced_memories = [] + user_id = memories[0].metadata.user_id + for new_mem in processed_text_memories: + enhanced_memories.append( + TextualMemoryItem( + memory=new_mem, metadata=TextualMemoryMetadata(user_id=user_id) + ) + ) + enhanced_memories = ( + enhanced_memories + memories[: len(memories) - len(enhanced_memories)] + ) + + logger.info( + f"[Enhance]: processed_text_memories: {len(processed_text_memories)}; padded with original memories to preserve total count" + ) + + return enhanced_memories, True + except Exception as e: + attempt += 1 + logger.debug( + f"[Enhance][batch={batch_index}] 🔁 retry {attempt}/{max(1, retries) + 1} failed: {e}" + ) + logger.error( + f"Fail to run memory enhancement; original memories: {memories}", exc_info=True + ) + return memories, False + + @staticmethod + def _split_batches( + memories: list[TextualMemoryItem], batch_size: int + ) -> list[tuple[int, int, list[TextualMemoryItem]]]: + batches: list[tuple[int, int, list[TextualMemoryItem]]] = [] + start = 0 + n = len(memories) + while start < n: + end = min(start + batch_size, n) + batches.append((start, end, memories[start:end])) + start = end + return batches + + def enhance_memories_with_query( + self, + query_history: list[str], + memories: list[TextualMemoryItem], + ) -> (list[TextualMemoryItem], bool): + """ + Enhance memories by adding context and making connections to better answer queries. + + Args: + query_history: List of user queries in chronological order + memories: List of memory items to enhance + + Returns: + Tuple of (enhanced_memories, success_flag) + """ + if not memories: + logger.warning("[Enhance] ⚠️ skipped (no memories to process)") + return memories, True + + batch_size = self.batch_size + retries = self.retries + num_of_memories = len(memories) + try: + # no parallel + if batch_size is None or num_of_memories <= batch_size: + # Single batch path with retry + enhanced_memories, success_flag = self._process_enhancement_batch( + batch_index=0, + query_history=query_history, + memories=memories, + retries=retries, + ) + + all_success = success_flag + else: + # parallel running batches + # Split into batches preserving order + batches = self._split_batches(memories=memories, batch_size=batch_size) + + # Process batches concurrently + all_success = True + failed_batches = 0 + with ContextThreadPoolExecutor(max_workers=len(batches)) as executor: + future_map = { + executor.submit( + self._process_enhancement_batch, bi, query_history, texts, retries + ): (bi, s, e) + for bi, (s, e, texts) in enumerate(batches) + } + enhanced_memories = [] + for fut in as_completed(future_map): + bi, s, e = future_map[fut] + + batch_memories, ok = fut.result() + enhanced_memories.extend(batch_memories) + if not ok: + all_success = False + failed_batches += 1 + logger.info( + f"[Enhance] ✅ multi-batch done | batches={len(batches)} | enhanced={len(enhanced_memories)} |" + f" failed_batches={failed_batches} | success={all_success}" + ) + + except Exception as e: + logger.error(f"[Enhance] ❌ fatal error: {e}", exc_info=True) + all_success = False + enhanced_memories = memories + + if len(enhanced_memories) == 0: + enhanced_memories = memories + logger.error("[Enhance] ❌ fatal error: enhanced_memories is empty", exc_info=True) + return enhanced_memories, all_success def search( self, @@ -115,7 +321,7 @@ def rerank_memories( try: # Parse JSON response - response = extract_json_dict(response) + response = extract_json_obj(response) new_order = response["new_order"][:top_k] text_memories_with_new_order = [original_memories[idx] for idx in new_order] logger.info( diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py index 46c4e2d49..99982d2e6 100644 --- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py +++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py @@ -11,6 +11,7 @@ from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL, DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES, + DEFAULT_STOP_WAIT, DEFAULT_STUCK_THREAD_TOLERANCE, ) from memos.mem_scheduler.utils.db_utils import get_utc_now @@ -46,6 +47,11 @@ def __init__(self, config: BaseSchedulerConfig): self.dispatcher: SchedulerDispatcher | None = None self.dispatcher_pool_name = "dispatcher" + # Configure shutdown wait behavior from config or default + self.stop_wait = ( + self.config.get("stop_wait", DEFAULT_STOP_WAIT) if self.config else DEFAULT_STOP_WAIT + ) + def initialize(self, dispatcher: SchedulerDispatcher): self.dispatcher = dispatcher self.register_pool( @@ -367,12 +373,9 @@ def stop(self) -> None: if not executor._shutdown: # pylint: disable=protected-access try: logger.info(f"Shutting down thread pool '{name}'") - executor.shutdown(wait=True, cancel_futures=True) + executor.shutdown(wait=self.stop_wait, cancel_futures=True) logger.info(f"Successfully shut down thread pool '{name}'") except Exception as e: logger.error(f"Error shutting down pool '{name}': {e!s}", exc_info=True) - # Clear the pool registry - self._pools.clear() - logger.info("Thread pool monitor and all pools stopped") diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py index a789d581e..a5f1c0097 100644 --- a/src/memos/mem_scheduler/monitors/general_monitor.py +++ b/src/memos/mem_scheduler/monitors/general_monitor.py @@ -29,7 +29,7 @@ QueryMonitorQueue, ) from memos.mem_scheduler.utils.db_utils import get_utc_now -from memos.mem_scheduler.utils.misc_utils import extract_json_dict +from memos.mem_scheduler.utils.misc_utils import extract_json_obj from memos.memories.textual.tree import TreeTextMemory @@ -92,7 +92,7 @@ def extract_query_keywords(self, query: str) -> list: llm_response = self._process_llm.generate([{"role": "user", "content": prompt}]) try: # Parse JSON output from LLM response - keywords = extract_json_dict(llm_response) + keywords = extract_json_obj(llm_response) assert isinstance(keywords, list) except Exception as e: logger.error( @@ -206,7 +206,7 @@ def update_working_memory_monitors( self.working_mem_monitor_capacity = min( DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT, ( - text_mem_base.memory_manager.memory_size["WorkingMemory"] + int(text_mem_base.memory_manager.memory_size["WorkingMemory"]) + self.partial_retention_number ), ) @@ -353,7 +353,7 @@ def detect_intent( ) response = self._process_llm.generate([{"role": "user", "content": prompt}]) try: - response = extract_json_dict(response) + response = extract_json_obj(response) assert ("trigger_retrieval" in response) and ("missing_evidences" in response) except Exception: logger.error(f"Fail to extract json dict from response: {response}") diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index a087ab2df..b62b1e51d 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -2,7 +2,7 @@ import os from collections import OrderedDict -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from memos.api.product_models import APISearchRequest from memos.configs.mem_scheduler import GeneralSchedulerConfig @@ -52,11 +52,24 @@ def __init__(self, config: GeneralSchedulerConfig): API_MIX_SEARCH_LABEL: self._api_mix_search_message_consumer, } ) + self.searcher = None + self.reranker = None + self.text_mem = None + + def init_mem_cube(self, mem_cube): + self.current_mem_cube = mem_cube + self.text_mem: TreeTextMemory = self.current_mem_cube.text_mem + self.searcher: Searcher = self.text_mem.get_searcher( + manual_close_internet=False, + moscube=False, + ) + self.reranker: HTTPBGEReranker = self.text_mem.reranker def submit_memory_history_async_task( self, search_req: APISearchRequest, user_context: UserContext, + memories_to_store: dict | None = None, session_id: str | None = None, ): # Create message for async fine search @@ -71,19 +84,16 @@ def submit_memory_history_async_task( "chat_history": search_req.chat_history, }, "user_context": {"mem_cube_id": user_context.mem_cube_id}, + "memories_to_store": memories_to_store, } async_task_id = f"mix_search_{search_req.user_id}_{get_utc_now().timestamp()}" - # Get mem_cube for the message - mem_cube = self.current_mem_cube - message = ScheduleMessageItem( item_id=async_task_id, user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id, label=API_MIX_SEARCH_LABEL, - mem_cube=mem_cube, content=json.dumps(message_content), timestamp=get_utc_now(), ) @@ -127,33 +137,26 @@ def mix_search_memories( self, search_req: APISearchRequest, user_context: UserContext, - ): + ) -> list[dict[str, Any]]: """ Mix search memories: fast search + async fine search """ # Get mem_cube for fast search - mem_cube = self.current_mem_cube - target_session_id = search_req.session_id if not target_session_id: target_session_id = "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - text_mem: TreeTextMemory = mem_cube.text_mem - searcher: Searcher = text_mem.get_searcher( - manual_close_internet=not search_req.internet_search, - moscube=False, - ) # Rerank Memories - reranker expects TextualMemoryItem objects - reranker: HTTPBGEReranker = text_mem.reranker + info = { "user_id": search_req.user_id, "session_id": target_session_id, "chat_history": search_req.chat_history, } - fast_retrieved_memories = searcher.retrieve( + fast_retrieved_memories = self.searcher.retrieve( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, @@ -164,13 +167,7 @@ def mix_search_memories( info=info, ) - self.submit_memory_history_async_task( - search_req=search_req, - user_context=user_context, - session_id=search_req.session_id, - ) - - # Try to get pre-computed fine memories if available + # Try to get pre-computed memories if available history_memories = self.api_module.get_history_memories( user_id=search_req.user_id, mem_cube_id=user_context.mem_cube_id, @@ -178,7 +175,7 @@ def mix_search_memories( ) if not history_memories: - fast_memories = searcher.post_retrieve( + fast_memories = self.searcher.post_retrieve( retrieved_results=fast_retrieved_memories, top_k=search_req.top_k, user_name=user_context.mem_cube_id, @@ -187,39 +184,72 @@ def mix_search_memories( # Format fast memories for return formatted_memories = [format_textual_memory_item(data) for data in fast_memories] return formatted_memories + else: + # if history memories can directly answer + sorted_history_memories = self.reranker.rerank( + query=search_req.query, # Use search_req.query instead of undefined query + graph_results=history_memories, # Pass TextualMemoryItem objects directly + top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k + search_filter=search_filter, + ) - sorted_history_memories = reranker.rerank( - query=search_req.query, # Use search_req.query instead of undefined query - graph_results=history_memories, # Pass TextualMemoryItem objects directly - top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k - search_filter=search_filter, - ) + processed_hist_mem = self.searcher.post_retrieve( + retrieved_results=sorted_history_memories, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) - sorted_results = fast_retrieved_memories + sorted_history_memories - final_results = searcher.post_retrieve( - retrieved_results=sorted_results, - top_k=search_req.top_k, - user_name=user_context.mem_cube_id, - info=info, - ) + can_answer = self.retriever.evaluate_memory_answer_ability( + query=search_req.query, memory_texts=[one.memory for one in processed_hist_mem] + ) - formatted_memories = [ - format_textual_memory_item(item) for item in final_results[: search_req.top_k] - ] + if can_answer: + sorted_results = fast_retrieved_memories + sorted_history_memories + combined_results = self.searcher.post_retrieve( + retrieved_results=sorted_results, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) + memories = combined_results[: search_req.top_k] + formatted_memories = [format_textual_memory_item(item) for item in memories] + logger.info("can_answer") + else: + sorted_results = fast_retrieved_memories + sorted_history_memories + combined_results = self.searcher.post_retrieve( + retrieved_results=sorted_results, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) + enhanced_results, _ = self.retriever.enhance_memories_with_query( + query_history=[search_req.query], + memories=combined_results, + ) + memories = enhanced_results[: search_req.top_k] + formatted_memories = [format_textual_memory_item(item) for item in memories] + logger.info("cannot answer") + + self.submit_memory_history_async_task( + search_req=search_req, + user_context=user_context, + memories_to_store={ + "memories": [one.to_dict() for one in memories], + "formatted_memories": formatted_memories, + }, + ) - return formatted_memories + return formatted_memories def update_search_memories_to_redis( self, messages: list[ScheduleMessageItem], ): - mem_cube: NaiveMemCube = self.current_mem_cube - for msg in messages: content_dict = json.loads(msg.content) search_req = content_dict["search_req"] user_context = content_dict["user_context"] - session_id = search_req.get("session_id") if session_id: if session_id not in self.session_counter: @@ -237,13 +267,20 @@ def update_search_memories_to_redis( else: session_turn = 0 - memories: list[TextualMemoryItem] = self.search_memories( - search_req=APISearchRequest(**content_dict["search_req"]), - user_context=UserContext(**content_dict["user_context"]), - mem_cube=mem_cube, - mode=SearchMode.FAST, - ) - formatted_memories = [format_textual_memory_item(data) for data in memories] + memories_to_store = content_dict["memories_to_store"] + if memories_to_store is None: + memories: list[TextualMemoryItem] = self.search_memories( + search_req=APISearchRequest(**content_dict["search_req"]), + user_context=UserContext(**content_dict["user_context"]), + mem_cube=self.current_mem_cube, + mode=SearchMode.FAST, + ) + formatted_memories = [format_textual_memory_item(data) for data in memories] + else: + memories = [ + TextualMemoryItem.from_dict(one) for one in memories_to_store["memories"] + ] + formatted_memories = memories_to_store["formatted_memories"] # Sync search data to Redis self.api_module.sync_search_data( diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index f3d2191f8..7f2c09b7d 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -6,6 +6,7 @@ class SearchMode(str, Enum): """Enumeration for search modes.""" + NOT_INITIALIZED = "not_initialized" FAST = "fast" FINE = "fine" MIXTURE = "mixture" @@ -32,14 +33,18 @@ class SearchMode(str, Enum): DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache" DEFAULT_THREAD_POOL_MAX_WORKERS = 50 DEFAULT_CONSUME_INTERVAL_SECONDS = 0.05 +DEFAULT_CONSUME_BATCH = 1 DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300 DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2 DEFAULT_STUCK_THREAD_TOLERANCE = 10 -DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 1000000 +DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 0 DEFAULT_TOP_K = 10 DEFAULT_CONTEXT_WINDOW_SIZE = 5 -DEFAULT_USE_REDIS_QUEUE = False +DEFAULT_USE_REDIS_QUEUE = True DEFAULT_MULTI_TASK_RUNNING_TIMEOUT = 30 +DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE = 10 +DEFAULT_SCHEDULER_RETRIEVER_RETRIES = 1 +DEFAULT_STOP_WAIT = False # startup mode configuration STARTUP_BY_THREAD = "thread" @@ -64,6 +69,7 @@ class SearchMode(str, Enum): MONITOR_ACTIVATION_MEMORY_TYPE = "MonitorActivationMemoryType" DEFAULT_MAX_QUERY_KEY_WORDS = 1000 DEFAULT_WEIGHT_VECTOR_FOR_RANKING = [0.9, 0.05, 0.05] +DEFAULT_MAX_WEB_LOG_QUEUE_SIZE = 50 # new types diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py index 7f328474f..f1d48f3f1 100644 --- a/src/memos/mem_scheduler/schemas/message_schemas.py +++ b/src/memos/mem_scheduler/schemas/message_schemas.py @@ -2,11 +2,10 @@ from typing import Any from uuid import uuid4 -from pydantic import BaseModel, ConfigDict, Field, field_serializer +from pydantic import BaseModel, ConfigDict, Field from typing_extensions import TypedDict from memos.log import get_logger -from memos.mem_cube.base import BaseMemCube from memos.mem_scheduler.general_modules.misc import DictConversionMixin from memos.mem_scheduler.utils.db_utils import get_utc_now @@ -34,22 +33,19 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): item_id: str = Field(description="uuid", default_factory=lambda: str(uuid4())) + redis_message_id: str = Field(default="", description="the message get from redis stream") user_id: str = Field(..., description="user id") mem_cube_id: str = Field(..., description="memcube id") + session_id: str = Field(default="", description="Session ID for soft-filtering memories") label: str = Field(..., description="Label of the schedule message") - mem_cube: BaseMemCube | str = Field(..., description="memcube for schedule") content: str = Field(..., description="Content of the schedule message") timestamp: datetime = Field( default_factory=get_utc_now, description="submit time for schedule_messages" ) - user_name: str | None = Field( - default=None, + user_name: str = Field( + default="", description="user name / display name (optional)", ) - session_id: str | None = Field( - default=None, - description="session_id (optional)", - ) # Pydantic V2 model configuration model_config = ConfigDict( @@ -65,7 +61,6 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): "user_id": "user123", # Example user identifier "mem_cube_id": "cube456", # Sample memory cube ID "label": "sample_label", # Demonstration label value - "mem_cube": "obj of GeneralMemCube", # Added mem_cube example "content": "sample content", # Example message content "timestamp": "2024-07-22T12:00:00Z", # Added timestamp example "user_name": "Alice", # Added username example @@ -73,13 +68,6 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin): }, ) - @field_serializer("mem_cube") - def serialize_mem_cube(self, cube: BaseMemCube | str, _info) -> str: - """Custom serializer for BaseMemCube objects to string representation""" - if isinstance(cube, str): - return cube - return f"<{type(cube).__name__}:{id(cube)}>" - def to_dict(self) -> dict: """Convert model to dictionary suitable for Redis Stream""" return { @@ -101,7 +89,6 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem": user_id=data["user_id"], mem_cube_id=data["cube_id"], label=data["label"], - mem_cube="Not Applicable", # Custom cube deserialization content=data["content"], timestamp=datetime.fromisoformat(data["timestamp"]), user_name=data.get("user_name"), diff --git a/src/memos/mem_scheduler/utils/metrics.py b/src/memos/mem_scheduler/utils/metrics.py index 5155c98b3..45abc5b36 100644 --- a/src/memos/mem_scheduler/utils/metrics.py +++ b/src/memos/mem_scheduler/utils/metrics.py @@ -6,10 +6,14 @@ from dataclasses import dataclass, field +from memos.log import get_logger + # ==== global window config ==== WINDOW_SEC = 120 # 2 minutes sliding window +logger = get_logger(__name__) + # ---------- O(1) EWMA ---------- class Ewma: @@ -187,7 +191,7 @@ def on_enqueue( old_lam = ls.lambda_ewma.value_at(now) ls.lambda_ewma.update(inst_rate, now) new_lam = ls.lambda_ewma.value_at(now) - print( + logger.info( f"[DEBUG enqueue] {label} backlog={ls.backlog} dt={dt if dt is not None else '—'}s inst={inst_rate:.3f} λ {old_lam:.3f}→{new_lam:.3f}" ) self._label_topk[label].add(mem_cube_id) @@ -225,7 +229,7 @@ def on_done( old_mu = ls.mu_ewma.value_at(now) ls.mu_ewma.update(inst_rate, now) new_mu = ls.mu_ewma.value_at(now) - print( + logger.info( f"[DEBUG done] {label} backlog={ls.backlog} dt={dt if dt is not None else '—'}s inst={inst_rate:.3f} μ {old_mu:.3f}→{new_mu:.3f}" ) ds = self._detail_stats.get((label, mem_cube_id)) diff --git a/src/memos/mem_scheduler/utils/misc_utils.py b/src/memos/mem_scheduler/utils/misc_utils.py index aa9b5c489..cce1286bb 100644 --- a/src/memos/mem_scheduler/utils/misc_utils.py +++ b/src/memos/mem_scheduler/utils/misc_utils.py @@ -1,5 +1,6 @@ import json import re +import traceback from functools import wraps from pathlib import Path @@ -12,7 +13,7 @@ logger = get_logger(__name__) -def extract_json_dict(text: str): +def extract_json_obj(text: str): """ Safely extracts JSON from LLM response text with robust error handling. @@ -40,7 +41,7 @@ def extract_json_dict(text: str): try: return json.loads(text.strip()) except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True) + logger.info(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True) # Fallback 1: Extract JSON using regex json_pattern = r"\{[\s\S]*\}|\[[\s\S]*\]" @@ -49,7 +50,7 @@ def extract_json_dict(text: str): try: return json.loads(matches[0]) except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True) + logger.info(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True) # Fallback 2: Handle malformed JSON (common LLM issues) try: @@ -57,10 +58,125 @@ def extract_json_dict(text: str): text = re.sub(r"([\{\s,])(\w+)(:)", r'\1"\2"\3', text) return json.loads(text) except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True) + logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}") + logger.error("Full traceback:\n" + traceback.format_exc()) raise ValueError(text) from e +def extract_list_items(text: str, bullet_prefixes: tuple[str, ...] = ("- ",)) -> list[str]: + """ + Extract bullet list items from LLM output where each item is on a single line + starting with a given bullet prefix (default: "- "). + + This function is designed to be robust to common LLM formatting variations, + following similar normalization practices as `extract_json_obj`. + + Behavior: + - Strips common code-fence markers (```json, ```python, ``` etc.). + - Collects all lines that start with any of the provided `bullet_prefixes`. + - Tolerates the "• " bullet as a loose fallback. + - Unescapes common sequences like "\\n" and "\\t" within items. + - If no bullet lines are found, falls back to attempting to parse a JSON array + (using `extract_json_obj`) and returns its string elements. + + Args: + text: Raw text response from LLM. + bullet_prefixes: Tuple of accepted bullet line prefixes. + + Returns: + List of extracted items (strings). Returns an empty list if none can be parsed. + """ + if not text: + return [] + + # Normalize the text similar to extract_json_obj + normalized = text.strip() + patterns_to_remove = ["json```", "```python", "```json", "latex```", "```latex", "```"] + for pattern in patterns_to_remove: + normalized = normalized.replace(pattern, "") + normalized = normalized.replace("\r\n", "\n") + + lines = normalized.splitlines() + items: list[str] = [] + seen: set[str] = set() + + for raw in lines: + line = raw.strip() + if not line: + continue + + matched = False + for prefix in bullet_prefixes: + if line.startswith(prefix): + content = line[len(prefix) :].strip() + content = content.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r") + if content and content not in seen: + items.append(content) + seen.add(content) + matched = True + break + + if matched: + continue + + if items: + return items + else: + logger.error(f"Fail to parse {text}") + + return [] + + +def extract_list_items_in_answer( + text: str, bullet_prefixes: tuple[str, ...] = ("- ",) +) -> list[str]: + """ + Extract list items specifically from content enclosed within `...` tags. + + - When one or more `...` blocks are present, concatenates their inner + contents with newlines and parses using `extract_list_items`. + - When no `` block is found, falls back to parsing the entire input with + `extract_list_items`. + - Case-insensitive matching of the `` tag. + + Args: + text: Raw text that may contain `...` blocks. + bullet_prefixes: Accepted bullet prefixes (default: strictly `"- "`). + + Returns: + List of extracted items (strings), or an empty list when nothing is parseable. + """ + if not text: + return [] + + try: + normalized = text.strip().replace("\r\n", "\n") + # Ordered, exact-case matching for blocks: answer -> Answer -> ANSWER + tag_variants = ["answer", "Answer", "ANSWER"] + matches: list[str] = [] + for tag in tag_variants: + matches = re.findall(rf"<{tag}>([\\s\\S]*?)", normalized) + if matches: + break + # Fallback: case-insensitive matching if none of the exact-case variants matched + if not matches: + matches = re.findall(r"([\\s\\S]*?)", normalized, flags=re.IGNORECASE) + + if matches: + combined = "\n".join(m.strip() for m in matches if m is not None) + return extract_list_items(combined, bullet_prefixes=bullet_prefixes) + + # Fallback: parse the whole text if tags are absent + return extract_list_items(normalized, bullet_prefixes=bullet_prefixes) + except Exception as e: + logger.info(f"Failed to extract items within tags: {e!s}", exc_info=True) + # Final fallback: attempt direct list extraction + try: + return extract_list_items(text, bullet_prefixes=bullet_prefixes) + except Exception: + return [] + + def parse_yaml(yaml_file: str | Path): yaml_path = Path(yaml_file) if not yaml_path.is_file(): diff --git a/src/memos/mem_scheduler/webservice_modules/redis_service.py b/src/memos/mem_scheduler/webservice_modules/redis_service.py index 5439af9c6..e79553f33 100644 --- a/src/memos/mem_scheduler/webservice_modules/redis_service.py +++ b/src/memos/mem_scheduler/webservice_modules/redis_service.py @@ -333,6 +333,15 @@ def redis_start_listening(self, handler: Callable | None = None): logger.warning("Listener is already running") return + # Check Redis connection before starting listener + if self.redis is None: + logger.warning( + "Redis connection is None, attempting to auto-initialize before starting listener..." + ) + if not self.auto_initialize_redis(): + logger.error("Failed to initialize Redis connection, cannot start listener") + return + if handler is None: handler = self.redis_consume_message_stream diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py index 55e33494c..b9814f079 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py @@ -22,6 +22,7 @@ class TaskGoalParser: def __init__(self, llm=BaseLLM): self.llm = llm self.tokenizer = FastTokenizer() + self.retries = 1 def parse( self, @@ -103,18 +104,24 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal: """ Parse LLM JSON output safely. """ - try: - context = kwargs.get("context", "") - response = response.replace("```", "").replace("json", "").strip() - response_json = eval(response) - return ParsedTaskGoal( - memories=response_json.get("memories", []), - keys=response_json.get("keys", []), - tags=response_json.get("tags", []), - rephrased_query=response_json.get("rephrased_instruction", None), - internet_search=response_json.get("internet_search", False), - goal_type=response_json.get("goal_type", "default"), - context=context, - ) - except Exception as e: - raise ValueError(f"Failed to parse LLM output: {e}\nRaw response:\n{response}") from e + # Ensure at least one attempt + attempts = max(1, getattr(self, "retries", 1)) + + for attempt_times in range(attempts): + try: + context = kwargs.get("context", "") + response = response.replace("```", "").replace("json", "").strip() + response_json = eval(response) + return ParsedTaskGoal( + memories=response_json.get("memories", []), + keys=response_json.get("keys", []), + tags=response_json.get("tags", []), + rephrased_query=response_json.get("rephrased_instruction", None), + internet_search=response_json.get("internet_search", False), + goal_type=response_json.get("goal_type", "default"), + context=context, + ) + except Exception as e: + raise ValueError( + f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts + 1}" + ) from e diff --git a/src/memos/templates/mem_scheduler_prompts.py b/src/memos/templates/mem_scheduler_prompts.py index b4d091c1f..197a2c1a7 100644 --- a/src/memos/templates/mem_scheduler_prompts.py +++ b/src/memos/templates/mem_scheduler_prompts.py @@ -390,6 +390,45 @@ - Focus on whether the memories can fully answer the query without additional information """ +MEMORY_ENHANCEMENT_PROMPT = """ +You are a knowledgeable and precise AI assistant. + +# GOAL +Transform each raw memory into an enhanced version that preserves all relevant factual details and makes the information directly useful for answering the user's query. + +# CORE PRINCIPLE +Focus on **relevance** — the enhanced memories should highlight, clarify, and preserve the information that most directly supports answering the current query. + +# RULES & THINKING STEPS +1. Read the user query carefully and identify what specific facts are needed to answer it. +2. Go through each memory and: + - Keep only details directly relevant to the query (dates, actions, entities, outcomes). + - Remove unrelated or background details. + - If nothing in a memory relates to the query, delete the entire memory. +3. Do not add or infer new facts. +4. Keep facts accurate and phrased clearly. +5. Each resulting line should stand alone as a usable fact for answering the query. + +# OUTPUT FORMAT (STRICT) +Return ONLY the following block, with **one enhanced memory per line**. +Each line MUST start with "- " (dash + space). + +Wrap the final output inside: + +- enhanced memory 1 +- enhanced memory 2 +... + + +## User Query +{query_history} + +## Available Memories +{memories} + +Answer: +""" + PROMPT_MAPPING = { "intent_recognizing": INTENT_RECOGNIZING_PROMPT, "memory_reranking": MEMORY_RERANKING_PROMPT, @@ -398,6 +437,7 @@ "memory_redundancy_filtering": MEMORY_REDUNDANCY_FILTERING_PROMPT, "memory_combined_filtering": MEMORY_COMBINED_FILTERING_PROMPT, "memory_answer_ability_evaluation": MEMORY_ANSWER_ABILITY_EVALUATION_PROMPT, + "memory_enhancement": MEMORY_ENHANCEMENT_PROMPT, } MEMORY_ASSEMBLY_TEMPLATE = """The retrieved memories are listed as follows:\n\n {memory_text}""" diff --git a/tests/mem_scheduler/test_dispatcher.py b/tests/mem_scheduler/test_dispatcher.py index e3064660b..fc154e013 100644 --- a/tests/mem_scheduler/test_dispatcher.py +++ b/tests/mem_scheduler/test_dispatcher.py @@ -90,7 +90,6 @@ def setUp(self): ScheduleMessageItem( item_id="msg1", user_id="user1", - mem_cube="cube1", mem_cube_id="msg1", label="label1", content="Test content 1", @@ -99,7 +98,6 @@ def setUp(self): ScheduleMessageItem( item_id="msg2", user_id="user1", - mem_cube="cube1", mem_cube_id="msg2", label="label2", content="Test content 2", @@ -108,7 +106,6 @@ def setUp(self): ScheduleMessageItem( item_id="msg3", user_id="user2", - mem_cube="cube2", mem_cube_id="msg3", label="label1", content="Test content 3", @@ -193,46 +190,6 @@ def test_dispatch_serial(self): self.assertEqual(len(label2_messages), 1) self.assertEqual(label2_messages[0].item_id, "msg2") - def test_dispatch_parallel(self): - """Test dispatching messages in parallel mode.""" - # Create fresh mock handlers for this test - mock_handler1 = MagicMock() - mock_handler2 = MagicMock() - - # Create a new dispatcher for this test to avoid interference - parallel_dispatcher = SchedulerDispatcher(max_workers=2, enable_parallel_dispatch=True) - parallel_dispatcher.register_handler("label1", mock_handler1) - parallel_dispatcher.register_handler("label2", mock_handler2) - - # Dispatch messages - parallel_dispatcher.dispatch(self.test_messages) - - # Wait for all futures to complete - parallel_dispatcher.join(timeout=1.0) - - # Verify handlers were called - label1 handler should be called twice (for user1 and user2) - # label2 handler should be called once (only for user1) - self.assertEqual(mock_handler1.call_count, 2) # Called for user1/msg1 and user2/msg3 - mock_handler2.assert_called_once() # Called for user1/msg2 - - # Check that each handler received the correct messages - # For label1: should have two calls, each with one message - label1_calls = mock_handler1.call_args_list - self.assertEqual(len(label1_calls), 2) - - # Extract messages from calls - call1_messages = label1_calls[0][0][0] # First call, first argument (messages list) - call2_messages = label1_calls[1][0][0] # Second call, first argument (messages list) - - # Verify the messages in each call - self.assertEqual(len(call1_messages), 1) - self.assertEqual(len(call2_messages), 1) - - # For label2: should have one call with [msg2] - label2_messages = mock_handler2.call_args[0][0] - self.assertEqual(len(label2_messages), 1) - self.assertEqual(label2_messages[0].item_id, "msg2") - def test_group_messages_by_user_and_mem_cube(self): """Test grouping messages by user and cube.""" # Check actual grouping logic diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py index 03a8e4318..fed1e8500 100644 --- a/tests/mem_scheduler/test_scheduler.py +++ b/tests/mem_scheduler/test_scheduler.py @@ -1,7 +1,6 @@ import sys import unittest -from contextlib import suppress from datetime import datetime from pathlib import Path from unittest.mock import MagicMock, patch @@ -21,12 +20,9 @@ from memos.mem_scheduler.schemas.general_schemas import ( ANSWER_LABEL, QUERY_LABEL, - STARTUP_BY_PROCESS, - STARTUP_BY_THREAD, ) from memos.mem_scheduler.schemas.message_schemas import ( ScheduleLogForWebItem, - ScheduleMessageItem, ) from memos.memories.textual.tree import TreeTextMemory @@ -182,124 +178,6 @@ def test_submit_web_logs(self): self.assertTrue(hasattr(actual_message, "timestamp")) self.assertTrue(isinstance(actual_message.timestamp, datetime)) - def test_scheduler_startup_mode_default(self): - """Test that scheduler has default startup mode set to thread.""" - self.assertEqual(self.scheduler.scheduler_startup_mode, STARTUP_BY_THREAD) - - def test_scheduler_startup_mode_thread(self): - """Test scheduler with thread startup mode.""" - # Set scheduler startup mode to thread - self.scheduler.scheduler_startup_mode = STARTUP_BY_THREAD - - # Start the scheduler - self.scheduler.start() - - # Verify that consumer thread is created and process is None - self.assertIsNotNone(self.scheduler._consumer_thread) - self.assertIsNone(self.scheduler._consumer_process) - self.assertTrue(self.scheduler._running) - - # Stop the scheduler - self.scheduler.stop() - - def test_redis_message_queue(self): - """Test Redis message queue functionality for sending and receiving messages.""" - import time - - from unittest.mock import MagicMock, patch - - # Mock Redis connection and operations - mock_redis = MagicMock() - mock_redis.xadd = MagicMock(return_value=b"1234567890-0") - - # Track received messages - received_messages = [] - - def redis_handler(messages: list[ScheduleMessageItem]) -> None: - """Handler for Redis messages.""" - received_messages.extend(messages) - - # Register Redis handler - redis_label = "test_redis" - handlers = {redis_label: redis_handler} - self.scheduler.register_handlers(handlers) - - # Enable Redis queue for this test - with ( - patch.object(self.scheduler, "use_redis_queue", True), - patch.object(self.scheduler, "_redis_conn", mock_redis), - ): - # Start scheduler - self.scheduler.start() - - # Create test message for Redis - redis_message = ScheduleMessageItem( - label=redis_label, - content="Redis test message", - user_id="redis_user", - mem_cube_id="redis_cube", - mem_cube="redis_mem_cube_obj", - timestamp=datetime.now(), - ) - - # Submit message to Redis queue - self.scheduler.submit_messages(redis_message) - - # Verify Redis xadd was called - mock_redis.xadd.assert_called_once() - call_args = mock_redis.xadd.call_args - self.assertEqual(call_args[0][0], "user:queries:stream") - - # Verify message data was serialized correctly - message_data = call_args[0][1] - self.assertEqual(message_data["label"], redis_label) - self.assertEqual(message_data["content"], "Redis test message") - self.assertEqual(message_data["user_id"], "redis_user") - self.assertEqual(message_data["cube_id"], "redis_cube") # Note: to_dict uses cube_id - - # Simulate Redis message consumption - # This would normally be handled by the Redis consumer in the scheduler - time.sleep(0.1) # Brief wait for async operations - - # Stop scheduler - self.scheduler.stop() - - print("Redis message queue test completed successfully!") - - # Removed test_robustness method - was too time-consuming for CI/CD pipeline - - def test_scheduler_startup_mode_process(self): - """Test scheduler with process startup mode.""" - # Set scheduler startup mode to process - self.scheduler.scheduler_startup_mode = STARTUP_BY_PROCESS - - # Start the scheduler - try: - self.scheduler.start() - - # Verify that consumer process is created and thread is None - self.assertIsNotNone(self.scheduler._consumer_process) - self.assertIsNone(self.scheduler._consumer_thread) - self.assertTrue(self.scheduler._running) - - except Exception as e: - # Process mode may fail due to pickling issues in test environment - # This is expected behavior - we just verify the startup mode is set correctly - self.assertEqual(self.scheduler.scheduler_startup_mode, STARTUP_BY_PROCESS) - print(f"Process mode test encountered expected pickling issue: {e}") - finally: - # Always attempt to stop the scheduler - with suppress(Exception): - self.scheduler.stop() - - # Verify cleanup attempt was made - self.assertEqual(self.scheduler.scheduler_startup_mode, STARTUP_BY_PROCESS) - - def test_scheduler_startup_mode_constants(self): - """Test that startup mode constants are properly defined.""" - self.assertEqual(STARTUP_BY_THREAD, "thread") - self.assertEqual(STARTUP_BY_PROCESS, "process") - def test_activation_memory_update(self): """Test activation memory update functionality with DynamicCache handling.""" if not self.RUN_ACTIVATION_MEMORY_TESTS: @@ -401,130 +279,3 @@ def test_dynamic_cache_layers_access(self): # If layers attribute doesn't exist, verify our fix handles this case print("⚠️ DynamicCache doesn't have 'layers' attribute in this transformers version") print("✅ Test passed - our code should handle this gracefully") - - def test_get_running_tasks_with_filter(self): - """Test get_running_tasks method with filter function.""" - # Mock dispatcher and its get_running_tasks method - mock_task_item1 = MagicMock() - mock_task_item1.item_id = "task_1" - mock_task_item1.user_id = "user_1" - mock_task_item1.mem_cube_id = "cube_1" - mock_task_item1.task_info = {"type": "query"} - mock_task_item1.task_name = "test_task_1" - mock_task_item1.start_time = datetime.now() - mock_task_item1.end_time = None - mock_task_item1.status = "running" - mock_task_item1.result = None - mock_task_item1.error_message = None - mock_task_item1.messages = [] - - # Define a filter function - def user_filter(task): - return task.user_id == "user_1" - - # Mock the filtered result (only task_1 matches the filter) - with patch.object( - self.scheduler.dispatcher, "get_running_tasks", return_value={"task_1": mock_task_item1} - ) as mock_get_running_tasks: - # Call get_running_tasks with filter - result = self.scheduler.get_running_tasks(filter_func=user_filter) - - # Verify result - self.assertIsInstance(result, dict) - self.assertIn("task_1", result) - self.assertEqual(len(result), 1) - - # Verify dispatcher method was called with filter - mock_get_running_tasks.assert_called_once_with(filter_func=user_filter) - - def test_get_running_tasks_empty_result(self): - """Test get_running_tasks method when no tasks are running.""" - # Mock dispatcher to return empty dict - with patch.object( - self.scheduler.dispatcher, "get_running_tasks", return_value={} - ) as mock_get_running_tasks: - # Call get_running_tasks - result = self.scheduler.get_running_tasks() - - # Verify empty result - self.assertIsInstance(result, dict) - self.assertEqual(len(result), 0) - - # Verify dispatcher method was called - mock_get_running_tasks.assert_called_once_with(filter_func=None) - - def test_get_running_tasks_no_dispatcher(self): - """Test get_running_tasks method when dispatcher is None.""" - # Temporarily set dispatcher to None - original_dispatcher = self.scheduler.dispatcher - self.scheduler.dispatcher = None - - # Call get_running_tasks - result = self.scheduler.get_running_tasks() - - # Verify empty result and warning behavior - self.assertIsInstance(result, dict) - self.assertEqual(len(result), 0) - - # Restore dispatcher - self.scheduler.dispatcher = original_dispatcher - - def test_get_running_tasks_multiple_tasks(self): - """Test get_running_tasks method with multiple tasks.""" - # Mock multiple task items - mock_task_item1 = MagicMock() - mock_task_item1.item_id = "task_1" - mock_task_item1.user_id = "user_1" - mock_task_item1.mem_cube_id = "cube_1" - mock_task_item1.task_info = {"type": "query"} - mock_task_item1.task_name = "test_task_1" - mock_task_item1.start_time = datetime.now() - mock_task_item1.end_time = None - mock_task_item1.status = "running" - mock_task_item1.result = None - mock_task_item1.error_message = None - mock_task_item1.messages = [] - - mock_task_item2 = MagicMock() - mock_task_item2.item_id = "task_2" - mock_task_item2.user_id = "user_2" - mock_task_item2.mem_cube_id = "cube_2" - mock_task_item2.task_info = {"type": "answer"} - mock_task_item2.task_name = "test_task_2" - mock_task_item2.start_time = datetime.now() - mock_task_item2.end_time = None - mock_task_item2.status = "completed" - mock_task_item2.result = "success" - mock_task_item2.error_message = None - mock_task_item2.messages = ["message1", "message2"] - - with patch.object( - self.scheduler.dispatcher, - "get_running_tasks", - return_value={"task_1": mock_task_item1, "task_2": mock_task_item2}, - ) as mock_get_running_tasks: - # Call get_running_tasks - result = self.scheduler.get_running_tasks() - - # Verify result structure - self.assertIsInstance(result, dict) - self.assertEqual(len(result), 2) - self.assertIn("task_1", result) - self.assertIn("task_2", result) - - # Verify task_1 details - task1_dict = result["task_1"] - self.assertEqual(task1_dict["item_id"], "task_1") - self.assertEqual(task1_dict["user_id"], "user_1") - self.assertEqual(task1_dict["status"], "running") - - # Verify task_2 details - task2_dict = result["task_2"] - self.assertEqual(task2_dict["item_id"], "task_2") - self.assertEqual(task2_dict["user_id"], "user_2") - self.assertEqual(task2_dict["status"], "completed") - self.assertEqual(task2_dict["result"], "success") - self.assertEqual(task2_dict["messages"], ["message1", "message2"]) - - # Verify dispatcher method was called - mock_get_running_tasks.assert_called_once_with(filter_func=None) From 940dfde683421c997d0d0f14e40a55f8f0cea191 Mon Sep 17 00:00:00 2001 From: Wustzdy <67457465+wustzdy@users.noreply.github.com> Date: Tue, 11 Nov 2025 19:47:00 +0800 Subject: [PATCH 07/18] add pool health && log (#482) add pool health --- src/memos/graph_dbs/polardb.py | 46 ++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py index 60902420f..da1635296 100644 --- a/src/memos/graph_dbs/polardb.py +++ b/src/memos/graph_dbs/polardb.py @@ -199,6 +199,7 @@ def _get_connection(self): max_retries = 3 for attempt in range(max_retries): + conn = None try: conn = self.connection_pool.getconn() @@ -216,8 +217,49 @@ def _get_connection(self): # Set autocommit for PolarDB compatibility conn.autocommit = True + + # Test connection health with SELECT 1 + try: + cursor = conn.cursor() + cursor.execute("SELECT 1") + cursor.fetchone() + cursor.close() + except Exception as health_check_error: + # Connection is not usable, close it and try again + logger.warning( + f"Connection health check failed: {health_check_error}, closing connection and retrying..." + ) + try: + conn.close() + except Exception as close_error: + logger.warning(f"Failed to close unhealthy connection: {close_error}") + + # Return connection to pool if it's still valid + try: + self.connection_pool.putconn(conn, close=True) + except Exception as close_error: + logger.warning(f"Failed to connection_pool.putconn: {close_error}") + + conn = None + if attempt < max_retries - 1: + continue + else: + raise RuntimeError( + f"Failed to get a healthy connection from pool after {max_retries} attempts: {health_check_error}" + ) from health_check_error + + # Connection is healthy, return it return conn except Exception as e: + # If we have a connection that failed, try to return it to pool + if conn is not None: + try: + self.connection_pool.putconn(conn, close=True) + except Exception as putconn_error: + logger.warning( + f"Failed to connection_pool.putconn to pool: {putconn_error}" + ) + if attempt >= max_retries - 1: raise RuntimeError(f"Failed to get a valid connection from pool: {e}") from e continue @@ -647,12 +689,16 @@ def add_edge( self, source_id: str, target_id: str, type: str, user_name: str | None = None ) -> None: if not source_id or not target_id: + logger.warning(f"Edge '{source_id}' and '{target_id}' are both None") raise ValueError("[add_edge] source_id and target_id must be provided") source_exists = self.get_node(source_id) is not None target_exists = self.get_node(target_id) is not None if not source_exists or not target_exists: + logger.warning( + "[add_edge] Source %s or target %s does not exist.", source_exists, target_exists + ) raise ValueError("[add_edge] source_id and target_id must be provided") properties = {} From 2bf2ad00710b7dd327e24bb555a92bf9c0dea783 Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Thu, 13 Nov 2025 21:16:27 +0800 Subject: [PATCH 08/18] Feat: reorganize playground code and merge to API (#488) * feat: re org code * feat: code reorg and merge API and playground * feat: update memcube info * feat: remove act mem and params mem * feat: upadte init * code suffix --------- Co-authored-by: CaralHsi --- src/memos/api/handlers/__init__.py | 62 ++ src/memos/api/handlers/add_handler.py | 294 ++++++ src/memos/api/handlers/base_handler.py | 173 ++++ src/memos/api/handlers/chat_handler.py | 824 +++++++++++++++ src/memos/api/handlers/component_init.py | 264 +++++ src/memos/api/handlers/config_builders.py | 153 +++ src/memos/api/handlers/formatters_handler.py | 92 ++ src/memos/api/handlers/memory_handler.py | 151 +++ src/memos/api/handlers/scheduler_handler.py | 220 ++++ src/memos/api/handlers/search_handler.py | 289 ++++++ src/memos/api/handlers/suggestion_handler.py | 117 +++ src/memos/api/product_models.py | 4 + src/memos/api/routers/server_router.py | 966 +++--------------- src/memos/mem_os/utils/reference_utils.py | 23 +- .../mem_scheduler/general_modules/base.py | 2 +- src/memos/mem_scheduler/general_scheduler.py | 9 - src/memos/memories/textual/tree.py | 13 +- 17 files changed, 2823 insertions(+), 833 deletions(-) create mode 100644 src/memos/api/handlers/__init__.py create mode 100644 src/memos/api/handlers/add_handler.py create mode 100644 src/memos/api/handlers/base_handler.py create mode 100644 src/memos/api/handlers/chat_handler.py create mode 100644 src/memos/api/handlers/component_init.py create mode 100644 src/memos/api/handlers/config_builders.py create mode 100644 src/memos/api/handlers/formatters_handler.py create mode 100644 src/memos/api/handlers/memory_handler.py create mode 100644 src/memos/api/handlers/scheduler_handler.py create mode 100644 src/memos/api/handlers/search_handler.py create mode 100644 src/memos/api/handlers/suggestion_handler.py diff --git a/src/memos/api/handlers/__init__.py b/src/memos/api/handlers/__init__.py new file mode 100644 index 000000000..90347768c --- /dev/null +++ b/src/memos/api/handlers/__init__.py @@ -0,0 +1,62 @@ +""" +Server handlers for MemOS API routers. + +This package contains modular handlers for the server_router, responsible for: +- Building component configurations (config_builders) +- Initializing server components (component_init) +- Formatting data for API responses (formatters) +- Handling search, add, scheduler, and chat operations +""" + +# Lazy imports to avoid circular dependencies +from memos.api.handlers import ( + add_handler, + chat_handler, + memory_handler, + scheduler_handler, + search_handler, + suggestion_handler, +) +from memos.api.handlers.component_init import init_server +from memos.api.handlers.config_builders import ( + build_embedder_config, + build_graph_db_config, + build_internet_retriever_config, + build_llm_config, + build_mem_reader_config, + build_pref_adder_config, + build_pref_extractor_config, + build_pref_retriever_config, + build_reranker_config, + build_vec_db_config, +) +from memos.api.handlers.formatters_handler import ( + format_memory_item, + post_process_pref_mem, + to_iter, +) + + +__all__ = [ + "add_handler", + "build_embedder_config", + "build_graph_db_config", + "build_internet_retriever_config", + "build_llm_config", + "build_mem_reader_config", + "build_pref_adder_config", + "build_pref_extractor_config", + "build_pref_retriever_config", + "build_reranker_config", + "build_vec_db_config", + "chat_handler", + "format_memory_item", + "formatters_handler", + "init_server", + "memory_handler", + "post_process_pref_mem", + "scheduler_handler", + "search_handler", + "suggestion_handler", + "to_iter", +] diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py new file mode 100644 index 000000000..48db7ae6e --- /dev/null +++ b/src/memos/api/handlers/add_handler.py @@ -0,0 +1,294 @@ +""" +Add handler for memory addition functionality (Class-based version). + +This module provides a class-based implementation of add handlers, +using dependency injection for better modularity and testability. +""" + +import json +import os + +from datetime import datetime + +from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies +from memos.api.product_models import APIADDRequest, MemoryResponse +from memos.context.context import ContextThreadPoolExecutor +from memos.mem_scheduler.schemas.general_schemas import ( + ADD_LABEL, + MEM_READ_LABEL, + PREF_ADD_LABEL, +) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.types import UserContext + + +class AddHandler(BaseHandler): + """ + Handler for memory addition operations. + + Handles both text and preference memory additions with sync/async support. + """ + + def __init__(self, dependencies: HandlerDependencies): + """ + Initialize add handler. + + Args: + dependencies: HandlerDependencies instance + """ + super().__init__(dependencies) + self._validate_dependencies("naive_mem_cube", "mem_reader", "mem_scheduler") + + def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse: + """ + Main handler for add memories endpoint. + + Orchestrates the addition of both text and preference memories, + supporting concurrent processing. + + Args: + add_req: Add memory request + + Returns: + MemoryResponse with added memory information + """ + # Create UserContext object + user_context = UserContext( + user_id=add_req.user_id, + mem_cube_id=add_req.mem_cube_id, + session_id=add_req.session_id or "default_session", + ) + + self.logger.info(f"Add Req is: {add_req}") + if (not add_req.messages) and add_req.memory_content: + add_req.messages = self._convert_content_messsage(add_req.memory_content) + self.logger.info(f"Converted Add Req content to messages: {add_req.messages}") + # Process text and preference memories in parallel + with ContextThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(self._process_text_mem, add_req, user_context) + pref_future = executor.submit(self._process_pref_mem, add_req, user_context) + + text_response_data = text_future.result() + pref_response_data = pref_future.result() + + self.logger.info(f"add_memories Text response data: {text_response_data}") + self.logger.info(f"add_memories Pref response data: {pref_response_data}") + + return MemoryResponse( + message="Memory added successfully", + data=text_response_data + pref_response_data, + ) + + def _convert_content_messsage(self, memory_content: str) -> list[dict[str, str]]: + """ + Convert content string to list of message dictionaries. + + Args: + content: add content string + + Returns: + List of message dictionaries + """ + messages_list = [ + { + "role": "user", + "content": memory_content, + "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + } + ] + # for only user-str input and convert message + return messages_list + + def _process_text_mem( + self, + add_req: APIADDRequest, + user_context: UserContext, + ) -> list[dict[str, str]]: + """ + Process and add text memories. + + Extracts memories from messages and adds them to the text memory system. + Handles both sync and async modes. + + Args: + add_req: Add memory request + user_context: User context with IDs + + Returns: + List of formatted memory responses + """ + target_session_id = add_req.session_id or "default_session" + + # Determine sync mode + sync_mode = add_req.async_mode or self._get_sync_mode() + + self.logger.info(f"Processing text memory with mode: {sync_mode}") + + # Extract memories + memories_local = self.mem_reader.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + }, + mode="fast" if sync_mode == "async" else "fine", + ) + flattened_local = [mm for m in memories_local for mm in m] + self.logger.info(f"Memory extraction completed for user {add_req.user_id}") + + # Add memories to text_mem + mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add( + flattened_local, + user_name=user_context.mem_cube_id, + ) + self.logger.info( + f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " + f"in session {add_req.session_id}: {mem_ids_local}" + ) + + # Schedule async/sync tasks + self._schedule_memory_tasks( + add_req=add_req, + user_context=user_context, + mem_ids=mem_ids_local, + sync_mode=sync_mode, + ) + + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.memory_type, + } + for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) + ] + + def _process_pref_mem( + self, + add_req: APIADDRequest, + user_context: UserContext, + ) -> list[dict[str, str]]: + """ + Process and add preference memories. + + Extracts preferences from messages and adds them to the preference memory system. + Handles both sync and async modes. + + Args: + add_req: Add memory request + user_context: User context with IDs + + Returns: + List of formatted preference responses + """ + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + + # Determine sync mode + sync_mode = add_req.async_mode or self._get_sync_mode() + target_session_id = add_req.session_id or "default_session" + + # Follow async behavior: enqueue when async + if sync_mode == "async": + try: + messages_list = [add_req.messages] + message_item_pref = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=self.naive_mem_cube, + label=PREF_ADD_LABEL, + content=json.dumps(messages_list), + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item_pref]) + self.logger.info("Submitted preference add to scheduler (async mode)") + except Exception as e: + self.logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True) + return [] + else: + # Sync mode: process immediately + pref_memories_local = self.naive_mem_cube.pref_mem.get_memory( + [add_req.messages], + type="chat", + info={ + "user_id": add_req.user_id, + "session_id": target_session_id, + "mem_cube_id": add_req.mem_cube_id, + }, + ) + pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local) + self.logger.info( + f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " + f"in session {add_req.session_id}: {pref_ids_local}" + ) + return [ + { + "memory": memory.memory, + "memory_id": memory_id, + "memory_type": memory.metadata.preference_type, + } + for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) + ] + + def _get_sync_mode(self) -> str: + """ + Get synchronization mode from memory cube. + + Returns: + Sync mode string ("sync" or "async") + """ + try: + return getattr(self.naive_mem_cube.text_mem, "mode", "sync") + except Exception: + return "sync" + + def _schedule_memory_tasks( + self, + add_req: APIADDRequest, + user_context: UserContext, + mem_ids: list[str], + sync_mode: str, + ) -> None: + """ + Schedule memory processing tasks based on sync mode. + + Args: + add_req: Add memory request + user_context: User context + mem_ids: List of memory IDs + sync_mode: Synchronization mode + """ + target_session_id = add_req.session_id or "default_session" + + if sync_mode == "async": + # Async mode: submit MEM_READ_LABEL task + try: + message_item_read = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=self.naive_mem_cube, + label=MEM_READ_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + user_name=add_req.mem_cube_id, + ) + self.mem_scheduler.submit_messages(messages=[message_item_read]) + self.logger.info(f"Submitted async memory read task: {json.dumps(mem_ids)}") + except Exception as e: + self.logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True) + else: + # Sync mode: submit ADD_LABEL task + message_item_add = ScheduleMessageItem( + user_id=add_req.user_id, + session_id=target_session_id, + mem_cube_id=add_req.mem_cube_id, + mem_cube=self.naive_mem_cube, + label=ADD_LABEL, + content=json.dumps(mem_ids), + timestamp=datetime.utcnow(), + user_name=add_req.mem_cube_id, + ) + self.mem_scheduler.submit_messages(messages=[message_item_add]) diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py new file mode 100644 index 000000000..86a00dc37 --- /dev/null +++ b/src/memos/api/handlers/base_handler.py @@ -0,0 +1,173 @@ +""" +Base handler for MemOS API handlers. + +This module provides the base class for all API handlers, implementing +dependency injection and common functionality. +""" + +from typing import Any + +from memos.log import get_logger + + +logger = get_logger(__name__) + + +class HandlerDependencies: + """ + Container for handler dependencies. + + This class acts as a dependency injection container, holding all + shared resources needed by handlers. + """ + + def __init__( + self, + llm: Any | None = None, + naive_mem_cube: Any | None = None, + mem_reader: Any | None = None, + mem_scheduler: Any | None = None, + embedder: Any | None = None, + reranker: Any | None = None, + graph_db: Any | None = None, + vector_db: Any | None = None, + internet_retriever: Any | None = None, + memory_manager: Any | None = None, + mos_server: Any | None = None, + **kwargs, + ): + """ + Initialize handler dependencies. + + Args: + llm: Language model instance + naive_mem_cube: Memory cube instance + mem_reader: Memory reader instance + mem_scheduler: Scheduler instance + embedder: Embedder instance + reranker: Reranker instance + graph_db: Graph database instance + vector_db: Vector database instance + internet_retriever: Internet retriever instance + memory_manager: Memory manager instance + mos_server: MOS server instance + **kwargs: Additional dependencies + """ + self.llm = llm + self.naive_mem_cube = naive_mem_cube + self.mem_reader = mem_reader + self.mem_scheduler = mem_scheduler + self.embedder = embedder + self.reranker = reranker + self.graph_db = graph_db + self.vector_db = vector_db + self.internet_retriever = internet_retriever + self.memory_manager = memory_manager + self.mos_server = mos_server + + # Store any additional dependencies + for key, value in kwargs.items(): + setattr(self, key, value) + + @classmethod + def from_init_server(cls, components: dict[str, Any]): + """ + Create dependencies from init_server() return values. + + Args: + components: Dictionary of components returned by init_server(). + All components will be automatically unpacked as dependencies. + + Returns: + HandlerDependencies instance + + Note: + This method uses **kwargs unpacking, so any new components added to + init_server() will automatically become available as dependencies + without modifying this code. + """ + return cls(**components) + + +class BaseHandler: + """ + Base class for all API handlers. + + Provides common functionality and dependency injection for handlers. + All specific handlers should inherit from this class. + """ + + def __init__(self, dependencies: HandlerDependencies): + """ + Initialize base handler. + + Args: + dependencies: HandlerDependencies instance containing all shared resources + """ + self.deps = dependencies + self.logger = get_logger(self.__class__.__name__) + + @property + def llm(self): + """Get LLM instance.""" + return self.deps.llm + + @property + def naive_mem_cube(self): + """Get memory cube instance.""" + return self.deps.naive_mem_cube + + @property + def mem_reader(self): + """Get memory reader instance.""" + return self.deps.mem_reader + + @property + def mem_scheduler(self): + """Get scheduler instance.""" + return self.deps.mem_scheduler + + @property + def embedder(self): + """Get embedder instance.""" + return self.deps.embedder + + @property + def reranker(self): + """Get reranker instance.""" + return self.deps.reranker + + @property + def graph_db(self): + """Get graph database instance.""" + return self.deps.graph_db + + @property + def vector_db(self): + """Get vector database instance.""" + return self.deps.vector_db + + @property + def mos_server(self): + """Get MOS server instance.""" + return self.deps.mos_server + + def _validate_dependencies(self, *required_deps: str) -> None: + """ + Validate that required dependencies are available. + + Args: + *required_deps: Names of required dependency attributes + + Raises: + ValueError: If any required dependency is None + """ + missing = [] + for dep_name in required_deps: + if not hasattr(self.deps, dep_name) or getattr(self.deps, dep_name) is None: + missing.append(dep_name) + + if missing: + raise ValueError( + f"{self.__class__.__name__} requires the following dependencies: {', '.join(missing)}" + ) diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py new file mode 100644 index 000000000..9b0048ed4 --- /dev/null +++ b/src/memos/api/handlers/chat_handler.py @@ -0,0 +1,824 @@ +""" +Chat handler for chat functionality (Class-based version). + +This module provides a complete implementation of chat handlers, +consolidating all chat-related logic without depending on mos_server. +""" + +import asyncio +import json +import traceback + +from collections.abc import Generator +from datetime import datetime +from typing import Any, Literal + +from fastapi import HTTPException +from fastapi.responses import StreamingResponse + +from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies +from memos.api.product_models import ( + APIADDRequest, + APIChatCompleteRequest, + APISearchRequest, + ChatRequest, +) +from memos.context.context import ContextThread +from memos.mem_os.utils.format_utils import clean_json_response +from memos.mem_os.utils.reference_utils import ( + prepare_reference_data, + process_streaming_references_complete, +) +from memos.mem_scheduler.schemas.general_schemas import ( + ANSWER_LABEL, + QUERY_LABEL, + SearchMode, +) +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.templates.mos_prompts import ( + FURTHER_SUGGESTION_PROMPT, + get_memos_prompt, +) +from memos.types import MessageList + + +class ChatHandler(BaseHandler): + """ + Handler for chat operations. + + Composes SearchHandler and AddHandler to provide complete chat functionality + without depending on mos_server. All chat logic is centralized here. + """ + + def __init__( + self, + dependencies: HandlerDependencies, + search_handler=None, + add_handler=None, + online_bot=None, + ): + """ + Initialize chat handler. + + Args: + dependencies: HandlerDependencies instance + search_handler: Optional SearchHandler instance (created if not provided) + add_handler: Optional AddHandler instance (created if not provided) + online_bot: Optional DingDing bot function for notifications + """ + super().__init__(dependencies) + self._validate_dependencies("llm", "naive_mem_cube", "mem_reader", "mem_scheduler") + + # Lazy import to avoid circular dependencies + if search_handler is None: + from memos.api.handlers.search_handler import SearchHandler + + search_handler = SearchHandler(dependencies) + + if add_handler is None: + from memos.api.handlers.add_handler import AddHandler + + add_handler = AddHandler(dependencies) + + self.search_handler = search_handler + self.add_handler = add_handler + self.online_bot = online_bot + + # Check if scheduler is enabled + self.enable_mem_scheduler = ( + hasattr(dependencies, "enable_mem_scheduler") and dependencies.enable_mem_scheduler + ) + + def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, Any]: + """ + Chat with MemOS for complete response (non-streaming). + + This implementation directly uses search/add handlers instead of mos_server. + + Args: + chat_req: Chat complete request + + Returns: + Dictionary with response and references + + Raises: + HTTPException: If chat fails + """ + try: + import time + + time_start = time.time() + + # Step 1: Search for relevant memories + search_req = APISearchRequest( + user_id=chat_req.user_id, + mem_cube_id=chat_req.mem_cube_id, + query=chat_req.query, + top_k=chat_req.top_k or 10, + session_id=chat_req.session_id, + mode=SearchMode.FINE, + internet_search=chat_req.internet_search, + moscube=chat_req.moscube, + chat_history=chat_req.history, + ) + + search_response = self.search_handler.handle_search_memories(search_req) + + # Extract memories from search results + memories_list = [] + if search_response.data and search_response.data.get("text_mem"): + text_mem_results = search_response.data["text_mem"] + if text_mem_results and text_mem_results[0].get("memories"): + memories_list = text_mem_results[0]["memories"] + + # Filter memories by threshold + filtered_memories = self._filter_memories_by_threshold( + memories_list, chat_req.threshold or 0.5 + ) + + # Step 2: Build system prompt + system_prompt = self._build_system_prompt(filtered_memories, chat_req.base_prompt) + + # Prepare message history + history_info = chat_req.history[-20:] if chat_req.history else [] + current_messages = [ + {"role": "system", "content": system_prompt}, + *history_info, + {"role": "user", "content": chat_req.query}, + ] + + self.logger.info("Starting to generate complete response...") + + # Step 3: Generate complete response from LLM + response = self.llm.generate(current_messages) + + time_end = time.time() + + # Step 4: Start post-chat processing asynchronously + self._start_post_chat_processing( + user_id=chat_req.user_id, + cube_id=chat_req.mem_cube_id, + session_id=chat_req.session_id or "default_session", + query=chat_req.query, + full_response=response, + system_prompt=system_prompt, + time_start=time_start, + time_end=time_end, + speed_improvement=0.0, + current_messages=current_messages, + ) + + # Return the complete response + return { + "message": "Chat completed successfully", + "data": {"response": response, "references": filtered_memories}, + } + + except ValueError as err: + raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err + except Exception as err: + self.logger.error(f"Failed to complete chat: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + + def handle_chat_stream(self, chat_req: ChatRequest) -> StreamingResponse: + """ + Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers. + + This implementation directly uses search_handler and add_handler. + + Args: + chat_req: Chat stream request + + Returns: + StreamingResponse with SSE formatted chat stream + + Raises: + HTTPException: If stream initialization fails + """ + try: + + def generate_chat_response() -> Generator[str, None, None]: + """Generate chat response as SSE stream.""" + try: + import time + + time_start = time.time() + + # Step 1: Search for memories using search handler + yield f"data: {json.dumps({'type': 'status', 'data': '0'})}\n\n" + + search_req = APISearchRequest( + user_id=chat_req.user_id, + mem_cube_id=chat_req.mem_cube_id, + query=chat_req.query, + top_k=20, + session_id=chat_req.session_id, + mode=SearchMode.FINE, + internet_search=chat_req.internet_search, + moscube=chat_req.moscube, + chat_history=chat_req.history, + ) + + search_response = self.search_handler.handle_search_memories(search_req) + + yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n" + self._send_message_to_scheduler( + user_id=chat_req.user_id, + mem_cube_id=chat_req.mem_cube_id, + query=chat_req.query, + label=QUERY_LABEL, + ) + # Extract memories from search results + memories_list = [] + if search_response.data and search_response.data.get("text_mem"): + text_mem_results = search_response.data["text_mem"] + if text_mem_results and text_mem_results[0].get("memories"): + memories_list = text_mem_results[0]["memories"] + + # Filter memories by threshold + filtered_memories = self._filter_memories_by_threshold(memories_list) + + # Prepare reference data + reference = prepare_reference_data(filtered_memories) + yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n" + + # Step 2: Build system prompt with memories + system_prompt = self._build_enhance_system_prompt(filtered_memories) + + # Prepare messages + history_info = chat_req.history[-20:] if chat_req.history else [] + current_messages = [ + {"role": "system", "content": system_prompt}, + *history_info, + {"role": "user", "content": chat_req.query}, + ] + + self.logger.info( + f"user_id: {chat_req.user_id}, cube_id: {chat_req.mem_cube_id}, " + f"current_system_prompt: {system_prompt}" + ) + + yield f"data: {json.dumps({'type': 'status', 'data': '2'})}\n\n" + + # Step 3: Generate streaming response from LLM + response_stream = self.llm.generate_stream(current_messages) + + # Stream the response + buffer = "" + full_response = "" + + for chunk in response_stream: + if chunk in ["", ""]: + continue + + buffer += chunk + full_response += chunk + + # Process buffer to ensure complete reference tags + processed_chunk, remaining_buffer = process_streaming_references_complete( + buffer + ) + + if processed_chunk: + chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n" + yield chunk_data + buffer = remaining_buffer + + # Process any remaining buffer + if buffer: + processed_chunk, _ = process_streaming_references_complete(buffer) + if processed_chunk: + chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n" + yield chunk_data + + # Calculate timing + time_end = time.time() + speed_improvement = round(float((len(system_prompt) / 2) * 0.0048 + 44.5), 1) + total_time = round(float(time_end - time_start), 1) + + yield f"data: {json.dumps({'type': 'time', 'data': {'total_time': total_time, 'speed_improvement': f'{speed_improvement}%'}})}\n\n" + + # Get further suggestion + current_messages.append({"role": "assistant", "content": full_response}) + further_suggestion = self._get_further_suggestion(current_messages) + self.logger.info(f"further_suggestion: {further_suggestion}") + yield f"data: {json.dumps({'type': 'suggestion', 'data': further_suggestion})}\n\n" + + yield f"data: {json.dumps({'type': 'end'})}\n\n" + + # Step 4: Add conversation to memory asynchronously + self._start_post_chat_processing( + user_id=chat_req.user_id, + cube_id=chat_req.mem_cube_id, + session_id=chat_req.session_id or "default_session", + query=chat_req.query, + full_response=full_response, + system_prompt=system_prompt, + time_start=time_start, + time_end=time_end, + speed_improvement=speed_improvement, + current_messages=current_messages, + ) + + except Exception as e: + self.logger.error(f"Error in chat stream: {e}", exc_info=True) + error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n" + yield error_data + + return StreamingResponse( + generate_chat_response(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Content-Type": "text/event-stream", + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Headers": "*", + "Access-Control-Allow-Methods": "*", + }, + ) + + except ValueError as err: + raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err + except Exception as err: + self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}") + raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err + + def _build_system_prompt( + self, + memories: list | None = None, + base_prompt: str | None = None, + **kwargs, + ) -> str: + """Build system prompt with optional memories context.""" + if base_prompt is None: + base_prompt = ( + "You are a knowledgeable and helpful AI assistant. " + "You have access to conversation memories that help you provide more personalized responses. " + "Use the memories to understand the user's context, preferences, and past interactions. " + "If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories." + ) + + memory_context = "" + if memories: + memory_list = [] + for i, memory in enumerate(memories, 1): + text_memory = memory.get("memory", "") + memory_list.append(f"{i}. {text_memory}") + memory_context = "\n".join(memory_list) + + if "{memories}" in base_prompt: + return base_prompt.format(memories=memory_context) + elif base_prompt and memories: + # For backward compatibility, append memories if no placeholder is found + memory_context_with_header = "\n\n## Memories:\n" + memory_context + return base_prompt + memory_context_with_header + return base_prompt + + def _build_enhance_system_prompt( + self, + memories_list: list, + tone: str = "friendly", + verbosity: str = "mid", + ) -> str: + """ + Build enhanced system prompt with memories (for streaming response). + + Args: + memories_list: List of memory items + tone: Tone of the prompt + verbosity: Verbosity level + + Returns: + System prompt string + """ + now = datetime.now() + formatted_date = now.strftime("%Y-%m-%d (%A)") + sys_body = get_memos_prompt( + date=formatted_date, tone=tone, verbosity=verbosity, mode="enhance" + ) + + # Format memories + mem_block_o, mem_block_p = self._format_mem_block(memories_list) + + return ( + sys_body + + "\n\n# Memories\n## PersonalMemory (ordered)\n" + + mem_block_p + + "\n## OuterMemory (ordered)\n" + + mem_block_o + ) + + def _format_mem_block( + self, memories_all: list, max_items: int = 20, max_chars_each: int = 320 + ) -> tuple[str, str]: + """ + Format memory block for prompt. + + Args: + memories_all: List of memory items + max_items: Maximum number of items to format + max_chars_each: Maximum characters per item + + Returns: + Tuple of (outer_memory_block, personal_memory_block) + """ + if not memories_all: + return "(none)", "(none)" + + lines_o = [] + lines_p = [] + + for idx, m in enumerate(memories_all[:max_items], 1): + mid = m.get("id", "").split("-")[0] if m.get("id") else f"mem_{idx}" + memory_content = m.get("memory", "") + metadata = m.get("metadata", {}) + memory_type = metadata.get("memory_type", "") + + tag = "O" if "Outer" in str(memory_type) else "P" + txt = memory_content.replace("\n", " ").strip() + if len(txt) > max_chars_each: + txt = txt[: max_chars_each - 1] + "…" + + mid = mid or f"mem_{idx}" + if tag == "O": + lines_o.append(f"[{idx}:{mid}] :: [{tag}] {txt}\n") + elif tag == "P": + lines_p.append(f"[{idx}:{mid}] :: [{tag}] {txt}") + + return "\n".join(lines_o), "\n".join(lines_p) + + def _filter_memories_by_threshold( + self, + memories: list, + threshold: float = 0.30, + min_num: int = 3, + memory_type: Literal["OuterMemory"] = "OuterMemory", + ) -> list: + """ + Filter memories by threshold and type. + + Args: + memories: List of memory items + threshold: Relevance threshold + min_num: Minimum number of memories to keep + memory_type: Memory type to filter + + Returns: + Filtered list of memories + """ + if not memories: + return [] + + # Handle dict format (from search results) + def get_relativity(m): + if isinstance(m, dict): + return m.get("metadata", {}).get("relativity", 0.0) + return getattr(getattr(m, "metadata", None), "relativity", 0.0) + + def get_memory_type(m): + if isinstance(m, dict): + return m.get("metadata", {}).get("memory_type", "") + return getattr(getattr(m, "metadata", None), "memory_type", "") + + sorted_memories = sorted(memories, key=get_relativity, reverse=True) + filtered_person = [m for m in memories if get_memory_type(m) != memory_type] + filtered_outer = [m for m in memories if get_memory_type(m) == memory_type] + + filtered = [] + per_memory_count = 0 + + for m in sorted_memories: + if get_relativity(m) >= threshold: + if get_memory_type(m) != memory_type: + per_memory_count += 1 + filtered.append(m) + + if len(filtered) < min_num: + filtered = filtered_person[:min_num] + filtered_outer[:min_num] + else: + if per_memory_count < min_num: + filtered += filtered_person[per_memory_count:min_num] + + filtered_memory = sorted(filtered, key=get_relativity, reverse=True) + return filtered_memory + + def _get_further_suggestion( + self, + current_messages: MessageList, + ) -> list[str]: + """Get further suggestion based on current messages.""" + try: + dialogue_info = "\n".join( + [f"{msg['role']}: {msg['content']}" for msg in current_messages[-2:]] + ) + further_suggestion_prompt = FURTHER_SUGGESTION_PROMPT.format(dialogue=dialogue_info) + message_list = [{"role": "system", "content": further_suggestion_prompt}] + response = self.llm.generate(message_list) + clean_response = clean_json_response(response) + response_json = json.loads(clean_response) + return response_json["query"] + except Exception as e: + self.logger.error(f"Error getting further suggestion: {e}", exc_info=True) + return [] + + def _extract_references_from_response(self, response: str) -> tuple[str, list[dict]]: + """Extract reference information from the response and return clean text.""" + import re + + try: + references = [] + # Pattern to match [refid:memoriesID] + pattern = r"\[(\d+):([^\]]+)\]" + + matches = re.findall(pattern, response) + for ref_number, memory_id in matches: + references.append({"memory_id": memory_id, "reference_number": int(ref_number)}) + + # Remove all reference markers from the text to get clean text + clean_text = re.sub(pattern, "", response) + + # Clean up any extra whitespace that might be left after removing markers + clean_text = re.sub(r"\s+", " ", clean_text).strip() + + return clean_text, references + except Exception as e: + self.logger.error(f"Error extracting references from response: {e}", exc_info=True) + return response, [] + + def _extract_struct_data_from_history(self, chat_data: list[dict]) -> dict: + """ + Extract structured message data from chat history. + + Args: + chat_data: List of chat messages + + Returns: + Dictionary with system, memory, and chat_history + """ + system_content = "" + memory_content = "" + chat_history = [] + + for item in chat_data: + role = item.get("role") + content = item.get("content", "") + if role == "system": + parts = content.split("# Memories", 1) + system_content = parts[0].strip() + if len(parts) > 1: + memory_content = "# Memories" + parts[1].strip() + elif role in ("user", "assistant"): + chat_history.append({"role": role, "content": content}) + + if chat_history and chat_history[-1]["role"] == "assistant": + if len(chat_history) >= 2 and chat_history[-2]["role"] == "user": + chat_history = chat_history[:-2] + else: + chat_history = chat_history[:-1] + + return {"system": system_content, "memory": memory_content, "chat_history": chat_history} + + def _send_message_to_scheduler( + self, + user_id: str, + mem_cube_id: str, + query: str, + label: str, + ) -> None: + """ + Send message to scheduler. + + Args: + user_id: User ID + mem_cube_id: Memory cube ID + query: Query content + label: Message label + """ + try: + message_item = ScheduleMessageItem( + user_id=user_id, + mem_cube_id=mem_cube_id, + label=label, + content=query, + timestamp=datetime.utcnow(), + ) + self.mem_scheduler.submit_messages(messages=[message_item]) + self.logger.info(f"Sent message to scheduler with label: {label}") + except Exception as e: + self.logger.error(f"Failed to send message to scheduler: {e}", exc_info=True) + + async def _post_chat_processing( + self, + user_id: str, + cube_id: str, + session_id: str, + query: str, + full_response: str, + system_prompt: str, + time_start: float, + time_end: float, + speed_improvement: float, + current_messages: list, + ) -> None: + """ + Asynchronous post-chat processing with complete functionality. + + Includes: + - Reference extraction + - DingDing notification + - Scheduler messaging + - Memory addition + + Args: + user_id: User ID + cube_id: Memory cube ID + session_id: Session ID + query: User query + full_response: Full LLM response + system_prompt: System prompt used + time_start: Start timestamp + time_end: End timestamp + speed_improvement: Speed improvement metric + current_messages: Current message history + """ + try: + self.logger.info( + f"user_id: {user_id}, cube_id: {cube_id}, current_messages: {current_messages}" + ) + self.logger.info( + f"user_id: {user_id}, cube_id: {cube_id}, full_response: {full_response}" + ) + + # Extract references and clean response + clean_response, extracted_references = self._extract_references_from_response( + full_response + ) + struct_message = self._extract_struct_data_from_history(current_messages) + self.logger.info(f"Extracted {len(extracted_references)} references from response") + + # Send DingDing notification if enabled + if self.online_bot: + self.logger.info("Online Bot Open!") + try: + from memos.memos_tools.notification_utils import ( + send_online_bot_notification_async, + ) + + # Prepare notification data + chat_data = {"query": query, "user_id": user_id, "cube_id": cube_id} + chat_data.update( + { + "memory": struct_message["memory"], + "chat_history": struct_message["chat_history"], + "full_response": full_response, + } + ) + + system_data = { + "references": extracted_references, + "time_start": time_start, + "time_end": time_end, + "speed_improvement": speed_improvement, + } + + emoji_config = {"chat": "💬", "system_info": "📊"} + + await send_online_bot_notification_async( + online_bot=self.online_bot, + header_name="MemOS Chat Report", + sub_title_name="chat_with_references", + title_color="#00956D", + other_data1=chat_data, + other_data2=system_data, + emoji=emoji_config, + ) + except Exception as e: + self.logger.warning(f"Failed to send chat notification (async): {e}") + + # Send answer to scheduler + self._send_message_to_scheduler( + user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_LABEL + ) + + # Add conversation to memory using add handler + add_req = APIADDRequest( + user_id=user_id, + mem_cube_id=cube_id, + session_id=session_id, + messages=[ + { + "role": "user", + "content": query, + "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + }, + { + "role": "assistant", + "content": clean_response, # Store clean text without reference markers + "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), + }, + ], + async_mode="sync", # set suync for playground + ) + + self.add_handler.handle_add_memories(add_req) + + self.logger.info(f"Post-chat processing completed for user {user_id}") + + except Exception as e: + self.logger.error( + f"Error in post-chat processing for user {user_id}: {e}", exc_info=True + ) + + def _start_post_chat_processing( + self, + user_id: str, + cube_id: str, + session_id: str, + query: str, + full_response: str, + system_prompt: str, + time_start: float, + time_end: float, + speed_improvement: float, + current_messages: list, + ) -> None: + """ + Start asynchronous post-chat processing in a background thread. + + Args: + user_id: User ID + cube_id: Memory cube ID + session_id: Session ID + query: User query + full_response: Full LLM response + system_prompt: System prompt used + time_start: Start timestamp + time_end: End timestamp + speed_improvement: Speed improvement metric + current_messages: Current message history + """ + + def run_async_in_thread(): + """Running asynchronous tasks in a new thread""" + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + self._post_chat_processing( + user_id=user_id, + cube_id=cube_id, + session_id=session_id, + query=query, + full_response=full_response, + system_prompt=system_prompt, + time_start=time_start, + time_end=time_end, + speed_improvement=speed_improvement, + current_messages=current_messages, + ) + ) + finally: + loop.close() + except Exception as e: + self.logger.error( + f"Error in thread-based post-chat processing for user {user_id}: {e}", + exc_info=True, + ) + + try: + # Try to get the current event loop + asyncio.get_running_loop() + # Create task and store reference to prevent garbage collection + task = asyncio.create_task( + self._post_chat_processing( + user_id=user_id, + cube_id=cube_id, + session_id=session_id, + query=query, + full_response=full_response, + system_prompt=system_prompt, + time_start=time_start, + time_end=time_end, + speed_improvement=speed_improvement, + current_messages=current_messages, + ) + ) + # Add exception handling for the background task + task.add_done_callback( + lambda t: self.logger.error( + f"Error in background post-chat processing for user {user_id}: {t.exception()}", + exc_info=True, + ) + if t.exception() + else None + ) + except RuntimeError: + # No event loop, run in a new thread with context propagation + thread = ContextThread( + target=run_async_in_thread, + name=f"PostChatProcessing-{user_id}", + daemon=True, + ) + thread.start() diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py new file mode 100644 index 000000000..4e696a341 --- /dev/null +++ b/src/memos/api/handlers/component_init.py @@ -0,0 +1,264 @@ +""" +Server component initialization module. + +This module handles the initialization of all MemOS server components +including databases, LLMs, memory systems, and schedulers. +""" + +from typing import TYPE_CHECKING, Any + +from memos.api.config import APIConfig +from memos.api.handlers.config_builders import ( + build_embedder_config, + build_graph_db_config, + build_internet_retriever_config, + build_llm_config, + build_mem_reader_config, + build_pref_adder_config, + build_pref_extractor_config, + build_pref_retriever_config, + build_reranker_config, + build_vec_db_config, +) +from memos.configs.mem_scheduler import SchedulerConfigFactory +from memos.embedders.factory import EmbedderFactory +from memos.graph_dbs.factory import GraphStoreFactory +from memos.llms.factory import LLMFactory +from memos.log import get_logger +from memos.mem_cube.navie import NaiveMemCube +from memos.mem_os.product_server import MOSServer +from memos.mem_reader.factory import MemReaderFactory +from memos.mem_scheduler.orm_modules.base_model import BaseDBManager +from memos.mem_scheduler.scheduler_factory import SchedulerFactory +from memos.memories.textual.prefer_text_memory.factory import ( + AdderFactory, + ExtractorFactory, + RetrieverFactory, +) +from memos.memories.textual.simple_preference import SimplePreferenceTextMemory +from memos.memories.textual.simple_tree import SimpleTreeTextMemory +from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager +from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( + InternetRetrieverFactory, +) +from memos.reranker.factory import RerankerFactory +from memos.vec_dbs.factory import VecDBFactory + + +if TYPE_CHECKING: + from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler + +logger = get_logger(__name__) + + +def _get_default_memory_size(cube_config: Any) -> dict[str, int]: + """ + Get default memory size configuration. + + Attempts to retrieve memory size from cube config, falls back to defaults + if not found. + + Args: + cube_config: The cube configuration object + + Returns: + Dictionary with memory sizes for different memory types + """ + return getattr(cube_config.text_mem.config, "memory_size", None) or { + "WorkingMemory": 20, + "LongTermMemory": 1500, + "UserMemory": 480, + } + + +def init_server() -> dict[str, Any]: + """ + Initialize all server components and configurations. + + This function orchestrates the creation and initialization of all components + required by the MemOS server, including: + - Database connections (graph DB, vector DB) + - Language models and embedders + - Memory systems (text, preference) + - Scheduler and related modules + + Returns: + A dictionary containing all initialized components with descriptive keys. + This approach allows easy addition of new components without breaking + existing code that uses the components. + """ + logger.info("Initializing MemOS server components...") + + # Get default cube configuration + default_cube_config = APIConfig.get_default_cube_config() + + # Get online bot setting + dingding_enabled = APIConfig.is_dingding_bot_enabled() + + # Build component configurations + graph_db_config = build_graph_db_config() + llm_config = build_llm_config() + embedder_config = build_embedder_config() + mem_reader_config = build_mem_reader_config() + reranker_config = build_reranker_config() + internet_retriever_config = build_internet_retriever_config() + vector_db_config = build_vec_db_config() + pref_extractor_config = build_pref_extractor_config() + pref_adder_config = build_pref_adder_config() + pref_retriever_config = build_pref_retriever_config() + + logger.debug("Component configurations built successfully") + + # Create component instances + graph_db = GraphStoreFactory.from_config(graph_db_config) + vector_db = VecDBFactory.from_config(vector_db_config) + llm = LLMFactory.from_config(llm_config) + embedder = EmbedderFactory.from_config(embedder_config) + mem_reader = MemReaderFactory.from_config(mem_reader_config) + reranker = RerankerFactory.from_config(reranker_config) + internet_retriever = InternetRetrieverFactory.from_config( + internet_retriever_config, embedder=embedder + ) + + logger.debug("Core components instantiated") + + # Initialize memory manager + memory_manager = MemoryManager( + graph_db, + embedder, + llm, + memory_size=_get_default_memory_size(default_cube_config), + is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False), + ) + + logger.debug("Memory manager initialized") + + # Initialize text memory + text_mem = SimpleTreeTextMemory( + llm=llm, + embedder=embedder, + mem_reader=mem_reader, + graph_db=graph_db, + reranker=reranker, + memory_manager=memory_manager, + config=default_cube_config.text_mem.config, + internet_retriever=internet_retriever, + ) + + logger.debug("Text memory initialized") + + # Initialize preference memory components + pref_extractor = ExtractorFactory.from_config( + config_factory=pref_extractor_config, + llm_provider=llm, + embedder=embedder, + vector_db=vector_db, + ) + + pref_adder = AdderFactory.from_config( + config_factory=pref_adder_config, + llm_provider=llm, + embedder=embedder, + vector_db=vector_db, + text_mem=text_mem, + ) + + pref_retriever = RetrieverFactory.from_config( + config_factory=pref_retriever_config, + llm_provider=llm, + embedder=embedder, + reranker=reranker, + vector_db=vector_db, + ) + + logger.debug("Preference memory components initialized") + + # Initialize preference memory + pref_mem = SimplePreferenceTextMemory( + extractor_llm=llm, + vector_db=vector_db, + embedder=embedder, + reranker=reranker, + extractor=pref_extractor, + adder=pref_adder, + retriever=pref_retriever, + ) + + logger.debug("Preference memory initialized") + + # Initialize MOS Server + mos_server = MOSServer( + mem_reader=mem_reader, + llm=llm, + online_bot=False, + ) + + logger.debug("MOS server initialized") + + # Create MemCube with pre-initialized memory instances + naive_mem_cube = NaiveMemCube( + text_mem=text_mem, + pref_mem=pref_mem, + act_mem=None, + para_mem=None, + ) + + logger.debug("MemCube created") + + # Initialize Scheduler + scheduler_config_dict = APIConfig.get_scheduler_config() + scheduler_config = SchedulerConfigFactory( + backend="optimized_scheduler", config=scheduler_config_dict + ) + mem_scheduler: OptimizedScheduler = SchedulerFactory.from_config(scheduler_config) + mem_scheduler.initialize_modules( + chat_llm=llm, + process_llm=mem_reader.llm, + db_engine=BaseDBManager.create_default_sqlite_engine(), + mem_reader=mem_reader, + ) + mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube) + logger.debug("Scheduler initialized") + + # Initialize SchedulerAPIModule + api_module = mem_scheduler.api_module + + # Start scheduler if enabled + import os + + if os.getenv("API_SCHEDULER_ON", True): + mem_scheduler.start() + logger.info("Scheduler started") + + logger.info("MemOS server components initialized successfully") + + # Initialize online bot if enabled + online_bot = None + if dingding_enabled: + from memos.memos_tools.notification_service import get_online_bot_function + + online_bot = get_online_bot_function() if dingding_enabled else None + logger.info("DingDing bot is enabled") + + # Return all components as a dictionary for easy access and extension + return { + "graph_db": graph_db, + "mem_reader": mem_reader, + "llm": llm, + "embedder": embedder, + "reranker": reranker, + "internet_retriever": internet_retriever, + "memory_manager": memory_manager, + "default_cube_config": default_cube_config, + "mos_server": mos_server, + "mem_scheduler": mem_scheduler, + "naive_mem_cube": naive_mem_cube, + "api_module": api_module, + "vector_db": vector_db, + "pref_extractor": pref_extractor, + "pref_adder": pref_adder, + "pref_retriever": pref_retriever, + "text_mem": text_mem, + "pref_mem": pref_mem, + "online_bot": online_bot, + } diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py new file mode 100644 index 000000000..9f510add0 --- /dev/null +++ b/src/memos/api/handlers/config_builders.py @@ -0,0 +1,153 @@ +""" +Configuration builders for server handlers. + +This module contains factory functions that build configurations for various +components used by the MemOS server. Each function constructs and validates +a configuration dictionary using the appropriate ConfigFactory. +""" + +import os + +from typing import Any + +from memos.api.config import APIConfig +from memos.configs.embedder import EmbedderConfigFactory +from memos.configs.graph_db import GraphDBConfigFactory +from memos.configs.internet_retriever import InternetRetrieverConfigFactory +from memos.configs.llm import LLMConfigFactory +from memos.configs.mem_reader import MemReaderConfigFactory +from memos.configs.reranker import RerankerConfigFactory +from memos.configs.vec_db import VectorDBConfigFactory +from memos.memories.textual.prefer_text_memory.config import ( + AdderConfigFactory, + ExtractorConfigFactory, + RetrieverConfigFactory, +) + + +def build_graph_db_config(user_id: str = "default") -> dict[str, Any]: + """ + Build graph database configuration. + + Args: + user_id: User ID for configuration context (default: "default") + + Returns: + Validated graph database configuration dictionary + """ + graph_db_backend_map = { + "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), + "neo4j": APIConfig.get_neo4j_config(user_id=user_id), + "nebular": APIConfig.get_nebular_config(user_id=user_id), + "polardb": APIConfig.get_polardb_config(user_id=user_id), + } + + graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() + return GraphDBConfigFactory.model_validate( + { + "backend": graph_db_backend, + "config": graph_db_backend_map[graph_db_backend], + } + ) + + +def build_vec_db_config() -> dict[str, Any]: + """ + Build vector database configuration. + + Returns: + Validated vector database configuration dictionary + """ + return VectorDBConfigFactory.model_validate( + { + "backend": "milvus", + "config": APIConfig.get_milvus_config(), + } + ) + + +def build_llm_config() -> dict[str, Any]: + """ + Build LLM configuration. + + Returns: + Validated LLM configuration dictionary + """ + return LLMConfigFactory.model_validate( + { + "backend": "openai", + "config": APIConfig.get_openai_config(), + } + ) + + +def build_embedder_config() -> dict[str, Any]: + """ + Build embedder configuration. + + Returns: + Validated embedder configuration dictionary + """ + return EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config()) + + +def build_mem_reader_config() -> dict[str, Any]: + """ + Build memory reader configuration. + + Returns: + Validated memory reader configuration dictionary + """ + return MemReaderConfigFactory.model_validate( + APIConfig.get_product_default_config()["mem_reader"] + ) + + +def build_reranker_config() -> dict[str, Any]: + """ + Build reranker configuration. + + Returns: + Validated reranker configuration dictionary + """ + return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) + + +def build_internet_retriever_config() -> dict[str, Any]: + """ + Build internet retriever configuration. + + Returns: + Validated internet retriever configuration dictionary + """ + return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) + + +def build_pref_extractor_config() -> dict[str, Any]: + """ + Build preference memory extractor configuration. + + Returns: + Validated extractor configuration dictionary + """ + return ExtractorConfigFactory.model_validate({"backend": "naive", "config": {}}) + + +def build_pref_adder_config() -> dict[str, Any]: + """ + Build preference memory adder configuration. + + Returns: + Validated adder configuration dictionary + """ + return AdderConfigFactory.model_validate({"backend": "naive", "config": {}}) + + +def build_pref_retriever_config() -> dict[str, Any]: + """ + Build preference memory retriever configuration. + + Returns: + Validated retriever configuration dictionary + """ + return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}}) diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py new file mode 100644 index 000000000..976be87bb --- /dev/null +++ b/src/memos/api/handlers/formatters_handler.py @@ -0,0 +1,92 @@ +""" +Data formatting utilities for server handlers. + +This module provides utility functions for formatting and transforming data +structures for API responses, including memory items and preferences. +""" + +from typing import Any + +from memos.templates.instruction_completion import instruct_completion + + +def to_iter(running: Any) -> list[Any]: + """ + Normalize running tasks to a list of task objects. + + Handles different input types and converts them to a consistent list format. + + Args: + running: Running tasks, can be None, dict, or iterable + + Returns: + List of task objects + """ + if running is None: + return [] + if isinstance(running, dict): + return list(running.values()) + return list(running) if running else [] + + +def format_memory_item(memory_data: Any) -> dict[str, Any]: + """ + Format a single memory item for API response. + + Transforms a memory object into a dictionary with metadata properly + structured for API consumption. + + Args: + memory_data: Memory object to format + + Returns: + Formatted memory dictionary with ref_id and metadata + """ + memory = memory_data.model_dump() + memory_id = memory["id"] + ref_id = f"[{memory_id.split('-')[0]}]" + + memory["ref_id"] = ref_id + memory["metadata"]["embedding"] = [] + memory["metadata"]["sources"] = [] + memory["metadata"]["usage"] = [] + memory["metadata"]["ref_id"] = ref_id + memory["metadata"]["id"] = memory_id + memory["metadata"]["memory"] = memory["memory"] + + return memory + + +def post_process_pref_mem( + memories_result: dict[str, Any], + pref_formatted_mem: list[dict[str, Any]], + mem_cube_id: str, + include_preference: bool, +) -> dict[str, Any]: + """ + Post-process preference memory results. + + Adds formatted preference memories to the result dictionary and generates + instruction completion strings if preferences are included. + + Args: + memories_result: Result dictionary to update + pref_formatted_mem: List of formatted preference memories + mem_cube_id: Memory cube ID + include_preference: Whether to include preferences in result + + Returns: + Updated memories_result dictionary + """ + if include_preference: + memories_result["pref_mem"].append( + { + "cube_id": mem_cube_id, + "memories": pref_formatted_mem, + } + ) + pref_instruction, pref_note = instruct_completion(pref_formatted_mem) + memories_result["pref_string"] = pref_instruction + memories_result["pref_note"] = pref_note + + return memories_result diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py new file mode 100644 index 000000000..85f339f3f --- /dev/null +++ b/src/memos/api/handlers/memory_handler.py @@ -0,0 +1,151 @@ +""" +Memory handler for retrieving and managing memories. + +This module handles retrieving all memories or specific subgraphs based on queries. +""" + +from typing import Any, Literal + +from memos.api.product_models import MemoryResponse +from memos.log import get_logger +from memos.mem_os.utils.format_utils import ( + convert_graph_to_tree_forworkmem, + ensure_unique_tree_ids, + filter_nodes_by_tree_ids, + remove_embedding_recursive, + sort_children_by_memory_type, +) + + +logger = get_logger(__name__) + + +def handle_get_all_memories( + user_id: str, + mem_cube_id: str, + memory_type: Literal["text_mem", "act_mem", "param_mem", "para_mem"], + naive_mem_cube: Any, +) -> MemoryResponse: + """ + Main handler for getting all memories. + + Retrieves all memories of specified type for a user and formats them appropriately. + + Args: + user_id: User ID + mem_cube_id: Memory cube ID + memory_type: Type of memory to retrieve + naive_mem_cube: Memory cube instance + + Returns: + MemoryResponse with formatted memory data + """ + try: + reformat_memory_list = [] + + if memory_type == "text_mem": + # Get all text memories from the graph database + memories = naive_mem_cube.text_mem.get_all(user_name=mem_cube_id) + + # Format and convert to tree structure + memories_cleaned = remove_embedding_recursive(memories) + custom_type_ratios = { + "WorkingMemory": 0.20, + "LongTermMemory": 0.40, + "UserMemory": 0.40, + } + tree_result, node_type_count = convert_graph_to_tree_forworkmem( + memories_cleaned, target_node_count=200, type_ratios=custom_type_ratios + ) + # Ensure all node IDs are unique in the tree structure + tree_result = ensure_unique_tree_ids(tree_result) + memories_filtered = filter_nodes_by_tree_ids(tree_result, memories_cleaned) + children = tree_result["children"] + children_sort = sort_children_by_memory_type(children) + tree_result["children"] = children_sort + memories_filtered["tree_structure"] = tree_result + + reformat_memory_list.append( + { + "cube_id": mem_cube_id, + "memories": [memories_filtered], + "memory_statistics": node_type_count, + } + ) + + elif memory_type == "act_mem": + logger.warning("Activity memory retrieval not implemented yet.") + elif memory_type == "para_mem": + logger.warning("Parameter memory retrieval not implemented yet.") + return MemoryResponse( + message="Memories retrieved successfully", + data=reformat_memory_list, + ) + + except Exception as e: + logger.error(f"Failed to get all memories: {e}", exc_info=True) + raise + + +def handle_get_subgraph( + user_id: str, + mem_cube_id: str, + query: str, + top_k: int, + naive_mem_cube: Any, +) -> MemoryResponse: + """ + Main handler for getting memory subgraph based on query. + + Retrieves relevant memory subgraph and formats it as a tree structure. + + Args: + user_id: User ID + mem_cube_id: Memory cube ID + query: Search query + top_k: Number of top results to return + naive_mem_cube: Memory cube instance + + Returns: + MemoryResponse with formatted subgraph data + """ + try: + # Get relevant subgraph from text memory + memories = naive_mem_cube.text_mem.get_relevant_subgraph( + query, top_k=top_k, user_name=mem_cube_id + ) + + # Format and convert to tree structure + memories_cleaned = remove_embedding_recursive(memories) + custom_type_ratios = { + "WorkingMemory": 0.20, + "LongTermMemory": 0.40, + "UserMemory": 0.40, + } + tree_result, node_type_count = convert_graph_to_tree_forworkmem( + memories_cleaned, target_node_count=150, type_ratios=custom_type_ratios + ) + # Ensure all node IDs are unique in the tree structure + tree_result = ensure_unique_tree_ids(tree_result) + memories_filtered = filter_nodes_by_tree_ids(tree_result, memories_cleaned) + children = tree_result["children"] + children_sort = sort_children_by_memory_type(children) + tree_result["children"] = children_sort + memories_filtered["tree_structure"] = tree_result + + reformat_memory_list = [ + { + "cube_id": mem_cube_id, + "memories": [memories_filtered], + "memory_statistics": node_type_count, + } + ] + + return MemoryResponse( + message="Memories retrieved successfully", + data=reformat_memory_list, + ) + + except Exception as e: + logger.error(f"Failed to get subgraph: {e}", exc_info=True) + raise diff --git a/src/memos/api/handlers/scheduler_handler.py b/src/memos/api/handlers/scheduler_handler.py new file mode 100644 index 000000000..8d3c6dc70 --- /dev/null +++ b/src/memos/api/handlers/scheduler_handler.py @@ -0,0 +1,220 @@ +""" +Scheduler handler for scheduler management functionality. + +This module handles all scheduler-related operations including status checking, +waiting for idle state, and streaming progress updates. +""" + +import json +import time +import traceback + +from typing import Any + +from fastapi import HTTPException +from fastapi.responses import StreamingResponse + +from memos.api.handlers.formatters_handler import to_iter +from memos.log import get_logger + + +logger = get_logger(__name__) + + +def handle_scheduler_status( + user_name: str | None = None, + mem_scheduler: Any | None = None, + instance_id: str = "", +) -> dict[str, Any]: + """ + Get scheduler running status. + + Retrieves the number of running tasks for a specific user or globally. + + Args: + user_name: Optional specific user name to filter tasks + mem_scheduler: Scheduler instance + instance_id: Instance ID for response + + Returns: + Dictionary with status information + + Raises: + HTTPException: If status retrieval fails + """ + try: + if user_name: + running = mem_scheduler.dispatcher.get_running_tasks( + lambda task: getattr(task, "mem_cube_id", None) == user_name + ) + tasks_iter = to_iter(running) + running_count = len(tasks_iter) + return { + "message": "ok", + "data": { + "scope": "user", + "user_name": user_name, + "running_tasks": running_count, + "timestamp": time.time(), + "instance_id": instance_id, + }, + } + else: + running_all = mem_scheduler.dispatcher.get_running_tasks(lambda _t: True) + tasks_iter = to_iter(running_all) + running_count = len(tasks_iter) + + task_count_per_user: dict[str, int] = {} + for task in tasks_iter: + cube = getattr(task, "mem_cube_id", "unknown") + task_count_per_user[cube] = task_count_per_user.get(cube, 0) + 1 + + try: + metrics_snapshot = mem_scheduler.dispatcher.metrics.snapshot() + except Exception: + metrics_snapshot = {} + + return { + "message": "ok", + "data": { + "scope": "global", + "running_tasks": running_count, + "task_count_per_user": task_count_per_user, + "timestamp": time.time(), + "instance_id": instance_id, + "metrics": metrics_snapshot, + }, + } + except Exception as err: + logger.error("Failed to get scheduler status: %s", traceback.format_exc()) + raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err + + +def handle_scheduler_wait( + user_name: str, + timeout_seconds: float = 120.0, + poll_interval: float = 0.2, + mem_scheduler: Any | None = None, +) -> dict[str, Any]: + """ + Wait until scheduler is idle for a specific user. + + Blocks until scheduler has no running tasks for the given user, or timeout. + + Args: + user_name: User name to wait for + timeout_seconds: Maximum wait time in seconds + poll_interval: Polling interval in seconds + mem_scheduler: Scheduler instance + + Returns: + Dictionary with wait result and statistics + + Raises: + HTTPException: If wait operation fails + """ + start = time.time() + try: + while True: + running = mem_scheduler.dispatcher.get_running_tasks( + lambda task: task.mem_cube_id == user_name + ) + running_count = len(running) + elapsed = time.time() - start + + # success -> scheduler is idle + if running_count == 0: + return { + "message": "idle", + "data": { + "running_tasks": 0, + "waited_seconds": round(elapsed, 3), + "timed_out": False, + "user_name": user_name, + }, + } + + # timeout check + if elapsed > timeout_seconds: + return { + "message": "timeout", + "data": { + "running_tasks": running_count, + "waited_seconds": round(elapsed, 3), + "timed_out": True, + "user_name": user_name, + }, + } + + time.sleep(poll_interval) + + except Exception as err: + logger.error("Failed while waiting for scheduler: %s", traceback.format_exc()) + raise HTTPException(status_code=500, detail="Failed while waiting for scheduler") from err + + +def handle_scheduler_wait_stream( + user_name: str, + timeout_seconds: float = 120.0, + poll_interval: float = 0.2, + mem_scheduler: Any | None = None, + instance_id: str = "", +) -> StreamingResponse: + """ + Stream scheduler progress via Server-Sent Events (SSE). + + Emits periodic heartbeat frames while tasks are running, then final + status frame indicating idle or timeout. + + Args: + user_name: User name to monitor + timeout_seconds: Maximum stream duration in seconds + poll_interval: Polling interval between updates + mem_scheduler: Scheduler instance + instance_id: Instance ID for response + + Returns: + StreamingResponse with SSE formatted progress updates + + Example: + curl -N "http://localhost:8000/product/scheduler/wait/stream?timeout_seconds=10" + """ + + def event_generator(): + start = time.time() + try: + while True: + running = mem_scheduler.dispatcher.get_running_tasks( + lambda task: task.mem_cube_id == user_name + ) + running_count = len(running) + elapsed = time.time() - start + + payload = { + "user_name": user_name, + "running_tasks": running_count, + "elapsed_seconds": round(elapsed, 3), + "status": "running" if running_count > 0 else "idle", + "instance_id": instance_id, + } + yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n" + + if running_count == 0 or elapsed > timeout_seconds: + payload["status"] = "idle" if running_count == 0 else "timeout" + payload["timed_out"] = running_count > 0 + yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n" + break + + time.sleep(poll_interval) + + except Exception as e: + err_payload = { + "status": "error", + "detail": "stream_failed", + "exception": str(e), + "user_name": user_name, + } + logger.error(f"Scheduler stream error for {user_name}: {traceback.format_exc()}") + yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n" + + return StreamingResponse(event_generator(), media_type="text/event-stream") diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py new file mode 100644 index 000000000..9fc8a5b28 --- /dev/null +++ b/src/memos/api/handlers/search_handler.py @@ -0,0 +1,289 @@ +""" +Search handler for memory search functionality (Class-based version). + +This module provides a class-based implementation of search handlers, +using dependency injection for better modularity and testability. +""" + +import os +import traceback + +from typing import Any + +from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies +from memos.api.handlers.formatters_handler import ( + format_memory_item, + post_process_pref_mem, +) +from memos.api.product_models import APISearchRequest, SearchResponse +from memos.context.context import ContextThreadPoolExecutor +from memos.mem_scheduler.schemas.general_schemas import SearchMode +from memos.types import MOSSearchResult, UserContext + + +class SearchHandler(BaseHandler): + """ + Handler for memory search operations. + + Provides fast, fine-grained, and mixture-based search modes. + """ + + def __init__(self, dependencies: HandlerDependencies): + """ + Initialize search handler. + + Args: + dependencies: HandlerDependencies instance + """ + super().__init__(dependencies) + self._validate_dependencies("naive_mem_cube", "mem_scheduler") + + def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: + """ + Main handler for search memories endpoint. + + Orchestrates the search process based on the requested search mode, + supporting both text and preference memory searches. + + Args: + search_req: Search request containing query and parameters + + Returns: + SearchResponse with formatted results + """ + # Create UserContext object + user_context = UserContext( + user_id=search_req.user_id, + mem_cube_id=search_req.mem_cube_id, + session_id=search_req.session_id or "default_session", + ) + self.logger.info(f"Search Req is: {search_req}") + + memories_result: MOSSearchResult = { + "text_mem": [], + "act_mem": [], + "para_mem": [], + "pref_mem": [], + "pref_note": "", + } + + # Determine search mode + search_mode = self._get_search_mode(search_req.mode) + + # Execute search in parallel for text and preference memories + with ContextThreadPoolExecutor(max_workers=2) as executor: + text_future = executor.submit(self._search_text, search_req, user_context, search_mode) + pref_future = executor.submit(self._search_pref, search_req, user_context) + + text_formatted_memories = text_future.result() + pref_formatted_memories = pref_future.result() + + # Build result + memories_result["text_mem"].append( + { + "cube_id": search_req.mem_cube_id, + "memories": text_formatted_memories, + } + ) + + memories_result = post_process_pref_mem( + memories_result, + pref_formatted_memories, + search_req.mem_cube_id, + search_req.include_preference, + ) + + self.logger.info(f"Search memories result: {memories_result}") + + return SearchResponse( + message="Search completed successfully", + data=memories_result, + ) + + def _get_search_mode(self, mode: str) -> str: + """ + Get search mode with environment variable fallback. + + Args: + mode: Requested search mode + + Returns: + Search mode string + """ + if mode == SearchMode.NOT_INITIALIZED: + return os.getenv("SEARCH_MODE", SearchMode.FAST) + return mode + + def _search_text( + self, + search_req: APISearchRequest, + user_context: UserContext, + search_mode: str, + ) -> list[dict[str, Any]]: + """ + Search text memories based on mode. + + Args: + search_req: Search request + user_context: User context + search_mode: Search mode (FAST, FINE, or MIXTURE) + + Returns: + List of formatted memory items + """ + try: + if search_mode == SearchMode.FAST: + memories = self._fast_search(search_req, user_context) + elif search_mode == SearchMode.FINE: + memories = self._fine_search(search_req, user_context) + elif search_mode == SearchMode.MIXTURE: + memories = self._mix_search(search_req, user_context) + else: + self.logger.error(f"Unsupported search mode: {search_mode}") + return [] + + return [format_memory_item(data) for data in memories] + + except Exception as e: + self.logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc()) + return [] + + def _search_pref( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list[dict[str, Any]]: + """ + Search preference memories. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of formatted preference memory items + """ + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": + return [] + + try: + results = self.naive_mem_cube.pref_mem.search( + query=search_req.query, + top_k=search_req.pref_top_k, + info={ + "user_id": search_req.user_id, + "session_id": search_req.session_id, + "chat_history": search_req.chat_history, + }, + ) + return [format_memory_item(data) for data in results] + except Exception as e: + self.logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) + return [] + + def _fast_search( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list: + """ + Fast search using vector database. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of search results + """ + target_session_id = search_req.session_id or "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + return self.naive_mem_cube.text_mem.search( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FAST, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info={ + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + }, + ) + + def _fine_search( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list: + """ + Fine-grained search with query enhancement. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of enhanced search results + """ + target_session_id = search_req.session_id or "default_session" + search_filter = {"session_id": search_req.session_id} if search_req.session_id else None + + searcher = self.mem_scheduler.searcher + + info = { + "user_id": search_req.user_id, + "session_id": target_session_id, + "chat_history": search_req.chat_history, + } + + # Fast retrieve + fast_retrieved_memories = searcher.retrieve( + query=search_req.query, + user_name=user_context.mem_cube_id, + top_k=search_req.top_k, + mode=SearchMode.FAST, + manual_close_internet=not search_req.internet_search, + moscube=search_req.moscube, + search_filter=search_filter, + info=info, + ) + + # Post retrieve + fast_memories = searcher.post_retrieve( + retrieved_results=fast_retrieved_memories, + top_k=search_req.top_k, + user_name=user_context.mem_cube_id, + info=info, + ) + + # Enhance with query + enhanced_results, _ = self.mem_scheduler.retriever.enhance_memories_with_query( + query_history=[search_req.query], + memories=fast_memories, + ) + + return enhanced_results + + def _mix_search( + self, + search_req: APISearchRequest, + user_context: UserContext, + ) -> list: + """ + Mix search combining fast and fine-grained approaches. + + Args: + search_req: Search request + user_context: User context + + Returns: + List of formatted search results + """ + return self.mem_scheduler.mix_search_memories( + search_req=search_req, + user_context=user_context, + ) diff --git a/src/memos/api/handlers/suggestion_handler.py b/src/memos/api/handlers/suggestion_handler.py new file mode 100644 index 000000000..dce894003 --- /dev/null +++ b/src/memos/api/handlers/suggestion_handler.py @@ -0,0 +1,117 @@ +""" +Suggestion handler for generating suggestion queries. + +This module handles suggestion query generation based on user's recent memories +or further suggestions from chat history. +""" + +import json + +from typing import Any + +from memos.api.product_models import SuggestionResponse +from memos.log import get_logger +from memos.mem_os.utils.format_utils import clean_json_response +from memos.templates.mos_prompts import ( + FURTHER_SUGGESTION_PROMPT, + SUGGESTION_QUERY_PROMPT_EN, + SUGGESTION_QUERY_PROMPT_ZH, +) +from memos.types import MessageList + + +logger = get_logger(__name__) + + +def _get_further_suggestion( + llm: Any, + message: MessageList, +) -> list[str]: + """ + Get further suggestion based on recent dialogue. + + Args: + llm: LLM instance for generating suggestions + message: Recent chat messages + + Returns: + List of suggestion queries + """ + try: + dialogue_info = "\n".join([f"{msg['role']}: {msg['content']}" for msg in message[-2:]]) + further_suggestion_prompt = FURTHER_SUGGESTION_PROMPT.format(dialogue=dialogue_info) + message_list = [{"role": "system", "content": further_suggestion_prompt}] + response = llm.generate(message_list) + clean_response = clean_json_response(response) + response_json = json.loads(clean_response) + return response_json["query"] + except Exception as e: + logger.error(f"Error getting further suggestion: {e}", exc_info=True) + return [] + + +def handle_get_suggestion_queries( + user_id: str, + language: str, + message: MessageList | None, + llm: Any, + naive_mem_cube: Any, +) -> SuggestionResponse: + """ + Main handler for suggestion queries endpoint. + + Generates suggestion queries based on user's recent memories or chat history. + + Args: + user_id: User ID + language: Language preference ("zh" or "en") + message: Optional chat message list for further suggestions + llm: LLM instance + naive_mem_cube: Memory cube instance + + Returns: + SuggestionResponse with generated queries + """ + try: + # If message is provided, get further suggestions based on dialogue + if message: + suggestions = _get_further_suggestion(llm, message) + return SuggestionResponse( + message="Suggestions retrieved successfully", + data={"query": suggestions}, + ) + + # Otherwise, generate suggestions based on recent memories + if language == "zh": + suggestion_prompt = SUGGESTION_QUERY_PROMPT_ZH + else: # English + suggestion_prompt = SUGGESTION_QUERY_PROMPT_EN + + # Search for recent memories + text_mem_results = naive_mem_cube.text_mem.search( + query="my recently memories", + user_name=user_id, + top_k=3, + mode="fast", + info={"user_id": user_id}, + ) + + # Extract memory content + memories = "" + if text_mem_results: + memories = "\n".join([m.memory[:200] for m in text_mem_results]) + + # Generate suggestions using LLM + message_list = [{"role": "system", "content": suggestion_prompt.format(memories=memories)}] + response = llm.generate(message_list) + clean_response = clean_json_response(response) + response_json = json.loads(clean_response) + + return SuggestionResponse( + message="Suggestions retrieved successfully", + data={"query": response_json["query"]}, + ) + + except Exception as e: + logger.error(f"Failed to get suggestions: {e}", exc_info=True) + raise diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 3b1ce2fc9..892d2d436 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -200,6 +200,9 @@ class APIADDRequest(BaseRequest): operation: list[PermissionDict] | None = Field( None, description="operation ids for multi cubes" ) + async_mode: Literal["async", "sync"] = Field( + "async", description="Whether to add memory in async mode" + ) class APIChatCompleteRequest(BaseRequest): @@ -223,6 +226,7 @@ class SuggestionRequest(BaseRequest): """Request model for getting suggestion queries.""" user_id: str = Field(..., description="User ID") + mem_cube_id: str = Field(..., description="Cube ID") language: Literal["zh", "en"] = Field("zh", description="Language for suggestions") message: list[MessageDict] | None = Field(None, description="List of messages to store.") diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index 7d9f141dc..d43f9ccdc 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -1,880 +1,220 @@ -import json +""" +Server API Router for MemOS (Class-based handlers version). + +This router demonstrates the improved architecture using class-based handlers +with dependency injection, providing better modularity and maintainability. + +Comparison with function-based approach: +- Cleaner code: No need to pass dependencies in every endpoint +- Better testability: Easy to mock handler dependencies +- Improved extensibility: Add new handlers or modify existing ones easily +- Clear separation of concerns: Router focuses on routing, handlers handle business logic +""" + import os import random as _random import socket -import time -import traceback -from collections.abc import Iterable -from datetime import datetime -from typing import TYPE_CHECKING, Any +from fastapi import APIRouter -from fastapi import APIRouter, HTTPException -from fastapi.responses import StreamingResponse - -from memos.api.config import APIConfig +from memos.api import handlers +from memos.api.handlers.add_handler import AddHandler +from memos.api.handlers.base_handler import HandlerDependencies +from memos.api.handlers.chat_handler import ChatHandler +from memos.api.handlers.search_handler import SearchHandler from memos.api.product_models import ( APIADDRequest, APIChatCompleteRequest, APISearchRequest, + ChatRequest, + GetMemoryRequest, MemoryResponse, SearchResponse, + SuggestionRequest, + SuggestionResponse, ) -from memos.configs.embedder import EmbedderConfigFactory -from memos.configs.graph_db import GraphDBConfigFactory -from memos.configs.internet_retriever import InternetRetrieverConfigFactory -from memos.configs.llm import LLMConfigFactory -from memos.configs.mem_reader import MemReaderConfigFactory -from memos.configs.mem_scheduler import SchedulerConfigFactory -from memos.configs.reranker import RerankerConfigFactory -from memos.configs.vec_db import VectorDBConfigFactory -from memos.context.context import ContextThreadPoolExecutor -from memos.embedders.factory import EmbedderFactory -from memos.graph_dbs.factory import GraphStoreFactory -from memos.llms.factory import LLMFactory from memos.log import get_logger -from memos.mem_cube.navie import NaiveMemCube -from memos.mem_os.product_server import MOSServer -from memos.mem_reader.factory import MemReaderFactory -from memos.mem_scheduler.orm_modules.base_model import BaseDBManager -from memos.mem_scheduler.scheduler_factory import SchedulerFactory -from memos.mem_scheduler.schemas.general_schemas import ( - ADD_LABEL, - MEM_READ_LABEL, - PREF_ADD_LABEL, - SearchMode, -) -from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem -from memos.memories.textual.prefer_text_memory.config import ( - AdderConfigFactory, - ExtractorConfigFactory, - RetrieverConfigFactory, -) -from memos.memories.textual.prefer_text_memory.factory import ( - AdderFactory, - ExtractorFactory, - RetrieverFactory, -) -from memos.memories.textual.simple_preference import SimplePreferenceTextMemory -from memos.memories.textual.simple_tree import SimpleTreeTextMemory -from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager -from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( - InternetRetrieverFactory, -) -from memos.reranker.factory import RerankerFactory -from memos.templates.instruction_completion import instruct_completion - - -if TYPE_CHECKING: - from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler -from memos.types import MOSSearchResult, UserContext -from memos.vec_dbs.factory import VecDBFactory logger = get_logger(__name__) router = APIRouter(prefix="/product", tags=["Server API"]) -INSTANCE_ID = f"{socket.gethostname()}:{os.getpid()}:{_random.randint(1000, 9999)}" - - -def _to_iter(running: Any) -> Iterable: - """Normalize running tasks to an iterable of task objects.""" - if running is None: - return [] - if isinstance(running, dict): - return running.values() - return running # assume it's already an iterable (e.g., list) - - -def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]: - """Build graph database configuration.""" - graph_db_backend_map = { - "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id), - "neo4j": APIConfig.get_neo4j_config(user_id=user_id), - "nebular": APIConfig.get_nebular_config(user_id=user_id), - "polardb": APIConfig.get_polardb_config(user_id=user_id), - } - - graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower() - return GraphDBConfigFactory.model_validate( - { - "backend": graph_db_backend, - "config": graph_db_backend_map[graph_db_backend], - } - ) - - -def _build_vec_db_config() -> dict[str, Any]: - """Build vector database configuration.""" - return VectorDBConfigFactory.model_validate( - { - "backend": "milvus", - "config": APIConfig.get_milvus_config(), - } - ) - - -def _build_llm_config() -> dict[str, Any]: - """Build LLM configuration.""" - return LLMConfigFactory.model_validate( - { - "backend": "openai", - "config": APIConfig.get_openai_config(), - } - ) - - -def _build_embedder_config() -> dict[str, Any]: - """Build embedder configuration.""" - return EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config()) - - -def _build_mem_reader_config() -> dict[str, Any]: - """Build memory reader configuration.""" - return MemReaderConfigFactory.model_validate( - APIConfig.get_product_default_config()["mem_reader"] - ) - - -def _build_reranker_config() -> dict[str, Any]: - """Build reranker configuration.""" - return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config()) - - -def _build_internet_retriever_config() -> dict[str, Any]: - """Build internet retriever configuration.""" - return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config()) - - -def _build_pref_extractor_config() -> dict[str, Any]: - """Build extractor configuration.""" - return ExtractorConfigFactory.model_validate({"backend": "naive", "config": {}}) +# Instance ID for identifying this server instance in logs and responses +INSTANCE_ID = f"{socket.gethostname()}:{os.getpid()}:{_random.randint(1000, 9999)}" -def _build_pref_adder_config() -> dict[str, Any]: - """Build adder configuration.""" - return AdderConfigFactory.model_validate({"backend": "naive", "config": {}}) +# Initialize all server components +components = handlers.init_server() +# Create dependency container +dependencies = HandlerDependencies.from_init_server(components) -def _build_pref_retriever_config() -> dict[str, Any]: - """Build retriever configuration.""" - return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}}) +# Initialize all handlers with dependency injection +search_handler = SearchHandler(dependencies) +add_handler = AddHandler(dependencies) +chat_handler = ChatHandler( + dependencies, search_handler, add_handler, online_bot=components.get("online_bot") +) +# Extract commonly used components for function-based handlers +# (These can be accessed from the components dict without unpacking all of them) +mem_scheduler = components["mem_scheduler"] +llm = components["llm"] +naive_mem_cube = components["naive_mem_cube"] -def _get_default_memory_size(cube_config) -> dict[str, int]: - """Get default memory size configuration.""" - return getattr(cube_config.text_mem.config, "memory_size", None) or { - "WorkingMemory": 20, - "LongTermMemory": 1500, - "UserMemory": 480, - } +# ============================================================================= +# Search API Endpoints +# ============================================================================= -def init_server(): - """Initialize server components and configurations.""" - # Get default cube configuration - default_cube_config = APIConfig.get_default_cube_config() - # Build component configurations - graph_db_config = _build_graph_db_config() - llm_config = _build_llm_config() - embedder_config = _build_embedder_config() - mem_reader_config = _build_mem_reader_config() - reranker_config = _build_reranker_config() - internet_retriever_config = _build_internet_retriever_config() - vector_db_config = _build_vec_db_config() - pref_extractor_config = _build_pref_extractor_config() - pref_adder_config = _build_pref_adder_config() - pref_retriever_config = _build_pref_retriever_config() +@router.post("/search", summary="Search memories", response_model=SearchResponse) +def search_memories(search_req: APISearchRequest): + """ + Search memories for a specific user. - # Create component instances - graph_db = GraphStoreFactory.from_config(graph_db_config) - vector_db = VecDBFactory.from_config(vector_db_config) - llm = LLMFactory.from_config(llm_config) - embedder = EmbedderFactory.from_config(embedder_config) - mem_reader = MemReaderFactory.from_config(mem_reader_config) - reranker = RerankerFactory.from_config(reranker_config) - internet_retriever = InternetRetrieverFactory.from_config( - internet_retriever_config, embedder=embedder - ) + This endpoint uses the class-based SearchHandler for better code organization. + """ + return search_handler.handle_search_memories(search_req) - # Initialize memory manager - memory_manager = MemoryManager( - graph_db, - embedder, - llm, - memory_size=_get_default_memory_size(default_cube_config), - is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False), - ) - # Initialize text memory - text_mem = SimpleTreeTextMemory( - llm=llm, - embedder=embedder, - mem_reader=mem_reader, - graph_db=graph_db, - reranker=reranker, - memory_manager=memory_manager, - config=default_cube_config.text_mem.config, - internet_retriever=internet_retriever, - ) +# ============================================================================= +# Add API Endpoints +# ============================================================================= - pref_extractor = ExtractorFactory.from_config( - config_factory=pref_extractor_config, - llm_provider=llm, - embedder=embedder, - vector_db=vector_db, - ) - pref_adder = AdderFactory.from_config( - config_factory=pref_adder_config, - llm_provider=llm, - embedder=embedder, - vector_db=vector_db, - text_mem=text_mem, - ) +@router.post("/add", summary="Add memories", response_model=MemoryResponse) +def add_memories(add_req: APIADDRequest): + """ + Add memories for a specific user. - pref_retriever = RetrieverFactory.from_config( - config_factory=pref_retriever_config, - llm_provider=llm, - embedder=embedder, - reranker=reranker, - vector_db=vector_db, - ) + This endpoint uses the class-based AddHandler for better code organization. + """ + return add_handler.handle_add_memories(add_req) - # Initialize preference memory - pref_mem = SimplePreferenceTextMemory( - extractor_llm=llm, - vector_db=vector_db, - embedder=embedder, - reranker=reranker, - extractor=pref_extractor, - adder=pref_adder, - retriever=pref_retriever, - ) - mos_server = MOSServer( - mem_reader=mem_reader, - llm=llm, - online_bot=False, - ) +# ============================================================================= +# Scheduler API Endpoints +# ============================================================================= - # Create MemCube with pre-initialized memory instances - naive_mem_cube = NaiveMemCube( - text_mem=text_mem, - pref_mem=pref_mem, - act_mem=None, - para_mem=None, - ) - # Initialize Scheduler - scheduler_config_dict = APIConfig.get_scheduler_config() - scheduler_config = SchedulerConfigFactory( - backend="optimized_scheduler", config=scheduler_config_dict - ) - mem_scheduler: OptimizedScheduler = SchedulerFactory.from_config(scheduler_config) - mem_scheduler.initialize_modules( - chat_llm=llm, - process_llm=mem_reader.llm, - db_engine=BaseDBManager.create_default_sqlite_engine(), - mem_reader=mem_reader, - ) - mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube) - - # Initialize SchedulerAPIModule - api_module = mem_scheduler.api_module - - if os.getenv("API_SCHEDULER_ON", True): - mem_scheduler.start() - - return ( - graph_db, - mem_reader, - llm, - embedder, - reranker, - internet_retriever, - memory_manager, - default_cube_config, - mos_server, - mem_scheduler, - naive_mem_cube, - api_module, - vector_db, - pref_extractor, - pref_adder, - pref_retriever, - text_mem, - pref_mem, +@router.get("/scheduler/status", summary="Get scheduler running status") +def scheduler_status(user_name: str | None = None): + """Get scheduler running status.""" + return handlers.scheduler_handler.handle_scheduler_status( + user_name=user_name, + mem_scheduler=mem_scheduler, + instance_id=INSTANCE_ID, ) -# Initialize global components -( - graph_db, - mem_reader, - llm, - embedder, - reranker, - internet_retriever, - memory_manager, - default_cube_config, - mos_server, - mem_scheduler, - naive_mem_cube, - api_module, - vector_db, - pref_extractor, - pref_adder, - pref_retriever, - text_mem, - pref_mem, -) = init_server() - - -def _format_memory_item(memory_data: Any) -> dict[str, Any]: - """Format a single memory item for API response.""" - memory = memory_data.model_dump() - memory_id = memory["id"] - ref_id = f"[{memory_id.split('-')[0]}]" - - memory["ref_id"] = ref_id - memory["metadata"]["embedding"] = [] - memory["metadata"]["sources"] = [] - memory["metadata"]["usage"] = [] - memory["metadata"]["ref_id"] = ref_id - memory["metadata"]["id"] = memory_id - memory["metadata"]["memory"] = memory["memory"] - - return memory - - -def _post_process_pref_mem( - memories_result: list[dict[str, Any]], - pref_formatted_mem: list[dict[str, Any]], - mem_cube_id: str, - include_preference: bool, +@router.post("/scheduler/wait", summary="Wait until scheduler is idle for a specific user") +def scheduler_wait( + user_name: str, + timeout_seconds: float = 120.0, + poll_interval: float = 0.2, ): - if include_preference: - memories_result["pref_mem"].append( - { - "cube_id": mem_cube_id, - "memories": pref_formatted_mem, - } - ) - pref_instruction, pref_note = instruct_completion(pref_formatted_mem) - memories_result["pref_string"] = pref_instruction - memories_result["pref_note"] = pref_note - - return memories_result - - -@router.post("/search", summary="Search memories", response_model=SearchResponse) -def search_memories(search_req: APISearchRequest): - """Search memories for a specific user.""" - # Create UserContext object - how to assign values - user_context = UserContext( - user_id=search_req.user_id, - mem_cube_id=search_req.mem_cube_id, - session_id=search_req.session_id or "default_session", - ) - logger.info(f"Search Req is: {search_req}") - memories_result: MOSSearchResult = { - "text_mem": [], - "act_mem": [], - "para_mem": [], - "pref_mem": [], - "pref_note": "", - } - if search_req.mode == SearchMode.NOT_INITIALIZED: - search_mode = os.getenv("SEARCH_MODE", SearchMode.FAST) - else: - search_mode = search_req.mode - - def _search_text(): - try: - if search_mode == SearchMode.FAST: - formatted_memories = fast_search_memories( - search_req=search_req, user_context=user_context - ) - elif search_mode == SearchMode.FINE: - formatted_memories = fine_search_memories( - search_req=search_req, user_context=user_context - ) - elif search_mode == SearchMode.MIXTURE: - formatted_memories = mix_search_memories( - search_req=search_req, user_context=user_context - ) - else: - logger.error(f"Unsupported search mode: {search_mode}") - raise HTTPException( - status_code=400, detail=f"Unsupported search mode: {search_mode}" - ) - return formatted_memories - except Exception as e: - logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc()) - return [] - - def _search_pref(): - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - try: - results = naive_mem_cube.pref_mem.search( - query=search_req.query, - top_k=search_req.pref_top_k, - info={ - "user_id": search_req.user_id, - "session_id": search_req.session_id, - "chat_history": search_req.chat_history, - }, - ) - return [_format_memory_item(data) for data in results] - except Exception as e: - logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc()) - return [] - - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(_search_text) - pref_future = executor.submit(_search_pref) - text_formatted_memories = text_future.result() - pref_formatted_memories = pref_future.result() - - memories_result["text_mem"].append( - { - "cube_id": search_req.mem_cube_id, - "memories": text_formatted_memories, - } - ) - - memories_result = _post_process_pref_mem( - memories_result, - pref_formatted_memories, - search_req.mem_cube_id, - search_req.include_preference, - ) - - logger.info(f"Search memories result: {memories_result}") - - return SearchResponse( - message="Search completed successfully", - data=memories_result, + """Wait until scheduler is idle for a specific user.""" + return handlers.scheduler_handler.handle_scheduler_wait( + user_name=user_name, + timeout_seconds=timeout_seconds, + poll_interval=poll_interval, + mem_scheduler=mem_scheduler, ) -def mix_search_memories( - search_req: APISearchRequest, - user_context: UserContext, +@router.get("/scheduler/wait/stream", summary="Stream scheduler progress for a user") +def scheduler_wait_stream( + user_name: str, + timeout_seconds: float = 120.0, + poll_interval: float = 0.2, ): - """ - Mix search memories: fast search + async fine search - """ - - formatted_memories = mem_scheduler.mix_search_memories( - search_req=search_req, - user_context=user_context, + """Stream scheduler progress via Server-Sent Events (SSE).""" + return handlers.scheduler_handler.handle_scheduler_wait_stream( + user_name=user_name, + timeout_seconds=timeout_seconds, + poll_interval=poll_interval, + mem_scheduler=mem_scheduler, + instance_id=INSTANCE_ID, ) - return formatted_memories -def fine_search_memories( - search_req: APISearchRequest, - user_context: UserContext, -): - target_session_id = search_req.session_id - if not target_session_id: - target_session_id = "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - searcher = mem_scheduler.searcher - - info = { - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - } - - fast_retrieved_memories = searcher.retrieve( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=SearchMode.FAST, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, - info=info, - ) +# ============================================================================= +# Chat API Endpoints +# ============================================================================= - fast_memories = searcher.post_retrieve( - retrieved_results=fast_retrieved_memories, - top_k=search_req.top_k, - user_name=user_context.mem_cube_id, - info=info, - ) - enhanced_results, _ = mem_scheduler.retriever.enhance_memories_with_query( - query_history=[search_req.query], - memories=fast_memories, - ) +@router.post("/chat/complete", summary="Chat with MemOS (Complete Response)") +def chat_complete(chat_req: APIChatCompleteRequest): + """ + Chat with MemOS for a specific user. Returns complete response (non-streaming). - formatted_memories = [_format_memory_item(data) for data in enhanced_results] + This endpoint uses the class-based ChatHandler. + """ + return chat_handler.handle_chat_complete(chat_req) - return formatted_memories +@router.post("/chat", summary="Chat with MemOS") +def chat(chat_req: ChatRequest): + """ + Chat with MemOS for a specific user. Returns SSE stream. -def fast_search_memories( - search_req: APISearchRequest, - user_context: UserContext, -): - target_session_id = search_req.session_id - if not target_session_id: - target_session_id = "default_session" - search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - - # Create MemCube and perform search - search_results = naive_mem_cube.text_mem.search( - query=search_req.query, - user_name=user_context.mem_cube_id, - top_k=search_req.top_k, - mode=SearchMode.FAST, - manual_close_internet=not search_req.internet_search, - moscube=search_req.moscube, - search_filter=search_filter, - info={ - "user_id": search_req.user_id, - "session_id": target_session_id, - "chat_history": search_req.chat_history, - }, - ) - formatted_memories = [_format_memory_item(data) for data in search_results] + This endpoint uses the class-based ChatHandler which internally + composes SearchHandler and AddHandler for a clean architecture. + """ + return chat_handler.handle_chat_stream(chat_req) - return formatted_memories +# ============================================================================= +# Suggestion API Endpoints +# ============================================================================= -@router.post("/add", summary="Add memories", response_model=MemoryResponse) -def add_memories(add_req: APIADDRequest): - """Add memories for a specific user.""" - # Create UserContext object - how to assign values - user_context = UserContext( - user_id=add_req.user_id, - mem_cube_id=add_req.mem_cube_id, - session_id=add_req.session_id or "default_session", - ) - logger.info(f"Add Req is: {add_req}") - - target_session_id = add_req.session_id - if not target_session_id: - target_session_id = "default_session" - - # If text memory backend works in async mode, submit tasks to scheduler - try: - sync_mode = getattr(naive_mem_cube.text_mem, "mode", "sync") - except Exception: - sync_mode = "sync" - logger.info(f"Add sync_mode mode is: {sync_mode}") - - def _process_text_mem() -> list[dict[str, str]]: - memories_local = mem_reader.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - }, - mode="fast" if sync_mode == "async" else "fine", - ) - flattened_local = [mm for m in memories_local for mm in m] - logger.info(f"Memory extraction completed for user {add_req.user_id}") - mem_ids_local: list[str] = naive_mem_cube.text_mem.add( - flattened_local, - user_name=user_context.mem_cube_id, - ) - logger.info( - f"Added {len(mem_ids_local)} memories for user {add_req.user_id} " - f"in session {add_req.session_id}: {mem_ids_local}" - ) - if sync_mode == "async": - try: - message_item_read = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=naive_mem_cube, - label=MEM_READ_LABEL, - content=json.dumps(mem_ids_local), - timestamp=datetime.utcnow(), - user_name=add_req.mem_cube_id, - ) - mem_scheduler.submit_messages(messages=[message_item_read]) - logger.info(f"2105Submit messages!!!!!: {json.dumps(mem_ids_local)}") - except Exception as e: - logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True) - else: - message_item_add = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=naive_mem_cube, - label=ADD_LABEL, - content=json.dumps(mem_ids_local), - timestamp=datetime.utcnow(), - user_name=add_req.mem_cube_id, - ) - mem_scheduler.submit_messages(messages=[message_item_add]) - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.memory_type, - } - for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False) - ] - - def _process_pref_mem() -> list[dict[str, str]]: - if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true": - return [] - # Follow async behavior similar to core.py: enqueue when async - if sync_mode == "async": - try: - messages_list = [add_req.messages] - message_item_pref = ScheduleMessageItem( - user_id=add_req.user_id, - session_id=target_session_id, - mem_cube_id=add_req.mem_cube_id, - mem_cube=naive_mem_cube, - label=PREF_ADD_LABEL, - content=json.dumps(messages_list), - timestamp=datetime.utcnow(), - ) - mem_scheduler.submit_messages(messages=[message_item_pref]) - logger.info("Submitted preference add to scheduler (async mode)") - except Exception as e: - logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True) - return [] - else: - pref_memories_local = naive_mem_cube.pref_mem.get_memory( - [add_req.messages], - type="chat", - info={ - "user_id": add_req.user_id, - "session_id": target_session_id, - "mem_cube_id": add_req.mem_cube_id, - }, - ) - pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local) - logger.info( - f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} " - f"in session {add_req.session_id}: {pref_ids_local}" - ) - return [ - { - "memory": memory.memory, - "memory_id": memory_id, - "memory_type": memory.metadata.preference_type, - } - for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False) - ] - - with ContextThreadPoolExecutor(max_workers=2) as executor: - text_future = executor.submit(_process_text_mem) - pref_future = executor.submit(_process_pref_mem) - text_response_data = text_future.result() - pref_response_data = pref_future.result() - - logger.info(f"add_memories Text response data: {text_response_data}") - logger.info(f"add_memories Pref response data: {pref_response_data}") - - return MemoryResponse( - message="Memory added successfully", - data=text_response_data + pref_response_data, +@router.post( + "/suggestions", + summary="Get suggestion queries", + response_model=SuggestionResponse, +) +def get_suggestion_queries(suggestion_req: SuggestionRequest): + """Get suggestion queries for a specific user with language preference.""" + return handlers.suggestion_handler.handle_get_suggestion_queries( + user_id=suggestion_req.mem_cube_id, + language=suggestion_req.language, + message=suggestion_req.message, + llm=llm, + naive_mem_cube=naive_mem_cube, ) -@router.get("/scheduler/status", summary="Get scheduler running status") -def scheduler_status(user_name: str | None = None): - try: - if user_name: - running = mem_scheduler.dispatcher.get_running_tasks( - lambda task: getattr(task, "mem_cube_id", None) == user_name - ) - tasks_iter = list(_to_iter(running)) - running_count = len(tasks_iter) - return { - "message": "ok", - "data": { - "scope": "user", - "user_name": user_name, - "running_tasks": running_count, - "timestamp": time.time(), - "instance_id": INSTANCE_ID, - }, - } - else: - running_all = mem_scheduler.dispatcher.get_running_tasks(lambda _t: True) - tasks_iter = list(_to_iter(running_all)) - running_count = len(tasks_iter) - - task_count_per_user: dict[str, int] = {} - for task in tasks_iter: - cube = getattr(task, "mem_cube_id", "unknown") - task_count_per_user[cube] = task_count_per_user.get(cube, 0) + 1 - - try: - metrics_snapshot = mem_scheduler.dispatcher.metrics.snapshot() - except Exception: - metrics_snapshot = {} - - return { - "message": "ok", - "data": { - "scope": "global", - "running_tasks": running_count, - "task_count_per_user": task_count_per_user, - "timestamp": time.time(), - "instance_id": INSTANCE_ID, - "metrics": metrics_snapshot, - }, - } - except Exception as err: - logger.error("Failed to get scheduler status: %s", traceback.format_exc()) - raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err - - -@router.post("/scheduler/wait", summary="Wait until scheduler is idle for a specific user") -def scheduler_wait( - user_name: str, - timeout_seconds: float = 120.0, - poll_interval: float = 0.2, -): - """ - Block until scheduler has no running tasks for the given user_name, or timeout. - """ - start = time.time() - try: - while True: - running = mem_scheduler.dispatcher.get_running_tasks( - lambda task: task.mem_cube_id == user_name - ) - running_count = len(running) - elapsed = time.time() - start - - # success -> scheduler is idle - if running_count == 0: - return { - "message": "idle", - "data": { - "running_tasks": 0, - "waited_seconds": round(elapsed, 3), - "timed_out": False, - "user_name": user_name, - }, - } - - # timeout check - if elapsed > timeout_seconds: - return { - "message": "timeout", - "data": { - "running_tasks": running_count, - "waited_seconds": round(elapsed, 3), - "timed_out": True, - "user_name": user_name, - }, - } - - time.sleep(poll_interval) - - except Exception as err: - logger.error("Failed while waiting for scheduler: %s", traceback.format_exc()) - raise HTTPException(status_code=500, detail="Failed while waiting for scheduler") from err +# ============================================================================= +# Memory Retrieval API Endpoints +# ============================================================================= -@router.get("/scheduler/wait/stream", summary="Stream scheduler progress for a user") -def scheduler_wait_stream( - user_name: str, - timeout_seconds: float = 120.0, - poll_interval: float = 0.2, -): +@router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse) +def get_all_memories(memory_req: GetMemoryRequest): """ - Stream scheduler progress via Server-Sent Events (SSE). - - Contract: - - We emit periodic heartbeat frames while tasks are still running. - - Each heartbeat frame is JSON, prefixed with "data: ". - - On final frame, we include status = "idle" or "timeout" and timed_out flag, - with the same semantics as /scheduler/wait. + Get all memories or subgraph for a specific user. - Example curl: - curl -N "${API_HOST}/product/scheduler/wait/stream?timeout_seconds=10&poll_interval=0.5" + If search_query is provided, returns a subgraph based on the query. + Otherwise, returns all memories of the specified type. """ - - def event_generator(): - start = time.time() - try: - while True: - running = mem_scheduler.dispatcher.get_running_tasks( - lambda task: task.mem_cube_id == user_name - ) - running_count = len(running) - elapsed = time.time() - start - - payload = { - "user_name": user_name, - "running_tasks": running_count, - "elapsed_seconds": round(elapsed, 3), - "status": "running" if running_count > 0 else "idle", - "instance_id": INSTANCE_ID, - } - yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n" - - if running_count == 0 or elapsed > timeout_seconds: - payload["status"] = "idle" if running_count == 0 else "timeout" - payload["timed_out"] = running_count > 0 - yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n" - break - - time.sleep(poll_interval) - - except Exception as e: - err_payload = { - "status": "error", - "detail": "stream_failed", - "exception": str(e), - "user_name": user_name, - } - logger.error(f"Scheduler stream error for {user_name}: {traceback.format_exc()}") - yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n" - - return StreamingResponse(event_generator(), media_type="text/event-stream") - - -@router.post("/chat/complete", summary="Chat with MemOS (Complete Response)") -def chat_complete(chat_req: APIChatCompleteRequest): - """Chat with MemOS for a specific user. Returns complete response (non-streaming).""" - try: - # Collect all responses from the generator - content, references = mos_server.chat( - query=chat_req.query, - user_id=chat_req.user_id, - cube_id=chat_req.mem_cube_id, - mem_cube=naive_mem_cube, - history=chat_req.history, - internet_search=chat_req.internet_search, - moscube=chat_req.moscube, - base_prompt=chat_req.base_prompt, - top_k=chat_req.top_k, - threshold=chat_req.threshold, - session_id=chat_req.session_id, + if memory_req.search_query: + return handlers.memory_handler.handle_get_subgraph( + user_id=memory_req.user_id, + mem_cube_id=( + memory_req.mem_cube_ids[0] if memory_req.mem_cube_ids else memory_req.user_id + ), + query=memory_req.search_query, + top_k=20, + naive_mem_cube=naive_mem_cube, + ) + else: + return handlers.memory_handler.handle_get_all_memories( + user_id=memory_req.user_id, + mem_cube_id=( + memory_req.mem_cube_ids[0] if memory_req.mem_cube_ids else memory_req.user_id + ), + memory_type=memory_req.memory_type or "text_mem", + naive_mem_cube=naive_mem_cube, ) - - # Return the complete response - return { - "message": "Chat completed successfully", - "data": {"response": content, "references": references}, - } - - except ValueError as err: - raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err - except Exception as err: - logger.error(f"Failed to start chat: {traceback.format_exc()}") - raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err diff --git a/src/memos/mem_os/utils/reference_utils.py b/src/memos/mem_os/utils/reference_utils.py index c2f4431c3..09b812207 100644 --- a/src/memos/mem_os/utils/reference_utils.py +++ b/src/memos/mem_os/utils/reference_utils.py @@ -142,12 +142,21 @@ def prepare_reference_data(memories_list: list[TextualMemoryItem]) -> list[dict] # Prepare reference data reference = [] for memories in memories_list: - memories_json = memories.model_dump() - memories_json["metadata"]["ref_id"] = f"{memories.id.split('-')[0]}" - memories_json["metadata"]["embedding"] = [] - memories_json["metadata"]["sources"] = [] - memories_json["metadata"]["memory"] = memories.memory - memories_json["metadata"]["id"] = memories.id - reference.append({"metadata": memories_json["metadata"]}) + if isinstance(memories, TextualMemoryItem): + memories_json = memories.model_dump() + memories_json["metadata"]["ref_id"] = f"{memories.id.split('-')[0]}" + memories_json["metadata"]["embedding"] = [] + memories_json["metadata"]["sources"] = [] + memories_json["metadata"]["memory"] = memories.memory + memories_json["metadata"]["id"] = memories.id + reference.append({"metadata": memories_json["metadata"]}) + else: + memories_json = memories + memories_json["metadata"]["ref_id"] = f"{memories_json['id'].split('-')[0]}" + memories_json["metadata"]["embedding"] = [] + memories_json["metadata"]["sources"] = [] + memories_json["metadata"]["memory"] = memories_json["memory"] + memories_json["metadata"]["id"] = memories_json["id"] + reference.append({"metadata": memories_json["metadata"]}) return reference diff --git a/src/memos/mem_scheduler/general_modules/base.py b/src/memos/mem_scheduler/general_modules/base.py index 392f2bde3..0b80b9e7d 100644 --- a/src/memos/mem_scheduler/general_modules/base.py +++ b/src/memos/mem_scheduler/general_modules/base.py @@ -51,7 +51,7 @@ def _build_system_prompt(self, memories: list | None = None) -> str: def get_mem_cube(self, mem_cube_id: str) -> GeneralMemCube: logger.error(f"mem_cube {mem_cube_id} does not exists.") - return self.mem_cubes.get(mem_cube_id, None) + return self.current_mem_cube @property def chat_llm(self) -> BaseLLM: diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 32fefce63..2b14887d6 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -53,9 +53,6 @@ def long_memory_update_process( ): mem_cube = self.current_mem_cube - # for status update - self._set_current_context_from_message(msg=messages[0]) - # update query monitors for msg in messages: self.monitor.register_query_monitor_if_not_exists( @@ -185,9 +182,6 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: if len(messages) == 0: return - # for status update - self._set_current_context_from_message(msg=messages[0]) - def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.info(f"Messages {messages} assigned to {ADD_LABEL} handler.") # Process the query in a session turn @@ -201,9 +195,6 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: if len(messages) == 0: return - # for status update - self._set_current_context_from_message(msg=messages[0]) - # submit logs for msg in messages: try: diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index e2e0be69c..15a6a8b49 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -218,7 +218,12 @@ def search( ) def get_relevant_subgraph( - self, query: str, top_k: int = 5, depth: int = 2, center_status: str = "activated" + self, + query: str, + top_k: int = 5, + depth: int = 2, + center_status: str = "activated", + user_name: str | None = None, ) -> dict[str, Any]: """ Find and merge the local neighborhood sub-graphs of the top-k @@ -249,7 +254,9 @@ def get_relevant_subgraph( query_embedding = self.embedder.embed([query])[0] # Step 2: Get top-1 similar node - similar_nodes = self.graph_store.search_by_embedding(query_embedding, top_k=top_k) + similar_nodes = self.graph_store.search_by_embedding( + query_embedding, top_k=top_k, user_name=user_name + ) if not similar_nodes: logger.info("No similar nodes found for query embedding.") return {"core_id": None, "nodes": [], "edges": []} @@ -264,7 +271,7 @@ def get_relevant_subgraph( score = node["score"] subgraph = self.graph_store.get_subgraph( - center_id=core_id, depth=depth, center_status=center_status + center_id=core_id, depth=depth, center_status=center_status, user_name=user_name ) if subgraph is None or not subgraph["core_node"]: From fa0573928250533a4287d8a6e7d342d8144c233a Mon Sep 17 00:00:00 2001 From: HarveyXiang Date: Thu, 13 Nov 2025 21:17:20 +0800 Subject: [PATCH 09/18] Fix/no response (#490) * fix: response error * fix: response error * fix: response error * feat: replace context thread --------- Co-authored-by: harvey_xiang Co-authored-by: CaralHsi From e8de395b1c0ef1a37dfc665ed98f5f24557bad1b Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Fri, 14 Nov 2025 17:31:05 +0800 Subject: [PATCH 10/18] fix lack mem_cube_id bug in pref async (#494) Co-authored-by: yuan.wang --- src/memos/mem_os/core.py | 6 +++++- src/memos/mem_scheduler/general_scheduler.py | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 1b6d4e126..3b53cef1a 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -796,7 +796,11 @@ def process_preference_memory(): pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory( messages_list, type="chat", - info={"user_id": target_user_id, "session_id": self.session_id}, + info={ + "user_id": target_user_id, + "session_id": self.session_id, + "mem_cube_id": mem_cube_id, + }, ) pref_ids = self.mem_cubes[mem_cube_id].pref_mem.add(pref_memories) logger.info( diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 2b14887d6..6e916962e 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -521,7 +521,9 @@ def process_message(message: ScheduleMessageItem): # Use pref_mem.get_memory to process the memories pref_memories = pref_mem.get_memory( - messages_list, type="chat", info={"user_id": user_id, "session_id": session_id} + messages_list, + type="chat", + info={"user_id": user_id, "session_id": session_id, "mem_cube_id": mem_cube_id}, ) # Add pref_mem to vector db pref_ids = pref_mem.add(pref_memories) From e9f663a2bea658b5ca96526b4d25633bad4ccb9c Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Sat, 15 Nov 2025 13:59:12 +0800 Subject: [PATCH 11/18] Feat/fix explicit threshold (#495) * fix explicit pref threshold * modify 2.0 * change theeshold --------- Co-authored-by: yuan.wang --- src/memos/memories/textual/prefer_text_memory/retrievers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py index 9f0d1ab32..c3aa950e4 100644 --- a/src/memos/memories/textual/prefer_text_memory/retrievers.py +++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py @@ -136,7 +136,7 @@ def retrieve( # filter explicit mem by score bigger than threshold explicit_prefs_mem = [ - item for item in explicit_prefs_mem if explicit_id_scores.get(item.id, 0) >= 0.2 + item for item in explicit_prefs_mem if explicit_id_scores.get(item.id, 0) >= 0.0 ] return explicit_prefs_mem + implicit_prefs_mem From 7541827718d1d1759960001732dcb882a63b7d0b Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Mon, 17 Nov 2025 16:29:24 +0800 Subject: [PATCH 12/18] =?UTF-8?q?Feat=EF=BC=9Areorg=20playground=20code=20?= =?UTF-8?q?(#497)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: re org code * feat: code reorg and merge API and playground * feat: update memcube info * feat: remove act mem and params mem * feat: upadte init * code suffix * feat: update internet search mode --------- Co-authored-by: CaralHsi --- src/memos/api/handlers/chat_handler.py | 6 +++--- src/memos/api/handlers/search_handler.py | 2 +- .../memories/textual/tree_text_memory/retrieve/searcher.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index 9b0048ed4..f6023e5c8 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -116,7 +116,7 @@ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, An query=chat_req.query, top_k=chat_req.top_k or 10, session_id=chat_req.session_id, - mode=SearchMode.FINE, + mode=SearchMode.FAST, internet_search=chat_req.internet_search, moscube=chat_req.moscube, chat_history=chat_req.history, @@ -213,8 +213,8 @@ def generate_chat_response() -> Generator[str, None, None]: query=chat_req.query, top_k=20, session_id=chat_req.session_id, - mode=SearchMode.FINE, - internet_search=chat_req.internet_search, + mode=SearchMode.FINE if chat_req.internet_search else SearchMode.FAST, + internet_search=chat_req.internet_search, # TODO this param is not worked at fine mode moscube=chat_req.moscube, chat_history=chat_req.history, ) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index 9fc8a5b28..e8e4e07d6 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -245,7 +245,7 @@ def _fine_search( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, - mode=SearchMode.FAST, + mode=SearchMode.FINE, manual_close_internet=not search_req.internet_search, moscube=search_req.moscube, search_filter=search_filter, diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index f196c5569..14ea8e2cb 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -182,7 +182,7 @@ def _parse_task( query_embedding = None # fine mode will trigger initial embedding search - if mode == "fine": + if mode == "fine_old": logger.info("[SEARCH] Fine mode: embedding search") query_embedding = self.embedder.embed([query])[0] From 934d00db33978ae1a28c04d0240acd440bad1c90 Mon Sep 17 00:00:00 2001 From: chunyu li <78344051+fridayL@users.noreply.github.com> Date: Tue, 18 Nov 2025 17:07:03 +0800 Subject: [PATCH 13/18] Feat: add redis_scheduler (#499) * debug an error function name * feat: Add DynamicCache compatibility for different transformers versions - Fix build_kv_cache method in hf.py to handle both old and new DynamicCache structures - Support new 'layers' attribute with key_cache/value_cache or keys/values - Maintain backward compatibility with direct key_cache/value_cache attributes - Add comprehensive error handling and logging for unsupported structures - Update move_dynamic_cache_htod function in kv.py for cross-version compatibility - Handle layers-based structure in newer transformers versions - Support alternative attribute names (keys/values vs key_cache/value_cache) - Preserve original functionality for older transformers versions - Add comprehensive tests for DynamicCache compatibility - Test activation memory update with mock DynamicCache layers - Verify layers attribute access across different transformers versions - Fix scheduler logger mock to include memory_manager attribute This resolves AttributeError issues when using different versions of the transformers library and ensures robust handling of DynamicCache objects. debug * feat: implement APIAnalyzerForScheduler for memory operations - Add APIAnalyzerForScheduler class with search/add operations - Support requests and http.client with connection reuse - Include comprehensive error handling and dynamic configuration - Add English test suite with real-world conversation scenarios * feat: Add search_ws API endpoint and enhance API analyzer functionality - Add search_ws endpoint in server_router.py for scheduler-enabled search - Fix missing imports: time module, SearchRequest class, and get_mos_product_instance function - Implement search_ws method in api_analyzer.py with HTTP client support - Add _search_ws_with_requests and _search_ws_with_http_client private methods - Include search_ws usage example in demonstration code - Enhance scheduler and dispatcher capabilities for improved memory management - Expand test coverage to ensure functionality stability This update primarily strengthens the memory scheduling system's search capabilities, providing users with more flexible API interface options. * fix: resolve test failures and warnings in test suite - Fix Pydantic serialization warning in test_memos_chen_tang_hello_world * Add warnings filter to suppress UserWarning from Pydantic serialization - Fix KeyError: 'past_key_values' in test_build_kv_cache_and_generation * Update mock configuration to properly return forward_output with past_key_values * Add DynamicCache version compatibility handling in test mocks * Support both old and new transformers versions with layers/key_cache attributes * Improve assertion logic to check all model calls for required parameters - Update base_scheduler.py to use centralized DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE constant * Add import for DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE from general_schemas * Replace hardcoded value 100 with configurable constant (1000) All tests now pass successfully with proper version compatibility handling. * feat: add a test_robustness execution to test thread pool execution * feat: optimize scheduler configuration and API search functionality - Add DEFAULT_TOP_K and DEFAULT_CONTEXT_WINDOW_SIZE global constants in general_schemas.py - Update base_scheduler.py to use global default values instead of hardcoded numbers - Fix SchedulerConfigFactory initialization issue by using keyword argument expansion - Resolve UnboundLocalError variable conflict in search_memories_ws function - Fix indentation and parameter issues in OptimizedScheduler search_for_api method - Improve code standardization and maintainability * feat: Add Redis auto-initialization with fallback strategies - Add auto_initialize_redis() with config/env/local fallback - Move Redis logic from dispatcher_monitor to redis_service - Update base_scheduler to use auto initialization - Add proper resource cleanup and error handling * feat: add database connection management to ORM module - Add MySQL engine loading from environment variables in BaseDBManager - Add Redis connection loading from environment variables in BaseDBManager - Enhance database configuration validation and error handling - Complete database adapter infrastructure for ORM module - Provide unified database connection management interface This update provides comprehensive database connection management capabilities for the mem_scheduler module, supporting dynamic MySQL and Redis configuration loading from environment variables, establishing reliable data persistence foundation for scheduling services and API services. * remove part of test * feat: add Redis-based ORM with multiprocess synchronization - Add RedisDBManager and RedisLockableORM classes - Implement atomic locking mechanism for concurrent access - Add merge functionality for different object types - Include comprehensive test suite and examples - Fix Redis key type conflicts in lock operations * fix: resolve scheduler module import and Redis integration issues * revise naive memcube creation in server router * remove long-time tests in test_scheduler * remove redis test which needs .env * refactor all codes about mixture search with scheduler * fix: resolve Redis API synchronization issues and implement search API with reranker - Fix running_entries to running_task_ids migration across codebase - Update sync_search_data method to properly handle TaskRunningStatus - Correct variable naming and logic in API synchronization flow - Implement search API endpoint with reranker functionality - Update test files to reflect new running_task_ids convention - Ensure proper Redis state management for concurrent tasks * remove a test for api module * revise to pass the test suite * address some bugs to make mix_search normally running * modify codes according to evaluation logs * feat: Optimize mixture search and enhance API client * feat: Add conversation_turn tracking for session-based memory search - Add conversation_turn field to APIMemoryHistoryEntryItem schema with default value 0 - Implement session counter in OptimizedScheduler to track turn count per session_id - Update sync_search_data method to accept and store conversation_turn parameter - Maintain session history with LRU eviction (max 5 sessions) - Rename conversation_id to session_id for consistency with request object - Enable direct access to session_id from search requests This feature allows tracking conversation turns within the same session, providing better context for memory retrieval and search history management. * adress time bug in monitor * revise simple tree * add mode to evaluation client; rewrite print to logger.info in db files * feat: 1. add redis queue for scheduler 2. finish the code related to mix search and fine search * debug the working memory code * addressed a range of bugs to make scheduler running correctly * remove test_dispatch_parallel test * print change to logger.info * adjucted the core code related to fine and mixture apis * feat: create task queue to wrap local queue and redis queue. queue now split FIFO to multi queue from different users. addressed a range of bugs * fix bugs: debug bugs about internet trigger * debug get searcher mode * feat: add manual internet * Fix: fix code format * feat: add strategy for fine search * debug redis queue * debug redis queue --------- Co-authored-by: chentang --- .../temporal_locomo/models/__init__.py | 0 .../temporal_locomo/models/locomo_eval.py | 531 ---------------- .../models/locomo_ingestion.py | 303 --------- .../temporal_locomo/models/locomo_metric.py | 390 ------------ .../models/locomo_processor.py | 370 ----------- .../models/locomo_processor_w_time_eval.py | 229 ------- .../scripts/temporal_locomo/modules/README.md | 83 --- .../temporal_locomo/modules/__init__.py | 0 .../modules/base_eval_module.py | 386 ------------ .../temporal_locomo/modules/client_manager.py | 191 ------ .../temporal_locomo/modules/constants.py | 19 - .../modules/locomo_eval_module.py | 578 ------------------ .../temporal_locomo/modules/prompts.py | 219 ------- .../temporal_locomo/modules/schemas.py | 161 ----- .../scripts/temporal_locomo/modules/utils.py | 296 --------- .../temporal_locomo/scheduler_time_eval.py | 93 --- .../temporal_locomo/temporal_locomo_eval.py | 155 ----- examples/mem_scheduler/api_w_scheduler.py | 42 +- src/memos/api/handlers/add_handler.py | 6 +- src/memos/api/handlers/base_handler.py | 3 +- src/memos/api/handlers/chat_handler.py | 4 +- src/memos/api/handlers/search_handler.py | 68 ++- src/memos/api/product_models.py | 3 +- src/memos/api/routers/server_router.py | 3 +- src/memos/mem_os/core.py | 24 +- src/memos/mem_os/main.py | 2 +- src/memos/mem_os/product.py | 2 +- .../mem_scheduler/analyzer/eval_analyzer.py | 4 +- .../analyzer/mos_for_test_scheduler.py | 2 +- src/memos/mem_scheduler/base_scheduler.py | 119 ++-- .../mem_scheduler/general_modules/base.py | 2 - .../general_modules/scheduler_logger.py | 10 +- src/memos/mem_scheduler/general_scheduler.py | 26 +- .../memory_manage_modules/retriever.py | 223 ++++--- .../monitors/dispatcher_monitor.py | 2 +- .../mem_scheduler/optimized_scheduler.py | 90 +-- .../mem_scheduler/schemas/general_schemas.py | 30 +- .../task_schedule_modules}/__init__.py | 0 .../dispatcher.py | 48 +- .../task_schedule_modules/local_queue.py | 155 +++++ .../redis_queue.py | 166 +++-- .../task_schedule_modules/task_queue.py | 151 +++++ src/memos/mem_scheduler/utils/misc_utils.py | 37 ++ src/memos/memories/textual/tree.py | 7 +- .../tree_text_memory/retrieve/searcher.py | 9 +- src/memos/templates/mem_scheduler_prompts.py | 111 +++- tests/mem_scheduler/test_dispatcher.py | 8 +- 47 files changed, 948 insertions(+), 4413 deletions(-) delete mode 100644 evaluation/scripts/temporal_locomo/models/__init__.py delete mode 100644 evaluation/scripts/temporal_locomo/models/locomo_eval.py delete mode 100644 evaluation/scripts/temporal_locomo/models/locomo_ingestion.py delete mode 100644 evaluation/scripts/temporal_locomo/models/locomo_metric.py delete mode 100644 evaluation/scripts/temporal_locomo/models/locomo_processor.py delete mode 100644 evaluation/scripts/temporal_locomo/models/locomo_processor_w_time_eval.py delete mode 100644 evaluation/scripts/temporal_locomo/modules/README.md delete mode 100644 evaluation/scripts/temporal_locomo/modules/__init__.py delete mode 100644 evaluation/scripts/temporal_locomo/modules/base_eval_module.py delete mode 100644 evaluation/scripts/temporal_locomo/modules/client_manager.py delete mode 100644 evaluation/scripts/temporal_locomo/modules/constants.py delete mode 100644 evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py delete mode 100644 evaluation/scripts/temporal_locomo/modules/prompts.py delete mode 100644 evaluation/scripts/temporal_locomo/modules/schemas.py delete mode 100644 evaluation/scripts/temporal_locomo/modules/utils.py delete mode 100644 evaluation/scripts/temporal_locomo/scheduler_time_eval.py delete mode 100644 evaluation/scripts/temporal_locomo/temporal_locomo_eval.py rename {evaluation/scripts/temporal_locomo => src/memos/mem_scheduler/task_schedule_modules}/__init__.py (100%) rename src/memos/mem_scheduler/{general_modules => task_schedule_modules}/dispatcher.py (92%) create mode 100644 src/memos/mem_scheduler/task_schedule_modules/local_queue.py rename src/memos/mem_scheduler/{general_modules => task_schedule_modules}/redis_queue.py (74%) create mode 100644 src/memos/mem_scheduler/task_schedule_modules/task_queue.py diff --git a/evaluation/scripts/temporal_locomo/models/__init__.py b/evaluation/scripts/temporal_locomo/models/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/evaluation/scripts/temporal_locomo/models/locomo_eval.py b/evaluation/scripts/temporal_locomo/models/locomo_eval.py deleted file mode 100644 index f98a481e2..000000000 --- a/evaluation/scripts/temporal_locomo/models/locomo_eval.py +++ /dev/null @@ -1,531 +0,0 @@ -import argparse -import asyncio -import json -import os -import time - -import nltk -import numpy as np - -from bert_score import score as bert_score -from dotenv import load_dotenv -from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu -from nltk.translate.meteor_score import meteor_score -from openai import AsyncOpenAI -from pydantic import BaseModel, Field -from rouge_score import rouge_scorer -from scipy.spatial.distance import cosine -from sentence_transformers import SentenceTransformer -from tqdm import tqdm - -from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules -from memos.log import get_logger - - -logger = get_logger(__name__) - - -# Download necessary NLTK resources -try: - nltk.download("wordnet", quiet=True) - nltk.download("punkt", quiet=True) - print("NLTK resources downloaded successfully.") -except Exception as e: - print(f"Warning: Failed to download NLTK resources: {e}") - - -try: - sentence_model_name = "Qwen/Qwen3-Embedding-0.6B" - sentence_model = SentenceTransformer(sentence_model_name) - print(f"SentenceTransformer model : {sentence_model_name} loaded successfully.") -except Exception as e: - print(f"Failed to load SentenceTransformer model: {e}") - sentence_model = None - - -class LLMGrade(BaseModel): - llm_judgment: str = Field(description="CORRECT or WRONG") - llm_reasoning: str = Field(description="Explain why the answer is correct or incorrect.") - - -async def locomo_grader(llm_client, question: str, gold_answer: str, response: str) -> bool: - system_prompt = """ - You are an expert grader that determines if answers to questions match a gold standard answer - """ - - accuracy_prompt = f""" - Your task is to label an answer to a question as ’CORRECT’ or ’WRONG’. You will be given the following data: - (1) a question (posed by one user to another user), - (2) a ’gold’ (ground truth) answer, - (3) a generated answer - which you will score as CORRECT/WRONG. - - The point of the question is to ask about something one user should know about the other user based on their prior conversations. - The gold answer will usually be a concise and short answer that includes the referenced topic, for example: - Question: Do you remember what I got the last time I went to Hawaii? - Gold answer: A shell necklace - The generated answer might be much longer, but you should be generous with your grading - as long as it touches on the same topic as the gold answer, it should be counted as CORRECT. - - For time related questions, the gold answer will be a specific date, month, year, etc. The generated answer might be much longer or use relative time references (like "last Tuesday" or "next month"), but you should be generous with your grading - as long as it refers to the same date or time period as the gold answer, it should be counted as CORRECT. Even if the format differs (e.g., "May 7th" vs "7 May"), consider it CORRECT if it's the same date. - - Now it’s time for the real question: - Question: {question} - Gold answer: {gold_answer} - Generated answer: {response} - - First, provide a short (one sentence) explanation of your reasoning, then finish with CORRECT or WRONG. - Do NOT include both CORRECT and WRONG in your response, or it will break the evaluation script. - - Just return the label CORRECT or WRONG in a json format with the key as "label". - """ - - response = await llm_client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": accuracy_prompt}, - ], - temperature=0, - ) - message_content = response.choices[0].message.content - label = json.loads(message_content)["label"] - parsed = LLMGrade(llm_judgment=label, llm_reasoning="") - - return parsed.llm_judgment.strip().lower() == "correct" - - -def calculate_rouge_scores(gold_answer, response): - metrics = {"rouge1_f": 0.0, "rouge2_f": 0.0, "rougeL_f": 0.0} - try: - scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True) - rouge_scores = scorer.score(gold_answer, response) - metrics["rouge1_f"] = rouge_scores["rouge1"].fmeasure - metrics["rouge2_f"] = rouge_scores["rouge2"].fmeasure - metrics["rougeL_f"] = rouge_scores["rougeL"].fmeasure - except Exception as e: - print(f"Failed to calculate ROUGE scores: {e}") - return metrics - - -def calculate_bleu_scores(gold_tokens, response_tokens): - metrics = {"bleu1": 0.0, "bleu2": 0.0, "bleu3": 0.0, "bleu4": 0.0} - - try: - smoothing = SmoothingFunction().method1 - weights = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)] - - for i, weight in enumerate(weights, 1): - metrics[f"bleu{i}"] = sentence_bleu( - [gold_tokens], response_tokens, weights=weight, smoothing_function=smoothing - ) - except ZeroDivisionError: - pass - except Exception as e: - print(f"Failed to calculate BLEU scores: {e}") - - return metrics - - -def calculate_meteor_score(gold_tokens, response_tokens): - try: - return meteor_score([gold_tokens], response_tokens) - except Exception as e: - print(f"Failed to calculate METEOR score: {e}") - return 0.0 - - -def calculate_semantic_similarity(gold_answer, response): - global sentence_model - - try: - if sentence_model is None: - sentence_model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B") - - gold_embedding = sentence_model.encode([gold_answer], show_progress_bar=False)[0] - response_embedding = sentence_model.encode([response], show_progress_bar=False)[0] - return 1 - cosine(gold_embedding, response_embedding) - except Exception as e: - print(f"Failed to calculate semantic similarity: {e}") - return 0.0 - - -def calculate_f1_score(gold_tokens, response_tokens): - try: - gold_set = set(gold_tokens) - response_set = set(response_tokens) - - if len(gold_set) == 0 or len(response_set) == 0: - return 0.0 - - precision = len(gold_set.intersection(response_set)) / len(response_set) - recall = len(gold_set.intersection(response_set)) / len(gold_set) - - if precision + recall > 0: - return 2 * precision * recall / (precision + recall) - return 0.0 - except Exception as e: - print(f"Failed to calculate F1 score: {e}") - return 0.0 - - -def calculate_nlp_metrics(gold_answer, response, context, options=None): - if options is None: - options = ["lexical", "semantic"] - - gold_answer = str(gold_answer) if gold_answer is not None else "" - response = str(response) if response is not None else "" - - metrics = {"context_tokens": len(nltk.word_tokenize(context)) if context else 0} - - if "lexical" in options: - gold_tokens = nltk.word_tokenize(gold_answer.lower()) - response_tokens = nltk.word_tokenize(response.lower()) - - metrics["lexical"] = {} - metrics["lexical"]["f1"] = calculate_f1_score(gold_tokens, response_tokens) - metrics["lexical"].update(calculate_rouge_scores(gold_answer, response)) - metrics["lexical"].update(calculate_bleu_scores(gold_tokens, response_tokens)) - metrics["lexical"]["meteor"] = calculate_meteor_score(gold_tokens, response_tokens) - - if "semantic" in options: - metrics["semantic"] = {} - metrics["semantic"]["similarity"] = calculate_semantic_similarity(gold_answer, response) - _, _, f1 = bert_score( - [gold_answer], [response], lang="en", rescale_with_baseline=True, verbose=False - ) - metrics["semantic"]["bert_f1"] = f1.item() if f1 is not None else 0.0 - - return metrics - - -def convert_numpy_types(obj): - if isinstance(obj, np.number): - return float(obj) - elif isinstance(obj, dict): - return {k: convert_numpy_types(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [convert_numpy_types(i) for i in obj] - else: - return obj - - -async def process_group_responses( - group_id, group_responses, oai_client, evaluation_options, num_runs: int -): - graded_responses = [] - - # Process responses with asyncio for concurrent API calls - for response in tqdm(group_responses, desc=f"Processing group {group_id}"): - question = response.get("question") - answer = response.get("answer") - ground_truth = response.get("golden_answer") - category = response.get("category") - - context = response.get("search_context", "") - response_duration_ms = response.get("response_duration_ms", 0.0) - search_duration_ms = response.get("search_duration_ms", 0.0) - - if ground_truth is None: - continue - - grading_tasks = [ - locomo_grader(oai_client, question, ground_truth, answer) for _ in range(num_runs) - ] - judgments = await asyncio.gather(*grading_tasks) - judgments_dict = {f"judgment_{i + 1}": j for i, j in enumerate(judgments)} - - nlp_metrics = calculate_nlp_metrics(ground_truth, answer, context, evaluation_options) - - graded_response = { - "question": question, - "answer": answer, - "golden_answer": ground_truth, - "category": category, - "llm_judgments": judgments_dict, - "nlp_metrics": nlp_metrics, - "response_duration_ms": response_duration_ms, - "search_duration_ms": search_duration_ms, - "total_duration_ms": response_duration_ms + search_duration_ms, - } - graded_responses.append(graded_response) - - return group_id, graded_responses - - -async def process_single_group(group_id, group_responses, oai_client, evaluation_options, num_runs): - try: - start_time = time.time() - result = await process_group_responses( - group_id, group_responses, oai_client, evaluation_options, num_runs - ) - end_time = time.time() - elapsed_time = round(end_time - start_time, 2) - print(f"Group {group_id} processed in {elapsed_time} seconds") - return result - except Exception as e: - logger.error(f"Error processing group {group_id}: {e}", exc_info=True) - return group_id, [] - - -class LocomoEvaluator(LocomoEvalModelModules): - def __init__(self, args): - # Initialize base class to populate self.frame, self.version, etc. - super().__init__(args=args) - - self.evaluation_options = getattr(args, "evaluation_options", ["lexical", "semantic"]) - self.num_runs = getattr(args, "num_runs", 1) - self.max_workers = getattr(args, "workers", 4) - - load_dotenv() - self.oai_client = AsyncOpenAI( - api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL") - ) - - def _load_response_data(self): - """ - Load response data from the response path file. - - Returns: - dict: The loaded response data - """ - with open(self.response_path) as file: - return json.load(file) - - def _load_existing_evaluation_results(self): - """ - Attempt to load existing evaluation results from the judged path. - If the file doesn't exist or there's an error loading it, return an empty dict. - - Returns: - dict: Existing evaluation results or empty dict if none available - """ - all_grades = {} - try: - if os.path.exists(self.judged_path): - with open(self.judged_path) as f: - all_grades = json.load(f) - print(f"Loaded existing evaluation results from {self.judged_path}") - except Exception as e: - print(f"Error loading existing evaluation results: {e}") - - return all_grades - - def _create_evaluation_tasks(self, locomo_responses, all_grades, num_users): - """ - Create evaluation tasks for groups that haven't been evaluated yet. - - Args: - locomo_responses (dict): The loaded response data - all_grades (dict): Existing evaluation results - num_users (int): Number of user groups to process - - Returns: - tuple: (tasks list, active users count) - """ - tasks = [] - active_users = 0 - - for group_idx in range(num_users): - group_id = f"locomo_exp_user_{group_idx}" - group_responses = locomo_responses.get(group_id, []) - - if not group_responses: - print(f"No responses found for group {group_id}") - continue - - # Skip groups that already have evaluation results - if all_grades.get(group_id): - print(f"Skipping group {group_id} as it already has evaluation results") - active_users += 1 - continue - - active_users += 1 - tasks.append( - process_single_group( - group_id=group_id, - group_responses=group_responses, - oai_client=self.oai_client, - evaluation_options=self.evaluation_options, - num_runs=self.num_runs, - ) - ) - - return tasks, active_users - - async def _process_tasks(self, tasks): - """ - Process evaluation tasks with concurrency control. - - Args: - tasks (list): List of tasks to process - - Returns: - list: Results from processing all tasks - """ - if not tasks: - return [] - - semaphore = asyncio.Semaphore(self.max_workers) - - async def limited_task(task): - """Helper function to limit concurrent task execution""" - async with semaphore: - return await task - - limited_tasks = [limited_task(task) for task in tasks] - return await asyncio.gather(*limited_tasks) - - def _calculate_scores(self, all_grades): - """ - Calculate evaluation scores based on all grades. - - Args: - all_grades (dict): The complete evaluation results - - Returns: - tuple: (run_scores, evaluated_count) - """ - run_scores = [] - evaluated_count = 0 - - if self.num_runs > 0: - for i in range(1, self.num_runs + 1): - judgment_key = f"judgment_{i}" - current_run_correct_count = 0 - current_run_total_count = 0 - - for group in all_grades.values(): - for response in group: - if judgment_key in response["llm_judgments"]: - if response["llm_judgments"][judgment_key]: - current_run_correct_count += 1 - current_run_total_count += 1 - - if current_run_total_count > 0: - run_accuracy = current_run_correct_count / current_run_total_count - run_scores.append(run_accuracy) - - evaluated_count = current_run_total_count - - return run_scores, evaluated_count - - def _report_scores(self, run_scores, evaluated_count): - """ - Report evaluation scores to the console. - - Args: - run_scores (list): List of accuracy scores for each run - evaluated_count (int): Number of evaluated responses - """ - if evaluated_count > 0: - mean_of_scores = np.mean(run_scores) - std_of_scores = np.std(run_scores) - print(f"LLM-as-a-Judge Mean Score: {mean_of_scores:.4f}") - print(f"LLM-as-a-Judge Standard Deviation: {std_of_scores:.4f}") - print( - f"(Calculated from {self.num_runs} separate runs over {evaluated_count} questions)" - ) - print(f"Individual run scores: {[round(s, 4) for s in run_scores]}") - else: - print("No responses were evaluated") - print("LLM-as-a-Judge score: N/A (0/0)") - - def _save_results(self, all_grades): - """ - Save evaluation results to the judged path file. - - Args: - all_grades (dict): The complete evaluation results to save - """ - all_grades = convert_numpy_types(all_grades) - with open(self.judged_path, "w") as f: - json.dump(all_grades, f, indent=2) - print(f"Saved detailed evaluation results to {self.judged_path}") - - async def run(self): - """ - Main execution method for the LoCoMo evaluation process. - This method orchestrates the entire evaluation workflow: - 1. Loads existing evaluation results if available - 2. Processes only groups that haven't been evaluated yet - 3. Calculates and reports final evaluation scores - """ - print( - f"\n=== Starting LoCoMo evaluation for {self.frame} (version: {self.version}) with {self.num_runs} run(s) per question ===" - ) - print(f"Using {self.max_workers} concurrent workers for processing groups") - - # Load response data and existing evaluation results - locomo_responses = self._load_response_data() - all_grades = self._load_existing_evaluation_results() - - # Count total responses for reporting - num_users = 10 - total_responses_count = sum( - len(locomo_responses.get(f"locomo_exp_user_{i}", [])) for i in range(num_users) - ) - print(f"Found {total_responses_count} total responses across {num_users} users to evaluate") - - # Create tasks only for groups that haven't been evaluated yet - tasks, active_users = self._create_evaluation_tasks(locomo_responses, all_grades, num_users) - print( - f"Starting evaluation of {len(tasks)} user groups with responses (out of {active_users} active users)" - ) - - # Process tasks and update all_grades with results - if tasks: - group_results = await self._process_tasks(tasks) - for group_id, graded_responses in group_results: - all_grades[group_id] = graded_responses - - print("\n=== Evaluation Complete: Calculating final scores ===") - - # Calculate and report scores - run_scores, evaluated_count = self._calculate_scores(all_grades) - self._report_scores(run_scores, evaluated_count) - - # Save results - self._save_results(all_grades) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - default="memos_scheduler", - choices=["zep", "memos", "memos_scheduler", "mem0", "mem0_graph", "langmem", "openai"], - help="Specify the memory framework (zep or memos or mem0 or mem0_graph)", - ) - parser.add_argument( - "--version", - type=str, - default="v1.0.1", - help="Version identifier for loading results (e.g., 1010)", - ) - parser.add_argument( - "--num_runs", - type=int, - default=3, - help="Number of times to run the LLM grader for each question", - ) - parser.add_argument("--evaluation_options", nargs="+", default=["lexical", "semantic"]) - parser.add_argument( - "--workers", type=int, default=10, help="Number of concurrent workers for processing groups" - ) - cli_args = parser.parse_args() - - # Build args for evaluator - class Args: - def __init__(self, cli_args): - self.frame = cli_args.lib - self.version = cli_args.version - self.workers = cli_args.workers - self.num_runs = cli_args.num_runs - self.evaluation_options = cli_args.evaluation_options - self.top_k = 20 - self.scheduler_flag = True - - args = Args(cli_args) - evaluator = LocomoEvaluator(args=args) - asyncio.run(evaluator.run()) diff --git a/evaluation/scripts/temporal_locomo/models/locomo_ingestion.py b/evaluation/scripts/temporal_locomo/models/locomo_ingestion.py deleted file mode 100644 index b45ec3d61..000000000 --- a/evaluation/scripts/temporal_locomo/models/locomo_ingestion.py +++ /dev/null @@ -1,303 +0,0 @@ -import concurrent.futures -import sys -import time -import traceback - -from datetime import datetime, timezone -from pathlib import Path - -from tqdm import tqdm - -from evaluation.scripts.temporal_locomo.modules.constants import ( - MEM0_GRAPH_MODEL, - MEM0_MODEL, - MEMOS_MODEL, - MEMOS_SCHEDULER_MODEL, - ZEP_MODEL, -) -from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules -from memos.log import get_logger - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -class LocomoIngestor(LocomoEvalModelModules): - def __init__(self, args): - super().__init__(args=args) - - def ingest_session(self, client, session, frame, metadata, revised_client=None): - session_date = metadata["session_date"] - date_format = "%I:%M %p on %d %B, %Y UTC" - date_string = datetime.strptime(session_date, date_format).replace(tzinfo=timezone.utc) - iso_date = date_string.isoformat() - conv_id = metadata["conv_id"] - conv_id = "locomo_exp_user_" + str(conv_id) - dt = datetime.fromisoformat(iso_date) - timestamp = int(dt.timestamp()) - print(f"Processing conv {conv_id}, session {metadata['session_key']}") - start_time = time.time() - print_once = True # Print example only once per session - - if frame == ZEP_MODEL: - for chat in tqdm(session, desc=f"{metadata['session_key']}"): - data = chat.get("speaker") + ": " + chat.get("text") - - # Print example only once per session - if print_once: - print({"context": data, "conv_id": conv_id, "created_at": iso_date}) - print_once = False - - # Check if the group exists, if not create it - groups = client.group.get_all_groups() - groups = dict(groups)["groups"] - exist_ids = [gp.group_id for gp in groups] - if conv_id not in exist_ids: - client.group.add(group_id=conv_id) - - # Add the message to the group - client.graph.add( - data=data, - type="message", - created_at=iso_date, - group_id=conv_id, - ) - - elif frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: - messages = [] - messages_reverse = [] - - for chat in tqdm(session, desc=f"{metadata['session_key']}"): - data = chat.get("speaker") + ": " + chat.get("text") - - if chat.get("speaker") == metadata["speaker_a"]: - messages.append({"role": "user", "content": data, "chat_time": iso_date}) - messages_reverse.append( - {"role": "assistant", "content": data, "chat_time": iso_date} - ) - elif chat.get("speaker") == metadata["speaker_b"]: - messages.append({"role": "assistant", "content": data, "chat_time": iso_date}) - messages_reverse.append( - {"role": "user", "content": data, "chat_time": iso_date} - ) - else: - raise ValueError( - f"Unknown speaker {chat.get('speaker')} in session {metadata['session_key']}" - ) - - # Print example only once per session - if print_once: - print({"context": data, "conv_id": conv_id, "created_at": iso_date}) - print_once = False - - speaker_a_user_id = conv_id + "_speaker_a" - speaker_b_user_id = conv_id + "_speaker_b" - - client.add( - messages=messages, - user_id=speaker_a_user_id, - ) - - revised_client.add( - messages=messages_reverse, - user_id=speaker_b_user_id, - ) - print(f"Added messages for {speaker_a_user_id} and {speaker_b_user_id} successfully.") - - elif frame in [MEM0_MODEL, MEM0_GRAPH_MODEL]: - print(f"Processing abc for {metadata['session_key']}") - messages = [] - messages_reverse = [] - - for chat in tqdm(session, desc=f"{metadata['session_key']}"): - data = chat.get("speaker") + ": " + chat.get("text") - - if chat.get("speaker") == metadata["speaker_a"]: - messages.append({"role": "user", "content": data}) - messages_reverse.append({"role": "assistant", "content": data}) - elif chat.get("speaker") == metadata["speaker_b"]: - messages.append({"role": "assistant", "content": data}) - messages_reverse.append({"role": "user", "content": data}) - else: - raise ValueError( - f"Unknown speaker {chat.get('speaker')} in session {metadata['session_key']}" - ) - - # Print example only once per session - if print_once: - print({"context": data, "conv_id": conv_id, "created_at": iso_date}) - print_once = False - - for i in range(0, len(messages), 2): - batch_messages = messages[i : i + 2] - batch_messages_reverse = messages_reverse[i : i + 2] - - if frame == "mem0": - client.add( - messages=batch_messages, - timestamp=timestamp, - user_id=metadata["speaker_a_user_id"], - version="v2", - ) - client.add( - messages=batch_messages_reverse, - timestamp=timestamp, - user_id=metadata["speaker_b_user_id"], - version="v2", - ) - - elif frame == "mem0_graph": - client.add( - messages=batch_messages, - timestamp=timestamp, - user_id=metadata["speaker_a_user_id"], - output_format="v1.1", - version="v2", - enable_graph=True, - ) - client.add( - messages=batch_messages_reverse, - timestamp=timestamp, - user_id=metadata["speaker_b_user_id"], - output_format="v1.1", - version="v2", - enable_graph=True, - ) - - end_time = time.time() - elapsed_time = round(end_time - start_time, 2) - - return elapsed_time - - def process_user_for_ingestion(self, conv_id, frame, locomo_df, version, num_workers=1): - try: - # Check if locomo_df is empty or doesn't have the required columns - if locomo_df.empty or "conversation" not in locomo_df.columns: - logger.warning( - f"Skipping user {conv_id}: locomo_df is empty or missing 'conversation' column" - ) - return 0 - - conversation = locomo_df["conversation"].iloc[conv_id] - max_session_count = 35 - start_time = time.time() - total_session_time = 0 - valid_sessions = 0 - - revised_client = None - if frame == "zep": - client = self.get_client_for_ingestion(frame=frame, user_id=None, version="default") - elif frame == "mem0" or frame == "mem0_graph": - client = self.get_client_for_ingestion(frame=frame, user_id=None, version="default") - client.delete_all(user_id=f"locomo_exp_user_{conv_id}") - client.delete_all(user_id=f"{conversation.get('speaker_a')}_{conv_id}") - client.delete_all(user_id=f"{conversation.get('speaker_b')}_{conv_id}") - elif frame in ["memos", "memos_scheduler"]: - conv_id = "locomo_exp_user_" + str(conv_id) - speaker_a_user_id = conv_id + "_speaker_a" - speaker_b_user_id = conv_id + "_speaker_b" - - client = self.get_client_for_ingestion( - frame=frame, user_id=speaker_a_user_id, version=version - ) - revised_client = self.get_client_for_ingestion( - frame=frame, user_id=speaker_b_user_id, version=version - ) - else: - raise NotImplementedError() - - sessions_to_process = [] - for session_idx in tqdm(range(max_session_count), desc=f"process_user {conv_id}"): - session_key = f"session_{session_idx}" - session = conversation.get(session_key) - if session is None: - continue - - metadata = { - "session_date": conversation.get(f"session_{session_idx}_date_time") + " UTC", - "speaker_a": conversation.get("speaker_a"), - "speaker_b": conversation.get("speaker_b"), - "speaker_a_user_id": f"{conversation.get('speaker_a')}_{conv_id}", - "speaker_b_user_id": f"{conversation.get('speaker_b')}_{conv_id}", - "conv_id": conv_id, - "session_key": session_key, - } - sessions_to_process.append((session, metadata)) - valid_sessions += 1 - - print( - f"Processing {valid_sessions} sessions for user {conv_id} with {num_workers} workers" - ) - with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = { - executor.submit( - self.ingest_session, client, session, frame, metadata, revised_client - ): metadata["session_key"] - for session, metadata in sessions_to_process - } - - for future in concurrent.futures.as_completed(futures): - session_key = futures[future] - try: - session_time = future.result() - total_session_time += session_time - print(f"User {conv_id}, {session_key} processed in {session_time} seconds") - except Exception as e: - print(f"Error processing user {conv_id}, session {session_key}: {e!s}") - - end_time = time.time() - elapsed_time = round(end_time - start_time, 2) - print(f"User {conv_id} processed successfully in {elapsed_time} seconds") - - return elapsed_time - - except Exception as e: - return f"Error processing user {conv_id}: {e!s}. Exception: {traceback.format_exc()}" - - def run_ingestion(self): - frame = self.frame - version = self.version - num_workers = self.workers - - num_users = 10 - start_time = time.time() - total_time = 0 - - print( - f"Starting processing for {num_users} users in serial mode," - f" each user using {num_workers} workers for sessions..." - ) - - for user_id in range(num_users): - try: - result = self.process_user_for_ingestion( - user_id, frame, self.locomo_df, version, num_workers - ) - if isinstance(result, float): - total_time += result - else: - print(result) - except Exception as e: - print( - f"Error processing user {user_id}: {e!s}. Traceback: {traceback.format_exc()}" - ) - - if num_users > 0: - average_time = total_time / num_users - minutes = int(average_time // 60) - seconds = int(average_time % 60) - average_time_formatted = f"{minutes} minutes and {seconds} seconds" - print( - f"The frame {frame} processed {num_users} users in average of {average_time_formatted} per user." - ) - - end_time = time.time() - elapsed_time = round(end_time - start_time, 2) - minutes = int(elapsed_time // 60) - seconds = int(elapsed_time % 60) - elapsed_time = f"{minutes} minutes and {seconds} seconds" - print(f"Total processing time: {elapsed_time}.") diff --git a/evaluation/scripts/temporal_locomo/models/locomo_metric.py b/evaluation/scripts/temporal_locomo/models/locomo_metric.py deleted file mode 100644 index 532fe2e14..000000000 --- a/evaluation/scripts/temporal_locomo/models/locomo_metric.py +++ /dev/null @@ -1,390 +0,0 @@ -import argparse -import json - -import numpy as np -import pandas as pd - -from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules - - -# Category mapping as per your request -category_mapping = { - "4": "single hop", - "1": "multi hop", - "2": "temporal reasoning", - "3": "open domain", -} - - -def calculate_scores(data): - category_scores = {} - category_question_count = {} - - overall_metrics = { - "lexical": { - m: [] - for m in [ - "f1", - "rouge1_f", - "rouge2_f", - "rougeL_f", - "bleu1", - "bleu2", - "bleu3", - "bleu4", - "meteor", - ] - }, - "semantic": {m: [] for m in ["bert_f1", "similarity"]}, - "context_tokens": [], - "duration": { - m: [] for m in ["response_duration_ms", "search_duration_ms", "total_duration_ms"] - }, - } - - category_metrics = {} - user_metrics = {} - - total_questions = 0 - - all_judgment_keys = set() - judgment_run_scores = {} - - for _user, questions in data.items(): - for question in questions: - if "llm_judgments" in question: - all_judgment_keys.update(question["llm_judgments"].keys()) - - for key in all_judgment_keys: - judgment_run_scores[key] = [] - - for user, questions in data.items(): - user_total = 0 - - # Initialize user_metrics with each judgment run - user_metrics[user] = { - "total": 0, - "llm_judge_score": 0, - "llm_judge_std": 0, - "judgment_run_scores": {key: [] for key in all_judgment_keys}, - "lexical": {m: [] for m in overall_metrics["lexical"]}, - "semantic": {m: [] for m in overall_metrics["semantic"]}, - "context_tokens": [], - "duration": {m: [] for m in overall_metrics["duration"]}, - } - - for question in questions: - total_questions += 1 - user_total += 1 - - if "llm_judgments" in question: - for judgment_key, judgment_value in question["llm_judgments"].items(): - score = 1 if judgment_value else 0 - judgment_run_scores[judgment_key].append(score) - user_metrics[user]["judgment_run_scores"][judgment_key].append(score) - - category = question["category"] - if category not in category_scores: - category_scores[category] = { - "total": 0, - "category_name": category_mapping.get(str(category), "Unknown"), - "judgment_run_scores": {key: [] for key in all_judgment_keys}, - } - category_metrics[category] = { - "lexical": {m: [] for m in overall_metrics["lexical"]}, - "semantic": {m: [] for m in overall_metrics["semantic"]}, - "context_tokens": [], - "duration": {m: [] for m in overall_metrics["duration"]}, - } - category_question_count[category] = 0 - - category_scores[category]["total"] += 1 - category_question_count[category] += 1 - - if "llm_judgments" in question: - for judgment_key, judgment_value in question["llm_judgments"].items(): - score = 1 if judgment_value else 0 - category_scores[category]["judgment_run_scores"][judgment_key].append(score) - - nlp = question.get("nlp_metrics", {}) - for metric in overall_metrics["lexical"]: - v = nlp.get("lexical", {}).get(metric) - if v is not None: - overall_metrics["lexical"][metric].append(v) - category_metrics[category]["lexical"][metric].append(v) - user_metrics[user]["lexical"][metric].append(v) - - for metric in overall_metrics["semantic"]: - v = nlp.get("semantic", {}).get(metric) - if v is not None: - overall_metrics["semantic"][metric].append(v) - category_metrics[category]["semantic"][metric].append(v) - user_metrics[user]["semantic"][metric].append(v) - - ct = nlp.get("context_tokens") - if ct is not None: - overall_metrics["context_tokens"].append(ct) - category_metrics[category]["context_tokens"].append(ct) - user_metrics[user]["context_tokens"].append(ct) - - for metric in overall_metrics["duration"]: - v = question.get(metric) - if v is not None: - overall_metrics["duration"][metric].append(v) - category_metrics[category]["duration"][metric].append(v) - user_metrics[user]["duration"][metric].append(v) - - user_metrics[user]["total"] = user_total - - judgment_avgs = [] - for _judgment_key, scores in user_metrics[user]["judgment_run_scores"].items(): - if scores: - avg = np.mean(scores) - judgment_avgs.append(avg) - - user_metrics[user]["llm_judge_score"] = np.mean(judgment_avgs) if judgment_avgs else 0.0 - user_metrics[user]["llm_judge_std"] = ( - np.std(judgment_avgs) if len(judgment_avgs) > 1 else 0.0 - ) - - for group in ["lexical", "semantic"]: - for metric in user_metrics[user][group]: - values = user_metrics[user][group][metric] - user_metrics[user][group][metric] = np.mean(values) if values else 0.0 - - user_metrics[user]["context_tokens"] = ( - np.mean(user_metrics[user]["context_tokens"]) - if user_metrics[user]["context_tokens"] - else 0.0 - ) - - duration_metrics = list(user_metrics[user]["duration"].keys()) - for metric in duration_metrics: - values = user_metrics[user]["duration"][metric] - if values: - user_metrics[user]["duration"][metric] = np.mean(values) - user_metrics[user]["duration"][f"{metric}_p50"] = np.percentile(values, 50) - user_metrics[user]["duration"][f"{metric}_p95"] = np.percentile(values, 95) - else: - user_metrics[user]["duration"][metric] = 0.0 - user_metrics[user]["duration"][f"{metric}_p50"] = 0.0 - user_metrics[user]["duration"][f"{metric}_p95"] = 0.0 - - judgment_run_averages = [] - for _judgment_key, scores in judgment_run_scores.items(): - if scores: - judgment_run_averages.append(np.mean(scores)) - - llm_judge_score = np.mean(judgment_run_averages) if judgment_run_averages else 0.0 - llm_judge_std = np.std(judgment_run_averages) if len(judgment_run_averages) > 1 else 0.0 - - category_overall_scores = {} - for category, score_data in category_scores.items(): - category_judgment_avgs = [] - for _judgment_key, scores in score_data["judgment_run_scores"].items(): - if scores: - category_judgment_avgs.append(np.mean(scores)) - - category_overall_scores[category] = { - "category_name": score_data["category_name"], - "llm_judge_score": np.mean(category_judgment_avgs) if category_judgment_avgs else 0.0, - "llm_judge_std": np.std(category_judgment_avgs) - if len(category_judgment_avgs) > 1 - else 0.0, - "total": score_data["total"], - "lexical": {}, - "semantic": {}, - "duration": {}, - "context_tokens": 0.0, - } - - for group in ["lexical", "semantic"]: - for metric in category_metrics[category][group]: - values = category_metrics[category][group][metric] - category_overall_scores[category][group][metric] = ( - np.mean(values) if values else 0.0 - ) - - category_overall_scores[category]["context_tokens"] = ( - np.mean(category_metrics[category]["context_tokens"]) - if category_metrics[category]["context_tokens"] - else 0.0 - ) - - # Calculate mean and percentiles for category duration metrics - duration_metrics = list( - category_metrics[category]["duration"].keys() - ) # Create a list of keys first - for metric in duration_metrics: - values = category_metrics[category]["duration"][metric] - if values: - category_overall_scores[category]["duration"][metric] = np.mean(values) - # Add P50 (median) and P95 percentiles - category_overall_scores[category]["duration"][f"{metric}_p50"] = np.percentile( - values, 50 - ) - category_overall_scores[category]["duration"][f"{metric}_p95"] = np.percentile( - values, 95 - ) - else: - category_overall_scores[category]["duration"][metric] = 0.0 - category_overall_scores[category]["duration"][f"{metric}_p50"] = 0.0 - category_overall_scores[category]["duration"][f"{metric}_p95"] = 0.0 - - # calculate overall scores - overall_metric_averages = { - "llm_judge_score": llm_judge_score, - "llm_judge_std": llm_judge_std, - "lexical": {}, - "semantic": {}, - "context_tokens": 0.0, - "duration": {}, - } - - for group in ["lexical", "semantic"]: - for metric in overall_metrics[group]: - values = overall_metrics[group][metric] - overall_metric_averages[group][metric] = np.mean(values) if values else 0.0 - - overall_metric_averages["context_tokens"] = ( - np.mean(overall_metrics["context_tokens"]) if overall_metrics["context_tokens"] else 0.0 - ) - - duration_metrics = list(overall_metrics["duration"].keys()) - for metric in duration_metrics: - values = overall_metrics["duration"][metric] - if values: - overall_metric_averages["duration"][metric] = np.mean(values) - overall_metric_averages["duration"][f"{metric}_p50"] = np.percentile(values, 50) - overall_metric_averages["duration"][f"{metric}_p95"] = np.percentile(values, 95) - else: - overall_metric_averages["duration"][metric] = 0.0 - overall_metric_averages["duration"][f"{metric}_p50"] = 0.0 - overall_metric_averages["duration"][f"{metric}_p95"] = 0.0 - - return { - "metrics": overall_metric_averages, - "category_scores": category_overall_scores, - "user_scores": user_metrics, - } - - -def save_to_excel(results, output_path): - # Create a combined data structure for metrics and category scores - combined_data = [] - - # Process overall metrics - flatten nested structures - overall_row = {"category": "overall"} - overall_row["llm_judge_score"] = results["metrics"]["llm_judge_score"] - overall_row["llm_judge_std"] = results["metrics"]["llm_judge_std"] - - # Add all lexical metrics - for metric, value in results["metrics"]["lexical"].items(): - overall_row[metric] = value - - # Add all semantic metrics - for metric, value in results["metrics"]["semantic"].items(): - overall_row[metric] = value - - # Add context tokens - overall_row["context_tokens"] = results["metrics"]["context_tokens"] - - # Add all duration metrics, including percentiles - for metric, value in results["metrics"]["duration"].items(): - overall_row[metric] = value - - combined_data.append(overall_row) - - # Process category scores - flatten nested structures - for _, scores in results["category_scores"].items(): - category_row = {"category": scores["category_name"]} - category_row["llm_judge_score"] = scores["llm_judge_score"] - category_row["llm_judge_std"] = scores["llm_judge_std"] - - # Add all lexical metrics - for metric, value in scores["lexical"].items(): - category_row[metric] = value - - # Add all semantic metrics - for metric, value in scores["semantic"].items(): - category_row[metric] = value - - # Add context tokens - category_row["context_tokens"] = scores["context_tokens"] - - # Add all duration metrics, including percentiles - for metric, value in scores["duration"].items(): - category_row[metric] = value - - combined_data.append(category_row) - - # Create DataFrame and save to Excel - combined_df = pd.DataFrame(combined_data) - - # Create a pandas Excel writer - with pd.ExcelWriter(output_path) as writer: - combined_df.to_excel(writer, sheet_name="Metrics", index=False) - - print(f"Excel file saved to: {output_path}") - - -class LocomoMetric(LocomoEvalModelModules): - def __init__(self, args): - super().__init__(args=args) - - def run(self): - with open(self.judged_path) as file: - data = json.load(file) - - results = calculate_scores(data) - - with open(self.grade_path, "w") as outfile: - json.dump(results, outfile, indent=4) - - save_to_excel(results, self.excel_path) - - print("\n=== Metric Calculation Complete ===") - total = sum(results["category_scores"][cat]["total"] for cat in results["category_scores"]) - print( - f"LLM-as-a-Judge score: {results['metrics']['llm_judge_score']:.4f} ± {results['metrics']['llm_judge_std']:.4f}" - ) - print(f"Total questions evaluated: {total}") - - print("\n=== Duration Metrics ===") - for metric in ["response_duration_ms", "search_duration_ms", "total_duration_ms"]: - print(f"{metric} (avg): {results['metrics']['duration'][metric]:.2f} ms") - print(f"{metric} (P50): {results['metrics']['duration'][f'{metric}_p50']:.2f} ms") - print(f"{metric} (P95): {results['metrics']['duration'][f'{metric}_p95']:.2f} ms") - - print(f"\nResults have been written to {self.grade_path}") - print(f"Excel report has been saved to {self.excel_path}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - default="memos_scheduler", - choices=["zep", "memos", "memos_scheduler", "mem0", "mem0_graph", "langmem", "openai"], - help="Specify the memory framework (zep or memos or mem0 or mem0_graph)", - ) - parser.add_argument( - "--version", - type=str, - default="v1.0.1", - help="Version identifier for loading results (e.g., 1010)", - ) - cli_args = parser.parse_args() - - # Build a minimal args namespace compatible with LocomoEvalModelModules - class _Args: - def __init__(self, frame, version): - self.frame = frame - self.version = version - self.workers = 1 - self.top_k = 20 - self.scheduler_flag = True - - args = _Args(frame=cli_args.lib, version=cli_args.version) - LocomoMetric(args=args).run() diff --git a/evaluation/scripts/temporal_locomo/models/locomo_processor.py b/evaluation/scripts/temporal_locomo/models/locomo_processor.py deleted file mode 100644 index 7cec6f5af..000000000 --- a/evaluation/scripts/temporal_locomo/models/locomo_processor.py +++ /dev/null @@ -1,370 +0,0 @@ -import json -import sys - -from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor, as_completed -from pathlib import Path -from time import time - -from dotenv import load_dotenv - -from evaluation.scripts.temporal_locomo.modules.constants import ( - MEMOS_SCHEDULER_MODEL, -) -from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules -from evaluation.scripts.temporal_locomo.modules.prompts import ( - SEARCH_PROMPT_MEM0, - SEARCH_PROMPT_MEM0_GRAPH, - SEARCH_PROMPT_MEMOS, - SEARCH_PROMPT_ZEP, -) -from evaluation.scripts.temporal_locomo.modules.schemas import ContextUpdateMethod, RecordingCase -from evaluation.scripts.temporal_locomo.modules.utils import save_evaluation_cases -from memos.log import get_logger - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -class LocomoProcessor(LocomoEvalModelModules): - """ - A class for handling conversational memory management across different memory frameworks. - Supports multiple memory backends (zep, mem0, memos, etc.) for searching and retrieving - relevant context to generate conversational responses. - """ - - def __init__(self, args): - """Initialize the LocomoChatter with path configurations and templates""" - super().__init__(args=args) - - # Template definitions for different memory frameworks - self.search_template_zep = SEARCH_PROMPT_ZEP - - self.search_template_mem0 = SEARCH_PROMPT_MEM0 - - self.search_template_mem0_graph = SEARCH_PROMPT_MEM0_GRAPH - - self.search_template_memos = SEARCH_PROMPT_MEMOS - - self.processed_data_dir = self.result_dir / "processed_data" - - def update_context(self, conv_id, method, **kwargs): - if method == ContextUpdateMethod.CHAT_HISTORY: - if "query" not in kwargs or "answer" not in kwargs: - raise ValueError("query and answer are required for TEMPLATE update method") - new_context = f"User: {kwargs['query']}\nAssistant: {kwargs['answer']}\n\n" - if self.pre_context_cache[conv_id] is None: - self.pre_context_cache[conv_id] = "" - self.pre_context_cache[conv_id] += new_context - else: - if "cur_context" not in kwargs: - raise ValueError("cur_context is required for DIRECT update method") - cur_context = kwargs["cur_context"] - self.pre_context_cache[conv_id] = cur_context - - def eval_context(self, context, query, gold_answer, oai_client): - can_answer_start = time() - can_answer = self.analyze_context_answerability(context, query, gold_answer, oai_client) - can_answer_duration_ms = (time() - can_answer_start) * 1000 - # Update global stats - with self.stats_lock: - self.stats[self.frame][self.version]["memory_stats"]["total_queries"] += 1 - if can_answer: - self.stats[self.frame][self.version]["memory_stats"]["can_answer_count"] += 1 - else: - self.stats[self.frame][self.version]["memory_stats"]["cannot_answer_count"] += 1 - total_queries = self.stats[self.frame][self.version]["memory_stats"]["total_queries"] - can_answer_count = self.stats[self.frame][self.version]["memory_stats"][ - "can_answer_count" - ] - hit_rate = (can_answer_count / total_queries * 100) if total_queries > 0 else 0 - self.stats[self.frame][self.version]["memory_stats"]["answer_hit_rate"] = hit_rate - self.stats[self.frame][self.version]["memory_stats"]["can_answer_duration_ms"] = ( - can_answer_duration_ms - ) - self.save_stats() - return can_answer, can_answer_duration_ms - - def _update_stats_and_context( - self, - *, - conv_id, - frame, - version, - conv_stats, - conv_stats_path, - query, - answer, - gold_answer, - cur_context, - can_answer, - ): - """ - Update conversation statistics and context. - - Args: - conv_id: Conversation ID - frame: Model frame - version: Model version - conv_stats: Conversation statistics dictionary - conv_stats_path: Path to save conversation statistics - query: User query - answer: Generated answer - gold_answer: Golden answer - cur_context: Current context - can_answer: Whether the context can answer the query - """ - # Update conversation stats - conv_stats["total_queries"] += 1 - conv_stats["response_count"] += 1 - if frame == MEMOS_SCHEDULER_MODEL: - if can_answer: - conv_stats["can_answer_count"] += 1 - else: - conv_stats["cannot_answer_count"] += 1 - if conv_stats["total_queries"] > 0: - conv_stats["answer_hit_rate"] = ( - conv_stats["can_answer_count"] / conv_stats["total_queries"] - ) * 100 - - # Persist conversation stats snapshot - self._save_conv_stats(conv_id, frame, version, conv_stats, conv_stats_path) - - logger.info(f"Processed question: {query[:100]}") - logger.info(f"Answer: {answer[:100]}") - - # Update pre-context cache with current context - with self.stats_lock: - if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY: - self.update_context( - conv_id=conv_id, - method=self.context_update_method, - query=query, - answer=answer, - ) - else: - self.update_context( - conv_id=conv_id, - method=self.context_update_method, - cur_context=cur_context, - ) - - self.print_eval_info() - - def _process_single_qa( - self, - qa, - *, - client, - reversed_client, - metadata, - frame, - version, - conv_id, - conv_stats_path, - oai_client, - top_k, - conv_stats, - ): - query = qa.get("question") - gold_answer = qa.get("answer") - qa_category = qa.get("category") - if qa_category == 5: - return None - - # Search - cur_context, search_duration_ms = self.search_query( - client, query, metadata, frame, reversed_client=reversed_client, top_k=top_k - ) - if not cur_context: - logger.warning(f"No context found for query: {query[:100]}") - cur_context = "" - - if self.context_update_method == ContextUpdateMethod.CURRENT_CONTEXT: - context = cur_context - else: - # Context answer ability analysis (for memos_scheduler only) - if self.pre_context_cache[conv_id] is None: - # Update pre-context cache with current context and return - if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY: - answer_from_cur_context = self.locomo_response( - frame, oai_client, cur_context, query - ) - self.update_context( - conv_id=conv_id, - method=self.context_update_method, - query=query, - answer=answer_from_cur_context, - ) - else: - self.update_context( - conv_id=conv_id, - method=self.context_update_method, - cur_context=cur_context, - ) - return None - else: - context = self.pre_context_cache[conv_id] - - # Generate answer - answer_start = time() - answer = self.locomo_response(frame, oai_client, context, query) - response_duration_ms = (time() - answer_start) * 1000 - - can_answer, can_answer_duration_ms = self.eval_context( - context=context, query=query, gold_answer=gold_answer, oai_client=oai_client - ) - - # Record case for memos_scheduler - try: - recording_case = RecordingCase( - conv_id=conv_id, - query=query, - answer=answer, - context=cur_context, - pre_context=self.pre_context_cache[conv_id], - can_answer=can_answer, - can_answer_reason=f"Context analysis result: {'can answer' if can_answer else 'cannot answer'}", - search_duration_ms=search_duration_ms, - can_answer_duration_ms=can_answer_duration_ms, - response_duration_ms=response_duration_ms, - category=int(qa_category) if qa_category is not None else None, - golden_answer=str(qa.get("answer", "")), - ) - if can_answer: - self.can_answer_cases.append(recording_case) - else: - self.cannot_answer_cases.append(recording_case) - except Exception as e: - logger.error(f"Error creating RecordingCase: {e}") - print(f"Error creating RecordingCase: {e}") - logger.error(f"QA data: {qa}") - print(f"QA data: {qa}") - logger.error(f"Query: {query}") - logger.error(f"Answer: {answer}") - logger.error( - f"Golden answer (raw): {qa.get('answer')} (type: {type(qa.get('answer'))})" - ) - logger.error(f"Category: {qa_category} (type: {type(qa_category)})") - logger.error(f"Can answer: {can_answer}") - raise e - - if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY: - answer_from_cur_context = self.locomo_response(frame, oai_client, cur_context, query) - answer = answer_from_cur_context - # Update conversation stats and context - self._update_stats_and_context( - conv_id=conv_id, - frame=frame, - version=version, - conv_stats=conv_stats, - conv_stats_path=conv_stats_path, - query=query, - answer=answer, - gold_answer=gold_answer, - cur_context=cur_context, - can_answer=can_answer, - ) - - return { - "question": query, - "answer": answer, - "category": qa_category, - "golden_answer": gold_answer, - "search_context": cur_context, - "response_duration_ms": response_duration_ms, - "search_duration_ms": search_duration_ms, - "can_answer_duration_ms": can_answer_duration_ms, - "can_answer": can_answer if frame == MEMOS_SCHEDULER_MODEL else None, - } - - def run_locomo_processing(self, num_users=10): - load_dotenv() - - frame = self.frame - version = self.version - num_workers = self.workers - top_k = self.top_k - - # Storage for aggregated results - all_search_results = defaultdict(list) - all_response_results = defaultdict(list) - num_users = num_users - - # Prepare arguments for each user processing task - user_args = [(idx, self.locomo_df, frame, version, top_k) for idx in range(num_users)] - - if num_workers > 1: - # === parallel running ==== - # Use ThreadPoolExecutor for parallel processing - print( - f"Starting parallel processing for {num_users} users, using {num_workers} workers for sessions..." - ) - with ThreadPoolExecutor(max_workers=num_workers) as executor: - # Submit all user processing tasks - future_to_user = { - executor.submit(self.process_user_wrapper, args): idx - for idx, args in enumerate(user_args) - } - - # Collect results as they complete - for future in as_completed(future_to_user): - idx = future_to_user[future] - user_search_results, user_response_results, error = future.result() - if error is not None: - idx, e, traceback_str = error - print(f"Error processing user {idx}: {e}. Exception: {traceback_str}") - else: - # Aggregate results - conv_id = f"locomo_exp_user_{idx}" - all_search_results[conv_id].extend(user_search_results[conv_id]) - all_response_results[conv_id].extend(user_response_results[conv_id]) - - else: - # Serial processing - print( - f"Starting serial processing for {num_users} users in serial mode, each user using {num_workers} workers for sessions..." - ) - for idx, args in enumerate(user_args): - user_search_results, user_response_results, error = self.process_user_wrapper(args) - if error is not None: - idx, e, traceback_str = error - print(f"Error processing user {idx}: {e}. Exception: {traceback_str}") - else: - # Aggregate results - conv_id = f"locomo_exp_user_{idx}" - all_search_results[conv_id].extend(user_search_results[conv_id]) - all_response_results[conv_id].extend(user_response_results[conv_id]) - - # Print evaluation information statistics - self.print_eval_info() - self.save_stats() - - # Save all aggregated results - with open(self.search_path, "w") as fw: - json.dump(all_search_results, fw, indent=2) - print(f"Saved all search results to {self.search_path}") - - with open(self.response_path, "w") as fw: - json.dump(all_response_results, fw, indent=2) - print(f"Saved all response results to {self.response_path}") - - # Save evaluation cases if they exist - if self.can_answer_cases or self.cannot_answer_cases: - try: - saved_files = save_evaluation_cases( - can_answer_cases=self.can_answer_cases, - cannot_answer_cases=self.cannot_answer_cases, - output_dir=self.stats_dir, - frame=self.frame, - version=self.version, - ) - print(f"Saved evaluation cases: {saved_files}") - except Exception as e: - logger.error(f"Error saving evaluation cases: {e}") - - return dict(all_search_results), dict(all_response_results) diff --git a/evaluation/scripts/temporal_locomo/models/locomo_processor_w_time_eval.py b/evaluation/scripts/temporal_locomo/models/locomo_processor_w_time_eval.py deleted file mode 100644 index b909c64e1..000000000 --- a/evaluation/scripts/temporal_locomo/models/locomo_processor_w_time_eval.py +++ /dev/null @@ -1,229 +0,0 @@ -import sys -import time - -from pathlib import Path -from typing import TYPE_CHECKING - -from evaluation.scripts.temporal_locomo.models.locomo_processor import LocomoProcessor -from evaluation.scripts.temporal_locomo.modules.constants import ( - MEMOS_SCHEDULER_MODEL, -) -from evaluation.scripts.temporal_locomo.modules.prompts import ( - SEARCH_PROMPT_MEMOS, -) -from evaluation.scripts.temporal_locomo.modules.schemas import ContextUpdateMethod, RecordingCase -from memos.log import get_logger - - -if TYPE_CHECKING: - from memos.mem_os.main import MOS - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -class LocomoProcessorWithTimeEval(LocomoProcessor): - def __init__(self, args): - super().__init__(args=args) - self.time_eval_mode = getattr(self.args, "time_eval_mode", False) - assert self.args.frame == MEMOS_SCHEDULER_MODEL - assert self.context_update_method == ContextUpdateMethod.PRE_CONTEXT - if self.time_eval_mode: - logger.warning( - "time_eval_mode is activated. _process_single_qa is replaced by _process_single_qa_for_time_eval" - ) - self._process_single_qa = self._process_single_qa_for_time_eval - - def memos_scheduler_search( - self, client, query, conv_id, speaker_a, speaker_b, reversed_client=None, top_k=20 - ): - # MemOS full search process and skip the parts of scheduler - start = time.time() - client: MOS = client - - if not self.scheduler_flag: - # if not scheduler_flag, search to update working memory - self.memos_search(client, query, conv_id, speaker_a, speaker_b, reversed_client) - - # ========= MemOS Search ========= - # Search for speaker A - search_a_results = client.search( - query=query, - user_id=conv_id + "_speaker_a", - install_cube_ids=[conv_id + "_speaker_a"], - top_k=top_k, - mode="fine", - internet_search=False, - moscube=False, # cube for mos introduction - session_id=None, - )["text_mem"] - search_a_results = [[m.memory for m in one["memories"]] for one in search_a_results] - search_a_results = [item for sublist in search_a_results for item in sublist] - - # Search for speaker B - search_b_results = client.search( - query=query, - user_id=conv_id + "_speaker_b", - install_cube_ids=[conv_id + "_speaker_b"], - top_k=top_k, - mode="fine", - internet_search=False, - moscube=False, # cube for mos introduction - session_id=None, - )["text_mem"] - search_b_results = [[m.memory for m in one["memories"]] for one in search_b_results] - search_b_results = [item for sublist in search_b_results for item in sublist] - - speaker_a_context = "" - for item in search_a_results: - speaker_a_context += f"{item}\n" - - speaker_b_context = "" - for item in search_b_results: - speaker_b_context += f"{item}\n" - - context = SEARCH_PROMPT_MEMOS.format( - speaker_1=speaker_a, - speaker_1_memories=speaker_a_context, - speaker_2=speaker_b, - speaker_2_memories=speaker_b_context, - ) - - logger.info(f'query "{query[:100]}", context: {context[:100]}"') - duration_ms = (time.time() - start) * 1000 - - return context, duration_ms - - def _process_single_qa_for_time_eval( - self, - qa, - *, - client, - reversed_client, - metadata, - frame, - version, - conv_id, - conv_stats_path, - oai_client, - top_k, - conv_stats, - ): - query = qa.get("question") - gold_answer = qa.get("answer") - qa_category = qa.get("category") - if qa_category == 5: - return None - - # 1. two parallel process, - # 1. memos search + response - # 2. pre_memories can answer, true : direct answer false: - - # Search - assert self.args.frame == MEMOS_SCHEDULER_MODEL - cur_context, search_duration_ms = self.search_query( - client, query, metadata, frame, reversed_client=reversed_client, top_k=top_k - ) - if not cur_context: - logger.warning(f"No context found for query: {query[:100]}") - cur_context = "" - - # Context answer ability analysis (for memos_scheduler only) - if self.pre_context_cache[conv_id] is None: - # Update pre-context cache with current context and return - self.update_context( - conv_id=conv_id, - method=self.context_update_method, - cur_context=cur_context, - ) - - # ========= MemOS Scheduler update ========= - _ = client.mem_scheduler.update_working_memory_for_eval( - query=query, user_id=conv_id + "_speaker_a", top_k=top_k - ) - - _ = client.mem_scheduler.update_working_memory_for_eval( - query=query, user_id=conv_id + "_speaker_b", top_k=top_k - ) - return None - - context = self.pre_context_cache[conv_id] - - # Generate answer - answer_start = time.time() - answer = self.locomo_response(frame, oai_client, context, query) - response_duration_ms = (time.time() - answer_start) * 1000 - - can_answer, can_answer_duration_ms = self.eval_context( - context=context, query=query, gold_answer=gold_answer, oai_client=oai_client - ) - - # Record case for memos_scheduler - try: - recording_case = RecordingCase( - conv_id=conv_id, - query=query, - answer=answer, - context=cur_context, - pre_context=self.pre_context_cache[conv_id], - can_answer=can_answer, - can_answer_reason=f"Context analysis result: {'can answer' if can_answer else 'cannot answer'}", - search_duration_ms=search_duration_ms, - can_answer_duration_ms=can_answer_duration_ms, - response_duration_ms=response_duration_ms, - category=int(qa_category) if qa_category is not None else None, - golden_answer=str(qa.get("answer", "")), - ) - if can_answer: - self.can_answer_cases.append(recording_case) - else: - self.cannot_answer_cases.append(recording_case) - except Exception as e: - logger.error(f"Error creating RecordingCase: {e}") - print(f"Error creating RecordingCase: {e}") - logger.error(f"QA data: {qa}") - print(f"QA data: {qa}") - logger.error(f"Query: {query}") - logger.error(f"Answer: {answer}") - logger.error( - f"Golden answer (raw): {qa.get('answer')} (type: {type(qa.get('answer'))})" - ) - logger.error(f"Category: {qa_category} (type: {type(qa_category)})") - logger.error(f"Can answer: {can_answer}") - raise e - - # Update conversation stats and context - self._update_stats_and_context( - conv_id=conv_id, - frame=frame, - version=version, - conv_stats=conv_stats, - conv_stats_path=conv_stats_path, - query=query, - answer=answer, - gold_answer=gold_answer, - cur_context=cur_context, - can_answer=can_answer, - ) - # ========= MemOS Scheduler update ========= - _ = client.mem_scheduler.update_working_memory_for_eval( - query=query, user_id=conv_id + "_speaker_a", top_k=top_k - ) - - _ = client.mem_scheduler.update_working_memory_for_eval( - query=query, user_id=conv_id + "_speaker_b", top_k=top_k - ) - return { - "question": query, - "answer": answer, - "category": qa_category, - "golden_answer": gold_answer, - "search_context": cur_context, - "response_duration_ms": response_duration_ms, - "search_duration_ms": search_duration_ms, - "can_answer_duration_ms": can_answer_duration_ms, - "can_answer": can_answer if frame == MEMOS_SCHEDULER_MODEL else None, - } diff --git a/evaluation/scripts/temporal_locomo/modules/README.md b/evaluation/scripts/temporal_locomo/modules/README.md deleted file mode 100644 index 31a274dd0..000000000 --- a/evaluation/scripts/temporal_locomo/modules/README.md +++ /dev/null @@ -1,83 +0,0 @@ -# Evaluation Modules - -This directory contains the modularized evaluation system for temporal locomo evaluation, organized using inheritance and composition patterns. - -## Structure - -### Base Classes - -- **`base_eval_module.py`**: Contains the `BaseEvalModule` class with common functionality: - - Statistics management - - Data loading and processing - - File I/O operations - - Basic evaluation methods - -### Specialized Modules - -- **`client_manager.py`**: Contains the `ClientManager` class for managing different memory framework clients: - - Zep client management - - Mem0 client management - - Memos client management - - Memos scheduler client management - -- **`search_modules.py`**: Contains the `SearchModules` class with all search methods: - - `mem0_search()`: Mem0 framework search - - `mem0_graph_search()`: Mem0 graph framework search - - `memos_search()`: Memos framework search - - `memos_scheduler_search()`: Memos scheduler framework search - - `zep_search()`: Zep framework search - -- **`locomo_eval_module.py`**: Contains the main `LocomoEvalModule` class that combines all functionality: - - Inherits from `BaseEvalModule` - - Uses `ClientManager` for client management - - Uses `SearchModules` for search operations - - Provides unified interface for evaluation - -## Usage - -### Basic Usage - -```python -from modules import LocomoEvalModule -import argparse - -# Create arguments -args = argparse.Namespace() -args.frame = 'memos_scheduler' -args.version = 'v0.2.1' -args.top_k = 20 -args.workers = 1 - -# Initialize the evaluation module -eval_module = LocomoEvalModule(args) - -# Use the module -eval_module.print_eval_info() -eval_module.save_stats() -``` - -### Backward Compatibility - -For backward compatibility, the original `LocomoEvalModelModules` class is available as an alias: - -```python -from modules import LocomoEvalModule as LocomoEvalModelModules -``` - -## Benefits of Modularization - -1. **Separation of Concerns**: Each module has a specific responsibility -2. **Maintainability**: Easier to modify and extend individual components -3. **Testability**: Each module can be tested independently -4. **Reusability**: Modules can be reused in different contexts -5. **Readability**: Code is more organized and easier to understand - -## Migration from Original Code - -The original `eval_model_modules.py` has been refactored into this modular structure: - -- **Original class**: `LocomoEvalModelModules` -- **New main class**: `LocomoEvalModule` -- **Backward compatibility**: `LocomoEvalModelModules = LocomoEvalModule` - -All existing functionality is preserved, but now organized in a more maintainable structure. diff --git a/evaluation/scripts/temporal_locomo/modules/__init__.py b/evaluation/scripts/temporal_locomo/modules/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/evaluation/scripts/temporal_locomo/modules/base_eval_module.py b/evaluation/scripts/temporal_locomo/modules/base_eval_module.py deleted file mode 100644 index d056745cc..000000000 --- a/evaluation/scripts/temporal_locomo/modules/base_eval_module.py +++ /dev/null @@ -1,386 +0,0 @@ -import json -import os -import traceback - -from collections import defaultdict -from pathlib import Path -from threading import Lock -from typing import TYPE_CHECKING - -import pandas as pd - -from dotenv import load_dotenv - -from memos.configs.mem_scheduler import AuthConfig -from memos.log import get_logger - -from .constants import ( - BASE_DIR, - MEMOS_SCHEDULER_MODEL, -) -from .prompts import ( - CUSTOM_INSTRUCTIONS, -) -from .schemas import ContextUpdateMethod - - -if TYPE_CHECKING: - from .schemas import RecordingCase - - -logger = get_logger(__name__) - - -class BaseEvalModule: - def __init__(self, args): - # hyper-parameters - self.args = args - self.frame = self.args.frame - self.version = self.args.version - self.workers = self.args.workers - self.top_k = self.args.top_k - - # attributes - self.context_update_method = getattr( - self.args, "context_update_method", ContextUpdateMethod.PRE_CONTEXT - ) - self.custom_instructions = CUSTOM_INSTRUCTIONS - self.data_dir = Path(f"{BASE_DIR}/data") - self.locomo_df = pd.read_json(f"{self.data_dir}/locomo/locomo10.json") - - # Load temporal_locomo dataset if it exists - self.temporal_locomo_data = None - temporal_locomo_file = self.data_dir / "temporal_locomo" / "temporal_locomo_qa.json" - if temporal_locomo_file.exists(): - with open(temporal_locomo_file, encoding="utf-8") as f: - self.temporal_locomo_data = json.load(f) - logger.info( - f"Loaded temporal_locomo dataset with {len(self.temporal_locomo_data)} conversations" - ) - else: - logger.warning(f"Temporal locomo dataset not found at {temporal_locomo_file}") - - result_dir_prefix = getattr(self.args, "result_dir_prefix", "") - - # Configure result dir; if scheduler disabled and using memos scheduler, mark as ablation - if ( - hasattr(self.args, "scheduler_flag") - and self.frame == MEMOS_SCHEDULER_MODEL - and self.args.scheduler_flag is False - ): - self.result_dir = Path( - f"{BASE_DIR}/results/temporal_locomo/{result_dir_prefix}{self.frame}-{self.version}-ablation/" - ) - else: - self.result_dir = Path( - f"{BASE_DIR}/results/temporal_locomo/{result_dir_prefix}{self.frame}-{self.version}/" - ) - - if self.context_update_method != ContextUpdateMethod.PRE_CONTEXT: - self.result_dir = ( - self.result_dir.parent / f"{self.result_dir.name}_{self.context_update_method}" - ) - self.result_dir.mkdir(parents=True, exist_ok=True) - - self.search_path = self.result_dir / f"{self.frame}-{self.version}_search_results.json" - self.response_path = self.result_dir / f"{self.frame}-{self.version}_responses.json" - self.judged_path = self.result_dir / f"{self.frame}-{self.version}_judged.json" - self.grade_path = self.result_dir / f"{self.frame}-{self.version}_grades.json" - self.excel_path = self.result_dir / f"{self.frame}-{self.version}_metrics.xlsx" - - self.ingestion_storage_dir = self.result_dir / "storages" - self.mos_config_path = Path(f"{BASE_DIR}/configs-example/mos_w_scheduler_config.json") - self.mem_cube_config_path = Path(f"{BASE_DIR}/configs-example/mem_cube_config.json") - - self.openai_api_key = os.getenv("CHAT_MODEL_API_KEY") - self.openai_base_url = os.getenv("CHAT_MODEL_BASE_URL") - self.openai_chat_model = os.getenv("CHAT_MODEL") - - auth_config_path = Path(f"{BASE_DIR}/scripts/temporal_locomo/eval_auth.json") - if auth_config_path.exists(): - auth_config = AuthConfig.from_local_config(config_path=auth_config_path) - print( - f"✅ Configuration loaded successfully: from local config file {auth_config_path}" - ) - else: - # Load .env file first before reading environment variables - load_dotenv() - auth_config = AuthConfig.from_local_env() - print("✅ Configuration loaded successfully: from environment variables") - self.openai_api_key = auth_config.openai.api_key - self.openai_base_url = auth_config.openai.base_url - self.openai_chat_model = auth_config.openai.default_model - - self.mos_config_data = json.load(self.mos_config_path.open("r", encoding="utf-8")) - self.mem_cube_config_data = json.load(self.mem_cube_config_path.open("r", encoding="utf-8")) - - # Update LLM authentication information in MOS configuration using dictionary assignment - self.mos_config_data["mem_reader"]["config"]["llm"]["config"]["api_key"] = ( - auth_config.openai.api_key - ) - self.mos_config_data["mem_reader"]["config"]["llm"]["config"]["api_base"] = ( - auth_config.openai.base_url - ) - - # Update graph database authentication information in memory cube configuration using dictionary assignment - self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["uri"] = ( - auth_config.graph_db.uri - ) - self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["user"] = ( - auth_config.graph_db.user - ) - self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["password"] = ( - auth_config.graph_db.password - ) - self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["db_name"] = ( - auth_config.graph_db.db_name - ) - self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["auto_create"] = ( - auth_config.graph_db.auto_create - ) - - # Logger initialization - self.logger = logger - - # Statistics tracking with thread safety - self.stats = {self.frame: {self.version: defaultdict(dict)}} - self.stats[self.frame][self.version]["memory_stats"] = defaultdict(dict) - self.stats[self.frame][self.version]["memory_stats"]["total_queries"] = 0 - self.stats[self.frame][self.version]["memory_stats"]["can_answer_count"] = 0 - self.stats[self.frame][self.version]["memory_stats"]["cannot_answer_count"] = 0 - self.stats[self.frame][self.version]["memory_stats"]["answer_hit_rate"] = 0.0 - - # Initialize memory history for tracking retrieval results - self.stats_lock = Lock() - # Reflect CLI flag - self.scheduler_flag = bool(getattr(self.args, "scheduler_flag", True)) - self.stats_dir = self.result_dir / f"stats/{self.frame}_{self.version}" - self.stats_dir.mkdir(parents=True, exist_ok=True) # Ensure the directory exists - self.stats_path = self.stats_dir / "stats.txt" - - self.can_answer_cases: list[RecordingCase] = [] - self.cannot_answer_cases: list[RecordingCase] = [] - - def print_eval_info(self): - """ - Calculate and print the evaluation information including answer statistics for memory scheduler (thread-safe). - Shows total queries, can answer count, cannot answer count, and answer hit rate. - """ - with self.stats_lock: - # Get statistics - total_queries = self.stats[self.frame][self.version]["memory_stats"]["total_queries"] - can_answer_count = self.stats[self.frame][self.version]["memory_stats"][ - "can_answer_count" - ] - cannot_answer_count = self.stats[self.frame][self.version]["memory_stats"][ - "cannot_answer_count" - ] - hit_rate = self.stats[self.frame][self.version]["memory_stats"]["answer_hit_rate"] - - # Print basic statistics - print(f"Total Queries: {total_queries}") - logger.info(f"Total Queries: {total_queries}") - - print(f"Can Answer Count: {can_answer_count}") - logger.info(f"Can Answer Count: {can_answer_count}") - - print(f"Cannot Answer Count: {cannot_answer_count}") - logger.info(f"Cannot Answer Count: {cannot_answer_count}") - - # Verify count consistency - if total_queries != (can_answer_count + cannot_answer_count): - print( - f"WARNING: Count mismatch! Total ({total_queries}) != Can Answer ({can_answer_count}) + Cannot Answer ({cannot_answer_count})" - ) - logger.warning( - f"Count mismatch! Total ({total_queries}) != Can Answer ({can_answer_count}) + Cannot Answer ({cannot_answer_count})" - ) - - print(f"Answer Hit Rate: {hit_rate:.2f}% ({can_answer_count}/{total_queries})") - logger.info(f"Answer Hit Rate: {hit_rate:.2f}% ({can_answer_count}/{total_queries})") - - def save_stats(self): - """ - Serializes and saves the contents of self.stats to the specified path: - Base_dir/results/frame-version/stats - - This method handles directory creation, thread-safe access to statistics data, - and proper JSON serialization of complex data structures. - """ - try: - # Thread-safe access to the stats data using the lock - # Create a copy of the data to prevent modification during serialization - stats_data = dict(self.stats) - - # Helper function to convert defaultdict to regular dict for JSON serialization - def convert_defaultdict(obj): - if isinstance(obj, defaultdict): - return dict(obj) - return obj - - # Debug: Print stats summary before saving - self.logger.info(f"DEBUG: Saving stats for {self.frame}-{self.version}") - self.logger.info(f"DEBUG: Stats path: {self.stats_path}") - self.logger.info(f"DEBUG: Stats data keys: {list(stats_data.keys())}") - if self.frame in stats_data and self.version in stats_data[self.frame]: - frame_data = stats_data[self.frame][self.version] - self.logger.info(f"DEBUG: Memory stats: {frame_data.get('memory_stats', {})}") - self.logger.info( - f"DEBUG: Total queries: {frame_data.get('memory_stats', {}).get('total_queries', 0)}" - ) - - # Serialize and save the statistics data to file - with self.stats_path.open("w", encoding="utf-8") as fw: - json.dump(stats_data, fw, ensure_ascii=False, indent=2, default=convert_defaultdict) - - self.logger.info(f"Successfully saved stats to: {self.stats_path}") - print(f"DEBUG: Stats file created at {self.stats_path}") - - except Exception as e: - self.logger.error(f"Failed to save stats: {e!s}") - self.logger.error(traceback.format_exc()) - print(f"DEBUG: Error saving stats: {e}") - - def get_answer_hit_rate(self): - """ - Get current answer hit rate statistics. - - Returns: - dict: Hit rate statistics - """ - with self.stats_lock: - return { - "total_queries": self.stats[self.frame][self.version]["memory_stats"][ - "total_queries" - ], - "can_answer_count": self.stats[self.frame][self.version]["memory_stats"][ - "can_answer_count" - ], - "hit_rate_percentage": self.stats[self.frame][self.version]["memory_stats"][ - "answer_hit_rate" - ], - } - - def group_and_sort_qa_by_day(self, qa_set, sort_by_evidence): - """ - Groups QA pairs by day and sorts them chronologically within each day group. - - Args: - qa_set (list): List of dictionaries containing QA data with evidence references - - Returns: - dict: Dictionary where keys are day strings (e.g., 'D1') and values are - lists of QA pairs sorted by evidence order within that day - """ - # Initialize a dictionary that automatically creates lists for new keys - day_groups = defaultdict(list) - - # Process each QA pair in the input dataset - for qa in qa_set: - # Extract all unique days referenced in this QA pair's evidence - days = set() - for evidence in qa["evidence"]: - # Split evidence string (e.g., 'D1:3') into day and position parts - day = evidence.split(":")[0] # Gets 'D1', 'D2', etc. - days.add(day) - - # Add this QA pair to each day group it references - for day in days: - day_groups[day].append(qa) - - if sort_by_evidence: - # Sort QA pairs within each day group by their earliest evidence position - for day in day_groups: - # Create list of (qa, position) pairs for proper sorting - qa_position_pairs = [] - - for qa in day_groups[day]: - # Find the earliest evidence position for this day - earliest_position = None - for evidence in qa["evidence"]: - if evidence.startswith(day + ":"): - try: - position = int(evidence.split(":")[1]) - if earliest_position is None or position < earliest_position: - earliest_position = position - except (IndexError, ValueError): - # Skip invalid evidence format - continue - - if earliest_position is not None: - qa_position_pairs.append((qa, earliest_position)) - - # Sort by evidence position (earliest first) - qa_position_pairs = sorted(qa_position_pairs, key=lambda x: x[1]) - day_groups[day] = [qa for qa, _ in qa_position_pairs] - - return dict(day_groups) - - def convert_locomo_to_temporal_locomo(self, output_dir: str | None = None): - """ - Convert locomo dataset to temporal_locomo dataset format. - - This function processes the original locomo dataset and reorganizes it by days - with proper chronological ordering within each day group. - - Args: - output_dir: Output directory for the converted dataset. - Defaults to evaluation/data/temporal_locomo/ - - Returns: - str: Path to the converted dataset file - """ - if output_dir is None: - output_dir = f"{BASE_DIR}/data/temporal_locomo" - - # Create output directory - os.makedirs(output_dir, exist_ok=True) - - # Load original locomo data - locomo_data = self.locomo_df.to_dict("records") - - # Process each conversation - temporal_data = [] - - for conv_id, conversation in enumerate(locomo_data): - logger.info(f"Processing conversation {conv_id + 1}/{len(locomo_data)}") - - # Get QA pairs for this conversation - qa_set = conversation.get("qa", []) - - # Group and sort QA pairs by day - day_groups = self.group_and_sort_qa_by_day(qa_set, sort_by_evidence=False) - - # Create temporal structure for this conversation - temporal_conversation = {"conversation_id": f"locomo_exp_user_{conv_id}", "days": {}} - - # Process each day group - for day, qa_list in day_groups.items(): - temporal_conversation["days"][day] = { - "day_id": day, - "qa_pairs": qa_list, - "total_qa_pairs": len(qa_list), - } - - temporal_data.append(temporal_conversation) - - # Save the converted dataset - output_file = os.path.join(output_dir, "temporal_locomo_qa.json") - with open(output_file, "w", encoding="utf-8") as f: - json.dump(temporal_data, f, indent=2, ensure_ascii=False) - - logger.info(f"Converted dataset saved to: {output_file}") - logger.info(f"Total conversations: {len(temporal_data)}") - - # Log statistics - total_qa_pairs = sum(len(conv["qa"]) for conv in locomo_data) - total_temporal_qa_pairs = sum( - sum(day_data["total_qa_pairs"] for day_data in conv["days"].values()) - for conv in temporal_data - ) - - logger.info(f"Original QA pairs: {total_qa_pairs}") - logger.info(f"Temporal QA pairs: {total_temporal_qa_pairs}") - logger.info("QA pairs may be duplicated across days if they reference multiple days") - - return output_file diff --git a/evaluation/scripts/temporal_locomo/modules/client_manager.py b/evaluation/scripts/temporal_locomo/modules/client_manager.py deleted file mode 100644 index c5882179e..000000000 --- a/evaluation/scripts/temporal_locomo/modules/client_manager.py +++ /dev/null @@ -1,191 +0,0 @@ -""" -Client management module for handling different memory framework clients. -""" - -import os - -from mem0 import MemoryClient -from zep_cloud.client import Zep - -from memos.configs.mem_cube import GeneralMemCubeConfig -from memos.configs.mem_os import MOSConfig -from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube -from memos.mem_os.main import MOS -from memos.mem_scheduler.analyzer.scheduler_for_eval import SchedulerForEval - -from .base_eval_module import BaseEvalModule -from .constants import ( - MEM0_GRAPH_MODEL, - MEM0_MODEL, - MEMOS_MODEL, - MEMOS_SCHEDULER_MODEL, - ZEP_MODEL, -) -from .prompts import ( - ANSWER_PROMPT_MEM0, - ANSWER_PROMPT_MEMOS, - ANSWER_PROMPT_ZEP, -) - - -logger = get_logger(__name__) - - -class EvalModuleWithClientManager(BaseEvalModule): - """ - Manages different memory framework clients for evaluation. - """ - - def __init__(self, args): - super().__init__(args=args) - - def get_client_for_ingestion( - self, frame: str, user_id: str | None = None, version: str = "default" - ): - if frame == ZEP_MODEL: - zep = Zep(api_key=os.getenv("ZEP_API_KEY"), base_url="https://api.getzep.com/api/v2") - return zep - - elif frame in (MEM0_MODEL, MEM0_GRAPH_MODEL): - mem0 = MemoryClient(api_key=os.getenv("MEM0_API_KEY")) - mem0.update_project(custom_instructions=self.custom_instructions) - return mem0 - else: - if frame not in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: - raise NotImplementedError(f"Unsupported framework: {frame}") - - # scheduler is not needed in the ingestion step - self.mos_config_data["top_k"] = 20 - self.mos_config_data["enable_mem_scheduler"] = False - - mos_config = MOSConfig(**self.mos_config_data) - mos = MOS(mos_config) - mos.create_user(user_id=user_id) - - self.mem_cube_config_data["user_id"] = user_id - self.mem_cube_config_data["cube_id"] = user_id - self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["db_name"] = ( - f"{user_id.replace('_', '')}{version}" - ) - mem_cube_config = GeneralMemCubeConfig.model_validate(self.mem_cube_config_data) - mem_cube = GeneralMemCube(mem_cube_config) - - storage_path = str(self.ingestion_storage_dir / user_id) - try: - mem_cube.dump(storage_path) - except Exception as e: - print(f"dumping memory cube: {e!s} already exists, will use it.") - - mos.register_mem_cube( - mem_cube_name_or_path=storage_path, - mem_cube_id=user_id, - user_id=user_id, - ) - - return mos - - def get_client_from_storage( - self, frame: str, user_id: str | None = None, version: str = "default", top_k: int = 20 - ): - """ - Get a client instance for the specified memory framework. - - Args: - frame: Memory framework to use (zep, mem0, mem0_graph, memos, memos_scheduler) - user_id: Unique identifier for the user - version: Version identifier for result storage - top_k: Number of results to retrieve in search queries - - Returns: - Client instance for the specified framework - """ - storage_path = str(self.ingestion_storage_dir / user_id) - - if frame == ZEP_MODEL: - zep = Zep(api_key=os.getenv("ZEP_API_KEY"), base_url="https://api.getzep.com/api/v2") - return zep - - elif frame == [MEM0_MODEL, MEM0_GRAPH_MODEL]: - mem0 = MemoryClient(api_key=os.getenv("MEM0_API_KEY")) - return mem0 - - else: - if frame not in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: - raise NotImplementedError(f"Unsupported framework: {frame}") - - if frame == MEMOS_MODEL: - self.mos_config_data["enable_mem_scheduler"] = False - - self.mos_config_data["top_k"] = top_k - mos_config = MOSConfig(**self.mos_config_data) - mos = MOS(mos_config) - mos.create_user(user_id=user_id) - mos.register_mem_cube( - mem_cube_name_or_path=storage_path, - mem_cube_id=user_id, - user_id=user_id, - ) - - if frame == MEMOS_SCHEDULER_MODEL: - # Configure memory scheduler - mos.mem_scheduler.current_mem_cube = mos.mem_cubes[user_id] - mos.mem_scheduler.current_mem_cube_id = user_id - mos.mem_scheduler.current_user_id = user_id - - # Create SchedulerForEval instance with the same config - scheduler_for_eval = SchedulerForEval(config=mos.mem_scheduler.config) - # Initialize with the same modules as the original scheduler - scheduler_for_eval.initialize_modules( - chat_llm=mos.mem_scheduler.chat_llm, - process_llm=mos.mem_scheduler.process_llm, - db_engine=mos.mem_scheduler.db_engine, - ) - # Set the same context - scheduler_for_eval.current_mem_cube = mos.mem_cubes[user_id] - scheduler_for_eval.current_mem_cube_id = user_id - scheduler_for_eval.current_user_id = user_id - - # set llms to openai api - mos.chat_llm = mos.mem_reader.llm - for cube in mos.mem_cubes.values(): - cube.text_mem.dispatcher_llm = mos.mem_reader.llm - cube.text_mem.extractor_llm = mos.mem_reader.llm - - # Replace the original scheduler - mos.mem_scheduler = scheduler_for_eval - return mos - - def locomo_response(self, frame, llm_client, context: str, question: str) -> str: - if frame == ZEP_MODEL: - prompt = ANSWER_PROMPT_ZEP.format( - context=context, - question=question, - ) - elif frame in (MEM0_MODEL, MEM0_GRAPH_MODEL): - prompt = ANSWER_PROMPT_MEM0.format( - context=context, - question=question, - ) - elif frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: - prompt = ANSWER_PROMPT_MEMOS.format( - context=context, - question=question, - ) - else: - raise NotImplementedError() - response = llm_client.chat.completions.create( - model=self.openai_chat_model, - messages=[ - {"role": "system", "content": prompt}, - ], - temperature=0, - ) - - result = response.choices[0].message.content or "" - - if result == "": - with self.stats_lock: - self.stats[self.frame][self.version]["response_stats"]["response_failure"] += 1 - self.stats[self.frame][self.version]["response_stats"]["response_count"] += 1 - return result diff --git a/evaluation/scripts/temporal_locomo/modules/constants.py b/evaluation/scripts/temporal_locomo/modules/constants.py deleted file mode 100644 index 51ad7c729..000000000 --- a/evaluation/scripts/temporal_locomo/modules/constants.py +++ /dev/null @@ -1,19 +0,0 @@ -import sys - -from pathlib import Path - -from memos.log import get_logger - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -ZEP_MODEL = "zep" -MEM0_MODEL = "mem0" -MEM0_GRAPH_MODEL = "mem0_graph" -MEMOS_MODEL = "memos" -MEMOS_SCHEDULER_MODEL = "memos_scheduler" diff --git a/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py b/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py deleted file mode 100644 index d444ea62c..000000000 --- a/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py +++ /dev/null @@ -1,578 +0,0 @@ -import json -import time -import traceback - -from collections import defaultdict -from datetime import datetime -from typing import TYPE_CHECKING - -from openai import OpenAI -from tqdm import tqdm - -from memos.log import get_logger - -from .client_manager import EvalModuleWithClientManager -from .constants import ( - MEM0_GRAPH_MODEL, - MEM0_MODEL, - MEMOS_MODEL, - MEMOS_SCHEDULER_MODEL, - ZEP_MODEL, -) -from .prompts import ( - CONTEXT_ANSWERABILITY_PROMPT, - SEARCH_PROMPT_MEM0, - SEARCH_PROMPT_MEM0_GRAPH, - SEARCH_PROMPT_MEMOS, - SEARCH_PROMPT_ZEP, -) -from .utils import filter_memory_data - - -if TYPE_CHECKING: - from memos.mem_os.main import MOS -logger = get_logger(__name__) - - -class LocomoEvalModelModules(EvalModuleWithClientManager): - """ - Contains search methods for different memory frameworks. - """ - - def __init__(self, args): - super().__init__(args=args) - self.pre_context_cache = {} - - def analyze_context_answerability(self, context, query, gold_answer, oai_client): - """ - Analyze whether the given context can answer the query. - - Args: - context: The context string to analyze - query: The query string - oai_client: OpenAI client for LLM analysis - - Returns: - bool: True if context can answer the query, False otherwise - """ - try: - prompt = CONTEXT_ANSWERABILITY_PROMPT.format( - context=context, question=query, gold_answer=str(gold_answer) - ) - - response = oai_client.chat.completions.create( - model="gpt-4o-mini", - messages=[{"role": "user", "content": prompt}], - temperature=0.1, - max_tokens=10, - ) - - answer = response.choices[0].message.content.strip().upper() - return answer == "YES" - except Exception as e: - logger.error(f"Error analyzing context answerability: {e}") - return False - - def mem0_search(self, client, query, speaker_a_user_id, speaker_b_user_id, top_k=20): - """ - Search memories using the mem0 framework. - - Args: - client: mem0 client instance - query: Search query string - speaker_a_user_id: User ID for first speaker - speaker_b_user_id: User ID for second speaker - top_k: Number of results to retrieve - - Returns: - Tuple containing formatted context and search duration in milliseconds - """ - start = time.time() - search_speaker_a_results = client.search( - query=query, - top_k=top_k, - user_id=speaker_a_user_id, - output_format="v1.1", - version="v2", - filters={"AND": [{"user_id": f"{speaker_a_user_id}"}, {"run_id": "*"}]}, - ) - search_speaker_b_results = client.search( - query=query, - top_k=top_k, - user_id=speaker_b_user_id, - output_format="v1.1", - version="v2", - filters={"AND": [{"user_id": f"{speaker_b_user_id}"}, {"run_id": "*"}]}, - ) - - # Format speaker A memories - search_speaker_a_memory = [ - { - "memory": memory["memory"], - "timestamp": memory["created_at"], - "score": round(memory["score"], 2), - } - for memory in search_speaker_a_results["results"] - ] - - search_speaker_a_memory = [ - [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_a_memory] - ] - - # Format speaker B memories - search_speaker_b_memory = [ - { - "memory": memory["memory"], - "timestamp": memory["created_at"], - "score": round(memory["score"], 2), - } - for memory in search_speaker_b_results["results"] - ] - - search_speaker_b_memory = [ - [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_b_memory] - ] - - # Create context using template - context = SEARCH_PROMPT_MEM0.format( - speaker_1_user_id=speaker_a_user_id.split("_")[0], - speaker_1_memories=json.dumps(search_speaker_a_memory, indent=4), - speaker_2_user_id=speaker_b_user_id.split("_")[0], - speaker_2_memories=json.dumps(search_speaker_b_memory, indent=4), - ) - - duration_ms = (time.time() - start) * 1000 - return context, duration_ms - - def memos_search( - self, client, query, conv_id, speaker_a, speaker_b, reversed_client=None, top_k=20 - ): - """ - Search memories using the memos framework. - - Args: - client: memos client instance - query: Search query string - conv_id: Conversation ID - speaker_a: First speaker identifier - speaker_b: Second speaker identifier - reversed_client: Client instance for reversed speaker context - - Returns: - Tuple containing formatted context and search duration in milliseconds - """ - start = time.time() - # Search memories for speaker A - search_a_results = client.search(query=query, user_id=conv_id + "_speaker_a") - filtered_search_a_results = filter_memory_data(search_a_results)["text_mem"][0]["memories"] - speaker_a_context = "" - for item in filtered_search_a_results[:top_k]: - speaker_a_context += f"{item['memory']}\n" - - # Search memories for speaker B - search_b_results = reversed_client.search( - query=query, - user_id=conv_id + "_speaker_b", - ) - filtered_search_b_results = filter_memory_data(search_b_results)["text_mem"][0]["memories"] - speaker_b_context = "" - for item in filtered_search_b_results[:top_k]: - speaker_b_context += f"{item['memory']}\n" - - # Create context using template - context = SEARCH_PROMPT_MEMOS.format( - speaker_1=speaker_a, - speaker_1_memories=speaker_a_context, - speaker_2=speaker_b, - speaker_2_memories=speaker_b_context, - ) - - duration_ms = (time.time() - start) * 1000 - return context, duration_ms - - def memos_scheduler_search( - self, client, query, conv_id, speaker_a, speaker_b, reversed_client=None, top_k=20 - ): - start = time.time() - client: MOS = client - - if not self.scheduler_flag: - # if not scheduler_flag, search to update working memory - self.memos_search(client, query, conv_id, speaker_a, speaker_b, reversed_client, top_k) - - # Search for speaker A - search_a_results = client.mem_scheduler.search_for_eval( - query=query, - user_id=conv_id + "_speaker_a", - top_k=top_k, - scheduler_flag=self.scheduler_flag, - ) - - # Search for speaker B - search_b_results = reversed_client.mem_scheduler.search_for_eval( - query=query, - user_id=conv_id + "_speaker_b", - top_k=top_k, - scheduler_flag=self.scheduler_flag, - ) - - speaker_a_context = "" - for item in search_a_results: - speaker_a_context += f"{item}\n" - - speaker_b_context = "" - for item in search_b_results: - speaker_b_context += f"{item}\n" - - context = SEARCH_PROMPT_MEMOS.format( - speaker_1=speaker_a, - speaker_1_memories=speaker_a_context, - speaker_2=speaker_b, - speaker_2_memories=speaker_b_context, - ) - - logger.info(f'query "{query[:100]}", context: {context[:100]}"') - duration_ms = (time.time() - start) * 1000 - - return context, duration_ms - - def mem0_graph_search(self, client, query, speaker_a_user_id, speaker_b_user_id, top_k=20): - start = time.time() - search_speaker_a_results = client.search( - query=query, - top_k=top_k, - user_id=speaker_a_user_id, - output_format="v1.1", - version="v2", - enable_graph=True, - filters={"AND": [{"user_id": f"{speaker_a_user_id}"}, {"run_id": "*"}]}, - ) - search_speaker_b_results = client.search( - query=query, - top_k=top_k, - user_id=speaker_b_user_id, - output_format="v1.1", - version="v2", - enable_graph=True, - filters={"AND": [{"user_id": f"{speaker_b_user_id}"}, {"run_id": "*"}]}, - ) - - search_speaker_a_memory = [ - { - "memory": memory["memory"], - "timestamp": memory["created_at"], - "score": round(memory["score"], 2), - } - for memory in search_speaker_a_results["results"] - ] - - search_speaker_a_memory = [ - [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_a_memory] - ] - - search_speaker_b_memory = [ - { - "memory": memory["memory"], - "timestamp": memory["created_at"], - "score": round(memory["score"], 2), - } - for memory in search_speaker_b_results["results"] - ] - - search_speaker_b_memory = [ - [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_b_memory] - ] - - search_speaker_a_graph = [ - { - "source": relation["source"], - "relationship": relation["relationship"], - "target": relation["target"], - } - for relation in search_speaker_a_results["relations"] - ] - - search_speaker_b_graph = [ - { - "source": relation["source"], - "relationship": relation["relationship"], - "target": relation["target"], - } - for relation in search_speaker_b_results["relations"] - ] - context = SEARCH_PROMPT_MEM0_GRAPH.format( - speaker_1_user_id=speaker_a_user_id.split("_")[0], - speaker_1_memories=json.dumps(search_speaker_a_memory, indent=4), - speaker_1_graph_memories=json.dumps(search_speaker_a_graph, indent=4), - speaker_2_user_id=speaker_b_user_id.split("_")[0], - speaker_2_memories=json.dumps(search_speaker_b_memory, indent=4), - speaker_2_graph_memories=json.dumps(search_speaker_b_graph, indent=4), - ) - print(query, context) - duration_ms = (time.time() - start) * 1000 - return context, duration_ms - - def zep_search(self, client, query, group_id, top_k=20): - start = time.time() - nodes_result = client.graph.search( - query=query, - group_id=group_id, - scope="nodes", - reranker="rrf", - limit=top_k, - ) - edges_result = client.graph.search( - query=query, - group_id=group_id, - scope="edges", - reranker="cross_encoder", - limit=top_k, - ) - - nodes = nodes_result.nodes - edges = edges_result.edges - - facts = [f" - {edge.fact} (event_time: {edge.valid_at})" for edge in edges] - entities = [f" - {node.name}: {node.summary}" for node in nodes] - - context = SEARCH_PROMPT_ZEP.format(facts="\n".join(facts), entities="\n".join(entities)) - - duration_ms = (time.time() - start) * 1000 - - return context, duration_ms - - def search_query(self, client, query, metadata, frame, reversed_client=None, top_k=20): - conv_id = metadata.get("conv_id") - speaker_a = metadata.get("speaker_a") - speaker_b = metadata.get("speaker_b") - speaker_a_user_id = metadata.get("speaker_a_user_id") - speaker_b_user_id = metadata.get("speaker_b_user_id") - - if frame == ZEP_MODEL: - context, duration_ms = self.zep_search(client, query, conv_id, top_k) - elif frame == MEM0_MODEL: - context, duration_ms = self.mem0_search( - client, query, speaker_a_user_id, speaker_b_user_id, top_k - ) - elif frame == MEM0_GRAPH_MODEL: - context, duration_ms = self.mem0_graph_search( - client, query, speaker_a_user_id, speaker_b_user_id, top_k - ) - elif frame == MEMOS_MODEL: - context, duration_ms = self.memos_search( - client, query, conv_id, speaker_a, speaker_b, reversed_client, top_k - ) - elif frame == MEMOS_SCHEDULER_MODEL: - context, duration_ms = self.memos_scheduler_search( - client, query, conv_id, speaker_a, speaker_b, reversed_client, top_k - ) - else: - raise NotImplementedError() - - return context, duration_ms - - def _initialize_conv_stats(self): - """Create a fresh statistics dictionary for a conversation.""" - return { - "total_queries": 0, - "can_answer_count": 0, - "cannot_answer_count": 0, - "answer_hit_rate": 0.0, - "response_failure": 0, - "response_count": 0, - } - - def _build_day_groups(self, temporal_conv): - """Build mapping day_id -> qa_pairs from a temporal conversation dict.""" - day_groups = {} - for day_id, day_data in temporal_conv.get("days", {}).items(): - day_groups[day_id] = day_data.get("qa_pairs", []) - return day_groups - - def _build_metadata(self, speaker_a, speaker_b, speaker_a_user_id, speaker_b_user_id, conv_id): - """Assemble metadata for downstream calls.""" - return { - "speaker_a": speaker_a, - "speaker_b": speaker_b, - "speaker_a_user_id": speaker_a_user_id, - "speaker_b_user_id": speaker_b_user_id, - "conv_id": conv_id, - } - - def _get_clients(self, frame, speaker_a_user_id, speaker_b_user_id, conv_id, version, top_k): - """Return (client, reversed_client) according to the target frame.""" - reversed_client = None - if frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: - client = self.get_client_from_storage(frame, speaker_a_user_id, version, top_k=top_k) - reversed_client = self.get_client_from_storage( - frame, speaker_b_user_id, version, top_k=top_k - ) - else: - client = self.get_client_from_storage(frame, conv_id, version) - return client, reversed_client - - def _save_conv_stats(self, conv_id, frame, version, conv_stats, conv_stats_path): - """Persist per-conversation stats to disk.""" - conv_stats_data = { - "conversation_id": conv_id, - "frame": frame, - "version": version, - "statistics": conv_stats, - "timestamp": str(datetime.now()), - } - with open(conv_stats_path, "w") as fw: - json.dump(conv_stats_data, fw, indent=2, ensure_ascii=False) - print(f"Saved conversation stats for {conv_id} to {conv_stats_path}") - - def _write_user_search_results(self, user_search_path, search_results, conv_id): - """Write per-user search results to a temporary JSON file.""" - with open(user_search_path, "w") as fw: - json.dump(dict(search_results), fw, indent=2) - print(f"Save search results {conv_id}") - - def process_user(self, conv_id, locomo_df, frame, version, top_k=20): - user_search_path = self.result_dir / f"tmp/{frame}_locomo_search_results_{conv_id}.json" - user_search_path.parent.mkdir(exist_ok=True, parents=True) - search_results = defaultdict(list) - response_results = defaultdict(list) - conv_stats_path = self.stats_dir / f"{frame}_{version}_conv_{conv_id}_stats.json" - - conversation = locomo_df["conversation"].iloc[conv_id] - speaker_a = conversation.get("speaker_a", "speaker_a") - speaker_b = conversation.get("speaker_b", "speaker_b") - - # Use temporal_locomo data if available, otherwise fall back to original locomo data - temporal_conv = self.temporal_locomo_data[conv_id] - conv_id = temporal_conv["conversation_id"] - speaker_a_user_id = f"{conv_id}_speaker_a" - speaker_b_user_id = f"{conv_id}_speaker_b" - - # Process temporal data by days - day_groups = {} - for day_id, day_data in temporal_conv["days"].items(): - day_groups[day_id] = day_data["qa_pairs"] - - # Initialize conversation-level statistics - conv_stats = self._initialize_conv_stats() - - metadata = self._build_metadata( - speaker_a, speaker_b, speaker_a_user_id, speaker_b_user_id, conv_id - ) - - client, reversed_client = self._get_clients( - frame, speaker_a_user_id, speaker_b_user_id, conv_id, version, top_k - ) - - oai_client = OpenAI(api_key=self.openai_api_key, base_url=self.openai_base_url) - - with self.stats_lock: - self.pre_context_cache[conv_id] = None - - def process_qa(qa): - return self._process_single_qa( - qa, - client=client, - reversed_client=reversed_client, - metadata=metadata, - frame=frame, - version=version, - conv_id=conv_id, - conv_stats_path=conv_stats_path, - oai_client=oai_client, - top_k=top_k, - conv_stats=conv_stats, - ) - - # =================================== - conv_stats["theoretical_total_queries"] = 0 - for day, qa_list in day_groups.items(): - conv_stats["theoretical_total_queries"] += len(qa_list) - 1 - conv_stats["processing_failure_count"] = 0 - print(f"Processing user {conv_id} day {day}") - for qa in tqdm(qa_list, desc=f"Processing user {conv_id} day {day}"): - try: - result = process_qa(qa) - except Exception as e: - logger.error(f"Error: {e}. traceback: {traceback.format_exc()}") - conv_stats["processing_failure_count"] += 1 - continue - if result: - context_preview = ( - result["search_context"][:20] + "..." - if result["search_context"] - else "No context" - ) - if "can_answer" in result: - logger.info("Print can_answer case") - logger.info( - { - "question": result["question"][:100], - "pre context can answer": result["can_answer"], - "answer": result["answer"][:100], - "golden_answer": result["golden_answer"], - "search_context": context_preview[:100], - "search_duration_ms": result["search_duration_ms"], - } - ) - - search_results[conv_id].append( - { - "question": result["question"], - "context": result["search_context"], - "search_duration_ms": result["search_duration_ms"], - } - ) - response_results[conv_id].append(result) - - logger.warning( - f"Finished processing user {conv_id} day {day}, data_length: {len(qa_list)}" - ) - - # recording separate search results - with open(user_search_path, "w") as fw: - json.dump(dict(search_results), fw, indent=2) - print(f"Save search results {conv_id}") - - search_durations = [] - for result in response_results[conv_id]: - if "search_duration_ms" in result: - search_durations.append(result["search_duration_ms"]) - - if search_durations: - avg_search_duration = sum(search_durations) / len(search_durations) - with self.stats_lock: - if self.stats[self.frame][self.version]["memory_stats"]["avg_search_duration_ms"]: - self.stats[self.frame][self.version]["memory_stats"][ - "avg_search_duration_ms" - ] = ( - self.stats[self.frame][self.version]["memory_stats"][ - "avg_search_duration_ms" - ] - + avg_search_duration - ) / 2 - print(f"Average search duration: {avg_search_duration:.2f} ms") - - # Dump stats after processing each user - self.save_stats() - - return search_results, response_results - - def process_user_wrapper(self, args): - """ - Wraps the process_user function to support parallel execution and error handling. - - Args: - args: Tuple containing parameters for process_user - - Returns: - tuple: Contains user results or error information - """ - idx, locomo_df, frame, version, top_k = args - try: - print(f"Processing user {idx}...") - user_search_results, user_response_results = self.process_user( - idx, locomo_df, frame, version, top_k - ) - return (user_search_results, user_response_results, None) - except Exception as e: - return (None, None, (idx, e, traceback.format_exc())) diff --git a/evaluation/scripts/temporal_locomo/modules/prompts.py b/evaluation/scripts/temporal_locomo/modules/prompts.py deleted file mode 100644 index c88a8ff28..000000000 --- a/evaluation/scripts/temporal_locomo/modules/prompts.py +++ /dev/null @@ -1,219 +0,0 @@ -CUSTOM_INSTRUCTIONS = """ -Generate personal memories that follow these guidelines: - -1. Each memory should be self-contained with complete context, including: - - The person's name, do not use "user" while creating memories - - Personal details (career aspirations, hobbies, life circumstances) - - Emotional states and reactions - - Ongoing journeys or future plans - - Specific dates when events occurred - -2. Include meaningful personal narratives focusing on: - - Identity and self-acceptance journeys - - Family planning and parenting - - Creative outlets and hobbies - - Mental health and self-care activities - - Career aspirations and education goals - - Important life events and milestones - -3. Make each memory rich with specific details rather than general statements - - Include timeframes (exact dates when possible) - - Name specific activities (e.g., "charity race for mental health" rather than just "exercise") - - Include emotional context and personal growth elements - -4. Extract memories only from user messages, not incorporating assistant responses - -5. Format each memory as a paragraph with a clear narrative structure that captures the person's experience, challenges, and aspirations -""" - -SEARCH_PROMPT_ZEP = """ -FACTS and ENTITIES represent relevant context to the current conversation. - -# These are the most relevant facts for the conversation along with the datetime of the event that the fact refers to. -If a fact mentions something happening a week ago, then the datetime will be the date time of last week and not the datetime -of when the fact was stated. -Timestamps in memories represent the actual time the event occurred, not the time the event was mentioned in a message. - - -{facts} - - -# These are the most relevant entities -# ENTITY_NAME: entity summary - -{entities} - -""" - -SEARCH_PROMPT_MEM0 = """Memories for user {speaker_1_user_id}: - - {speaker_1_memories} - - Memories for user {speaker_2_user_id}: - - {speaker_2_memories} -""" - -SEARCH_PROMPT_MEM0_GRAPH = """Memories for user {speaker_1_user_id}: - - {speaker_1_memories} - - Relations for user {speaker_1_user_id}: - - {speaker_1_graph_memories} - - Memories for user {speaker_2_user_id}: - - {speaker_2_memories} - - Relations for user {speaker_2_user_id}: - - {speaker_2_graph_memories} -""" - -SEARCH_PROMPT_MEMOS = """Memories for user {speaker_1}: - - {speaker_1_memories} - - Memories for user {speaker_2}: - - {speaker_2_memories} -""" - - -ANSWER_PROMPT_MEM0 = """ - You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories. - - # CONTEXT: - You have access to memories from two speakers in a conversation. These memories contain - timestamped information that may be relevant to answering the question. - - # INSTRUCTIONS: - 1. Carefully analyze all provided memories from both speakers - 2. Pay special attention to the timestamps to determine the answer - 3. If the question asks about a specific event or fact, look for direct evidence in the memories - 4. If the memories contain contradictory information, prioritize the most recent memory - 5. If there is a question about time references (like "last year", "two months ago", etc.), - calculate the actual date based on the memory timestamp. For example, if a memory from - 4 May 2022 mentions "went to India last year," then the trip occurred in 2021. - 6. Always convert relative time references to specific dates, months, or years. For example, - convert "last year" to "2022" or "two months ago" to "March 2023" based on the memory - timestamp. Ignore the reference while answering the question. - 7. Focus only on the content of the memories from both speakers. Do not confuse character - names mentioned in memories with the actual users who created those memories. - 8. The answer should be less than 5-6 words. - - # APPROACH (Think step by step): - 1. First, examine all memories that contain information related to the question - 2. Examine the timestamps and content of these memories carefully - 3. Look for explicit mentions of dates, times, locations, or events that answer the question - 4. If the answer requires calculation (e.g., converting relative time references), show your work - 5. Formulate a precise, concise answer based solely on the evidence in the memories - 6. Double-check that your answer directly addresses the question asked - 7. Ensure your final answer is specific and avoids vague time references - - {context} - - Question: {question} - - Answer: - """ - - -ANSWER_PROMPT_ZEP = """ - You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories. - - # CONTEXT: - You have access to memories from a conversation. These memories contain - timestamped information that may be relevant to answering the question. - - # INSTRUCTIONS: - 1. Carefully analyze all provided memories - 2. Pay special attention to the timestamps to determine the answer - 3. If the question asks about a specific event or fact, look for direct evidence in the memories - 4. If the memories contain contradictory information, prioritize the most recent memory - 5. If there is a question about time references (like "last year", "two months ago", etc.), - calculate the actual date based on the memory timestamp. For example, if a memory from - 4 May 2022 mentions "went to India last year," then the trip occurred in 2021. - 6. Always convert relative time references to specific dates, months, or years. For example, - convert "last year" to "2022" or "two months ago" to "March 2023" based on the memory - timestamp. Ignore the reference while answering the question. - 7. Focus only on the content of the memories. Do not confuse character - names mentioned in memories with the actual users who created those memories. - 8. The answer should be less than 5-6 words. - - # APPROACH (Think step by step): - 1. First, examine all memories that contain information related to the question - 2. Examine the timestamps and content of these memories carefully - 3. Look for explicit mentions of dates, times, locations, or events that answer the question - 4. If the answer requires calculation (e.g., converting relative time references), show your work - 5. Formulate a precise, concise answer based solely on the evidence in the memories - 6. Double-check that your answer directly addresses the question asked - 7. Ensure your final answer is specific and avoids vague time references - - Context: - - {context} - - Question: {question} - Answer: - """ - -ANSWER_PROMPT_MEMOS = """ - You are a knowledgeable and helpful AI assistant. - - # CONTEXT: - You have access to memories from two speakers in a conversation. These memories contain - timestamped information that may be relevant to answering the question. - - # INSTRUCTIONS: - 1. Carefully analyze all provided memories. Synthesize information across different entries if needed to form a complete answer. - 2. Pay close attention to the timestamps to determine the answer. If memories contain contradictory information, the **most recent memory** is the source of truth. - 3. If the question asks about a specific event or fact, look for direct evidence in the memories. - 4. Your answer must be grounded in the memories. However, you may use general world knowledge to interpret or complete information found within a memory (e.g., identifying a landmark mentioned by description). - 5. If the question involves time references (like "last year", "two months ago", etc.), you **must** calculate the actual date based on the memory's timestamp. For example, if a memory from 4 May 2022 mentions "went to India last year," then the trip occurred in 2021. - 6. Always convert relative time references to specific dates, months, or years in your final answer. - 7. Do not confuse character names mentioned in memories with the actual users who created them. - 8. The answer must be brief (under 5-6 words) and direct, with no extra description. - - # APPROACH (Think step by step): - 1. First, examine all memories that contain information related to the question. - 2. Synthesize findings from multiple memories if a single entry is insufficient. - 3. Examine timestamps and content carefully, looking for explicit dates, times, locations, or events. - 4. If the answer requires calculation (e.g., converting relative time references), perform the calculation. - 5. Formulate a precise, concise answer based on the evidence from the memories (and allowed world knowledge). - 6. Double-check that your answer directly addresses the question asked and adheres to all instructions. - 7. Ensure your final answer is specific and avoids vague time references. - - {context} - - Question: {question} - - Answer: - """ - -CONTEXT_ANSWERABILITY_PROMPT = """ -You are an AI assistant that analyzes whether given context can answer a specific question, considering the ground-truth answer. - -# TASK: -Analyze the provided context and determine if it contains sufficient information to answer the given question. Use the provided ground-truth answer to guide your judgment: if the context contains the necessary evidence to derive that answer (explicitly or via direct inference), respond YES; otherwise respond NO. - -# INSTRUCTIONS: -1. Carefully examine the context provided -2. Identify if the context contains information directly related to the question -3. Determine if the information is sufficient to provide a complete answer that matches the ground-truth -4. Consider both explicit mentions and straightforward implications present in the context -5. Return only "YES" if the context can yield the ground-truth answer, "NO" if it cannot - -# CONTEXT: -{context} - -# QUESTION: -{question} - -# GROUND_TRUTH_ANSWER: -{gold_answer} - -# ANALYSIS: -Can this context answer the question and support the ground-truth answer? (YES/NO): -""" diff --git a/evaluation/scripts/temporal_locomo/modules/schemas.py b/evaluation/scripts/temporal_locomo/modules/schemas.py deleted file mode 100644 index fee89cc62..000000000 --- a/evaluation/scripts/temporal_locomo/modules/schemas.py +++ /dev/null @@ -1,161 +0,0 @@ -from typing import Any - -from pydantic import BaseModel, Field - - -class ContextUpdateMethod: - """Enumeration for context update methods""" - - PRE_CONTEXT = "pre_context" - CHAT_HISTORY = "chat_history" - CURRENT_CONTEXT = "current_context" - - @classmethod - def values(cls): - """Return a list of all constant values""" - return [ - getattr(cls, attr) - for attr in dir(cls) - if not attr.startswith("_") and isinstance(getattr(cls, attr), str) - ] - - -class RecordingCase(BaseModel): - """ - Data structure for recording evaluation cases in temporal locomo evaluation. - - This schema represents a single evaluation case containing conversation history, - context information, memory data, and evaluation results. - """ - - # Conversation identification - conv_id: str = Field(description="Conversation identifier for this evaluation case") - - context: str = Field( - default="", - description="Current search context retrieved from memory systems for answering the query", - ) - - pre_context: str | None = Field( - default=None, - description="Previous context from the last query, used for answerability analysis", - ) - - # Query and answer information - query: str = Field(description="The current question/query being evaluated") - - answer: str = Field(description="The generated answer for the query") - - # Evaluation metrics - can_answer: bool | None = Field( - default=None, - description="Whether the context can answer the query (only for memos_scheduler frame)", - ) - - can_answer_reason: str | None = Field( - default=None, description="Reasoning for the can_answer decision" - ) - - # Additional metadata - category: int | None = Field( - default=None, description="Category of the query (1-4, where 5 is filtered out)" - ) - - golden_answer: str | None = Field( - default=None, description="Ground truth answer for evaluation" - ) - - search_duration_ms: float | None = Field( - default=None, description="Time taken for memory search in milliseconds" - ) - - response_duration_ms: float | None = Field( - default=None, description="Time taken for response generation in milliseconds" - ) - - can_answer_duration_ms: float | None = Field( - default=None, description="Time taken for answerability analysis in milliseconds" - ) - - def to_dict(self) -> dict[str, Any]: - """ - Convert the RecordingCase to a dictionary for serialization. - - Returns: - Dict[str, Any]: Dictionary representation of the RecordingCase - """ - return self.dict() - - def to_json(self, indent: int = 2) -> str: - """ - Convert the RecordingCase to a JSON string. - - Args: - indent: JSON indentation level - - Returns: - str: JSON string representation of the RecordingCase - """ - return self.json(indent=indent, ensure_ascii=False) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> "RecordingCase": - """ - Create a RecordingCase from a dictionary. - - Args: - data: Dictionary containing RecordingCase data - - Returns: - RecordingCase: New instance created from the dictionary - """ - return cls(**data) - - @classmethod - def from_json(cls, json_str: str) -> "RecordingCase": - """ - Create a RecordingCase from a JSON string. - - Args: - json_str: JSON string containing RecordingCase data - - Returns: - RecordingCase: New instance created from the JSON string - """ - import json - - data = json.loads(json_str) - return cls.from_dict(data) - - class Config: - """Pydantic configuration""" - - extra = "allow" # Allow additional fields not defined in the schema - validate_assignment = True # Validate on assignment - use_enum_values = True # Use enum values instead of enum names - - -class TimeEvalRecordingCase(BaseModel): - memos_search_duration_ms: float | None = Field( - default=None, description="Time taken for memory search in milliseconds" - ) - - memos_response_duration_ms: float | None = Field( - default=None, description="Time taken for response generation in milliseconds" - ) - - memos_can_answer_duration_ms: float | None = Field( - default=None, description="Time taken for answerability analysis in milliseconds" - ) - - scheduler_search_duration_ms: float | None = Field( - default=None, description="Time taken for memory search in milliseconds" - ) - - scheduler_response_duration_ms: float | None = Field( - default=None, description="Time taken for response generation in milliseconds" - ) - - scheduler_can_answer_duration_ms: float | None = Field( - default=None, description="Time taken for answerability analysis in milliseconds" - ) diff --git a/evaluation/scripts/temporal_locomo/modules/utils.py b/evaluation/scripts/temporal_locomo/modules/utils.py deleted file mode 100644 index 215bc4256..000000000 --- a/evaluation/scripts/temporal_locomo/modules/utils.py +++ /dev/null @@ -1,296 +0,0 @@ -import json - -from pathlib import Path - -from .schemas import RecordingCase - - -def filter_memory_data(memories_data): - filtered_data = {} - for key, value in memories_data.items(): - if key == "text_mem": - filtered_data[key] = [] - for mem_group in value: - # Check if it's the new data structure (list of TextualMemoryItem objects) - if "memories" in mem_group and isinstance(mem_group["memories"], list): - # New data structure: directly a list of TextualMemoryItem objects - filtered_memories = [] - for memory_item in mem_group["memories"]: - # Create filtered dictionary - filtered_item = { - "id": memory_item.id, - "memory": memory_item.memory, - "metadata": {}, - } - # Filter metadata, excluding embedding - if hasattr(memory_item, "metadata") and memory_item.metadata: - for attr_name in dir(memory_item.metadata): - if not attr_name.startswith("_") and attr_name != "embedding": - attr_value = getattr(memory_item.metadata, attr_name) - if not callable(attr_value): - filtered_item["metadata"][attr_name] = attr_value - filtered_memories.append(filtered_item) - - filtered_group = { - "cube_id": mem_group.get("cube_id", ""), - "memories": filtered_memories, - } - filtered_data[key].append(filtered_group) - else: - # Old data structure: dictionary with nodes and edges - filtered_group = { - "memories": {"nodes": [], "edges": mem_group["memories"].get("edges", [])} - } - for node in mem_group["memories"].get("nodes", []): - filtered_node = { - "id": node.get("id"), - "memory": node.get("memory"), - "metadata": { - k: v - for k, v in node.get("metadata", {}).items() - if k != "embedding" - }, - } - filtered_group["memories"]["nodes"].append(filtered_node) - filtered_data[key].append(filtered_group) - else: - filtered_data[key] = value - return filtered_data - - -def save_recording_cases( - cases: list[RecordingCase], output_dir: str | Path, filename: str = "recording_cases.json" -) -> Path: - """ - Save a list of RecordingCase objects to a JSON file. - - Args: - cases: List of RecordingCase objects to save - output_dir: Directory to save the file - filename: Name of the output file (default: "recording_cases.json") - - Returns: - Path: Path to the saved file - """ - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - file_path = output_dir / filename - - # Convert cases to dictionaries for JSON serialization - cases_data = [case.to_dict() for case in cases] - - with open(file_path, "w", encoding="utf-8") as f: - json.dump(cases_data, f, indent=2, ensure_ascii=False) - - return file_path - - -def load_recording_cases(file_path: str | Path) -> list[RecordingCase]: - """ - Load RecordingCase objects from a JSON file. - - Args: - file_path: Path to the JSON file containing RecordingCase data - - Returns: - List[RecordingCase]: List of RecordingCase objects loaded from the file - """ - file_path = Path(file_path) - - with open(file_path, encoding="utf-8") as f: - cases_data = json.load(f) - - return [RecordingCase.from_dict(case_data) for case_data in cases_data] - - -def save_evaluation_cases( - can_answer_cases: list[RecordingCase], - cannot_answer_cases: list[RecordingCase], - output_dir: str | Path, - frame: str = "default", - version: str = "default", -) -> dict[str, Path]: - """ - Save both can_answer_cases and cannot_answer_cases to separate JSON files. - - Args: - can_answer_cases: List of cases that can be answered - cannot_answer_cases: List of cases that cannot be answered - output_dir: Directory to save the files - frame: Framework name for filename prefix - version: Version identifier for filename - - Returns: - Dict[str, Path]: Dictionary mapping case type to saved file path - """ - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - saved_files = {} - - # Save can_answer_cases - if can_answer_cases: - can_answer_filename = f"{frame}_{version}_can_answer_cases.json" - can_answer_path = save_recording_cases(can_answer_cases, output_dir, can_answer_filename) - saved_files["can_answer_cases"] = can_answer_path - print(f"Saved {len(can_answer_cases)} can_answer_cases to {can_answer_path}") - - # Save cannot_answer_cases - if cannot_answer_cases: - cannot_answer_filename = f"{frame}_{version}_cannot_answer_cases.json" - cannot_answer_path = save_recording_cases( - cannot_answer_cases, output_dir, cannot_answer_filename - ) - saved_files["cannot_answer_cases"] = cannot_answer_path - print(f"Saved {len(cannot_answer_cases)} cannot_answer_cases to {cannot_answer_path}") - - return saved_files - - -def compute_can_answer_stats(day_groups, rounds_to_consider=float("inf")): - """ - Compute can-answer statistics for each day using the union of all prior evidences. - - For each day, iterate over the QAs in the given order. If the current QA's - evidences (restricted to the same day) are a subset of the union of all - previously seen evidences for that day, increment can_answer_count. Then add - the current evidences to the seen set. - - Note: - The first QA of each day is excluded from the statistics because it - cannot be answered without any prior evidences. It is still used to - seed the seen evidences for subsequent QAs. - - Args: - day_groups: Dict mapping day_id (e.g., "D1") to a list of QA dicts. Each QA - dict should contain an "evidence" field that is a list of strings. - rounds_to_consider: Number of previous rounds to consider for evidence accumulation. - Default is infinity (all previous rounds). - Set to 1 to only consider the immediately preceding round. - - Returns: - dict: Mapping day_id -> {"can_answer_count": int, "total": int, "ratio": float} - """ - results = {} - for day, qa_list in day_groups.items(): - seen = set() - # Keep track of evidence history for limited rounds - evidence_history = [] - can_answer = 0 - total = max(len(qa_list) - 1, 0) - rounds_count = 0 - for idx, qa in enumerate(qa_list): - cur = set(qa.get("evidence", [])) - rounds_count += 1 - - if idx == 0: - # Seed seen evidences with the first QA but do not count it - evidence_history.append(cur) - seen = set().union(*evidence_history) - continue - - # Check if current evidence is subset of accumulated evidence - if cur and cur.issubset(seen): - can_answer += 1 - - # Add current evidence to history - evidence_history.append(cur) - - # Limit history to specified number of rounds - if rounds_count > rounds_to_consider: - evidence_history.pop(0) - - # Recalculate seen as union of evidence_history - seen = set().union(*evidence_history) - - results[day] = { - "can_answer_count": can_answer, - "total": total, - "ratio": (can_answer / total) if total else 0.0, - } - return results - - -def compute_can_answer_count_by_pre_evidences( - temporal_locomo_data, num_of_users, stats_dir=None, rounds_to_consider=float("inf") -): - """ - Compute can-answer statistics per day for each conversation using the - union of all previously asked evidences within the same day. - - Args: - temporal_locomo_data: The temporal locomo data containing conversations - num_of_users: Number of users/conversations to process - stats_dir: Directory to save statistics (optional) - rounds_to_consider: Number of previous rounds to consider for evidence accumulation. - Default is infinity (all previous rounds). - Set to 1 to only consider the immediately preceding round. - - Returns: - dict: Mapping conversation_id -> per-day stats as produced by compute_can_answer_stats - """ - all_conversations_stats = {} - for conv_idx in range(num_of_users): - temporal_conv = temporal_locomo_data[conv_idx] - conversation_id = temporal_conv["conversation_id"] - - # Build day -> qa_pairs mapping - day_groups = {} - for day_id, day_data in temporal_conv.get("days", {}).items(): - day_groups[day_id] = day_data.get("qa_pairs", []) - - # Use shared utility to compute stats with correct accumulation logic - per_day_stats = compute_can_answer_stats(day_groups, rounds_to_consider) - all_conversations_stats[conversation_id] = per_day_stats - - # Build per-conversation summaries and overall summary - per_conversation_summaries = {} - overall_can = 0 - overall_total = 0 - for conv_id, day_stats in all_conversations_stats.items(): - conv_can = 0 - conv_total = 0 - for _day, stats in day_stats.items(): - conv_can += int(stats.get("can_answer_count", 0)) - conv_total += int(stats.get("total", 0)) - conv_ratio = (conv_can / conv_total) if conv_total else 0.0 - per_conversation_summaries[conv_id] = { - "can_answer_count": conv_can, - "total": conv_total, - "ratio": conv_ratio, - } - overall_can += conv_can - overall_total += conv_total - - overall_summary = { - "can_answer_count": overall_can, - "total": overall_total, - "ratio": (overall_can / overall_total) if overall_total else 0.0, - } - - # Add rounds information to the result - result_payload = { - "per_conversation_summary": per_conversation_summaries, - "overall_summary": overall_summary, - "rounds_considered": rounds_to_consider if rounds_to_consider != float("inf") else "all", - } - - # Print results - print("\nComputed can-answer-by-pre-evidences stats:") - print( - f"Rounds considered: {rounds_to_consider if rounds_to_consider != float('inf') else 'all'}" - ) - print(json.dumps(result_payload, indent=2, ensure_ascii=False)) - - # Save results if stats_dir is provided - if stats_dir: - output_path = ( - stats_dir - / f"evidences_rounds_{rounds_to_consider if rounds_to_consider != float('inf') else 'all'}.json" - ) - with open(output_path, "w", encoding="utf-8") as fw: - json.dump(result_payload, fw, indent=2, ensure_ascii=False) - print(f"Saved stats to {output_path}") - - return result_payload diff --git a/evaluation/scripts/temporal_locomo/scheduler_time_eval.py b/evaluation/scripts/temporal_locomo/scheduler_time_eval.py deleted file mode 100644 index 12d1964cd..000000000 --- a/evaluation/scripts/temporal_locomo/scheduler_time_eval.py +++ /dev/null @@ -1,93 +0,0 @@ -import argparse -import sys - -from pathlib import Path - -from modules.locomo_eval_module import LocomoEvalModelModules -from modules.schemas import ContextUpdateMethod - -from evaluation.scripts.temporal_locomo.models.locomo_ingestion import LocomoIngestor -from evaluation.scripts.temporal_locomo.models.locomo_processor_w_time_eval import ( - LocomoProcessorWithTimeEval, -) -from memos.log import get_logger - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -# TODO: This evaluation has been suspended—it is not finished yet. -class TemporalLocomoForTimeEval(LocomoEvalModelModules): - def __init__(self, args): - args.result_dir_prefix = "time_eval-" - - super().__init__(args=args) - self.num_of_users = 10 - - self.locomo_ingestor = LocomoIngestor(args=args) - self.locomo_processor = LocomoProcessorWithTimeEval(args=args) - - def run_time_eval_pipeline(self, skip_ingestion=True, skip_processing=False): - """ - Run the complete evaluation pipeline including dataset conversion, - data ingestion, and processing. - """ - print("=" * 80) - print("Starting TimeLocomo Evaluation Pipeline") - print("=" * 80) - - # Step 1: Check if temporal_locomo dataset exists, if not convert it - temporal_locomo_file = self.data_dir / "temporal_locomo" / "temporal_locomo_qa.json" - if not temporal_locomo_file.exists(): - print(f"Temporal locomo dataset not found at {temporal_locomo_file}") - print("Converting locomo dataset to temporal_locomo format...") - self.convert_locomo_to_temporal_locomo(output_dir=self.data_dir / "temporal_locomo") - print("Dataset conversion completed.") - else: - print(f"Temporal locomo dataset found at {temporal_locomo_file}, skipping conversion.") - - # Step 2: Data ingestion - if not skip_ingestion: - print("\n" + "=" * 50) - print("Step 2: Data Ingestion") - print("=" * 50) - self.locomo_ingestor.run_ingestion() - - # Step 3: Processing and evaluation - print("\n" + "=" * 50) - print("Step 3: Processing and Evaluation") - print("=" * 50) - print("Running locomo processing to search and answer...") - - print("Starting locomo processing to generate search and response results...") - self.locomo_processor.run_locomo_processing(num_users=self.num_of_users) - print("Processing completed successfully.") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--version", - type=str, - default="v1.0.1", - help="Version identifier for saving results (e.g., 1010)", - ) - parser.add_argument( - "--workers", type=int, default=10, help="Number of parallel workers to process users" - ) - parser.add_argument( - "--top_k", type=int, default=20, help="Number of results to retrieve in search queries" - ) - - args = parser.parse_args() - - args.frame = "memos_scheduler" - args.scheduler_flag = True - args.context_update_method = ContextUpdateMethod.PRE_CONTEXT - - evaluator = TemporalLocomoForTimeEval(args=args) - evaluator.run_time_eval_pipeline() diff --git a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py b/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py deleted file mode 100644 index bb6967e7f..000000000 --- a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py +++ /dev/null @@ -1,155 +0,0 @@ -import argparse -import asyncio -import os -import sys - -from pathlib import Path - -from modules.locomo_eval_module import LocomoEvalModelModules -from modules.schemas import ContextUpdateMethod -from modules.utils import compute_can_answer_count_by_pre_evidences - -from evaluation.scripts.temporal_locomo.models.locomo_eval import LocomoEvaluator -from evaluation.scripts.temporal_locomo.models.locomo_ingestion import LocomoIngestor -from evaluation.scripts.temporal_locomo.models.locomo_metric import LocomoMetric -from evaluation.scripts.temporal_locomo.models.locomo_processor import LocomoProcessor -from memos.log import get_logger - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -class TemporalLocomoEval(LocomoEvalModelModules): - def __init__(self, args): - super().__init__(args=args) - self.num_of_users = 10 - - self.locomo_ingestor = LocomoIngestor(args=args) - self.locomo_processor = LocomoProcessor(args=args) - self.locomo_evaluator = LocomoEvaluator(args=args) - self.locomo_metric = LocomoMetric(args=args) - - def run_answer_hit_eval_pipeline(self, skip_ingestion=True, skip_processing=False): - """ - Run the complete evaluation pipeline including dataset conversion, - data ingestion, and processing. - """ - print("=" * 80) - print("Starting TimeLocomo Evaluation Pipeline") - print("=" * 80) - - # Step 1: Check if temporal_locomo dataset exists, if not convert it - temporal_locomo_file = self.data_dir / "temporal_locomo" / "temporal_locomo_qa.json" - if not temporal_locomo_file.exists(): - print(f"Temporal locomo dataset not found at {temporal_locomo_file}") - print("Converting locomo dataset to temporal_locomo format...") - self.convert_locomo_to_temporal_locomo(output_dir=self.data_dir / "temporal_locomo") - print("Dataset conversion completed.") - else: - print(f"Temporal locomo dataset found at {temporal_locomo_file}, skipping conversion.") - - # Step 2: Data ingestion - if not skip_ingestion: - print("\n" + "=" * 50) - print("Step 2: Data Ingestion") - print("=" * 50) - self.locomo_ingestor.run_ingestion() - - # Step 3: Processing and evaluation - if not skip_processing: - print("\n" + "=" * 50) - print("Step 3: Processing and Evaluation") - print("=" * 50) - print("Running locomo processing to search and answer...") - - print("Starting locomo processing to generate search and response results...") - self.locomo_processor.run_locomo_processing(num_users=self.num_of_users) - print("Processing completed successfully.") - - # Optional: run post-hoc evaluation over generated responses if available - try: - if os.path.exists(self.response_path): - print("Running LocomoEvaluator over existing response results...") - asyncio.run(self.locomo_evaluator.run()) - else: - print( - f"Skipping LocomoEvaluator: response file not found at {evaluator.response_path}" - ) - # Run metrics summarization if judged file is produced - - if os.path.exists(self.judged_path): - print("Running LocomoMetric over judged results...") - self.locomo_metric.run() - else: - print(f"Skipping LocomoMetric: judged file not found at {self.judged_path}") - except Exception as e: - logger.error(f"LocomoEvaluator step skipped due to error: {e}", exc_info=True) - - # Step 4: Summary - print("\n" + "=" * 80) - print("Evaluation Pipeline Completed Successfully!") - print("=" * 80) - print("Results saved to:") - print(f" - Search results: {self.search_path}") - print(f" - Response results: {self.response_path}") - print(f" - Statistics: {self.stats_path}") - print("=" * 80) - - def compute_can_answer_count_by_pre_evidences(self, rounds_to_consider): - """ - Compute can-answer statistics per day for each conversation using the - union of all previously asked evidences within the same day. - - Returns: - dict: Mapping conversation_id -> per-day stats as produced by compute_can_answer_stats - """ - return compute_can_answer_count_by_pre_evidences( - temporal_locomo_data=self.temporal_locomo_data, - num_of_users=self.num_of_users, - stats_dir=self.stats_dir, - rounds_to_consider=rounds_to_consider, - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--frame", - type=str, - default="memos", - choices=["zep", "memos", "mem0", "mem0_graph", "memos_scheduler"], - help="Specify the memory framework (zep or memos or mem0 or mem0_graph)", - ) - parser.add_argument( - "--version", - type=str, - default="v1.0.1", - help="Version identifier for saving results (e.g., 1010)", - ) - parser.add_argument( - "--workers", type=int, default=10, help="Number of parallel workers to process users" - ) - parser.add_argument( - "--top_k", type=int, default=20, help="Number of results to retrieve in search queries" - ) - parser.add_argument( - "--scheduler_flag", - action=argparse.BooleanOptionalAction, - default=False, - help="Enable or disable memory scheduler features", - ) - parser.add_argument( - "--context_update_method", - type=str, - default="chat_history", - choices=ContextUpdateMethod.values(), - help="Method to update context: pre_context (use previous context), chat_history (use template with history), current_context (use current context)", - ) - args = parser.parse_args() - - evaluator = TemporalLocomoEval(args=args) - evaluator.run_answer_hit_eval_pipeline() diff --git a/examples/mem_scheduler/api_w_scheduler.py b/examples/mem_scheduler/api_w_scheduler.py index 11f0ebb81..a2184e9ca 100644 --- a/examples/mem_scheduler/api_w_scheduler.py +++ b/examples/mem_scheduler/api_w_scheduler.py @@ -1,3 +1,7 @@ +from memos.api.handlers.scheduler_handler import ( + handle_scheduler_status, + handle_scheduler_wait, +) from memos.api.routers.server_router import mem_scheduler from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem @@ -9,14 +13,9 @@ print(f"use_redis_queue: {mem_scheduler.use_redis_queue}") print(f"Queue type: {type(mem_scheduler.memos_message_queue).__name__}") print(f"Queue maxsize: {getattr(mem_scheduler.memos_message_queue, 'maxsize', 'N/A')}") - -# Check if Redis queue is connected -if hasattr(mem_scheduler.memos_message_queue, "_is_connected"): - print(f"Redis connected: {mem_scheduler.memos_message_queue._is_connected}") -if hasattr(mem_scheduler.memos_message_queue, "_redis_conn"): - print(f"Redis connection: {mem_scheduler.memos_message_queue._redis_conn}") print("=====================================\n") +mem_scheduler.memos_message_queue.debug_mode_on() queue = mem_scheduler.memos_message_queue queue.clear() @@ -27,7 +26,7 @@ def my_test_handler(messages: list[ScheduleMessageItem]): for msg in messages: print(f" my_test_handler - {msg.item_id}: {msg.content}") print( - f"{queue._redis_conn.xinfo_groups(queue.stream_name)} qsize: {queue.qsize()} messages:{messages}" + f"{queue._redis_conn.xinfo_groups(queue.stream_key_prefix)} qsize: {queue.qsize()} messages:{messages}" ) @@ -35,6 +34,12 @@ def my_test_handler(messages: list[ScheduleMessageItem]): TEST_HANDLER_LABEL = "test_handler" mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler}) +# 2.1 Monitor global scheduler status before submitting tasks +global_status_before = handle_scheduler_status( + user_name=None, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler" +) +print("[Monitor] Global status before submit:", global_status_before) + # 3. Create messages messages_to_send = [ ScheduleMessageItem( @@ -50,12 +55,33 @@ def my_test_handler(messages: list[ScheduleMessageItem]): # 5. Submit messages for mes in messages_to_send: print(f"Submitting message {mes.item_id} to the scheduler...") - mem_scheduler.submit_messages([mes]) + mem_scheduler.memos_message_queue.submit_messages([mes]) + +# 5.1 Monitor status for specific mem_cube while running +USER_MEM_CUBE = "test_mem_cube" +user_status_running = handle_scheduler_status( + user_name=USER_MEM_CUBE, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler" +) +print(f"[Monitor] Status for {USER_MEM_CUBE} after submit:", user_status_running) # 6. Wait for messages to be processed (limited to 100 checks) print("Waiting for messages to be consumed (max 100 checks)...") mem_scheduler.mem_scheduler_wait() +# 6.1 Wait until idle for specific mem_cube via handler +wait_result = handle_scheduler_wait( + user_name=USER_MEM_CUBE, + timeout_seconds=120.0, + poll_interval=0.2, + mem_scheduler=mem_scheduler, +) +print(f"[Monitor] Wait result for {USER_MEM_CUBE}:", wait_result) + +# 6.2 Monitor global scheduler status after processing +global_status_after = handle_scheduler_status( + user_name=None, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler" +) +print("[Monitor] Global status after processing:", global_status_after) # 7. Stop the scheduler print("Stopping the scheduler...") diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 48db7ae6e..ee481d028 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -202,7 +202,7 @@ def _process_pref_mem( content=json.dumps(messages_list), timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item_pref]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_pref]) self.logger.info("Submitted preference add to scheduler (async mode)") except Exception as e: self.logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True) @@ -275,7 +275,7 @@ def _schedule_memory_tasks( timestamp=datetime.utcnow(), user_name=add_req.mem_cube_id, ) - self.mem_scheduler.submit_messages(messages=[message_item_read]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_read]) self.logger.info(f"Submitted async memory read task: {json.dumps(mem_ids)}") except Exception as e: self.logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True) @@ -291,4 +291,4 @@ def _schedule_memory_tasks( timestamp=datetime.utcnow(), user_name=add_req.mem_cube_id, ) - self.mem_scheduler.submit_messages(messages=[message_item_add]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_add]) diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py index 86a00dc37..a174defb1 100644 --- a/src/memos/api/handlers/base_handler.py +++ b/src/memos/api/handlers/base_handler.py @@ -8,6 +8,7 @@ from typing import Any from memos.log import get_logger +from memos.mem_scheduler.base_scheduler import BaseScheduler logger = get_logger(__name__) @@ -123,7 +124,7 @@ def mem_reader(self): return self.deps.mem_reader @property - def mem_scheduler(self): + def mem_scheduler(self) -> BaseScheduler: """Get scheduler instance.""" return self.deps.mem_scheduler diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index f6023e5c8..8540a67ec 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -213,7 +213,7 @@ def generate_chat_response() -> Generator[str, None, None]: query=chat_req.query, top_k=20, session_id=chat_req.session_id, - mode=SearchMode.FINE if chat_req.internet_search else SearchMode.FAST, + mode=SearchMode.FAST, internet_search=chat_req.internet_search, # TODO this param is not worked at fine mode moscube=chat_req.moscube, chat_history=chat_req.history, @@ -603,7 +603,7 @@ def _send_message_to_scheduler( content=query, timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) self.logger.info(f"Sent message to scheduler with label: {label}") except Exception as e: self.logger.error(f"Failed to send message to scheduler: {e}", exc_info=True) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index e8e4e07d6..cf2ab73bb 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -17,10 +17,14 @@ ) from memos.api.product_models import APISearchRequest, SearchResponse from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger from memos.mem_scheduler.schemas.general_schemas import SearchMode from memos.types import MOSSearchResult, UserContext +logger = get_logger(__name__) + + class SearchHandler(BaseHandler): """ Handler for memory search operations. @@ -101,17 +105,6 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse ) def _get_search_mode(self, mode: str) -> str: - """ - Get search mode with environment variable fallback. - - Args: - mode: Requested search mode - - Returns: - Search mode string - """ - if mode == SearchMode.NOT_INITIALIZED: - return os.getenv("SEARCH_MODE", SearchMode.FAST) return mode def _search_text( @@ -133,16 +126,16 @@ def _search_text( """ try: if search_mode == SearchMode.FAST: - memories = self._fast_search(search_req, user_context) + text_memories = self._fast_search(search_req, user_context) elif search_mode == SearchMode.FINE: - memories = self._fine_search(search_req, user_context) + text_memories = self._fine_search(search_req, user_context) elif search_mode == SearchMode.MIXTURE: - memories = self._mix_search(search_req, user_context) + text_memories = self._mix_search(search_req, user_context) else: self.logger.error(f"Unsupported search mode: {search_mode}") return [] - return [format_memory_item(data) for data in memories] + return text_memories except Exception as e: self.logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc()) @@ -199,7 +192,7 @@ def _fast_search( target_session_id = search_req.session_id or "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - return self.naive_mem_cube.text_mem.search( + search_results = self.naive_mem_cube.text_mem.search( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, @@ -214,6 +207,10 @@ def _fast_search( }, ) + formatted_memories = [format_memory_item(data) for data in search_results] + + return formatted_memories + def _fine_search( self, search_req: APISearchRequest, @@ -240,7 +237,7 @@ def _fine_search( "chat_history": search_req.chat_history, } - # Fast retrieve + # Fine retrieve fast_retrieved_memories = searcher.retrieve( query=search_req.query, user_name=user_context.mem_cube_id, @@ -261,12 +258,45 @@ def _fine_search( ) # Enhance with query - enhanced_results, _ = self.mem_scheduler.retriever.enhance_memories_with_query( + enhanced_memories, _ = self.mem_scheduler.retriever.enhance_memories_with_query( query_history=[search_req.query], memories=fast_memories, ) - return enhanced_results + if len(enhanced_memories) < len(fast_memories): + logger.info( + f"Enhanced memories ({len(enhanced_memories)}) are less than fast memories ({len(fast_memories)}). Recalling for more." + ) + missing_info_hint, trigger = self.mem_scheduler.retriever.recall_for_missing_memories( + query=search_req.query, + memories=fast_memories, + ) + retrieval_size = len(fast_memories) - len(enhanced_memories) + logger.info(f"Retrieval size: {retrieval_size}") + if trigger: + logger.info(f"Triggering additional search with hint: {missing_info_hint}") + additional_memories = searcher.search( + query=missing_info_hint, + user_name=user_context.mem_cube_id, + top_k=retrieval_size, + mode=SearchMode.FAST, + memory_type="All", + search_filter=search_filter, + info=info, + ) + else: + logger.info("Not triggering additional search, using fast memories.") + additional_memories = fast_memories[:retrieval_size] + + enhanced_memories += additional_memories + logger.info( + f"Added {len(additional_memories)} more memories. Total enhanced memories: {len(enhanced_memories)}" + ) + formatted_memories = [format_memory_item(data) for data in enhanced_memories] + + logger.info(f"Found {len(formatted_memories)} memories for user {search_req.user_id}") + + return formatted_memories def _mix_search( self, diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 892d2d436..30df150ea 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -1,3 +1,4 @@ +import os import uuid from typing import Generic, Literal, TypeVar @@ -172,7 +173,7 @@ class APISearchRequest(BaseRequest): user_id: str = Field(None, description="User ID") mem_cube_id: str | None = Field(None, description="Cube ID to search in") mode: SearchMode = Field( - SearchMode.NOT_INITIALIZED, description="search mode: fast, fine, or mixture" + os.getenv("SEARCH_MODE", SearchMode.FAST), description="search mode: fast, fine, or mixture" ) internet_search: bool = Field(False, description="Whether to use internet search") moscube: bool = Field(False, description="Whether to use MemOSCube") diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index d43f9ccdc..b3b517305 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -34,6 +34,7 @@ SuggestionResponse, ) from memos.log import get_logger +from memos.mem_scheduler.base_scheduler import BaseScheduler logger = get_logger(__name__) @@ -58,7 +59,7 @@ # Extract commonly used components for function-based handlers # (These can be accessed from the components dict without unpacking all of them) -mem_scheduler = components["mem_scheduler"] +mem_scheduler: BaseScheduler = components["mem_scheduler"] llm = components["llm"] naive_mem_cube = components["naive_mem_cube"] diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 3b53cef1a..f11b3a44c 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -287,7 +287,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = content=query, timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) memories = mem_cube.text_mem.search( query, @@ -347,7 +347,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = content=response, timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) return response @@ -774,7 +774,9 @@ def process_textual_memory(): content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages( + messages=[message_item] + ) else: message_item = ScheduleMessageItem( user_id=target_user_id, @@ -783,7 +785,9 @@ def process_textual_memory(): content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages( + messages=[message_item] + ) def process_preference_memory(): if ( @@ -818,7 +822,7 @@ def process_preference_memory(): content=json.dumps(messages_list), timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) # Execute both memory processing functions in parallel with ContextThreadPoolExecutor(max_workers=2) as executor: @@ -872,7 +876,9 @@ def process_preference_memory(): content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages( + messages=[message_item] + ) else: message_item = ScheduleMessageItem( user_id=target_user_id, @@ -881,7 +887,9 @@ def process_preference_memory(): content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages( + messages=[message_item] + ) # user doc input if ( @@ -910,7 +918,7 @@ def process_preference_memory(): content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) logger.info(f"Add memory to {mem_cube_id} successfully") diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py index 0114fc0da..11c112d52 100644 --- a/src/memos/mem_os/main.py +++ b/src/memos/mem_os/main.py @@ -220,7 +220,7 @@ def _chat_with_cot_enhancement( content=enhanced_response, timestamp=datetime.now().isoformat(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) return enhanced_response diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index 359db72ba..9a4ab3f4d 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -641,7 +641,7 @@ def _send_message_to_scheduler( content=query, timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) async def _post_chat_processing( self, diff --git a/src/memos/mem_scheduler/analyzer/eval_analyzer.py b/src/memos/mem_scheduler/analyzer/eval_analyzer.py index d37e17456..cf0b8f1dd 100644 --- a/src/memos/mem_scheduler/analyzer/eval_analyzer.py +++ b/src/memos/mem_scheduler/analyzer/eval_analyzer.py @@ -1244,7 +1244,7 @@ def analyze_bad_cases_with_llm_processing( return results -def main(): +def main(version_name="ct-1111"): """Main test function.""" print("=== EvalAnalyzer Simple Test ===") @@ -1254,7 +1254,7 @@ def main(): print("Analyzer initialized") # Test file paths - eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-xcy-1030-2114-locomo" + eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-{version_name}-locomo" judged_file = os.path.join(eval_result_dir, "memos-api_locomo_judged.json") search_results_file = os.path.join(eval_result_dir, "memos-api_locomo_search_results.json") diff --git a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py index 03e1fc778..df504ee75 100644 --- a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +++ b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py @@ -521,7 +521,7 @@ def chat(self, query: str, user_id: str | None = None) -> str: content=response, timestamp=datetime.now(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) return response diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index eb49d0238..657ceea0f 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -1,5 +1,5 @@ -import contextlib import multiprocessing +import os import threading import time @@ -14,10 +14,9 @@ from memos.context.context import ContextThread from memos.llms.base import BaseLLM from memos.log import get_logger +from memos.mem_cube.base import BaseMemCube from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue -from memos.mem_scheduler.general_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor @@ -43,6 +42,9 @@ ScheduleMessageItem, ) from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem +from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher +from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue +from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.filter_utils import ( transform_name_to_key, @@ -56,7 +58,8 @@ if TYPE_CHECKING: - from memos.mem_cube.base import BaseMemCube + from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher + from memos.reranker.http_bge import HTTPBGEReranker logger = get_logger(__name__) @@ -88,22 +91,34 @@ def __init__(self, config: BaseSchedulerConfig): "scheduler_startup_mode", DEFAULT_STARTUP_MODE ) + # optional configs + self.disabled_handlers: list | None = self.config.get("disabled_handlers", None) + + self.max_web_log_queue_size = self.config.get( + "max_web_log_queue_size", DEFAULT_MAX_WEB_LOG_QUEUE_SIZE + ) + self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue( + maxsize=self.max_web_log_queue_size + ) + self._consumer_thread = None # Reference to our consumer thread/process + self._consumer_process = None # Reference to our consumer process + self._running = False + self._consume_interval = self.config.get( + "consume_interval_seconds", DEFAULT_CONSUME_INTERVAL_SECONDS + ) + self.consume_batch = self.config.get("consume_batch", DEFAULT_CONSUME_BATCH) + # message queue configuration self.use_redis_queue = self.config.get("use_redis_queue", DEFAULT_USE_REDIS_QUEUE) self.max_internal_message_queue_size = self.config.get( "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE ) - - # Initialize message queue based on configuration - if self.use_redis_queue: - self.memos_message_queue = SchedulerRedisQueue( - maxsize=self.max_internal_message_queue_size - ) - else: - self.memos_message_queue: Queue[ScheduleMessageItem] = Queue( - maxsize=self.max_internal_message_queue_size - ) - + self.memos_message_queue = ScheduleTaskQueue( + use_redis_queue=self.use_redis_queue, + maxsize=self.max_internal_message_queue_size, + disabled_handlers=self.disabled_handlers, + ) + self.searcher: Searcher | None = None self.retriever: SchedulerRetriever | None = None self.db_engine: Engine | None = None self.monitor: SchedulerGeneralMonitor | None = None @@ -117,23 +132,6 @@ def __init__(self, config: BaseSchedulerConfig): enable_parallel_dispatch=self.enable_parallel_dispatch, ) - # optional configs - self.disable_handlers: list | None = self.config.get("disable_handlers", None) - - self.max_web_log_queue_size = self.config.get( - "max_web_log_queue_size", DEFAULT_MAX_WEB_LOG_QUEUE_SIZE - ) - self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue( - maxsize=self.max_web_log_queue_size - ) - self._consumer_thread = None # Reference to our consumer thread/process - self._consumer_process = None # Reference to our consumer process - self._running = False - self._consume_interval = self.config.get( - "consume_interval_seconds", DEFAULT_CONSUME_INTERVAL_SECONDS - ) - self.consume_batch = self.config.get("consume_batch", DEFAULT_CONSUME_BATCH) - # other attributes self._context_lock = threading.Lock() self.current_user_id: UserID | str | None = None @@ -143,6 +141,15 @@ def __init__(self, config: BaseSchedulerConfig): self.auth_config = None self.rabbitmq_config = None + def init_mem_cube(self, mem_cube): + self.mem_cube = mem_cube + self.text_mem: TreeTextMemory = self.mem_cube.text_mem + self.searcher: Searcher = self.text_mem.get_searcher( + manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", + moscube=False, + ) + self.reranker: HTTPBGEReranker = self.text_mem.reranker + def initialize_modules( self, chat_llm: BaseLLM, @@ -199,6 +206,9 @@ def initialize_modules( # start queue monitor if enabled and a bot is set later + def debug_mode_on(self): + self.memos_message_queue.debug_mode_on() + def _cleanup_on_init_failure(self): """Clean up resources if initialization fails.""" try: @@ -208,23 +218,16 @@ def _cleanup_on_init_failure(self): logger.warning(f"Error during cleanup: {e}") @property - def mem_cube(self) -> GeneralMemCube: + def mem_cube(self) -> BaseMemCube: """The memory cube associated with this MemChat.""" return self.current_mem_cube @mem_cube.setter - def mem_cube(self, value: GeneralMemCube) -> None: + def mem_cube(self, value: BaseMemCube) -> None: """The memory cube associated with this MemChat.""" self.current_mem_cube = value self.retriever.mem_cube = value - def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None: - """Update current user/cube context from the incoming message (thread-safe).""" - with self._context_lock: - self.current_user_id = msg.user_id - self.current_mem_cube_id = msg.mem_cube_id - self.current_mem_cube = self.get_mem_cube(msg.mem_cube_id) - def transform_working_memories_to_monitors( self, query_keywords, memories: list[TextualMemoryItem] ) -> list[MemoryMonitorItem]: @@ -523,29 +526,7 @@ def update_activation_memory_periodically( logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True) def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): - """Submit messages to the message queue (either local queue or Redis).""" - if isinstance(messages, ScheduleMessageItem): - messages = [messages] # transform single message to list - - for message in messages: - if not isinstance(message, ScheduleMessageItem): - error_msg = f"Invalid message type: {type(message)}, expected ScheduleMessageItem" - logger.error(error_msg) - raise TypeError(error_msg) - - if getattr(message, "timestamp", None) is None: - with contextlib.suppress(Exception): - message.timestamp = datetime.utcnow() - - if self.disable_handlers and message.label in self.disable_handlers: - logger.info(f"Skipping disabled handler: {message.label} - {message.content}") - continue - self.memos_message_queue.put(message) - logger.info(f"Submitted message to local queue: {message.label} - {message.content}") - - with contextlib.suppress(Exception): - if messages: - self.dispatcher.on_messages_enqueued(messages) + self.memos_message_queue.submit_messages(messages=messages) def _submit_web_logs( self, messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem] @@ -606,10 +587,16 @@ def _message_consumer(self) -> None: try: # Get messages in batches based on consume_batch setting - messages = self.memos_message_queue.get(block=True, batch_size=self.consume_batch) + messages = self.memos_message_queue.get_messages(batch_size=self.consume_batch) if messages: try: + import contextlib + + with contextlib.suppress(Exception): + if messages: + self.dispatcher.on_messages_enqueued(messages) + self.dispatcher.dispatch(messages) except Exception as e: logger.error(f"Error dispatching messages: {e!s}") @@ -878,7 +865,7 @@ def _fmt_eta(seconds: float | None) -> str: if isinstance(self.memos_message_queue, SchedulerRedisQueue): # For Redis queue, prefer XINFO GROUPS to compute pending groups_info = self.memos_message_queue.redis.xinfo_groups( - self.memos_message_queue.stream_name + self.memos_message_queue.stream_key_prefix ) if groups_info: for group in groups_info: diff --git a/src/memos/mem_scheduler/general_modules/base.py b/src/memos/mem_scheduler/general_modules/base.py index 0b80b9e7d..e0ee65ba0 100644 --- a/src/memos/mem_scheduler/general_modules/base.py +++ b/src/memos/mem_scheduler/general_modules/base.py @@ -18,8 +18,6 @@ def __init__(self): self._chat_llm = None self._process_llm = None - self.mem_cubes: dict[str, GeneralMemCube] = {} - def load_template(self, template_name: str) -> str: if template_name not in PROMPT_MAPPING: logger.error("Prompt template is not found!") diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py index 1f89d3b02..d35a4f106 100644 --- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py @@ -1,7 +1,7 @@ from collections.abc import Callable from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube +from memos.mem_cube.base import BaseMemCube from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.schemas.general_schemas import ( ACTIVATION_MEMORY_TYPE, @@ -44,7 +44,7 @@ def create_autofilled_log_item( to_memory_type: str, user_id: str, mem_cube_id: str, - mem_cube: GeneralMemCube, + mem_cube: BaseMemCube, ) -> ScheduleLogForWebItem: text_mem_base: TreeTextMemory = mem_cube.text_mem current_memory_sizes = text_mem_base.get_current_memory_size() @@ -106,7 +106,7 @@ def log_working_memory_replacement( new_memory: list[TextualMemoryItem], user_id: str, mem_cube_id: str, - mem_cube: GeneralMemCube, + mem_cube: BaseMemCube, log_func_callback: Callable[[list[ScheduleLogForWebItem]], None], ): """Log changes when working memory is replaced.""" @@ -163,7 +163,7 @@ def log_activation_memory_update( label: str, user_id: str, mem_cube_id: str, - mem_cube: GeneralMemCube, + mem_cube: BaseMemCube, log_func_callback: Callable[[list[ScheduleLogForWebItem]], None], ): """Log changes when activation memory is updated.""" @@ -214,7 +214,7 @@ def log_adding_memory( memory_type: str, user_id: str, mem_cube_id: str, - mem_cube: GeneralMemCube, + mem_cube: BaseMemCube, log_func_callback: Callable[[list[ScheduleLogForWebItem]], None], ): """Log changes when working memory is replaced.""" diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 6e916962e..92e317881 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -5,6 +5,7 @@ from memos.configs.mem_scheduler import GeneralSchedulerConfig from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger +from memos.mem_cube.base import BaseMemCube from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.base_scheduler import BaseScheduler from memos.mem_scheduler.schemas.general_schemas import ( @@ -22,6 +23,7 @@ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem from memos.mem_scheduler.utils.filter_utils import is_all_chinese, is_all_english +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube from memos.memories.textual.item import TextualMemoryItem from memos.memories.textual.preference import PreferenceTextMemory from memos.memories.textual.tree import TreeTextMemory @@ -137,7 +139,7 @@ def long_memory_update_process( label=QUERY_LABEL, user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=self.current_mem_cube, + mem_cube=self.mem_cube, ) def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: @@ -150,7 +152,7 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.info(f"Messages {messages} assigned to {QUERY_LABEL} handler.") # Process the query in a session turn - grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) + grouped_messages = group_messages_by_user_and_mem_cube(messages) self.validate_schedule_messages(messages=messages, label=QUERY_LABEL) @@ -172,7 +174,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: """ logger.info(f"Messages {messages} assigned to {ANSWER_LABEL} handler.") # Process the query in a session turn - grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) + grouped_messages = group_messages_by_user_and_mem_cube(messages) self.validate_schedule_messages(messages=messages, label=ANSWER_LABEL) @@ -185,7 +187,8 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.info(f"Messages {messages} assigned to {ADD_LABEL} handler.") # Process the query in a session turn - grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) + grouped_messages = group_messages_by_user_and_mem_cube(messages) + mem_cube = self.mem_cube self.validate_schedule_messages(messages=messages, label=ADD_LABEL) try: @@ -203,7 +206,6 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.error(f"Error: {e}. Content: {msg.content}", exc_info=True) userinput_memory_ids = [] - mem_cube = self.current_mem_cube for memory_id in userinput_memory_ids: try: mem_item: TextualMemoryItem = mem_cube.text_mem.get( @@ -225,7 +227,7 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: memory_type=mem_type, user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, - mem_cube=self.current_mem_cube, + mem_cube=self.mem_cube, log_func_callback=self._submit_web_logs, ) @@ -239,7 +241,7 @@ def process_message(message: ScheduleMessageItem): try: user_id = message.user_id mem_cube_id = message.mem_cube_id - mem_cube = self.current_mem_cube + mem_cube = self.mem_cube content = message.content user_name = message.user_name @@ -263,7 +265,6 @@ def process_message(message: ScheduleMessageItem): mem_ids=mem_ids, user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, text_mem=text_mem, user_name=user_name, ) @@ -288,7 +289,6 @@ def _process_memories_with_reader( mem_ids: list[str], user_id: str, mem_cube_id: str, - mem_cube: GeneralMemCube, text_mem: TreeTextMemory, user_name: str, ) -> None: @@ -299,7 +299,6 @@ def _process_memories_with_reader( mem_ids: List of memory IDs to process user_id: User ID mem_cube_id: Memory cube ID - mem_cube: Memory cube instance text_mem: Text memory instance """ try: @@ -403,7 +402,7 @@ def process_message(message: ScheduleMessageItem): try: user_id = message.user_id mem_cube_id = message.mem_cube_id - mem_cube = self.current_mem_cube + mem_cube = self.mem_cube content = message.content user_name = message.user_name @@ -452,7 +451,7 @@ def _process_memories_with_reorganize( mem_ids: list[str], user_id: str, mem_cube_id: str, - mem_cube: GeneralMemCube, + mem_cube: BaseMemCube, text_mem: TreeTextMemory, user_name: str, ) -> None: @@ -504,10 +503,11 @@ def _pref_add_message_consumer(self, messages: list[ScheduleMessageItem]) -> Non def process_message(message: ScheduleMessageItem): try: + mem_cube = self.mem_cube + user_id = message.user_id session_id = message.session_id mem_cube_id = message.mem_cube_id - mem_cube = self.current_mem_cube content = message.content messages_list = json.loads(content) diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py index 848b1d257..01b57563d 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py +++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py @@ -1,3 +1,5 @@ +import time + from concurrent.futures import as_completed from memos.configs.mem_scheduler import BaseSchedulerConfig @@ -9,6 +11,8 @@ from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE, DEFAULT_SCHEDULER_RETRIEVER_RETRIES, + FINE_STRATEGY, + FineStrategy, TreeTextMemory_FINE_SEARCH_METHOD, TreeTextMemory_SEARCH_METHOD, ) @@ -91,9 +95,15 @@ def _build_enhancement_prompt(self, query_history: list[str], batch_texts: list[ if len(query_history) > 1 else query_history[0] ) - text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)]) + # Include numbering for rewrite mode to help LLM reference original memory IDs + if FINE_STRATEGY == FineStrategy.REWRITE: + text_memories = "\n".join([f"- [{i}] {mem}" for i, mem in enumerate(batch_texts)]) + prompt_name = "memory_rewrite_enhancement" + else: + text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)]) + prompt_name = "memory_recreate_enhancement" return self.build_prompt( - "memory_enhancement", + prompt_name, query_history=query_history, memories=text_memories, ) @@ -107,53 +117,80 @@ def _process_enhancement_batch( ) -> tuple[list[TextualMemoryItem], bool]: attempt = 0 text_memories = [one.memory for one in memories] - while attempt <= max(0, retries) + 1: - try: - prompt = self._build_enhancement_prompt( - query_history=query_history, batch_texts=text_memories - ) - logger.debug( - f"[Enhance][batch={batch_index}] Prompt (first 200 chars, len={len(prompt)}): " - f"{prompt[:200]}]..." - ) - response = self.process_llm.generate([{"role": "user", "content": prompt}]) - logger.debug( - f"[Enhance][batch={batch_index}] Response (first 200 chars): {response}..." - ) + prompt = self._build_enhancement_prompt( + query_history=query_history, batch_texts=text_memories + ) - processed_text_memories = extract_list_items_in_answer(response) - if len(processed_text_memories) == len(memories): - # Update - for i, new_mem in enumerate(processed_text_memories): - memories[i].memory = new_mem - enhanced_memories = memories - else: + llm_response = None + while attempt <= max(0, retries) + 1: + try: + llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) + processed_text_memories = extract_list_items_in_answer(llm_response) + if len(processed_text_memories) > 0: # create new enhanced_memories = [] user_id = memories[0].metadata.user_id - for new_mem in processed_text_memories: - enhanced_memories.append( - TextualMemoryItem( - memory=new_mem, metadata=TextualMemoryMetadata(user_id=user_id) + if FINE_STRATEGY == FineStrategy.RECREATE: + for new_mem in processed_text_memories: + enhanced_memories.append( + TextualMemoryItem( + memory=new_mem, metadata=TextualMemoryMetadata(user_id=user_id) + ) ) - ) - enhanced_memories = ( - enhanced_memories + memories[: len(memories) - len(enhanced_memories)] - ) + elif FINE_STRATEGY == FineStrategy.REWRITE: + # Parse index from each processed line and rewrite corresponding original memory + def _parse_index_and_text(s: str) -> tuple[int | None, str]: + import re + + s = (s or "").strip() + # Preferred: [index] text + m = re.match(r"^\s*\[(\d+)\]\s*(.+)$", s) + if m: + return int(m.group(1)), m.group(2).strip() + # Fallback: index: text or index - text + m = re.match(r"^\s*(\d+)\s*[:\-\)]\s*(.+)$", s) + if m: + return int(m.group(1)), m.group(2).strip() + return None, s + + idx_to_original = dict(enumerate(memories)) + for j, item in enumerate(processed_text_memories): + idx, new_text = _parse_index_and_text(item) + if idx is not None and idx in idx_to_original: + orig = idx_to_original[idx] + else: + # Fallback: align by order if index missing/invalid + orig = memories[j] if j < len(memories) else None + if not orig: + continue + enhanced_memories.append( + TextualMemoryItem( + id=orig.id, + memory=new_text, + metadata=orig.metadata, + ) + ) + else: + logger.error(f"Fine search strategy {FINE_STRATEGY} not exists") logger.info( - f"[Enhance]: processed_text_memories: {len(processed_text_memories)}; padded with original memories to preserve total count" + f"[enhance_memories_with_query] ✅ done | Strategy={FINE_STRATEGY} | prompt={prompt} | llm_response={llm_response}" + ) + return enhanced_memories, True + else: + raise ValueError( + f"Fail to run memory enhancement; retry {attempt}/{max(1, retries) + 1}; processed_text_memories: {processed_text_memories}" ) - - return enhanced_memories, True except Exception as e: attempt += 1 + time.sleep(1) logger.debug( - f"[Enhance][batch={batch_index}] 🔁 retry {attempt}/{max(1, retries) + 1} failed: {e}" + f"[enhance_memories_with_query][batch={batch_index}] 🔁 retry {attempt}/{max(1, retries) + 1} failed: {e}" ) logger.error( - f"Fail to run memory enhancement; original memories: {memories}", exc_info=True + f"Fail to run memory enhancement; prompt: {prompt};\n llm_response: {llm_response}", + exc_info=True, ) return memories, False @@ -170,6 +207,76 @@ def _split_batches( start = end return batches + def recall_for_missing_memories( + self, + query: str, + memories: list[TextualMemoryItem], + ) -> tuple[str, bool]: + text_memories = [one.memory for one in memories] if memories else [] + text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(text_memories)]) + + prompt = self.build_prompt( + template_name="enlarge_recall", + query=query, + memories_inline=text_memories, + ) + llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) + + json_result: dict = extract_json_obj(llm_response) + + logger.info( + f"[recall_for_missing_memories] ✅ done | prompt={prompt} | llm_response={llm_response}" + ) + + hint = json_result.get("hint", "") + if len(hint) == 0: + return hint, False + return hint, json_result.get("trigger_recall", False) + + def search( + self, + query: str, + mem_cube: GeneralMemCube, + top_k: int, + method: str = TreeTextMemory_SEARCH_METHOD, + info: dict | None = None, + ) -> list[TextualMemoryItem]: + """Search in text memory with the given query. + + Args: + query: The search query string + top_k: Number of top results to return + method: Search method to use + + Returns: + Search results or None if not implemented + """ + text_mem_base = mem_cube.text_mem + try: + if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]: + assert isinstance(text_mem_base, TreeTextMemory) + if info is None: + logger.warning( + "Please input 'info' when use tree.search so that " + "the database would store the consume history." + ) + info = {"user_id": "", "session_id": ""} + + mode = "fast" if method == TreeTextMemory_SEARCH_METHOD else "fine" + results_long_term = text_mem_base.search( + query=query, top_k=top_k, memory_type="LongTermMemory", mode=mode, info=info + ) + results_user = text_mem_base.search( + query=query, top_k=top_k, memory_type="UserMemory", mode=mode, info=info + ) + results = results_long_term + results_user + else: + raise NotImplementedError(str(type(text_mem_base))) + except Exception as e: + logger.error(f"Fail to search. The exeption is {e}.", exc_info=True) + results = [] + return results + def enhance_memories_with_query( self, query_history: list[str], @@ -239,54 +346,10 @@ def enhance_memories_with_query( enhanced_memories = memories if len(enhanced_memories) == 0: - enhanced_memories = memories + enhanced_memories = [] logger.error("[Enhance] ❌ fatal error: enhanced_memories is empty", exc_info=True) return enhanced_memories, all_success - def search( - self, - query: str, - mem_cube: GeneralMemCube, - top_k: int, - method: str = TreeTextMemory_SEARCH_METHOD, - info: dict | None = None, - ) -> list[TextualMemoryItem]: - """Search in text memory with the given query. - - Args: - query: The search query string - top_k: Number of top results to return - method: Search method to use - - Returns: - Search results or None if not implemented - """ - text_mem_base = mem_cube.text_mem - try: - if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]: - assert isinstance(text_mem_base, TreeTextMemory) - if info is None: - logger.warning( - "Please input 'info' when use tree.search so that " - "the database would store the consume history." - ) - info = {"user_id": "", "session_id": ""} - - mode = "fast" if method == TreeTextMemory_SEARCH_METHOD else "fine" - results_long_term = text_mem_base.search( - query=query, top_k=top_k, memory_type="LongTermMemory", mode=mode, info=info - ) - results_user = text_mem_base.search( - query=query, top_k=top_k, memory_type="UserMemory", mode=mode, info=info - ) - results = results_long_term + results_user - else: - raise NotImplementedError(str(type(text_mem_base))) - except Exception as e: - logger.error(f"Fail to search. The exeption is {e}.", exc_info=True) - results = [] - return results - def rerank_memories( self, queries: list[str], original_memories: list[str], top_k: int ) -> (list[str], bool): diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py index 99982d2e6..f8e321a82 100644 --- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py +++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py @@ -7,13 +7,13 @@ from memos.context.context import ContextThread, ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule -from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL, DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES, DEFAULT_STOP_WAIT, DEFAULT_STUCK_THREAD_TOLERANCE, ) +from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.utils.db_utils import get_utc_now diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index b62b1e51d..21b2d63f0 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -20,15 +20,13 @@ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.utils.api_utils import format_textual_memory_item from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory from memos.types import UserContext if TYPE_CHECKING: from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem - from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher - from memos.reranker.http_bge import HTTPBGEReranker - logger = get_logger(__name__) @@ -56,15 +54,6 @@ def __init__(self, config: GeneralSchedulerConfig): self.reranker = None self.text_mem = None - def init_mem_cube(self, mem_cube): - self.current_mem_cube = mem_cube - self.text_mem: TreeTextMemory = self.current_mem_cube.text_mem - self.searcher: Searcher = self.text_mem.get_searcher( - manual_close_internet=False, - moscube=False, - ) - self.reranker: HTTPBGEReranker = self.text_mem.reranker - def submit_memory_history_async_task( self, search_req: APISearchRequest, @@ -99,7 +88,7 @@ def submit_memory_history_async_task( ) # Submit async task - self.submit_messages([message]) + self.memos_message_queue.submit_messages([message]) logger.info(f"Submitted async fine search task for user {search_req.user_id}") return async_task_id @@ -141,6 +130,9 @@ def mix_search_memories( """ Mix search memories: fast search + async fine search """ + logger.info( + f"Mix searching memories for user {search_req.user_id} with query: {search_req.query}" + ) # Get mem_cube for fast search target_session_id = search_req.session_id @@ -173,17 +165,14 @@ def mix_search_memories( mem_cube_id=user_context.mem_cube_id, turns=self.history_memory_turns, ) - + logger.info(f"Found {len(history_memories)} history memories.") if not history_memories: - fast_memories = self.searcher.post_retrieve( + memories = self.searcher.post_retrieve( retrieved_results=fast_retrieved_memories, top_k=search_req.top_k, user_name=user_context.mem_cube_id, info=info, ) - # Format fast memories for return - formatted_memories = [format_textual_memory_item(data) for data in fast_memories] - return formatted_memories else: # if history memories can directly answer sorted_history_memories = self.reranker.rerank( @@ -192,7 +181,7 @@ def mix_search_memories( top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k search_filter=search_filter, ) - + logger.info(f"Reranked {len(sorted_history_memories)} history memories.") processed_hist_mem = self.searcher.post_retrieve( retrieved_results=sorted_history_memories, top_k=search_req.top_k, @@ -205,6 +194,7 @@ def mix_search_memories( ) if can_answer: + logger.info("History memories can answer the query.") sorted_results = fast_retrieved_memories + sorted_history_memories combined_results = self.searcher.post_retrieve( retrieved_results=sorted_results, @@ -213,9 +203,8 @@ def mix_search_memories( info=info, ) memories = combined_results[: search_req.top_k] - formatted_memories = [format_textual_memory_item(item) for item in memories] - logger.info("can_answer") else: + logger.info("History memories cannot answer the query, enhancing memories.") sorted_results = fast_retrieved_memories + sorted_history_memories combined_results = self.searcher.post_retrieve( retrieved_results=sorted_results, @@ -223,24 +212,53 @@ def mix_search_memories( user_name=user_context.mem_cube_id, info=info, ) - enhanced_results, _ = self.retriever.enhance_memories_with_query( + enhanced_memories, _ = self.retriever.enhance_memories_with_query( query_history=[search_req.query], memories=combined_results, ) - memories = enhanced_results[: search_req.top_k] - formatted_memories = [format_textual_memory_item(item) for item in memories] - logger.info("cannot answer") - - self.submit_memory_history_async_task( - search_req=search_req, - user_context=user_context, - memories_to_store={ - "memories": [one.to_dict() for one in memories], - "formatted_memories": formatted_memories, - }, - ) - return formatted_memories + if len(enhanced_memories) < search_req.top_k: + logger.info( + f"Enhanced memories ({len(enhanced_memories)}) are less than top_k ({search_req.top_k}). Recalling for more." + ) + missing_info_hint, trigger = self.retriever.recall_for_missing_memories( + query=search_req.query, + memories=combined_results, + ) + retrieval_size = search_req.top_k - len(enhanced_memories) + if trigger: + logger.info(f"Triggering additional search with hint: {missing_info_hint}") + additional_memories = self.searcher.search( + query=missing_info_hint, + user_name=user_context.mem_cube_id, + top_k=retrieval_size, + mode=SearchMode.FAST, + memory_type="All", + search_filter=search_filter, + info=info, + ) + else: + logger.info("Not triggering additional search, using combined results.") + additional_memories = combined_results[:retrieval_size] + logger.info( + f"Added {len(additional_memories)} more memories. Total enhanced memories: {len(enhanced_memories)}" + ) + enhanced_memories += additional_memories + + memories = enhanced_memories[: search_req.top_k] + + formatted_memories = [format_textual_memory_item(item) for item in memories] + logger.info("Submitted memory history async task.") + self.submit_memory_history_async_task( + search_req=search_req, + user_context=user_context, + memories_to_store={ + "memories": [one.to_dict() for one in memories], + "formatted_memories": formatted_memories, + }, + ) + + return formatted_memories def update_search_memories_to_redis( self, @@ -304,7 +322,7 @@ def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem]) logger.info(f"Messages {messages} assigned to {API_MIX_SEARCH_LABEL} handler.") # Process the query in a session turn - grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) + grouped_messages = group_messages_by_user_and_mem_cube(messages) self.validate_schedule_messages(messages=messages, label=API_MIX_SEARCH_LABEL) diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index 7f2c09b7d..524eab785 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -1,3 +1,5 @@ +import os + from enum import Enum from pathlib import Path from typing import NewType @@ -6,12 +8,18 @@ class SearchMode(str, Enum): """Enumeration for search modes.""" - NOT_INITIALIZED = "not_initialized" FAST = "fast" FINE = "fine" MIXTURE = "mixture" +class FineStrategy(str, Enum): + """Enumeration for fine strategies.""" + + REWRITE = "rewrite" + RECREATE = "recreate" + + FILE_PATH = Path(__file__).absolute() BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent @@ -32,17 +40,17 @@ class SearchMode(str, Enum): DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT = 20 DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache" DEFAULT_THREAD_POOL_MAX_WORKERS = 50 -DEFAULT_CONSUME_INTERVAL_SECONDS = 0.05 +DEFAULT_CONSUME_INTERVAL_SECONDS = 0.01 DEFAULT_CONSUME_BATCH = 1 DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300 DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2 DEFAULT_STUCK_THREAD_TOLERANCE = 10 -DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 0 +DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = -1 DEFAULT_TOP_K = 10 DEFAULT_CONTEXT_WINDOW_SIZE = 5 DEFAULT_USE_REDIS_QUEUE = True DEFAULT_MULTI_TASK_RUNNING_TIMEOUT = 30 -DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE = 10 +DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE = 20 DEFAULT_SCHEDULER_RETRIEVER_RETRIES = 1 DEFAULT_STOP_WAIT = False @@ -75,3 +83,17 @@ class SearchMode(str, Enum): # new types UserID = NewType("UserID", str) MemCubeID = NewType("CubeID", str) + +# algorithm strategies +DEFAULT_FINE_STRATEGY = FineStrategy.REWRITE + +# Read fine strategy from environment variable `FINE_STRATEGY`. +# If provided and valid, use it; otherwise fall back to default. +_env_fine_strategy = os.getenv("FINE_STRATEGY") +if _env_fine_strategy: + try: + FINE_STRATEGY = FineStrategy(_env_fine_strategy) + except ValueError: + FINE_STRATEGY = DEFAULT_FINE_STRATEGY +else: + FINE_STRATEGY = DEFAULT_FINE_STRATEGY diff --git a/evaluation/scripts/temporal_locomo/__init__.py b/src/memos/mem_scheduler/task_schedule_modules/__init__.py similarity index 100% rename from evaluation/scripts/temporal_locomo/__init__.py rename to src/memos/mem_scheduler/task_schedule_modules/__init__.py diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py similarity index 92% rename from src/memos/mem_scheduler/general_modules/dispatcher.py rename to src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index b74529c8c..ac9f9a6d0 100644 --- a/src/memos/mem_scheduler/general_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -10,12 +10,12 @@ from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule -from memos.mem_scheduler.general_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.general_modules.task_threads import ThreadManager from memos.mem_scheduler.schemas.general_schemas import DEFAULT_STOP_WAIT from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem from memos.mem_scheduler.utils.metrics import MetricsRegistry +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube logger = get_logger(__name__) @@ -151,15 +151,15 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): # acknowledge redis messages - if ( - self.use_redis_queue - and self.memos_message_queue is not None - and isinstance(self.memos_message_queue, SchedulerRedisQueue) - ): + if self.use_redis_queue and self.memos_message_queue is not None: for msg in messages: redis_message_id = msg.redis_message_id # Acknowledge message processing - self.memos_message_queue.ack_message(redis_message_id=redis_message_id) + self.memos_message_queue.ack_message( + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + redis_message_id=redis_message_id, + ) # Mark task as completed and remove from tracking with self._task_lock: @@ -329,38 +329,6 @@ def stats(self) -> dict[str, int]: def _default_message_handler(self, messages: list[ScheduleMessageItem]) -> None: logger.debug(f"Using _default_message_handler to deal with messages: {messages}") - def _group_messages_by_user_and_mem_cube( - self, messages: list[ScheduleMessageItem] - ) -> dict[str, dict[str, list[ScheduleMessageItem]]]: - """ - Groups messages into a nested dictionary structure first by user_id, then by mem_cube_id. - - Args: - messages: List of ScheduleMessageItem objects to be grouped - - Returns: - A nested dictionary with the structure: - { - "user_id_1": { - "mem_cube_id_1": [msg1, msg2, ...], - "mem_cube_id_2": [msg3, msg4, ...], - ... - }, - "user_id_2": { - ... - }, - ... - } - Where each msg is the original ScheduleMessageItem object - """ - grouped_dict = defaultdict(lambda: defaultdict(list)) - - for msg in messages: - grouped_dict[msg.user_id][msg.mem_cube_id].append(msg) - - # Convert defaultdict to regular dict for cleaner output - return {user_id: dict(cube_groups) for user_id, cube_groups in grouped_dict.items()} - def _handle_future_result(self, future): self._futures.remove(future) try: @@ -380,7 +348,7 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]): return # Group messages by user_id and mem_cube_id first - user_cube_groups = self._group_messages_by_user_and_mem_cube(msg_list) + user_cube_groups = group_messages_by_user_and_mem_cube(msg_list) # Process each user and mem_cube combination for user_id, cube_groups in user_cube_groups.items(): diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py new file mode 100644 index 000000000..93dd81132 --- /dev/null +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -0,0 +1,155 @@ +""" +Local Queue implementation for SchedulerMessageItem objects. +This module provides a local-based queue implementation that can replace +the local memos_message_queue functionality in BaseScheduler. +""" + +from memos.log import get_logger +from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule + + +logger = get_logger(__name__) + + +class SchedulerLocalQueue(RedisSchedulerModule): + def __init__( + self, + maxsize: int, + ): + """ + Initialize the SchedulerLocalQueue with a maximum queue size limit. + + Args: + maxsize (int): Maximum number of messages allowed + in each individual queue. + If exceeded, subsequent puts will block + or raise an exception based on `block` parameter. + """ + super().__init__() + + self.stream_key_prefix = "local_queue" + + self.max_internal_message_queue_size = maxsize + # Dictionary to hold per-stream queues: key = stream_key, value = Queue[ScheduleMessageItem] + self.queue_streams: dict[str, Queue[ScheduleMessageItem]] = {} + logger.info( + f"SchedulerLocalQueue initialized with max_internal_message_queue_size={maxsize}" + ) + + def get_stream_key(self, user_id: str, mem_cube_id: str) -> str: + stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}" + return stream_key + + def put( + self, message: ScheduleMessageItem, block: bool = True, timeout: float | None = None + ) -> None: + """ + Put a message into the appropriate internal queue based on user_id and mem_cube_id. + + If the corresponding queue does not exist, it is created automatically. + This method uses a local in-memory queue (not Redis) for buffering messages. + + Args: + message (ScheduleMessageItem): The message to enqueue. + block (bool): If True, block if the queue is full; if False, raise Full immediately. + timeout (float | None): Maximum time to wait for the queue to become available. + If None, block indefinitely. Ignored if block=False. + + Raises: + queue.Full: If the queue is full and block=False or timeout expires. + Exception: Any underlying error during queue.put() operation. + """ + stream_key = self.get_stream_key(user_id=message.user_id, mem_cube_id=message.mem_cube_id) + + # Create the queue if it doesn't exist yet + if stream_key not in self.queue_streams: + logger.info(f"Creating new internal queue for stream: {stream_key}") + self.queue_streams[stream_key] = Queue(maxsize=self.max_internal_message_queue_size) + + try: + self.queue_streams[stream_key].put(item=message, block=block, timeout=timeout) + logger.info( + f"Message successfully put into queue '{stream_key}'. Current size: {self.queue_streams[stream_key].qsize()}" + ) + except Exception as e: + logger.error(f"Failed to put message into queue '{stream_key}': {e}", exc_info=True) + raise # Re-raise to maintain caller expectations + + def get( + self, + user_id: str, + mem_cube_id: str, + block: bool = True, + timeout: float | None = None, + batch_size: int | None = None, + ) -> list[ScheduleMessageItem]: + if batch_size is not None and batch_size <= 0: + logger.warning( + f"get() called with invalid batch_size: {batch_size}. Returning empty list." + ) + return [] + + stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id) + + # Return empty list if queue does not exist + if stream_key not in self.queue_streams: + logger.error(f"Stream {stream_key} does not exist when trying to get messages.") + return [] + + # Note: Assumes custom Queue implementation supports batch_size parameter + res = self.queue_streams[stream_key].get( + block=block, timeout=timeout, batch_size=batch_size + ) + logger.debug( + f"Retrieved {len(res)} messages from queue '{stream_key}'. Current size: {self.queue_streams[stream_key].qsize()}" + ) + return res + + def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]: + """ + Non-blocking version of get(). Equivalent to get(block=False, batch_size=batch_size). + + Returns immediately with available messages or an empty list if queue is empty. + + Args: + batch_size (int | None): Number of messages to retrieve in a batch. + If None, retrieves one message. + + Returns: + List[ScheduleMessageItem]: Retrieved messages or empty list if queue is empty. + """ + logger.debug(f"get_nowait() called with batch_size: {batch_size}") + return self.get(block=False, batch_size=batch_size) + + def qsize(self) -> dict: + """ + Return the current size of all internal queues as a dictionary. + + Each key is the stream name, and each value is the number of messages in that queue. + + Returns: + Dict[str, int]: Mapping from stream name to current queue size. + """ + sizes = {stream: queue.qsize() for stream, queue in self.queue_streams.items()} + logger.debug(f"Current queue sizes: {sizes}") + return sizes + + def clear(self) -> None: + for queue in self.queue_streams.values(): + queue.clear() + + @property + def unfinished_tasks(self) -> int: + """ + Calculate the total number of unprocessed messages across all queues. + + This is a convenience property for monitoring overall system load. + + Returns: + int: Sum of all message counts in all internal queues. + """ + total = sum(self.qsize().values()) + logger.debug(f"Total unfinished tasks across all queues: {total}") + return total diff --git a/src/memos/mem_scheduler/general_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py similarity index 74% rename from src/memos/mem_scheduler/general_modules/redis_queue.py rename to src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index c10765d05..fe7e3452c 100644 --- a/src/memos/mem_scheduler/general_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -32,7 +32,7 @@ class SchedulerRedisQueue(RedisSchedulerModule): def __init__( self, - stream_name: str = "scheduler:messages:stream", + stream_key_prefix: str = "scheduler:messages:stream", consumer_group: str = "scheduler_group", consumer_name: str | None = "scheduler_consumer", max_len: int = 10000, @@ -43,7 +43,7 @@ def __init__( Initialize the Redis queue. Args: - stream_name: Name of the Redis stream + stream_key_prefix: Name of the Redis stream consumer_group: Name of the consumer group consumer_name: Name of the consumer (auto-generated if None) max_len: Maximum length of the stream (for memory management) @@ -57,7 +57,7 @@ def __init__( maxsize = 0 # Stream configuration - self.stream_name = stream_name + self.stream_key_prefix = stream_key_prefix self.consumer_group = consumer_group self.consumer_name = consumer_name or f"consumer_{uuid4().hex[:8]}" self.max_len = max_len @@ -77,26 +77,29 @@ def __init__( # Auto-initialize Redis connection if self.auto_initialize_redis(): self._is_connected = True - self._ensure_consumer_group() - def _ensure_consumer_group(self) -> None: + self.seen_streams = set() + + def get_stream_key(self, user_id: str, mem_cube_id: str) -> str: + stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}" + return stream_key + + def _ensure_consumer_group(self, stream_key) -> None: """Ensure the consumer group exists for the stream.""" if not self._redis_conn: return try: - self._redis_conn.xgroup_create( - self.stream_name, self.consumer_group, id="0", mkstream=True - ) + self._redis_conn.xgroup_create(stream_key, self.consumer_group, id="0", mkstream=True) logger.debug( - f"Created consumer group '{self.consumer_group}' for stream '{self.stream_name}'" + f"Created consumer group '{self.consumer_group}' for stream '{stream_key}'" ) except Exception as e: # Check if it's a "consumer group already exists" error error_msg = str(e).lower() if "busygroup" in error_msg or "already exists" in error_msg: logger.info( - f"Consumer group '{self.consumer_group}' already exists for stream '{self.stream_name}'" + f"Consumer group '{self.consumer_group}' already exists for stream '{stream_key}'" ) else: logger.error(f"Error creating consumer group: {e}", exc_info=True) @@ -123,12 +126,20 @@ def put( raise TypeError(f"Expected ScheduleMessageItem, got {type(message)}") try: + stream_key = self.get_stream_key( + user_id=message.user_id, mem_cube_id=message.mem_cube_id + ) + + if stream_key not in self.seen_streams: + self.seen_streams.add(stream_key) + self._ensure_consumer_group(stream_key=stream_key) + # Convert message to dictionary for Redis storage message_data = message.to_dict() # Add to Redis stream with automatic trimming message_id = self._redis_conn.xadd( - self.stream_name, message_data, maxlen=self.max_len, approximate=True + stream_key, message_data, maxlen=self.max_len, approximate=True ) logger.info( @@ -139,28 +150,23 @@ def put( logger.error(f"Failed to add message to Redis queue: {e}") raise - def put_nowait(self, message: ScheduleMessageItem) -> None: - """ - Add a message to the Redis queue without blocking (Queue-compatible interface). + def ack_message(self, user_id, mem_cube_id, redis_message_id) -> None: + stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id) - Args: - message: SchedulerMessageItem to add to the queue - """ - self.put(message, block=False) - - def ack_message(self, redis_message_id): - self.redis.xack(self.stream_name, self.consumer_group, redis_message_id) + self.redis.xack(stream_key, self.consumer_group, redis_message_id) # Optionally delete the message from the stream to keep it clean if self.auto_delete_acked: try: - self._redis_conn.xdel(self.stream_name, redis_message_id) + self._redis_conn.xdel(stream_key, redis_message_id) logger.info(f"Successfully delete acknowledged message {redis_message_id}") except Exception as e: logger.warning(f"Failed to delete acknowledged message {redis_message_id}: {e}") def get( self, + user_id: str, + mem_cube_id: str, block: bool = True, timeout: float | None = None, batch_size: int | None = None, @@ -169,6 +175,8 @@ def get( raise ConnectionError("Not connected to Redis. Redis connection not available.") try: + stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id) + # Calculate timeout for Redis redis_timeout = None if block and timeout is not None: @@ -181,7 +189,7 @@ def get( messages = self._redis_conn.xreadgroup( self.consumer_group, self.consumer_name, - {self.stream_name: ">"}, + {stream_key: ">"}, count=batch_size if not batch_size else 1, block=redis_timeout, ) @@ -190,12 +198,13 @@ def get( err_msg = str(read_err).lower() if "nogroup" in err_msg or "no such key" in err_msg: logger.warning( - f"Consumer group or stream missing for '{self.stream_name}/{self.consumer_group}'. Attempting to create and retry." + f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry." ) + self._ensure_consumer_group(stream_key=stream_key) messages = self._redis_conn.xreadgroup( self.consumer_group, self.consumer_name, - {self.stream_name: ">"}, + {stream_key: ">"}, count=batch_size if not batch_size else 1, block=redis_timeout, ) @@ -233,7 +242,9 @@ def get( logger.error(f"Failed to get message from Redis queue: {e}") raise - def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]: + def get_nowait( + self, user_id: str, mem_cube_id: str, batch_size: int | None = None + ) -> list[ScheduleMessageItem]: """ Get messages from the Redis queue without blocking (Queue-compatible interface). @@ -243,76 +254,58 @@ def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem] Raises: Empty: If no message is available """ - return self.get(block=False, batch_size=batch_size) + return self.get( + user_id=user_id, mem_cube_id=mem_cube_id, block=False, batch_size=batch_size + ) def qsize(self) -> int: """ Get the current size of the Redis queue (Queue-compatible interface). - Returns the number of pending (unacknowledged) messages in the consumer group, - which represents the actual queue size for processing. + This method scans for all streams matching the `stream_key_prefix` + and sums up their lengths to get the total queue size. Returns: - Number of pending messages in the queue + Total number of messages across all matching streams. """ if not self._redis_conn: return 0 + total_size = 0 try: - # Get pending messages info for the consumer group - # XPENDING returns info about pending messages that haven't been acknowledged - pending_info = self._redis_conn.xpending(self.stream_name, self.consumer_group) - - # pending_info[0] contains the count of pending messages - if pending_info and len(pending_info) > 0 and pending_info[0] is not None: - pending_count = int(pending_info[0]) - if pending_count > 0: - return pending_count - - # If no pending messages, check if there are new messages in the stream - # that haven't been read by any consumer yet - try: - # Get the last delivered ID for the consumer group - groups_info = self._redis_conn.xinfo_groups(self.stream_name) - if not groups_info: - # No groups exist, check total stream length - return self._redis_conn.xlen(self.stream_name) or 0 - - last_delivered_id = "0-0" - - for group_info in groups_info: - if group_info and group_info.get("name") == self.consumer_group: - last_delivered_id = group_info.get("last-delivered-id", "0-0") - break - - # Count messages after the last delivered ID - new_messages = self._redis_conn.xrange( - self.stream_name, - f"({last_delivered_id}", # Exclusive start - "+", # End at the latest message - count=1000, # Limit to avoid memory issues - ) + # Scan for all stream keys matching the prefix + for stream_key in self._redis_conn.scan_iter(f"{self.stream_key_prefix}:*"): + try: + # Get the length of each stream and add to total + total_size += self._redis_conn.xlen(stream_key) + except Exception as e: + logger.debug(f"Failed to get length for stream {stream_key}: {e}") + return total_size + except Exception as e: + logger.error(f"Failed to get Redis queue size: {e}") + return 0 - return len(new_messages) if new_messages else 0 + def get_stream_keys(self) -> list[str]: + """ + List all Redis stream keys that match this queue's prefix. - except Exception as inner_e: - logger.debug(f"Failed to get new messages count: {inner_e}") - # Fallback: return stream length - try: - stream_len = self._redis_conn.xlen(self.stream_name) - return stream_len if stream_len is not None else 0 - except Exception: - return 0 + Returns: + A list of stream keys like `"{prefix}:{user_id}:{mem_cube_id}"`. + """ + if not self._redis_conn: + return [] + try: + # Use match parameter and decode byte strings to regular strings + stream_keys = [ + key.decode("utf-8") if isinstance(key, bytes) else key + for key in self._redis_conn.scan_iter(match=f"{self.stream_key_prefix}:*") + ] + logger.debug(f"get stream_keys from redis: {stream_keys}") + return stream_keys except Exception as e: - logger.debug(f"Failed to get Redis queue size via XPENDING: {e}") - # Fallback to stream length if pending check fails - try: - stream_len = self._redis_conn.xlen(self.stream_name) - return stream_len if stream_len is not None else 0 - except Exception as fallback_e: - logger.error(f"Failed to get Redis queue size (all methods failed): {fallback_e}") - return 0 + logger.error(f"Failed to list Redis stream keys: {e}") + return [] def size(self) -> int: """ @@ -360,12 +353,13 @@ def clear(self) -> None: return try: - # Delete the entire stream - self._redis_conn.delete(self.stream_name) - logger.info(f"Cleared Redis stream: {self.stream_name}") + stream_keys = self.get_stream_keys() + + for stream_key in stream_keys: + # Delete the entire stream + self._redis_conn.delete(stream_key) + logger.info(f"Cleared Redis stream: {stream_key}") - # Recreate the consumer group - self._ensure_consumer_group() except Exception as e: logger.error(f"Failed to clear Redis queue: {e}") @@ -389,7 +383,7 @@ def start_listening( self._message_handler = handler self._is_listening = True - logger.info(f"Started listening on Redis stream: {self.stream_name}") + logger.info(f"Started listening on Redis stream: {self.stream_key_prefix}") try: while self._is_listening: diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py new file mode 100644 index 000000000..74f1ad1f8 --- /dev/null +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -0,0 +1,151 @@ +""" +Redis Queue implementation for SchedulerMessageItem objects. + +This module provides a Redis-based queue implementation that can replace +the local memos_message_queue functionality in BaseScheduler. +""" + +from collections import defaultdict + +from memos.log import get_logger +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue +from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue +from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube + + +logger = get_logger(__name__) + + +class ScheduleTaskQueue: + def __init__( + self, + use_redis_queue: bool, + maxsize: int, + disabled_handlers: list | None = None, + ): + self.use_redis_queue = use_redis_queue + self.maxsize = maxsize + + if self.use_redis_queue: + self.memos_message_queue = SchedulerRedisQueue(maxsize=self.maxsize) + else: + self.memos_message_queue = SchedulerLocalQueue(maxsize=self.maxsize) + + self.disabled_handlers = disabled_handlers + + def ack_message( + self, + user_id, + mem_cube_id, + redis_message_id, + ) -> None: + if not isinstance(self.memos_message_queue, SchedulerRedisQueue): + logger.warning("ack_message is only supported for Redis queues") + return + + self.memos_message_queue.ack_message( + user_id=user_id, + mem_cube_id=mem_cube_id, + redis_message_id=redis_message_id, + ) + + def debug_mode_on(self): + self.memos_message_queue.stream_key_prefix = ( + f"debug_mode:{self.memos_message_queue.stream_key_prefix}" + ) + + def get_stream_keys(self) -> list[str]: + if isinstance(self.memos_message_queue, SchedulerRedisQueue): + return self.memos_message_queue.get_stream_keys() + else: + return list(self.memos_message_queue.queue_streams.keys()) + + def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): + """Submit messages to the message queue (either local queue or Redis).""" + if isinstance(messages, ScheduleMessageItem): + messages = [messages] + + if len(messages) < 1: + logger.error("Submit empty") + elif len(messages) == 1: + self.memos_message_queue.put(messages[0]) + else: + user_cube_groups = group_messages_by_user_and_mem_cube(messages) + + # Process each user and mem_cube combination + for _user_id, cube_groups in user_cube_groups.items(): + for _mem_cube_id, user_cube_msgs in cube_groups.items(): + for message in user_cube_msgs: + if not isinstance(message, ScheduleMessageItem): + error_msg = f"Invalid message type: {type(message)}, expected ScheduleMessageItem" + logger.error(error_msg) + raise TypeError(error_msg) + + if getattr(message, "timestamp", None) is None: + message.timestamp = get_utc_now() + + if self.disabled_handlers and message.label in self.disabled_handlers: + logger.info( + f"Skipping disabled handler: {message.label} - {message.content}" + ) + continue + + self.memos_message_queue.put(message) + logger.info( + f"Submitted message to local queue: {message.label} - {message.content}" + ) + + def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: + # Discover all active streams via queue API + streams: list[tuple[str, str]] = [] + + stream_keys = self.get_stream_keys() + for stream_key in stream_keys: + try: + parts = stream_key.split(":") + if len(parts) >= 3: + user_id = parts[-2] + mem_cube_id = parts[-1] + streams.append((user_id, mem_cube_id)) + except Exception as e: + logger.debug(f"Failed to parse stream key {stream_key}: {e}") + + if not streams: + return [] + + messages: list[ScheduleMessageItem] = [] + + # Group by user: {user_id: [mem_cube_id, ...]} + + streams_by_user: dict[str, list[str]] = defaultdict(list) + for user_id, mem_cube_id in streams: + streams_by_user[user_id].append(mem_cube_id) + + # For each user, fairly consume up to batch_size across their streams + for user_id, mem_cube_ids in streams_by_user.items(): + if not mem_cube_ids: + continue + + # First pass: give each stream an equal share for this user + for mem_cube_id in mem_cube_ids: + fetched = self.memos_message_queue.get( + user_id=user_id, + mem_cube_id=mem_cube_id, + block=False, + batch_size=batch_size, + ) + + messages.extend(fetched) + + logger.info( + f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}" + ) + return messages + + def clear(self): + self.memos_message_queue.clear() + + def qsize(self): + return self.memos_message_queue.qsize() diff --git a/src/memos/mem_scheduler/utils/misc_utils.py b/src/memos/mem_scheduler/utils/misc_utils.py index cce1286bb..7b0bcea34 100644 --- a/src/memos/mem_scheduler/utils/misc_utils.py +++ b/src/memos/mem_scheduler/utils/misc_utils.py @@ -2,12 +2,16 @@ import re import traceback +from collections import defaultdict from functools import wraps from pathlib import Path import yaml from memos.log import get_logger +from memos.mem_scheduler.schemas.message_schemas import ( + ScheduleMessageItem, +) logger = get_logger(__name__) @@ -216,3 +220,36 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def group_messages_by_user_and_mem_cube( + messages: list[ScheduleMessageItem], +) -> dict[str, dict[str, list[ScheduleMessageItem]]]: + """ + Groups messages into a nested dictionary structure first by user_id, then by mem_cube_id. + + Args: + messages: List of ScheduleMessageItem objects to be grouped + + Returns: + A nested dictionary with the structure: + { + "user_id_1": { + "mem_cube_id_1": [msg1, msg2, ...], + "mem_cube_id_2": [msg3, msg4, ...], + ... + }, + "user_id_2": { + ... + }, + ... + } + Where each msg is the original ScheduleMessageItem object + """ + grouped_dict = defaultdict(lambda: defaultdict(list)) + + for msg in messages: + grouped_dict[msg.user_id][msg.mem_cube_id].append(msg) + + # Convert defaultdict to regular dict for cleaner output + return {user_id: dict(cube_groups) for user_id, cube_groups in grouped_dict.items()} diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 15a6a8b49..1b2355bc8 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -161,7 +161,7 @@ def search( info=None, mode: str = "fast", memory_type: str = "All", - manual_close_internet: bool = False, + manual_close_internet: bool = True, moscube: bool = False, search_filter: dict | None = None, user_name: str | None = None, @@ -189,9 +189,6 @@ def search( list[TextualMemoryItem]: List of matching memories. """ if (self.internet_retriever is not None) and manual_close_internet: - logger.warning( - "Internet retriever is init by config , but this search set manual_close_internet is True and will close it" - ) searcher = Searcher( self.dispatcher_llm, self.graph_store, @@ -201,6 +198,7 @@ def search( internet_retriever=None, moscube=moscube, search_strategy=self.search_strategy, + manual_close_internet=manual_close_internet, ) else: searcher = Searcher( @@ -212,6 +210,7 @@ def search( internet_retriever=self.internet_retriever, moscube=moscube, search_strategy=self.search_strategy, + manual_close_internet=manual_close_internet, ) return searcher.search( query, top_k, info, mode, memory_type, search_filter, user_name=user_name diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 14ea8e2cb..933ef5af1 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -43,6 +43,7 @@ def __init__( internet_retriever: None = None, moscube: bool = False, search_strategy: dict | None = None, + manual_close_internet: bool = True, ): self.graph_store = graph_store self.embedder = embedder @@ -58,7 +59,7 @@ def __init__( self.moscube = moscube self.vec_cot = search_strategy.get("cot", False) if search_strategy else False self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False - + self.manual_close_internet = manual_close_internet self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @timed @@ -72,7 +73,7 @@ def retrieve( search_filter: dict | None = None, user_name: str | None = None, **kwargs, - ) -> list[TextualMemoryItem]: + ) -> list[tuple[TextualMemoryItem, float]]: logger.info( f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" ) @@ -94,7 +95,7 @@ def retrieve( def post_retrieve( self, - retrieved_results: list[TextualMemoryItem], + retrieved_results: list[tuple[TextualMemoryItem, float]], top_k: int, user_name: str | None = None, info=None, @@ -458,7 +459,7 @@ def _retrieve_from_internet( user_id: str | None = None, ): """Retrieve and rerank from Internet source""" - if not self.internet_retriever or mode == "fast": + if not self.internet_retriever or self.manual_close_internet: logger.info(f"[PATH-C] '{query}' Skipped (no retriever, fast mode)") return [] if memory_type not in ["All"]: diff --git a/src/memos/templates/mem_scheduler_prompts.py b/src/memos/templates/mem_scheduler_prompts.py index 197a2c1a7..7f7415e79 100644 --- a/src/memos/templates/mem_scheduler_prompts.py +++ b/src/memos/templates/mem_scheduler_prompts.py @@ -390,24 +390,22 @@ - Focus on whether the memories can fully answer the query without additional information """ -MEMORY_ENHANCEMENT_PROMPT = """ +MEMORY_RECREATE_ENHANCEMENT_PROMPT = """ You are a knowledgeable and precise AI assistant. # GOAL -Transform each raw memory into an enhanced version that preserves all relevant factual details and makes the information directly useful for answering the user's query. - -# CORE PRINCIPLE -Focus on **relevance** — the enhanced memories should highlight, clarify, and preserve the information that most directly supports answering the current query. +Transform raw memories into clean, query-relevant facts — preserving timestamps and resolving ambiguities without inference. # RULES & THINKING STEPS -1. Read the user query carefully and identify what specific facts are needed to answer it. -2. Go through each memory and: - - Keep only details directly relevant to the query (dates, actions, entities, outcomes). - - Remove unrelated or background details. - - If nothing in a memory relates to the query, delete the entire memory. -3. Do not add or infer new facts. -4. Keep facts accurate and phrased clearly. -5. Each resulting line should stand alone as a usable fact for answering the query. +1. Keep ONLY what’s relevant to the user’s query. Delete irrelevant memories entirely. +2. Preserve ALL explicit timestamps (e.g., “on October 6”, “daily”, “after injury”). +3. Resolve all ambiguities using only memory content: + - Pronouns → full name: “she” → “Melanie” + - Vague nouns → specific detail: “home” → “her childhood home in Guangzhou” + - “the user” → identity from context (e.g., “Melanie” if travel/running memories) +4. Never invent, assume, or extrapolate. +5. Each output line must be a standalone, clear, factual statement. +6. Output format: one line per fact, starting with "- ", no extra text. # OUTPUT FORMAT (STRICT) Return ONLY the following block, with **one enhanced memory per line**. @@ -423,12 +421,91 @@ ## User Query {query_history} -## Available Memories +## Original Memories {memories} -Answer: +Final Output: +""" + +# Rewrite version: return enhanced memories with original IDs +MEMORY_REWRITE_ENHANCEMENT_PROMPT = """ +You are a knowledgeable and precise AI assistant. + +# GOAL +Transform raw memories into clean, query-relevant facts — preserving timestamps and resolving ambiguities without inference. Return each enhanced fact with the ID of the original memory being modified. + +# RULES & THINKING STEPS +1. Keep ONLY what’s relevant to the user’s query. Delete irrelevant memories entirely. +2. Preserve ALL explicit timestamps (e.g., “on October 6”, “daily”, “after injury”). +3. Resolve all ambiguities using only memory content: + - Pronouns → full name: “she” → “Melanie” + - Vague nouns → specific detail: “home” → “her childhood home in Guangzhou” + - “the user” → identity from context (e.g., “Melanie” if travel/running memories) +4. Never invent, assume, or extrapolate. +5. Each output line must be a standalone, clear, factual statement. +6. Output format: one line per fact, starting with "- ", no extra text. + +# IMPORTANT FOR REWRITE +- Each output line MUST include the original memory’s ID shown in the input list. +- Use the index shown for each original memory (e.g., "[0]", "[1]") as the ID to reference which memory you are rewriting. +- For every rewritten line, prefix with the corresponding index in square brackets. + +# OUTPUT FORMAT (STRICT) +Return ONLY the following block, with **one enhanced memory per line**. +Each line MUST start with "- " (dash + space) AND include index in square brackets. + +Wrap the final output inside: + +- [index] enhanced memory 1 +- [index] enhanced memory 2 +... + + +## User Query +{query_history} + +## Original Memories +{memories} + +Final Output: """ +# One-sentence prompt for recalling missing information to answer the query (English) +ENLARGE_RECALL_PROMPT_ONE_SENTENCE = """ +You are a precise AI assistant. Your job is to analyze the user's query and the available memories to identify what specific information is missing to fully answer the query. + +# GOAL + +Identify the specific missing facts needed to fully answer the user's query and generate a concise hint for recalling them. + +# RULES + +- Analyze the user's query to understand what information is being asked. +- Review the available memories to see what information is already present. +- Identify the gap between the user's query and the available memories. +- Generate a single, concise hint that prompts the user to provide the missing information. +- The hint should be a direct question or a statement that clearly indicates what is needed. + +# OUTPUT FORMAT +A JSON object with: + +trigger_retrieval: true if information is missing, false if sufficient. +hint: A clear, specific prompt to retrieve the missing information (or an empty string if trigger_retrieval is false): +{{ + "trigger_recall": , + "hint": a paraphrase to retrieve support memories +}} + +## User Query +{query} + +## Available Memories +{memories_inline} + +Final Output: +""" + + PROMPT_MAPPING = { "intent_recognizing": INTENT_RECOGNIZING_PROMPT, "memory_reranking": MEMORY_RERANKING_PROMPT, @@ -437,7 +514,9 @@ "memory_redundancy_filtering": MEMORY_REDUNDANCY_FILTERING_PROMPT, "memory_combined_filtering": MEMORY_COMBINED_FILTERING_PROMPT, "memory_answer_ability_evaluation": MEMORY_ANSWER_ABILITY_EVALUATION_PROMPT, - "memory_enhancement": MEMORY_ENHANCEMENT_PROMPT, + "memory_recreate_enhancement": MEMORY_RECREATE_ENHANCEMENT_PROMPT, + "memory_rewrite_enhancement": MEMORY_REWRITE_ENHANCEMENT_PROMPT, + "enlarge_recall": ENLARGE_RECALL_PROMPT_ONE_SENTENCE, } MEMORY_ASSEMBLY_TEMPLATE = """The retrieved memories are listed as follows:\n\n {memory_text}""" diff --git a/tests/mem_scheduler/test_dispatcher.py b/tests/mem_scheduler/test_dispatcher.py index fc154e013..e687d2986 100644 --- a/tests/mem_scheduler/test_dispatcher.py +++ b/tests/mem_scheduler/test_dispatcher.py @@ -14,10 +14,11 @@ ) from memos.llms.base import BaseLLM from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.scheduler_factory import SchedulerFactory from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem +from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube from memos.memories.textual.tree import TreeTextMemory @@ -192,9 +193,8 @@ def test_dispatch_serial(self): def test_group_messages_by_user_and_mem_cube(self): """Test grouping messages by user and cube.""" - # Check actual grouping logic - with patch("memos.mem_scheduler.general_modules.dispatcher.logger.debug"): - result = self.dispatcher._group_messages_by_user_and_mem_cube(self.test_messages) + # Check actual grouping logic using shared utility function + result = group_messages_by_user_and_mem_cube(self.test_messages) # Adjust expected results based on actual grouping logic # Note: According to dispatcher.py implementation, grouping is by mem_cube_id not mem_cube From e234da99e2232a8cf05ca485171ecbb640282a5b Mon Sep 17 00:00:00 2001 From: Travis Tang Date: Wed, 19 Nov 2025 16:17:22 +0800 Subject: [PATCH 14/18] feat & refactor: add searcher to handler_init and remove logger info from get_messages (#501) * debug an error function name * feat: Add DynamicCache compatibility for different transformers versions - Fix build_kv_cache method in hf.py to handle both old and new DynamicCache structures - Support new 'layers' attribute with key_cache/value_cache or keys/values - Maintain backward compatibility with direct key_cache/value_cache attributes - Add comprehensive error handling and logging for unsupported structures - Update move_dynamic_cache_htod function in kv.py for cross-version compatibility - Handle layers-based structure in newer transformers versions - Support alternative attribute names (keys/values vs key_cache/value_cache) - Preserve original functionality for older transformers versions - Add comprehensive tests for DynamicCache compatibility - Test activation memory update with mock DynamicCache layers - Verify layers attribute access across different transformers versions - Fix scheduler logger mock to include memory_manager attribute This resolves AttributeError issues when using different versions of the transformers library and ensures robust handling of DynamicCache objects. debug * feat: implement APIAnalyzerForScheduler for memory operations - Add APIAnalyzerForScheduler class with search/add operations - Support requests and http.client with connection reuse - Include comprehensive error handling and dynamic configuration - Add English test suite with real-world conversation scenarios * feat: Add search_ws API endpoint and enhance API analyzer functionality - Add search_ws endpoint in server_router.py for scheduler-enabled search - Fix missing imports: time module, SearchRequest class, and get_mos_product_instance function - Implement search_ws method in api_analyzer.py with HTTP client support - Add _search_ws_with_requests and _search_ws_with_http_client private methods - Include search_ws usage example in demonstration code - Enhance scheduler and dispatcher capabilities for improved memory management - Expand test coverage to ensure functionality stability This update primarily strengthens the memory scheduling system's search capabilities, providing users with more flexible API interface options. * fix: resolve test failures and warnings in test suite - Fix Pydantic serialization warning in test_memos_chen_tang_hello_world * Add warnings filter to suppress UserWarning from Pydantic serialization - Fix KeyError: 'past_key_values' in test_build_kv_cache_and_generation * Update mock configuration to properly return forward_output with past_key_values * Add DynamicCache version compatibility handling in test mocks * Support both old and new transformers versions with layers/key_cache attributes * Improve assertion logic to check all model calls for required parameters - Update base_scheduler.py to use centralized DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE constant * Add import for DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE from general_schemas * Replace hardcoded value 100 with configurable constant (1000) All tests now pass successfully with proper version compatibility handling. * feat: add a test_robustness execution to test thread pool execution * feat: optimize scheduler configuration and API search functionality - Add DEFAULT_TOP_K and DEFAULT_CONTEXT_WINDOW_SIZE global constants in general_schemas.py - Update base_scheduler.py to use global default values instead of hardcoded numbers - Fix SchedulerConfigFactory initialization issue by using keyword argument expansion - Resolve UnboundLocalError variable conflict in search_memories_ws function - Fix indentation and parameter issues in OptimizedScheduler search_for_api method - Improve code standardization and maintainability * feat: Add Redis auto-initialization with fallback strategies - Add auto_initialize_redis() with config/env/local fallback - Move Redis logic from dispatcher_monitor to redis_service - Update base_scheduler to use auto initialization - Add proper resource cleanup and error handling * feat: add database connection management to ORM module - Add MySQL engine loading from environment variables in BaseDBManager - Add Redis connection loading from environment variables in BaseDBManager - Enhance database configuration validation and error handling - Complete database adapter infrastructure for ORM module - Provide unified database connection management interface This update provides comprehensive database connection management capabilities for the mem_scheduler module, supporting dynamic MySQL and Redis configuration loading from environment variables, establishing reliable data persistence foundation for scheduling services and API services. * remove part of test * feat: add Redis-based ORM with multiprocess synchronization - Add RedisDBManager and RedisLockableORM classes - Implement atomic locking mechanism for concurrent access - Add merge functionality for different object types - Include comprehensive test suite and examples - Fix Redis key type conflicts in lock operations * fix: resolve scheduler module import and Redis integration issues * revise naive memcube creation in server router * remove long-time tests in test_scheduler * remove redis test which needs .env * refactor all codes about mixture search with scheduler * fix: resolve Redis API synchronization issues and implement search API with reranker - Fix running_entries to running_task_ids migration across codebase - Update sync_search_data method to properly handle TaskRunningStatus - Correct variable naming and logic in API synchronization flow - Implement search API endpoint with reranker functionality - Update test files to reflect new running_task_ids convention - Ensure proper Redis state management for concurrent tasks * remove a test for api module * revise to pass the test suite * address some bugs to make mix_search normally running * modify codes according to evaluation logs * feat: Optimize mixture search and enhance API client * feat: Add conversation_turn tracking for session-based memory search - Add conversation_turn field to APIMemoryHistoryEntryItem schema with default value 0 - Implement session counter in OptimizedScheduler to track turn count per session_id - Update sync_search_data method to accept and store conversation_turn parameter - Maintain session history with LRU eviction (max 5 sessions) - Rename conversation_id to session_id for consistency with request object - Enable direct access to session_id from search requests This feature allows tracking conversation turns within the same session, providing better context for memory retrieval and search history management. * adress time bug in monitor * revise simple tree * add mode to evaluation client; rewrite print to logger.info in db files * feat: 1. add redis queue for scheduler 2. finish the code related to mix search and fine search * debug the working memory code * addressed a range of bugs to make scheduler running correctly * remove test_dispatch_parallel test * print change to logger.info * adjucted the core code related to fine and mixture apis * feat: create task queue to wrap local queue and redis queue. queue now split FIFO to multi queue from different users. addressed a range of bugs * fix bugs: debug bugs about internet trigger * debug get searcher mode * feat: add manual internet * Fix: fix code format * feat: add strategy for fine search * debug redis queue * debug redis queue * fix bugs: completely addressed bugs about redis queue * refactor: add searcher to handler_init; remove info log from task_queue --------- Co-authored-by: fridayL --- examples/mem_scheduler/api_w_scheduler.py | 9 +-- src/memos/api/handlers/base_handler.py | 8 +++ src/memos/api/handlers/component_init.py | 22 +++++-- src/memos/api/handlers/search_handler.py | 39 +++++++----- src/memos/mem_scheduler/base_scheduler.py | 19 ++++-- .../mem_scheduler/schemas/general_schemas.py | 1 + .../task_schedule_modules/dispatcher.py | 1 - .../task_schedule_modules/local_queue.py | 5 +- .../task_schedule_modules/redis_queue.py | 32 +++++----- .../task_schedule_modules/task_queue.py | 60 ++++++------------- 10 files changed, 101 insertions(+), 95 deletions(-) diff --git a/examples/mem_scheduler/api_w_scheduler.py b/examples/mem_scheduler/api_w_scheduler.py index a2184e9ca..1b59543f3 100644 --- a/examples/mem_scheduler/api_w_scheduler.py +++ b/examples/mem_scheduler/api_w_scheduler.py @@ -25,9 +25,10 @@ def my_test_handler(messages: list[ScheduleMessageItem]): print(f"My test handler received {len(messages)} messages:") for msg in messages: print(f" my_test_handler - {msg.item_id}: {msg.content}") - print( - f"{queue._redis_conn.xinfo_groups(queue.stream_key_prefix)} qsize: {queue.qsize()} messages:{messages}" + user_status_running = handle_scheduler_status( + user_name=USER_MEM_CUBE, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler" ) + print(f"[Monitor] Status for {USER_MEM_CUBE} after submit:", user_status_running) # 2. Register the handler @@ -59,10 +60,6 @@ def my_test_handler(messages: list[ScheduleMessageItem]): # 5.1 Monitor status for specific mem_cube while running USER_MEM_CUBE = "test_mem_cube" -user_status_running = handle_scheduler_status( - user_name=USER_MEM_CUBE, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler" -) -print(f"[Monitor] Status for {USER_MEM_CUBE} after submit:", user_status_running) # 6. Wait for messages to be processed (limited to 100 checks) print("Waiting for messages to be consumed (max 100 checks)...") diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py index a174defb1..a686ac8f9 100644 --- a/src/memos/api/handlers/base_handler.py +++ b/src/memos/api/handlers/base_handler.py @@ -9,6 +9,7 @@ from memos.log import get_logger from memos.mem_scheduler.base_scheduler import BaseScheduler +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher logger = get_logger(__name__) @@ -28,6 +29,7 @@ def __init__( naive_mem_cube: Any | None = None, mem_reader: Any | None = None, mem_scheduler: Any | None = None, + searcher: Any | None = None, embedder: Any | None = None, reranker: Any | None = None, graph_db: Any | None = None, @@ -58,6 +60,7 @@ def __init__( self.naive_mem_cube = naive_mem_cube self.mem_reader = mem_reader self.mem_scheduler = mem_scheduler + self.searcher = searcher self.embedder = embedder self.reranker = reranker self.graph_db = graph_db @@ -128,6 +131,11 @@ def mem_scheduler(self) -> BaseScheduler: """Get scheduler instance.""" return self.deps.mem_scheduler + @property + def searcher(self) -> Searcher: + """Get scheduler instance.""" + return self.deps.searcher + @property def embedder(self): """Get embedder instance.""" diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 4e696a341..78ed13e1f 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -5,6 +5,8 @@ including databases, LLMs, memory systems, and schedulers. """ +import os + from typing import TYPE_CHECKING, Any from memos.api.config import APIConfig @@ -38,6 +40,10 @@ from memos.memories.textual.simple_preference import SimplePreferenceTextMemory from memos.memories.textual.simple_tree import SimpleTreeTextMemory from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager + + +if TYPE_CHECKING: + from memos.memories.textual.tree import TreeTextMemory from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import ( InternetRetrieverFactory, ) @@ -47,7 +53,7 @@ if TYPE_CHECKING: from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler - + from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher logger = get_logger(__name__) @@ -205,6 +211,13 @@ def init_server() -> dict[str, Any]: logger.debug("MemCube created") + tree_mem: TreeTextMemory = naive_mem_cube.text_mem + searcher: Searcher = tree_mem.get_searcher( + manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", + moscube=False, + ) + logger.debug("Searcher created") + # Initialize Scheduler scheduler_config_dict = APIConfig.get_scheduler_config() scheduler_config = SchedulerConfigFactory( @@ -217,16 +230,14 @@ def init_server() -> dict[str, Any]: db_engine=BaseDBManager.create_default_sqlite_engine(), mem_reader=mem_reader, ) - mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube) + mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube, searcher=searcher) logger.debug("Scheduler initialized") # Initialize SchedulerAPIModule api_module = mem_scheduler.api_module # Start scheduler if enabled - import os - - if os.getenv("API_SCHEDULER_ON", True): + if os.getenv("API_SCHEDULER_ON", "true").lower() == "true": mem_scheduler.start() logger.info("Scheduler started") @@ -253,6 +264,7 @@ def init_server() -> dict[str, Any]: "mos_server": mos_server, "mem_scheduler": mem_scheduler, "naive_mem_cube": naive_mem_cube, + "searcher": searcher, "api_module": api_module, "vector_db": vector_db, "pref_extractor": pref_extractor, diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index cf2ab73bb..7d7d52dc4 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -18,7 +18,7 @@ from memos.api.product_models import APISearchRequest, SearchResponse from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger -from memos.mem_scheduler.schemas.general_schemas import SearchMode +from memos.mem_scheduler.schemas.general_schemas import FINE_STRATEGY, FineStrategy, SearchMode from memos.types import MOSSearchResult, UserContext @@ -40,7 +40,7 @@ def __init__(self, dependencies: HandlerDependencies): dependencies: HandlerDependencies instance """ super().__init__(dependencies) - self._validate_dependencies("naive_mem_cube", "mem_scheduler") + self._validate_dependencies("naive_mem_cube", "mem_scheduler", "searcher") def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse: """ @@ -211,11 +211,17 @@ def _fast_search( return formatted_memories + def _deep_search( + self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int + ) -> list: + logger.error("waiting to be implemented") + return [] + def _fine_search( self, search_req: APISearchRequest, user_context: UserContext, - ) -> list: + ) -> list[str]: """ Fine-grained search with query enhancement. @@ -226,11 +232,14 @@ def _fine_search( Returns: List of enhanced search results """ + if FINE_STRATEGY == FineStrategy.DEEP_SEARCH: + return self._deep_search( + search_req=search_req, user_context=user_context, max_thinking_depth=3 + ) + target_session_id = search_req.session_id or "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - searcher = self.mem_scheduler.searcher - info = { "user_id": search_req.user_id, "session_id": target_session_id, @@ -238,7 +247,7 @@ def _fine_search( } # Fine retrieve - fast_retrieved_memories = searcher.retrieve( + raw_retrieved_memories = self.searcher.retrieve( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, @@ -250,8 +259,8 @@ def _fine_search( ) # Post retrieve - fast_memories = searcher.post_retrieve( - retrieved_results=fast_retrieved_memories, + raw_memories = self.searcher.post_retrieve( + retrieved_results=raw_retrieved_memories, top_k=search_req.top_k, user_name=user_context.mem_cube_id, info=info, @@ -260,22 +269,22 @@ def _fine_search( # Enhance with query enhanced_memories, _ = self.mem_scheduler.retriever.enhance_memories_with_query( query_history=[search_req.query], - memories=fast_memories, + memories=raw_memories, ) - if len(enhanced_memories) < len(fast_memories): + if len(enhanced_memories) < len(raw_memories): logger.info( - f"Enhanced memories ({len(enhanced_memories)}) are less than fast memories ({len(fast_memories)}). Recalling for more." + f"Enhanced memories ({len(enhanced_memories)}) are less than raw memories ({len(raw_memories)}). Recalling for more." ) missing_info_hint, trigger = self.mem_scheduler.retriever.recall_for_missing_memories( query=search_req.query, - memories=fast_memories, + memories=raw_memories, ) - retrieval_size = len(fast_memories) - len(enhanced_memories) + retrieval_size = len(raw_memories) - len(enhanced_memories) logger.info(f"Retrieval size: {retrieval_size}") if trigger: logger.info(f"Triggering additional search with hint: {missing_info_hint}") - additional_memories = searcher.search( + additional_memories = self.searcher.search( query=missing_info_hint, user_name=user_context.mem_cube_id, top_k=retrieval_size, @@ -286,7 +295,7 @@ def _fine_search( ) else: logger.info("Not triggering additional search, using fast memories.") - additional_memories = fast_memories[:retrieval_size] + additional_memories = raw_memories[:retrieval_size] enhanced_memories += additional_memories logger.info( diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index 657ceea0f..6ad7f5cdd 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -54,11 +54,11 @@ from memos.memories.activation.kv import KVCacheMemory from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory +from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE if TYPE_CHECKING: - from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher from memos.reranker.http_bge import HTTPBGEReranker @@ -141,14 +141,21 @@ def __init__(self, config: BaseSchedulerConfig): self.auth_config = None self.rabbitmq_config = None - def init_mem_cube(self, mem_cube): + def init_mem_cube( + self, + mem_cube: BaseMemCube, + searcher: Searcher | None = None, + ): self.mem_cube = mem_cube self.text_mem: TreeTextMemory = self.mem_cube.text_mem - self.searcher: Searcher = self.text_mem.get_searcher( - manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", - moscube=False, - ) self.reranker: HTTPBGEReranker = self.text_mem.reranker + if searcher is None: + self.searcher: Searcher = self.text_mem.get_searcher( + manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", + moscube=False, + ) + else: + self.searcher = searcher def initialize_modules( self, diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index 524eab785..8dd51c5bd 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -18,6 +18,7 @@ class FineStrategy(str, Enum): REWRITE = "rewrite" RECREATE = "recreate" + DEEP_SEARCH = "deep_search" FILE_PATH = Path(__file__).absolute() diff --git a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index ac9f9a6d0..b1a304754 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -150,7 +150,6 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time()) # acknowledge redis messages - if self.use_redis_queue and self.memos_message_queue is not None: for msg in messages: redis_message_id = msg.redis_message_id diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py index 93dd81132..f7e3eac15 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -79,8 +79,7 @@ def put( def get( self, - user_id: str, - mem_cube_id: str, + stream_key: str, block: bool = True, timeout: float | None = None, batch_size: int | None = None, @@ -91,8 +90,6 @@ def get( ) return [] - stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id) - # Return empty list if queue does not exist if stream_key not in self.queue_streams: logger.error(f"Stream {stream_key} does not exist when trying to get messages.") diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index fe7e3452c..5e850c8ce 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -5,6 +5,7 @@ the local memos_message_queue functionality in BaseScheduler. """ +import re import time from collections.abc import Callable @@ -165,8 +166,7 @@ def ack_message(self, user_id, mem_cube_id, redis_message_id) -> None: def get( self, - user_id: str, - mem_cube_id: str, + stream_key: str, block: bool = True, timeout: float | None = None, batch_size: int | None = None, @@ -175,8 +175,6 @@ def get( raise ConnectionError("Not connected to Redis. Redis connection not available.") try: - stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id) - # Calculate timeout for Redis redis_timeout = None if block and timeout is not None: @@ -295,17 +293,21 @@ def get_stream_keys(self) -> list[str]: if not self._redis_conn: return [] - try: - # Use match parameter and decode byte strings to regular strings - stream_keys = [ - key.decode("utf-8") if isinstance(key, bytes) else key - for key in self._redis_conn.scan_iter(match=f"{self.stream_key_prefix}:*") - ] - logger.debug(f"get stream_keys from redis: {stream_keys}") - return stream_keys - except Exception as e: - logger.error(f"Failed to list Redis stream keys: {e}") - return [] + # First, get all keys that might match (using Redis pattern matching) + redis_pattern = f"{self.stream_key_prefix}:*" + raw_keys = [ + key.decode("utf-8") if isinstance(key, bytes) else key + for key in self._redis_conn.scan_iter(match=redis_pattern) + ] + + # Second, filter using Python regex to ensure exact prefix match + # Escape special regex characters in the prefix, then add :.* + escaped_prefix = re.escape(self.stream_key_prefix) + regex_pattern = f"^{escaped_prefix}:" + stream_keys = [key for key in raw_keys if re.match(regex_pattern, key)] + + logger.debug(f"get stream_keys from redis: {stream_keys}") + return stream_keys def size(self) -> int: """ diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py index 74f1ad1f8..6d824f4b1 100644 --- a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -5,8 +5,6 @@ the local memos_message_queue functionality in BaseScheduler. """ -from collections import defaultdict - from memos.log import get_logger from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue @@ -58,9 +56,10 @@ def debug_mode_on(self): def get_stream_keys(self) -> list[str]: if isinstance(self.memos_message_queue, SchedulerRedisQueue): - return self.memos_message_queue.get_stream_keys() + stream_keys = self.memos_message_queue.get_stream_keys() else: - return list(self.memos_message_queue.queue_streams.keys()) + stream_keys = list(self.memos_message_queue.queue_streams.keys()) + return stream_keys def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): """Submit messages to the message queue (either local queue or Redis).""" @@ -98,50 +97,25 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt ) def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: - # Discover all active streams via queue API - streams: list[tuple[str, str]] = [] - stream_keys = self.get_stream_keys() - for stream_key in stream_keys: - try: - parts = stream_key.split(":") - if len(parts) >= 3: - user_id = parts[-2] - mem_cube_id = parts[-1] - streams.append((user_id, mem_cube_id)) - except Exception as e: - logger.debug(f"Failed to parse stream key {stream_key}: {e}") - - if not streams: + + if len(stream_keys) == 0: return [] messages: list[ScheduleMessageItem] = [] - # Group by user: {user_id: [mem_cube_id, ...]} - - streams_by_user: dict[str, list[str]] = defaultdict(list) - for user_id, mem_cube_id in streams: - streams_by_user[user_id].append(mem_cube_id) - - # For each user, fairly consume up to batch_size across their streams - for user_id, mem_cube_ids in streams_by_user.items(): - if not mem_cube_ids: - continue - - # First pass: give each stream an equal share for this user - for mem_cube_id in mem_cube_ids: - fetched = self.memos_message_queue.get( - user_id=user_id, - mem_cube_id=mem_cube_id, - block=False, - batch_size=batch_size, - ) - - messages.extend(fetched) - - logger.info( - f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}" - ) + for stream_key in stream_keys: + fetched = self.memos_message_queue.get( + stream_key=stream_key, + block=False, + batch_size=batch_size, + ) + + messages.extend(fetched) + if len(messages) > 0: + logger.debug( + f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}" + ) return messages def clear(self): From 4e1a9919ba8be2f7d1646e7483d458f0319d90b2 Mon Sep 17 00:00:00 2001 From: HarveyXiang Date: Thu, 20 Nov 2025 14:32:56 +0800 Subject: [PATCH 15/18] feat: log support time rotating (#504) * Update API Reference link in README.md * hotfix bug in pref init * feat: log support time rotating * feat: log support time rotating * feat: log support time rotating * feat: delete useless log * feat: delete useless log --------- Co-authored-by: CaralHsi Co-authored-by: yuan.wang Co-authored-by: harvey_xiang --- README.md | 6 +++--- src/memos/log.py | 8 +++++--- .../monitors/dispatcher_monitor.py | 17 +++-------------- src/memos/mem_scheduler/utils/metrics.py | 10 ---------- .../tree_text_memory/retrieve/recall.py | 18 ------------------ 5 files changed, 11 insertions(+), 48 deletions(-) diff --git a/README.md b/README.md index 4269abf0a..cb464b9cd 100644 --- a/README.md +++ b/README.md @@ -53,8 +53,8 @@ Designed for **AI companions, role-playing NPCs, and multi-agent systems**, MemO -Get Free API: [Try API](https://memos-dashboard.openmem.net/quickstart/?source=github) - +Get Free API: [Try API](https://memos-dashboard.openmem.net/quickstart/?source=github) + --- @@ -64,7 +64,7 @@ Get Free API: [Try API](https://memos-dashboard.openmem.net/quickstart/?source=g - **Website**: https://memos.openmem.net/ - **Documentation**: https://memos-docs.openmem.net/home/overview/ -- **API Reference**: https://memos-docs.openmem.net/docs/api/info/ +- **API Reference**: https://memos-docs.openmem.net/api-reference/configure-memos/ - **Source Code**: https://github.com/MemTensor/MemOS ## 📰 News diff --git a/src/memos/log.py b/src/memos/log.py index faa808414..874f2c6a7 100644 --- a/src/memos/log.py +++ b/src/memos/log.py @@ -37,6 +37,7 @@ def _setup_logfile() -> Path: logfile = Path(settings.MEMOS_DIR / "logs" / "memos.log") logfile.parent.mkdir(parents=True, exist_ok=True) logfile.touch(exist_ok=True) + return logfile @@ -195,10 +196,11 @@ def close(self): }, "file": { "level": "DEBUG", - "class": "logging.handlers.RotatingFileHandler", + "class": "logging.handlers.TimedRotatingFileHandler", + "when": "midnight", + "interval": 1, + "backupCount": 3, "filename": _setup_logfile(), - "maxBytes": 1024**2 * 10, - "backupCount": 10, "formatter": "standard", "filters": ["context_filter"], }, diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py index f8e321a82..03221aa7b 100644 --- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py +++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py @@ -135,7 +135,9 @@ def _check_pools_health(self) -> None: pool_info=pool_info, stuck_max_interval=4, ) - logger.info(f"Pool '{name}'. is_healthy: {is_healthy}. pool_info: {pool_info}") + if not is_healthy: + logger.info(f"Pool '{name}'. is_healthy: {is_healthy}. pool_info: {pool_info}") + with self._pool_lock: if is_healthy: pool_info["failure_count"] = 0 @@ -237,20 +239,7 @@ def _check_pool_health( # Log health status with comprehensive information if self.dispatcher: - # Check thread activity - active_threads = sum( - 1 - for t in threading.enumerate() - if t.name.startswith(executor._thread_name_prefix) # pylint: disable=protected-access - ) - - task_count = self.dispatcher.get_running_task_count() max_workers = pool_info.get("max_workers", 0) - stuck_count = len(stuck_tasks) - logger.info( - f"Pool health check passed - {active_threads} active threads, " - f"{task_count} running tasks, pool size: {max_workers}, stuck tasks: {stuck_count}" - ) return True, "" diff --git a/src/memos/mem_scheduler/utils/metrics.py b/src/memos/mem_scheduler/utils/metrics.py index 45abc5b36..0d781c996 100644 --- a/src/memos/mem_scheduler/utils/metrics.py +++ b/src/memos/mem_scheduler/utils/metrics.py @@ -188,12 +188,7 @@ def on_enqueue( inst_rate = (1.0 / max(1e-3, dt)) if dt is not None else 0.0 # first sample: no spike ls.last_enqueue_ts = now ls.backlog += 1 - old_lam = ls.lambda_ewma.value_at(now) ls.lambda_ewma.update(inst_rate, now) - new_lam = ls.lambda_ewma.value_at(now) - logger.info( - f"[DEBUG enqueue] {label} backlog={ls.backlog} dt={dt if dt is not None else '—'}s inst={inst_rate:.3f} λ {old_lam:.3f}→{new_lam:.3f}" - ) self._label_topk[label].add(mem_cube_id) ds = self._get_detail(label, mem_cube_id) if ds: @@ -226,12 +221,7 @@ def on_done( ls.last_done_ts = now if ls.backlog > 0: ls.backlog -= 1 - old_mu = ls.mu_ewma.value_at(now) ls.mu_ewma.update(inst_rate, now) - new_mu = ls.mu_ewma.value_at(now) - logger.info( - f"[DEBUG done] {label} backlog={ls.backlog} dt={dt if dt is not None else '—'}s inst={inst_rate:.3f} μ {old_mu:.3f}→{new_mu:.3f}" - ) ds = self._detail_stats.get((label, mem_cube_id)) if ds: prev_ts_d = ds.last_done_ts diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py index 7bb2eba7e..375048900 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py @@ -104,15 +104,6 @@ def retrieve( # Merge and deduplicate by ID combined = {item.id: item for item in graph_results + vector_results + bm25_results} - graph_ids = {item.id for item in graph_results} - combined_ids = set(combined.keys()) - lost_ids = graph_ids - combined_ids - - if lost_ids: - print( - f"[DEBUG] The following nodes were in graph_results but missing in combined: {lost_ids}" - ) - return list(combined.values()) def retrieve_from_cube( @@ -150,15 +141,6 @@ def retrieve_from_cube( # Merge and deduplicate by ID combined = {item.id: item for item in graph_results} - graph_ids = {item.id for item in graph_results} - combined_ids = set(combined.keys()) - lost_ids = graph_ids - combined_ids - - if lost_ids: - print( - f"[DEBUG] The following nodes were in graph_results but missing in combined: {lost_ids}" - ) - return list(combined.values()) def _graph_recall( From 02284a41203cf3fc39d066e74df2e245b0f45540 Mon Sep 17 00:00:00 2001 From: Wang Daoji <75928131+Wang-Daoji@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:44:24 +0800 Subject: [PATCH 16/18] fix init bug pref (#508) Co-authored-by: yuan.wang --- src/memos/api/handlers/component_init.py | 72 +++++++++++++++--------- 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py index 78ed13e1f..89e61e79d 100644 --- a/src/memos/api/handlers/component_init.py +++ b/src/memos/api/handlers/component_init.py @@ -117,7 +117,11 @@ def init_server() -> dict[str, Any]: # Create component instances graph_db = GraphStoreFactory.from_config(graph_db_config) - vector_db = VecDBFactory.from_config(vector_db_config) + vector_db = ( + VecDBFactory.from_config(vector_db_config) + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" + else None + ) llm = LLMFactory.from_config(llm_config) embedder = EmbedderFactory.from_config(embedder_config) mem_reader = MemReaderFactory.from_config(mem_reader_config) @@ -154,40 +158,56 @@ def init_server() -> dict[str, Any]: logger.debug("Text memory initialized") # Initialize preference memory components - pref_extractor = ExtractorFactory.from_config( - config_factory=pref_extractor_config, - llm_provider=llm, - embedder=embedder, - vector_db=vector_db, + pref_extractor = ( + ExtractorFactory.from_config( + config_factory=pref_extractor_config, + llm_provider=llm, + embedder=embedder, + vector_db=vector_db, + ) + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" + else None ) - pref_adder = AdderFactory.from_config( - config_factory=pref_adder_config, - llm_provider=llm, - embedder=embedder, - vector_db=vector_db, - text_mem=text_mem, + pref_adder = ( + AdderFactory.from_config( + config_factory=pref_adder_config, + llm_provider=llm, + embedder=embedder, + vector_db=vector_db, + text_mem=text_mem, + ) + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" + else None ) - pref_retriever = RetrieverFactory.from_config( - config_factory=pref_retriever_config, - llm_provider=llm, - embedder=embedder, - reranker=reranker, - vector_db=vector_db, + pref_retriever = ( + RetrieverFactory.from_config( + config_factory=pref_retriever_config, + llm_provider=llm, + embedder=embedder, + reranker=reranker, + vector_db=vector_db, + ) + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" + else None ) logger.debug("Preference memory components initialized") # Initialize preference memory - pref_mem = SimplePreferenceTextMemory( - extractor_llm=llm, - vector_db=vector_db, - embedder=embedder, - reranker=reranker, - extractor=pref_extractor, - adder=pref_adder, - retriever=pref_retriever, + pref_mem = ( + SimplePreferenceTextMemory( + extractor_llm=llm, + vector_db=vector_db, + embedder=embedder, + reranker=reranker, + extractor=pref_extractor, + adder=pref_adder, + retriever=pref_retriever, + ) + if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true" + else None ) logger.debug("Preference memory initialized") From f9a10d130ad2d53c574198195e74c12ffd124585 Mon Sep 17 00:00:00 2001 From: HarveyXiang Date: Thu, 20 Nov 2025 16:57:20 +0800 Subject: [PATCH 17/18] Feat/log rotating (#507) * Update API Reference link in README.md * hotfix bug in pref init * feat: log support time rotating * feat: log support time rotating * feat: log support time rotating * feat: delete useless log * feat: delete useless log * feat: add time log * feat: add time log --------- Co-authored-by: CaralHsi Co-authored-by: yuan.wang Co-authored-by: harvey_xiang --- src/memos/embedders/universal_api.py | 2 +- src/memos/llms/openai.py | 12 ++++++++++-- src/memos/reranker/http_bge.py | 2 +- src/memos/utils.py | 3 ++- 4 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py index fc51cf073..583a02acb 100644 --- a/src/memos/embedders/universal_api.py +++ b/src/memos/embedders/universal_api.py @@ -26,7 +26,7 @@ def __init__(self, config: UniversalAPIEmbedderConfig): else: raise ValueError(f"Embeddings unsupported provider: {self.provider}") - @timed(log=True, log_prefix="EmbedderAPI") + @timed(log=True, log_prefix="model_timed_embedding") def embed(self, texts: list[str]) -> list[list[float]]: if self.provider == "openai" or self.provider == "azure": try: diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py index 1a1703340..da55ae593 100644 --- a/src/memos/llms/openai.py +++ b/src/memos/llms/openai.py @@ -1,5 +1,6 @@ import hashlib import json +import time from collections.abc import Generator from typing import ClassVar @@ -57,12 +58,15 @@ def clear_cache(cls): cls._instances.clear() logger.info("OpenAI LLM instance cache cleared") - @timed(log=True, log_prefix="OpenAI LLM") + @timed(log=True, log_prefix="model_timed_openai") def generate(self, messages: MessageList, **kwargs) -> str: """Generate a response from OpenAI LLM, optionally overriding generation params.""" temperature = kwargs.get("temperature", self.config.temperature) max_tokens = kwargs.get("max_tokens", self.config.max_tokens) top_p = kwargs.get("top_p", self.config.top_p) + start_time = time.time() + logger.info(f"openai model request start, model_name: {self.config.model_name_or_path}") + response = self.client.chat.completions.create( model=self.config.model_name_or_path, messages=messages, @@ -71,7 +75,11 @@ def generate(self, messages: MessageList, **kwargs) -> str: max_tokens=max_tokens, top_p=top_p, ) - logger.info(f"Response from OpenAI: {response.model_dump_json()}") + + end_time = time.time() + logger.info( + f"openai model request end, time_cost: {end_time - start_time:.0f} ms, response from OpenAI: {response.model_dump_json()}" + ) response_content = response.choices[0].message.content if self.config.remove_think_prefix: return remove_thinking_tags(response_content) diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py index 41011df14..db5a51fc2 100644 --- a/src/memos/reranker/http_bge.py +++ b/src/memos/reranker/http_bge.py @@ -119,7 +119,7 @@ def __init__( self.warn_unknown_filter_keys = bool(warn_unknown_filter_keys) self._warned_missing_keys: set[str] = set() - @timed(log=True, log_prefix="RerankerAPI") + @timed(log=True, log_prefix="model_timed_rerank") def rerank( self, query: str, diff --git a/src/memos/utils.py b/src/memos/utils.py index 9ae27bb81..4b1a59834 100644 --- a/src/memos/utils.py +++ b/src/memos/utils.py @@ -17,8 +17,9 @@ def wrapper(*args, **kwargs): start = time.perf_counter() result = fn(*args, **kwargs) elapsed = time.perf_counter() - start + elapsed_ms = elapsed * 1000.0 if log: - logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed:.2f} seconds") + logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms") return result return wrapper From 2f6571f7dfe126d1f0e2e166aa210254d35d236a Mon Sep 17 00:00:00 2001 From: HarveyXiang Date: Fri, 21 Nov 2025 10:24:15 +0800 Subject: [PATCH 18/18] fix: format error (#510) Co-authored-by: harvey_xiang --- src/memos/api/routers/server_router.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index b02569e60..b3b517305 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -98,6 +98,7 @@ def add_memories(add_req: APIADDRequest): # Scheduler API Endpoints # ============================================================================= + @router.get("/scheduler/status", summary="Get scheduler running status") def scheduler_status(user_name: str | None = None): """Get scheduler running status."""