diff --git a/evaluation/scripts/temporal_locomo/models/locomo_eval.py b/evaluation/scripts/temporal_locomo/models/locomo_eval.py
deleted file mode 100644
index f98a481e2..000000000
--- a/evaluation/scripts/temporal_locomo/models/locomo_eval.py
+++ /dev/null
@@ -1,531 +0,0 @@
-import argparse
-import asyncio
-import json
-import os
-import time
-
-import nltk
-import numpy as np
-
-from bert_score import score as bert_score
-from dotenv import load_dotenv
-from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
-from nltk.translate.meteor_score import meteor_score
-from openai import AsyncOpenAI
-from pydantic import BaseModel, Field
-from rouge_score import rouge_scorer
-from scipy.spatial.distance import cosine
-from sentence_transformers import SentenceTransformer
-from tqdm import tqdm
-
-from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules
-from memos.log import get_logger
-
-
-logger = get_logger(__name__)
-
-
-# Download necessary NLTK resources
-try:
- nltk.download("wordnet", quiet=True)
- nltk.download("punkt", quiet=True)
- print("NLTK resources downloaded successfully.")
-except Exception as e:
- print(f"Warning: Failed to download NLTK resources: {e}")
-
-
-try:
- sentence_model_name = "Qwen/Qwen3-Embedding-0.6B"
- sentence_model = SentenceTransformer(sentence_model_name)
- print(f"SentenceTransformer model : {sentence_model_name} loaded successfully.")
-except Exception as e:
- print(f"Failed to load SentenceTransformer model: {e}")
- sentence_model = None
-
-
-class LLMGrade(BaseModel):
- llm_judgment: str = Field(description="CORRECT or WRONG")
- llm_reasoning: str = Field(description="Explain why the answer is correct or incorrect.")
-
-
-async def locomo_grader(llm_client, question: str, gold_answer: str, response: str) -> bool:
- system_prompt = """
- You are an expert grader that determines if answers to questions match a gold standard answer
- """
-
- accuracy_prompt = f"""
- Your task is to label an answer to a question as ’CORRECT’ or ’WRONG’. You will be given the following data:
- (1) a question (posed by one user to another user),
- (2) a ’gold’ (ground truth) answer,
- (3) a generated answer
- which you will score as CORRECT/WRONG.
-
- The point of the question is to ask about something one user should know about the other user based on their prior conversations.
- The gold answer will usually be a concise and short answer that includes the referenced topic, for example:
- Question: Do you remember what I got the last time I went to Hawaii?
- Gold answer: A shell necklace
- The generated answer might be much longer, but you should be generous with your grading - as long as it touches on the same topic as the gold answer, it should be counted as CORRECT.
-
- For time related questions, the gold answer will be a specific date, month, year, etc. The generated answer might be much longer or use relative time references (like "last Tuesday" or "next month"), but you should be generous with your grading - as long as it refers to the same date or time period as the gold answer, it should be counted as CORRECT. Even if the format differs (e.g., "May 7th" vs "7 May"), consider it CORRECT if it's the same date.
-
- Now it’s time for the real question:
- Question: {question}
- Gold answer: {gold_answer}
- Generated answer: {response}
-
- First, provide a short (one sentence) explanation of your reasoning, then finish with CORRECT or WRONG.
- Do NOT include both CORRECT and WRONG in your response, or it will break the evaluation script.
-
- Just return the label CORRECT or WRONG in a json format with the key as "label".
- """
-
- response = await llm_client.chat.completions.create(
- model="gpt-4o-mini",
- messages=[
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": accuracy_prompt},
- ],
- temperature=0,
- )
- message_content = response.choices[0].message.content
- label = json.loads(message_content)["label"]
- parsed = LLMGrade(llm_judgment=label, llm_reasoning="")
-
- return parsed.llm_judgment.strip().lower() == "correct"
-
-
-def calculate_rouge_scores(gold_answer, response):
- metrics = {"rouge1_f": 0.0, "rouge2_f": 0.0, "rougeL_f": 0.0}
- try:
- scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
- rouge_scores = scorer.score(gold_answer, response)
- metrics["rouge1_f"] = rouge_scores["rouge1"].fmeasure
- metrics["rouge2_f"] = rouge_scores["rouge2"].fmeasure
- metrics["rougeL_f"] = rouge_scores["rougeL"].fmeasure
- except Exception as e:
- print(f"Failed to calculate ROUGE scores: {e}")
- return metrics
-
-
-def calculate_bleu_scores(gold_tokens, response_tokens):
- metrics = {"bleu1": 0.0, "bleu2": 0.0, "bleu3": 0.0, "bleu4": 0.0}
-
- try:
- smoothing = SmoothingFunction().method1
- weights = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)]
-
- for i, weight in enumerate(weights, 1):
- metrics[f"bleu{i}"] = sentence_bleu(
- [gold_tokens], response_tokens, weights=weight, smoothing_function=smoothing
- )
- except ZeroDivisionError:
- pass
- except Exception as e:
- print(f"Failed to calculate BLEU scores: {e}")
-
- return metrics
-
-
-def calculate_meteor_score(gold_tokens, response_tokens):
- try:
- return meteor_score([gold_tokens], response_tokens)
- except Exception as e:
- print(f"Failed to calculate METEOR score: {e}")
- return 0.0
-
-
-def calculate_semantic_similarity(gold_answer, response):
- global sentence_model
-
- try:
- if sentence_model is None:
- sentence_model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B")
-
- gold_embedding = sentence_model.encode([gold_answer], show_progress_bar=False)[0]
- response_embedding = sentence_model.encode([response], show_progress_bar=False)[0]
- return 1 - cosine(gold_embedding, response_embedding)
- except Exception as e:
- print(f"Failed to calculate semantic similarity: {e}")
- return 0.0
-
-
-def calculate_f1_score(gold_tokens, response_tokens):
- try:
- gold_set = set(gold_tokens)
- response_set = set(response_tokens)
-
- if len(gold_set) == 0 or len(response_set) == 0:
- return 0.0
-
- precision = len(gold_set.intersection(response_set)) / len(response_set)
- recall = len(gold_set.intersection(response_set)) / len(gold_set)
-
- if precision + recall > 0:
- return 2 * precision * recall / (precision + recall)
- return 0.0
- except Exception as e:
- print(f"Failed to calculate F1 score: {e}")
- return 0.0
-
-
-def calculate_nlp_metrics(gold_answer, response, context, options=None):
- if options is None:
- options = ["lexical", "semantic"]
-
- gold_answer = str(gold_answer) if gold_answer is not None else ""
- response = str(response) if response is not None else ""
-
- metrics = {"context_tokens": len(nltk.word_tokenize(context)) if context else 0}
-
- if "lexical" in options:
- gold_tokens = nltk.word_tokenize(gold_answer.lower())
- response_tokens = nltk.word_tokenize(response.lower())
-
- metrics["lexical"] = {}
- metrics["lexical"]["f1"] = calculate_f1_score(gold_tokens, response_tokens)
- metrics["lexical"].update(calculate_rouge_scores(gold_answer, response))
- metrics["lexical"].update(calculate_bleu_scores(gold_tokens, response_tokens))
- metrics["lexical"]["meteor"] = calculate_meteor_score(gold_tokens, response_tokens)
-
- if "semantic" in options:
- metrics["semantic"] = {}
- metrics["semantic"]["similarity"] = calculate_semantic_similarity(gold_answer, response)
- _, _, f1 = bert_score(
- [gold_answer], [response], lang="en", rescale_with_baseline=True, verbose=False
- )
- metrics["semantic"]["bert_f1"] = f1.item() if f1 is not None else 0.0
-
- return metrics
-
-
-def convert_numpy_types(obj):
- if isinstance(obj, np.number):
- return float(obj)
- elif isinstance(obj, dict):
- return {k: convert_numpy_types(v) for k, v in obj.items()}
- elif isinstance(obj, list):
- return [convert_numpy_types(i) for i in obj]
- else:
- return obj
-
-
-async def process_group_responses(
- group_id, group_responses, oai_client, evaluation_options, num_runs: int
-):
- graded_responses = []
-
- # Process responses with asyncio for concurrent API calls
- for response in tqdm(group_responses, desc=f"Processing group {group_id}"):
- question = response.get("question")
- answer = response.get("answer")
- ground_truth = response.get("golden_answer")
- category = response.get("category")
-
- context = response.get("search_context", "")
- response_duration_ms = response.get("response_duration_ms", 0.0)
- search_duration_ms = response.get("search_duration_ms", 0.0)
-
- if ground_truth is None:
- continue
-
- grading_tasks = [
- locomo_grader(oai_client, question, ground_truth, answer) for _ in range(num_runs)
- ]
- judgments = await asyncio.gather(*grading_tasks)
- judgments_dict = {f"judgment_{i + 1}": j for i, j in enumerate(judgments)}
-
- nlp_metrics = calculate_nlp_metrics(ground_truth, answer, context, evaluation_options)
-
- graded_response = {
- "question": question,
- "answer": answer,
- "golden_answer": ground_truth,
- "category": category,
- "llm_judgments": judgments_dict,
- "nlp_metrics": nlp_metrics,
- "response_duration_ms": response_duration_ms,
- "search_duration_ms": search_duration_ms,
- "total_duration_ms": response_duration_ms + search_duration_ms,
- }
- graded_responses.append(graded_response)
-
- return group_id, graded_responses
-
-
-async def process_single_group(group_id, group_responses, oai_client, evaluation_options, num_runs):
- try:
- start_time = time.time()
- result = await process_group_responses(
- group_id, group_responses, oai_client, evaluation_options, num_runs
- )
- end_time = time.time()
- elapsed_time = round(end_time - start_time, 2)
- print(f"Group {group_id} processed in {elapsed_time} seconds")
- return result
- except Exception as e:
- logger.error(f"Error processing group {group_id}: {e}", exc_info=True)
- return group_id, []
-
-
-class LocomoEvaluator(LocomoEvalModelModules):
- def __init__(self, args):
- # Initialize base class to populate self.frame, self.version, etc.
- super().__init__(args=args)
-
- self.evaluation_options = getattr(args, "evaluation_options", ["lexical", "semantic"])
- self.num_runs = getattr(args, "num_runs", 1)
- self.max_workers = getattr(args, "workers", 4)
-
- load_dotenv()
- self.oai_client = AsyncOpenAI(
- api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL")
- )
-
- def _load_response_data(self):
- """
- Load response data from the response path file.
-
- Returns:
- dict: The loaded response data
- """
- with open(self.response_path) as file:
- return json.load(file)
-
- def _load_existing_evaluation_results(self):
- """
- Attempt to load existing evaluation results from the judged path.
- If the file doesn't exist or there's an error loading it, return an empty dict.
-
- Returns:
- dict: Existing evaluation results or empty dict if none available
- """
- all_grades = {}
- try:
- if os.path.exists(self.judged_path):
- with open(self.judged_path) as f:
- all_grades = json.load(f)
- print(f"Loaded existing evaluation results from {self.judged_path}")
- except Exception as e:
- print(f"Error loading existing evaluation results: {e}")
-
- return all_grades
-
- def _create_evaluation_tasks(self, locomo_responses, all_grades, num_users):
- """
- Create evaluation tasks for groups that haven't been evaluated yet.
-
- Args:
- locomo_responses (dict): The loaded response data
- all_grades (dict): Existing evaluation results
- num_users (int): Number of user groups to process
-
- Returns:
- tuple: (tasks list, active users count)
- """
- tasks = []
- active_users = 0
-
- for group_idx in range(num_users):
- group_id = f"locomo_exp_user_{group_idx}"
- group_responses = locomo_responses.get(group_id, [])
-
- if not group_responses:
- print(f"No responses found for group {group_id}")
- continue
-
- # Skip groups that already have evaluation results
- if all_grades.get(group_id):
- print(f"Skipping group {group_id} as it already has evaluation results")
- active_users += 1
- continue
-
- active_users += 1
- tasks.append(
- process_single_group(
- group_id=group_id,
- group_responses=group_responses,
- oai_client=self.oai_client,
- evaluation_options=self.evaluation_options,
- num_runs=self.num_runs,
- )
- )
-
- return tasks, active_users
-
- async def _process_tasks(self, tasks):
- """
- Process evaluation tasks with concurrency control.
-
- Args:
- tasks (list): List of tasks to process
-
- Returns:
- list: Results from processing all tasks
- """
- if not tasks:
- return []
-
- semaphore = asyncio.Semaphore(self.max_workers)
-
- async def limited_task(task):
- """Helper function to limit concurrent task execution"""
- async with semaphore:
- return await task
-
- limited_tasks = [limited_task(task) for task in tasks]
- return await asyncio.gather(*limited_tasks)
-
- def _calculate_scores(self, all_grades):
- """
- Calculate evaluation scores based on all grades.
-
- Args:
- all_grades (dict): The complete evaluation results
-
- Returns:
- tuple: (run_scores, evaluated_count)
- """
- run_scores = []
- evaluated_count = 0
-
- if self.num_runs > 0:
- for i in range(1, self.num_runs + 1):
- judgment_key = f"judgment_{i}"
- current_run_correct_count = 0
- current_run_total_count = 0
-
- for group in all_grades.values():
- for response in group:
- if judgment_key in response["llm_judgments"]:
- if response["llm_judgments"][judgment_key]:
- current_run_correct_count += 1
- current_run_total_count += 1
-
- if current_run_total_count > 0:
- run_accuracy = current_run_correct_count / current_run_total_count
- run_scores.append(run_accuracy)
-
- evaluated_count = current_run_total_count
-
- return run_scores, evaluated_count
-
- def _report_scores(self, run_scores, evaluated_count):
- """
- Report evaluation scores to the console.
-
- Args:
- run_scores (list): List of accuracy scores for each run
- evaluated_count (int): Number of evaluated responses
- """
- if evaluated_count > 0:
- mean_of_scores = np.mean(run_scores)
- std_of_scores = np.std(run_scores)
- print(f"LLM-as-a-Judge Mean Score: {mean_of_scores:.4f}")
- print(f"LLM-as-a-Judge Standard Deviation: {std_of_scores:.4f}")
- print(
- f"(Calculated from {self.num_runs} separate runs over {evaluated_count} questions)"
- )
- print(f"Individual run scores: {[round(s, 4) for s in run_scores]}")
- else:
- print("No responses were evaluated")
- print("LLM-as-a-Judge score: N/A (0/0)")
-
- def _save_results(self, all_grades):
- """
- Save evaluation results to the judged path file.
-
- Args:
- all_grades (dict): The complete evaluation results to save
- """
- all_grades = convert_numpy_types(all_grades)
- with open(self.judged_path, "w") as f:
- json.dump(all_grades, f, indent=2)
- print(f"Saved detailed evaluation results to {self.judged_path}")
-
- async def run(self):
- """
- Main execution method for the LoCoMo evaluation process.
- This method orchestrates the entire evaluation workflow:
- 1. Loads existing evaluation results if available
- 2. Processes only groups that haven't been evaluated yet
- 3. Calculates and reports final evaluation scores
- """
- print(
- f"\n=== Starting LoCoMo evaluation for {self.frame} (version: {self.version}) with {self.num_runs} run(s) per question ==="
- )
- print(f"Using {self.max_workers} concurrent workers for processing groups")
-
- # Load response data and existing evaluation results
- locomo_responses = self._load_response_data()
- all_grades = self._load_existing_evaluation_results()
-
- # Count total responses for reporting
- num_users = 10
- total_responses_count = sum(
- len(locomo_responses.get(f"locomo_exp_user_{i}", [])) for i in range(num_users)
- )
- print(f"Found {total_responses_count} total responses across {num_users} users to evaluate")
-
- # Create tasks only for groups that haven't been evaluated yet
- tasks, active_users = self._create_evaluation_tasks(locomo_responses, all_grades, num_users)
- print(
- f"Starting evaluation of {len(tasks)} user groups with responses (out of {active_users} active users)"
- )
-
- # Process tasks and update all_grades with results
- if tasks:
- group_results = await self._process_tasks(tasks)
- for group_id, graded_responses in group_results:
- all_grades[group_id] = graded_responses
-
- print("\n=== Evaluation Complete: Calculating final scores ===")
-
- # Calculate and report scores
- run_scores, evaluated_count = self._calculate_scores(all_grades)
- self._report_scores(run_scores, evaluated_count)
-
- # Save results
- self._save_results(all_grades)
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--lib",
- type=str,
- default="memos_scheduler",
- choices=["zep", "memos", "memos_scheduler", "mem0", "mem0_graph", "langmem", "openai"],
- help="Specify the memory framework (zep or memos or mem0 or mem0_graph)",
- )
- parser.add_argument(
- "--version",
- type=str,
- default="v1.0.1",
- help="Version identifier for loading results (e.g., 1010)",
- )
- parser.add_argument(
- "--num_runs",
- type=int,
- default=3,
- help="Number of times to run the LLM grader for each question",
- )
- parser.add_argument("--evaluation_options", nargs="+", default=["lexical", "semantic"])
- parser.add_argument(
- "--workers", type=int, default=10, help="Number of concurrent workers for processing groups"
- )
- cli_args = parser.parse_args()
-
- # Build args for evaluator
- class Args:
- def __init__(self, cli_args):
- self.frame = cli_args.lib
- self.version = cli_args.version
- self.workers = cli_args.workers
- self.num_runs = cli_args.num_runs
- self.evaluation_options = cli_args.evaluation_options
- self.top_k = 20
- self.scheduler_flag = True
-
- args = Args(cli_args)
- evaluator = LocomoEvaluator(args=args)
- asyncio.run(evaluator.run())
diff --git a/evaluation/scripts/temporal_locomo/models/locomo_ingestion.py b/evaluation/scripts/temporal_locomo/models/locomo_ingestion.py
deleted file mode 100644
index b45ec3d61..000000000
--- a/evaluation/scripts/temporal_locomo/models/locomo_ingestion.py
+++ /dev/null
@@ -1,303 +0,0 @@
-import concurrent.futures
-import sys
-import time
-import traceback
-
-from datetime import datetime, timezone
-from pathlib import Path
-
-from tqdm import tqdm
-
-from evaluation.scripts.temporal_locomo.modules.constants import (
- MEM0_GRAPH_MODEL,
- MEM0_MODEL,
- MEMOS_MODEL,
- MEMOS_SCHEDULER_MODEL,
- ZEP_MODEL,
-)
-from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules
-from memos.log import get_logger
-
-
-FILE_PATH = Path(__file__).absolute()
-BASE_DIR = FILE_PATH.parent.parent.parent
-sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
-
-logger = get_logger(__name__)
-
-
-class LocomoIngestor(LocomoEvalModelModules):
- def __init__(self, args):
- super().__init__(args=args)
-
- def ingest_session(self, client, session, frame, metadata, revised_client=None):
- session_date = metadata["session_date"]
- date_format = "%I:%M %p on %d %B, %Y UTC"
- date_string = datetime.strptime(session_date, date_format).replace(tzinfo=timezone.utc)
- iso_date = date_string.isoformat()
- conv_id = metadata["conv_id"]
- conv_id = "locomo_exp_user_" + str(conv_id)
- dt = datetime.fromisoformat(iso_date)
- timestamp = int(dt.timestamp())
- print(f"Processing conv {conv_id}, session {metadata['session_key']}")
- start_time = time.time()
- print_once = True # Print example only once per session
-
- if frame == ZEP_MODEL:
- for chat in tqdm(session, desc=f"{metadata['session_key']}"):
- data = chat.get("speaker") + ": " + chat.get("text")
-
- # Print example only once per session
- if print_once:
- print({"context": data, "conv_id": conv_id, "created_at": iso_date})
- print_once = False
-
- # Check if the group exists, if not create it
- groups = client.group.get_all_groups()
- groups = dict(groups)["groups"]
- exist_ids = [gp.group_id for gp in groups]
- if conv_id not in exist_ids:
- client.group.add(group_id=conv_id)
-
- # Add the message to the group
- client.graph.add(
- data=data,
- type="message",
- created_at=iso_date,
- group_id=conv_id,
- )
-
- elif frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
- messages = []
- messages_reverse = []
-
- for chat in tqdm(session, desc=f"{metadata['session_key']}"):
- data = chat.get("speaker") + ": " + chat.get("text")
-
- if chat.get("speaker") == metadata["speaker_a"]:
- messages.append({"role": "user", "content": data, "chat_time": iso_date})
- messages_reverse.append(
- {"role": "assistant", "content": data, "chat_time": iso_date}
- )
- elif chat.get("speaker") == metadata["speaker_b"]:
- messages.append({"role": "assistant", "content": data, "chat_time": iso_date})
- messages_reverse.append(
- {"role": "user", "content": data, "chat_time": iso_date}
- )
- else:
- raise ValueError(
- f"Unknown speaker {chat.get('speaker')} in session {metadata['session_key']}"
- )
-
- # Print example only once per session
- if print_once:
- print({"context": data, "conv_id": conv_id, "created_at": iso_date})
- print_once = False
-
- speaker_a_user_id = conv_id + "_speaker_a"
- speaker_b_user_id = conv_id + "_speaker_b"
-
- client.add(
- messages=messages,
- user_id=speaker_a_user_id,
- )
-
- revised_client.add(
- messages=messages_reverse,
- user_id=speaker_b_user_id,
- )
- print(f"Added messages for {speaker_a_user_id} and {speaker_b_user_id} successfully.")
-
- elif frame in [MEM0_MODEL, MEM0_GRAPH_MODEL]:
- print(f"Processing abc for {metadata['session_key']}")
- messages = []
- messages_reverse = []
-
- for chat in tqdm(session, desc=f"{metadata['session_key']}"):
- data = chat.get("speaker") + ": " + chat.get("text")
-
- if chat.get("speaker") == metadata["speaker_a"]:
- messages.append({"role": "user", "content": data})
- messages_reverse.append({"role": "assistant", "content": data})
- elif chat.get("speaker") == metadata["speaker_b"]:
- messages.append({"role": "assistant", "content": data})
- messages_reverse.append({"role": "user", "content": data})
- else:
- raise ValueError(
- f"Unknown speaker {chat.get('speaker')} in session {metadata['session_key']}"
- )
-
- # Print example only once per session
- if print_once:
- print({"context": data, "conv_id": conv_id, "created_at": iso_date})
- print_once = False
-
- for i in range(0, len(messages), 2):
- batch_messages = messages[i : i + 2]
- batch_messages_reverse = messages_reverse[i : i + 2]
-
- if frame == "mem0":
- client.add(
- messages=batch_messages,
- timestamp=timestamp,
- user_id=metadata["speaker_a_user_id"],
- version="v2",
- )
- client.add(
- messages=batch_messages_reverse,
- timestamp=timestamp,
- user_id=metadata["speaker_b_user_id"],
- version="v2",
- )
-
- elif frame == "mem0_graph":
- client.add(
- messages=batch_messages,
- timestamp=timestamp,
- user_id=metadata["speaker_a_user_id"],
- output_format="v1.1",
- version="v2",
- enable_graph=True,
- )
- client.add(
- messages=batch_messages_reverse,
- timestamp=timestamp,
- user_id=metadata["speaker_b_user_id"],
- output_format="v1.1",
- version="v2",
- enable_graph=True,
- )
-
- end_time = time.time()
- elapsed_time = round(end_time - start_time, 2)
-
- return elapsed_time
-
- def process_user_for_ingestion(self, conv_id, frame, locomo_df, version, num_workers=1):
- try:
- # Check if locomo_df is empty or doesn't have the required columns
- if locomo_df.empty or "conversation" not in locomo_df.columns:
- logger.warning(
- f"Skipping user {conv_id}: locomo_df is empty or missing 'conversation' column"
- )
- return 0
-
- conversation = locomo_df["conversation"].iloc[conv_id]
- max_session_count = 35
- start_time = time.time()
- total_session_time = 0
- valid_sessions = 0
-
- revised_client = None
- if frame == "zep":
- client = self.get_client_for_ingestion(frame=frame, user_id=None, version="default")
- elif frame == "mem0" or frame == "mem0_graph":
- client = self.get_client_for_ingestion(frame=frame, user_id=None, version="default")
- client.delete_all(user_id=f"locomo_exp_user_{conv_id}")
- client.delete_all(user_id=f"{conversation.get('speaker_a')}_{conv_id}")
- client.delete_all(user_id=f"{conversation.get('speaker_b')}_{conv_id}")
- elif frame in ["memos", "memos_scheduler"]:
- conv_id = "locomo_exp_user_" + str(conv_id)
- speaker_a_user_id = conv_id + "_speaker_a"
- speaker_b_user_id = conv_id + "_speaker_b"
-
- client = self.get_client_for_ingestion(
- frame=frame, user_id=speaker_a_user_id, version=version
- )
- revised_client = self.get_client_for_ingestion(
- frame=frame, user_id=speaker_b_user_id, version=version
- )
- else:
- raise NotImplementedError()
-
- sessions_to_process = []
- for session_idx in tqdm(range(max_session_count), desc=f"process_user {conv_id}"):
- session_key = f"session_{session_idx}"
- session = conversation.get(session_key)
- if session is None:
- continue
-
- metadata = {
- "session_date": conversation.get(f"session_{session_idx}_date_time") + " UTC",
- "speaker_a": conversation.get("speaker_a"),
- "speaker_b": conversation.get("speaker_b"),
- "speaker_a_user_id": f"{conversation.get('speaker_a')}_{conv_id}",
- "speaker_b_user_id": f"{conversation.get('speaker_b')}_{conv_id}",
- "conv_id": conv_id,
- "session_key": session_key,
- }
- sessions_to_process.append((session, metadata))
- valid_sessions += 1
-
- print(
- f"Processing {valid_sessions} sessions for user {conv_id} with {num_workers} workers"
- )
- with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor:
- futures = {
- executor.submit(
- self.ingest_session, client, session, frame, metadata, revised_client
- ): metadata["session_key"]
- for session, metadata in sessions_to_process
- }
-
- for future in concurrent.futures.as_completed(futures):
- session_key = futures[future]
- try:
- session_time = future.result()
- total_session_time += session_time
- print(f"User {conv_id}, {session_key} processed in {session_time} seconds")
- except Exception as e:
- print(f"Error processing user {conv_id}, session {session_key}: {e!s}")
-
- end_time = time.time()
- elapsed_time = round(end_time - start_time, 2)
- print(f"User {conv_id} processed successfully in {elapsed_time} seconds")
-
- return elapsed_time
-
- except Exception as e:
- return f"Error processing user {conv_id}: {e!s}. Exception: {traceback.format_exc()}"
-
- def run_ingestion(self):
- frame = self.frame
- version = self.version
- num_workers = self.workers
-
- num_users = 10
- start_time = time.time()
- total_time = 0
-
- print(
- f"Starting processing for {num_users} users in serial mode,"
- f" each user using {num_workers} workers for sessions..."
- )
-
- for user_id in range(num_users):
- try:
- result = self.process_user_for_ingestion(
- user_id, frame, self.locomo_df, version, num_workers
- )
- if isinstance(result, float):
- total_time += result
- else:
- print(result)
- except Exception as e:
- print(
- f"Error processing user {user_id}: {e!s}. Traceback: {traceback.format_exc()}"
- )
-
- if num_users > 0:
- average_time = total_time / num_users
- minutes = int(average_time // 60)
- seconds = int(average_time % 60)
- average_time_formatted = f"{minutes} minutes and {seconds} seconds"
- print(
- f"The frame {frame} processed {num_users} users in average of {average_time_formatted} per user."
- )
-
- end_time = time.time()
- elapsed_time = round(end_time - start_time, 2)
- minutes = int(elapsed_time // 60)
- seconds = int(elapsed_time % 60)
- elapsed_time = f"{minutes} minutes and {seconds} seconds"
- print(f"Total processing time: {elapsed_time}.")
diff --git a/evaluation/scripts/temporal_locomo/models/locomo_metric.py b/evaluation/scripts/temporal_locomo/models/locomo_metric.py
deleted file mode 100644
index 532fe2e14..000000000
--- a/evaluation/scripts/temporal_locomo/models/locomo_metric.py
+++ /dev/null
@@ -1,390 +0,0 @@
-import argparse
-import json
-
-import numpy as np
-import pandas as pd
-
-from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules
-
-
-# Category mapping as per your request
-category_mapping = {
- "4": "single hop",
- "1": "multi hop",
- "2": "temporal reasoning",
- "3": "open domain",
-}
-
-
-def calculate_scores(data):
- category_scores = {}
- category_question_count = {}
-
- overall_metrics = {
- "lexical": {
- m: []
- for m in [
- "f1",
- "rouge1_f",
- "rouge2_f",
- "rougeL_f",
- "bleu1",
- "bleu2",
- "bleu3",
- "bleu4",
- "meteor",
- ]
- },
- "semantic": {m: [] for m in ["bert_f1", "similarity"]},
- "context_tokens": [],
- "duration": {
- m: [] for m in ["response_duration_ms", "search_duration_ms", "total_duration_ms"]
- },
- }
-
- category_metrics = {}
- user_metrics = {}
-
- total_questions = 0
-
- all_judgment_keys = set()
- judgment_run_scores = {}
-
- for _user, questions in data.items():
- for question in questions:
- if "llm_judgments" in question:
- all_judgment_keys.update(question["llm_judgments"].keys())
-
- for key in all_judgment_keys:
- judgment_run_scores[key] = []
-
- for user, questions in data.items():
- user_total = 0
-
- # Initialize user_metrics with each judgment run
- user_metrics[user] = {
- "total": 0,
- "llm_judge_score": 0,
- "llm_judge_std": 0,
- "judgment_run_scores": {key: [] for key in all_judgment_keys},
- "lexical": {m: [] for m in overall_metrics["lexical"]},
- "semantic": {m: [] for m in overall_metrics["semantic"]},
- "context_tokens": [],
- "duration": {m: [] for m in overall_metrics["duration"]},
- }
-
- for question in questions:
- total_questions += 1
- user_total += 1
-
- if "llm_judgments" in question:
- for judgment_key, judgment_value in question["llm_judgments"].items():
- score = 1 if judgment_value else 0
- judgment_run_scores[judgment_key].append(score)
- user_metrics[user]["judgment_run_scores"][judgment_key].append(score)
-
- category = question["category"]
- if category not in category_scores:
- category_scores[category] = {
- "total": 0,
- "category_name": category_mapping.get(str(category), "Unknown"),
- "judgment_run_scores": {key: [] for key in all_judgment_keys},
- }
- category_metrics[category] = {
- "lexical": {m: [] for m in overall_metrics["lexical"]},
- "semantic": {m: [] for m in overall_metrics["semantic"]},
- "context_tokens": [],
- "duration": {m: [] for m in overall_metrics["duration"]},
- }
- category_question_count[category] = 0
-
- category_scores[category]["total"] += 1
- category_question_count[category] += 1
-
- if "llm_judgments" in question:
- for judgment_key, judgment_value in question["llm_judgments"].items():
- score = 1 if judgment_value else 0
- category_scores[category]["judgment_run_scores"][judgment_key].append(score)
-
- nlp = question.get("nlp_metrics", {})
- for metric in overall_metrics["lexical"]:
- v = nlp.get("lexical", {}).get(metric)
- if v is not None:
- overall_metrics["lexical"][metric].append(v)
- category_metrics[category]["lexical"][metric].append(v)
- user_metrics[user]["lexical"][metric].append(v)
-
- for metric in overall_metrics["semantic"]:
- v = nlp.get("semantic", {}).get(metric)
- if v is not None:
- overall_metrics["semantic"][metric].append(v)
- category_metrics[category]["semantic"][metric].append(v)
- user_metrics[user]["semantic"][metric].append(v)
-
- ct = nlp.get("context_tokens")
- if ct is not None:
- overall_metrics["context_tokens"].append(ct)
- category_metrics[category]["context_tokens"].append(ct)
- user_metrics[user]["context_tokens"].append(ct)
-
- for metric in overall_metrics["duration"]:
- v = question.get(metric)
- if v is not None:
- overall_metrics["duration"][metric].append(v)
- category_metrics[category]["duration"][metric].append(v)
- user_metrics[user]["duration"][metric].append(v)
-
- user_metrics[user]["total"] = user_total
-
- judgment_avgs = []
- for _judgment_key, scores in user_metrics[user]["judgment_run_scores"].items():
- if scores:
- avg = np.mean(scores)
- judgment_avgs.append(avg)
-
- user_metrics[user]["llm_judge_score"] = np.mean(judgment_avgs) if judgment_avgs else 0.0
- user_metrics[user]["llm_judge_std"] = (
- np.std(judgment_avgs) if len(judgment_avgs) > 1 else 0.0
- )
-
- for group in ["lexical", "semantic"]:
- for metric in user_metrics[user][group]:
- values = user_metrics[user][group][metric]
- user_metrics[user][group][metric] = np.mean(values) if values else 0.0
-
- user_metrics[user]["context_tokens"] = (
- np.mean(user_metrics[user]["context_tokens"])
- if user_metrics[user]["context_tokens"]
- else 0.0
- )
-
- duration_metrics = list(user_metrics[user]["duration"].keys())
- for metric in duration_metrics:
- values = user_metrics[user]["duration"][metric]
- if values:
- user_metrics[user]["duration"][metric] = np.mean(values)
- user_metrics[user]["duration"][f"{metric}_p50"] = np.percentile(values, 50)
- user_metrics[user]["duration"][f"{metric}_p95"] = np.percentile(values, 95)
- else:
- user_metrics[user]["duration"][metric] = 0.0
- user_metrics[user]["duration"][f"{metric}_p50"] = 0.0
- user_metrics[user]["duration"][f"{metric}_p95"] = 0.0
-
- judgment_run_averages = []
- for _judgment_key, scores in judgment_run_scores.items():
- if scores:
- judgment_run_averages.append(np.mean(scores))
-
- llm_judge_score = np.mean(judgment_run_averages) if judgment_run_averages else 0.0
- llm_judge_std = np.std(judgment_run_averages) if len(judgment_run_averages) > 1 else 0.0
-
- category_overall_scores = {}
- for category, score_data in category_scores.items():
- category_judgment_avgs = []
- for _judgment_key, scores in score_data["judgment_run_scores"].items():
- if scores:
- category_judgment_avgs.append(np.mean(scores))
-
- category_overall_scores[category] = {
- "category_name": score_data["category_name"],
- "llm_judge_score": np.mean(category_judgment_avgs) if category_judgment_avgs else 0.0,
- "llm_judge_std": np.std(category_judgment_avgs)
- if len(category_judgment_avgs) > 1
- else 0.0,
- "total": score_data["total"],
- "lexical": {},
- "semantic": {},
- "duration": {},
- "context_tokens": 0.0,
- }
-
- for group in ["lexical", "semantic"]:
- for metric in category_metrics[category][group]:
- values = category_metrics[category][group][metric]
- category_overall_scores[category][group][metric] = (
- np.mean(values) if values else 0.0
- )
-
- category_overall_scores[category]["context_tokens"] = (
- np.mean(category_metrics[category]["context_tokens"])
- if category_metrics[category]["context_tokens"]
- else 0.0
- )
-
- # Calculate mean and percentiles for category duration metrics
- duration_metrics = list(
- category_metrics[category]["duration"].keys()
- ) # Create a list of keys first
- for metric in duration_metrics:
- values = category_metrics[category]["duration"][metric]
- if values:
- category_overall_scores[category]["duration"][metric] = np.mean(values)
- # Add P50 (median) and P95 percentiles
- category_overall_scores[category]["duration"][f"{metric}_p50"] = np.percentile(
- values, 50
- )
- category_overall_scores[category]["duration"][f"{metric}_p95"] = np.percentile(
- values, 95
- )
- else:
- category_overall_scores[category]["duration"][metric] = 0.0
- category_overall_scores[category]["duration"][f"{metric}_p50"] = 0.0
- category_overall_scores[category]["duration"][f"{metric}_p95"] = 0.0
-
- # calculate overall scores
- overall_metric_averages = {
- "llm_judge_score": llm_judge_score,
- "llm_judge_std": llm_judge_std,
- "lexical": {},
- "semantic": {},
- "context_tokens": 0.0,
- "duration": {},
- }
-
- for group in ["lexical", "semantic"]:
- for metric in overall_metrics[group]:
- values = overall_metrics[group][metric]
- overall_metric_averages[group][metric] = np.mean(values) if values else 0.0
-
- overall_metric_averages["context_tokens"] = (
- np.mean(overall_metrics["context_tokens"]) if overall_metrics["context_tokens"] else 0.0
- )
-
- duration_metrics = list(overall_metrics["duration"].keys())
- for metric in duration_metrics:
- values = overall_metrics["duration"][metric]
- if values:
- overall_metric_averages["duration"][metric] = np.mean(values)
- overall_metric_averages["duration"][f"{metric}_p50"] = np.percentile(values, 50)
- overall_metric_averages["duration"][f"{metric}_p95"] = np.percentile(values, 95)
- else:
- overall_metric_averages["duration"][metric] = 0.0
- overall_metric_averages["duration"][f"{metric}_p50"] = 0.0
- overall_metric_averages["duration"][f"{metric}_p95"] = 0.0
-
- return {
- "metrics": overall_metric_averages,
- "category_scores": category_overall_scores,
- "user_scores": user_metrics,
- }
-
-
-def save_to_excel(results, output_path):
- # Create a combined data structure for metrics and category scores
- combined_data = []
-
- # Process overall metrics - flatten nested structures
- overall_row = {"category": "overall"}
- overall_row["llm_judge_score"] = results["metrics"]["llm_judge_score"]
- overall_row["llm_judge_std"] = results["metrics"]["llm_judge_std"]
-
- # Add all lexical metrics
- for metric, value in results["metrics"]["lexical"].items():
- overall_row[metric] = value
-
- # Add all semantic metrics
- for metric, value in results["metrics"]["semantic"].items():
- overall_row[metric] = value
-
- # Add context tokens
- overall_row["context_tokens"] = results["metrics"]["context_tokens"]
-
- # Add all duration metrics, including percentiles
- for metric, value in results["metrics"]["duration"].items():
- overall_row[metric] = value
-
- combined_data.append(overall_row)
-
- # Process category scores - flatten nested structures
- for _, scores in results["category_scores"].items():
- category_row = {"category": scores["category_name"]}
- category_row["llm_judge_score"] = scores["llm_judge_score"]
- category_row["llm_judge_std"] = scores["llm_judge_std"]
-
- # Add all lexical metrics
- for metric, value in scores["lexical"].items():
- category_row[metric] = value
-
- # Add all semantic metrics
- for metric, value in scores["semantic"].items():
- category_row[metric] = value
-
- # Add context tokens
- category_row["context_tokens"] = scores["context_tokens"]
-
- # Add all duration metrics, including percentiles
- for metric, value in scores["duration"].items():
- category_row[metric] = value
-
- combined_data.append(category_row)
-
- # Create DataFrame and save to Excel
- combined_df = pd.DataFrame(combined_data)
-
- # Create a pandas Excel writer
- with pd.ExcelWriter(output_path) as writer:
- combined_df.to_excel(writer, sheet_name="Metrics", index=False)
-
- print(f"Excel file saved to: {output_path}")
-
-
-class LocomoMetric(LocomoEvalModelModules):
- def __init__(self, args):
- super().__init__(args=args)
-
- def run(self):
- with open(self.judged_path) as file:
- data = json.load(file)
-
- results = calculate_scores(data)
-
- with open(self.grade_path, "w") as outfile:
- json.dump(results, outfile, indent=4)
-
- save_to_excel(results, self.excel_path)
-
- print("\n=== Metric Calculation Complete ===")
- total = sum(results["category_scores"][cat]["total"] for cat in results["category_scores"])
- print(
- f"LLM-as-a-Judge score: {results['metrics']['llm_judge_score']:.4f} ± {results['metrics']['llm_judge_std']:.4f}"
- )
- print(f"Total questions evaluated: {total}")
-
- print("\n=== Duration Metrics ===")
- for metric in ["response_duration_ms", "search_duration_ms", "total_duration_ms"]:
- print(f"{metric} (avg): {results['metrics']['duration'][metric]:.2f} ms")
- print(f"{metric} (P50): {results['metrics']['duration'][f'{metric}_p50']:.2f} ms")
- print(f"{metric} (P95): {results['metrics']['duration'][f'{metric}_p95']:.2f} ms")
-
- print(f"\nResults have been written to {self.grade_path}")
- print(f"Excel report has been saved to {self.excel_path}")
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--lib",
- type=str,
- default="memos_scheduler",
- choices=["zep", "memos", "memos_scheduler", "mem0", "mem0_graph", "langmem", "openai"],
- help="Specify the memory framework (zep or memos or mem0 or mem0_graph)",
- )
- parser.add_argument(
- "--version",
- type=str,
- default="v1.0.1",
- help="Version identifier for loading results (e.g., 1010)",
- )
- cli_args = parser.parse_args()
-
- # Build a minimal args namespace compatible with LocomoEvalModelModules
- class _Args:
- def __init__(self, frame, version):
- self.frame = frame
- self.version = version
- self.workers = 1
- self.top_k = 20
- self.scheduler_flag = True
-
- args = _Args(frame=cli_args.lib, version=cli_args.version)
- LocomoMetric(args=args).run()
diff --git a/evaluation/scripts/temporal_locomo/models/locomo_processor.py b/evaluation/scripts/temporal_locomo/models/locomo_processor.py
deleted file mode 100644
index 7cec6f5af..000000000
--- a/evaluation/scripts/temporal_locomo/models/locomo_processor.py
+++ /dev/null
@@ -1,370 +0,0 @@
-import json
-import sys
-
-from collections import defaultdict
-from concurrent.futures import ThreadPoolExecutor, as_completed
-from pathlib import Path
-from time import time
-
-from dotenv import load_dotenv
-
-from evaluation.scripts.temporal_locomo.modules.constants import (
- MEMOS_SCHEDULER_MODEL,
-)
-from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules
-from evaluation.scripts.temporal_locomo.modules.prompts import (
- SEARCH_PROMPT_MEM0,
- SEARCH_PROMPT_MEM0_GRAPH,
- SEARCH_PROMPT_MEMOS,
- SEARCH_PROMPT_ZEP,
-)
-from evaluation.scripts.temporal_locomo.modules.schemas import ContextUpdateMethod, RecordingCase
-from evaluation.scripts.temporal_locomo.modules.utils import save_evaluation_cases
-from memos.log import get_logger
-
-
-FILE_PATH = Path(__file__).absolute()
-BASE_DIR = FILE_PATH.parent.parent.parent
-sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
-
-logger = get_logger(__name__)
-
-
-class LocomoProcessor(LocomoEvalModelModules):
- """
- A class for handling conversational memory management across different memory frameworks.
- Supports multiple memory backends (zep, mem0, memos, etc.) for searching and retrieving
- relevant context to generate conversational responses.
- """
-
- def __init__(self, args):
- """Initialize the LocomoChatter with path configurations and templates"""
- super().__init__(args=args)
-
- # Template definitions for different memory frameworks
- self.search_template_zep = SEARCH_PROMPT_ZEP
-
- self.search_template_mem0 = SEARCH_PROMPT_MEM0
-
- self.search_template_mem0_graph = SEARCH_PROMPT_MEM0_GRAPH
-
- self.search_template_memos = SEARCH_PROMPT_MEMOS
-
- self.processed_data_dir = self.result_dir / "processed_data"
-
- def update_context(self, conv_id, method, **kwargs):
- if method == ContextUpdateMethod.CHAT_HISTORY:
- if "query" not in kwargs or "answer" not in kwargs:
- raise ValueError("query and answer are required for TEMPLATE update method")
- new_context = f"User: {kwargs['query']}\nAssistant: {kwargs['answer']}\n\n"
- if self.pre_context_cache[conv_id] is None:
- self.pre_context_cache[conv_id] = ""
- self.pre_context_cache[conv_id] += new_context
- else:
- if "cur_context" not in kwargs:
- raise ValueError("cur_context is required for DIRECT update method")
- cur_context = kwargs["cur_context"]
- self.pre_context_cache[conv_id] = cur_context
-
- def eval_context(self, context, query, gold_answer, oai_client):
- can_answer_start = time()
- can_answer = self.analyze_context_answerability(context, query, gold_answer, oai_client)
- can_answer_duration_ms = (time() - can_answer_start) * 1000
- # Update global stats
- with self.stats_lock:
- self.stats[self.frame][self.version]["memory_stats"]["total_queries"] += 1
- if can_answer:
- self.stats[self.frame][self.version]["memory_stats"]["can_answer_count"] += 1
- else:
- self.stats[self.frame][self.version]["memory_stats"]["cannot_answer_count"] += 1
- total_queries = self.stats[self.frame][self.version]["memory_stats"]["total_queries"]
- can_answer_count = self.stats[self.frame][self.version]["memory_stats"][
- "can_answer_count"
- ]
- hit_rate = (can_answer_count / total_queries * 100) if total_queries > 0 else 0
- self.stats[self.frame][self.version]["memory_stats"]["answer_hit_rate"] = hit_rate
- self.stats[self.frame][self.version]["memory_stats"]["can_answer_duration_ms"] = (
- can_answer_duration_ms
- )
- self.save_stats()
- return can_answer, can_answer_duration_ms
-
- def _update_stats_and_context(
- self,
- *,
- conv_id,
- frame,
- version,
- conv_stats,
- conv_stats_path,
- query,
- answer,
- gold_answer,
- cur_context,
- can_answer,
- ):
- """
- Update conversation statistics and context.
-
- Args:
- conv_id: Conversation ID
- frame: Model frame
- version: Model version
- conv_stats: Conversation statistics dictionary
- conv_stats_path: Path to save conversation statistics
- query: User query
- answer: Generated answer
- gold_answer: Golden answer
- cur_context: Current context
- can_answer: Whether the context can answer the query
- """
- # Update conversation stats
- conv_stats["total_queries"] += 1
- conv_stats["response_count"] += 1
- if frame == MEMOS_SCHEDULER_MODEL:
- if can_answer:
- conv_stats["can_answer_count"] += 1
- else:
- conv_stats["cannot_answer_count"] += 1
- if conv_stats["total_queries"] > 0:
- conv_stats["answer_hit_rate"] = (
- conv_stats["can_answer_count"] / conv_stats["total_queries"]
- ) * 100
-
- # Persist conversation stats snapshot
- self._save_conv_stats(conv_id, frame, version, conv_stats, conv_stats_path)
-
- logger.info(f"Processed question: {query[:100]}")
- logger.info(f"Answer: {answer[:100]}")
-
- # Update pre-context cache with current context
- with self.stats_lock:
- if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY:
- self.update_context(
- conv_id=conv_id,
- method=self.context_update_method,
- query=query,
- answer=answer,
- )
- else:
- self.update_context(
- conv_id=conv_id,
- method=self.context_update_method,
- cur_context=cur_context,
- )
-
- self.print_eval_info()
-
- def _process_single_qa(
- self,
- qa,
- *,
- client,
- reversed_client,
- metadata,
- frame,
- version,
- conv_id,
- conv_stats_path,
- oai_client,
- top_k,
- conv_stats,
- ):
- query = qa.get("question")
- gold_answer = qa.get("answer")
- qa_category = qa.get("category")
- if qa_category == 5:
- return None
-
- # Search
- cur_context, search_duration_ms = self.search_query(
- client, query, metadata, frame, reversed_client=reversed_client, top_k=top_k
- )
- if not cur_context:
- logger.warning(f"No context found for query: {query[:100]}")
- cur_context = ""
-
- if self.context_update_method == ContextUpdateMethod.CURRENT_CONTEXT:
- context = cur_context
- else:
- # Context answer ability analysis (for memos_scheduler only)
- if self.pre_context_cache[conv_id] is None:
- # Update pre-context cache with current context and return
- if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY:
- answer_from_cur_context = self.locomo_response(
- frame, oai_client, cur_context, query
- )
- self.update_context(
- conv_id=conv_id,
- method=self.context_update_method,
- query=query,
- answer=answer_from_cur_context,
- )
- else:
- self.update_context(
- conv_id=conv_id,
- method=self.context_update_method,
- cur_context=cur_context,
- )
- return None
- else:
- context = self.pre_context_cache[conv_id]
-
- # Generate answer
- answer_start = time()
- answer = self.locomo_response(frame, oai_client, context, query)
- response_duration_ms = (time() - answer_start) * 1000
-
- can_answer, can_answer_duration_ms = self.eval_context(
- context=context, query=query, gold_answer=gold_answer, oai_client=oai_client
- )
-
- # Record case for memos_scheduler
- try:
- recording_case = RecordingCase(
- conv_id=conv_id,
- query=query,
- answer=answer,
- context=cur_context,
- pre_context=self.pre_context_cache[conv_id],
- can_answer=can_answer,
- can_answer_reason=f"Context analysis result: {'can answer' if can_answer else 'cannot answer'}",
- search_duration_ms=search_duration_ms,
- can_answer_duration_ms=can_answer_duration_ms,
- response_duration_ms=response_duration_ms,
- category=int(qa_category) if qa_category is not None else None,
- golden_answer=str(qa.get("answer", "")),
- )
- if can_answer:
- self.can_answer_cases.append(recording_case)
- else:
- self.cannot_answer_cases.append(recording_case)
- except Exception as e:
- logger.error(f"Error creating RecordingCase: {e}")
- print(f"Error creating RecordingCase: {e}")
- logger.error(f"QA data: {qa}")
- print(f"QA data: {qa}")
- logger.error(f"Query: {query}")
- logger.error(f"Answer: {answer}")
- logger.error(
- f"Golden answer (raw): {qa.get('answer')} (type: {type(qa.get('answer'))})"
- )
- logger.error(f"Category: {qa_category} (type: {type(qa_category)})")
- logger.error(f"Can answer: {can_answer}")
- raise e
-
- if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY:
- answer_from_cur_context = self.locomo_response(frame, oai_client, cur_context, query)
- answer = answer_from_cur_context
- # Update conversation stats and context
- self._update_stats_and_context(
- conv_id=conv_id,
- frame=frame,
- version=version,
- conv_stats=conv_stats,
- conv_stats_path=conv_stats_path,
- query=query,
- answer=answer,
- gold_answer=gold_answer,
- cur_context=cur_context,
- can_answer=can_answer,
- )
-
- return {
- "question": query,
- "answer": answer,
- "category": qa_category,
- "golden_answer": gold_answer,
- "search_context": cur_context,
- "response_duration_ms": response_duration_ms,
- "search_duration_ms": search_duration_ms,
- "can_answer_duration_ms": can_answer_duration_ms,
- "can_answer": can_answer if frame == MEMOS_SCHEDULER_MODEL else None,
- }
-
- def run_locomo_processing(self, num_users=10):
- load_dotenv()
-
- frame = self.frame
- version = self.version
- num_workers = self.workers
- top_k = self.top_k
-
- # Storage for aggregated results
- all_search_results = defaultdict(list)
- all_response_results = defaultdict(list)
- num_users = num_users
-
- # Prepare arguments for each user processing task
- user_args = [(idx, self.locomo_df, frame, version, top_k) for idx in range(num_users)]
-
- if num_workers > 1:
- # === parallel running ====
- # Use ThreadPoolExecutor for parallel processing
- print(
- f"Starting parallel processing for {num_users} users, using {num_workers} workers for sessions..."
- )
- with ThreadPoolExecutor(max_workers=num_workers) as executor:
- # Submit all user processing tasks
- future_to_user = {
- executor.submit(self.process_user_wrapper, args): idx
- for idx, args in enumerate(user_args)
- }
-
- # Collect results as they complete
- for future in as_completed(future_to_user):
- idx = future_to_user[future]
- user_search_results, user_response_results, error = future.result()
- if error is not None:
- idx, e, traceback_str = error
- print(f"Error processing user {idx}: {e}. Exception: {traceback_str}")
- else:
- # Aggregate results
- conv_id = f"locomo_exp_user_{idx}"
- all_search_results[conv_id].extend(user_search_results[conv_id])
- all_response_results[conv_id].extend(user_response_results[conv_id])
-
- else:
- # Serial processing
- print(
- f"Starting serial processing for {num_users} users in serial mode, each user using {num_workers} workers for sessions..."
- )
- for idx, args in enumerate(user_args):
- user_search_results, user_response_results, error = self.process_user_wrapper(args)
- if error is not None:
- idx, e, traceback_str = error
- print(f"Error processing user {idx}: {e}. Exception: {traceback_str}")
- else:
- # Aggregate results
- conv_id = f"locomo_exp_user_{idx}"
- all_search_results[conv_id].extend(user_search_results[conv_id])
- all_response_results[conv_id].extend(user_response_results[conv_id])
-
- # Print evaluation information statistics
- self.print_eval_info()
- self.save_stats()
-
- # Save all aggregated results
- with open(self.search_path, "w") as fw:
- json.dump(all_search_results, fw, indent=2)
- print(f"Saved all search results to {self.search_path}")
-
- with open(self.response_path, "w") as fw:
- json.dump(all_response_results, fw, indent=2)
- print(f"Saved all response results to {self.response_path}")
-
- # Save evaluation cases if they exist
- if self.can_answer_cases or self.cannot_answer_cases:
- try:
- saved_files = save_evaluation_cases(
- can_answer_cases=self.can_answer_cases,
- cannot_answer_cases=self.cannot_answer_cases,
- output_dir=self.stats_dir,
- frame=self.frame,
- version=self.version,
- )
- print(f"Saved evaluation cases: {saved_files}")
- except Exception as e:
- logger.error(f"Error saving evaluation cases: {e}")
-
- return dict(all_search_results), dict(all_response_results)
diff --git a/evaluation/scripts/temporal_locomo/models/locomo_processor_w_time_eval.py b/evaluation/scripts/temporal_locomo/models/locomo_processor_w_time_eval.py
deleted file mode 100644
index b909c64e1..000000000
--- a/evaluation/scripts/temporal_locomo/models/locomo_processor_w_time_eval.py
+++ /dev/null
@@ -1,229 +0,0 @@
-import sys
-import time
-
-from pathlib import Path
-from typing import TYPE_CHECKING
-
-from evaluation.scripts.temporal_locomo.models.locomo_processor import LocomoProcessor
-from evaluation.scripts.temporal_locomo.modules.constants import (
- MEMOS_SCHEDULER_MODEL,
-)
-from evaluation.scripts.temporal_locomo.modules.prompts import (
- SEARCH_PROMPT_MEMOS,
-)
-from evaluation.scripts.temporal_locomo.modules.schemas import ContextUpdateMethod, RecordingCase
-from memos.log import get_logger
-
-
-if TYPE_CHECKING:
- from memos.mem_os.main import MOS
-
-FILE_PATH = Path(__file__).absolute()
-BASE_DIR = FILE_PATH.parent.parent.parent
-sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
-
-logger = get_logger(__name__)
-
-
-class LocomoProcessorWithTimeEval(LocomoProcessor):
- def __init__(self, args):
- super().__init__(args=args)
- self.time_eval_mode = getattr(self.args, "time_eval_mode", False)
- assert self.args.frame == MEMOS_SCHEDULER_MODEL
- assert self.context_update_method == ContextUpdateMethod.PRE_CONTEXT
- if self.time_eval_mode:
- logger.warning(
- "time_eval_mode is activated. _process_single_qa is replaced by _process_single_qa_for_time_eval"
- )
- self._process_single_qa = self._process_single_qa_for_time_eval
-
- def memos_scheduler_search(
- self, client, query, conv_id, speaker_a, speaker_b, reversed_client=None, top_k=20
- ):
- # MemOS full search process and skip the parts of scheduler
- start = time.time()
- client: MOS = client
-
- if not self.scheduler_flag:
- # if not scheduler_flag, search to update working memory
- self.memos_search(client, query, conv_id, speaker_a, speaker_b, reversed_client)
-
- # ========= MemOS Search =========
- # Search for speaker A
- search_a_results = client.search(
- query=query,
- user_id=conv_id + "_speaker_a",
- install_cube_ids=[conv_id + "_speaker_a"],
- top_k=top_k,
- mode="fine",
- internet_search=False,
- moscube=False, # cube for mos introduction
- session_id=None,
- )["text_mem"]
- search_a_results = [[m.memory for m in one["memories"]] for one in search_a_results]
- search_a_results = [item for sublist in search_a_results for item in sublist]
-
- # Search for speaker B
- search_b_results = client.search(
- query=query,
- user_id=conv_id + "_speaker_b",
- install_cube_ids=[conv_id + "_speaker_b"],
- top_k=top_k,
- mode="fine",
- internet_search=False,
- moscube=False, # cube for mos introduction
- session_id=None,
- )["text_mem"]
- search_b_results = [[m.memory for m in one["memories"]] for one in search_b_results]
- search_b_results = [item for sublist in search_b_results for item in sublist]
-
- speaker_a_context = ""
- for item in search_a_results:
- speaker_a_context += f"{item}\n"
-
- speaker_b_context = ""
- for item in search_b_results:
- speaker_b_context += f"{item}\n"
-
- context = SEARCH_PROMPT_MEMOS.format(
- speaker_1=speaker_a,
- speaker_1_memories=speaker_a_context,
- speaker_2=speaker_b,
- speaker_2_memories=speaker_b_context,
- )
-
- logger.info(f'query "{query[:100]}", context: {context[:100]}"')
- duration_ms = (time.time() - start) * 1000
-
- return context, duration_ms
-
- def _process_single_qa_for_time_eval(
- self,
- qa,
- *,
- client,
- reversed_client,
- metadata,
- frame,
- version,
- conv_id,
- conv_stats_path,
- oai_client,
- top_k,
- conv_stats,
- ):
- query = qa.get("question")
- gold_answer = qa.get("answer")
- qa_category = qa.get("category")
- if qa_category == 5:
- return None
-
- # 1. two parallel process,
- # 1. memos search + response
- # 2. pre_memories can answer, true : direct answer false:
-
- # Search
- assert self.args.frame == MEMOS_SCHEDULER_MODEL
- cur_context, search_duration_ms = self.search_query(
- client, query, metadata, frame, reversed_client=reversed_client, top_k=top_k
- )
- if not cur_context:
- logger.warning(f"No context found for query: {query[:100]}")
- cur_context = ""
-
- # Context answer ability analysis (for memos_scheduler only)
- if self.pre_context_cache[conv_id] is None:
- # Update pre-context cache with current context and return
- self.update_context(
- conv_id=conv_id,
- method=self.context_update_method,
- cur_context=cur_context,
- )
-
- # ========= MemOS Scheduler update =========
- _ = client.mem_scheduler.update_working_memory_for_eval(
- query=query, user_id=conv_id + "_speaker_a", top_k=top_k
- )
-
- _ = client.mem_scheduler.update_working_memory_for_eval(
- query=query, user_id=conv_id + "_speaker_b", top_k=top_k
- )
- return None
-
- context = self.pre_context_cache[conv_id]
-
- # Generate answer
- answer_start = time.time()
- answer = self.locomo_response(frame, oai_client, context, query)
- response_duration_ms = (time.time() - answer_start) * 1000
-
- can_answer, can_answer_duration_ms = self.eval_context(
- context=context, query=query, gold_answer=gold_answer, oai_client=oai_client
- )
-
- # Record case for memos_scheduler
- try:
- recording_case = RecordingCase(
- conv_id=conv_id,
- query=query,
- answer=answer,
- context=cur_context,
- pre_context=self.pre_context_cache[conv_id],
- can_answer=can_answer,
- can_answer_reason=f"Context analysis result: {'can answer' if can_answer else 'cannot answer'}",
- search_duration_ms=search_duration_ms,
- can_answer_duration_ms=can_answer_duration_ms,
- response_duration_ms=response_duration_ms,
- category=int(qa_category) if qa_category is not None else None,
- golden_answer=str(qa.get("answer", "")),
- )
- if can_answer:
- self.can_answer_cases.append(recording_case)
- else:
- self.cannot_answer_cases.append(recording_case)
- except Exception as e:
- logger.error(f"Error creating RecordingCase: {e}")
- print(f"Error creating RecordingCase: {e}")
- logger.error(f"QA data: {qa}")
- print(f"QA data: {qa}")
- logger.error(f"Query: {query}")
- logger.error(f"Answer: {answer}")
- logger.error(
- f"Golden answer (raw): {qa.get('answer')} (type: {type(qa.get('answer'))})"
- )
- logger.error(f"Category: {qa_category} (type: {type(qa_category)})")
- logger.error(f"Can answer: {can_answer}")
- raise e
-
- # Update conversation stats and context
- self._update_stats_and_context(
- conv_id=conv_id,
- frame=frame,
- version=version,
- conv_stats=conv_stats,
- conv_stats_path=conv_stats_path,
- query=query,
- answer=answer,
- gold_answer=gold_answer,
- cur_context=cur_context,
- can_answer=can_answer,
- )
- # ========= MemOS Scheduler update =========
- _ = client.mem_scheduler.update_working_memory_for_eval(
- query=query, user_id=conv_id + "_speaker_a", top_k=top_k
- )
-
- _ = client.mem_scheduler.update_working_memory_for_eval(
- query=query, user_id=conv_id + "_speaker_b", top_k=top_k
- )
- return {
- "question": query,
- "answer": answer,
- "category": qa_category,
- "golden_answer": gold_answer,
- "search_context": cur_context,
- "response_duration_ms": response_duration_ms,
- "search_duration_ms": search_duration_ms,
- "can_answer_duration_ms": can_answer_duration_ms,
- "can_answer": can_answer if frame == MEMOS_SCHEDULER_MODEL else None,
- }
diff --git a/evaluation/scripts/temporal_locomo/modules/README.md b/evaluation/scripts/temporal_locomo/modules/README.md
deleted file mode 100644
index 31a274dd0..000000000
--- a/evaluation/scripts/temporal_locomo/modules/README.md
+++ /dev/null
@@ -1,83 +0,0 @@
-# Evaluation Modules
-
-This directory contains the modularized evaluation system for temporal locomo evaluation, organized using inheritance and composition patterns.
-
-## Structure
-
-### Base Classes
-
-- **`base_eval_module.py`**: Contains the `BaseEvalModule` class with common functionality:
- - Statistics management
- - Data loading and processing
- - File I/O operations
- - Basic evaluation methods
-
-### Specialized Modules
-
-- **`client_manager.py`**: Contains the `ClientManager` class for managing different memory framework clients:
- - Zep client management
- - Mem0 client management
- - Memos client management
- - Memos scheduler client management
-
-- **`search_modules.py`**: Contains the `SearchModules` class with all search methods:
- - `mem0_search()`: Mem0 framework search
- - `mem0_graph_search()`: Mem0 graph framework search
- - `memos_search()`: Memos framework search
- - `memos_scheduler_search()`: Memos scheduler framework search
- - `zep_search()`: Zep framework search
-
-- **`locomo_eval_module.py`**: Contains the main `LocomoEvalModule` class that combines all functionality:
- - Inherits from `BaseEvalModule`
- - Uses `ClientManager` for client management
- - Uses `SearchModules` for search operations
- - Provides unified interface for evaluation
-
-## Usage
-
-### Basic Usage
-
-```python
-from modules import LocomoEvalModule
-import argparse
-
-# Create arguments
-args = argparse.Namespace()
-args.frame = 'memos_scheduler'
-args.version = 'v0.2.1'
-args.top_k = 20
-args.workers = 1
-
-# Initialize the evaluation module
-eval_module = LocomoEvalModule(args)
-
-# Use the module
-eval_module.print_eval_info()
-eval_module.save_stats()
-```
-
-### Backward Compatibility
-
-For backward compatibility, the original `LocomoEvalModelModules` class is available as an alias:
-
-```python
-from modules import LocomoEvalModule as LocomoEvalModelModules
-```
-
-## Benefits of Modularization
-
-1. **Separation of Concerns**: Each module has a specific responsibility
-2. **Maintainability**: Easier to modify and extend individual components
-3. **Testability**: Each module can be tested independently
-4. **Reusability**: Modules can be reused in different contexts
-5. **Readability**: Code is more organized and easier to understand
-
-## Migration from Original Code
-
-The original `eval_model_modules.py` has been refactored into this modular structure:
-
-- **Original class**: `LocomoEvalModelModules`
-- **New main class**: `LocomoEvalModule`
-- **Backward compatibility**: `LocomoEvalModelModules = LocomoEvalModule`
-
-All existing functionality is preserved, but now organized in a more maintainable structure.
diff --git a/evaluation/scripts/temporal_locomo/modules/__init__.py b/evaluation/scripts/temporal_locomo/modules/__init__.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/evaluation/scripts/temporal_locomo/modules/base_eval_module.py b/evaluation/scripts/temporal_locomo/modules/base_eval_module.py
deleted file mode 100644
index d056745cc..000000000
--- a/evaluation/scripts/temporal_locomo/modules/base_eval_module.py
+++ /dev/null
@@ -1,386 +0,0 @@
-import json
-import os
-import traceback
-
-from collections import defaultdict
-from pathlib import Path
-from threading import Lock
-from typing import TYPE_CHECKING
-
-import pandas as pd
-
-from dotenv import load_dotenv
-
-from memos.configs.mem_scheduler import AuthConfig
-from memos.log import get_logger
-
-from .constants import (
- BASE_DIR,
- MEMOS_SCHEDULER_MODEL,
-)
-from .prompts import (
- CUSTOM_INSTRUCTIONS,
-)
-from .schemas import ContextUpdateMethod
-
-
-if TYPE_CHECKING:
- from .schemas import RecordingCase
-
-
-logger = get_logger(__name__)
-
-
-class BaseEvalModule:
- def __init__(self, args):
- # hyper-parameters
- self.args = args
- self.frame = self.args.frame
- self.version = self.args.version
- self.workers = self.args.workers
- self.top_k = self.args.top_k
-
- # attributes
- self.context_update_method = getattr(
- self.args, "context_update_method", ContextUpdateMethod.PRE_CONTEXT
- )
- self.custom_instructions = CUSTOM_INSTRUCTIONS
- self.data_dir = Path(f"{BASE_DIR}/data")
- self.locomo_df = pd.read_json(f"{self.data_dir}/locomo/locomo10.json")
-
- # Load temporal_locomo dataset if it exists
- self.temporal_locomo_data = None
- temporal_locomo_file = self.data_dir / "temporal_locomo" / "temporal_locomo_qa.json"
- if temporal_locomo_file.exists():
- with open(temporal_locomo_file, encoding="utf-8") as f:
- self.temporal_locomo_data = json.load(f)
- logger.info(
- f"Loaded temporal_locomo dataset with {len(self.temporal_locomo_data)} conversations"
- )
- else:
- logger.warning(f"Temporal locomo dataset not found at {temporal_locomo_file}")
-
- result_dir_prefix = getattr(self.args, "result_dir_prefix", "")
-
- # Configure result dir; if scheduler disabled and using memos scheduler, mark as ablation
- if (
- hasattr(self.args, "scheduler_flag")
- and self.frame == MEMOS_SCHEDULER_MODEL
- and self.args.scheduler_flag is False
- ):
- self.result_dir = Path(
- f"{BASE_DIR}/results/temporal_locomo/{result_dir_prefix}{self.frame}-{self.version}-ablation/"
- )
- else:
- self.result_dir = Path(
- f"{BASE_DIR}/results/temporal_locomo/{result_dir_prefix}{self.frame}-{self.version}/"
- )
-
- if self.context_update_method != ContextUpdateMethod.PRE_CONTEXT:
- self.result_dir = (
- self.result_dir.parent / f"{self.result_dir.name}_{self.context_update_method}"
- )
- self.result_dir.mkdir(parents=True, exist_ok=True)
-
- self.search_path = self.result_dir / f"{self.frame}-{self.version}_search_results.json"
- self.response_path = self.result_dir / f"{self.frame}-{self.version}_responses.json"
- self.judged_path = self.result_dir / f"{self.frame}-{self.version}_judged.json"
- self.grade_path = self.result_dir / f"{self.frame}-{self.version}_grades.json"
- self.excel_path = self.result_dir / f"{self.frame}-{self.version}_metrics.xlsx"
-
- self.ingestion_storage_dir = self.result_dir / "storages"
- self.mos_config_path = Path(f"{BASE_DIR}/configs-example/mos_w_scheduler_config.json")
- self.mem_cube_config_path = Path(f"{BASE_DIR}/configs-example/mem_cube_config.json")
-
- self.openai_api_key = os.getenv("CHAT_MODEL_API_KEY")
- self.openai_base_url = os.getenv("CHAT_MODEL_BASE_URL")
- self.openai_chat_model = os.getenv("CHAT_MODEL")
-
- auth_config_path = Path(f"{BASE_DIR}/scripts/temporal_locomo/eval_auth.json")
- if auth_config_path.exists():
- auth_config = AuthConfig.from_local_config(config_path=auth_config_path)
- print(
- f"✅ Configuration loaded successfully: from local config file {auth_config_path}"
- )
- else:
- # Load .env file first before reading environment variables
- load_dotenv()
- auth_config = AuthConfig.from_local_env()
- print("✅ Configuration loaded successfully: from environment variables")
- self.openai_api_key = auth_config.openai.api_key
- self.openai_base_url = auth_config.openai.base_url
- self.openai_chat_model = auth_config.openai.default_model
-
- self.mos_config_data = json.load(self.mos_config_path.open("r", encoding="utf-8"))
- self.mem_cube_config_data = json.load(self.mem_cube_config_path.open("r", encoding="utf-8"))
-
- # Update LLM authentication information in MOS configuration using dictionary assignment
- self.mos_config_data["mem_reader"]["config"]["llm"]["config"]["api_key"] = (
- auth_config.openai.api_key
- )
- self.mos_config_data["mem_reader"]["config"]["llm"]["config"]["api_base"] = (
- auth_config.openai.base_url
- )
-
- # Update graph database authentication information in memory cube configuration using dictionary assignment
- self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["uri"] = (
- auth_config.graph_db.uri
- )
- self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["user"] = (
- auth_config.graph_db.user
- )
- self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["password"] = (
- auth_config.graph_db.password
- )
- self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["db_name"] = (
- auth_config.graph_db.db_name
- )
- self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["auto_create"] = (
- auth_config.graph_db.auto_create
- )
-
- # Logger initialization
- self.logger = logger
-
- # Statistics tracking with thread safety
- self.stats = {self.frame: {self.version: defaultdict(dict)}}
- self.stats[self.frame][self.version]["memory_stats"] = defaultdict(dict)
- self.stats[self.frame][self.version]["memory_stats"]["total_queries"] = 0
- self.stats[self.frame][self.version]["memory_stats"]["can_answer_count"] = 0
- self.stats[self.frame][self.version]["memory_stats"]["cannot_answer_count"] = 0
- self.stats[self.frame][self.version]["memory_stats"]["answer_hit_rate"] = 0.0
-
- # Initialize memory history for tracking retrieval results
- self.stats_lock = Lock()
- # Reflect CLI flag
- self.scheduler_flag = bool(getattr(self.args, "scheduler_flag", True))
- self.stats_dir = self.result_dir / f"stats/{self.frame}_{self.version}"
- self.stats_dir.mkdir(parents=True, exist_ok=True) # Ensure the directory exists
- self.stats_path = self.stats_dir / "stats.txt"
-
- self.can_answer_cases: list[RecordingCase] = []
- self.cannot_answer_cases: list[RecordingCase] = []
-
- def print_eval_info(self):
- """
- Calculate and print the evaluation information including answer statistics for memory scheduler (thread-safe).
- Shows total queries, can answer count, cannot answer count, and answer hit rate.
- """
- with self.stats_lock:
- # Get statistics
- total_queries = self.stats[self.frame][self.version]["memory_stats"]["total_queries"]
- can_answer_count = self.stats[self.frame][self.version]["memory_stats"][
- "can_answer_count"
- ]
- cannot_answer_count = self.stats[self.frame][self.version]["memory_stats"][
- "cannot_answer_count"
- ]
- hit_rate = self.stats[self.frame][self.version]["memory_stats"]["answer_hit_rate"]
-
- # Print basic statistics
- print(f"Total Queries: {total_queries}")
- logger.info(f"Total Queries: {total_queries}")
-
- print(f"Can Answer Count: {can_answer_count}")
- logger.info(f"Can Answer Count: {can_answer_count}")
-
- print(f"Cannot Answer Count: {cannot_answer_count}")
- logger.info(f"Cannot Answer Count: {cannot_answer_count}")
-
- # Verify count consistency
- if total_queries != (can_answer_count + cannot_answer_count):
- print(
- f"WARNING: Count mismatch! Total ({total_queries}) != Can Answer ({can_answer_count}) + Cannot Answer ({cannot_answer_count})"
- )
- logger.warning(
- f"Count mismatch! Total ({total_queries}) != Can Answer ({can_answer_count}) + Cannot Answer ({cannot_answer_count})"
- )
-
- print(f"Answer Hit Rate: {hit_rate:.2f}% ({can_answer_count}/{total_queries})")
- logger.info(f"Answer Hit Rate: {hit_rate:.2f}% ({can_answer_count}/{total_queries})")
-
- def save_stats(self):
- """
- Serializes and saves the contents of self.stats to the specified path:
- Base_dir/results/frame-version/stats
-
- This method handles directory creation, thread-safe access to statistics data,
- and proper JSON serialization of complex data structures.
- """
- try:
- # Thread-safe access to the stats data using the lock
- # Create a copy of the data to prevent modification during serialization
- stats_data = dict(self.stats)
-
- # Helper function to convert defaultdict to regular dict for JSON serialization
- def convert_defaultdict(obj):
- if isinstance(obj, defaultdict):
- return dict(obj)
- return obj
-
- # Debug: Print stats summary before saving
- self.logger.info(f"DEBUG: Saving stats for {self.frame}-{self.version}")
- self.logger.info(f"DEBUG: Stats path: {self.stats_path}")
- self.logger.info(f"DEBUG: Stats data keys: {list(stats_data.keys())}")
- if self.frame in stats_data and self.version in stats_data[self.frame]:
- frame_data = stats_data[self.frame][self.version]
- self.logger.info(f"DEBUG: Memory stats: {frame_data.get('memory_stats', {})}")
- self.logger.info(
- f"DEBUG: Total queries: {frame_data.get('memory_stats', {}).get('total_queries', 0)}"
- )
-
- # Serialize and save the statistics data to file
- with self.stats_path.open("w", encoding="utf-8") as fw:
- json.dump(stats_data, fw, ensure_ascii=False, indent=2, default=convert_defaultdict)
-
- self.logger.info(f"Successfully saved stats to: {self.stats_path}")
- print(f"DEBUG: Stats file created at {self.stats_path}")
-
- except Exception as e:
- self.logger.error(f"Failed to save stats: {e!s}")
- self.logger.error(traceback.format_exc())
- print(f"DEBUG: Error saving stats: {e}")
-
- def get_answer_hit_rate(self):
- """
- Get current answer hit rate statistics.
-
- Returns:
- dict: Hit rate statistics
- """
- with self.stats_lock:
- return {
- "total_queries": self.stats[self.frame][self.version]["memory_stats"][
- "total_queries"
- ],
- "can_answer_count": self.stats[self.frame][self.version]["memory_stats"][
- "can_answer_count"
- ],
- "hit_rate_percentage": self.stats[self.frame][self.version]["memory_stats"][
- "answer_hit_rate"
- ],
- }
-
- def group_and_sort_qa_by_day(self, qa_set, sort_by_evidence):
- """
- Groups QA pairs by day and sorts them chronologically within each day group.
-
- Args:
- qa_set (list): List of dictionaries containing QA data with evidence references
-
- Returns:
- dict: Dictionary where keys are day strings (e.g., 'D1') and values are
- lists of QA pairs sorted by evidence order within that day
- """
- # Initialize a dictionary that automatically creates lists for new keys
- day_groups = defaultdict(list)
-
- # Process each QA pair in the input dataset
- for qa in qa_set:
- # Extract all unique days referenced in this QA pair's evidence
- days = set()
- for evidence in qa["evidence"]:
- # Split evidence string (e.g., 'D1:3') into day and position parts
- day = evidence.split(":")[0] # Gets 'D1', 'D2', etc.
- days.add(day)
-
- # Add this QA pair to each day group it references
- for day in days:
- day_groups[day].append(qa)
-
- if sort_by_evidence:
- # Sort QA pairs within each day group by their earliest evidence position
- for day in day_groups:
- # Create list of (qa, position) pairs for proper sorting
- qa_position_pairs = []
-
- for qa in day_groups[day]:
- # Find the earliest evidence position for this day
- earliest_position = None
- for evidence in qa["evidence"]:
- if evidence.startswith(day + ":"):
- try:
- position = int(evidence.split(":")[1])
- if earliest_position is None or position < earliest_position:
- earliest_position = position
- except (IndexError, ValueError):
- # Skip invalid evidence format
- continue
-
- if earliest_position is not None:
- qa_position_pairs.append((qa, earliest_position))
-
- # Sort by evidence position (earliest first)
- qa_position_pairs = sorted(qa_position_pairs, key=lambda x: x[1])
- day_groups[day] = [qa for qa, _ in qa_position_pairs]
-
- return dict(day_groups)
-
- def convert_locomo_to_temporal_locomo(self, output_dir: str | None = None):
- """
- Convert locomo dataset to temporal_locomo dataset format.
-
- This function processes the original locomo dataset and reorganizes it by days
- with proper chronological ordering within each day group.
-
- Args:
- output_dir: Output directory for the converted dataset.
- Defaults to evaluation/data/temporal_locomo/
-
- Returns:
- str: Path to the converted dataset file
- """
- if output_dir is None:
- output_dir = f"{BASE_DIR}/data/temporal_locomo"
-
- # Create output directory
- os.makedirs(output_dir, exist_ok=True)
-
- # Load original locomo data
- locomo_data = self.locomo_df.to_dict("records")
-
- # Process each conversation
- temporal_data = []
-
- for conv_id, conversation in enumerate(locomo_data):
- logger.info(f"Processing conversation {conv_id + 1}/{len(locomo_data)}")
-
- # Get QA pairs for this conversation
- qa_set = conversation.get("qa", [])
-
- # Group and sort QA pairs by day
- day_groups = self.group_and_sort_qa_by_day(qa_set, sort_by_evidence=False)
-
- # Create temporal structure for this conversation
- temporal_conversation = {"conversation_id": f"locomo_exp_user_{conv_id}", "days": {}}
-
- # Process each day group
- for day, qa_list in day_groups.items():
- temporal_conversation["days"][day] = {
- "day_id": day,
- "qa_pairs": qa_list,
- "total_qa_pairs": len(qa_list),
- }
-
- temporal_data.append(temporal_conversation)
-
- # Save the converted dataset
- output_file = os.path.join(output_dir, "temporal_locomo_qa.json")
- with open(output_file, "w", encoding="utf-8") as f:
- json.dump(temporal_data, f, indent=2, ensure_ascii=False)
-
- logger.info(f"Converted dataset saved to: {output_file}")
- logger.info(f"Total conversations: {len(temporal_data)}")
-
- # Log statistics
- total_qa_pairs = sum(len(conv["qa"]) for conv in locomo_data)
- total_temporal_qa_pairs = sum(
- sum(day_data["total_qa_pairs"] for day_data in conv["days"].values())
- for conv in temporal_data
- )
-
- logger.info(f"Original QA pairs: {total_qa_pairs}")
- logger.info(f"Temporal QA pairs: {total_temporal_qa_pairs}")
- logger.info("QA pairs may be duplicated across days if they reference multiple days")
-
- return output_file
diff --git a/evaluation/scripts/temporal_locomo/modules/client_manager.py b/evaluation/scripts/temporal_locomo/modules/client_manager.py
deleted file mode 100644
index c5882179e..000000000
--- a/evaluation/scripts/temporal_locomo/modules/client_manager.py
+++ /dev/null
@@ -1,191 +0,0 @@
-"""
-Client management module for handling different memory framework clients.
-"""
-
-import os
-
-from mem0 import MemoryClient
-from zep_cloud.client import Zep
-
-from memos.configs.mem_cube import GeneralMemCubeConfig
-from memos.configs.mem_os import MOSConfig
-from memos.log import get_logger
-from memos.mem_cube.general import GeneralMemCube
-from memos.mem_os.main import MOS
-from memos.mem_scheduler.analyzer.scheduler_for_eval import SchedulerForEval
-
-from .base_eval_module import BaseEvalModule
-from .constants import (
- MEM0_GRAPH_MODEL,
- MEM0_MODEL,
- MEMOS_MODEL,
- MEMOS_SCHEDULER_MODEL,
- ZEP_MODEL,
-)
-from .prompts import (
- ANSWER_PROMPT_MEM0,
- ANSWER_PROMPT_MEMOS,
- ANSWER_PROMPT_ZEP,
-)
-
-
-logger = get_logger(__name__)
-
-
-class EvalModuleWithClientManager(BaseEvalModule):
- """
- Manages different memory framework clients for evaluation.
- """
-
- def __init__(self, args):
- super().__init__(args=args)
-
- def get_client_for_ingestion(
- self, frame: str, user_id: str | None = None, version: str = "default"
- ):
- if frame == ZEP_MODEL:
- zep = Zep(api_key=os.getenv("ZEP_API_KEY"), base_url="https://api.getzep.com/api/v2")
- return zep
-
- elif frame in (MEM0_MODEL, MEM0_GRAPH_MODEL):
- mem0 = MemoryClient(api_key=os.getenv("MEM0_API_KEY"))
- mem0.update_project(custom_instructions=self.custom_instructions)
- return mem0
- else:
- if frame not in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
- raise NotImplementedError(f"Unsupported framework: {frame}")
-
- # scheduler is not needed in the ingestion step
- self.mos_config_data["top_k"] = 20
- self.mos_config_data["enable_mem_scheduler"] = False
-
- mos_config = MOSConfig(**self.mos_config_data)
- mos = MOS(mos_config)
- mos.create_user(user_id=user_id)
-
- self.mem_cube_config_data["user_id"] = user_id
- self.mem_cube_config_data["cube_id"] = user_id
- self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["db_name"] = (
- f"{user_id.replace('_', '')}{version}"
- )
- mem_cube_config = GeneralMemCubeConfig.model_validate(self.mem_cube_config_data)
- mem_cube = GeneralMemCube(mem_cube_config)
-
- storage_path = str(self.ingestion_storage_dir / user_id)
- try:
- mem_cube.dump(storage_path)
- except Exception as e:
- print(f"dumping memory cube: {e!s} already exists, will use it.")
-
- mos.register_mem_cube(
- mem_cube_name_or_path=storage_path,
- mem_cube_id=user_id,
- user_id=user_id,
- )
-
- return mos
-
- def get_client_from_storage(
- self, frame: str, user_id: str | None = None, version: str = "default", top_k: int = 20
- ):
- """
- Get a client instance for the specified memory framework.
-
- Args:
- frame: Memory framework to use (zep, mem0, mem0_graph, memos, memos_scheduler)
- user_id: Unique identifier for the user
- version: Version identifier for result storage
- top_k: Number of results to retrieve in search queries
-
- Returns:
- Client instance for the specified framework
- """
- storage_path = str(self.ingestion_storage_dir / user_id)
-
- if frame == ZEP_MODEL:
- zep = Zep(api_key=os.getenv("ZEP_API_KEY"), base_url="https://api.getzep.com/api/v2")
- return zep
-
- elif frame == [MEM0_MODEL, MEM0_GRAPH_MODEL]:
- mem0 = MemoryClient(api_key=os.getenv("MEM0_API_KEY"))
- return mem0
-
- else:
- if frame not in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
- raise NotImplementedError(f"Unsupported framework: {frame}")
-
- if frame == MEMOS_MODEL:
- self.mos_config_data["enable_mem_scheduler"] = False
-
- self.mos_config_data["top_k"] = top_k
- mos_config = MOSConfig(**self.mos_config_data)
- mos = MOS(mos_config)
- mos.create_user(user_id=user_id)
- mos.register_mem_cube(
- mem_cube_name_or_path=storage_path,
- mem_cube_id=user_id,
- user_id=user_id,
- )
-
- if frame == MEMOS_SCHEDULER_MODEL:
- # Configure memory scheduler
- mos.mem_scheduler.current_mem_cube = mos.mem_cubes[user_id]
- mos.mem_scheduler.current_mem_cube_id = user_id
- mos.mem_scheduler.current_user_id = user_id
-
- # Create SchedulerForEval instance with the same config
- scheduler_for_eval = SchedulerForEval(config=mos.mem_scheduler.config)
- # Initialize with the same modules as the original scheduler
- scheduler_for_eval.initialize_modules(
- chat_llm=mos.mem_scheduler.chat_llm,
- process_llm=mos.mem_scheduler.process_llm,
- db_engine=mos.mem_scheduler.db_engine,
- )
- # Set the same context
- scheduler_for_eval.current_mem_cube = mos.mem_cubes[user_id]
- scheduler_for_eval.current_mem_cube_id = user_id
- scheduler_for_eval.current_user_id = user_id
-
- # set llms to openai api
- mos.chat_llm = mos.mem_reader.llm
- for cube in mos.mem_cubes.values():
- cube.text_mem.dispatcher_llm = mos.mem_reader.llm
- cube.text_mem.extractor_llm = mos.mem_reader.llm
-
- # Replace the original scheduler
- mos.mem_scheduler = scheduler_for_eval
- return mos
-
- def locomo_response(self, frame, llm_client, context: str, question: str) -> str:
- if frame == ZEP_MODEL:
- prompt = ANSWER_PROMPT_ZEP.format(
- context=context,
- question=question,
- )
- elif frame in (MEM0_MODEL, MEM0_GRAPH_MODEL):
- prompt = ANSWER_PROMPT_MEM0.format(
- context=context,
- question=question,
- )
- elif frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
- prompt = ANSWER_PROMPT_MEMOS.format(
- context=context,
- question=question,
- )
- else:
- raise NotImplementedError()
- response = llm_client.chat.completions.create(
- model=self.openai_chat_model,
- messages=[
- {"role": "system", "content": prompt},
- ],
- temperature=0,
- )
-
- result = response.choices[0].message.content or ""
-
- if result == "":
- with self.stats_lock:
- self.stats[self.frame][self.version]["response_stats"]["response_failure"] += 1
- self.stats[self.frame][self.version]["response_stats"]["response_count"] += 1
- return result
diff --git a/evaluation/scripts/temporal_locomo/modules/constants.py b/evaluation/scripts/temporal_locomo/modules/constants.py
deleted file mode 100644
index 51ad7c729..000000000
--- a/evaluation/scripts/temporal_locomo/modules/constants.py
+++ /dev/null
@@ -1,19 +0,0 @@
-import sys
-
-from pathlib import Path
-
-from memos.log import get_logger
-
-
-FILE_PATH = Path(__file__).absolute()
-BASE_DIR = FILE_PATH.parent.parent.parent.parent
-sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
-
-logger = get_logger(__name__)
-
-
-ZEP_MODEL = "zep"
-MEM0_MODEL = "mem0"
-MEM0_GRAPH_MODEL = "mem0_graph"
-MEMOS_MODEL = "memos"
-MEMOS_SCHEDULER_MODEL = "memos_scheduler"
diff --git a/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py b/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py
deleted file mode 100644
index d444ea62c..000000000
--- a/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py
+++ /dev/null
@@ -1,578 +0,0 @@
-import json
-import time
-import traceback
-
-from collections import defaultdict
-from datetime import datetime
-from typing import TYPE_CHECKING
-
-from openai import OpenAI
-from tqdm import tqdm
-
-from memos.log import get_logger
-
-from .client_manager import EvalModuleWithClientManager
-from .constants import (
- MEM0_GRAPH_MODEL,
- MEM0_MODEL,
- MEMOS_MODEL,
- MEMOS_SCHEDULER_MODEL,
- ZEP_MODEL,
-)
-from .prompts import (
- CONTEXT_ANSWERABILITY_PROMPT,
- SEARCH_PROMPT_MEM0,
- SEARCH_PROMPT_MEM0_GRAPH,
- SEARCH_PROMPT_MEMOS,
- SEARCH_PROMPT_ZEP,
-)
-from .utils import filter_memory_data
-
-
-if TYPE_CHECKING:
- from memos.mem_os.main import MOS
-logger = get_logger(__name__)
-
-
-class LocomoEvalModelModules(EvalModuleWithClientManager):
- """
- Contains search methods for different memory frameworks.
- """
-
- def __init__(self, args):
- super().__init__(args=args)
- self.pre_context_cache = {}
-
- def analyze_context_answerability(self, context, query, gold_answer, oai_client):
- """
- Analyze whether the given context can answer the query.
-
- Args:
- context: The context string to analyze
- query: The query string
- oai_client: OpenAI client for LLM analysis
-
- Returns:
- bool: True if context can answer the query, False otherwise
- """
- try:
- prompt = CONTEXT_ANSWERABILITY_PROMPT.format(
- context=context, question=query, gold_answer=str(gold_answer)
- )
-
- response = oai_client.chat.completions.create(
- model="gpt-4o-mini",
- messages=[{"role": "user", "content": prompt}],
- temperature=0.1,
- max_tokens=10,
- )
-
- answer = response.choices[0].message.content.strip().upper()
- return answer == "YES"
- except Exception as e:
- logger.error(f"Error analyzing context answerability: {e}")
- return False
-
- def mem0_search(self, client, query, speaker_a_user_id, speaker_b_user_id, top_k=20):
- """
- Search memories using the mem0 framework.
-
- Args:
- client: mem0 client instance
- query: Search query string
- speaker_a_user_id: User ID for first speaker
- speaker_b_user_id: User ID for second speaker
- top_k: Number of results to retrieve
-
- Returns:
- Tuple containing formatted context and search duration in milliseconds
- """
- start = time.time()
- search_speaker_a_results = client.search(
- query=query,
- top_k=top_k,
- user_id=speaker_a_user_id,
- output_format="v1.1",
- version="v2",
- filters={"AND": [{"user_id": f"{speaker_a_user_id}"}, {"run_id": "*"}]},
- )
- search_speaker_b_results = client.search(
- query=query,
- top_k=top_k,
- user_id=speaker_b_user_id,
- output_format="v1.1",
- version="v2",
- filters={"AND": [{"user_id": f"{speaker_b_user_id}"}, {"run_id": "*"}]},
- )
-
- # Format speaker A memories
- search_speaker_a_memory = [
- {
- "memory": memory["memory"],
- "timestamp": memory["created_at"],
- "score": round(memory["score"], 2),
- }
- for memory in search_speaker_a_results["results"]
- ]
-
- search_speaker_a_memory = [
- [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_a_memory]
- ]
-
- # Format speaker B memories
- search_speaker_b_memory = [
- {
- "memory": memory["memory"],
- "timestamp": memory["created_at"],
- "score": round(memory["score"], 2),
- }
- for memory in search_speaker_b_results["results"]
- ]
-
- search_speaker_b_memory = [
- [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_b_memory]
- ]
-
- # Create context using template
- context = SEARCH_PROMPT_MEM0.format(
- speaker_1_user_id=speaker_a_user_id.split("_")[0],
- speaker_1_memories=json.dumps(search_speaker_a_memory, indent=4),
- speaker_2_user_id=speaker_b_user_id.split("_")[0],
- speaker_2_memories=json.dumps(search_speaker_b_memory, indent=4),
- )
-
- duration_ms = (time.time() - start) * 1000
- return context, duration_ms
-
- def memos_search(
- self, client, query, conv_id, speaker_a, speaker_b, reversed_client=None, top_k=20
- ):
- """
- Search memories using the memos framework.
-
- Args:
- client: memos client instance
- query: Search query string
- conv_id: Conversation ID
- speaker_a: First speaker identifier
- speaker_b: Second speaker identifier
- reversed_client: Client instance for reversed speaker context
-
- Returns:
- Tuple containing formatted context and search duration in milliseconds
- """
- start = time.time()
- # Search memories for speaker A
- search_a_results = client.search(query=query, user_id=conv_id + "_speaker_a")
- filtered_search_a_results = filter_memory_data(search_a_results)["text_mem"][0]["memories"]
- speaker_a_context = ""
- for item in filtered_search_a_results[:top_k]:
- speaker_a_context += f"{item['memory']}\n"
-
- # Search memories for speaker B
- search_b_results = reversed_client.search(
- query=query,
- user_id=conv_id + "_speaker_b",
- )
- filtered_search_b_results = filter_memory_data(search_b_results)["text_mem"][0]["memories"]
- speaker_b_context = ""
- for item in filtered_search_b_results[:top_k]:
- speaker_b_context += f"{item['memory']}\n"
-
- # Create context using template
- context = SEARCH_PROMPT_MEMOS.format(
- speaker_1=speaker_a,
- speaker_1_memories=speaker_a_context,
- speaker_2=speaker_b,
- speaker_2_memories=speaker_b_context,
- )
-
- duration_ms = (time.time() - start) * 1000
- return context, duration_ms
-
- def memos_scheduler_search(
- self, client, query, conv_id, speaker_a, speaker_b, reversed_client=None, top_k=20
- ):
- start = time.time()
- client: MOS = client
-
- if not self.scheduler_flag:
- # if not scheduler_flag, search to update working memory
- self.memos_search(client, query, conv_id, speaker_a, speaker_b, reversed_client, top_k)
-
- # Search for speaker A
- search_a_results = client.mem_scheduler.search_for_eval(
- query=query,
- user_id=conv_id + "_speaker_a",
- top_k=top_k,
- scheduler_flag=self.scheduler_flag,
- )
-
- # Search for speaker B
- search_b_results = reversed_client.mem_scheduler.search_for_eval(
- query=query,
- user_id=conv_id + "_speaker_b",
- top_k=top_k,
- scheduler_flag=self.scheduler_flag,
- )
-
- speaker_a_context = ""
- for item in search_a_results:
- speaker_a_context += f"{item}\n"
-
- speaker_b_context = ""
- for item in search_b_results:
- speaker_b_context += f"{item}\n"
-
- context = SEARCH_PROMPT_MEMOS.format(
- speaker_1=speaker_a,
- speaker_1_memories=speaker_a_context,
- speaker_2=speaker_b,
- speaker_2_memories=speaker_b_context,
- )
-
- logger.info(f'query "{query[:100]}", context: {context[:100]}"')
- duration_ms = (time.time() - start) * 1000
-
- return context, duration_ms
-
- def mem0_graph_search(self, client, query, speaker_a_user_id, speaker_b_user_id, top_k=20):
- start = time.time()
- search_speaker_a_results = client.search(
- query=query,
- top_k=top_k,
- user_id=speaker_a_user_id,
- output_format="v1.1",
- version="v2",
- enable_graph=True,
- filters={"AND": [{"user_id": f"{speaker_a_user_id}"}, {"run_id": "*"}]},
- )
- search_speaker_b_results = client.search(
- query=query,
- top_k=top_k,
- user_id=speaker_b_user_id,
- output_format="v1.1",
- version="v2",
- enable_graph=True,
- filters={"AND": [{"user_id": f"{speaker_b_user_id}"}, {"run_id": "*"}]},
- )
-
- search_speaker_a_memory = [
- {
- "memory": memory["memory"],
- "timestamp": memory["created_at"],
- "score": round(memory["score"], 2),
- }
- for memory in search_speaker_a_results["results"]
- ]
-
- search_speaker_a_memory = [
- [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_a_memory]
- ]
-
- search_speaker_b_memory = [
- {
- "memory": memory["memory"],
- "timestamp": memory["created_at"],
- "score": round(memory["score"], 2),
- }
- for memory in search_speaker_b_results["results"]
- ]
-
- search_speaker_b_memory = [
- [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_b_memory]
- ]
-
- search_speaker_a_graph = [
- {
- "source": relation["source"],
- "relationship": relation["relationship"],
- "target": relation["target"],
- }
- for relation in search_speaker_a_results["relations"]
- ]
-
- search_speaker_b_graph = [
- {
- "source": relation["source"],
- "relationship": relation["relationship"],
- "target": relation["target"],
- }
- for relation in search_speaker_b_results["relations"]
- ]
- context = SEARCH_PROMPT_MEM0_GRAPH.format(
- speaker_1_user_id=speaker_a_user_id.split("_")[0],
- speaker_1_memories=json.dumps(search_speaker_a_memory, indent=4),
- speaker_1_graph_memories=json.dumps(search_speaker_a_graph, indent=4),
- speaker_2_user_id=speaker_b_user_id.split("_")[0],
- speaker_2_memories=json.dumps(search_speaker_b_memory, indent=4),
- speaker_2_graph_memories=json.dumps(search_speaker_b_graph, indent=4),
- )
- print(query, context)
- duration_ms = (time.time() - start) * 1000
- return context, duration_ms
-
- def zep_search(self, client, query, group_id, top_k=20):
- start = time.time()
- nodes_result = client.graph.search(
- query=query,
- group_id=group_id,
- scope="nodes",
- reranker="rrf",
- limit=top_k,
- )
- edges_result = client.graph.search(
- query=query,
- group_id=group_id,
- scope="edges",
- reranker="cross_encoder",
- limit=top_k,
- )
-
- nodes = nodes_result.nodes
- edges = edges_result.edges
-
- facts = [f" - {edge.fact} (event_time: {edge.valid_at})" for edge in edges]
- entities = [f" - {node.name}: {node.summary}" for node in nodes]
-
- context = SEARCH_PROMPT_ZEP.format(facts="\n".join(facts), entities="\n".join(entities))
-
- duration_ms = (time.time() - start) * 1000
-
- return context, duration_ms
-
- def search_query(self, client, query, metadata, frame, reversed_client=None, top_k=20):
- conv_id = metadata.get("conv_id")
- speaker_a = metadata.get("speaker_a")
- speaker_b = metadata.get("speaker_b")
- speaker_a_user_id = metadata.get("speaker_a_user_id")
- speaker_b_user_id = metadata.get("speaker_b_user_id")
-
- if frame == ZEP_MODEL:
- context, duration_ms = self.zep_search(client, query, conv_id, top_k)
- elif frame == MEM0_MODEL:
- context, duration_ms = self.mem0_search(
- client, query, speaker_a_user_id, speaker_b_user_id, top_k
- )
- elif frame == MEM0_GRAPH_MODEL:
- context, duration_ms = self.mem0_graph_search(
- client, query, speaker_a_user_id, speaker_b_user_id, top_k
- )
- elif frame == MEMOS_MODEL:
- context, duration_ms = self.memos_search(
- client, query, conv_id, speaker_a, speaker_b, reversed_client, top_k
- )
- elif frame == MEMOS_SCHEDULER_MODEL:
- context, duration_ms = self.memos_scheduler_search(
- client, query, conv_id, speaker_a, speaker_b, reversed_client, top_k
- )
- else:
- raise NotImplementedError()
-
- return context, duration_ms
-
- def _initialize_conv_stats(self):
- """Create a fresh statistics dictionary for a conversation."""
- return {
- "total_queries": 0,
- "can_answer_count": 0,
- "cannot_answer_count": 0,
- "answer_hit_rate": 0.0,
- "response_failure": 0,
- "response_count": 0,
- }
-
- def _build_day_groups(self, temporal_conv):
- """Build mapping day_id -> qa_pairs from a temporal conversation dict."""
- day_groups = {}
- for day_id, day_data in temporal_conv.get("days", {}).items():
- day_groups[day_id] = day_data.get("qa_pairs", [])
- return day_groups
-
- def _build_metadata(self, speaker_a, speaker_b, speaker_a_user_id, speaker_b_user_id, conv_id):
- """Assemble metadata for downstream calls."""
- return {
- "speaker_a": speaker_a,
- "speaker_b": speaker_b,
- "speaker_a_user_id": speaker_a_user_id,
- "speaker_b_user_id": speaker_b_user_id,
- "conv_id": conv_id,
- }
-
- def _get_clients(self, frame, speaker_a_user_id, speaker_b_user_id, conv_id, version, top_k):
- """Return (client, reversed_client) according to the target frame."""
- reversed_client = None
- if frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]:
- client = self.get_client_from_storage(frame, speaker_a_user_id, version, top_k=top_k)
- reversed_client = self.get_client_from_storage(
- frame, speaker_b_user_id, version, top_k=top_k
- )
- else:
- client = self.get_client_from_storage(frame, conv_id, version)
- return client, reversed_client
-
- def _save_conv_stats(self, conv_id, frame, version, conv_stats, conv_stats_path):
- """Persist per-conversation stats to disk."""
- conv_stats_data = {
- "conversation_id": conv_id,
- "frame": frame,
- "version": version,
- "statistics": conv_stats,
- "timestamp": str(datetime.now()),
- }
- with open(conv_stats_path, "w") as fw:
- json.dump(conv_stats_data, fw, indent=2, ensure_ascii=False)
- print(f"Saved conversation stats for {conv_id} to {conv_stats_path}")
-
- def _write_user_search_results(self, user_search_path, search_results, conv_id):
- """Write per-user search results to a temporary JSON file."""
- with open(user_search_path, "w") as fw:
- json.dump(dict(search_results), fw, indent=2)
- print(f"Save search results {conv_id}")
-
- def process_user(self, conv_id, locomo_df, frame, version, top_k=20):
- user_search_path = self.result_dir / f"tmp/{frame}_locomo_search_results_{conv_id}.json"
- user_search_path.parent.mkdir(exist_ok=True, parents=True)
- search_results = defaultdict(list)
- response_results = defaultdict(list)
- conv_stats_path = self.stats_dir / f"{frame}_{version}_conv_{conv_id}_stats.json"
-
- conversation = locomo_df["conversation"].iloc[conv_id]
- speaker_a = conversation.get("speaker_a", "speaker_a")
- speaker_b = conversation.get("speaker_b", "speaker_b")
-
- # Use temporal_locomo data if available, otherwise fall back to original locomo data
- temporal_conv = self.temporal_locomo_data[conv_id]
- conv_id = temporal_conv["conversation_id"]
- speaker_a_user_id = f"{conv_id}_speaker_a"
- speaker_b_user_id = f"{conv_id}_speaker_b"
-
- # Process temporal data by days
- day_groups = {}
- for day_id, day_data in temporal_conv["days"].items():
- day_groups[day_id] = day_data["qa_pairs"]
-
- # Initialize conversation-level statistics
- conv_stats = self._initialize_conv_stats()
-
- metadata = self._build_metadata(
- speaker_a, speaker_b, speaker_a_user_id, speaker_b_user_id, conv_id
- )
-
- client, reversed_client = self._get_clients(
- frame, speaker_a_user_id, speaker_b_user_id, conv_id, version, top_k
- )
-
- oai_client = OpenAI(api_key=self.openai_api_key, base_url=self.openai_base_url)
-
- with self.stats_lock:
- self.pre_context_cache[conv_id] = None
-
- def process_qa(qa):
- return self._process_single_qa(
- qa,
- client=client,
- reversed_client=reversed_client,
- metadata=metadata,
- frame=frame,
- version=version,
- conv_id=conv_id,
- conv_stats_path=conv_stats_path,
- oai_client=oai_client,
- top_k=top_k,
- conv_stats=conv_stats,
- )
-
- # ===================================
- conv_stats["theoretical_total_queries"] = 0
- for day, qa_list in day_groups.items():
- conv_stats["theoretical_total_queries"] += len(qa_list) - 1
- conv_stats["processing_failure_count"] = 0
- print(f"Processing user {conv_id} day {day}")
- for qa in tqdm(qa_list, desc=f"Processing user {conv_id} day {day}"):
- try:
- result = process_qa(qa)
- except Exception as e:
- logger.error(f"Error: {e}. traceback: {traceback.format_exc()}")
- conv_stats["processing_failure_count"] += 1
- continue
- if result:
- context_preview = (
- result["search_context"][:20] + "..."
- if result["search_context"]
- else "No context"
- )
- if "can_answer" in result:
- logger.info("Print can_answer case")
- logger.info(
- {
- "question": result["question"][:100],
- "pre context can answer": result["can_answer"],
- "answer": result["answer"][:100],
- "golden_answer": result["golden_answer"],
- "search_context": context_preview[:100],
- "search_duration_ms": result["search_duration_ms"],
- }
- )
-
- search_results[conv_id].append(
- {
- "question": result["question"],
- "context": result["search_context"],
- "search_duration_ms": result["search_duration_ms"],
- }
- )
- response_results[conv_id].append(result)
-
- logger.warning(
- f"Finished processing user {conv_id} day {day}, data_length: {len(qa_list)}"
- )
-
- # recording separate search results
- with open(user_search_path, "w") as fw:
- json.dump(dict(search_results), fw, indent=2)
- print(f"Save search results {conv_id}")
-
- search_durations = []
- for result in response_results[conv_id]:
- if "search_duration_ms" in result:
- search_durations.append(result["search_duration_ms"])
-
- if search_durations:
- avg_search_duration = sum(search_durations) / len(search_durations)
- with self.stats_lock:
- if self.stats[self.frame][self.version]["memory_stats"]["avg_search_duration_ms"]:
- self.stats[self.frame][self.version]["memory_stats"][
- "avg_search_duration_ms"
- ] = (
- self.stats[self.frame][self.version]["memory_stats"][
- "avg_search_duration_ms"
- ]
- + avg_search_duration
- ) / 2
- print(f"Average search duration: {avg_search_duration:.2f} ms")
-
- # Dump stats after processing each user
- self.save_stats()
-
- return search_results, response_results
-
- def process_user_wrapper(self, args):
- """
- Wraps the process_user function to support parallel execution and error handling.
-
- Args:
- args: Tuple containing parameters for process_user
-
- Returns:
- tuple: Contains user results or error information
- """
- idx, locomo_df, frame, version, top_k = args
- try:
- print(f"Processing user {idx}...")
- user_search_results, user_response_results = self.process_user(
- idx, locomo_df, frame, version, top_k
- )
- return (user_search_results, user_response_results, None)
- except Exception as e:
- return (None, None, (idx, e, traceback.format_exc()))
diff --git a/evaluation/scripts/temporal_locomo/modules/prompts.py b/evaluation/scripts/temporal_locomo/modules/prompts.py
deleted file mode 100644
index c88a8ff28..000000000
--- a/evaluation/scripts/temporal_locomo/modules/prompts.py
+++ /dev/null
@@ -1,219 +0,0 @@
-CUSTOM_INSTRUCTIONS = """
-Generate personal memories that follow these guidelines:
-
-1. Each memory should be self-contained with complete context, including:
- - The person's name, do not use "user" while creating memories
- - Personal details (career aspirations, hobbies, life circumstances)
- - Emotional states and reactions
- - Ongoing journeys or future plans
- - Specific dates when events occurred
-
-2. Include meaningful personal narratives focusing on:
- - Identity and self-acceptance journeys
- - Family planning and parenting
- - Creative outlets and hobbies
- - Mental health and self-care activities
- - Career aspirations and education goals
- - Important life events and milestones
-
-3. Make each memory rich with specific details rather than general statements
- - Include timeframes (exact dates when possible)
- - Name specific activities (e.g., "charity race for mental health" rather than just "exercise")
- - Include emotional context and personal growth elements
-
-4. Extract memories only from user messages, not incorporating assistant responses
-
-5. Format each memory as a paragraph with a clear narrative structure that captures the person's experience, challenges, and aspirations
-"""
-
-SEARCH_PROMPT_ZEP = """
-FACTS and ENTITIES represent relevant context to the current conversation.
-
-# These are the most relevant facts for the conversation along with the datetime of the event that the fact refers to.
-If a fact mentions something happening a week ago, then the datetime will be the date time of last week and not the datetime
-of when the fact was stated.
-Timestamps in memories represent the actual time the event occurred, not the time the event was mentioned in a message.
-
-
-{facts}
-
-
-# These are the most relevant entities
-# ENTITY_NAME: entity summary
-
-{entities}
-
-"""
-
-SEARCH_PROMPT_MEM0 = """Memories for user {speaker_1_user_id}:
-
- {speaker_1_memories}
-
- Memories for user {speaker_2_user_id}:
-
- {speaker_2_memories}
-"""
-
-SEARCH_PROMPT_MEM0_GRAPH = """Memories for user {speaker_1_user_id}:
-
- {speaker_1_memories}
-
- Relations for user {speaker_1_user_id}:
-
- {speaker_1_graph_memories}
-
- Memories for user {speaker_2_user_id}:
-
- {speaker_2_memories}
-
- Relations for user {speaker_2_user_id}:
-
- {speaker_2_graph_memories}
-"""
-
-SEARCH_PROMPT_MEMOS = """Memories for user {speaker_1}:
-
- {speaker_1_memories}
-
- Memories for user {speaker_2}:
-
- {speaker_2_memories}
-"""
-
-
-ANSWER_PROMPT_MEM0 = """
- You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories.
-
- # CONTEXT:
- You have access to memories from two speakers in a conversation. These memories contain
- timestamped information that may be relevant to answering the question.
-
- # INSTRUCTIONS:
- 1. Carefully analyze all provided memories from both speakers
- 2. Pay special attention to the timestamps to determine the answer
- 3. If the question asks about a specific event or fact, look for direct evidence in the memories
- 4. If the memories contain contradictory information, prioritize the most recent memory
- 5. If there is a question about time references (like "last year", "two months ago", etc.),
- calculate the actual date based on the memory timestamp. For example, if a memory from
- 4 May 2022 mentions "went to India last year," then the trip occurred in 2021.
- 6. Always convert relative time references to specific dates, months, or years. For example,
- convert "last year" to "2022" or "two months ago" to "March 2023" based on the memory
- timestamp. Ignore the reference while answering the question.
- 7. Focus only on the content of the memories from both speakers. Do not confuse character
- names mentioned in memories with the actual users who created those memories.
- 8. The answer should be less than 5-6 words.
-
- # APPROACH (Think step by step):
- 1. First, examine all memories that contain information related to the question
- 2. Examine the timestamps and content of these memories carefully
- 3. Look for explicit mentions of dates, times, locations, or events that answer the question
- 4. If the answer requires calculation (e.g., converting relative time references), show your work
- 5. Formulate a precise, concise answer based solely on the evidence in the memories
- 6. Double-check that your answer directly addresses the question asked
- 7. Ensure your final answer is specific and avoids vague time references
-
- {context}
-
- Question: {question}
-
- Answer:
- """
-
-
-ANSWER_PROMPT_ZEP = """
- You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories.
-
- # CONTEXT:
- You have access to memories from a conversation. These memories contain
- timestamped information that may be relevant to answering the question.
-
- # INSTRUCTIONS:
- 1. Carefully analyze all provided memories
- 2. Pay special attention to the timestamps to determine the answer
- 3. If the question asks about a specific event or fact, look for direct evidence in the memories
- 4. If the memories contain contradictory information, prioritize the most recent memory
- 5. If there is a question about time references (like "last year", "two months ago", etc.),
- calculate the actual date based on the memory timestamp. For example, if a memory from
- 4 May 2022 mentions "went to India last year," then the trip occurred in 2021.
- 6. Always convert relative time references to specific dates, months, or years. For example,
- convert "last year" to "2022" or "two months ago" to "March 2023" based on the memory
- timestamp. Ignore the reference while answering the question.
- 7. Focus only on the content of the memories. Do not confuse character
- names mentioned in memories with the actual users who created those memories.
- 8. The answer should be less than 5-6 words.
-
- # APPROACH (Think step by step):
- 1. First, examine all memories that contain information related to the question
- 2. Examine the timestamps and content of these memories carefully
- 3. Look for explicit mentions of dates, times, locations, or events that answer the question
- 4. If the answer requires calculation (e.g., converting relative time references), show your work
- 5. Formulate a precise, concise answer based solely on the evidence in the memories
- 6. Double-check that your answer directly addresses the question asked
- 7. Ensure your final answer is specific and avoids vague time references
-
- Context:
-
- {context}
-
- Question: {question}
- Answer:
- """
-
-ANSWER_PROMPT_MEMOS = """
- You are a knowledgeable and helpful AI assistant.
-
- # CONTEXT:
- You have access to memories from two speakers in a conversation. These memories contain
- timestamped information that may be relevant to answering the question.
-
- # INSTRUCTIONS:
- 1. Carefully analyze all provided memories. Synthesize information across different entries if needed to form a complete answer.
- 2. Pay close attention to the timestamps to determine the answer. If memories contain contradictory information, the **most recent memory** is the source of truth.
- 3. If the question asks about a specific event or fact, look for direct evidence in the memories.
- 4. Your answer must be grounded in the memories. However, you may use general world knowledge to interpret or complete information found within a memory (e.g., identifying a landmark mentioned by description).
- 5. If the question involves time references (like "last year", "two months ago", etc.), you **must** calculate the actual date based on the memory's timestamp. For example, if a memory from 4 May 2022 mentions "went to India last year," then the trip occurred in 2021.
- 6. Always convert relative time references to specific dates, months, or years in your final answer.
- 7. Do not confuse character names mentioned in memories with the actual users who created them.
- 8. The answer must be brief (under 5-6 words) and direct, with no extra description.
-
- # APPROACH (Think step by step):
- 1. First, examine all memories that contain information related to the question.
- 2. Synthesize findings from multiple memories if a single entry is insufficient.
- 3. Examine timestamps and content carefully, looking for explicit dates, times, locations, or events.
- 4. If the answer requires calculation (e.g., converting relative time references), perform the calculation.
- 5. Formulate a precise, concise answer based on the evidence from the memories (and allowed world knowledge).
- 6. Double-check that your answer directly addresses the question asked and adheres to all instructions.
- 7. Ensure your final answer is specific and avoids vague time references.
-
- {context}
-
- Question: {question}
-
- Answer:
- """
-
-CONTEXT_ANSWERABILITY_PROMPT = """
-You are an AI assistant that analyzes whether given context can answer a specific question, considering the ground-truth answer.
-
-# TASK:
-Analyze the provided context and determine if it contains sufficient information to answer the given question. Use the provided ground-truth answer to guide your judgment: if the context contains the necessary evidence to derive that answer (explicitly or via direct inference), respond YES; otherwise respond NO.
-
-# INSTRUCTIONS:
-1. Carefully examine the context provided
-2. Identify if the context contains information directly related to the question
-3. Determine if the information is sufficient to provide a complete answer that matches the ground-truth
-4. Consider both explicit mentions and straightforward implications present in the context
-5. Return only "YES" if the context can yield the ground-truth answer, "NO" if it cannot
-
-# CONTEXT:
-{context}
-
-# QUESTION:
-{question}
-
-# GROUND_TRUTH_ANSWER:
-{gold_answer}
-
-# ANALYSIS:
-Can this context answer the question and support the ground-truth answer? (YES/NO):
-"""
diff --git a/evaluation/scripts/temporal_locomo/modules/schemas.py b/evaluation/scripts/temporal_locomo/modules/schemas.py
deleted file mode 100644
index fee89cc62..000000000
--- a/evaluation/scripts/temporal_locomo/modules/schemas.py
+++ /dev/null
@@ -1,161 +0,0 @@
-from typing import Any
-
-from pydantic import BaseModel, Field
-
-
-class ContextUpdateMethod:
- """Enumeration for context update methods"""
-
- PRE_CONTEXT = "pre_context"
- CHAT_HISTORY = "chat_history"
- CURRENT_CONTEXT = "current_context"
-
- @classmethod
- def values(cls):
- """Return a list of all constant values"""
- return [
- getattr(cls, attr)
- for attr in dir(cls)
- if not attr.startswith("_") and isinstance(getattr(cls, attr), str)
- ]
-
-
-class RecordingCase(BaseModel):
- """
- Data structure for recording evaluation cases in temporal locomo evaluation.
-
- This schema represents a single evaluation case containing conversation history,
- context information, memory data, and evaluation results.
- """
-
- # Conversation identification
- conv_id: str = Field(description="Conversation identifier for this evaluation case")
-
- context: str = Field(
- default="",
- description="Current search context retrieved from memory systems for answering the query",
- )
-
- pre_context: str | None = Field(
- default=None,
- description="Previous context from the last query, used for answerability analysis",
- )
-
- # Query and answer information
- query: str = Field(description="The current question/query being evaluated")
-
- answer: str = Field(description="The generated answer for the query")
-
- # Evaluation metrics
- can_answer: bool | None = Field(
- default=None,
- description="Whether the context can answer the query (only for memos_scheduler frame)",
- )
-
- can_answer_reason: str | None = Field(
- default=None, description="Reasoning for the can_answer decision"
- )
-
- # Additional metadata
- category: int | None = Field(
- default=None, description="Category of the query (1-4, where 5 is filtered out)"
- )
-
- golden_answer: str | None = Field(
- default=None, description="Ground truth answer for evaluation"
- )
-
- search_duration_ms: float | None = Field(
- default=None, description="Time taken for memory search in milliseconds"
- )
-
- response_duration_ms: float | None = Field(
- default=None, description="Time taken for response generation in milliseconds"
- )
-
- can_answer_duration_ms: float | None = Field(
- default=None, description="Time taken for answerability analysis in milliseconds"
- )
-
- def to_dict(self) -> dict[str, Any]:
- """
- Convert the RecordingCase to a dictionary for serialization.
-
- Returns:
- Dict[str, Any]: Dictionary representation of the RecordingCase
- """
- return self.dict()
-
- def to_json(self, indent: int = 2) -> str:
- """
- Convert the RecordingCase to a JSON string.
-
- Args:
- indent: JSON indentation level
-
- Returns:
- str: JSON string representation of the RecordingCase
- """
- return self.json(indent=indent, ensure_ascii=False)
-
- @classmethod
- def from_dict(cls, data: dict[str, Any]) -> "RecordingCase":
- """
- Create a RecordingCase from a dictionary.
-
- Args:
- data: Dictionary containing RecordingCase data
-
- Returns:
- RecordingCase: New instance created from the dictionary
- """
- return cls(**data)
-
- @classmethod
- def from_json(cls, json_str: str) -> "RecordingCase":
- """
- Create a RecordingCase from a JSON string.
-
- Args:
- json_str: JSON string containing RecordingCase data
-
- Returns:
- RecordingCase: New instance created from the JSON string
- """
- import json
-
- data = json.loads(json_str)
- return cls.from_dict(data)
-
- class Config:
- """Pydantic configuration"""
-
- extra = "allow" # Allow additional fields not defined in the schema
- validate_assignment = True # Validate on assignment
- use_enum_values = True # Use enum values instead of enum names
-
-
-class TimeEvalRecordingCase(BaseModel):
- memos_search_duration_ms: float | None = Field(
- default=None, description="Time taken for memory search in milliseconds"
- )
-
- memos_response_duration_ms: float | None = Field(
- default=None, description="Time taken for response generation in milliseconds"
- )
-
- memos_can_answer_duration_ms: float | None = Field(
- default=None, description="Time taken for answerability analysis in milliseconds"
- )
-
- scheduler_search_duration_ms: float | None = Field(
- default=None, description="Time taken for memory search in milliseconds"
- )
-
- scheduler_response_duration_ms: float | None = Field(
- default=None, description="Time taken for response generation in milliseconds"
- )
-
- scheduler_can_answer_duration_ms: float | None = Field(
- default=None, description="Time taken for answerability analysis in milliseconds"
- )
diff --git a/evaluation/scripts/temporal_locomo/modules/utils.py b/evaluation/scripts/temporal_locomo/modules/utils.py
deleted file mode 100644
index 215bc4256..000000000
--- a/evaluation/scripts/temporal_locomo/modules/utils.py
+++ /dev/null
@@ -1,296 +0,0 @@
-import json
-
-from pathlib import Path
-
-from .schemas import RecordingCase
-
-
-def filter_memory_data(memories_data):
- filtered_data = {}
- for key, value in memories_data.items():
- if key == "text_mem":
- filtered_data[key] = []
- for mem_group in value:
- # Check if it's the new data structure (list of TextualMemoryItem objects)
- if "memories" in mem_group and isinstance(mem_group["memories"], list):
- # New data structure: directly a list of TextualMemoryItem objects
- filtered_memories = []
- for memory_item in mem_group["memories"]:
- # Create filtered dictionary
- filtered_item = {
- "id": memory_item.id,
- "memory": memory_item.memory,
- "metadata": {},
- }
- # Filter metadata, excluding embedding
- if hasattr(memory_item, "metadata") and memory_item.metadata:
- for attr_name in dir(memory_item.metadata):
- if not attr_name.startswith("_") and attr_name != "embedding":
- attr_value = getattr(memory_item.metadata, attr_name)
- if not callable(attr_value):
- filtered_item["metadata"][attr_name] = attr_value
- filtered_memories.append(filtered_item)
-
- filtered_group = {
- "cube_id": mem_group.get("cube_id", ""),
- "memories": filtered_memories,
- }
- filtered_data[key].append(filtered_group)
- else:
- # Old data structure: dictionary with nodes and edges
- filtered_group = {
- "memories": {"nodes": [], "edges": mem_group["memories"].get("edges", [])}
- }
- for node in mem_group["memories"].get("nodes", []):
- filtered_node = {
- "id": node.get("id"),
- "memory": node.get("memory"),
- "metadata": {
- k: v
- for k, v in node.get("metadata", {}).items()
- if k != "embedding"
- },
- }
- filtered_group["memories"]["nodes"].append(filtered_node)
- filtered_data[key].append(filtered_group)
- else:
- filtered_data[key] = value
- return filtered_data
-
-
-def save_recording_cases(
- cases: list[RecordingCase], output_dir: str | Path, filename: str = "recording_cases.json"
-) -> Path:
- """
- Save a list of RecordingCase objects to a JSON file.
-
- Args:
- cases: List of RecordingCase objects to save
- output_dir: Directory to save the file
- filename: Name of the output file (default: "recording_cases.json")
-
- Returns:
- Path: Path to the saved file
- """
- output_dir = Path(output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
-
- file_path = output_dir / filename
-
- # Convert cases to dictionaries for JSON serialization
- cases_data = [case.to_dict() for case in cases]
-
- with open(file_path, "w", encoding="utf-8") as f:
- json.dump(cases_data, f, indent=2, ensure_ascii=False)
-
- return file_path
-
-
-def load_recording_cases(file_path: str | Path) -> list[RecordingCase]:
- """
- Load RecordingCase objects from a JSON file.
-
- Args:
- file_path: Path to the JSON file containing RecordingCase data
-
- Returns:
- List[RecordingCase]: List of RecordingCase objects loaded from the file
- """
- file_path = Path(file_path)
-
- with open(file_path, encoding="utf-8") as f:
- cases_data = json.load(f)
-
- return [RecordingCase.from_dict(case_data) for case_data in cases_data]
-
-
-def save_evaluation_cases(
- can_answer_cases: list[RecordingCase],
- cannot_answer_cases: list[RecordingCase],
- output_dir: str | Path,
- frame: str = "default",
- version: str = "default",
-) -> dict[str, Path]:
- """
- Save both can_answer_cases and cannot_answer_cases to separate JSON files.
-
- Args:
- can_answer_cases: List of cases that can be answered
- cannot_answer_cases: List of cases that cannot be answered
- output_dir: Directory to save the files
- frame: Framework name for filename prefix
- version: Version identifier for filename
-
- Returns:
- Dict[str, Path]: Dictionary mapping case type to saved file path
- """
- output_dir = Path(output_dir)
- output_dir.mkdir(parents=True, exist_ok=True)
-
- saved_files = {}
-
- # Save can_answer_cases
- if can_answer_cases:
- can_answer_filename = f"{frame}_{version}_can_answer_cases.json"
- can_answer_path = save_recording_cases(can_answer_cases, output_dir, can_answer_filename)
- saved_files["can_answer_cases"] = can_answer_path
- print(f"Saved {len(can_answer_cases)} can_answer_cases to {can_answer_path}")
-
- # Save cannot_answer_cases
- if cannot_answer_cases:
- cannot_answer_filename = f"{frame}_{version}_cannot_answer_cases.json"
- cannot_answer_path = save_recording_cases(
- cannot_answer_cases, output_dir, cannot_answer_filename
- )
- saved_files["cannot_answer_cases"] = cannot_answer_path
- print(f"Saved {len(cannot_answer_cases)} cannot_answer_cases to {cannot_answer_path}")
-
- return saved_files
-
-
-def compute_can_answer_stats(day_groups, rounds_to_consider=float("inf")):
- """
- Compute can-answer statistics for each day using the union of all prior evidences.
-
- For each day, iterate over the QAs in the given order. If the current QA's
- evidences (restricted to the same day) are a subset of the union of all
- previously seen evidences for that day, increment can_answer_count. Then add
- the current evidences to the seen set.
-
- Note:
- The first QA of each day is excluded from the statistics because it
- cannot be answered without any prior evidences. It is still used to
- seed the seen evidences for subsequent QAs.
-
- Args:
- day_groups: Dict mapping day_id (e.g., "D1") to a list of QA dicts. Each QA
- dict should contain an "evidence" field that is a list of strings.
- rounds_to_consider: Number of previous rounds to consider for evidence accumulation.
- Default is infinity (all previous rounds).
- Set to 1 to only consider the immediately preceding round.
-
- Returns:
- dict: Mapping day_id -> {"can_answer_count": int, "total": int, "ratio": float}
- """
- results = {}
- for day, qa_list in day_groups.items():
- seen = set()
- # Keep track of evidence history for limited rounds
- evidence_history = []
- can_answer = 0
- total = max(len(qa_list) - 1, 0)
- rounds_count = 0
- for idx, qa in enumerate(qa_list):
- cur = set(qa.get("evidence", []))
- rounds_count += 1
-
- if idx == 0:
- # Seed seen evidences with the first QA but do not count it
- evidence_history.append(cur)
- seen = set().union(*evidence_history)
- continue
-
- # Check if current evidence is subset of accumulated evidence
- if cur and cur.issubset(seen):
- can_answer += 1
-
- # Add current evidence to history
- evidence_history.append(cur)
-
- # Limit history to specified number of rounds
- if rounds_count > rounds_to_consider:
- evidence_history.pop(0)
-
- # Recalculate seen as union of evidence_history
- seen = set().union(*evidence_history)
-
- results[day] = {
- "can_answer_count": can_answer,
- "total": total,
- "ratio": (can_answer / total) if total else 0.0,
- }
- return results
-
-
-def compute_can_answer_count_by_pre_evidences(
- temporal_locomo_data, num_of_users, stats_dir=None, rounds_to_consider=float("inf")
-):
- """
- Compute can-answer statistics per day for each conversation using the
- union of all previously asked evidences within the same day.
-
- Args:
- temporal_locomo_data: The temporal locomo data containing conversations
- num_of_users: Number of users/conversations to process
- stats_dir: Directory to save statistics (optional)
- rounds_to_consider: Number of previous rounds to consider for evidence accumulation.
- Default is infinity (all previous rounds).
- Set to 1 to only consider the immediately preceding round.
-
- Returns:
- dict: Mapping conversation_id -> per-day stats as produced by compute_can_answer_stats
- """
- all_conversations_stats = {}
- for conv_idx in range(num_of_users):
- temporal_conv = temporal_locomo_data[conv_idx]
- conversation_id = temporal_conv["conversation_id"]
-
- # Build day -> qa_pairs mapping
- day_groups = {}
- for day_id, day_data in temporal_conv.get("days", {}).items():
- day_groups[day_id] = day_data.get("qa_pairs", [])
-
- # Use shared utility to compute stats with correct accumulation logic
- per_day_stats = compute_can_answer_stats(day_groups, rounds_to_consider)
- all_conversations_stats[conversation_id] = per_day_stats
-
- # Build per-conversation summaries and overall summary
- per_conversation_summaries = {}
- overall_can = 0
- overall_total = 0
- for conv_id, day_stats in all_conversations_stats.items():
- conv_can = 0
- conv_total = 0
- for _day, stats in day_stats.items():
- conv_can += int(stats.get("can_answer_count", 0))
- conv_total += int(stats.get("total", 0))
- conv_ratio = (conv_can / conv_total) if conv_total else 0.0
- per_conversation_summaries[conv_id] = {
- "can_answer_count": conv_can,
- "total": conv_total,
- "ratio": conv_ratio,
- }
- overall_can += conv_can
- overall_total += conv_total
-
- overall_summary = {
- "can_answer_count": overall_can,
- "total": overall_total,
- "ratio": (overall_can / overall_total) if overall_total else 0.0,
- }
-
- # Add rounds information to the result
- result_payload = {
- "per_conversation_summary": per_conversation_summaries,
- "overall_summary": overall_summary,
- "rounds_considered": rounds_to_consider if rounds_to_consider != float("inf") else "all",
- }
-
- # Print results
- print("\nComputed can-answer-by-pre-evidences stats:")
- print(
- f"Rounds considered: {rounds_to_consider if rounds_to_consider != float('inf') else 'all'}"
- )
- print(json.dumps(result_payload, indent=2, ensure_ascii=False))
-
- # Save results if stats_dir is provided
- if stats_dir:
- output_path = (
- stats_dir
- / f"evidences_rounds_{rounds_to_consider if rounds_to_consider != float('inf') else 'all'}.json"
- )
- with open(output_path, "w", encoding="utf-8") as fw:
- json.dump(result_payload, fw, indent=2, ensure_ascii=False)
- print(f"Saved stats to {output_path}")
-
- return result_payload
diff --git a/evaluation/scripts/temporal_locomo/scheduler_time_eval.py b/evaluation/scripts/temporal_locomo/scheduler_time_eval.py
deleted file mode 100644
index 12d1964cd..000000000
--- a/evaluation/scripts/temporal_locomo/scheduler_time_eval.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import argparse
-import sys
-
-from pathlib import Path
-
-from modules.locomo_eval_module import LocomoEvalModelModules
-from modules.schemas import ContextUpdateMethod
-
-from evaluation.scripts.temporal_locomo.models.locomo_ingestion import LocomoIngestor
-from evaluation.scripts.temporal_locomo.models.locomo_processor_w_time_eval import (
- LocomoProcessorWithTimeEval,
-)
-from memos.log import get_logger
-
-
-FILE_PATH = Path(__file__).absolute()
-BASE_DIR = FILE_PATH.parent.parent.parent
-sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
-
-logger = get_logger(__name__)
-
-
-# TODO: This evaluation has been suspended—it is not finished yet.
-class TemporalLocomoForTimeEval(LocomoEvalModelModules):
- def __init__(self, args):
- args.result_dir_prefix = "time_eval-"
-
- super().__init__(args=args)
- self.num_of_users = 10
-
- self.locomo_ingestor = LocomoIngestor(args=args)
- self.locomo_processor = LocomoProcessorWithTimeEval(args=args)
-
- def run_time_eval_pipeline(self, skip_ingestion=True, skip_processing=False):
- """
- Run the complete evaluation pipeline including dataset conversion,
- data ingestion, and processing.
- """
- print("=" * 80)
- print("Starting TimeLocomo Evaluation Pipeline")
- print("=" * 80)
-
- # Step 1: Check if temporal_locomo dataset exists, if not convert it
- temporal_locomo_file = self.data_dir / "temporal_locomo" / "temporal_locomo_qa.json"
- if not temporal_locomo_file.exists():
- print(f"Temporal locomo dataset not found at {temporal_locomo_file}")
- print("Converting locomo dataset to temporal_locomo format...")
- self.convert_locomo_to_temporal_locomo(output_dir=self.data_dir / "temporal_locomo")
- print("Dataset conversion completed.")
- else:
- print(f"Temporal locomo dataset found at {temporal_locomo_file}, skipping conversion.")
-
- # Step 2: Data ingestion
- if not skip_ingestion:
- print("\n" + "=" * 50)
- print("Step 2: Data Ingestion")
- print("=" * 50)
- self.locomo_ingestor.run_ingestion()
-
- # Step 3: Processing and evaluation
- print("\n" + "=" * 50)
- print("Step 3: Processing and Evaluation")
- print("=" * 50)
- print("Running locomo processing to search and answer...")
-
- print("Starting locomo processing to generate search and response results...")
- self.locomo_processor.run_locomo_processing(num_users=self.num_of_users)
- print("Processing completed successfully.")
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--version",
- type=str,
- default="v1.0.1",
- help="Version identifier for saving results (e.g., 1010)",
- )
- parser.add_argument(
- "--workers", type=int, default=10, help="Number of parallel workers to process users"
- )
- parser.add_argument(
- "--top_k", type=int, default=20, help="Number of results to retrieve in search queries"
- )
-
- args = parser.parse_args()
-
- args.frame = "memos_scheduler"
- args.scheduler_flag = True
- args.context_update_method = ContextUpdateMethod.PRE_CONTEXT
-
- evaluator = TemporalLocomoForTimeEval(args=args)
- evaluator.run_time_eval_pipeline()
diff --git a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py b/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py
deleted file mode 100644
index bb6967e7f..000000000
--- a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py
+++ /dev/null
@@ -1,155 +0,0 @@
-import argparse
-import asyncio
-import os
-import sys
-
-from pathlib import Path
-
-from modules.locomo_eval_module import LocomoEvalModelModules
-from modules.schemas import ContextUpdateMethod
-from modules.utils import compute_can_answer_count_by_pre_evidences
-
-from evaluation.scripts.temporal_locomo.models.locomo_eval import LocomoEvaluator
-from evaluation.scripts.temporal_locomo.models.locomo_ingestion import LocomoIngestor
-from evaluation.scripts.temporal_locomo.models.locomo_metric import LocomoMetric
-from evaluation.scripts.temporal_locomo.models.locomo_processor import LocomoProcessor
-from memos.log import get_logger
-
-
-FILE_PATH = Path(__file__).absolute()
-BASE_DIR = FILE_PATH.parent.parent.parent
-sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
-
-logger = get_logger(__name__)
-
-
-class TemporalLocomoEval(LocomoEvalModelModules):
- def __init__(self, args):
- super().__init__(args=args)
- self.num_of_users = 10
-
- self.locomo_ingestor = LocomoIngestor(args=args)
- self.locomo_processor = LocomoProcessor(args=args)
- self.locomo_evaluator = LocomoEvaluator(args=args)
- self.locomo_metric = LocomoMetric(args=args)
-
- def run_answer_hit_eval_pipeline(self, skip_ingestion=True, skip_processing=False):
- """
- Run the complete evaluation pipeline including dataset conversion,
- data ingestion, and processing.
- """
- print("=" * 80)
- print("Starting TimeLocomo Evaluation Pipeline")
- print("=" * 80)
-
- # Step 1: Check if temporal_locomo dataset exists, if not convert it
- temporal_locomo_file = self.data_dir / "temporal_locomo" / "temporal_locomo_qa.json"
- if not temporal_locomo_file.exists():
- print(f"Temporal locomo dataset not found at {temporal_locomo_file}")
- print("Converting locomo dataset to temporal_locomo format...")
- self.convert_locomo_to_temporal_locomo(output_dir=self.data_dir / "temporal_locomo")
- print("Dataset conversion completed.")
- else:
- print(f"Temporal locomo dataset found at {temporal_locomo_file}, skipping conversion.")
-
- # Step 2: Data ingestion
- if not skip_ingestion:
- print("\n" + "=" * 50)
- print("Step 2: Data Ingestion")
- print("=" * 50)
- self.locomo_ingestor.run_ingestion()
-
- # Step 3: Processing and evaluation
- if not skip_processing:
- print("\n" + "=" * 50)
- print("Step 3: Processing and Evaluation")
- print("=" * 50)
- print("Running locomo processing to search and answer...")
-
- print("Starting locomo processing to generate search and response results...")
- self.locomo_processor.run_locomo_processing(num_users=self.num_of_users)
- print("Processing completed successfully.")
-
- # Optional: run post-hoc evaluation over generated responses if available
- try:
- if os.path.exists(self.response_path):
- print("Running LocomoEvaluator over existing response results...")
- asyncio.run(self.locomo_evaluator.run())
- else:
- print(
- f"Skipping LocomoEvaluator: response file not found at {evaluator.response_path}"
- )
- # Run metrics summarization if judged file is produced
-
- if os.path.exists(self.judged_path):
- print("Running LocomoMetric over judged results...")
- self.locomo_metric.run()
- else:
- print(f"Skipping LocomoMetric: judged file not found at {self.judged_path}")
- except Exception as e:
- logger.error(f"LocomoEvaluator step skipped due to error: {e}", exc_info=True)
-
- # Step 4: Summary
- print("\n" + "=" * 80)
- print("Evaluation Pipeline Completed Successfully!")
- print("=" * 80)
- print("Results saved to:")
- print(f" - Search results: {self.search_path}")
- print(f" - Response results: {self.response_path}")
- print(f" - Statistics: {self.stats_path}")
- print("=" * 80)
-
- def compute_can_answer_count_by_pre_evidences(self, rounds_to_consider):
- """
- Compute can-answer statistics per day for each conversation using the
- union of all previously asked evidences within the same day.
-
- Returns:
- dict: Mapping conversation_id -> per-day stats as produced by compute_can_answer_stats
- """
- return compute_can_answer_count_by_pre_evidences(
- temporal_locomo_data=self.temporal_locomo_data,
- num_of_users=self.num_of_users,
- stats_dir=self.stats_dir,
- rounds_to_consider=rounds_to_consider,
- )
-
-
-if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "--frame",
- type=str,
- default="memos",
- choices=["zep", "memos", "mem0", "mem0_graph", "memos_scheduler"],
- help="Specify the memory framework (zep or memos or mem0 or mem0_graph)",
- )
- parser.add_argument(
- "--version",
- type=str,
- default="v1.0.1",
- help="Version identifier for saving results (e.g., 1010)",
- )
- parser.add_argument(
- "--workers", type=int, default=10, help="Number of parallel workers to process users"
- )
- parser.add_argument(
- "--top_k", type=int, default=20, help="Number of results to retrieve in search queries"
- )
- parser.add_argument(
- "--scheduler_flag",
- action=argparse.BooleanOptionalAction,
- default=False,
- help="Enable or disable memory scheduler features",
- )
- parser.add_argument(
- "--context_update_method",
- type=str,
- default="chat_history",
- choices=ContextUpdateMethod.values(),
- help="Method to update context: pre_context (use previous context), chat_history (use template with history), current_context (use current context)",
- )
- args = parser.parse_args()
-
- evaluator = TemporalLocomoEval(args=args)
- evaluator.run_answer_hit_eval_pipeline()
diff --git a/evaluation/scripts/temporal_locomo/__init__.py b/examples/api/__init__.py
similarity index 100%
rename from evaluation/scripts/temporal_locomo/__init__.py
rename to examples/api/__init__.py
diff --git a/examples/api/product_api.py b/examples/api/product_api.py
new file mode 100644
index 000000000..b98f3b8e5
--- /dev/null
+++ b/examples/api/product_api.py
@@ -0,0 +1,144 @@
+#!/usr/bin/env python3
+"""
+Simulate full MemOS Product API workflow:
+1. Register user
+2. Add memory
+3. Search memory
+4. Chat (stream)
+"""
+
+import json
+
+import requests
+
+
+BASE_URL = "http://0.0.0.0:8001/product"
+HEADERS = {"Content-Type": "application/json"}
+
+index = "24"
+USER_ID = f"memos_user_id_{index}"
+USER_NAME = f"memos_user_alice_{index}"
+MEM_CUBE_ID = f"memos_cube_id_{index}"
+SESSION_ID = f"memos_session_id_{index}"
+SESSION_ID2 = f"memos_session_id_{index}_s2"
+
+
+def register_user():
+ url = f"{BASE_URL}/users/register"
+ data = {
+ "user_id": USER_ID,
+ "user_name": USER_NAME,
+ "interests": "memory,retrieval,test",
+ "mem_cube_id": MEM_CUBE_ID,
+ }
+ print(f"[*] Registering user {USER_ID} ...")
+ resp = requests.post(url, headers=HEADERS, data=json.dumps(data), timeout=30)
+ print(resp.status_code, resp.text)
+ return resp.json()
+
+
+def add_memory():
+ url = f"{BASE_URL}/add"
+ data = {
+ "user_id": USER_ID,
+ "memory_content": "今天我在测试 MemOS 的记忆添加与检索流程。",
+ "messages": [{"role": "user", "content": "我今天在做系统测试"}],
+ "doc_path": None,
+ "mem_cube_id": MEM_CUBE_ID,
+ "source": "test_script",
+ "user_profile": False,
+ "session_id": SESSION_ID,
+ }
+ print("[*] Adding memory ...")
+ resp = requests.post(url, headers=HEADERS, data=json.dumps(data), timeout=30)
+ print(resp.status_code, resp.text)
+ return resp.json()
+
+
+def search_memory(query="系统测试"):
+ url = f"{BASE_URL}/search"
+ data = {
+ "user_id": USER_ID,
+ "query": query,
+ "mem_cube_id": MEM_CUBE_ID,
+ "top_k": 5,
+ "session_id": SESSION_ID,
+ }
+ print("[*] Searching memory ...")
+ resp = requests.post(url, headers=HEADERS, data=json.dumps(data), timeout=30)
+ print(resp.status_code, resp.text)
+ return resp.json()
+
+
+def chat_stream(query: str, session_id: str, history: list | None = None):
+ url = f"{BASE_URL}/chat"
+ data = {
+ "user_id": USER_ID,
+ "query": query,
+ "mem_cube_id": MEM_CUBE_ID,
+ "history": history,
+ "internet_search": False,
+ "moscube": False,
+ "session_id": session_id,
+ }
+
+ print("[*] Starting streaming chat ...")
+
+ with requests.post(url, headers=HEADERS, data=json.dumps(data), stream=True) as resp:
+ for raw_line in resp.iter_lines():
+ if not raw_line:
+ continue
+ line = raw_line.decode("utf-8", errors="ignore")
+
+ payload = line.removeprefix("data: ").strip()
+ if payload == "[DONE]":
+ print("[done]")
+ break
+
+ try:
+ msg = json.loads(payload)
+ msg_type = msg.get("type")
+ msg_data = msg.get("data") or msg.get("content")
+
+ if msg_type == "text":
+ print(msg_data, end="", flush=True)
+ elif msg_type == "reference":
+ print(f"\n[参考记忆] {msg_data}")
+ elif msg_type == "status":
+ pass
+ elif msg_type == "suggestion":
+ print(f"\n[建议] {msg_data}")
+ elif msg_type == "end":
+ print("\n[✅ Chat End]")
+ else:
+ print(f"\n[{msg_type}] {msg_data}")
+ except Exception:
+ try:
+ print(payload.encode("latin-1").decode("utf-8"), end="")
+ except Exception:
+ print(payload)
+
+
+if __name__ == "__main__":
+ print("===== STEP 1: Register User =====")
+ register_user()
+
+ print("\n===== STEP 2: Add Memory =====")
+ add_memory()
+
+ print("\n===== STEP 3: Search Memory =====")
+ search_memory()
+
+ print("\n===== STEP 4: Stream Chat =====")
+ chat_stream("我很开心,我今天吃了好吃的拉面", SESSION_ID, history=[])
+ chat_stream(
+ "我刚和你说什么",
+ SESSION_ID,
+ history=[
+ {"role": "user", "content": "我很开心,我今天吃了好吃的拉面"},
+ {"role": "assistant", "content": "🉑"},
+ ],
+ )
+
+ print("\n===== STEP 4: Stream Chat =====")
+ chat_stream("我刚和你说什么了呢", SESSION_ID2, history=[])
diff --git a/examples/mem_scheduler/api_w_scheduler.py b/examples/mem_scheduler/api_w_scheduler.py
new file mode 100644
index 000000000..1b59543f3
--- /dev/null
+++ b/examples/mem_scheduler/api_w_scheduler.py
@@ -0,0 +1,85 @@
+from memos.api.handlers.scheduler_handler import (
+ handle_scheduler_status,
+ handle_scheduler_wait,
+)
+from memos.api.routers.server_router import mem_scheduler
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+
+
+# Debug: Print scheduler configuration
+print("=== Scheduler Configuration Debug ===")
+print(f"Scheduler type: {type(mem_scheduler).__name__}")
+print(f"Config: {mem_scheduler.config}")
+print(f"use_redis_queue: {mem_scheduler.use_redis_queue}")
+print(f"Queue type: {type(mem_scheduler.memos_message_queue).__name__}")
+print(f"Queue maxsize: {getattr(mem_scheduler.memos_message_queue, 'maxsize', 'N/A')}")
+print("=====================================\n")
+
+mem_scheduler.memos_message_queue.debug_mode_on()
+queue = mem_scheduler.memos_message_queue
+queue.clear()
+
+
+# 1. Define a handler function
+def my_test_handler(messages: list[ScheduleMessageItem]):
+ print(f"My test handler received {len(messages)} messages:")
+ for msg in messages:
+ print(f" my_test_handler - {msg.item_id}: {msg.content}")
+ user_status_running = handle_scheduler_status(
+ user_name=USER_MEM_CUBE, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler"
+ )
+ print(f"[Monitor] Status for {USER_MEM_CUBE} after submit:", user_status_running)
+
+
+# 2. Register the handler
+TEST_HANDLER_LABEL = "test_handler"
+mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler})
+
+# 2.1 Monitor global scheduler status before submitting tasks
+global_status_before = handle_scheduler_status(
+ user_name=None, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler"
+)
+print("[Monitor] Global status before submit:", global_status_before)
+
+# 3. Create messages
+messages_to_send = [
+ ScheduleMessageItem(
+ item_id=f"test_item_{i}",
+ user_id="test_user",
+ mem_cube_id="test_mem_cube",
+ label=TEST_HANDLER_LABEL,
+ content=f"This is test message {i}",
+ )
+ for i in range(5)
+]
+
+# 5. Submit messages
+for mes in messages_to_send:
+ print(f"Submitting message {mes.item_id} to the scheduler...")
+ mem_scheduler.memos_message_queue.submit_messages([mes])
+
+# 5.1 Monitor status for specific mem_cube while running
+USER_MEM_CUBE = "test_mem_cube"
+
+# 6. Wait for messages to be processed (limited to 100 checks)
+print("Waiting for messages to be consumed (max 100 checks)...")
+mem_scheduler.mem_scheduler_wait()
+
+# 6.1 Wait until idle for specific mem_cube via handler
+wait_result = handle_scheduler_wait(
+ user_name=USER_MEM_CUBE,
+ timeout_seconds=120.0,
+ poll_interval=0.2,
+ mem_scheduler=mem_scheduler,
+)
+print(f"[Monitor] Wait result for {USER_MEM_CUBE}:", wait_result)
+
+# 6.2 Monitor global scheduler status after processing
+global_status_after = handle_scheduler_status(
+ user_name=None, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler"
+)
+print("[Monitor] Global status after processing:", global_status_after)
+
+# 7. Stop the scheduler
+print("Stopping the scheduler...")
+mem_scheduler.stop()
diff --git a/examples/mem_scheduler/memos_w_optimized_scheduler.py b/examples/mem_scheduler/memos_w_optimized_scheduler.py
deleted file mode 100644
index 664168f62..000000000
--- a/examples/mem_scheduler/memos_w_optimized_scheduler.py
+++ /dev/null
@@ -1,85 +0,0 @@
-import shutil
-import sys
-
-from pathlib import Path
-
-from memos_w_scheduler import init_task, show_web_logs
-
-from memos.configs.mem_cube import GeneralMemCubeConfig
-from memos.configs.mem_os import MOSConfig
-from memos.configs.mem_scheduler import AuthConfig
-from memos.log import get_logger
-from memos.mem_cube.general import GeneralMemCube
-from memos.mem_os.main import MOS
-
-
-FILE_PATH = Path(__file__).absolute()
-BASE_DIR = FILE_PATH.parent.parent.parent
-sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
-
-logger = get_logger(__name__)
-
-
-def run_with_scheduler_init():
- print("==== run_with_automatic_scheduler_init ====")
- conversations, questions = init_task()
-
- # set configs
- mos_config = MOSConfig.from_yaml_file(
- f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml"
- )
-
- mem_cube_config = GeneralMemCubeConfig.from_yaml_file(
- f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config_neo4j.yaml"
- )
-
- # default local graphdb uri
- if AuthConfig.default_config_exists():
- auth_config = AuthConfig.from_local_config()
-
- mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key
- mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url
-
- mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
- mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user
- mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password
- mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name
- mem_cube_config.text_mem.config.graph_db.config.auto_create = (
- auth_config.graph_db.auto_create
- )
-
- # Initialization
- mos = MOS(mos_config)
-
- user_id = "user_1"
- mos.create_user(user_id)
-
- mem_cube_id = "mem_cube_5"
- mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}"
-
- if Path(mem_cube_name_or_path).exists():
- shutil.rmtree(mem_cube_name_or_path)
- print(f"{mem_cube_name_or_path} is not empty, and has been removed.")
-
- mem_cube = GeneralMemCube(mem_cube_config)
- mem_cube.dump(mem_cube_name_or_path)
- mos.register_mem_cube(
- mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
- )
-
- mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
-
- for item in questions:
- print("===== Chat Start =====")
- query = item["question"]
- print(f"Query:\n {query}\n")
- response = mos.chat(query=query, user_id=user_id)
- print(f"Answer:\n {response}\n")
-
- show_web_logs(mem_scheduler=mos.mem_scheduler)
-
- mos.mem_scheduler.stop()
-
-
-if __name__ == "__main__":
- run_with_scheduler_init()
diff --git a/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py b/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py
deleted file mode 100644
index ed4f721ad..000000000
--- a/examples/mem_scheduler/memos_w_optimized_scheduler_for_test.py
+++ /dev/null
@@ -1,87 +0,0 @@
-import json
-import shutil
-import sys
-
-from pathlib import Path
-
-from memos_w_scheduler_for_test import init_task
-
-from memos.configs.mem_cube import GeneralMemCubeConfig
-from memos.configs.mem_os import MOSConfig
-from memos.configs.mem_scheduler import AuthConfig
-from memos.log import get_logger
-from memos.mem_cube.general import GeneralMemCube
-from memos.mem_scheduler.analyzer.mos_for_test_scheduler import MOSForTestScheduler
-
-
-FILE_PATH = Path(__file__).absolute()
-BASE_DIR = FILE_PATH.parent.parent.parent
-sys.path.insert(0, str(BASE_DIR))
-
-# Enable execution from any working directory
-
-logger = get_logger(__name__)
-
-if __name__ == "__main__":
- # set up data
- conversations, questions = init_task()
-
- # set configs
- mos_config = MOSConfig.from_yaml_file(
- f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml"
- )
-
- mem_cube_config = GeneralMemCubeConfig.from_yaml_file(
- f"{BASE_DIR}/examples/data/config/mem_scheduler/mem_cube_config_neo4j.yaml"
- )
-
- # default local graphdb uri
- if AuthConfig.default_config_exists():
- auth_config = AuthConfig.from_local_config()
-
- mos_config.mem_reader.config.llm.config.api_key = auth_config.openai.api_key
- mos_config.mem_reader.config.llm.config.api_base = auth_config.openai.base_url
-
- mem_cube_config.text_mem.config.graph_db.config.uri = auth_config.graph_db.uri
- mem_cube_config.text_mem.config.graph_db.config.user = auth_config.graph_db.user
- mem_cube_config.text_mem.config.graph_db.config.password = auth_config.graph_db.password
- mem_cube_config.text_mem.config.graph_db.config.db_name = auth_config.graph_db.db_name
- mem_cube_config.text_mem.config.graph_db.config.auto_create = (
- auth_config.graph_db.auto_create
- )
-
- # Initialization
- mos = MOSForTestScheduler(mos_config)
-
- user_id = "user_1"
- mos.create_user(user_id)
-
- mem_cube_id = "mem_cube_5"
- mem_cube_name_or_path = f"{BASE_DIR}/outputs/mem_scheduler/{user_id}/{mem_cube_id}"
-
- if Path(mem_cube_name_or_path).exists():
- shutil.rmtree(mem_cube_name_or_path)
- print(f"{mem_cube_name_or_path} is not empty, and has been removed.")
-
- mem_cube = GeneralMemCube(mem_cube_config)
- mem_cube.dump(mem_cube_name_or_path)
- mos.register_mem_cube(
- mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
- )
-
- mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
-
- # Add interfering conversations
- file_path = Path(f"{BASE_DIR}/examples/data/mem_scheduler/scene_data.json")
- scene_data = json.load(file_path.open("r", encoding="utf-8"))
- mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id)
- mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id)
-
- for item in questions:
- print("===== Chat Start =====")
- query = item["question"]
- print(f"Query:\n {query}\n")
- response = mos.chat(query=query, user_id=user_id)
- print(f"Answer:\n {response}\n")
-
- mos.mem_scheduler.stop()
diff --git a/examples/mem_scheduler/memos_w_scheduler.py b/examples/mem_scheduler/memos_w_scheduler.py
index dc196b85a..c523a8667 100644
--- a/examples/mem_scheduler/memos_w_scheduler.py
+++ b/examples/mem_scheduler/memos_w_scheduler.py
@@ -70,13 +70,48 @@ def init_task():
return conversations, questions
+def show_web_logs(mem_scheduler: GeneralScheduler):
+ """Display all web log entries from the scheduler's log queue.
+
+ Args:
+ mem_scheduler: The scheduler instance containing web logs to display
+ """
+ if mem_scheduler._web_log_message_queue.empty():
+ print("Web log queue is currently empty.")
+ return
+
+ print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50)
+
+ # Create a temporary queue to preserve the original queue contents
+ temp_queue = Queue()
+ log_count = 0
+
+ while not mem_scheduler._web_log_message_queue.empty():
+ log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get()
+ temp_queue.put(log_item)
+ log_count += 1
+
+ # Print log entry details
+ print(f"\nLog Entry #{log_count}:")
+ print(f'- "{log_item.label}" log: {log_item}')
+
+ print("-" * 50)
+
+ # Restore items back to the original queue
+ while not temp_queue.empty():
+ mem_scheduler._web_log_message_queue.put(temp_queue.get())
+
+ print(f"\nTotal {log_count} web log entries displayed.")
+ print("=" * 110 + "\n")
+
+
def run_with_scheduler_init():
print("==== run_with_automatic_scheduler_init ====")
conversations, questions = init_task()
# set configs
mos_config = MOSConfig.from_yaml_file(
- f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml"
+ f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml"
)
mem_cube_config = GeneralMemCubeConfig.from_yaml_file(
@@ -118,6 +153,7 @@ def run_with_scheduler_init():
)
mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
+ mos.mem_scheduler.current_mem_cube = mem_cube
for item in questions:
print("===== Chat Start =====")
@@ -131,40 +167,5 @@ def run_with_scheduler_init():
mos.mem_scheduler.stop()
-def show_web_logs(mem_scheduler: GeneralScheduler):
- """Display all web log entries from the scheduler's log queue.
-
- Args:
- mem_scheduler: The scheduler instance containing web logs to display
- """
- if mem_scheduler._web_log_message_queue.empty():
- print("Web log queue is currently empty.")
- return
-
- print("\n" + "=" * 50 + " WEB LOGS " + "=" * 50)
-
- # Create a temporary queue to preserve the original queue contents
- temp_queue = Queue()
- log_count = 0
-
- while not mem_scheduler._web_log_message_queue.empty():
- log_item: ScheduleLogForWebItem = mem_scheduler._web_log_message_queue.get()
- temp_queue.put(log_item)
- log_count += 1
-
- # Print log entry details
- print(f"\nLog Entry #{log_count}:")
- print(f'- "{log_item.label}" log: {log_item}')
-
- print("-" * 50)
-
- # Restore items back to the original queue
- while not temp_queue.empty():
- mem_scheduler._web_log_message_queue.put(temp_queue.get())
-
- print(f"\nTotal {log_count} web log entries displayed.")
- print("=" * 110 + "\n")
-
-
if __name__ == "__main__":
run_with_scheduler_init()
diff --git a/examples/mem_scheduler/memos_w_scheduler_for_test.py b/examples/mem_scheduler/memos_w_scheduler_for_test.py
index 6faac98af..2e135f127 100644
--- a/examples/mem_scheduler/memos_w_scheduler_for_test.py
+++ b/examples/mem_scheduler/memos_w_scheduler_for_test.py
@@ -1,10 +1,11 @@
import json
import shutil
import sys
-import time
from pathlib import Path
+from memos_w_scheduler import init_task
+
from memos.configs.mem_cube import GeneralMemCubeConfig
from memos.configs.mem_os import MOSConfig
from memos.configs.mem_scheduler import AuthConfig
@@ -15,155 +16,19 @@
FILE_PATH = Path(__file__).absolute()
BASE_DIR = FILE_PATH.parent.parent.parent
-sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
-
-logger = get_logger(__name__)
-
-
-def display_memory_cube_stats(mos, user_id, mem_cube_id):
- """Display detailed memory cube statistics."""
- print(f"\n📊 MEMORY CUBE STATISTICS for {mem_cube_id}:")
- print("-" * 60)
-
- mem_cube = mos.mem_cubes.get(mem_cube_id)
- if not mem_cube:
- print(" ❌ Memory cube not found")
- return
-
- # Text memory stats
- if mem_cube.text_mem:
- text_mem = mem_cube.text_mem
- working_memories = text_mem.get_working_memory()
- all_memories = text_mem.get_all()
-
- print(" 📝 Text Memory:")
- print(f" • Working Memory Items: {len(working_memories)}")
- print(
- f" • Total Memory Items: {len(all_memories) if isinstance(all_memories, list) else 'N/A'}"
- )
-
- if working_memories:
- print(" • Working Memory Content Preview:")
- for i, mem in enumerate(working_memories[:2]):
- content = mem.memory[:60] + "..." if len(mem.memory) > 60 else mem.memory
- print(f" {i + 1}. {content}")
-
- # Activation memory stats
- if mem_cube.act_mem:
- act_mem = mem_cube.act_mem
- act_memories = list(act_mem.get_all())
- print(" ⚡ Activation Memory:")
- print(f" • KV Cache Items: {len(act_memories)}")
- if act_memories:
- print(
- f" • Latest Cache Size: {len(act_memories[-1].memory) if hasattr(act_memories[-1], 'memory') else 'N/A'}"
- )
-
- print("-" * 60)
-
-
-def display_scheduler_status(mos):
- """Display current scheduler status and configuration."""
- print("\n⚙️ SCHEDULER STATUS:")
- print("-" * 60)
-
- if not mos.mem_scheduler:
- print(" ❌ Memory scheduler not initialized")
- return
-
- scheduler = mos.mem_scheduler
- print(f" 🔄 Scheduler Running: {scheduler._running}")
- print(f" 📊 Internal Queue Size: {scheduler.memos_message_queue.qsize()}")
- print(f" 🧵 Parallel Dispatch: {scheduler.enable_parallel_dispatch}")
- print(f" 👥 Max Workers: {scheduler.thread_pool_max_workers}")
- print(f" ⏱️ Consume Interval: {scheduler._consume_interval}s")
-
- if scheduler.monitor:
- print(" 📈 Monitor Active: ✅")
- print(f" 🗄️ Database Engine: {'✅' if scheduler.db_engine else '❌'}")
-
- if scheduler.dispatcher:
- print(" 🚀 Dispatcher Active: ✅")
- print(
- f" 🔧 Dispatcher Status: {scheduler.dispatcher.status if hasattr(scheduler.dispatcher, 'status') else 'Unknown'}"
- )
+sys.path.insert(0, str(BASE_DIR))
- print("-" * 60)
-
-
-def init_task():
- conversations = [
- {
- "role": "user",
- "content": "I have two dogs - Max (golden retriever) and Bella (pug). We live in Seattle.",
- },
- {"role": "assistant", "content": "Great! Any special care for them?"},
- {
- "role": "user",
- "content": "Max needs joint supplements. Actually, we're moving to Chicago next month.",
- },
- {
- "role": "user",
- "content": "Correction: Bella is 6, not 5. And she's allergic to chicken.",
- },
- {
- "role": "user",
- "content": "My partner's cat Whiskers visits weekends. Bella chases her sometimes.",
- },
- ]
-
- questions = [
- # 1. Basic factual recall (simple)
- {
- "question": "What breed is Max?",
- "category": "Pet",
- "expected": "golden retriever",
- "difficulty": "easy",
- },
- # 2. Temporal context (medium)
- {
- "question": "Where will I live next month?",
- "category": "Location",
- "expected": "Chicago",
- "difficulty": "medium",
- },
- # 3. Information correction (hard)
- {
- "question": "How old is Bella really?",
- "category": "Pet",
- "expected": "6",
- "difficulty": "hard",
- "hint": "User corrected the age later",
- },
- # 4. Relationship inference (harder)
- {
- "question": "Why might Whiskers be nervous around my pets?",
- "category": "Behavior",
- "expected": "Bella chases her sometimes",
- "difficulty": "harder",
- },
- # 5. Combined medical info (hardest)
- {
- "question": "Which pets have health considerations?",
- "category": "Health",
- "expected": "Max needs joint supplements, Bella is allergic to chicken",
- "difficulty": "hardest",
- "requires": ["combining multiple facts", "ignoring outdated info"],
- },
- ]
- return conversations, questions
+# Enable execution from any working directory
+logger = get_logger(__name__)
if __name__ == "__main__":
- print("🚀 Starting Enhanced Memory Scheduler Test...")
- print("=" * 80)
-
# set up data
conversations, questions = init_task()
# set configs
mos_config = MOSConfig.from_yaml_file(
- f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_scheduler.yaml"
+ f"{BASE_DIR}/examples/data/config/mem_scheduler/memos_config_w_optimized_scheduler.yaml"
)
mem_cube_config = GeneralMemCubeConfig.from_yaml_file(
@@ -186,7 +51,6 @@ def init_task():
)
# Initialization
- print("🔧 Initializing MOS with Scheduler...")
mos = MOSForTestScheduler(mos_config)
user_id = "user_1"
@@ -197,15 +61,15 @@ def init_task():
if Path(mem_cube_name_or_path).exists():
shutil.rmtree(mem_cube_name_or_path)
- print(f"🗑️ {mem_cube_name_or_path} is not empty, and has been removed.")
+ print(f"{mem_cube_name_or_path} is not empty, and has been removed.")
mem_cube = GeneralMemCube(mem_cube_config)
mem_cube.dump(mem_cube_name_or_path)
mos.register_mem_cube(
mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
)
+ mos.mem_scheduler.current_mem_cube = mem_cube
- print("📚 Adding initial conversations...")
mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
# Add interfering conversations
@@ -214,77 +78,11 @@ def init_task():
mos.add(scene_data[0], user_id=user_id, mem_cube_id=mem_cube_id)
mos.add(scene_data[1], user_id=user_id, mem_cube_id=mem_cube_id)
- # Display initial status
- print("\n📊 INITIAL SYSTEM STATUS:")
- display_scheduler_status(mos)
- display_memory_cube_stats(mos, user_id, mem_cube_id)
-
- # Process questions with enhanced monitoring
- print(f"\n🎯 Starting Question Processing ({len(questions)} questions)...")
- question_start_time = time.time()
-
- for i, item in enumerate(questions, 1):
- print(f"\n{'=' * 20} Question {i}/{len(questions)} {'=' * 20}")
- print(f"📝 Category: {item['category']} | Difficulty: {item['difficulty']}")
- print(f"🎯 Expected: {item['expected']}")
- if "hint" in item:
- print(f"💡 Hint: {item['hint']}")
- if "requires" in item:
- print(f"🔍 Requires: {', '.join(item['requires'])}")
-
- print(f"\n🚀 Processing Query: {item['question']}")
- query_start_time = time.time()
-
- response = mos.chat(query=item["question"], user_id=user_id)
-
- query_time = time.time() - query_start_time
- print(f"⏱️ Query Processing Time: {query_time:.3f}s")
- print(f"🤖 Response: {response}")
-
- # Display intermediate status every 2 questions
- if i % 2 == 0:
- print(f"\n📊 INTERMEDIATE STATUS (Question {i}):")
- display_scheduler_status(mos)
- display_memory_cube_stats(mos, user_id, mem_cube_id)
-
- total_processing_time = time.time() - question_start_time
- print(f"\n⏱️ Total Question Processing Time: {total_processing_time:.3f}s")
-
- # Display final scheduler performance summary
- print("\n" + "=" * 80)
- print("📊 FINAL SCHEDULER PERFORMANCE SUMMARY")
- print("=" * 80)
-
- summary = mos.get_scheduler_summary()
- print(f"🔢 Total Queries Processed: {summary['total_queries']}")
- print(f"⚡ Total Scheduler Calls: {summary['total_scheduler_calls']}")
- print(f"⏱️ Average Scheduler Response Time: {summary['average_scheduler_response_time']:.3f}s")
- print(f"🧠 Memory Optimizations Applied: {summary['memory_optimization_count']}")
- print(f"🔄 Working Memory Updates: {summary['working_memory_updates']}")
- print(f"⚡ Activation Memory Updates: {summary['activation_memory_updates']}")
- print(f"📈 Average Query Processing Time: {summary['average_query_processing_time']:.3f}s")
-
- # Performance insights
- print("\n💡 PERFORMANCE INSIGHTS:")
- if summary["total_scheduler_calls"] > 0:
- optimization_rate = (
- summary["memory_optimization_count"] / summary["total_scheduler_calls"]
- ) * 100
- print(f" • Memory Optimization Rate: {optimization_rate:.1f}%")
-
- if summary["average_scheduler_response_time"] < 0.1:
- print(" • Scheduler Performance: 🟢 Excellent (< 100ms)")
- elif summary["average_scheduler_response_time"] < 0.5:
- print(" • Scheduler Performance: 🟡 Good (100-500ms)")
- else:
- print(" • Scheduler Performance: 🔴 Needs Improvement (> 500ms)")
-
- # Final system status
- print("\n🔍 FINAL SYSTEM STATUS:")
- display_scheduler_status(mos)
- display_memory_cube_stats(mos, user_id, mem_cube_id)
-
- print("=" * 80)
- print("🏁 Test completed successfully!")
+ for item in questions:
+ print("===== Chat Start =====")
+ query = item["question"]
+ print(f"Query:\n {query}\n")
+ response = mos.chat(query=query, user_id=user_id)
+ print(f"Answer:\n {response}\n")
mos.mem_scheduler.stop()
diff --git a/examples/mem_scheduler/orm_examples.py b/examples/mem_scheduler/orm_examples.py
deleted file mode 100644
index bbb57b4ab..000000000
--- a/examples/mem_scheduler/orm_examples.py
+++ /dev/null
@@ -1,374 +0,0 @@
-#!/usr/bin/env python3
-"""
-ORM Examples for MemScheduler
-
-This script demonstrates how to use the BaseDBManager's new environment variable loading methods
-for MySQL and Redis connections.
-"""
-
-import multiprocessing
-import os
-import sys
-
-from pathlib import Path
-
-
-# Add the src directory to the Python path
-sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))
-
-from memos.log import get_logger
-from memos.mem_scheduler.orm_modules.base_model import BaseDBManager, DatabaseError
-from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager, SimpleListManager
-
-
-logger = get_logger(__name__)
-
-
-def test_mysql_engine_from_env():
- """Test loading MySQL engine from environment variables"""
- print("\n" + "=" * 60)
- print("Testing MySQL Engine from Environment Variables")
- print("=" * 60)
-
- try:
- # Test loading MySQL engine from current environment variables
- mysql_engine = BaseDBManager.load_mysql_engine_from_env()
- if mysql_engine is None:
- print("❌ Failed to create MySQL engine - check environment variables")
- return
-
- print(f"✅ Successfully created MySQL engine: {mysql_engine}")
- print(f" Engine URL: {mysql_engine.url}")
-
- # Test connection
- with mysql_engine.connect() as conn:
- from sqlalchemy import text
-
- result = conn.execute(text("SELECT 'MySQL connection test successful' as message"))
- message = result.fetchone()[0]
- print(f" Connection test: {message}")
-
- mysql_engine.dispose()
- print(" MySQL engine disposed successfully")
-
- except DatabaseError as e:
- print(f"❌ DatabaseError: {e}")
- except Exception as e:
- print(f"❌ Unexpected error: {e}")
-
-
-def test_redis_connection_from_env():
- """Test loading Redis connection from environment variables"""
- print("\n" + "=" * 60)
- print("Testing Redis Connection from Environment Variables")
- print("=" * 60)
-
- try:
- # Test loading Redis connection from current environment variables
- redis_client = BaseDBManager.load_redis_engine_from_env()
- if redis_client is None:
- print("❌ Failed to create Redis connection - check environment variables")
- return
-
- print(f"✅ Successfully created Redis connection: {redis_client}")
-
- # Test basic Redis operations
- redis_client.set("test_key", "Hello from ORM Examples!")
- value = redis_client.get("test_key")
- print(f" Redis test - Set/Get: {value}")
-
- # Test Redis info
- info = redis_client.info("server")
- redis_version = info.get("redis_version", "unknown")
- print(f" Redis server version: {redis_version}")
-
- # Clean up test key
- redis_client.delete("test_key")
- print(" Test key cleaned up")
-
- redis_client.close()
- print(" Redis connection closed successfully")
-
- except DatabaseError as e:
- print(f"❌ DatabaseError: {e}")
- except Exception as e:
- print(f"❌ Unexpected error: {e}")
-
-
-def test_environment_variables():
- """Test and display current environment variables"""
- print("\n" + "=" * 60)
- print("Current Environment Variables")
- print("=" * 60)
-
- # MySQL environment variables
- mysql_vars = [
- "MYSQL_HOST",
- "MYSQL_PORT",
- "MYSQL_USERNAME",
- "MYSQL_PASSWORD",
- "MYSQL_DATABASE",
- "MYSQL_CHARSET",
- ]
-
- print("\nMySQL Environment Variables:")
- for var in mysql_vars:
- value = os.getenv(var, "Not set")
- # Mask password for security
- if "PASSWORD" in var and value != "Not set":
- value = "*" * len(value)
- print(f" {var}: {value}")
-
- # Redis environment variables
- redis_vars = [
- "REDIS_HOST",
- "REDIS_PORT",
- "REDIS_DB",
- "REDIS_PASSWORD",
- "MEMSCHEDULER_REDIS_HOST",
- "MEMSCHEDULER_REDIS_PORT",
- "MEMSCHEDULER_REDIS_DB",
- "MEMSCHEDULER_REDIS_PASSWORD",
- ]
-
- print("\nRedis Environment Variables:")
- for var in redis_vars:
- value = os.getenv(var, "Not set")
- # Mask password for security
- if "PASSWORD" in var and value != "Not set":
- value = "*" * len(value)
- print(f" {var}: {value}")
-
-
-def test_manual_env_loading():
- """Test loading environment variables manually from .env file"""
- print("\n" + "=" * 60)
- print("Testing Manual Environment Loading")
- print("=" * 60)
-
- env_file_path = "/Users/travistang/Documents/codes/memos/.env"
-
- if not os.path.exists(env_file_path):
- print(f"❌ Environment file not found: {env_file_path}")
- return
-
- try:
- from dotenv import load_dotenv
-
- # Load environment variables
- load_dotenv(env_file_path)
- print(f"✅ Successfully loaded environment variables from {env_file_path}")
-
- # Test some key variables
- test_vars = ["OPENAI_API_KEY", "MOS_CHAT_MODEL", "TZ"]
- for var in test_vars:
- value = os.getenv(var, "Not set")
- if "KEY" in var and value != "Not set":
- value = f"{value[:10]}..." if len(value) > 10 else value
- print(f" {var}: {value}")
-
- except ImportError:
- print("❌ python-dotenv not installed. Install with: pip install python-dotenv")
- except Exception as e:
- print(f"❌ Error loading environment file: {e}")
-
-
-def test_redis_lockable_orm_with_list():
- """Test RedisDBManager with list[str] type synchronization"""
- print("\n" + "=" * 60)
- print("Testing RedisDBManager with list[str]")
- print("=" * 60)
-
- try:
- from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager
-
- # Create a simple list manager instance
- list_manager = SimpleListManager(["apple", "banana", "cherry"])
- print(f"Original list manager: {list_manager}")
-
- # Create RedisDBManager instance
- redis_client = BaseDBManager.load_redis_engine_from_env()
- if redis_client is None:
- print("❌ Failed to create Redis connection - check environment variables")
- return
-
- db_manager = RedisDBManager(
- redis_client=redis_client,
- user_id="test_user",
- mem_cube_id="test_list_cube",
- obj=list_manager,
- )
-
- # Save to Redis
- db_manager.save_to_db(list_manager)
- print("✅ List manager saved to Redis")
-
- # Load from Redis
- loaded_manager = db_manager.load_from_db()
- if loaded_manager:
- print(f"Loaded list manager: {loaded_manager}")
- print(f"Items match: {list_manager.items == loaded_manager.items}")
- else:
- print("❌ Failed to load list manager from Redis")
-
- # Clean up
- redis_client.delete("lockable_orm:test_user:test_list_cube:data")
- redis_client.delete("lockable_orm:test_user:test_list_cube:lock")
- redis_client.delete("lockable_orm:test_user:test_list_cube:version")
- redis_client.close()
-
- except Exception as e:
- print(f"❌ Error in RedisDBManager test: {e}")
-
-
-def modify_list_process(process_id: int, items_to_add: list[str]):
- """Function to be run in separate processes to modify the list using merge_items"""
- try:
- from memos.mem_scheduler.orm_modules.redis_model import RedisDBManager
-
- # Create Redis connection
- redis_client = BaseDBManager.load_redis_engine_from_env()
- if redis_client is None:
- print(f"Process {process_id}: Failed to create Redis connection")
- return
-
- # Create a temporary list manager for this process with items to add
- temp_manager = SimpleListManager()
-
- db_manager = RedisDBManager(
- redis_client=redis_client,
- user_id="test_user",
- mem_cube_id="multiprocess_list",
- obj=temp_manager,
- )
-
- print(f"Process {process_id}: Starting modification with items: {items_to_add}")
- for item in items_to_add:
- db_manager.obj.add_item(item)
- # Use sync_with_orm which internally uses merge_items
- db_manager.sync_with_orm(size_limit=None)
-
- print(f"Process {process_id}: Successfully synchronized with Redis")
-
- redis_client.close()
-
- except Exception as e:
- print(f"Process {process_id}: Error - {e}")
- import traceback
-
- traceback.print_exc()
-
-
-def test_multiprocess_synchronization():
- """Test multiprocess synchronization with RedisDBManager"""
- print("\n" + "=" * 60)
- print("Testing Multiprocess Synchronization")
- print("=" * 60)
-
- try:
- # Initialize Redis with empty list
- redis_client = BaseDBManager.load_redis_engine_from_env()
- if redis_client is None:
- print("❌ Failed to create Redis connection")
- return
-
- # Initialize with empty list
- initial_manager = SimpleListManager([])
- db_manager = RedisDBManager(
- redis_client=redis_client,
- user_id="test_user",
- mem_cube_id="multiprocess_list",
- obj=initial_manager,
- )
- db_manager.save_to_db(initial_manager)
- print("✅ Initialized empty list manager in Redis")
-
- # Define items for each process to add
- process_items = [
- ["item1", "item2"],
- ["item3", "item4"],
- ["item5", "item6"],
- ["item1", "item7"], # item1 is duplicate, should not be added twice
- ]
-
- # Create and start processes
- processes = []
- for i, items in enumerate(process_items):
- p = multiprocessing.Process(target=modify_list_process, args=(i + 1, items))
- processes.append(p)
- p.start()
-
- # Wait for all processes to complete
- for p in processes:
- p.join()
-
- print("\n" + "-" * 40)
- print("All processes completed. Checking final result...")
-
- # Load final result
- final_db_manager = RedisDBManager(
- redis_client=redis_client,
- user_id="test_user",
- mem_cube_id="multiprocess_list",
- obj=SimpleListManager([]),
- )
- final_manager = final_db_manager.load_from_db()
-
- if final_manager:
- print(f"Final synchronized list manager: {final_manager}")
- print(f"Final list length: {len(final_manager)}")
- print("Expected items: {'item1', 'item2', 'item3', 'item4', 'item5', 'item6', 'item7'}")
- print(f"Actual items: {set(final_manager.items)}")
-
- # Check if all unique items are present
- expected_items = {"item1", "item2", "item3", "item4", "item5", "item6", "item7"}
- actual_items = set(final_manager.items)
-
- if expected_items == actual_items:
- print("✅ All processes contributed correctly - synchronization successful!")
- else:
- print(f"❌ Expected items: {expected_items}")
- print(f" Actual items: {actual_items}")
- else:
- print("❌ Failed to load final result")
-
- # Clean up
- redis_client.delete("lockable_orm:test_user:multiprocess_list:data")
- redis_client.delete("lockable_orm:test_user:multiprocess_list:lock")
- redis_client.delete("lockable_orm:test_user:multiprocess_list:version")
- redis_client.close()
-
- except Exception as e:
- print(f"❌ Error in multiprocess synchronization test: {e}")
-
-
-def main():
- """Main function to run all tests"""
- print("ORM Examples - Environment Variable Loading Tests")
- print("=" * 80)
-
- # Test environment variables display
- test_environment_variables()
-
- # Test manual environment loading
- test_manual_env_loading()
-
- # Test MySQL engine loading
- test_mysql_engine_from_env()
-
- # Test Redis connection loading
- test_redis_connection_from_env()
-
- # Test RedisLockableORM with list[str]
- test_redis_lockable_orm_with_list()
-
- # Test multiprocess synchronization
- test_multiprocess_synchronization()
-
- print("\n" + "=" * 80)
- print("All tests completed!")
- print("=" * 80)
-
-
-if __name__ == "__main__":
- main()
diff --git a/examples/mem_scheduler/redis_example.py b/examples/mem_scheduler/redis_example.py
index 1660d6c02..2c3801539 100644
--- a/examples/mem_scheduler/redis_example.py
+++ b/examples/mem_scheduler/redis_example.py
@@ -22,7 +22,7 @@
sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
-async def service_run():
+def service_run():
# Init
example_scheduler_config_path = (
f"{BASE_DIR}/examples/data/config/mem_scheduler/general_scheduler_config.yaml"
@@ -60,11 +60,11 @@ async def service_run():
content=query,
timestamp=datetime.now(),
)
- res = await mem_scheduler.redis_add_message_stream(message=message_item.to_dict())
+ res = mem_scheduler.redis_add_message_stream(message=message_item.to_dict())
print(
f"Added: {res}",
)
- await asyncio.sleep(0.5)
+ asyncio.sleep(0.5)
mem_scheduler.redis_stop_listening()
@@ -72,4 +72,4 @@ async def service_run():
if __name__ == "__main__":
- asyncio.run(service_run())
+ service_run()
diff --git a/examples/mem_scheduler/try_schedule_modules.py b/examples/mem_scheduler/try_schedule_modules.py
index de99f1c95..4aedac711 100644
--- a/examples/mem_scheduler/try_schedule_modules.py
+++ b/examples/mem_scheduler/try_schedule_modules.py
@@ -176,6 +176,7 @@ def show_web_logs(mem_scheduler: GeneralScheduler):
mos.register_mem_cube(
mem_cube_name_or_path=mem_cube_name_or_path, mem_cube_id=mem_cube_id, user_id=user_id
)
+ mos.mem_scheduler.current_mem_cube = mem_cube
mos.add(conversations, user_id=user_id, mem_cube_id=mem_cube_id)
diff --git a/src/memos/api/config.py b/src/memos/api/config.py
index f02edaad6..a276fa63d 100644
--- a/src/memos/api/config.py
+++ b/src/memos/api/config.py
@@ -175,7 +175,7 @@ def start_config_watch(cls):
@classmethod
def start_watch_if_enabled(cls) -> None:
enable = os.getenv("NACOS_ENABLE_WATCH", "false").lower() == "true"
- print("enable:", enable)
+ logger.info(f"NACOS_ENABLE_WATCH: {enable}")
if not enable:
return
interval = int(os.getenv("NACOS_WATCH_INTERVAL", "60"))
@@ -623,7 +623,10 @@ def get_scheduler_config() -> dict[str, Any]:
"MOS_SCHEDULER_ENABLE_PARALLEL_DISPATCH", "true"
).lower()
== "true",
- "enable_activation_memory": True,
+ "enable_activation_memory": os.getenv(
+ "MOS_SCHEDULER_ENABLE_ACTIVATION_MEMORY", "false"
+ ).lower()
+ == "true",
},
}
diff --git a/src/memos/api/handlers/__init__.py b/src/memos/api/handlers/__init__.py
new file mode 100644
index 000000000..90347768c
--- /dev/null
+++ b/src/memos/api/handlers/__init__.py
@@ -0,0 +1,62 @@
+"""
+Server handlers for MemOS API routers.
+
+This package contains modular handlers for the server_router, responsible for:
+- Building component configurations (config_builders)
+- Initializing server components (component_init)
+- Formatting data for API responses (formatters)
+- Handling search, add, scheduler, and chat operations
+"""
+
+# Lazy imports to avoid circular dependencies
+from memos.api.handlers import (
+ add_handler,
+ chat_handler,
+ memory_handler,
+ scheduler_handler,
+ search_handler,
+ suggestion_handler,
+)
+from memos.api.handlers.component_init import init_server
+from memos.api.handlers.config_builders import (
+ build_embedder_config,
+ build_graph_db_config,
+ build_internet_retriever_config,
+ build_llm_config,
+ build_mem_reader_config,
+ build_pref_adder_config,
+ build_pref_extractor_config,
+ build_pref_retriever_config,
+ build_reranker_config,
+ build_vec_db_config,
+)
+from memos.api.handlers.formatters_handler import (
+ format_memory_item,
+ post_process_pref_mem,
+ to_iter,
+)
+
+
+__all__ = [
+ "add_handler",
+ "build_embedder_config",
+ "build_graph_db_config",
+ "build_internet_retriever_config",
+ "build_llm_config",
+ "build_mem_reader_config",
+ "build_pref_adder_config",
+ "build_pref_extractor_config",
+ "build_pref_retriever_config",
+ "build_reranker_config",
+ "build_vec_db_config",
+ "chat_handler",
+ "format_memory_item",
+ "formatters_handler",
+ "init_server",
+ "memory_handler",
+ "post_process_pref_mem",
+ "scheduler_handler",
+ "search_handler",
+ "suggestion_handler",
+ "to_iter",
+]
diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py
new file mode 100644
index 000000000..ee481d028
--- /dev/null
+++ b/src/memos/api/handlers/add_handler.py
@@ -0,0 +1,294 @@
+"""
+Add handler for memory addition functionality (Class-based version).
+
+This module provides a class-based implementation of add handlers,
+using dependency injection for better modularity and testability.
+"""
+
+import json
+import os
+
+from datetime import datetime
+
+from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
+from memos.api.product_models import APIADDRequest, MemoryResponse
+from memos.context.context import ContextThreadPoolExecutor
+from memos.mem_scheduler.schemas.general_schemas import (
+ ADD_LABEL,
+ MEM_READ_LABEL,
+ PREF_ADD_LABEL,
+)
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.types import UserContext
+
+
+class AddHandler(BaseHandler):
+ """
+ Handler for memory addition operations.
+
+ Handles both text and preference memory additions with sync/async support.
+ """
+
+ def __init__(self, dependencies: HandlerDependencies):
+ """
+ Initialize add handler.
+
+ Args:
+ dependencies: HandlerDependencies instance
+ """
+ super().__init__(dependencies)
+ self._validate_dependencies("naive_mem_cube", "mem_reader", "mem_scheduler")
+
+ def handle_add_memories(self, add_req: APIADDRequest) -> MemoryResponse:
+ """
+ Main handler for add memories endpoint.
+
+ Orchestrates the addition of both text and preference memories,
+ supporting concurrent processing.
+
+ Args:
+ add_req: Add memory request
+
+ Returns:
+ MemoryResponse with added memory information
+ """
+ # Create UserContext object
+ user_context = UserContext(
+ user_id=add_req.user_id,
+ mem_cube_id=add_req.mem_cube_id,
+ session_id=add_req.session_id or "default_session",
+ )
+
+ self.logger.info(f"Add Req is: {add_req}")
+ if (not add_req.messages) and add_req.memory_content:
+ add_req.messages = self._convert_content_messsage(add_req.memory_content)
+ self.logger.info(f"Converted Add Req content to messages: {add_req.messages}")
+ # Process text and preference memories in parallel
+ with ContextThreadPoolExecutor(max_workers=2) as executor:
+ text_future = executor.submit(self._process_text_mem, add_req, user_context)
+ pref_future = executor.submit(self._process_pref_mem, add_req, user_context)
+
+ text_response_data = text_future.result()
+ pref_response_data = pref_future.result()
+
+ self.logger.info(f"add_memories Text response data: {text_response_data}")
+ self.logger.info(f"add_memories Pref response data: {pref_response_data}")
+
+ return MemoryResponse(
+ message="Memory added successfully",
+ data=text_response_data + pref_response_data,
+ )
+
+ def _convert_content_messsage(self, memory_content: str) -> list[dict[str, str]]:
+ """
+ Convert content string to list of message dictionaries.
+
+ Args:
+ content: add content string
+
+ Returns:
+ List of message dictionaries
+ """
+ messages_list = [
+ {
+ "role": "user",
+ "content": memory_content,
+ "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
+ }
+ ]
+ # for only user-str input and convert message
+ return messages_list
+
+ def _process_text_mem(
+ self,
+ add_req: APIADDRequest,
+ user_context: UserContext,
+ ) -> list[dict[str, str]]:
+ """
+ Process and add text memories.
+
+ Extracts memories from messages and adds them to the text memory system.
+ Handles both sync and async modes.
+
+ Args:
+ add_req: Add memory request
+ user_context: User context with IDs
+
+ Returns:
+ List of formatted memory responses
+ """
+ target_session_id = add_req.session_id or "default_session"
+
+ # Determine sync mode
+ sync_mode = add_req.async_mode or self._get_sync_mode()
+
+ self.logger.info(f"Processing text memory with mode: {sync_mode}")
+
+ # Extract memories
+ memories_local = self.mem_reader.get_memory(
+ [add_req.messages],
+ type="chat",
+ info={
+ "user_id": add_req.user_id,
+ "session_id": target_session_id,
+ },
+ mode="fast" if sync_mode == "async" else "fine",
+ )
+ flattened_local = [mm for m in memories_local for mm in m]
+ self.logger.info(f"Memory extraction completed for user {add_req.user_id}")
+
+ # Add memories to text_mem
+ mem_ids_local: list[str] = self.naive_mem_cube.text_mem.add(
+ flattened_local,
+ user_name=user_context.mem_cube_id,
+ )
+ self.logger.info(
+ f"Added {len(mem_ids_local)} memories for user {add_req.user_id} "
+ f"in session {add_req.session_id}: {mem_ids_local}"
+ )
+
+ # Schedule async/sync tasks
+ self._schedule_memory_tasks(
+ add_req=add_req,
+ user_context=user_context,
+ mem_ids=mem_ids_local,
+ sync_mode=sync_mode,
+ )
+
+ return [
+ {
+ "memory": memory.memory,
+ "memory_id": memory_id,
+ "memory_type": memory.metadata.memory_type,
+ }
+ for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False)
+ ]
+
+ def _process_pref_mem(
+ self,
+ add_req: APIADDRequest,
+ user_context: UserContext,
+ ) -> list[dict[str, str]]:
+ """
+ Process and add preference memories.
+
+ Extracts preferences from messages and adds them to the preference memory system.
+ Handles both sync and async modes.
+
+ Args:
+ add_req: Add memory request
+ user_context: User context with IDs
+
+ Returns:
+ List of formatted preference responses
+ """
+ if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true":
+ return []
+
+ # Determine sync mode
+ sync_mode = add_req.async_mode or self._get_sync_mode()
+ target_session_id = add_req.session_id or "default_session"
+
+ # Follow async behavior: enqueue when async
+ if sync_mode == "async":
+ try:
+ messages_list = [add_req.messages]
+ message_item_pref = ScheduleMessageItem(
+ user_id=add_req.user_id,
+ session_id=target_session_id,
+ mem_cube_id=add_req.mem_cube_id,
+ mem_cube=self.naive_mem_cube,
+ label=PREF_ADD_LABEL,
+ content=json.dumps(messages_list),
+ timestamp=datetime.utcnow(),
+ )
+ self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_pref])
+ self.logger.info("Submitted preference add to scheduler (async mode)")
+ except Exception as e:
+ self.logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True)
+ return []
+ else:
+ # Sync mode: process immediately
+ pref_memories_local = self.naive_mem_cube.pref_mem.get_memory(
+ [add_req.messages],
+ type="chat",
+ info={
+ "user_id": add_req.user_id,
+ "session_id": target_session_id,
+ "mem_cube_id": add_req.mem_cube_id,
+ },
+ )
+ pref_ids_local: list[str] = self.naive_mem_cube.pref_mem.add(pref_memories_local)
+ self.logger.info(
+ f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} "
+ f"in session {add_req.session_id}: {pref_ids_local}"
+ )
+ return [
+ {
+ "memory": memory.memory,
+ "memory_id": memory_id,
+ "memory_type": memory.metadata.preference_type,
+ }
+ for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False)
+ ]
+
+ def _get_sync_mode(self) -> str:
+ """
+ Get synchronization mode from memory cube.
+
+ Returns:
+ Sync mode string ("sync" or "async")
+ """
+ try:
+ return getattr(self.naive_mem_cube.text_mem, "mode", "sync")
+ except Exception:
+ return "sync"
+
+ def _schedule_memory_tasks(
+ self,
+ add_req: APIADDRequest,
+ user_context: UserContext,
+ mem_ids: list[str],
+ sync_mode: str,
+ ) -> None:
+ """
+ Schedule memory processing tasks based on sync mode.
+
+ Args:
+ add_req: Add memory request
+ user_context: User context
+ mem_ids: List of memory IDs
+ sync_mode: Synchronization mode
+ """
+ target_session_id = add_req.session_id or "default_session"
+
+ if sync_mode == "async":
+ # Async mode: submit MEM_READ_LABEL task
+ try:
+ message_item_read = ScheduleMessageItem(
+ user_id=add_req.user_id,
+ session_id=target_session_id,
+ mem_cube_id=add_req.mem_cube_id,
+ mem_cube=self.naive_mem_cube,
+ label=MEM_READ_LABEL,
+ content=json.dumps(mem_ids),
+ timestamp=datetime.utcnow(),
+ user_name=add_req.mem_cube_id,
+ )
+ self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_read])
+ self.logger.info(f"Submitted async memory read task: {json.dumps(mem_ids)}")
+ except Exception as e:
+ self.logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True)
+ else:
+ # Sync mode: submit ADD_LABEL task
+ message_item_add = ScheduleMessageItem(
+ user_id=add_req.user_id,
+ session_id=target_session_id,
+ mem_cube_id=add_req.mem_cube_id,
+ mem_cube=self.naive_mem_cube,
+ label=ADD_LABEL,
+ content=json.dumps(mem_ids),
+ timestamp=datetime.utcnow(),
+ user_name=add_req.mem_cube_id,
+ )
+ self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_add])
diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py
new file mode 100644
index 000000000..a686ac8f9
--- /dev/null
+++ b/src/memos/api/handlers/base_handler.py
@@ -0,0 +1,182 @@
+"""
+Base handler for MemOS API handlers.
+
+This module provides the base class for all API handlers, implementing
+dependency injection and common functionality.
+"""
+
+from typing import Any
+
+from memos.log import get_logger
+from memos.mem_scheduler.base_scheduler import BaseScheduler
+from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
+
+
+logger = get_logger(__name__)
+
+
+class HandlerDependencies:
+ """
+ Container for handler dependencies.
+
+ This class acts as a dependency injection container, holding all
+ shared resources needed by handlers.
+ """
+
+ def __init__(
+ self,
+ llm: Any | None = None,
+ naive_mem_cube: Any | None = None,
+ mem_reader: Any | None = None,
+ mem_scheduler: Any | None = None,
+ searcher: Any | None = None,
+ embedder: Any | None = None,
+ reranker: Any | None = None,
+ graph_db: Any | None = None,
+ vector_db: Any | None = None,
+ internet_retriever: Any | None = None,
+ memory_manager: Any | None = None,
+ mos_server: Any | None = None,
+ **kwargs,
+ ):
+ """
+ Initialize handler dependencies.
+
+ Args:
+ llm: Language model instance
+ naive_mem_cube: Memory cube instance
+ mem_reader: Memory reader instance
+ mem_scheduler: Scheduler instance
+ embedder: Embedder instance
+ reranker: Reranker instance
+ graph_db: Graph database instance
+ vector_db: Vector database instance
+ internet_retriever: Internet retriever instance
+ memory_manager: Memory manager instance
+ mos_server: MOS server instance
+ **kwargs: Additional dependencies
+ """
+ self.llm = llm
+ self.naive_mem_cube = naive_mem_cube
+ self.mem_reader = mem_reader
+ self.mem_scheduler = mem_scheduler
+ self.searcher = searcher
+ self.embedder = embedder
+ self.reranker = reranker
+ self.graph_db = graph_db
+ self.vector_db = vector_db
+ self.internet_retriever = internet_retriever
+ self.memory_manager = memory_manager
+ self.mos_server = mos_server
+
+ # Store any additional dependencies
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+ @classmethod
+ def from_init_server(cls, components: dict[str, Any]):
+ """
+ Create dependencies from init_server() return values.
+
+ Args:
+ components: Dictionary of components returned by init_server().
+ All components will be automatically unpacked as dependencies.
+
+ Returns:
+ HandlerDependencies instance
+
+ Note:
+ This method uses **kwargs unpacking, so any new components added to
+ init_server() will automatically become available as dependencies
+ without modifying this code.
+ """
+ return cls(**components)
+
+
+class BaseHandler:
+ """
+ Base class for all API handlers.
+
+ Provides common functionality and dependency injection for handlers.
+ All specific handlers should inherit from this class.
+ """
+
+ def __init__(self, dependencies: HandlerDependencies):
+ """
+ Initialize base handler.
+
+ Args:
+ dependencies: HandlerDependencies instance containing all shared resources
+ """
+ self.deps = dependencies
+ self.logger = get_logger(self.__class__.__name__)
+
+ @property
+ def llm(self):
+ """Get LLM instance."""
+ return self.deps.llm
+
+ @property
+ def naive_mem_cube(self):
+ """Get memory cube instance."""
+ return self.deps.naive_mem_cube
+
+ @property
+ def mem_reader(self):
+ """Get memory reader instance."""
+ return self.deps.mem_reader
+
+ @property
+ def mem_scheduler(self) -> BaseScheduler:
+ """Get scheduler instance."""
+ return self.deps.mem_scheduler
+
+ @property
+ def searcher(self) -> Searcher:
+ """Get scheduler instance."""
+ return self.deps.searcher
+
+ @property
+ def embedder(self):
+ """Get embedder instance."""
+ return self.deps.embedder
+
+ @property
+ def reranker(self):
+ """Get reranker instance."""
+ return self.deps.reranker
+
+ @property
+ def graph_db(self):
+ """Get graph database instance."""
+ return self.deps.graph_db
+
+ @property
+ def vector_db(self):
+ """Get vector database instance."""
+ return self.deps.vector_db
+
+ @property
+ def mos_server(self):
+ """Get MOS server instance."""
+ return self.deps.mos_server
+
+ def _validate_dependencies(self, *required_deps: str) -> None:
+ """
+ Validate that required dependencies are available.
+
+ Args:
+ *required_deps: Names of required dependency attributes
+
+ Raises:
+ ValueError: If any required dependency is None
+ """
+ missing = []
+ for dep_name in required_deps:
+ if not hasattr(self.deps, dep_name) or getattr(self.deps, dep_name) is None:
+ missing.append(dep_name)
+
+ if missing:
+ raise ValueError(
+ f"{self.__class__.__name__} requires the following dependencies: {', '.join(missing)}"
+ )
diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py
new file mode 100644
index 000000000..8540a67ec
--- /dev/null
+++ b/src/memos/api/handlers/chat_handler.py
@@ -0,0 +1,824 @@
+"""
+Chat handler for chat functionality (Class-based version).
+
+This module provides a complete implementation of chat handlers,
+consolidating all chat-related logic without depending on mos_server.
+"""
+
+import asyncio
+import json
+import traceback
+
+from collections.abc import Generator
+from datetime import datetime
+from typing import Any, Literal
+
+from fastapi import HTTPException
+from fastapi.responses import StreamingResponse
+
+from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
+from memos.api.product_models import (
+ APIADDRequest,
+ APIChatCompleteRequest,
+ APISearchRequest,
+ ChatRequest,
+)
+from memos.context.context import ContextThread
+from memos.mem_os.utils.format_utils import clean_json_response
+from memos.mem_os.utils.reference_utils import (
+ prepare_reference_data,
+ process_streaming_references_complete,
+)
+from memos.mem_scheduler.schemas.general_schemas import (
+ ANSWER_LABEL,
+ QUERY_LABEL,
+ SearchMode,
+)
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.templates.mos_prompts import (
+ FURTHER_SUGGESTION_PROMPT,
+ get_memos_prompt,
+)
+from memos.types import MessageList
+
+
+class ChatHandler(BaseHandler):
+ """
+ Handler for chat operations.
+
+ Composes SearchHandler and AddHandler to provide complete chat functionality
+ without depending on mos_server. All chat logic is centralized here.
+ """
+
+ def __init__(
+ self,
+ dependencies: HandlerDependencies,
+ search_handler=None,
+ add_handler=None,
+ online_bot=None,
+ ):
+ """
+ Initialize chat handler.
+
+ Args:
+ dependencies: HandlerDependencies instance
+ search_handler: Optional SearchHandler instance (created if not provided)
+ add_handler: Optional AddHandler instance (created if not provided)
+ online_bot: Optional DingDing bot function for notifications
+ """
+ super().__init__(dependencies)
+ self._validate_dependencies("llm", "naive_mem_cube", "mem_reader", "mem_scheduler")
+
+ # Lazy import to avoid circular dependencies
+ if search_handler is None:
+ from memos.api.handlers.search_handler import SearchHandler
+
+ search_handler = SearchHandler(dependencies)
+
+ if add_handler is None:
+ from memos.api.handlers.add_handler import AddHandler
+
+ add_handler = AddHandler(dependencies)
+
+ self.search_handler = search_handler
+ self.add_handler = add_handler
+ self.online_bot = online_bot
+
+ # Check if scheduler is enabled
+ self.enable_mem_scheduler = (
+ hasattr(dependencies, "enable_mem_scheduler") and dependencies.enable_mem_scheduler
+ )
+
+ def handle_chat_complete(self, chat_req: APIChatCompleteRequest) -> dict[str, Any]:
+ """
+ Chat with MemOS for complete response (non-streaming).
+
+ This implementation directly uses search/add handlers instead of mos_server.
+
+ Args:
+ chat_req: Chat complete request
+
+ Returns:
+ Dictionary with response and references
+
+ Raises:
+ HTTPException: If chat fails
+ """
+ try:
+ import time
+
+ time_start = time.time()
+
+ # Step 1: Search for relevant memories
+ search_req = APISearchRequest(
+ user_id=chat_req.user_id,
+ mem_cube_id=chat_req.mem_cube_id,
+ query=chat_req.query,
+ top_k=chat_req.top_k or 10,
+ session_id=chat_req.session_id,
+ mode=SearchMode.FAST,
+ internet_search=chat_req.internet_search,
+ moscube=chat_req.moscube,
+ chat_history=chat_req.history,
+ )
+
+ search_response = self.search_handler.handle_search_memories(search_req)
+
+ # Extract memories from search results
+ memories_list = []
+ if search_response.data and search_response.data.get("text_mem"):
+ text_mem_results = search_response.data["text_mem"]
+ if text_mem_results and text_mem_results[0].get("memories"):
+ memories_list = text_mem_results[0]["memories"]
+
+ # Filter memories by threshold
+ filtered_memories = self._filter_memories_by_threshold(
+ memories_list, chat_req.threshold or 0.5
+ )
+
+ # Step 2: Build system prompt
+ system_prompt = self._build_system_prompt(filtered_memories, chat_req.base_prompt)
+
+ # Prepare message history
+ history_info = chat_req.history[-20:] if chat_req.history else []
+ current_messages = [
+ {"role": "system", "content": system_prompt},
+ *history_info,
+ {"role": "user", "content": chat_req.query},
+ ]
+
+ self.logger.info("Starting to generate complete response...")
+
+ # Step 3: Generate complete response from LLM
+ response = self.llm.generate(current_messages)
+
+ time_end = time.time()
+
+ # Step 4: Start post-chat processing asynchronously
+ self._start_post_chat_processing(
+ user_id=chat_req.user_id,
+ cube_id=chat_req.mem_cube_id,
+ session_id=chat_req.session_id or "default_session",
+ query=chat_req.query,
+ full_response=response,
+ system_prompt=system_prompt,
+ time_start=time_start,
+ time_end=time_end,
+ speed_improvement=0.0,
+ current_messages=current_messages,
+ )
+
+ # Return the complete response
+ return {
+ "message": "Chat completed successfully",
+ "data": {"response": response, "references": filtered_memories},
+ }
+
+ except ValueError as err:
+ raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
+ except Exception as err:
+ self.logger.error(f"Failed to complete chat: {traceback.format_exc()}")
+ raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
+
+ def handle_chat_stream(self, chat_req: ChatRequest) -> StreamingResponse:
+ """
+ Chat with MemOS via Server-Sent Events (SSE) stream using search/add handlers.
+
+ This implementation directly uses search_handler and add_handler.
+
+ Args:
+ chat_req: Chat stream request
+
+ Returns:
+ StreamingResponse with SSE formatted chat stream
+
+ Raises:
+ HTTPException: If stream initialization fails
+ """
+ try:
+
+ def generate_chat_response() -> Generator[str, None, None]:
+ """Generate chat response as SSE stream."""
+ try:
+ import time
+
+ time_start = time.time()
+
+ # Step 1: Search for memories using search handler
+ yield f"data: {json.dumps({'type': 'status', 'data': '0'})}\n\n"
+
+ search_req = APISearchRequest(
+ user_id=chat_req.user_id,
+ mem_cube_id=chat_req.mem_cube_id,
+ query=chat_req.query,
+ top_k=20,
+ session_id=chat_req.session_id,
+ mode=SearchMode.FAST,
+ internet_search=chat_req.internet_search, # TODO this param is not worked at fine mode
+ moscube=chat_req.moscube,
+ chat_history=chat_req.history,
+ )
+
+ search_response = self.search_handler.handle_search_memories(search_req)
+
+ yield f"data: {json.dumps({'type': 'status', 'data': '1'})}\n\n"
+ self._send_message_to_scheduler(
+ user_id=chat_req.user_id,
+ mem_cube_id=chat_req.mem_cube_id,
+ query=chat_req.query,
+ label=QUERY_LABEL,
+ )
+ # Extract memories from search results
+ memories_list = []
+ if search_response.data and search_response.data.get("text_mem"):
+ text_mem_results = search_response.data["text_mem"]
+ if text_mem_results and text_mem_results[0].get("memories"):
+ memories_list = text_mem_results[0]["memories"]
+
+ # Filter memories by threshold
+ filtered_memories = self._filter_memories_by_threshold(memories_list)
+
+ # Prepare reference data
+ reference = prepare_reference_data(filtered_memories)
+ yield f"data: {json.dumps({'type': 'reference', 'data': reference})}\n\n"
+
+ # Step 2: Build system prompt with memories
+ system_prompt = self._build_enhance_system_prompt(filtered_memories)
+
+ # Prepare messages
+ history_info = chat_req.history[-20:] if chat_req.history else []
+ current_messages = [
+ {"role": "system", "content": system_prompt},
+ *history_info,
+ {"role": "user", "content": chat_req.query},
+ ]
+
+ self.logger.info(
+ f"user_id: {chat_req.user_id}, cube_id: {chat_req.mem_cube_id}, "
+ f"current_system_prompt: {system_prompt}"
+ )
+
+ yield f"data: {json.dumps({'type': 'status', 'data': '2'})}\n\n"
+
+ # Step 3: Generate streaming response from LLM
+ response_stream = self.llm.generate_stream(current_messages)
+
+ # Stream the response
+ buffer = ""
+ full_response = ""
+
+ for chunk in response_stream:
+ if chunk in ["", ""]:
+ continue
+
+ buffer += chunk
+ full_response += chunk
+
+ # Process buffer to ensure complete reference tags
+ processed_chunk, remaining_buffer = process_streaming_references_complete(
+ buffer
+ )
+
+ if processed_chunk:
+ chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n"
+ yield chunk_data
+ buffer = remaining_buffer
+
+ # Process any remaining buffer
+ if buffer:
+ processed_chunk, _ = process_streaming_references_complete(buffer)
+ if processed_chunk:
+ chunk_data = f"data: {json.dumps({'type': 'text', 'data': processed_chunk}, ensure_ascii=False)}\n\n"
+ yield chunk_data
+
+ # Calculate timing
+ time_end = time.time()
+ speed_improvement = round(float((len(system_prompt) / 2) * 0.0048 + 44.5), 1)
+ total_time = round(float(time_end - time_start), 1)
+
+ yield f"data: {json.dumps({'type': 'time', 'data': {'total_time': total_time, 'speed_improvement': f'{speed_improvement}%'}})}\n\n"
+
+ # Get further suggestion
+ current_messages.append({"role": "assistant", "content": full_response})
+ further_suggestion = self._get_further_suggestion(current_messages)
+ self.logger.info(f"further_suggestion: {further_suggestion}")
+ yield f"data: {json.dumps({'type': 'suggestion', 'data': further_suggestion})}\n\n"
+
+ yield f"data: {json.dumps({'type': 'end'})}\n\n"
+
+ # Step 4: Add conversation to memory asynchronously
+ self._start_post_chat_processing(
+ user_id=chat_req.user_id,
+ cube_id=chat_req.mem_cube_id,
+ session_id=chat_req.session_id or "default_session",
+ query=chat_req.query,
+ full_response=full_response,
+ system_prompt=system_prompt,
+ time_start=time_start,
+ time_end=time_end,
+ speed_improvement=speed_improvement,
+ current_messages=current_messages,
+ )
+
+ except Exception as e:
+ self.logger.error(f"Error in chat stream: {e}", exc_info=True)
+ error_data = f"data: {json.dumps({'type': 'error', 'content': str(traceback.format_exc())})}\n\n"
+ yield error_data
+
+ return StreamingResponse(
+ generate_chat_response(),
+ media_type="text/event-stream",
+ headers={
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ "Content-Type": "text/event-stream",
+ "Access-Control-Allow-Origin": "*",
+ "Access-Control-Allow-Headers": "*",
+ "Access-Control-Allow-Methods": "*",
+ },
+ )
+
+ except ValueError as err:
+ raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
+ except Exception as err:
+ self.logger.error(f"Failed to start chat stream: {traceback.format_exc()}")
+ raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
+
+ def _build_system_prompt(
+ self,
+ memories: list | None = None,
+ base_prompt: str | None = None,
+ **kwargs,
+ ) -> str:
+ """Build system prompt with optional memories context."""
+ if base_prompt is None:
+ base_prompt = (
+ "You are a knowledgeable and helpful AI assistant. "
+ "You have access to conversation memories that help you provide more personalized responses. "
+ "Use the memories to understand the user's context, preferences, and past interactions. "
+ "If memories are provided, reference them naturally when relevant, but don't explicitly mention having memories."
+ )
+
+ memory_context = ""
+ if memories:
+ memory_list = []
+ for i, memory in enumerate(memories, 1):
+ text_memory = memory.get("memory", "")
+ memory_list.append(f"{i}. {text_memory}")
+ memory_context = "\n".join(memory_list)
+
+ if "{memories}" in base_prompt:
+ return base_prompt.format(memories=memory_context)
+ elif base_prompt and memories:
+ # For backward compatibility, append memories if no placeholder is found
+ memory_context_with_header = "\n\n## Memories:\n" + memory_context
+ return base_prompt + memory_context_with_header
+ return base_prompt
+
+ def _build_enhance_system_prompt(
+ self,
+ memories_list: list,
+ tone: str = "friendly",
+ verbosity: str = "mid",
+ ) -> str:
+ """
+ Build enhanced system prompt with memories (for streaming response).
+
+ Args:
+ memories_list: List of memory items
+ tone: Tone of the prompt
+ verbosity: Verbosity level
+
+ Returns:
+ System prompt string
+ """
+ now = datetime.now()
+ formatted_date = now.strftime("%Y-%m-%d (%A)")
+ sys_body = get_memos_prompt(
+ date=formatted_date, tone=tone, verbosity=verbosity, mode="enhance"
+ )
+
+ # Format memories
+ mem_block_o, mem_block_p = self._format_mem_block(memories_list)
+
+ return (
+ sys_body
+ + "\n\n# Memories\n## PersonalMemory (ordered)\n"
+ + mem_block_p
+ + "\n## OuterMemory (ordered)\n"
+ + mem_block_o
+ )
+
+ def _format_mem_block(
+ self, memories_all: list, max_items: int = 20, max_chars_each: int = 320
+ ) -> tuple[str, str]:
+ """
+ Format memory block for prompt.
+
+ Args:
+ memories_all: List of memory items
+ max_items: Maximum number of items to format
+ max_chars_each: Maximum characters per item
+
+ Returns:
+ Tuple of (outer_memory_block, personal_memory_block)
+ """
+ if not memories_all:
+ return "(none)", "(none)"
+
+ lines_o = []
+ lines_p = []
+
+ for idx, m in enumerate(memories_all[:max_items], 1):
+ mid = m.get("id", "").split("-")[0] if m.get("id") else f"mem_{idx}"
+ memory_content = m.get("memory", "")
+ metadata = m.get("metadata", {})
+ memory_type = metadata.get("memory_type", "")
+
+ tag = "O" if "Outer" in str(memory_type) else "P"
+ txt = memory_content.replace("\n", " ").strip()
+ if len(txt) > max_chars_each:
+ txt = txt[: max_chars_each - 1] + "…"
+
+ mid = mid or f"mem_{idx}"
+ if tag == "O":
+ lines_o.append(f"[{idx}:{mid}] :: [{tag}] {txt}\n")
+ elif tag == "P":
+ lines_p.append(f"[{idx}:{mid}] :: [{tag}] {txt}")
+
+ return "\n".join(lines_o), "\n".join(lines_p)
+
+ def _filter_memories_by_threshold(
+ self,
+ memories: list,
+ threshold: float = 0.30,
+ min_num: int = 3,
+ memory_type: Literal["OuterMemory"] = "OuterMemory",
+ ) -> list:
+ """
+ Filter memories by threshold and type.
+
+ Args:
+ memories: List of memory items
+ threshold: Relevance threshold
+ min_num: Minimum number of memories to keep
+ memory_type: Memory type to filter
+
+ Returns:
+ Filtered list of memories
+ """
+ if not memories:
+ return []
+
+ # Handle dict format (from search results)
+ def get_relativity(m):
+ if isinstance(m, dict):
+ return m.get("metadata", {}).get("relativity", 0.0)
+ return getattr(getattr(m, "metadata", None), "relativity", 0.0)
+
+ def get_memory_type(m):
+ if isinstance(m, dict):
+ return m.get("metadata", {}).get("memory_type", "")
+ return getattr(getattr(m, "metadata", None), "memory_type", "")
+
+ sorted_memories = sorted(memories, key=get_relativity, reverse=True)
+ filtered_person = [m for m in memories if get_memory_type(m) != memory_type]
+ filtered_outer = [m for m in memories if get_memory_type(m) == memory_type]
+
+ filtered = []
+ per_memory_count = 0
+
+ for m in sorted_memories:
+ if get_relativity(m) >= threshold:
+ if get_memory_type(m) != memory_type:
+ per_memory_count += 1
+ filtered.append(m)
+
+ if len(filtered) < min_num:
+ filtered = filtered_person[:min_num] + filtered_outer[:min_num]
+ else:
+ if per_memory_count < min_num:
+ filtered += filtered_person[per_memory_count:min_num]
+
+ filtered_memory = sorted(filtered, key=get_relativity, reverse=True)
+ return filtered_memory
+
+ def _get_further_suggestion(
+ self,
+ current_messages: MessageList,
+ ) -> list[str]:
+ """Get further suggestion based on current messages."""
+ try:
+ dialogue_info = "\n".join(
+ [f"{msg['role']}: {msg['content']}" for msg in current_messages[-2:]]
+ )
+ further_suggestion_prompt = FURTHER_SUGGESTION_PROMPT.format(dialogue=dialogue_info)
+ message_list = [{"role": "system", "content": further_suggestion_prompt}]
+ response = self.llm.generate(message_list)
+ clean_response = clean_json_response(response)
+ response_json = json.loads(clean_response)
+ return response_json["query"]
+ except Exception as e:
+ self.logger.error(f"Error getting further suggestion: {e}", exc_info=True)
+ return []
+
+ def _extract_references_from_response(self, response: str) -> tuple[str, list[dict]]:
+ """Extract reference information from the response and return clean text."""
+ import re
+
+ try:
+ references = []
+ # Pattern to match [refid:memoriesID]
+ pattern = r"\[(\d+):([^\]]+)\]"
+
+ matches = re.findall(pattern, response)
+ for ref_number, memory_id in matches:
+ references.append({"memory_id": memory_id, "reference_number": int(ref_number)})
+
+ # Remove all reference markers from the text to get clean text
+ clean_text = re.sub(pattern, "", response)
+
+ # Clean up any extra whitespace that might be left after removing markers
+ clean_text = re.sub(r"\s+", " ", clean_text).strip()
+
+ return clean_text, references
+ except Exception as e:
+ self.logger.error(f"Error extracting references from response: {e}", exc_info=True)
+ return response, []
+
+ def _extract_struct_data_from_history(self, chat_data: list[dict]) -> dict:
+ """
+ Extract structured message data from chat history.
+
+ Args:
+ chat_data: List of chat messages
+
+ Returns:
+ Dictionary with system, memory, and chat_history
+ """
+ system_content = ""
+ memory_content = ""
+ chat_history = []
+
+ for item in chat_data:
+ role = item.get("role")
+ content = item.get("content", "")
+ if role == "system":
+ parts = content.split("# Memories", 1)
+ system_content = parts[0].strip()
+ if len(parts) > 1:
+ memory_content = "# Memories" + parts[1].strip()
+ elif role in ("user", "assistant"):
+ chat_history.append({"role": role, "content": content})
+
+ if chat_history and chat_history[-1]["role"] == "assistant":
+ if len(chat_history) >= 2 and chat_history[-2]["role"] == "user":
+ chat_history = chat_history[:-2]
+ else:
+ chat_history = chat_history[:-1]
+
+ return {"system": system_content, "memory": memory_content, "chat_history": chat_history}
+
+ def _send_message_to_scheduler(
+ self,
+ user_id: str,
+ mem_cube_id: str,
+ query: str,
+ label: str,
+ ) -> None:
+ """
+ Send message to scheduler.
+
+ Args:
+ user_id: User ID
+ mem_cube_id: Memory cube ID
+ query: Query content
+ label: Message label
+ """
+ try:
+ message_item = ScheduleMessageItem(
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ label=label,
+ content=query,
+ timestamp=datetime.utcnow(),
+ )
+ self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item])
+ self.logger.info(f"Sent message to scheduler with label: {label}")
+ except Exception as e:
+ self.logger.error(f"Failed to send message to scheduler: {e}", exc_info=True)
+
+ async def _post_chat_processing(
+ self,
+ user_id: str,
+ cube_id: str,
+ session_id: str,
+ query: str,
+ full_response: str,
+ system_prompt: str,
+ time_start: float,
+ time_end: float,
+ speed_improvement: float,
+ current_messages: list,
+ ) -> None:
+ """
+ Asynchronous post-chat processing with complete functionality.
+
+ Includes:
+ - Reference extraction
+ - DingDing notification
+ - Scheduler messaging
+ - Memory addition
+
+ Args:
+ user_id: User ID
+ cube_id: Memory cube ID
+ session_id: Session ID
+ query: User query
+ full_response: Full LLM response
+ system_prompt: System prompt used
+ time_start: Start timestamp
+ time_end: End timestamp
+ speed_improvement: Speed improvement metric
+ current_messages: Current message history
+ """
+ try:
+ self.logger.info(
+ f"user_id: {user_id}, cube_id: {cube_id}, current_messages: {current_messages}"
+ )
+ self.logger.info(
+ f"user_id: {user_id}, cube_id: {cube_id}, full_response: {full_response}"
+ )
+
+ # Extract references and clean response
+ clean_response, extracted_references = self._extract_references_from_response(
+ full_response
+ )
+ struct_message = self._extract_struct_data_from_history(current_messages)
+ self.logger.info(f"Extracted {len(extracted_references)} references from response")
+
+ # Send DingDing notification if enabled
+ if self.online_bot:
+ self.logger.info("Online Bot Open!")
+ try:
+ from memos.memos_tools.notification_utils import (
+ send_online_bot_notification_async,
+ )
+
+ # Prepare notification data
+ chat_data = {"query": query, "user_id": user_id, "cube_id": cube_id}
+ chat_data.update(
+ {
+ "memory": struct_message["memory"],
+ "chat_history": struct_message["chat_history"],
+ "full_response": full_response,
+ }
+ )
+
+ system_data = {
+ "references": extracted_references,
+ "time_start": time_start,
+ "time_end": time_end,
+ "speed_improvement": speed_improvement,
+ }
+
+ emoji_config = {"chat": "💬", "system_info": "📊"}
+
+ await send_online_bot_notification_async(
+ online_bot=self.online_bot,
+ header_name="MemOS Chat Report",
+ sub_title_name="chat_with_references",
+ title_color="#00956D",
+ other_data1=chat_data,
+ other_data2=system_data,
+ emoji=emoji_config,
+ )
+ except Exception as e:
+ self.logger.warning(f"Failed to send chat notification (async): {e}")
+
+ # Send answer to scheduler
+ self._send_message_to_scheduler(
+ user_id=user_id, mem_cube_id=cube_id, query=clean_response, label=ANSWER_LABEL
+ )
+
+ # Add conversation to memory using add handler
+ add_req = APIADDRequest(
+ user_id=user_id,
+ mem_cube_id=cube_id,
+ session_id=session_id,
+ messages=[
+ {
+ "role": "user",
+ "content": query,
+ "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
+ },
+ {
+ "role": "assistant",
+ "content": clean_response, # Store clean text without reference markers
+ "chat_time": str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
+ },
+ ],
+ async_mode="sync", # set suync for playground
+ )
+
+ self.add_handler.handle_add_memories(add_req)
+
+ self.logger.info(f"Post-chat processing completed for user {user_id}")
+
+ except Exception as e:
+ self.logger.error(
+ f"Error in post-chat processing for user {user_id}: {e}", exc_info=True
+ )
+
+ def _start_post_chat_processing(
+ self,
+ user_id: str,
+ cube_id: str,
+ session_id: str,
+ query: str,
+ full_response: str,
+ system_prompt: str,
+ time_start: float,
+ time_end: float,
+ speed_improvement: float,
+ current_messages: list,
+ ) -> None:
+ """
+ Start asynchronous post-chat processing in a background thread.
+
+ Args:
+ user_id: User ID
+ cube_id: Memory cube ID
+ session_id: Session ID
+ query: User query
+ full_response: Full LLM response
+ system_prompt: System prompt used
+ time_start: Start timestamp
+ time_end: End timestamp
+ speed_improvement: Speed improvement metric
+ current_messages: Current message history
+ """
+
+ def run_async_in_thread():
+ """Running asynchronous tasks in a new thread"""
+ try:
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ try:
+ loop.run_until_complete(
+ self._post_chat_processing(
+ user_id=user_id,
+ cube_id=cube_id,
+ session_id=session_id,
+ query=query,
+ full_response=full_response,
+ system_prompt=system_prompt,
+ time_start=time_start,
+ time_end=time_end,
+ speed_improvement=speed_improvement,
+ current_messages=current_messages,
+ )
+ )
+ finally:
+ loop.close()
+ except Exception as e:
+ self.logger.error(
+ f"Error in thread-based post-chat processing for user {user_id}: {e}",
+ exc_info=True,
+ )
+
+ try:
+ # Try to get the current event loop
+ asyncio.get_running_loop()
+ # Create task and store reference to prevent garbage collection
+ task = asyncio.create_task(
+ self._post_chat_processing(
+ user_id=user_id,
+ cube_id=cube_id,
+ session_id=session_id,
+ query=query,
+ full_response=full_response,
+ system_prompt=system_prompt,
+ time_start=time_start,
+ time_end=time_end,
+ speed_improvement=speed_improvement,
+ current_messages=current_messages,
+ )
+ )
+ # Add exception handling for the background task
+ task.add_done_callback(
+ lambda t: self.logger.error(
+ f"Error in background post-chat processing for user {user_id}: {t.exception()}",
+ exc_info=True,
+ )
+ if t.exception()
+ else None
+ )
+ except RuntimeError:
+ # No event loop, run in a new thread with context propagation
+ thread = ContextThread(
+ target=run_async_in_thread,
+ name=f"PostChatProcessing-{user_id}",
+ daemon=True,
+ )
+ thread.start()
diff --git a/src/memos/api/handlers/component_init.py b/src/memos/api/handlers/component_init.py
new file mode 100644
index 000000000..89e61e79d
--- /dev/null
+++ b/src/memos/api/handlers/component_init.py
@@ -0,0 +1,296 @@
+"""
+Server component initialization module.
+
+This module handles the initialization of all MemOS server components
+including databases, LLMs, memory systems, and schedulers.
+"""
+
+import os
+
+from typing import TYPE_CHECKING, Any
+
+from memos.api.config import APIConfig
+from memos.api.handlers.config_builders import (
+ build_embedder_config,
+ build_graph_db_config,
+ build_internet_retriever_config,
+ build_llm_config,
+ build_mem_reader_config,
+ build_pref_adder_config,
+ build_pref_extractor_config,
+ build_pref_retriever_config,
+ build_reranker_config,
+ build_vec_db_config,
+)
+from memos.configs.mem_scheduler import SchedulerConfigFactory
+from memos.embedders.factory import EmbedderFactory
+from memos.graph_dbs.factory import GraphStoreFactory
+from memos.llms.factory import LLMFactory
+from memos.log import get_logger
+from memos.mem_cube.navie import NaiveMemCube
+from memos.mem_os.product_server import MOSServer
+from memos.mem_reader.factory import MemReaderFactory
+from memos.mem_scheduler.orm_modules.base_model import BaseDBManager
+from memos.mem_scheduler.scheduler_factory import SchedulerFactory
+from memos.memories.textual.prefer_text_memory.factory import (
+ AdderFactory,
+ ExtractorFactory,
+ RetrieverFactory,
+)
+from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
+from memos.memories.textual.simple_tree import SimpleTreeTextMemory
+from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
+
+
+if TYPE_CHECKING:
+ from memos.memories.textual.tree import TreeTextMemory
+from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import (
+ InternetRetrieverFactory,
+)
+from memos.reranker.factory import RerankerFactory
+from memos.vec_dbs.factory import VecDBFactory
+
+
+if TYPE_CHECKING:
+ from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler
+ from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
+logger = get_logger(__name__)
+
+
+def _get_default_memory_size(cube_config: Any) -> dict[str, int]:
+ """
+ Get default memory size configuration.
+
+ Attempts to retrieve memory size from cube config, falls back to defaults
+ if not found.
+
+ Args:
+ cube_config: The cube configuration object
+
+ Returns:
+ Dictionary with memory sizes for different memory types
+ """
+ return getattr(cube_config.text_mem.config, "memory_size", None) or {
+ "WorkingMemory": 20,
+ "LongTermMemory": 1500,
+ "UserMemory": 480,
+ }
+
+
+def init_server() -> dict[str, Any]:
+ """
+ Initialize all server components and configurations.
+
+ This function orchestrates the creation and initialization of all components
+ required by the MemOS server, including:
+ - Database connections (graph DB, vector DB)
+ - Language models and embedders
+ - Memory systems (text, preference)
+ - Scheduler and related modules
+
+ Returns:
+ A dictionary containing all initialized components with descriptive keys.
+ This approach allows easy addition of new components without breaking
+ existing code that uses the components.
+ """
+ logger.info("Initializing MemOS server components...")
+
+ # Get default cube configuration
+ default_cube_config = APIConfig.get_default_cube_config()
+
+ # Get online bot setting
+ dingding_enabled = APIConfig.is_dingding_bot_enabled()
+
+ # Build component configurations
+ graph_db_config = build_graph_db_config()
+ llm_config = build_llm_config()
+ embedder_config = build_embedder_config()
+ mem_reader_config = build_mem_reader_config()
+ reranker_config = build_reranker_config()
+ internet_retriever_config = build_internet_retriever_config()
+ vector_db_config = build_vec_db_config()
+ pref_extractor_config = build_pref_extractor_config()
+ pref_adder_config = build_pref_adder_config()
+ pref_retriever_config = build_pref_retriever_config()
+
+ logger.debug("Component configurations built successfully")
+
+ # Create component instances
+ graph_db = GraphStoreFactory.from_config(graph_db_config)
+ vector_db = (
+ VecDBFactory.from_config(vector_db_config)
+ if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true"
+ else None
+ )
+ llm = LLMFactory.from_config(llm_config)
+ embedder = EmbedderFactory.from_config(embedder_config)
+ mem_reader = MemReaderFactory.from_config(mem_reader_config)
+ reranker = RerankerFactory.from_config(reranker_config)
+ internet_retriever = InternetRetrieverFactory.from_config(
+ internet_retriever_config, embedder=embedder
+ )
+
+ logger.debug("Core components instantiated")
+
+ # Initialize memory manager
+ memory_manager = MemoryManager(
+ graph_db,
+ embedder,
+ llm,
+ memory_size=_get_default_memory_size(default_cube_config),
+ is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False),
+ )
+
+ logger.debug("Memory manager initialized")
+
+ # Initialize text memory
+ text_mem = SimpleTreeTextMemory(
+ llm=llm,
+ embedder=embedder,
+ mem_reader=mem_reader,
+ graph_db=graph_db,
+ reranker=reranker,
+ memory_manager=memory_manager,
+ config=default_cube_config.text_mem.config,
+ internet_retriever=internet_retriever,
+ )
+
+ logger.debug("Text memory initialized")
+
+ # Initialize preference memory components
+ pref_extractor = (
+ ExtractorFactory.from_config(
+ config_factory=pref_extractor_config,
+ llm_provider=llm,
+ embedder=embedder,
+ vector_db=vector_db,
+ )
+ if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true"
+ else None
+ )
+
+ pref_adder = (
+ AdderFactory.from_config(
+ config_factory=pref_adder_config,
+ llm_provider=llm,
+ embedder=embedder,
+ vector_db=vector_db,
+ text_mem=text_mem,
+ )
+ if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true"
+ else None
+ )
+
+ pref_retriever = (
+ RetrieverFactory.from_config(
+ config_factory=pref_retriever_config,
+ llm_provider=llm,
+ embedder=embedder,
+ reranker=reranker,
+ vector_db=vector_db,
+ )
+ if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true"
+ else None
+ )
+
+ logger.debug("Preference memory components initialized")
+
+ # Initialize preference memory
+ pref_mem = (
+ SimplePreferenceTextMemory(
+ extractor_llm=llm,
+ vector_db=vector_db,
+ embedder=embedder,
+ reranker=reranker,
+ extractor=pref_extractor,
+ adder=pref_adder,
+ retriever=pref_retriever,
+ )
+ if os.getenv("ENABLE_PREFERENCE_MEMORY", "false") == "true"
+ else None
+ )
+
+ logger.debug("Preference memory initialized")
+
+ # Initialize MOS Server
+ mos_server = MOSServer(
+ mem_reader=mem_reader,
+ llm=llm,
+ online_bot=False,
+ )
+
+ logger.debug("MOS server initialized")
+
+ # Create MemCube with pre-initialized memory instances
+ naive_mem_cube = NaiveMemCube(
+ text_mem=text_mem,
+ pref_mem=pref_mem,
+ act_mem=None,
+ para_mem=None,
+ )
+
+ logger.debug("MemCube created")
+
+ tree_mem: TreeTextMemory = naive_mem_cube.text_mem
+ searcher: Searcher = tree_mem.get_searcher(
+ manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false",
+ moscube=False,
+ )
+ logger.debug("Searcher created")
+
+ # Initialize Scheduler
+ scheduler_config_dict = APIConfig.get_scheduler_config()
+ scheduler_config = SchedulerConfigFactory(
+ backend="optimized_scheduler", config=scheduler_config_dict
+ )
+ mem_scheduler: OptimizedScheduler = SchedulerFactory.from_config(scheduler_config)
+ mem_scheduler.initialize_modules(
+ chat_llm=llm,
+ process_llm=mem_reader.llm,
+ db_engine=BaseDBManager.create_default_sqlite_engine(),
+ mem_reader=mem_reader,
+ )
+ mem_scheduler.init_mem_cube(mem_cube=naive_mem_cube, searcher=searcher)
+ logger.debug("Scheduler initialized")
+
+ # Initialize SchedulerAPIModule
+ api_module = mem_scheduler.api_module
+
+ # Start scheduler if enabled
+ if os.getenv("API_SCHEDULER_ON", "true").lower() == "true":
+ mem_scheduler.start()
+ logger.info("Scheduler started")
+
+ logger.info("MemOS server components initialized successfully")
+
+ # Initialize online bot if enabled
+ online_bot = None
+ if dingding_enabled:
+ from memos.memos_tools.notification_service import get_online_bot_function
+
+ online_bot = get_online_bot_function() if dingding_enabled else None
+ logger.info("DingDing bot is enabled")
+
+ # Return all components as a dictionary for easy access and extension
+ return {
+ "graph_db": graph_db,
+ "mem_reader": mem_reader,
+ "llm": llm,
+ "embedder": embedder,
+ "reranker": reranker,
+ "internet_retriever": internet_retriever,
+ "memory_manager": memory_manager,
+ "default_cube_config": default_cube_config,
+ "mos_server": mos_server,
+ "mem_scheduler": mem_scheduler,
+ "naive_mem_cube": naive_mem_cube,
+ "searcher": searcher,
+ "api_module": api_module,
+ "vector_db": vector_db,
+ "pref_extractor": pref_extractor,
+ "pref_adder": pref_adder,
+ "pref_retriever": pref_retriever,
+ "text_mem": text_mem,
+ "pref_mem": pref_mem,
+ "online_bot": online_bot,
+ }
diff --git a/src/memos/api/handlers/config_builders.py b/src/memos/api/handlers/config_builders.py
new file mode 100644
index 000000000..9f510add0
--- /dev/null
+++ b/src/memos/api/handlers/config_builders.py
@@ -0,0 +1,153 @@
+"""
+Configuration builders for server handlers.
+
+This module contains factory functions that build configurations for various
+components used by the MemOS server. Each function constructs and validates
+a configuration dictionary using the appropriate ConfigFactory.
+"""
+
+import os
+
+from typing import Any
+
+from memos.api.config import APIConfig
+from memos.configs.embedder import EmbedderConfigFactory
+from memos.configs.graph_db import GraphDBConfigFactory
+from memos.configs.internet_retriever import InternetRetrieverConfigFactory
+from memos.configs.llm import LLMConfigFactory
+from memos.configs.mem_reader import MemReaderConfigFactory
+from memos.configs.reranker import RerankerConfigFactory
+from memos.configs.vec_db import VectorDBConfigFactory
+from memos.memories.textual.prefer_text_memory.config import (
+ AdderConfigFactory,
+ ExtractorConfigFactory,
+ RetrieverConfigFactory,
+)
+
+
+def build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
+ """
+ Build graph database configuration.
+
+ Args:
+ user_id: User ID for configuration context (default: "default")
+
+ Returns:
+ Validated graph database configuration dictionary
+ """
+ graph_db_backend_map = {
+ "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id),
+ "neo4j": APIConfig.get_neo4j_config(user_id=user_id),
+ "nebular": APIConfig.get_nebular_config(user_id=user_id),
+ "polardb": APIConfig.get_polardb_config(user_id=user_id),
+ }
+
+ graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower()
+ return GraphDBConfigFactory.model_validate(
+ {
+ "backend": graph_db_backend,
+ "config": graph_db_backend_map[graph_db_backend],
+ }
+ )
+
+
+def build_vec_db_config() -> dict[str, Any]:
+ """
+ Build vector database configuration.
+
+ Returns:
+ Validated vector database configuration dictionary
+ """
+ return VectorDBConfigFactory.model_validate(
+ {
+ "backend": "milvus",
+ "config": APIConfig.get_milvus_config(),
+ }
+ )
+
+
+def build_llm_config() -> dict[str, Any]:
+ """
+ Build LLM configuration.
+
+ Returns:
+ Validated LLM configuration dictionary
+ """
+ return LLMConfigFactory.model_validate(
+ {
+ "backend": "openai",
+ "config": APIConfig.get_openai_config(),
+ }
+ )
+
+
+def build_embedder_config() -> dict[str, Any]:
+ """
+ Build embedder configuration.
+
+ Returns:
+ Validated embedder configuration dictionary
+ """
+ return EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config())
+
+
+def build_mem_reader_config() -> dict[str, Any]:
+ """
+ Build memory reader configuration.
+
+ Returns:
+ Validated memory reader configuration dictionary
+ """
+ return MemReaderConfigFactory.model_validate(
+ APIConfig.get_product_default_config()["mem_reader"]
+ )
+
+
+def build_reranker_config() -> dict[str, Any]:
+ """
+ Build reranker configuration.
+
+ Returns:
+ Validated reranker configuration dictionary
+ """
+ return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config())
+
+
+def build_internet_retriever_config() -> dict[str, Any]:
+ """
+ Build internet retriever configuration.
+
+ Returns:
+ Validated internet retriever configuration dictionary
+ """
+ return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config())
+
+
+def build_pref_extractor_config() -> dict[str, Any]:
+ """
+ Build preference memory extractor configuration.
+
+ Returns:
+ Validated extractor configuration dictionary
+ """
+ return ExtractorConfigFactory.model_validate({"backend": "naive", "config": {}})
+
+
+def build_pref_adder_config() -> dict[str, Any]:
+ """
+ Build preference memory adder configuration.
+
+ Returns:
+ Validated adder configuration dictionary
+ """
+ return AdderConfigFactory.model_validate({"backend": "naive", "config": {}})
+
+
+def build_pref_retriever_config() -> dict[str, Any]:
+ """
+ Build preference memory retriever configuration.
+
+ Returns:
+ Validated retriever configuration dictionary
+ """
+ return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}})
diff --git a/src/memos/api/handlers/formatters_handler.py b/src/memos/api/handlers/formatters_handler.py
new file mode 100644
index 000000000..976be87bb
--- /dev/null
+++ b/src/memos/api/handlers/formatters_handler.py
@@ -0,0 +1,92 @@
+"""
+Data formatting utilities for server handlers.
+
+This module provides utility functions for formatting and transforming data
+structures for API responses, including memory items and preferences.
+"""
+
+from typing import Any
+
+from memos.templates.instruction_completion import instruct_completion
+
+
+def to_iter(running: Any) -> list[Any]:
+ """
+ Normalize running tasks to a list of task objects.
+
+ Handles different input types and converts them to a consistent list format.
+
+ Args:
+ running: Running tasks, can be None, dict, or iterable
+
+ Returns:
+ List of task objects
+ """
+ if running is None:
+ return []
+ if isinstance(running, dict):
+ return list(running.values())
+ return list(running) if running else []
+
+
+def format_memory_item(memory_data: Any) -> dict[str, Any]:
+ """
+ Format a single memory item for API response.
+
+ Transforms a memory object into a dictionary with metadata properly
+ structured for API consumption.
+
+ Args:
+ memory_data: Memory object to format
+
+ Returns:
+ Formatted memory dictionary with ref_id and metadata
+ """
+ memory = memory_data.model_dump()
+ memory_id = memory["id"]
+ ref_id = f"[{memory_id.split('-')[0]}]"
+
+ memory["ref_id"] = ref_id
+ memory["metadata"]["embedding"] = []
+ memory["metadata"]["sources"] = []
+ memory["metadata"]["usage"] = []
+ memory["metadata"]["ref_id"] = ref_id
+ memory["metadata"]["id"] = memory_id
+ memory["metadata"]["memory"] = memory["memory"]
+
+ return memory
+
+
+def post_process_pref_mem(
+ memories_result: dict[str, Any],
+ pref_formatted_mem: list[dict[str, Any]],
+ mem_cube_id: str,
+ include_preference: bool,
+) -> dict[str, Any]:
+ """
+ Post-process preference memory results.
+
+ Adds formatted preference memories to the result dictionary and generates
+ instruction completion strings if preferences are included.
+
+ Args:
+ memories_result: Result dictionary to update
+ pref_formatted_mem: List of formatted preference memories
+ mem_cube_id: Memory cube ID
+ include_preference: Whether to include preferences in result
+
+ Returns:
+ Updated memories_result dictionary
+ """
+ if include_preference:
+ memories_result["pref_mem"].append(
+ {
+ "cube_id": mem_cube_id,
+ "memories": pref_formatted_mem,
+ }
+ )
+ pref_instruction, pref_note = instruct_completion(pref_formatted_mem)
+ memories_result["pref_string"] = pref_instruction
+ memories_result["pref_note"] = pref_note
+
+ return memories_result
diff --git a/src/memos/api/handlers/memory_handler.py b/src/memos/api/handlers/memory_handler.py
new file mode 100644
index 000000000..85f339f3f
--- /dev/null
+++ b/src/memos/api/handlers/memory_handler.py
@@ -0,0 +1,151 @@
+"""
+Memory handler for retrieving and managing memories.
+
+This module handles retrieving all memories or specific subgraphs based on queries.
+"""
+
+from typing import Any, Literal
+
+from memos.api.product_models import MemoryResponse
+from memos.log import get_logger
+from memos.mem_os.utils.format_utils import (
+ convert_graph_to_tree_forworkmem,
+ ensure_unique_tree_ids,
+ filter_nodes_by_tree_ids,
+ remove_embedding_recursive,
+ sort_children_by_memory_type,
+)
+
+
+logger = get_logger(__name__)
+
+
+def handle_get_all_memories(
+ user_id: str,
+ mem_cube_id: str,
+ memory_type: Literal["text_mem", "act_mem", "param_mem", "para_mem"],
+ naive_mem_cube: Any,
+) -> MemoryResponse:
+ """
+ Main handler for getting all memories.
+
+ Retrieves all memories of specified type for a user and formats them appropriately.
+
+ Args:
+ user_id: User ID
+ mem_cube_id: Memory cube ID
+ memory_type: Type of memory to retrieve
+ naive_mem_cube: Memory cube instance
+
+ Returns:
+ MemoryResponse with formatted memory data
+ """
+ try:
+ reformat_memory_list = []
+
+ if memory_type == "text_mem":
+ # Get all text memories from the graph database
+ memories = naive_mem_cube.text_mem.get_all(user_name=mem_cube_id)
+
+ # Format and convert to tree structure
+ memories_cleaned = remove_embedding_recursive(memories)
+ custom_type_ratios = {
+ "WorkingMemory": 0.20,
+ "LongTermMemory": 0.40,
+ "UserMemory": 0.40,
+ }
+ tree_result, node_type_count = convert_graph_to_tree_forworkmem(
+ memories_cleaned, target_node_count=200, type_ratios=custom_type_ratios
+ )
+ # Ensure all node IDs are unique in the tree structure
+ tree_result = ensure_unique_tree_ids(tree_result)
+ memories_filtered = filter_nodes_by_tree_ids(tree_result, memories_cleaned)
+ children = tree_result["children"]
+ children_sort = sort_children_by_memory_type(children)
+ tree_result["children"] = children_sort
+ memories_filtered["tree_structure"] = tree_result
+
+ reformat_memory_list.append(
+ {
+ "cube_id": mem_cube_id,
+ "memories": [memories_filtered],
+ "memory_statistics": node_type_count,
+ }
+ )
+
+ elif memory_type == "act_mem":
+ logger.warning("Activity memory retrieval not implemented yet.")
+ elif memory_type == "para_mem":
+ logger.warning("Parameter memory retrieval not implemented yet.")
+ return MemoryResponse(
+ message="Memories retrieved successfully",
+ data=reformat_memory_list,
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to get all memories: {e}", exc_info=True)
+ raise
+
+
+def handle_get_subgraph(
+ user_id: str,
+ mem_cube_id: str,
+ query: str,
+ top_k: int,
+ naive_mem_cube: Any,
+) -> MemoryResponse:
+ """
+ Main handler for getting memory subgraph based on query.
+
+ Retrieves relevant memory subgraph and formats it as a tree structure.
+
+ Args:
+ user_id: User ID
+ mem_cube_id: Memory cube ID
+ query: Search query
+ top_k: Number of top results to return
+ naive_mem_cube: Memory cube instance
+
+ Returns:
+ MemoryResponse with formatted subgraph data
+ """
+ try:
+ # Get relevant subgraph from text memory
+ memories = naive_mem_cube.text_mem.get_relevant_subgraph(
+ query, top_k=top_k, user_name=mem_cube_id
+ )
+
+ # Format and convert to tree structure
+ memories_cleaned = remove_embedding_recursive(memories)
+ custom_type_ratios = {
+ "WorkingMemory": 0.20,
+ "LongTermMemory": 0.40,
+ "UserMemory": 0.40,
+ }
+ tree_result, node_type_count = convert_graph_to_tree_forworkmem(
+ memories_cleaned, target_node_count=150, type_ratios=custom_type_ratios
+ )
+ # Ensure all node IDs are unique in the tree structure
+ tree_result = ensure_unique_tree_ids(tree_result)
+ memories_filtered = filter_nodes_by_tree_ids(tree_result, memories_cleaned)
+ children = tree_result["children"]
+ children_sort = sort_children_by_memory_type(children)
+ tree_result["children"] = children_sort
+ memories_filtered["tree_structure"] = tree_result
+
+ reformat_memory_list = [
+ {
+ "cube_id": mem_cube_id,
+ "memories": [memories_filtered],
+ "memory_statistics": node_type_count,
+ }
+ ]
+
+ return MemoryResponse(
+ message="Memories retrieved successfully",
+ data=reformat_memory_list,
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to get subgraph: {e}", exc_info=True)
+ raise
diff --git a/src/memos/api/handlers/scheduler_handler.py b/src/memos/api/handlers/scheduler_handler.py
new file mode 100644
index 000000000..8d3c6dc70
--- /dev/null
+++ b/src/memos/api/handlers/scheduler_handler.py
@@ -0,0 +1,220 @@
+"""
+Scheduler handler for scheduler management functionality.
+
+This module handles all scheduler-related operations including status checking,
+waiting for idle state, and streaming progress updates.
+"""
+
+import json
+import time
+import traceback
+
+from typing import Any
+
+from fastapi import HTTPException
+from fastapi.responses import StreamingResponse
+
+from memos.api.handlers.formatters_handler import to_iter
+from memos.log import get_logger
+
+
+logger = get_logger(__name__)
+
+
+def handle_scheduler_status(
+ user_name: str | None = None,
+ mem_scheduler: Any | None = None,
+ instance_id: str = "",
+) -> dict[str, Any]:
+ """
+ Get scheduler running status.
+
+ Retrieves the number of running tasks for a specific user or globally.
+
+ Args:
+ user_name: Optional specific user name to filter tasks
+ mem_scheduler: Scheduler instance
+ instance_id: Instance ID for response
+
+ Returns:
+ Dictionary with status information
+
+ Raises:
+ HTTPException: If status retrieval fails
+ """
+ try:
+ if user_name:
+ running = mem_scheduler.dispatcher.get_running_tasks(
+ lambda task: getattr(task, "mem_cube_id", None) == user_name
+ )
+ tasks_iter = to_iter(running)
+ running_count = len(tasks_iter)
+ return {
+ "message": "ok",
+ "data": {
+ "scope": "user",
+ "user_name": user_name,
+ "running_tasks": running_count,
+ "timestamp": time.time(),
+ "instance_id": instance_id,
+ },
+ }
+ else:
+ running_all = mem_scheduler.dispatcher.get_running_tasks(lambda _t: True)
+ tasks_iter = to_iter(running_all)
+ running_count = len(tasks_iter)
+
+ task_count_per_user: dict[str, int] = {}
+ for task in tasks_iter:
+ cube = getattr(task, "mem_cube_id", "unknown")
+ task_count_per_user[cube] = task_count_per_user.get(cube, 0) + 1
+
+ try:
+ metrics_snapshot = mem_scheduler.dispatcher.metrics.snapshot()
+ except Exception:
+ metrics_snapshot = {}
+
+ return {
+ "message": "ok",
+ "data": {
+ "scope": "global",
+ "running_tasks": running_count,
+ "task_count_per_user": task_count_per_user,
+ "timestamp": time.time(),
+ "instance_id": instance_id,
+ "metrics": metrics_snapshot,
+ },
+ }
+ except Exception as err:
+ logger.error("Failed to get scheduler status: %s", traceback.format_exc())
+ raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err
+
+
+def handle_scheduler_wait(
+ user_name: str,
+ timeout_seconds: float = 120.0,
+ poll_interval: float = 0.2,
+ mem_scheduler: Any | None = None,
+) -> dict[str, Any]:
+ """
+ Wait until scheduler is idle for a specific user.
+
+ Blocks until scheduler has no running tasks for the given user, or timeout.
+
+ Args:
+ user_name: User name to wait for
+ timeout_seconds: Maximum wait time in seconds
+ poll_interval: Polling interval in seconds
+ mem_scheduler: Scheduler instance
+
+ Returns:
+ Dictionary with wait result and statistics
+
+ Raises:
+ HTTPException: If wait operation fails
+ """
+ start = time.time()
+ try:
+ while True:
+ running = mem_scheduler.dispatcher.get_running_tasks(
+ lambda task: task.mem_cube_id == user_name
+ )
+ running_count = len(running)
+ elapsed = time.time() - start
+
+ # success -> scheduler is idle
+ if running_count == 0:
+ return {
+ "message": "idle",
+ "data": {
+ "running_tasks": 0,
+ "waited_seconds": round(elapsed, 3),
+ "timed_out": False,
+ "user_name": user_name,
+ },
+ }
+
+ # timeout check
+ if elapsed > timeout_seconds:
+ return {
+ "message": "timeout",
+ "data": {
+ "running_tasks": running_count,
+ "waited_seconds": round(elapsed, 3),
+ "timed_out": True,
+ "user_name": user_name,
+ },
+ }
+
+ time.sleep(poll_interval)
+
+ except Exception as err:
+ logger.error("Failed while waiting for scheduler: %s", traceback.format_exc())
+ raise HTTPException(status_code=500, detail="Failed while waiting for scheduler") from err
+
+
+def handle_scheduler_wait_stream(
+ user_name: str,
+ timeout_seconds: float = 120.0,
+ poll_interval: float = 0.2,
+ mem_scheduler: Any | None = None,
+ instance_id: str = "",
+) -> StreamingResponse:
+ """
+ Stream scheduler progress via Server-Sent Events (SSE).
+
+ Emits periodic heartbeat frames while tasks are running, then final
+ status frame indicating idle or timeout.
+
+ Args:
+ user_name: User name to monitor
+ timeout_seconds: Maximum stream duration in seconds
+ poll_interval: Polling interval between updates
+ mem_scheduler: Scheduler instance
+ instance_id: Instance ID for response
+
+ Returns:
+ StreamingResponse with SSE formatted progress updates
+
+ Example:
+ curl -N "http://localhost:8000/product/scheduler/wait/stream?timeout_seconds=10"
+ """
+
+ def event_generator():
+ start = time.time()
+ try:
+ while True:
+ running = mem_scheduler.dispatcher.get_running_tasks(
+ lambda task: task.mem_cube_id == user_name
+ )
+ running_count = len(running)
+ elapsed = time.time() - start
+
+ payload = {
+ "user_name": user_name,
+ "running_tasks": running_count,
+ "elapsed_seconds": round(elapsed, 3),
+ "status": "running" if running_count > 0 else "idle",
+ "instance_id": instance_id,
+ }
+ yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n"
+
+ if running_count == 0 or elapsed > timeout_seconds:
+ payload["status"] = "idle" if running_count == 0 else "timeout"
+ payload["timed_out"] = running_count > 0
+ yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n"
+ break
+
+ time.sleep(poll_interval)
+
+ except Exception as e:
+ err_payload = {
+ "status": "error",
+ "detail": "stream_failed",
+ "exception": str(e),
+ "user_name": user_name,
+ }
+ logger.error(f"Scheduler stream error for {user_name}: {traceback.format_exc()}")
+ yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n"
+
+ return StreamingResponse(event_generator(), media_type="text/event-stream")
diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py
new file mode 100644
index 000000000..7d7d52dc4
--- /dev/null
+++ b/src/memos/api/handlers/search_handler.py
@@ -0,0 +1,328 @@
+"""
+Search handler for memory search functionality (Class-based version).
+
+This module provides a class-based implementation of search handlers,
+using dependency injection for better modularity and testability.
+"""
+
+import os
+import traceback
+
+from typing import Any
+
+from memos.api.handlers.base_handler import BaseHandler, HandlerDependencies
+from memos.api.handlers.formatters_handler import (
+ format_memory_item,
+ post_process_pref_mem,
+)
+from memos.api.product_models import APISearchRequest, SearchResponse
+from memos.context.context import ContextThreadPoolExecutor
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.general_schemas import FINE_STRATEGY, FineStrategy, SearchMode
+from memos.types import MOSSearchResult, UserContext
+
+
+logger = get_logger(__name__)
+
+
+class SearchHandler(BaseHandler):
+ """
+ Handler for memory search operations.
+
+ Provides fast, fine-grained, and mixture-based search modes.
+ """
+
+ def __init__(self, dependencies: HandlerDependencies):
+ """
+ Initialize search handler.
+
+ Args:
+ dependencies: HandlerDependencies instance
+ """
+ super().__init__(dependencies)
+ self._validate_dependencies("naive_mem_cube", "mem_scheduler", "searcher")
+
+ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse:
+ """
+ Main handler for search memories endpoint.
+
+ Orchestrates the search process based on the requested search mode,
+ supporting both text and preference memory searches.
+
+ Args:
+ search_req: Search request containing query and parameters
+
+ Returns:
+ SearchResponse with formatted results
+ """
+ # Create UserContext object
+ user_context = UserContext(
+ user_id=search_req.user_id,
+ mem_cube_id=search_req.mem_cube_id,
+ session_id=search_req.session_id or "default_session",
+ )
+ self.logger.info(f"Search Req is: {search_req}")
+
+ memories_result: MOSSearchResult = {
+ "text_mem": [],
+ "act_mem": [],
+ "para_mem": [],
+ "pref_mem": [],
+ "pref_note": "",
+ }
+
+ # Determine search mode
+ search_mode = self._get_search_mode(search_req.mode)
+
+ # Execute search in parallel for text and preference memories
+ with ContextThreadPoolExecutor(max_workers=2) as executor:
+ text_future = executor.submit(self._search_text, search_req, user_context, search_mode)
+ pref_future = executor.submit(self._search_pref, search_req, user_context)
+
+ text_formatted_memories = text_future.result()
+ pref_formatted_memories = pref_future.result()
+
+ # Build result
+ memories_result["text_mem"].append(
+ {
+ "cube_id": search_req.mem_cube_id,
+ "memories": text_formatted_memories,
+ }
+ )
+
+ memories_result = post_process_pref_mem(
+ memories_result,
+ pref_formatted_memories,
+ search_req.mem_cube_id,
+ search_req.include_preference,
+ )
+
+ self.logger.info(f"Search memories result: {memories_result}")
+
+ return SearchResponse(
+ message="Search completed successfully",
+ data=memories_result,
+ )
+
+ def _get_search_mode(self, mode: str) -> str:
+ return mode
+
+ def _search_text(
+ self,
+ search_req: APISearchRequest,
+ user_context: UserContext,
+ search_mode: str,
+ ) -> list[dict[str, Any]]:
+ """
+ Search text memories based on mode.
+
+ Args:
+ search_req: Search request
+ user_context: User context
+ search_mode: Search mode (FAST, FINE, or MIXTURE)
+
+ Returns:
+ List of formatted memory items
+ """
+ try:
+ if search_mode == SearchMode.FAST:
+ text_memories = self._fast_search(search_req, user_context)
+ elif search_mode == SearchMode.FINE:
+ text_memories = self._fine_search(search_req, user_context)
+ elif search_mode == SearchMode.MIXTURE:
+ text_memories = self._mix_search(search_req, user_context)
+ else:
+ self.logger.error(f"Unsupported search mode: {search_mode}")
+ return []
+
+ return text_memories
+
+ except Exception as e:
+ self.logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc())
+ return []
+
+ def _search_pref(
+ self,
+ search_req: APISearchRequest,
+ user_context: UserContext,
+ ) -> list[dict[str, Any]]:
+ """
+ Search preference memories.
+
+ Args:
+ search_req: Search request
+ user_context: User context
+
+ Returns:
+ List of formatted preference memory items
+ """
+ if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true":
+ return []
+
+ try:
+ results = self.naive_mem_cube.pref_mem.search(
+ query=search_req.query,
+ top_k=search_req.pref_top_k,
+ info={
+ "user_id": search_req.user_id,
+ "session_id": search_req.session_id,
+ "chat_history": search_req.chat_history,
+ },
+ )
+ return [format_memory_item(data) for data in results]
+ except Exception as e:
+ self.logger.error("Error in _search_pref: %s; traceback: %s", e, traceback.format_exc())
+ return []
+
+ def _fast_search(
+ self,
+ search_req: APISearchRequest,
+ user_context: UserContext,
+ ) -> list:
+ """
+ Fast search using vector database.
+
+ Args:
+ search_req: Search request
+ user_context: User context
+
+ Returns:
+ List of search results
+ """
+ target_session_id = search_req.session_id or "default_session"
+ search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
+
+ search_results = self.naive_mem_cube.text_mem.search(
+ query=search_req.query,
+ user_name=user_context.mem_cube_id,
+ top_k=search_req.top_k,
+ mode=SearchMode.FAST,
+ manual_close_internet=not search_req.internet_search,
+ moscube=search_req.moscube,
+ search_filter=search_filter,
+ info={
+ "user_id": search_req.user_id,
+ "session_id": target_session_id,
+ "chat_history": search_req.chat_history,
+ },
+ )
+
+ formatted_memories = [format_memory_item(data) for data in search_results]
+
+ return formatted_memories
+
+ def _deep_search(
+ self, search_req: APISearchRequest, user_context: UserContext, max_thinking_depth: int
+ ) -> list:
+ logger.error("waiting to be implemented")
+ return []
+
+ def _fine_search(
+ self,
+ search_req: APISearchRequest,
+ user_context: UserContext,
+ ) -> list[str]:
+ """
+ Fine-grained search with query enhancement.
+
+ Args:
+ search_req: Search request
+ user_context: User context
+
+ Returns:
+ List of enhanced search results
+ """
+ if FINE_STRATEGY == FineStrategy.DEEP_SEARCH:
+ return self._deep_search(
+ search_req=search_req, user_context=user_context, max_thinking_depth=3
+ )
+
+ target_session_id = search_req.session_id or "default_session"
+ search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
+
+ info = {
+ "user_id": search_req.user_id,
+ "session_id": target_session_id,
+ "chat_history": search_req.chat_history,
+ }
+
+ # Fine retrieve
+ raw_retrieved_memories = self.searcher.retrieve(
+ query=search_req.query,
+ user_name=user_context.mem_cube_id,
+ top_k=search_req.top_k,
+ mode=SearchMode.FINE,
+ manual_close_internet=not search_req.internet_search,
+ moscube=search_req.moscube,
+ search_filter=search_filter,
+ info=info,
+ )
+
+ # Post retrieve
+ raw_memories = self.searcher.post_retrieve(
+ retrieved_results=raw_retrieved_memories,
+ top_k=search_req.top_k,
+ user_name=user_context.mem_cube_id,
+ info=info,
+ )
+
+ # Enhance with query
+ enhanced_memories, _ = self.mem_scheduler.retriever.enhance_memories_with_query(
+ query_history=[search_req.query],
+ memories=raw_memories,
+ )
+
+ if len(enhanced_memories) < len(raw_memories):
+ logger.info(
+ f"Enhanced memories ({len(enhanced_memories)}) are less than raw memories ({len(raw_memories)}). Recalling for more."
+ )
+ missing_info_hint, trigger = self.mem_scheduler.retriever.recall_for_missing_memories(
+ query=search_req.query,
+ memories=raw_memories,
+ )
+ retrieval_size = len(raw_memories) - len(enhanced_memories)
+ logger.info(f"Retrieval size: {retrieval_size}")
+ if trigger:
+ logger.info(f"Triggering additional search with hint: {missing_info_hint}")
+ additional_memories = self.searcher.search(
+ query=missing_info_hint,
+ user_name=user_context.mem_cube_id,
+ top_k=retrieval_size,
+ mode=SearchMode.FAST,
+ memory_type="All",
+ search_filter=search_filter,
+ info=info,
+ )
+ else:
+ logger.info("Not triggering additional search, using fast memories.")
+ additional_memories = raw_memories[:retrieval_size]
+
+ enhanced_memories += additional_memories
+ logger.info(
+ f"Added {len(additional_memories)} more memories. Total enhanced memories: {len(enhanced_memories)}"
+ )
+ formatted_memories = [format_memory_item(data) for data in enhanced_memories]
+
+ logger.info(f"Found {len(formatted_memories)} memories for user {search_req.user_id}")
+
+ return formatted_memories
+
+ def _mix_search(
+ self,
+ search_req: APISearchRequest,
+ user_context: UserContext,
+ ) -> list:
+ """
+ Mix search combining fast and fine-grained approaches.
+
+ Args:
+ search_req: Search request
+ user_context: User context
+
+ Returns:
+ List of formatted search results
+ """
+ return self.mem_scheduler.mix_search_memories(
+ search_req=search_req,
+ user_context=user_context,
+ )
diff --git a/src/memos/api/handlers/suggestion_handler.py b/src/memos/api/handlers/suggestion_handler.py
new file mode 100644
index 000000000..dce894003
--- /dev/null
+++ b/src/memos/api/handlers/suggestion_handler.py
@@ -0,0 +1,117 @@
+"""
+Suggestion handler for generating suggestion queries.
+
+This module handles suggestion query generation based on user's recent memories
+or further suggestions from chat history.
+"""
+
+import json
+
+from typing import Any
+
+from memos.api.product_models import SuggestionResponse
+from memos.log import get_logger
+from memos.mem_os.utils.format_utils import clean_json_response
+from memos.templates.mos_prompts import (
+ FURTHER_SUGGESTION_PROMPT,
+ SUGGESTION_QUERY_PROMPT_EN,
+ SUGGESTION_QUERY_PROMPT_ZH,
+)
+from memos.types import MessageList
+
+
+logger = get_logger(__name__)
+
+
+def _get_further_suggestion(
+ llm: Any,
+ message: MessageList,
+) -> list[str]:
+ """
+ Get further suggestion based on recent dialogue.
+
+ Args:
+ llm: LLM instance for generating suggestions
+ message: Recent chat messages
+
+ Returns:
+ List of suggestion queries
+ """
+ try:
+ dialogue_info = "\n".join([f"{msg['role']}: {msg['content']}" for msg in message[-2:]])
+ further_suggestion_prompt = FURTHER_SUGGESTION_PROMPT.format(dialogue=dialogue_info)
+ message_list = [{"role": "system", "content": further_suggestion_prompt}]
+ response = llm.generate(message_list)
+ clean_response = clean_json_response(response)
+ response_json = json.loads(clean_response)
+ return response_json["query"]
+ except Exception as e:
+ logger.error(f"Error getting further suggestion: {e}", exc_info=True)
+ return []
+
+
+def handle_get_suggestion_queries(
+ user_id: str,
+ language: str,
+ message: MessageList | None,
+ llm: Any,
+ naive_mem_cube: Any,
+) -> SuggestionResponse:
+ """
+ Main handler for suggestion queries endpoint.
+
+ Generates suggestion queries based on user's recent memories or chat history.
+
+ Args:
+ user_id: User ID
+ language: Language preference ("zh" or "en")
+ message: Optional chat message list for further suggestions
+ llm: LLM instance
+ naive_mem_cube: Memory cube instance
+
+ Returns:
+ SuggestionResponse with generated queries
+ """
+ try:
+ # If message is provided, get further suggestions based on dialogue
+ if message:
+ suggestions = _get_further_suggestion(llm, message)
+ return SuggestionResponse(
+ message="Suggestions retrieved successfully",
+ data={"query": suggestions},
+ )
+
+ # Otherwise, generate suggestions based on recent memories
+ if language == "zh":
+ suggestion_prompt = SUGGESTION_QUERY_PROMPT_ZH
+ else: # English
+ suggestion_prompt = SUGGESTION_QUERY_PROMPT_EN
+
+ # Search for recent memories
+ text_mem_results = naive_mem_cube.text_mem.search(
+ query="my recently memories",
+ user_name=user_id,
+ top_k=3,
+ mode="fast",
+ info={"user_id": user_id},
+ )
+
+ # Extract memory content
+ memories = ""
+ if text_mem_results:
+ memories = "\n".join([m.memory[:200] for m in text_mem_results])
+
+ # Generate suggestions using LLM
+ message_list = [{"role": "system", "content": suggestion_prompt.format(memories=memories)}]
+ response = llm.generate(message_list)
+ clean_response = clean_json_response(response)
+ response_json = json.loads(clean_response)
+
+ return SuggestionResponse(
+ message="Suggestions retrieved successfully",
+ data={"query": response_json["query"]},
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to get suggestions: {e}", exc_info=True)
+ raise
diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py
index 0412754c3..30df150ea 100644
--- a/src/memos/api/product_models.py
+++ b/src/memos/api/product_models.py
@@ -1,3 +1,4 @@
+import os
import uuid
from typing import Generic, Literal, TypeVar
@@ -171,7 +172,9 @@ class APISearchRequest(BaseRequest):
query: str = Field(..., description="Search query")
user_id: str = Field(None, description="User ID")
mem_cube_id: str | None = Field(None, description="Cube ID to search in")
- mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
+ mode: SearchMode = Field(
+ os.getenv("SEARCH_MODE", SearchMode.FAST), description="search mode: fast, fine, or mixture"
+ )
internet_search: bool = Field(False, description="Whether to use internet search")
moscube: bool = Field(False, description="Whether to use MemOSCube")
top_k: int = Field(10, description="Number of results to return")
@@ -198,6 +201,9 @@ class APIADDRequest(BaseRequest):
operation: list[PermissionDict] | None = Field(
None, description="operation ids for multi cubes"
)
+ async_mode: Literal["async", "sync"] = Field(
+ "async", description="Whether to add memory in async mode"
+ )
class APIChatCompleteRequest(BaseRequest):
@@ -221,6 +227,7 @@ class SuggestionRequest(BaseRequest):
"""Request model for getting suggestion queries."""
user_id: str = Field(..., description="User ID")
+ mem_cube_id: str = Field(..., description="Cube ID")
language: Literal["zh", "en"] = Field("zh", description="Language for suggestions")
message: list[MessageDict] | None = Field(None, description="List of messages to store.")
diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py
index 23aec0cb0..b3b517305 100644
--- a/src/memos/api/routers/server_router.py
+++ b/src/memos/api/routers/server_router.py
@@ -1,838 +1,221 @@
-import json
+"""
+Server API Router for MemOS (Class-based handlers version).
+
+This router demonstrates the improved architecture using class-based handlers
+with dependency injection, providing better modularity and maintainability.
+
+Comparison with function-based approach:
+- Cleaner code: No need to pass dependencies in every endpoint
+- Better testability: Easy to mock handler dependencies
+- Improved extensibility: Add new handlers or modify existing ones easily
+- Clear separation of concerns: Router focuses on routing, handlers handle business logic
+"""
+
import os
import random as _random
import socket
-import time
-import traceback
-
-from collections.abc import Iterable
-from datetime import datetime
-from typing import TYPE_CHECKING, Any
-from fastapi import APIRouter, HTTPException
-from fastapi.responses import StreamingResponse
+from fastapi import APIRouter
-from memos.api.config import APIConfig
+from memos.api import handlers
+from memos.api.handlers.add_handler import AddHandler
+from memos.api.handlers.base_handler import HandlerDependencies
+from memos.api.handlers.chat_handler import ChatHandler
+from memos.api.handlers.search_handler import SearchHandler
from memos.api.product_models import (
APIADDRequest,
APIChatCompleteRequest,
APISearchRequest,
+ ChatRequest,
+ GetMemoryRequest,
MemoryResponse,
SearchResponse,
+ SuggestionRequest,
+ SuggestionResponse,
)
-from memos.configs.embedder import EmbedderConfigFactory
-from memos.configs.graph_db import GraphDBConfigFactory
-from memos.configs.internet_retriever import InternetRetrieverConfigFactory
-from memos.configs.llm import LLMConfigFactory
-from memos.configs.mem_reader import MemReaderConfigFactory
-from memos.configs.mem_scheduler import SchedulerConfigFactory
-from memos.configs.reranker import RerankerConfigFactory
-from memos.configs.vec_db import VectorDBConfigFactory
-from memos.context.context import ContextThreadPoolExecutor
-from memos.embedders.factory import EmbedderFactory
-from memos.graph_dbs.factory import GraphStoreFactory
-from memos.llms.factory import LLMFactory
from memos.log import get_logger
-from memos.mem_cube.navie import NaiveMemCube
-from memos.mem_os.product_server import MOSServer
-from memos.mem_reader.factory import MemReaderFactory
-from memos.mem_scheduler.orm_modules.base_model import BaseDBManager
-from memos.mem_scheduler.scheduler_factory import SchedulerFactory
-from memos.mem_scheduler.schemas.general_schemas import (
- ADD_LABEL,
- MEM_READ_LABEL,
- PREF_ADD_LABEL,
- SearchMode,
-)
-from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
-from memos.memories.textual.prefer_text_memory.config import (
- AdderConfigFactory,
- ExtractorConfigFactory,
- RetrieverConfigFactory,
-)
-from memos.memories.textual.prefer_text_memory.factory import (
- AdderFactory,
- ExtractorFactory,
- RetrieverFactory,
-)
-from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
-from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import (
- InternetRetrieverFactory,
-)
-from memos.reranker.factory import RerankerFactory
-from memos.templates.instruction_completion import instruct_completion
-
-
-if TYPE_CHECKING:
- from memos.mem_scheduler.optimized_scheduler import OptimizedScheduler
-from memos.types import MOSSearchResult, UserContext
-from memos.vec_dbs.factory import VecDBFactory
+from memos.mem_scheduler.base_scheduler import BaseScheduler
logger = get_logger(__name__)
router = APIRouter(prefix="/product", tags=["Server API"])
-INSTANCE_ID = f"{socket.gethostname()}:{os.getpid()}:{_random.randint(1000, 9999)}"
-
-
-def _to_iter(running: Any) -> Iterable:
- """Normalize running tasks to an iterable of task objects."""
- if running is None:
- return []
- if isinstance(running, dict):
- return running.values()
- return running # assume it's already an iterable (e.g., list)
-
-
-def _build_graph_db_config(user_id: str = "default") -> dict[str, Any]:
- """Build graph database configuration."""
- graph_db_backend_map = {
- "neo4j-community": APIConfig.get_neo4j_community_config(user_id=user_id),
- "neo4j": APIConfig.get_neo4j_config(user_id=user_id),
- "nebular": APIConfig.get_nebular_config(user_id=user_id),
- "polardb": APIConfig.get_polardb_config(user_id=user_id),
- }
-
- graph_db_backend = os.getenv("NEO4J_BACKEND", "nebular").lower()
- return GraphDBConfigFactory.model_validate(
- {
- "backend": graph_db_backend,
- "config": graph_db_backend_map[graph_db_backend],
- }
- )
-
-def _build_vec_db_config() -> dict[str, Any]:
- """Build vector database configuration."""
- return VectorDBConfigFactory.model_validate(
- {
- "backend": "milvus",
- "config": APIConfig.get_milvus_config(),
- }
- )
-
-
-def _build_llm_config() -> dict[str, Any]:
- """Build LLM configuration."""
- return LLMConfigFactory.model_validate(
- {
- "backend": "openai",
- "config": APIConfig.get_openai_config(),
- }
- )
-
-
-def _build_embedder_config() -> dict[str, Any]:
- """Build embedder configuration."""
- return EmbedderConfigFactory.model_validate(APIConfig.get_embedder_config())
+# Instance ID for identifying this server instance in logs and responses
+INSTANCE_ID = f"{socket.gethostname()}:{os.getpid()}:{_random.randint(1000, 9999)}"
+# Initialize all server components
+components = handlers.init_server()
-def _build_mem_reader_config() -> dict[str, Any]:
- """Build memory reader configuration."""
- return MemReaderConfigFactory.model_validate(
- APIConfig.get_product_default_config()["mem_reader"]
- )
+# Create dependency container
+dependencies = HandlerDependencies.from_init_server(components)
+# Initialize all handlers with dependency injection
+search_handler = SearchHandler(dependencies)
+add_handler = AddHandler(dependencies)
+chat_handler = ChatHandler(
+ dependencies, search_handler, add_handler, online_bot=components.get("online_bot")
+)
-def _build_reranker_config() -> dict[str, Any]:
- """Build reranker configuration."""
- return RerankerConfigFactory.model_validate(APIConfig.get_reranker_config())
-
+# Extract commonly used components for function-based handlers
+# (These can be accessed from the components dict without unpacking all of them)
+mem_scheduler: BaseScheduler = components["mem_scheduler"]
+llm = components["llm"]
+naive_mem_cube = components["naive_mem_cube"]
-def _build_internet_retriever_config() -> dict[str, Any]:
- """Build internet retriever configuration."""
- return InternetRetrieverConfigFactory.model_validate(APIConfig.get_internet_config())
+# =============================================================================
+# Search API Endpoints
+# =============================================================================
-def _build_pref_extractor_config() -> dict[str, Any]:
- """Build extractor configuration."""
- return ExtractorConfigFactory.model_validate({"backend": "naive", "config": {}})
+@router.post("/search", summary="Search memories", response_model=SearchResponse)
+def search_memories(search_req: APISearchRequest):
+ """
+ Search memories for a specific user.
-def _build_pref_adder_config() -> dict[str, Any]:
- """Build adder configuration."""
- return AdderConfigFactory.model_validate({"backend": "naive", "config": {}})
+ This endpoint uses the class-based SearchHandler for better code organization.
+ """
+ return search_handler.handle_search_memories(search_req)
-def _build_pref_retriever_config() -> dict[str, Any]:
- """Build retriever configuration."""
- return RetrieverConfigFactory.model_validate({"backend": "naive", "config": {}})
+# =============================================================================
+# Add API Endpoints
+# =============================================================================
-def _get_default_memory_size(cube_config) -> dict[str, int]:
- """Get default memory size configuration."""
- return getattr(cube_config.text_mem.config, "memory_size", None) or {
- "WorkingMemory": 20,
- "LongTermMemory": 1500,
- "UserMemory": 480,
- }
+@router.post("/add", summary="Add memories", response_model=MemoryResponse)
+def add_memories(add_req: APIADDRequest):
+ """
+ Add memories for a specific user.
+ This endpoint uses the class-based AddHandler for better code organization.
+ """
+ return add_handler.handle_add_memories(add_req)
-def init_server():
- """Initialize server components and configurations."""
- # Get default cube configuration
- default_cube_config = APIConfig.get_default_cube_config()
- # Build component configurations
- graph_db_config = _build_graph_db_config()
- llm_config = _build_llm_config()
- embedder_config = _build_embedder_config()
- mem_reader_config = _build_mem_reader_config()
- reranker_config = _build_reranker_config()
- internet_retriever_config = _build_internet_retriever_config()
- vector_db_config = _build_vec_db_config()
- pref_extractor_config = _build_pref_extractor_config()
- pref_adder_config = _build_pref_adder_config()
- pref_retriever_config = _build_pref_retriever_config()
+# =============================================================================
+# Scheduler API Endpoints
+# =============================================================================
- # Create component instances
- graph_db = GraphStoreFactory.from_config(graph_db_config)
- vector_db = (
- VecDBFactory.from_config(vector_db_config)
- if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "true"
- else None
- )
- llm = LLMFactory.from_config(llm_config)
- embedder = EmbedderFactory.from_config(embedder_config)
- mem_reader = MemReaderFactory.from_config(mem_reader_config)
- reranker = RerankerFactory.from_config(reranker_config)
- internet_retriever = InternetRetrieverFactory.from_config(
- internet_retriever_config, embedder=embedder
- )
- pref_extractor = (
- ExtractorFactory.from_config(
- config_factory=pref_extractor_config,
- llm_provider=llm,
- embedder=embedder,
- vector_db=vector_db,
- )
- if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "true"
- else None
- )
- pref_adder = (
- AdderFactory.from_config(
- config_factory=pref_adder_config,
- llm_provider=llm,
- embedder=embedder,
- vector_db=vector_db,
- )
- if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "true"
- else None
- )
- pref_retriever = (
- RetrieverFactory.from_config(
- config_factory=pref_retriever_config,
- llm_provider=llm,
- embedder=embedder,
- reranker=reranker,
- vector_db=vector_db,
- )
- if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "true"
- else None
- )
- # Initialize memory manager
- memory_manager = MemoryManager(
- graph_db,
- embedder,
- llm,
- memory_size=_get_default_memory_size(default_cube_config),
- is_reorganize=getattr(default_cube_config.text_mem.config, "reorganize", False),
- )
- mos_server = MOSServer(
- mem_reader=mem_reader,
- llm=llm,
- online_bot=False,
+@router.get("/scheduler/status", summary="Get scheduler running status")
+def scheduler_status(user_name: str | None = None):
+ """Get scheduler running status."""
+ return handlers.scheduler_handler.handle_scheduler_status(
+ user_name=user_name,
+ mem_scheduler=mem_scheduler,
+ instance_id=INSTANCE_ID,
)
- naive_mem_cube = NaiveMemCube(
- llm=llm,
- embedder=embedder,
- mem_reader=mem_reader,
- graph_db=graph_db,
- reranker=reranker,
- internet_retriever=internet_retriever,
- memory_manager=memory_manager,
- default_cube_config=default_cube_config,
- vector_db=vector_db,
- pref_extractor=pref_extractor,
- pref_adder=pref_adder,
- pref_retriever=pref_retriever,
- )
- # Initialize Scheduler
- scheduler_config_dict = APIConfig.get_scheduler_config()
- scheduler_config = SchedulerConfigFactory(
- backend="optimized_scheduler", config=scheduler_config_dict
- )
- mem_scheduler: OptimizedScheduler = SchedulerFactory.from_config(scheduler_config)
- mem_scheduler.initialize_modules(
- chat_llm=llm,
- process_llm=mem_reader.llm,
- db_engine=BaseDBManager.create_default_sqlite_engine(),
- mem_reader=mem_reader,
- )
- mem_scheduler.current_mem_cube = naive_mem_cube
- mem_scheduler.start()
-
- # Initialize SchedulerAPIModule
- api_module = mem_scheduler.api_module
-
- return (
- graph_db,
- mem_reader,
- llm,
- embedder,
- reranker,
- internet_retriever,
- memory_manager,
- default_cube_config,
- mos_server,
- mem_scheduler,
- naive_mem_cube,
- api_module,
- vector_db,
- pref_extractor,
- pref_adder,
- pref_retriever,
+@router.post("/scheduler/wait", summary="Wait until scheduler is idle for a specific user")
+def scheduler_wait(
+ user_name: str,
+ timeout_seconds: float = 120.0,
+ poll_interval: float = 0.2,
+):
+ """Wait until scheduler is idle for a specific user."""
+ return handlers.scheduler_handler.handle_scheduler_wait(
+ user_name=user_name,
+ timeout_seconds=timeout_seconds,
+ poll_interval=poll_interval,
+ mem_scheduler=mem_scheduler,
)
-# Initialize global components
-(
- graph_db,
- mem_reader,
- llm,
- embedder,
- reranker,
- internet_retriever,
- memory_manager,
- default_cube_config,
- mos_server,
- mem_scheduler,
- naive_mem_cube,
- api_module,
- vector_db,
- pref_extractor,
- pref_adder,
- pref_retriever,
-) = init_server()
-
-
-def _format_memory_item(memory_data: Any) -> dict[str, Any]:
- """Format a single memory item for API response."""
- memory = memory_data.model_dump()
- memory_id = memory["id"]
- ref_id = f"[{memory_id.split('-')[0]}]"
-
- memory["ref_id"] = ref_id
- memory["metadata"]["embedding"] = []
- memory["metadata"]["sources"] = []
- memory["metadata"]["usage"] = []
- memory["metadata"]["ref_id"] = ref_id
- memory["metadata"]["id"] = memory_id
- memory["metadata"]["memory"] = memory["memory"]
-
- return memory
-
-
-def _post_process_pref_mem(
- memories_result: list[dict[str, Any]],
- pref_formatted_mem: list[dict[str, Any]],
- mem_cube_id: str,
- include_preference: bool,
+@router.get("/scheduler/wait/stream", summary="Stream scheduler progress for a user")
+def scheduler_wait_stream(
+ user_name: str,
+ timeout_seconds: float = 120.0,
+ poll_interval: float = 0.2,
):
- if include_preference:
- memories_result["pref_mem"].append(
- {
- "cube_id": mem_cube_id,
- "memories": pref_formatted_mem,
- }
- )
- pref_instruction, pref_note = instruct_completion(pref_formatted_mem)
- memories_result["pref_string"] = pref_instruction
- memories_result["pref_note"] = pref_note
-
- return memories_result
-
-
-@router.post("/search", summary="Search memories", response_model=SearchResponse)
-def search_memories(search_req: APISearchRequest):
- """Search memories for a specific user."""
- # Create UserContext object - how to assign values
- user_context = UserContext(
- user_id=search_req.user_id,
- mem_cube_id=search_req.mem_cube_id,
- session_id=search_req.session_id or "default_session",
- )
- logger.info(f"Search Req is: {search_req}")
- memories_result: MOSSearchResult = {
- "text_mem": [],
- "act_mem": [],
- "para_mem": [],
- "pref_mem": [],
- "pref_note": "",
- }
-
- search_mode = search_req.mode
-
- def _search_text():
- if search_mode == SearchMode.FAST:
- formatted_memories = fast_search_memories(
- search_req=search_req, user_context=user_context
- )
- elif search_mode == SearchMode.FINE:
- formatted_memories = fine_search_memories(
- search_req=search_req, user_context=user_context
- )
- elif search_mode == SearchMode.MIXTURE:
- formatted_memories = mix_search_memories(
- search_req=search_req, user_context=user_context
- )
- else:
- logger.error(f"Unsupported search mode: {search_mode}")
- raise HTTPException(status_code=400, detail=f"Unsupported search mode: {search_mode}")
- return formatted_memories
-
- def _search_pref():
- if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true":
- return []
- results = naive_mem_cube.pref_mem.search(
- query=search_req.query,
- top_k=search_req.pref_top_k,
- info={
- "user_id": search_req.user_id,
- "session_id": search_req.session_id,
- "chat_history": search_req.chat_history,
- },
- )
- return [_format_memory_item(data) for data in results]
-
- with ContextThreadPoolExecutor(max_workers=2) as executor:
- text_future = executor.submit(_search_text)
- pref_future = executor.submit(_search_pref)
- text_formatted_memories = text_future.result()
- pref_formatted_memories = pref_future.result()
-
- memories_result["text_mem"].append(
- {
- "cube_id": search_req.mem_cube_id,
- "memories": text_formatted_memories,
- }
+ """Stream scheduler progress via Server-Sent Events (SSE)."""
+ return handlers.scheduler_handler.handle_scheduler_wait_stream(
+ user_name=user_name,
+ timeout_seconds=timeout_seconds,
+ poll_interval=poll_interval,
+ mem_scheduler=mem_scheduler,
+ instance_id=INSTANCE_ID,
)
- memories_result = _post_process_pref_mem(
- memories_result,
- pref_formatted_memories,
- search_req.mem_cube_id,
- search_req.include_preference,
- )
-
- logger.info(f"Search memories result: {memories_result}")
- return SearchResponse(
- message="Search completed successfully",
- data=memories_result,
- )
+# =============================================================================
+# Chat API Endpoints
+# =============================================================================
-def mix_search_memories(
- search_req: APISearchRequest,
- user_context: UserContext,
-):
- """
- Mix search memories: fast search + async fine search
+@router.post("/chat/complete", summary="Chat with MemOS (Complete Response)")
+def chat_complete(chat_req: APIChatCompleteRequest):
"""
+ Chat with MemOS for a specific user. Returns complete response (non-streaming).
- formatted_memories = mem_scheduler.mix_search_memories(
- search_req=search_req,
- user_context=user_context,
- )
- return formatted_memories
-
-
-def fine_search_memories(
- search_req: APISearchRequest,
- user_context: UserContext,
-):
- target_session_id = search_req.session_id
- if not target_session_id:
- target_session_id = "default_session"
- search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
-
- # Create MemCube and perform search
- search_results = naive_mem_cube.text_mem.search(
- query=search_req.query,
- user_name=user_context.mem_cube_id,
- top_k=search_req.top_k,
- mode=SearchMode.FINE,
- manual_close_internet=not search_req.internet_search,
- moscube=search_req.moscube,
- search_filter=search_filter,
- info={
- "user_id": search_req.user_id,
- "session_id": target_session_id,
- "chat_history": search_req.chat_history,
- },
- )
- formatted_memories = [_format_memory_item(data) for data in search_results]
+ This endpoint uses the class-based ChatHandler.
+ """
+ return chat_handler.handle_chat_complete(chat_req)
- return formatted_memories
+@router.post("/chat", summary="Chat with MemOS")
+def chat(chat_req: ChatRequest):
+ """
+ Chat with MemOS for a specific user. Returns SSE stream.
-def fast_search_memories(
- search_req: APISearchRequest,
- user_context: UserContext,
-):
- target_session_id = search_req.session_id
- if not target_session_id:
- target_session_id = "default_session"
- search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
-
- # Create MemCube and perform search
- search_results = naive_mem_cube.text_mem.search(
- query=search_req.query,
- user_name=user_context.mem_cube_id,
- top_k=search_req.top_k,
- mode=SearchMode.FAST,
- manual_close_internet=not search_req.internet_search,
- moscube=search_req.moscube,
- search_filter=search_filter,
- info={
- "user_id": search_req.user_id,
- "session_id": target_session_id,
- "chat_history": search_req.chat_history,
- },
- )
- formatted_memories = [_format_memory_item(data) for data in search_results]
+ This endpoint uses the class-based ChatHandler which internally
+ composes SearchHandler and AddHandler for a clean architecture.
+ """
+ return chat_handler.handle_chat_stream(chat_req)
- return formatted_memories
+# =============================================================================
+# Suggestion API Endpoints
+# =============================================================================
-@router.post("/add", summary="Add memories", response_model=MemoryResponse)
-def add_memories(add_req: APIADDRequest):
- """Add memories for a specific user."""
- # Create UserContext object - how to assign values
- user_context = UserContext(
- user_id=add_req.user_id,
- mem_cube_id=add_req.mem_cube_id,
- session_id=add_req.session_id or "default_session",
- )
- logger.info(f"Add Req is: {add_req}")
-
- target_session_id = add_req.session_id
- if not target_session_id:
- target_session_id = "default_session"
-
- # If text memory backend works in async mode, submit tasks to scheduler
- try:
- sync_mode = getattr(naive_mem_cube.text_mem, "mode", "sync")
- except Exception:
- sync_mode = "sync"
- logger.info(f"Add sync_mode mode is: {sync_mode}")
-
- def _process_text_mem() -> list[dict[str, str]]:
- memories_local = mem_reader.get_memory(
- [add_req.messages],
- type="chat",
- info={
- "user_id": add_req.user_id,
- "session_id": target_session_id,
- },
- mode="fast" if sync_mode == "async" else "fine",
- )
- flattened_local = [mm for m in memories_local for mm in m]
- logger.info(f"Memory extraction completed for user {add_req.user_id}")
- mem_ids_local: list[str] = naive_mem_cube.text_mem.add(
- flattened_local,
- user_name=user_context.mem_cube_id,
- )
- logger.info(
- f"Added {len(mem_ids_local)} memories for user {add_req.user_id} "
- f"in session {add_req.session_id}: {mem_ids_local}"
- )
- if sync_mode == "async":
- try:
- message_item_read = ScheduleMessageItem(
- user_id=add_req.user_id,
- session_id=target_session_id,
- mem_cube_id=add_req.mem_cube_id,
- mem_cube=naive_mem_cube,
- label=MEM_READ_LABEL,
- content=json.dumps(mem_ids_local),
- timestamp=datetime.utcnow(),
- user_name=add_req.mem_cube_id,
- )
- mem_scheduler.submit_messages(messages=[message_item_read])
- logger.info(f"2105Submit messages!!!!!: {json.dumps(mem_ids_local)}")
- except Exception as e:
- logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True)
- else:
- message_item_add = ScheduleMessageItem(
- user_id=add_req.user_id,
- session_id=target_session_id,
- mem_cube_id=add_req.mem_cube_id,
- mem_cube=naive_mem_cube,
- label=ADD_LABEL,
- content=json.dumps(mem_ids_local),
- timestamp=datetime.utcnow(),
- user_name=add_req.mem_cube_id,
- )
- mem_scheduler.submit_messages(messages=[message_item_add])
- return [
- {
- "memory": memory.memory,
- "memory_id": memory_id,
- "memory_type": memory.metadata.memory_type,
- }
- for memory_id, memory in zip(mem_ids_local, flattened_local, strict=False)
- ]
-
- def _process_pref_mem() -> list[dict[str, str]]:
- if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() != "true":
- return []
- # Follow async behavior similar to core.py: enqueue when async
- if sync_mode == "async":
- try:
- messages_list = [add_req.messages]
- message_item_pref = ScheduleMessageItem(
- user_id=add_req.user_id,
- session_id=target_session_id,
- mem_cube_id=add_req.mem_cube_id,
- mem_cube=naive_mem_cube,
- label=PREF_ADD_LABEL,
- content=json.dumps(messages_list),
- timestamp=datetime.utcnow(),
- )
- mem_scheduler.submit_messages(messages=[message_item_pref])
- logger.info("Submitted preference add to scheduler (async mode)")
- except Exception as e:
- logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True)
- return []
- else:
- pref_memories_local = naive_mem_cube.pref_mem.get_memory(
- [add_req.messages],
- type="chat",
- info={
- "user_id": add_req.user_id,
- "session_id": target_session_id,
- },
- )
- pref_ids_local: list[str] = naive_mem_cube.pref_mem.add(pref_memories_local)
- logger.info(
- f"Added {len(pref_ids_local)} preferences for user {add_req.user_id} "
- f"in session {add_req.session_id}: {pref_ids_local}"
- )
- return [
- {
- "memory": memory.memory,
- "memory_id": memory_id,
- "memory_type": memory.metadata.preference_type,
- }
- for memory_id, memory in zip(pref_ids_local, pref_memories_local, strict=False)
- ]
-
- with ContextThreadPoolExecutor(max_workers=2) as executor:
- text_future = executor.submit(_process_text_mem)
- pref_future = executor.submit(_process_pref_mem)
- text_response_data = text_future.result()
- pref_response_data = pref_future.result()
-
- logger.info(f"add_memories Text response data: {text_response_data}")
- logger.info(f"add_memories Pref response data: {pref_response_data}")
-
- return MemoryResponse(
- message="Memory added successfully",
- data=text_response_data + pref_response_data,
+@router.post(
+ "/suggestions",
+ summary="Get suggestion queries",
+ response_model=SuggestionResponse,
+)
+def get_suggestion_queries(suggestion_req: SuggestionRequest):
+ """Get suggestion queries for a specific user with language preference."""
+ return handlers.suggestion_handler.handle_get_suggestion_queries(
+ user_id=suggestion_req.mem_cube_id,
+ language=suggestion_req.language,
+ message=suggestion_req.message,
+ llm=llm,
+ naive_mem_cube=naive_mem_cube,
)
-@router.get("/scheduler/status", summary="Get scheduler running status")
-def scheduler_status(user_name: str | None = None):
- try:
- if user_name:
- running = mem_scheduler.dispatcher.get_running_tasks(
- lambda task: getattr(task, "mem_cube_id", None) == user_name
- )
- tasks_iter = list(_to_iter(running))
- running_count = len(tasks_iter)
- return {
- "message": "ok",
- "data": {
- "scope": "user",
- "user_name": user_name,
- "running_tasks": running_count,
- "timestamp": time.time(),
- "instance_id": INSTANCE_ID,
- },
- }
- else:
- running_all = mem_scheduler.dispatcher.get_running_tasks(lambda _t: True)
- tasks_iter = list(_to_iter(running_all))
- running_count = len(tasks_iter)
-
- task_count_per_user: dict[str, int] = {}
- for task in tasks_iter:
- cube = getattr(task, "mem_cube_id", "unknown")
- task_count_per_user[cube] = task_count_per_user.get(cube, 0) + 1
-
- try:
- metrics_snapshot = mem_scheduler.dispatcher.metrics.snapshot()
- except Exception:
- metrics_snapshot = {}
-
- return {
- "message": "ok",
- "data": {
- "scope": "global",
- "running_tasks": running_count,
- "task_count_per_user": task_count_per_user,
- "timestamp": time.time(),
- "instance_id": INSTANCE_ID,
- "metrics": metrics_snapshot,
- },
- }
- except Exception as err:
- logger.error("Failed to get scheduler status: %s", traceback.format_exc())
- raise HTTPException(status_code=500, detail="Failed to get scheduler status") from err
+# =============================================================================
+# Memory Retrieval API Endpoints
+# =============================================================================
-@router.post("/scheduler/wait", summary="Wait until scheduler is idle for a specific user")
-def scheduler_wait(
- user_name: str,
- timeout_seconds: float = 120.0,
- poll_interval: float = 0.2,
-):
- """
- Block until scheduler has no running tasks for the given user_name, or timeout.
+@router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse)
+def get_all_memories(memory_req: GetMemoryRequest):
"""
- start = time.time()
- try:
- while True:
- running = mem_scheduler.dispatcher.get_running_tasks(
- lambda task: task.mem_cube_id == user_name
- )
- running_count = len(running)
- elapsed = time.time() - start
-
- # success -> scheduler is idle
- if running_count == 0:
- return {
- "message": "idle",
- "data": {
- "running_tasks": 0,
- "waited_seconds": round(elapsed, 3),
- "timed_out": False,
- "user_name": user_name,
- },
- }
-
- # timeout check
- if elapsed > timeout_seconds:
- return {
- "message": "timeout",
- "data": {
- "running_tasks": running_count,
- "waited_seconds": round(elapsed, 3),
- "timed_out": True,
- "user_name": user_name,
- },
- }
-
- time.sleep(poll_interval)
-
- except Exception as err:
- logger.error("Failed while waiting for scheduler: %s", traceback.format_exc())
- raise HTTPException(status_code=500, detail="Failed while waiting for scheduler") from err
+ Get all memories or subgraph for a specific user.
-
-@router.get("/scheduler/wait/stream", summary="Stream scheduler progress for a user")
-def scheduler_wait_stream(
- user_name: str,
- timeout_seconds: float = 120.0,
- poll_interval: float = 0.2,
-):
- """
- Stream scheduler progress via Server-Sent Events (SSE).
-
- Contract:
- - We emit periodic heartbeat frames while tasks are still running.
- - Each heartbeat frame is JSON, prefixed with "data: ".
- - On final frame, we include status = "idle" or "timeout" and timed_out flag,
- with the same semantics as /scheduler/wait.
-
- Example curl:
- curl -N "${API_HOST}/product/scheduler/wait/stream?timeout_seconds=10&poll_interval=0.5"
+ If search_query is provided, returns a subgraph based on the query.
+ Otherwise, returns all memories of the specified type.
"""
-
- def event_generator():
- start = time.time()
- try:
- while True:
- running = mem_scheduler.dispatcher.get_running_tasks(
- lambda task: task.mem_cube_id == user_name
- )
- running_count = len(running)
- elapsed = time.time() - start
-
- payload = {
- "user_name": user_name,
- "running_tasks": running_count,
- "elapsed_seconds": round(elapsed, 3),
- "status": "running" if running_count > 0 else "idle",
- "instance_id": INSTANCE_ID,
- }
- yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n"
-
- if running_count == 0 or elapsed > timeout_seconds:
- payload["status"] = "idle" if running_count == 0 else "timeout"
- payload["timed_out"] = running_count > 0
- yield "data: " + json.dumps(payload, ensure_ascii=False) + "\n\n"
- break
-
- time.sleep(poll_interval)
-
- except Exception as e:
- err_payload = {
- "status": "error",
- "detail": "stream_failed",
- "exception": str(e),
- "user_name": user_name,
- }
- logger.error(f"Scheduler stream error for {user_name}: {traceback.format_exc()}")
- yield "data: " + json.dumps(err_payload, ensure_ascii=False) + "\n\n"
-
- return StreamingResponse(event_generator(), media_type="text/event-stream")
-
-
-@router.post("/chat/complete", summary="Chat with MemOS (Complete Response)")
-def chat_complete(chat_req: APIChatCompleteRequest):
- """Chat with MemOS for a specific user. Returns complete response (non-streaming)."""
- try:
- # Collect all responses from the generator
- content, references = mos_server.chat(
- query=chat_req.query,
- user_id=chat_req.user_id,
- cube_id=chat_req.mem_cube_id,
- mem_cube=naive_mem_cube,
- history=chat_req.history,
- internet_search=chat_req.internet_search,
- moscube=chat_req.moscube,
- base_prompt=chat_req.base_prompt,
- top_k=chat_req.top_k,
- threshold=chat_req.threshold,
- session_id=chat_req.session_id,
+ if memory_req.search_query:
+ return handlers.memory_handler.handle_get_subgraph(
+ user_id=memory_req.user_id,
+ mem_cube_id=(
+ memory_req.mem_cube_ids[0] if memory_req.mem_cube_ids else memory_req.user_id
+ ),
+ query=memory_req.search_query,
+ top_k=20,
+ naive_mem_cube=naive_mem_cube,
+ )
+ else:
+ return handlers.memory_handler.handle_get_all_memories(
+ user_id=memory_req.user_id,
+ mem_cube_id=(
+ memory_req.mem_cube_ids[0] if memory_req.mem_cube_ids else memory_req.user_id
+ ),
+ memory_type=memory_req.memory_type or "text_mem",
+ naive_mem_cube=naive_mem_cube,
)
-
- # Return the complete response
- return {
- "message": "Chat completed successfully",
- "data": {"response": content, "references": references},
- }
-
- except ValueError as err:
- raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
- except Exception as err:
- logger.error(f"Failed to start chat: {traceback.format_exc()}")
- raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
diff --git a/src/memos/configs/mem_scheduler.py b/src/memos/configs/mem_scheduler.py
index e757f243b..afdaf6871 100644
--- a/src/memos/configs/mem_scheduler.py
+++ b/src/memos/configs/mem_scheduler.py
@@ -12,10 +12,13 @@
BASE_DIR,
DEFAULT_ACT_MEM_DUMP_PATH,
DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT,
+ DEFAULT_CONSUME_BATCH,
DEFAULT_CONSUME_INTERVAL_SECONDS,
DEFAULT_CONTEXT_WINDOW_SIZE,
DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE,
DEFAULT_MULTI_TASK_RUNNING_TIMEOUT,
+ DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE,
+ DEFAULT_SCHEDULER_RETRIEVER_RETRIES,
DEFAULT_THREAD_POOL_MAX_WORKERS,
DEFAULT_TOP_K,
DEFAULT_USE_REDIS_QUEUE,
@@ -43,6 +46,11 @@ class BaseSchedulerConfig(BaseConfig):
gt=0,
description=f"Interval for consuming messages from queue in seconds (default: {DEFAULT_CONSUME_INTERVAL_SECONDS})",
)
+ consume_batch: int = Field(
+ default=DEFAULT_CONSUME_BATCH,
+ gt=0,
+ description=f"Number of messages to consume in each batch (default: {DEFAULT_CONSUME_BATCH})",
+ )
auth_config_path: str | None = Field(
default=None,
description="Path to the authentication configuration file containing private credentials",
@@ -91,6 +99,17 @@ class GeneralSchedulerConfig(BaseSchedulerConfig):
description="Capacity of the activation memory monitor",
)
+ # Memory enhancement concurrency & retries configuration
+ enhance_batch_size: int | None = Field(
+ default=DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE,
+ description="Batch size for concurrent memory enhancement; None or <=1 disables batching",
+ )
+ enhance_retries: int = Field(
+ default=DEFAULT_SCHEDULER_RETRIEVER_RETRIES,
+ ge=0,
+ description="Number of retry attempts per enhancement batch",
+ )
+
# Database configuration for ORM persistence
db_path: str | None = Field(
default=None,
diff --git a/src/memos/embedders/universal_api.py b/src/memos/embedders/universal_api.py
index fc51cf073..583a02acb 100644
--- a/src/memos/embedders/universal_api.py
+++ b/src/memos/embedders/universal_api.py
@@ -26,7 +26,7 @@ def __init__(self, config: UniversalAPIEmbedderConfig):
else:
raise ValueError(f"Embeddings unsupported provider: {self.provider}")
- @timed(log=True, log_prefix="EmbedderAPI")
+ @timed(log=True, log_prefix="model_timed_embedding")
def embed(self, texts: list[str]) -> list[list[float]]:
if self.provider == "openai" or self.provider == "azure":
try:
diff --git a/src/memos/graph_dbs/polardb.py b/src/memos/graph_dbs/polardb.py
index 60902420f..da1635296 100644
--- a/src/memos/graph_dbs/polardb.py
+++ b/src/memos/graph_dbs/polardb.py
@@ -199,6 +199,7 @@ def _get_connection(self):
max_retries = 3
for attempt in range(max_retries):
+ conn = None
try:
conn = self.connection_pool.getconn()
@@ -216,8 +217,49 @@ def _get_connection(self):
# Set autocommit for PolarDB compatibility
conn.autocommit = True
+
+ # Test connection health with SELECT 1
+ try:
+ cursor = conn.cursor()
+ cursor.execute("SELECT 1")
+ cursor.fetchone()
+ cursor.close()
+ except Exception as health_check_error:
+ # Connection is not usable, close it and try again
+ logger.warning(
+ f"Connection health check failed: {health_check_error}, closing connection and retrying..."
+ )
+ try:
+ conn.close()
+ except Exception as close_error:
+ logger.warning(f"Failed to close unhealthy connection: {close_error}")
+
+ # Return connection to pool if it's still valid
+ try:
+ self.connection_pool.putconn(conn, close=True)
+ except Exception as close_error:
+ logger.warning(f"Failed to connection_pool.putconn: {close_error}")
+
+ conn = None
+ if attempt < max_retries - 1:
+ continue
+ else:
+ raise RuntimeError(
+ f"Failed to get a healthy connection from pool after {max_retries} attempts: {health_check_error}"
+ ) from health_check_error
+
+ # Connection is healthy, return it
return conn
except Exception as e:
+ # If we have a connection that failed, try to return it to pool
+ if conn is not None:
+ try:
+ self.connection_pool.putconn(conn, close=True)
+ except Exception as putconn_error:
+ logger.warning(
+ f"Failed to connection_pool.putconn to pool: {putconn_error}"
+ )
+
if attempt >= max_retries - 1:
raise RuntimeError(f"Failed to get a valid connection from pool: {e}") from e
continue
@@ -647,12 +689,16 @@ def add_edge(
self, source_id: str, target_id: str, type: str, user_name: str | None = None
) -> None:
if not source_id or not target_id:
+ logger.warning(f"Edge '{source_id}' and '{target_id}' are both None")
raise ValueError("[add_edge] source_id and target_id must be provided")
source_exists = self.get_node(source_id) is not None
target_exists = self.get_node(target_id) is not None
if not source_exists or not target_exists:
+ logger.warning(
+ "[add_edge] Source %s or target %s does not exist.", source_exists, target_exists
+ )
raise ValueError("[add_edge] source_id and target_id must be provided")
properties = {}
diff --git a/src/memos/llms/openai.py b/src/memos/llms/openai.py
index 1a1703340..da55ae593 100644
--- a/src/memos/llms/openai.py
+++ b/src/memos/llms/openai.py
@@ -1,5 +1,6 @@
import hashlib
import json
+import time
from collections.abc import Generator
from typing import ClassVar
@@ -57,12 +58,15 @@ def clear_cache(cls):
cls._instances.clear()
logger.info("OpenAI LLM instance cache cleared")
- @timed(log=True, log_prefix="OpenAI LLM")
+ @timed(log=True, log_prefix="model_timed_openai")
def generate(self, messages: MessageList, **kwargs) -> str:
"""Generate a response from OpenAI LLM, optionally overriding generation params."""
temperature = kwargs.get("temperature", self.config.temperature)
max_tokens = kwargs.get("max_tokens", self.config.max_tokens)
top_p = kwargs.get("top_p", self.config.top_p)
+ start_time = time.time()
+ logger.info(f"openai model request start, model_name: {self.config.model_name_or_path}")
+
response = self.client.chat.completions.create(
model=self.config.model_name_or_path,
messages=messages,
@@ -71,7 +75,11 @@ def generate(self, messages: MessageList, **kwargs) -> str:
max_tokens=max_tokens,
top_p=top_p,
)
- logger.info(f"Response from OpenAI: {response.model_dump_json()}")
+
+ end_time = time.time()
+ logger.info(
+ f"openai model request end, time_cost: {end_time - start_time:.0f} ms, response from OpenAI: {response.model_dump_json()}"
+ )
response_content = response.choices[0].message.content
if self.config.remove_think_prefix:
return remove_thinking_tags(response_content)
diff --git a/src/memos/log.py b/src/memos/log.py
index faa808414..874f2c6a7 100644
--- a/src/memos/log.py
+++ b/src/memos/log.py
@@ -37,6 +37,7 @@ def _setup_logfile() -> Path:
logfile = Path(settings.MEMOS_DIR / "logs" / "memos.log")
logfile.parent.mkdir(parents=True, exist_ok=True)
logfile.touch(exist_ok=True)
+
return logfile
@@ -195,10 +196,11 @@ def close(self):
},
"file": {
"level": "DEBUG",
- "class": "logging.handlers.RotatingFileHandler",
+ "class": "logging.handlers.TimedRotatingFileHandler",
+ "when": "midnight",
+ "interval": 1,
+ "backupCount": 3,
"filename": _setup_logfile(),
- "maxBytes": 1024**2 * 10,
- "backupCount": 10,
"formatter": "standard",
"filters": ["context_filter"],
},
diff --git a/src/memos/mem_cube/navie.py b/src/memos/mem_cube/navie.py
index 3acc441a0..3afa78bab 100644
--- a/src/memos/mem_cube/navie.py
+++ b/src/memos/mem_cube/navie.py
@@ -2,26 +2,13 @@
from typing import Literal
-from memos.configs.mem_cube import GeneralMemCubeConfig
from memos.configs.utils import get_json_file_model_schema
-from memos.embedders.base import BaseEmbedder
from memos.exceptions import ConfigurationError, MemCubeError
-from memos.graph_dbs.base import BaseGraphDB
-from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.mem_cube.base import BaseMemCube
-from memos.mem_reader.base import BaseMemReader
from memos.memories.activation.base import BaseActMemory
from memos.memories.parametric.base import BaseParaMemory
from memos.memories.textual.base import BaseTextMemory
-from memos.memories.textual.prefer_text_memory.adder import BaseAdder
-from memos.memories.textual.prefer_text_memory.extractor import BaseExtractor
-from memos.memories.textual.prefer_text_memory.retrievers import BaseRetriever
-from memos.memories.textual.simple_preference import SimplePreferenceTextMemory
-from memos.memories.textual.simple_tree import SimpleTreeTextMemory
-from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
-from memos.reranker.base import BaseReranker
-from memos.vec_dbs.base import BaseVecDB
logger = get_logger(__name__)
@@ -32,55 +19,28 @@ class NaiveMemCube(BaseMemCube):
def __init__(
self,
- llm: BaseLLM,
- embedder: BaseEmbedder,
- mem_reader: BaseMemReader,
- graph_db: BaseGraphDB,
- reranker: BaseReranker,
- memory_manager: MemoryManager,
- default_cube_config: GeneralMemCubeConfig,
- vector_db: BaseVecDB,
- internet_retriever: None = None,
- pref_extractor: BaseExtractor | None = None,
- pref_adder: BaseAdder | None = None,
- pref_retriever: BaseRetriever | None = None,
+ text_mem: BaseTextMemory | None = None,
+ pref_mem: BaseTextMemory | None = None,
+ act_mem: BaseActMemory | None = None,
+ para_mem: BaseParaMemory | None = None,
):
- """Initialize the MemCube with a configuration."""
- self._text_mem: BaseTextMemory | None = SimpleTreeTextMemory(
- llm,
- embedder,
- mem_reader,
- graph_db,
- reranker,
- memory_manager,
- default_cube_config.text_mem.config,
- internet_retriever,
- )
- self._act_mem: BaseActMemory | None = None
- self._para_mem: BaseParaMemory | None = None
- self._pref_mem: BaseTextMemory | None = (
- SimplePreferenceTextMemory(
- extractor_llm=llm,
- vector_db=vector_db,
- embedder=embedder,
- reranker=reranker,
- extractor=pref_extractor,
- adder=pref_adder,
- retriever=pref_retriever,
- )
- if os.getenv("ENABLE_PREFERENCE_MEMORY", "false").lower() == "true"
- else None
- )
+ """Initialize the MemCube with memory instances."""
+ self._text_mem: BaseTextMemory = text_mem
+ self._act_mem: BaseActMemory | None = act_mem
+ self._para_mem: BaseParaMemory | None = para_mem
+ self._pref_mem: BaseTextMemory | None = pref_mem
def load(
- self, dir: str, memory_types: list[Literal["text_mem", "act_mem", "para_mem"]] | None = None
+ self,
+ dir: str,
+ memory_types: list[Literal["text_mem", "act_mem", "para_mem", "pref_mem"]] | None = None,
) -> None:
"""Load memories.
Args:
dir (str): The directory containing the memory files.
memory_types (list[str], optional): List of memory types to load.
If None, loads all available memory types.
- Options: ["text_mem", "act_mem", "para_mem"]
+ Options: ["text_mem", "act_mem", "para_mem", "pref_mem"]
"""
loaded_schema = get_json_file_model_schema(os.path.join(dir, self.config.config_filename))
if loaded_schema != self.config.model_schema:
diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py
index 97ff9879f..f11b3a44c 100644
--- a/src/memos/mem_os/core.py
+++ b/src/memos/mem_os/core.py
@@ -283,12 +283,11 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None =
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=QUERY_LABEL,
content=query,
timestamp=datetime.utcnow(),
)
- self.mem_scheduler.submit_messages(messages=[message_item])
+ self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item])
memories = mem_cube.text_mem.search(
query,
@@ -344,12 +343,11 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None =
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=ANSWER_LABEL,
content=response,
timestamp=datetime.utcnow(),
)
- self.mem_scheduler.submit_messages(messages=[message_item])
+ self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item])
return response
@@ -768,27 +766,28 @@ def process_textual_memory():
)
# submit messages for scheduler
if self.enable_mem_scheduler and self.mem_scheduler is not None:
- mem_cube = self.mem_cubes[mem_cube_id]
if sync_mode == "async":
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=MEM_READ_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.utcnow(),
)
- self.mem_scheduler.submit_messages(messages=[message_item])
+ self.mem_scheduler.memos_message_queue.submit_messages(
+ messages=[message_item]
+ )
else:
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=ADD_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.utcnow(),
)
- self.mem_scheduler.submit_messages(messages=[message_item])
+ self.mem_scheduler.memos_message_queue.submit_messages(
+ messages=[message_item]
+ )
def process_preference_memory():
if (
@@ -797,12 +796,15 @@ def process_preference_memory():
and self.mem_cubes[mem_cube_id].pref_mem
):
messages_list = [messages]
- mem_cube = self.mem_cubes[mem_cube_id]
if sync_mode == "sync":
pref_memories = self.mem_cubes[mem_cube_id].pref_mem.get_memory(
messages_list,
type="chat",
- info={"user_id": target_user_id, "session_id": self.session_id},
+ info={
+ "user_id": target_user_id,
+ "session_id": self.session_id,
+ "mem_cube_id": mem_cube_id,
+ },
)
pref_ids = self.mem_cubes[mem_cube_id].pref_mem.add(pref_memories)
logger.info(
@@ -816,12 +818,11 @@ def process_preference_memory():
user_id=target_user_id,
session_id=target_session_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=PREF_ADD_LABEL,
content=json.dumps(messages_list),
timestamp=datetime.utcnow(),
)
- self.mem_scheduler.submit_messages(messages=[message_item])
+ self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item])
# Execute both memory processing functions in parallel
with ContextThreadPoolExecutor(max_workers=2) as executor:
@@ -867,27 +868,28 @@ def process_preference_memory():
# submit messages for scheduler
if self.enable_mem_scheduler and self.mem_scheduler is not None:
- mem_cube = self.mem_cubes[mem_cube_id]
if sync_mode == "async":
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=MEM_READ_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.utcnow(),
)
- self.mem_scheduler.submit_messages(messages=[message_item])
+ self.mem_scheduler.memos_message_queue.submit_messages(
+ messages=[message_item]
+ )
else:
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=ADD_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.utcnow(),
)
- self.mem_scheduler.submit_messages(messages=[message_item])
+ self.mem_scheduler.memos_message_queue.submit_messages(
+ messages=[message_item]
+ )
# user doc input
if (
@@ -909,16 +911,14 @@ def process_preference_memory():
# submit messages for scheduler
if self.enable_mem_scheduler and self.mem_scheduler is not None:
- mem_cube = self.mem_cubes[mem_cube_id]
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=ADD_LABEL,
content=json.dumps(mem_ids),
timestamp=datetime.utcnow(),
)
- self.mem_scheduler.submit_messages(messages=[message_item])
+ self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item])
logger.info(f"Add memory to {mem_cube_id} successfully")
diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py
index 6fc64c5e3..11c112d52 100644
--- a/src/memos/mem_os/main.py
+++ b/src/memos/mem_os/main.py
@@ -205,7 +205,6 @@ def _chat_with_cot_enhancement(
# Step 7: Submit message to scheduler (same as core method)
if len(accessible_cubes) == 1:
mem_cube_id = accessible_cubes[0].cube_id
- mem_cube = self.mem_cubes[mem_cube_id]
if self.enable_mem_scheduler and self.mem_scheduler is not None:
from datetime import datetime
@@ -217,12 +216,11 @@ def _chat_with_cot_enhancement(
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=ANSWER_LABEL,
content=enhanced_response,
timestamp=datetime.now().isoformat(),
)
- self.mem_scheduler.submit_messages(messages=[message_item])
+ self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item])
return enhanced_response
diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py
index 89e468bd7..9a4ab3f4d 100644
--- a/src/memos/mem_os/product.py
+++ b/src/memos/mem_os/product.py
@@ -563,6 +563,34 @@ def _extract_references_from_response(self, response: str) -> tuple[str, list[di
logger.error(f"Error extracting references from response: {e}", exc_info=True)
return response, []
+ def _extract_struct_data_from_history(self, chat_data: list[dict]) -> dict:
+ """
+ get struct message from chat-history
+ # TODO: @xcy make this more general
+ """
+ system_content = ""
+ memory_content = ""
+ chat_history = []
+
+ for item in chat_data:
+ role = item.get("role")
+ content = item.get("content", "")
+ if role == "system":
+ parts = content.split("# Memories", 1)
+ system_content = parts[0].strip()
+ if len(parts) > 1:
+ memory_content = "# Memories" + parts[1].strip()
+ elif role in ("user", "assistant"):
+ chat_history.append({"role": role, "content": content})
+
+ if chat_history and chat_history[-1]["role"] == "assistant":
+ if len(chat_history) >= 2 and chat_history[-2]["role"] == "user":
+ chat_history = chat_history[:-2]
+ else:
+ chat_history = chat_history[:-1]
+
+ return {"system": system_content, "memory": memory_content, "chat_history": chat_history}
+
def _chunk_response_with_tiktoken(
self, response: str, chunk_size: int = 5
) -> Generator[str, None, None]:
@@ -609,12 +637,11 @@ def _send_message_to_scheduler(
message_item = ScheduleMessageItem(
user_id=user_id,
mem_cube_id=mem_cube_id,
- mem_cube=self.mem_cubes[mem_cube_id],
label=label,
content=query,
timestamp=datetime.utcnow(),
)
- self.mem_scheduler.submit_messages(messages=[message_item])
+ self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item])
async def _post_chat_processing(
self,
@@ -640,23 +667,26 @@ async def _post_chat_processing(
clean_response, extracted_references = self._extract_references_from_response(
full_response
)
+ struct_message = self._extract_struct_data_from_history(current_messages)
logger.info(f"Extracted {len(extracted_references)} references from response")
# Send chat report notifications asynchronously
if self.online_bot:
+ logger.info("Online Bot Open!")
try:
from memos.memos_tools.notification_utils import (
send_online_bot_notification_async,
)
# Prepare notification data
- chat_data = {
- "query": query,
- "user_id": user_id,
- "cube_id": cube_id,
- "system_prompt": system_prompt,
- "full_response": full_response,
- }
+ chat_data = {"query": query, "user_id": user_id, "cube_id": cube_id}
+ chat_data.update(
+ {
+ "memory": struct_message["memory"],
+ "chat_history": struct_message["chat_history"],
+ "full_response": full_response,
+ }
+ )
system_data = {
"references": extracted_references,
@@ -720,6 +750,7 @@ def _start_post_chat_processing(
"""
Asynchronous processing of logs, notifications and memory additions, handle synchronous and asynchronous environments
"""
+ logger.info("Start post_chat_processing...")
def run_async_in_thread():
"""Running asynchronous tasks in a new thread"""
@@ -1046,14 +1077,20 @@ def chat(
memories_list = new_memories_list
system_prompt = super()._build_system_prompt(memories_list, base_prompt)
- history_info = []
- if history:
+ if history is not None:
+ # Use the provided history (even if it's empty)
history_info = history[-20:]
+ else:
+ # Fall back to internal chat_history
+ if user_id not in self.chat_history_manager:
+ self._register_chat_history(user_id, session_id)
+ history_info = self.chat_history_manager[user_id].chat_history[-20:]
current_messages = [
{"role": "system", "content": system_prompt},
*history_info,
{"role": "user", "content": query},
]
+ logger.info("Start to get final answer...")
response = self.chat_llm.generate(current_messages)
time_end = time.time()
self._start_post_chat_processing(
@@ -1129,7 +1166,7 @@ def chat_with_references(
self._register_chat_history(user_id, session_id)
chat_history = self.chat_history_manager[user_id]
- if history:
+ if history is not None:
chat_history.chat_history = history[-20:]
current_messages = [
{"role": "system", "content": system_prompt},
diff --git a/src/memos/mem_os/utils/reference_utils.py b/src/memos/mem_os/utils/reference_utils.py
index c2f4431c3..09b812207 100644
--- a/src/memos/mem_os/utils/reference_utils.py
+++ b/src/memos/mem_os/utils/reference_utils.py
@@ -142,12 +142,21 @@ def prepare_reference_data(memories_list: list[TextualMemoryItem]) -> list[dict]
# Prepare reference data
reference = []
for memories in memories_list:
- memories_json = memories.model_dump()
- memories_json["metadata"]["ref_id"] = f"{memories.id.split('-')[0]}"
- memories_json["metadata"]["embedding"] = []
- memories_json["metadata"]["sources"] = []
- memories_json["metadata"]["memory"] = memories.memory
- memories_json["metadata"]["id"] = memories.id
- reference.append({"metadata": memories_json["metadata"]})
+ if isinstance(memories, TextualMemoryItem):
+ memories_json = memories.model_dump()
+ memories_json["metadata"]["ref_id"] = f"{memories.id.split('-')[0]}"
+ memories_json["metadata"]["embedding"] = []
+ memories_json["metadata"]["sources"] = []
+ memories_json["metadata"]["memory"] = memories.memory
+ memories_json["metadata"]["id"] = memories.id
+ reference.append({"metadata": memories_json["metadata"]})
+ else:
+ memories_json = memories
+ memories_json["metadata"]["ref_id"] = f"{memories_json['id'].split('-')[0]}"
+ memories_json["metadata"]["embedding"] = []
+ memories_json["metadata"]["sources"] = []
+ memories_json["metadata"]["memory"] = memories_json["memory"]
+ memories_json["metadata"]["id"] = memories_json["id"]
+ reference.append({"metadata": memories_json["metadata"]})
return reference
diff --git a/src/memos/mem_reader/simple_struct.py b/src/memos/mem_reader/simple_struct.py
index 13515c038..3845f37d0 100644
--- a/src/memos/mem_reader/simple_struct.py
+++ b/src/memos/mem_reader/simple_struct.py
@@ -6,6 +6,7 @@
import traceback
from abc import ABC
+from datetime import datetime, timezone
from typing import Any
from tqdm import tqdm
@@ -399,7 +400,7 @@ def get_memory(
if not all(isinstance(info[field], str) for field in required_fields):
raise ValueError("user_id and session_id must be strings")
-
+ scene_data = self._complete_chat_time(scene_data, type)
list_scene_data_info = self.get_scene_data_info(scene_data, type)
memory_list = []
@@ -508,6 +509,31 @@ def get_scene_data_info(self, scene_data: list, type: str) -> list[str]:
return results
+ def _complete_chat_time(self, scene_data: list[list[dict]], type: str):
+ if type != "chat":
+ return scene_data
+ complete_scene_data = []
+
+ for items in scene_data:
+ chat_time_value = None
+
+ for item in items:
+ if "chat_time" in item:
+ chat_time_value = item["chat_time"]
+ break
+
+ if chat_time_value is None:
+ session_date = datetime.now(timezone.utc)
+ date_format = "%I:%M %p on %d %B, %Y UTC"
+ chat_time_value = session_date.strftime(date_format)
+
+ for i in range(len(items)):
+ if "chat_time" not in items[i]:
+ items[i]["chat_time"] = chat_time_value
+
+ complete_scene_data.append(items)
+ return complete_scene_data
+
def _process_doc_data(self, scene_data_info, info, **kwargs):
mode = kwargs.get("mode", "fine")
if mode == "fast":
diff --git a/src/memos/mem_scheduler/analyzer/api_analyzer.py b/src/memos/mem_scheduler/analyzer/api_analyzer.py
index 28ca182e5..085025b7f 100644
--- a/src/memos/mem_scheduler/analyzer/api_analyzer.py
+++ b/src/memos/mem_scheduler/analyzer/api_analyzer.py
@@ -7,7 +7,6 @@
import http.client
import json
-import time
from typing import Any
from urllib.parse import urlparse
@@ -15,6 +14,7 @@
import requests
from memos.log import get_logger
+from memos.mem_scheduler.schemas.general_schemas import SearchMode
logger = get_logger(__name__)
@@ -487,7 +487,7 @@ def search_in_conversation(self, query, mode="fast", top_k=10, include_history=T
return result
- def test_continuous_conversation(self):
+ def test_continuous_conversation(self, mode=SearchMode.MIXTURE):
"""Test continuous conversation functionality"""
print("=" * 80)
print("Testing Continuous Conversation Functionality")
@@ -542,15 +542,15 @@ def test_continuous_conversation(self):
# Search for trip-related information
self.search_in_conversation(
- query="New Year's Eve Shanghai recommendations", mode="mixture", top_k=5
+ query="New Year's Eve Shanghai recommendations", mode=mode, top_k=5
)
# Search for food-related information
- self.search_in_conversation(query="budget food Shanghai", mode="mixture", top_k=3)
+ self.search_in_conversation(query="budget food Shanghai", mode=mode, top_k=3)
# Search without conversation history
self.search_in_conversation(
- query="Shanghai travel", mode="mixture", top_k=3, include_history=False
+ query="Shanghai travel", mode=mode, top_k=3, include_history=False
)
print("\n✅ Continuous conversation test completed successfully!")
@@ -645,7 +645,7 @@ def create_test_add_request(
operation=None,
)
- def run_all_tests(self):
+ def run_all_tests(self, mode=SearchMode.MIXTURE):
"""Run all available tests"""
print("🚀 Starting comprehensive test suite")
print("=" * 80)
@@ -653,8 +653,7 @@ def run_all_tests(self):
# Test continuous conversation functionality
print("\n💬 Testing CONTINUOUS CONVERSATION functions:")
try:
- self.test_continuous_conversation()
- time.sleep(5)
+ self.test_continuous_conversation(mode=mode)
print("✅ Continuous conversation test completed successfully")
except Exception as e:
print(f"❌ Continuous conversation test failed: {e}")
@@ -682,7 +681,7 @@ def run_all_tests(self):
print("Using direct test mode")
try:
direct_analyzer = DirectSearchMemoriesAnalyzer()
- direct_analyzer.run_all_tests()
+ direct_analyzer.run_all_tests(mode=SearchMode.MIXTURE)
except Exception as e:
print(f"Direct test mode failed: {e}")
import traceback
diff --git a/src/memos/mem_scheduler/analyzer/eval_analyzer.py b/src/memos/mem_scheduler/analyzer/eval_analyzer.py
new file mode 100644
index 000000000..cf0b8f1dd
--- /dev/null
+++ b/src/memos/mem_scheduler/analyzer/eval_analyzer.py
@@ -0,0 +1,1322 @@
+"""
+Evaluation Analyzer for Bad Cases
+
+This module provides the EvalAnalyzer class that extracts bad cases from evaluation results
+and analyzes whether memories contain sufficient information to answer golden answers.
+"""
+
+import json
+import os
+import sys
+
+from pathlib import Path
+from typing import Any
+
+from openai import OpenAI
+
+from memos.api.routers.server_router import mem_scheduler
+from memos.log import get_logger
+from memos.memories.textual.item import TextualMemoryMetadata
+from memos.memories.textual.tree import TextualMemoryItem
+
+
+FILE_PATH = Path(__file__).absolute()
+BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent # Go up to project root
+sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
+
+logger = get_logger(__name__)
+
+
+class EvalAnalyzer:
+ """
+ Evaluation Analyzer class for extracting and analyzing bad cases.
+
+ This class extracts bad cases from evaluation results and uses LLM to analyze
+ whether memories contain sufficient information to answer golden answers.
+ """
+
+ def __init__(
+ self,
+ openai_api_key: str | None = None,
+ openai_base_url: str | None = None,
+ openai_model: str = "gpt-4o-mini",
+ output_dir: str = "./tmp/eval_analyzer",
+ ):
+ """
+ Initialize the EvalAnalyzer.
+
+ Args:
+ openai_api_key: OpenAI API key
+ openai_base_url: OpenAI base URL
+ openai_model: OpenAI model to use
+ output_dir: Output directory for results
+ """
+ self.output_dir = Path(output_dir)
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Initialize OpenAI client
+ self.openai_client = OpenAI(
+ api_key=openai_api_key or os.getenv("MEMSCHEDULER_OPENAI_API_KEY"),
+ base_url=openai_base_url or os.getenv("MEMSCHEDULER_OPENAI_BASE_URL"),
+ )
+ self.openai_model = openai_model or os.getenv(
+ "MEMSCHEDULER_OPENAI_DEFAULT_MODEL", "gpt-4o-mini"
+ )
+
+ logger.info(f"EvalAnalyzer initialized with model: {self.openai_model}")
+
+ def load_json_file(self, filepath: str) -> Any:
+ """Load JSON file safely."""
+ try:
+ with open(filepath, encoding="utf-8") as f:
+ return json.load(f)
+ except FileNotFoundError:
+ logger.error(f"File not found: {filepath}")
+ return None
+ except json.JSONDecodeError as e:
+ logger.error(f"JSON decode error in {filepath}: {e}")
+ return None
+
+ def extract_bad_cases(self, judged_file: str, search_results_file: str) -> list[dict[str, Any]]:
+ """
+ Extract bad cases from judged results and corresponding search results.
+
+ Args:
+ judged_file: Path to the judged results JSON file
+ search_results_file: Path to the search results JSON file
+
+ Returns:
+ List of bad cases with their memories
+ """
+ logger.info(f"Loading judged results from: {judged_file}")
+ judged_data = self.load_json_file(judged_file)
+ if not judged_data:
+ return []
+
+ logger.info(f"Loading search results from: {search_results_file}")
+ search_data = self.load_json_file(search_results_file)
+ if not search_data:
+ return []
+
+ bad_cases = []
+
+ # Process each user's data
+ for user_id, user_judged_results in judged_data.items():
+ user_search_results = search_data.get(user_id, [])
+
+ # Create a mapping from query to search context
+ search_context_map = {}
+ for search_result in user_search_results:
+ query = search_result.get("query", "")
+ context = search_result.get("context", "")
+ search_context_map[query] = context
+
+ # Process each question for this user
+ for result in user_judged_results:
+ # Check if this is a bad case (all judgments are False)
+ judgments = result.get("llm_judgments", {})
+ is_bad_case = all(not judgment for judgment in judgments.values())
+
+ if is_bad_case:
+ question = result.get("question", "")
+ answer = result.get("answer", "")
+ golden_answer = result.get("golden_answer", "")
+
+ # Find corresponding memories from search results
+ memories = search_context_map.get(question, "")
+
+ bad_case = {
+ "user_id": user_id,
+ "query": question,
+ "answer": answer,
+ "golden_answer": golden_answer,
+ "memories": memories,
+ "category": result.get("category", 0),
+ "nlp_metrics": result.get("nlp_metrics", {}),
+ "response_duration_ms": result.get("response_duration_ms", 0),
+ "search_duration_ms": result.get("search_duration_ms", 0),
+ "total_duration_ms": result.get("total_duration_ms", 0),
+ }
+
+ bad_cases.append(bad_case)
+
+ logger.info(f"Extracted {len(bad_cases)} bad cases")
+ return bad_cases
+
+ def analyze_memory_sufficiency(
+ self, query: str, golden_answer: str, memories: str
+ ) -> dict[str, Any]:
+ """
+ Use LLM to analyze whether memories contain sufficient information to answer the golden answer.
+
+ Args:
+ query: The original query
+ golden_answer: The correct answer
+ memories: The memory context
+
+ Returns:
+ Analysis result containing sufficiency judgment and relevant memory indices
+ """
+ prompt = f"""
+You are an expert analyst tasked with determining whether the provided memories contain sufficient information to answer a specific question correctly.
+
+**Question:** {query}
+
+**Golden Answer (Correct Answer):** {golden_answer}
+
+**Available Memories:**
+{memories}
+
+**Task:**
+1. Analyze whether the memories contain enough information to derive the golden answer
+2. Identify which specific memory entries (if any) contain relevant information
+3. Provide a clear judgment: True if sufficient, False if insufficient
+
+**Response Format (JSON):**
+{{
+ "sufficient": true/false,
+ "confidence": 0.0-1.0,
+ "relevant_memories": ["memory_1", "memory_2", ...],
+ "reasoning": "Detailed explanation of your analysis",
+ "missing_information": "What key information is missing (if insufficient)"
+}}
+
+**Guidelines:**
+- Be strict in your evaluation - only mark as sufficient if the memories clearly contain the information needed
+- Consider both direct and indirect information that could lead to the golden answer
+- Pay attention to dates, names, events, and specific details
+- If information is ambiguous or requires significant inference, lean towards insufficient
+"""
+
+ try:
+ response = self.openai_client.chat.completions.create(
+ model=self.openai_model,
+ messages=[
+ {
+ "role": "system",
+ "content": "You are a precise analyst who evaluates information sufficiency.",
+ },
+ {"role": "user", "content": prompt},
+ ],
+ temperature=0.1,
+ max_tokens=1000,
+ )
+
+ content = response.choices[0].message.content.strip()
+
+ # Try to parse JSON response
+ try:
+ # Remove markdown code blocks if present
+ if content.startswith("```json"):
+ content = content[7:]
+ if content.endswith("```"):
+ content = content[:-3]
+ content = content.strip()
+
+ analysis = json.loads(content)
+ return analysis
+
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse LLM response as JSON: {content}")
+ return {
+ "sufficient": False,
+ "confidence": 0.0,
+ "relevant_memories": [],
+ "reasoning": f"Failed to parse LLM response: {content}",
+ "missing_information": "Analysis failed",
+ }
+
+ except Exception as e:
+ logger.error(f"Error in LLM analysis: {e}")
+ return {
+ "sufficient": False,
+ "confidence": 0.0,
+ "relevant_memories": [],
+ "reasoning": f"Error occurred: {e!s}",
+ "missing_information": "Analysis failed due to error",
+ }
+
+ def process_memories_with_llm(
+ self, memories: str, query: str, processing_type: str = "summarize"
+ ) -> dict[str, Any]:
+ """
+ Use LLM to process memories for better question answering.
+
+ Args:
+ memories: The raw memory content
+ query: The query that will be answered using these memories
+ processing_type: Type of processing ("summarize", "restructure", "enhance")
+
+ Returns:
+ Dictionary containing processed memories and processing metadata
+ """
+ if processing_type == "summarize":
+ prompt = f"""
+You are an expert at summarizing and organizing information to help answer specific questions.
+
+**Target Question:** {query}
+
+**Raw Memories:**
+{memories}
+
+**Task:**
+Summarize and organize the above memories in a way that would be most helpful for answering the target question. Focus on:
+1. Key facts and information relevant to the question
+2. Important relationships and connections
+3. Chronological or logical organization where applicable
+4. Remove redundant or irrelevant information
+
+**Processed Memories:**
+"""
+ elif processing_type == "restructure":
+ prompt = f"""
+You are an expert at restructuring information to optimize question answering.
+
+**Target Question:** {query}
+
+**Raw Memories:**
+{memories}
+
+**Task:**
+Restructure the above memories into a clear, logical format that directly supports answering the target question. Organize by:
+1. Most relevant information first
+2. Supporting details and context
+3. Clear categorization of different types of information
+4. Logical flow that leads to the answer
+
+**Restructured Memories:**
+"""
+ elif processing_type == "enhance":
+ prompt = f"""
+You are an expert at enhancing information by adding context and making connections.
+
+**Target Question:** {query}
+
+**Raw Memories:**
+{memories}
+
+**Task:**
+Enhance the above memories by:
+1. Making implicit connections explicit
+2. Adding relevant context that helps answer the question
+3. Highlighting key relationships between different pieces of information
+4. Organizing information in a question-focused manner
+
+**Enhanced Memories:**
+"""
+ else:
+ raise ValueError(f"Unknown processing_type: {processing_type}")
+
+ try:
+ response = self.openai_client.chat.completions.create(
+ model=self.openai_model,
+ messages=[
+ {
+ "role": "system",
+ "content": "You are an expert information processor who optimizes content for question answering.",
+ },
+ {"role": "user", "content": prompt},
+ ],
+ temperature=0.3,
+ max_tokens=2000,
+ )
+
+ processed_memories = response.choices[0].message.content.strip()
+
+ return {
+ "processed_memories": processed_memories,
+ "processing_type": processing_type,
+ "original_length": len(memories),
+ "processed_length": len(processed_memories),
+ "compression_ratio": len(processed_memories) / len(memories)
+ if len(memories) > 0
+ else 0,
+ }
+
+ except Exception as e:
+ logger.error(f"Error in memory processing: {e}")
+ return {
+ "processed_memories": memories, # Fallback to original
+ "processing_type": processing_type,
+ "original_length": len(memories),
+ "processed_length": len(memories),
+ "compression_ratio": 1.0,
+ "error": str(e),
+ }
+
+ def generate_answer_with_memories(
+ self, query: str, memories: str, memory_type: str = "original"
+ ) -> dict[str, Any]:
+ """
+ Generate an answer to the query using the provided memories.
+
+ Args:
+ query: The question to answer
+ memories: The memory content to use
+ memory_type: Type of memories ("original", "processed")
+
+ Returns:
+ Dictionary containing the generated answer and metadata
+ """
+ prompt = f"""
+ You are a knowledgeable and helpful AI assistant.
+
+ # CONTEXT:
+ You have access to memories from two speakers in a conversation. These memories contain
+ timestamped information that may be relevant to answering the question.
+
+ # INSTRUCTIONS:
+ 1. Carefully analyze all provided memories. Synthesize information across different entries if needed to form a complete answer.
+ 2. Pay close attention to the timestamps to determine the answer. If memories contain contradictory information, the **most recent memory** is the source of truth.
+ 3. If the question asks about a specific event or fact, look for direct evidence in the memories.
+ 4. Your answer must be grounded in the memories. However, you may use general world knowledge to interpret or complete information found within a memory (e.g., identifying a landmark mentioned by description).
+ 5. If the question involves time references (like "last year", "two months ago", etc.), you **must** calculate the actual date based on the memory's timestamp. For example, if a memory from 4 May 2022 mentions "went to India last year," then the trip occurred in 2021.
+ 6. Always convert relative time references to specific dates, months, or years in your final answer.
+ 7. Do not confuse character names mentioned in memories with the actual users who created them.
+ 8. The answer must be brief (under 5-6 words) and direct, with no extra description.
+
+ # APPROACH (Think step by step):
+ 1. First, examine all memories that contain information related to the question.
+ 2. Synthesize findings from multiple memories if a single entry is insufficient.
+ 3. Examine timestamps and content carefully, looking for explicit dates, times, locations, or events.
+ 4. If the answer requires calculation (e.g., converting relative time references), perform the calculation.
+ 5. Formulate a precise, concise answer based on the evidence from the memories (and allowed world knowledge).
+ 6. Double-check that your answer directly addresses the question asked and adheres to all instructions.
+ 7. Ensure your final answer is specific and avoids vague time references.
+
+ {memories}
+
+ Question: {query}
+
+ Answer:
+"""
+
+ try:
+ response = self.openai_client.chat.completions.create(
+ model=self.openai_model,
+ messages=[
+ {
+ "role": "system",
+ "content": "You are a precise assistant who answers questions based only on provided information.",
+ },
+ {"role": "user", "content": prompt},
+ ],
+ temperature=0.1,
+ max_tokens=1000,
+ )
+
+ answer = response.choices[0].message.content.strip()
+
+ return {
+ "answer": answer,
+ "memory_type": memory_type,
+ "query": query,
+ "memory_length": len(memories),
+ "answer_length": len(answer),
+ }
+
+ except Exception as e:
+ logger.error(f"Error in answer generation: {e}")
+ return {
+ "answer": f"Error generating answer: {e!s}",
+ "memory_type": memory_type,
+ "query": query,
+ "memory_length": len(memories),
+ "answer_length": 0,
+ "error": str(e),
+ }
+
+ def compare_answer_quality(
+ self, query: str, golden_answer: str, original_answer: str, processed_answer: str
+ ) -> dict[str, Any]:
+ """
+ Compare the quality of answers generated from original vs processed memories.
+
+ Args:
+ query: The original query
+ golden_answer: The correct/expected answer
+ original_answer: Answer generated from original memories
+ processed_answer: Answer generated from processed memories
+
+ Returns:
+ Dictionary containing comparison results
+ """
+ prompt = f"""
+You are an expert evaluator comparing the quality of two answers against a golden standard.
+
+**Question:** {query}
+
+**Golden Answer (Correct):** {golden_answer}
+
+**Answer A (Original Memories):** {original_answer}
+
+**Answer B (Processed Memories):** {processed_answer}
+
+**Task:**
+Compare both answers against the golden answer and evaluate:
+1. Accuracy: How correct is each answer?
+2. Completeness: How complete is each answer?
+3. Relevance: How relevant is each answer to the question?
+4. Clarity: How clear and well-structured is each answer?
+
+**Response Format (JSON):**
+{{
+ "original_scores": {{
+ "accuracy": 0.0-1.0,
+ "completeness": 0.0-1.0,
+ "relevance": 0.0-1.0,
+ "clarity": 0.0-1.0,
+ "overall": 0.0-1.0
+ }},
+ "processed_scores": {{
+ "accuracy": 0.0-1.0,
+ "completeness": 0.0-1.0,
+ "relevance": 0.0-1.0,
+ "clarity": 0.0-1.0,
+ "overall": 0.0-1.0
+ }},
+ "winner": "original|processed|tie",
+ "improvement": 0.0-1.0,
+ "reasoning": "Detailed explanation of the comparison"
+}}
+"""
+
+ try:
+ response = self.openai_client.chat.completions.create(
+ model=self.openai_model,
+ messages=[
+ {
+ "role": "system",
+ "content": "You are an expert evaluator who compares answer quality objectively.",
+ },
+ {"role": "user", "content": prompt},
+ ],
+ temperature=0.1,
+ max_tokens=1500,
+ )
+
+ content = response.choices[0].message.content.strip()
+
+ # Try to parse JSON response
+ try:
+ if content.startswith("```json"):
+ content = content[7:]
+ if content.endswith("```"):
+ content = content[:-3]
+ content = content.strip()
+
+ comparison = json.loads(content)
+ return comparison
+
+ except json.JSONDecodeError:
+ logger.warning(f"Failed to parse comparison response as JSON: {content}")
+ return {
+ "original_scores": {
+ "accuracy": 0.5,
+ "completeness": 0.5,
+ "relevance": 0.5,
+ "clarity": 0.5,
+ "overall": 0.5,
+ },
+ "processed_scores": {
+ "accuracy": 0.5,
+ "completeness": 0.5,
+ "relevance": 0.5,
+ "clarity": 0.5,
+ "overall": 0.5,
+ },
+ "winner": "tie",
+ "improvement": 0.0,
+ "reasoning": f"Failed to parse comparison: {content}",
+ }
+
+ except Exception as e:
+ logger.error(f"Error in answer comparison: {e}")
+ return {
+ "original_scores": {
+ "accuracy": 0.0,
+ "completeness": 0.0,
+ "relevance": 0.0,
+ "clarity": 0.0,
+ "overall": 0.0,
+ },
+ "processed_scores": {
+ "accuracy": 0.0,
+ "completeness": 0.0,
+ "relevance": 0.0,
+ "clarity": 0.0,
+ "overall": 0.0,
+ },
+ "winner": "tie",
+ "improvement": 0.0,
+ "reasoning": f"Error occurred: {e!s}",
+ }
+
+ def analyze_memory_processing_effectiveness(
+ self,
+ bad_cases: list[dict[str, Any]],
+ processing_types: list[str] | None = None,
+ ) -> dict[str, Any]:
+ """
+ Analyze the effectiveness of different memory processing techniques.
+
+ Args:
+ bad_cases: List of bad cases to analyze
+ processing_types: List of processing types to test
+
+ Returns:
+ Dictionary containing comprehensive analysis results
+ """
+ if processing_types is None:
+ processing_types = ["summarize", "restructure", "enhance"]
+ results = {"processing_results": [], "statistics": {}, "processing_types": processing_types}
+
+ for i, case in enumerate(bad_cases):
+ logger.info(f"Processing case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...")
+
+ case_result = {
+ "case_id": i,
+ "query": case["query"],
+ "golden_answer": case["golden_answer"],
+ "original_memories": case["memories"],
+ "processing_results": {},
+ }
+
+ # Generate answer with original memories
+ original_answer_result = self.generate_answer_with_memories(
+ case["query"], case["memories"], "original"
+ )
+ case_result["original_answer"] = original_answer_result
+
+ # Test each processing type
+ for processing_type in processing_types:
+ logger.info(f" Testing {processing_type} processing...")
+
+ # Process memories
+ processing_result = self.process_memories_with_llm(
+ case["memories"], case["query"], processing_type
+ )
+
+ # Generate answer with processed memories
+ processed_answer_result = self.generate_answer_with_memories(
+ case["query"],
+ processing_result["processed_memories"],
+ f"processed_{processing_type}",
+ )
+
+ # Compare answer quality
+ comparison_result = self.compare_answer_quality(
+ case["query"],
+ case["golden_answer"],
+ original_answer_result["answer"],
+ processed_answer_result["answer"],
+ )
+
+ case_result["processing_results"][processing_type] = {
+ "processing": processing_result,
+ "answer": processed_answer_result,
+ "comparison": comparison_result,
+ }
+
+ results["processing_results"].append(case_result)
+
+ # Calculate statistics
+ self._calculate_processing_statistics(results)
+
+ return results
+
+ def _calculate_processing_statistics(self, results: dict[str, Any]) -> None:
+ """Calculate statistics for processing effectiveness analysis."""
+ processing_types = results["processing_types"]
+ processing_results = results["processing_results"]
+
+ if not processing_results:
+ results["statistics"] = {}
+ return
+
+ stats = {"total_cases": len(processing_results), "processing_type_stats": {}}
+
+ for processing_type in processing_types:
+ type_stats = {
+ "wins": 0,
+ "ties": 0,
+ "losses": 0,
+ "avg_improvement": 0.0,
+ "avg_compression_ratio": 0.0,
+ "avg_scores": {
+ "accuracy": 0.0,
+ "completeness": 0.0,
+ "relevance": 0.0,
+ "clarity": 0.0,
+ "overall": 0.0,
+ },
+ }
+
+ valid_cases = []
+ for case in processing_results:
+ if processing_type in case["processing_results"]:
+ result = case["processing_results"][processing_type]
+ comparison = result["comparison"]
+
+ # Count wins/ties/losses
+ if comparison["winner"] == "processed":
+ type_stats["wins"] += 1
+ elif comparison["winner"] == "tie":
+ type_stats["ties"] += 1
+ else:
+ type_stats["losses"] += 1
+
+ valid_cases.append(result)
+
+ if valid_cases:
+ # Calculate averages
+ type_stats["avg_improvement"] = sum(
+ case["comparison"]["improvement"] for case in valid_cases
+ ) / len(valid_cases)
+
+ type_stats["avg_compression_ratio"] = sum(
+ case["processing"]["compression_ratio"] for case in valid_cases
+ ) / len(valid_cases)
+
+ # Calculate average scores
+ for score_type in type_stats["avg_scores"]:
+ type_stats["avg_scores"][score_type] = sum(
+ case["comparison"]["processed_scores"][score_type] for case in valid_cases
+ ) / len(valid_cases)
+
+ # Calculate win rate
+ total_decisions = type_stats["wins"] + type_stats["ties"] + type_stats["losses"]
+ type_stats["win_rate"] = (
+ type_stats["wins"] / total_decisions if total_decisions > 0 else 0.0
+ )
+ type_stats["success_rate"] = (
+ (type_stats["wins"] + type_stats["ties"]) / total_decisions
+ if total_decisions > 0
+ else 0.0
+ )
+
+ stats["processing_type_stats"][processing_type] = type_stats
+
+ results["statistics"] = stats
+
+ def analyze_bad_cases(self, bad_cases: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ """
+ Analyze all bad cases to determine memory sufficiency.
+
+ Args:
+ bad_cases: List of bad cases to analyze
+
+ Returns:
+ List of analyzed bad cases with sufficiency information
+ """
+ analyzed_cases = []
+
+ for i, case in enumerate(bad_cases):
+ logger.info(f"Analyzing bad case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...")
+
+ analysis = self.analyze_memory_sufficiency(
+ case["query"], case["golden_answer"], case["memories"]
+ )
+
+ # Add analysis results to the case
+ analyzed_case = case.copy()
+ analyzed_case.update(
+ {
+ "memory_analysis": analysis,
+ "has_sufficient_memories": analysis["sufficient"],
+ "analysis_confidence": analysis["confidence"],
+ "relevant_memory_count": len(analysis["relevant_memories"]),
+ }
+ )
+
+ analyzed_cases.append(analyzed_case)
+
+ return analyzed_cases
+
+ def collect_bad_cases(self, eval_result_dir: str | None = None) -> dict[str, Any]:
+ """
+ Main method to collect and analyze bad cases from evaluation results.
+
+ Args:
+ eval_result_dir: Directory containing evaluation results
+
+ Returns:
+ Dictionary containing analysis results and statistics
+ """
+ if eval_result_dir is None:
+ eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-072005-fast"
+
+ judged_file = os.path.join(eval_result_dir, "memos-api_locomo_judged.json")
+ search_results_file = os.path.join(eval_result_dir, "memos-api_locomo_search_results.json")
+
+ # Extract bad cases
+ bad_cases = self.extract_bad_cases(judged_file, search_results_file)
+
+ if not bad_cases:
+ logger.warning("No bad cases found")
+ return {"bad_cases": [], "statistics": {}}
+
+ # Analyze bad cases
+ analyzed_cases = self.analyze_bad_cases(bad_cases)
+
+ # Calculate statistics
+ total_cases = len(analyzed_cases)
+ sufficient_cases = sum(
+ 1 for case in analyzed_cases if case.get("has_sufficient_memories", False)
+ )
+ insufficient_cases = total_cases - sufficient_cases
+
+ avg_confidence = (
+ sum(case["analysis_confidence"] for case in analyzed_cases) / total_cases
+ if total_cases > 0
+ else 0
+ )
+ avg_relevant_memories = (
+ sum(case["relevant_memory_count"] for case in analyzed_cases) / total_cases
+ if total_cases > 0
+ else 0
+ )
+
+ statistics = {
+ "total_bad_cases": total_cases,
+ "sufficient_memory_cases": sufficient_cases,
+ "insufficient_memory_cases": insufficient_cases,
+ "sufficiency_rate": sufficient_cases / total_cases if total_cases > 0 else 0,
+ "average_confidence": avg_confidence,
+ "average_relevant_memories": avg_relevant_memories,
+ }
+
+ # Save results
+ results = {
+ "bad_cases": analyzed_cases,
+ "statistics": statistics,
+ "metadata": {
+ "eval_result_dir": eval_result_dir,
+ "judged_file": judged_file,
+ "search_results_file": search_results_file,
+ "analysis_model": self.openai_model,
+ },
+ }
+
+ output_file = self.output_dir / "bad_cases_analysis.json"
+ with open(output_file, "w", encoding="utf-8") as f:
+ json.dump(results, f, indent=2, ensure_ascii=False)
+
+ logger.info(f"Analysis complete. Results saved to: {output_file}")
+ logger.info(f"Statistics: {statistics}")
+
+ return results
+
+ def _parse_json_response(self, response_text: str) -> dict:
+ """
+ Parse JSON response from LLM, handling various formats and potential errors.
+
+ Args:
+ response_text: Raw response text from LLM
+
+ Returns:
+ Parsed JSON dictionary
+
+ Raises:
+ ValueError: If JSON cannot be parsed
+ """
+ import re
+
+ # Try to extract JSON from response text
+ # Look for JSON blocks between ```json and ``` or just {} blocks
+ json_patterns = [r"```json\s*(\{.*?\})\s*```", r"```\s*(\{.*?\})\s*```", r"(\{.*\})"]
+
+ for pattern in json_patterns:
+ matches = re.findall(pattern, response_text, re.DOTALL)
+ if matches:
+ json_str = matches[0].strip()
+ try:
+ return json.loads(json_str)
+ except json.JSONDecodeError:
+ continue
+
+ # If no JSON pattern found, try parsing the entire response
+ try:
+ return json.loads(response_text.strip())
+ except json.JSONDecodeError as e:
+ logger.error(f"Failed to parse JSON response: {response_text[:200]}...")
+ raise ValueError(f"Invalid JSON response: {e!s}") from e
+
+ def filter_memories_with_llm(self, memories: list[str], query: str) -> tuple[list[str], bool]:
+ """
+ Use LLM to filter memories based on relevance to the query.
+
+ Args:
+ memories: List of memory strings
+ query: Query to filter memories against
+
+ Returns:
+ Tuple of (filtered_memories, success_flag)
+ """
+ if not memories:
+ return [], True
+
+ # Build prompt for memory filtering
+ memories_text = "\n".join([f"{i + 1}. {memory}" for i, memory in enumerate(memories)])
+
+ prompt = f"""You are a memory filtering system. Given a query and a list of memories, identify which memories are relevant and non-redundant for answering the query.
+
+Query: {query}
+
+Memories:
+{memories_text}
+
+Please analyze each memory and return a JSON response with the following format:
+{{
+ "relevant_memory_indices": [list of indices (1-based) of memories that are relevant to the query],
+ "reasoning": "Brief explanation of your filtering decisions"
+}}
+
+Only include memories that are directly relevant to answering the query. Remove redundant or unrelated memories."""
+
+ try:
+ response = self.openai_client.chat.completions.create(
+ model=self.openai_model,
+ messages=[{"role": "user", "content": prompt}],
+ temperature=0.1,
+ )
+
+ response_text = response.choices[0].message.content
+
+ # Extract JSON from response
+ result = self._parse_json_response(response_text)
+
+ if "relevant_memory_indices" in result:
+ relevant_indices = result["relevant_memory_indices"]
+ filtered_memories = []
+
+ for idx in relevant_indices:
+ if 1 <= idx <= len(memories):
+ filtered_memories.append(memories[idx - 1])
+
+ logger.info(f"Filtered memories: {len(memories)} -> {len(filtered_memories)}")
+ return filtered_memories, True
+ else:
+ logger.warning("Invalid response format from memory filtering LLM")
+ return memories, False
+
+ except Exception as e:
+ logger.error(f"Error in memory filtering: {e}")
+ return memories, False
+
+ def evaluate_answer_ability_with_llm(self, query: str, memories: list[str]) -> bool:
+ """
+ Use LLM to evaluate whether the given memories can answer the query.
+
+ Args:
+ query: Query to evaluate
+ memories: List of memory strings
+
+ Returns:
+ Boolean indicating whether memories can answer the query
+ """
+ if not memories:
+ return False
+
+ memories_text = "\n".join([f"- {memory}" for memory in memories])
+
+ prompt = f"""You are an answer ability evaluator. Given a query and a list of memories, determine whether the memories contain sufficient information to answer the query.
+
+Query: {query}
+
+Available Memories:
+{memories_text}
+
+Please analyze the memories and return a JSON response with the following format:
+{{
+ "can_answer": true/false,
+ "confidence": 0.0-1.0,
+ "reasoning": "Brief explanation of your decision"
+}}
+
+Consider whether the memories contain the specific information needed to provide a complete and accurate answer to the query."""
+
+ try:
+ response = self.openai_client.chat.completions.create(
+ model=self.openai_model,
+ messages=[{"role": "user", "content": prompt}],
+ temperature=0.1,
+ )
+
+ response_text = response.choices[0].message.content
+ result = self._parse_json_response(response_text)
+
+ if "can_answer" in result:
+ can_answer = result["can_answer"]
+ confidence = result.get("confidence", 0.5)
+ reasoning = result.get("reasoning", "No reasoning provided")
+
+ logger.info(
+ f"Answer ability evaluation: {can_answer} (confidence: {confidence:.2f}) - {reasoning}"
+ )
+ return can_answer
+ else:
+ logger.warning("Invalid response format from answer ability evaluation")
+ return False
+
+ except Exception as e:
+ logger.error(f"Error in answer ability evaluation: {e}")
+ return False
+
+ def memory_llm_processing_analysis(
+ self, bad_cases: list[dict[str, Any]], use_llm_filtering: bool = True
+ ) -> list[dict[str, Any]]:
+ """
+ Analyze bad cases by processing memories with LLM filtering and testing answer ability.
+
+ This method:
+ 1. Parses memory strings from bad cases
+ 2. Uses LLM to filter unrelated and redundant memories
+ 3. Tests whether processed memories can help answer questions correctly
+ 4. Compares results before and after LLM processing
+
+ Args:
+ bad_cases: List of bad cases to analyze
+ use_llm_filtering: Whether to use LLM filtering
+
+ Returns:
+ List of analyzed bad cases with LLM processing results
+ """
+ analyzed_cases = []
+
+ for i, case in enumerate(bad_cases):
+ logger.info(f"Processing bad case {i + 1}/{len(bad_cases)}: {case['query'][:50]}...")
+
+ try:
+ # Parse memory string
+ memories_text = case.get("memories", "")
+ if not memories_text:
+ logger.warning(f"No memories found for case {i + 1}")
+ analyzed_case = case.copy()
+ analyzed_case.update(
+ {
+ "llm_processing_analysis": {
+ "error": "No memories available",
+ "original_memories_count": 0,
+ "processed_memories_count": 0,
+ "can_answer_with_original": False,
+ "can_answer_with_processed": False,
+ "processing_improved_answer": False,
+ }
+ }
+ )
+ analyzed_cases.append(analyzed_case)
+ continue
+
+ # Split memories by lines
+ memory_lines = [line.strip() for line in memories_text.split("\n") if line.strip()]
+ original_memories = [line for line in memory_lines if line]
+
+ logger.info(f"Parsed {len(original_memories)} memories from text")
+
+ # Test answer ability with original memories
+ can_answer_original = self.evaluate_answer_ability_with_llm(
+ query=case["query"], memories=original_memories
+ )
+
+ # Process memories with LLM filtering if enabled
+ processed_memories = original_memories
+ processing_success = False
+
+ if use_llm_filtering and len(original_memories) > 0:
+ processed_memories, processing_success = self.filter_memories_with_llm(
+ memories=original_memories, query=case["query"]
+ )
+ logger.info(
+ f"LLM filtering: {len(original_memories)} -> {len(processed_memories)} memories, success: {processing_success}"
+ )
+
+ # Test answer ability with processed memories
+ can_answer_processed = self.evaluate_answer_ability_with_llm(
+ query=case["query"], memories=processed_memories
+ )
+
+ # Determine if processing improved answer ability
+ processing_improved = can_answer_processed and not can_answer_original
+
+ # Create analysis result
+ llm_analysis = {
+ "processing_success": processing_success,
+ "original_memories_count": len(original_memories),
+ "processed_memories_count": len(processed_memories),
+ "memories_removed_count": len(original_memories) - len(processed_memories),
+ "can_answer_with_original": can_answer_original,
+ "can_answer_with_processed": can_answer_processed,
+ "processing_improved_answer": processing_improved,
+ "original_memories": original_memories,
+ "processed_memories": processed_memories,
+ }
+
+ # Add analysis to case
+ analyzed_case = case.copy()
+ analyzed_case["llm_processing_analysis"] = llm_analysis
+
+ logger.info(
+ f"Case {i + 1} analysis complete: "
+ f"Original: {can_answer_original}, "
+ f"Processed: {can_answer_processed}, "
+ f"Improved: {processing_improved}"
+ )
+
+ except Exception as e:
+ logger.error(f"Error processing case {i + 1}: {e}")
+ analyzed_case = case.copy()
+ analyzed_case["llm_processing_analysis"] = {
+ "error": str(e),
+ "processing_success": False,
+ "original_memories_count": 0,
+ "processed_memories_count": 0,
+ "can_answer_with_original": False,
+ "can_answer_with_processed": False,
+ "processing_improved_answer": False,
+ }
+
+ analyzed_cases.append(analyzed_case)
+
+ return analyzed_cases
+
+ def scheduler_mem_process(self, query, memories):
+ from memos.mem_scheduler.utils.misc_utils import extract_list_items_in_answer
+
+ _memories = []
+ for mem in memories:
+ mem_item = TextualMemoryItem(memory=mem, metadata=TextualMemoryMetadata())
+ _memories.append(mem_item)
+ prompt = mem_scheduler.retriever._build_enhancement_prompt(
+ query_history=[query], batch_texts=memories
+ )
+ logger.debug(
+ f"[Enhance][batch={0}] Prompt (first 200 chars, len={len(prompt)}): {prompt[:200]}..."
+ )
+
+ response = mem_scheduler.retriever.process_llm.generate(
+ [{"role": "user", "content": prompt}]
+ )
+ logger.debug(f"[Enhance][batch={0}] Response (first 200 chars): {response[:200]}...")
+
+ processed_results = extract_list_items_in_answer(response)
+
+ return {
+ "processed_memories": processed_results,
+ "processing_type": "enhance",
+ "original_length": len("\n".join(memories)),
+ "processed_length": len("\n".join(processed_results)),
+ "compression_ratio": len("\n".join(processed_results)) / len("\n".join(memories))
+ if len(memories) > 0
+ else 0,
+ }
+
+ def analyze_bad_cases_with_llm_processing(
+ self,
+ bad_cases: list[dict[str, Any]],
+ save_results: bool = True,
+ output_file: str | None = None,
+ ) -> dict[str, Any]:
+ """
+ Comprehensive analysis of bad cases with LLM memory processing.
+
+ This method performs a complete analysis including:
+ 1. Basic bad case analysis
+ 2. LLM memory processing analysis
+ 3. Statistical summary of improvements
+ 4. Detailed reporting
+
+ Args:
+ bad_cases: List of bad cases to analyze
+ save_results: Whether to save results to file
+ output_file: Optional output file path
+
+ Returns:
+ Dictionary containing comprehensive analysis results
+ """
+ from datetime import datetime
+
+ logger.info(
+ f"Starting comprehensive analysis of {len(bad_cases)} bad cases with LLM processing"
+ )
+
+ # Perform LLM memory processing analysis
+ analyzed_cases = self.memory_llm_processing_analysis(
+ bad_cases=bad_cases, use_llm_filtering=True
+ )
+
+ # Calculate statistics
+ total_cases = len(analyzed_cases)
+ successful_processing = 0
+ improved_cases = 0
+ original_answerable = 0
+ processed_answerable = 0
+ total_memories_before = 0
+ total_memories_after = 0
+
+ for case in analyzed_cases:
+ llm_analysis = case.get("llm_processing_analysis", {})
+
+ if llm_analysis.get("processing_success", False):
+ successful_processing += 1
+
+ if llm_analysis.get("processing_improved_answer", False):
+ improved_cases += 1
+
+ if llm_analysis.get("can_answer_with_original", False):
+ original_answerable += 1
+
+ if llm_analysis.get("can_answer_with_processed", False):
+ processed_answerable += 1
+
+ total_memories_before += llm_analysis.get("original_memories_count", 0)
+ total_memories_after += llm_analysis.get("processed_memories_count", 0)
+
+ # Calculate improvement metrics
+ processing_success_rate = successful_processing / total_cases if total_cases > 0 else 0
+ improvement_rate = improved_cases / total_cases if total_cases > 0 else 0
+ original_answer_rate = original_answerable / total_cases if total_cases > 0 else 0
+ processed_answer_rate = processed_answerable / total_cases if total_cases > 0 else 0
+ memory_reduction_rate = (
+ (total_memories_before - total_memories_after) / total_memories_before
+ if total_memories_before > 0
+ else 0
+ )
+
+ # Create comprehensive results
+ results = {
+ "analysis_metadata": {
+ "total_cases_analyzed": total_cases,
+ "analysis_timestamp": datetime.now().isoformat(),
+ "llm_model_used": self.openai_model,
+ },
+ "processing_statistics": {
+ "successful_processing_count": successful_processing,
+ "processing_success_rate": processing_success_rate,
+ "cases_with_improvement": improved_cases,
+ "improvement_rate": improvement_rate,
+ "original_answerable_cases": original_answerable,
+ "original_answer_rate": original_answer_rate,
+ "processed_answerable_cases": processed_answerable,
+ "processed_answer_rate": processed_answer_rate,
+ "answer_rate_improvement": processed_answer_rate - original_answer_rate,
+ },
+ "memory_statistics": {
+ "total_memories_before_processing": total_memories_before,
+ "total_memories_after_processing": total_memories_after,
+ "memories_removed": total_memories_before - total_memories_after,
+ "memory_reduction_rate": memory_reduction_rate,
+ "average_memories_per_case_before": total_memories_before / total_cases
+ if total_cases > 0
+ else 0,
+ "average_memories_per_case_after": total_memories_after / total_cases
+ if total_cases > 0
+ else 0,
+ },
+ "analyzed_cases": analyzed_cases,
+ }
+
+ # Log summary
+ logger.info("LLM Processing Analysis Summary:")
+ logger.info(f" - Total cases: {total_cases}")
+ logger.info(f" - Processing success rate: {processing_success_rate:.2%}")
+ logger.info(f" - Cases with improvement: {improved_cases} ({improvement_rate:.2%})")
+ logger.info(f" - Original answer rate: {original_answer_rate:.2%}")
+ logger.info(f" - Processed answer rate: {processed_answer_rate:.2%}")
+ logger.info(
+ f" - Answer rate improvement: {processed_answer_rate - original_answer_rate:.2%}"
+ )
+ logger.info(f" - Memory reduction: {memory_reduction_rate:.2%}")
+
+ # Save results if requested
+ if save_results:
+ if output_file is None:
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
+ output_file = f"llm_processing_analysis_{timestamp}.json"
+
+ try:
+ with open(output_file, "w", encoding="utf-8") as f:
+ json.dump(results, f, indent=2, ensure_ascii=False)
+ logger.info(f"Analysis results saved to: {output_file}")
+ except Exception as e:
+ logger.error(f"Failed to save results to {output_file}: {e}")
+
+ return results
+
+
+def main(version_name="ct-1111"):
+ """Main test function."""
+ print("=== EvalAnalyzer Simple Test ===")
+
+ # Initialize analyzer
+ analyzer = EvalAnalyzer(output_dir="./tmp/eval_analyzer")
+
+ print("Analyzer initialized")
+
+ # Test file paths
+ eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-{version_name}-locomo"
+ judged_file = os.path.join(eval_result_dir, "memos-api_locomo_judged.json")
+ search_results_file = os.path.join(eval_result_dir, "memos-api_locomo_search_results.json")
+
+ print("Testing with files:")
+ print(f" Judged file: {judged_file}")
+ print(f" Search results file: {search_results_file}")
+
+ # Check if files exist
+ if not os.path.exists(judged_file):
+ print(f"❌ Judged file not found: {judged_file}")
+ return
+
+ if not os.path.exists(search_results_file):
+ print(f"❌ Search results file not found: {search_results_file}")
+ return
+
+ print("✅ Both files exist")
+
+ # Test bad case extraction only
+ try:
+ print("\n=== Testing Bad Case Extraction ===")
+ bad_cases = analyzer.extract_bad_cases(judged_file, search_results_file)
+
+ print(f"✅ Successfully extracted {len(bad_cases)} bad cases")
+
+ if bad_cases:
+ print("\n=== Sample Bad Cases ===")
+ for i, case in enumerate(bad_cases[:3]): # Show first 3 cases
+ print(f"\nBad Case {i + 1}:")
+ print(f" User ID: {case['user_id']}")
+ print(f" Query: {case['query'][:100]}...")
+ print(f" Golden Answer: {case['golden_answer']}...")
+ print(f" Answer: {case['answer']}...")
+ print(f" Has Memories: {len(case['memories']) > 0}")
+ print(f" Memory Length: {len(case['memories'])} chars")
+
+ # Save basic results without LLM analysis
+ basic_results = {
+ "bad_cases_count": len(bad_cases),
+ "bad_cases": bad_cases,
+ "metadata": {
+ "eval_result_dir": eval_result_dir,
+ "judged_file": judged_file,
+ "search_results_file": search_results_file,
+ "extraction_only": True,
+ },
+ }
+
+ output_file = analyzer.output_dir / "bad_cases_extraction_only.json"
+ import json
+
+ with open(output_file, "w", encoding="utf-8") as f:
+ json.dump(basic_results, f, indent=2, ensure_ascii=False)
+
+ print(f"\n✅ Basic extraction results saved to: {output_file}")
+
+ except Exception as e:
+ print(f"❌ Error during extraction: {e}")
+ import traceback
+
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/memos/mem_scheduler/analyzer/memory_processing.py b/src/memos/mem_scheduler/analyzer/memory_processing.py
new file mode 100644
index 000000000..b692341c2
--- /dev/null
+++ b/src/memos/mem_scheduler/analyzer/memory_processing.py
@@ -0,0 +1,246 @@
+#!/usr/bin/env python3
+"""
+Test script for memory processing functionality in eval_analyzer.py
+
+This script demonstrates how to use the new LLM memory processing features
+to analyze and improve memory-based question answering.
+"""
+
+import json
+import os
+import sys
+
+from pathlib import Path
+from typing import Any
+
+from memos.log import get_logger
+from memos.mem_scheduler.analyzer.eval_analyzer import EvalAnalyzer
+
+
+FILE_PATH = Path(__file__).absolute()
+BASE_DIR = FILE_PATH.parent # Go up to project root
+sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory
+
+
+logger = get_logger(__name__)
+
+
+def create_sample_bad_cases() -> list[dict[str, Any]]:
+ """Create sample bad cases for testing memory processing."""
+ return [
+ {
+ "query": "What is the capital of France?",
+ "golden_answer": "Paris",
+ "memories": """
+ Memory 1: France is a country in Western Europe.
+ Memory 2: The Eiffel Tower is located in Paris.
+ Memory 3: Paris is known for its art museums and fashion.
+ Memory 4: French cuisine is famous worldwide.
+ Memory 5: The Seine River flows through Paris.
+ """,
+ },
+ {
+ "query": "When was the iPhone first released?",
+ "golden_answer": "June 29, 2007",
+ "memories": """
+ Memory 1: Apple Inc. was founded by Steve Jobs, Steve Wozniak, and Ronald Wayne.
+ Memory 2: The iPhone was announced by Steve Jobs at the Macworld Conference & Expo on January 9, 2007.
+ Memory 3: The iPhone went on sale on June 29, 2007.
+ Memory 4: The original iPhone had a 3.5-inch screen.
+ Memory 5: Apple's stock price increased significantly after the iPhone launch.
+ """,
+ },
+ {
+ "query": "What is photosynthesis?",
+ "golden_answer": "Photosynthesis is the process by which plants use sunlight, water, and carbon dioxide to produce glucose and oxygen.",
+ "memories": """
+ Memory 1: Plants are living organisms that need sunlight to grow.
+ Memory 2: Chlorophyll is the green pigment in plants.
+ Memory 3: Plants take in carbon dioxide from the air.
+ Memory 4: Water is absorbed by plant roots from the soil.
+ Memory 5: Oxygen is released by plants during the day.
+ Memory 6: Glucose is a type of sugar that plants produce.
+ """,
+ },
+ ]
+
+
+def memory_processing(bad_cases):
+ """
+ Test the memory processing functionality with cover rate and acc rate analysis.
+
+ This function analyzes:
+ 1. Cover rate: Whether memories contain all information needed to answer the query
+ 2. Acc rate: Whether processed memories can correctly answer the query
+ """
+ print("🧪 Testing Memory Processing Functionality with Cover Rate & Acc Rate Analysis")
+ print("=" * 80)
+
+ # Initialize analyzer
+ analyzer = EvalAnalyzer()
+
+ print(f"📊 Testing with {len(bad_cases)} sample cases")
+ print()
+
+ # Initialize counters for real-time statistics
+ total_cases = 0
+ cover_count = 0 # Cases where memories cover all needed information
+ acc_count = 0 # Cases where processed memories can correctly answer
+
+ # Process each case
+ for i, case in enumerate(bad_cases):
+ total_cases += 1
+
+ # Safely handle query display
+ query_display = str(case.get("query", "Unknown query"))
+ print(f"🔍 Case {i + 1}/{len(bad_cases)}: {query_display}...")
+
+ # Safely handle golden_answer display (convert to string if needed)
+ golden_answer = case.get("golden_answer", "Unknown answer")
+ golden_answer_str = str(golden_answer) if golden_answer is not None else "Unknown answer"
+ print(f"📝 Golden Answer: {golden_answer_str}")
+ print()
+
+ # Step 1: Analyze if memories contain sufficient information (Cover Rate)
+ print(" 📋 Step 1: Analyzing memory coverage...")
+ coverage_analysis = analyzer.analyze_memory_sufficiency(
+ case["query"],
+ golden_answer_str, # Use the string version
+ case["memories"],
+ )
+
+ has_coverage = coverage_analysis.get("sufficient", False)
+ if has_coverage:
+ cover_count += 1
+
+ print(f" ✅ Memory Coverage: {'SUFFICIENT' if has_coverage else 'INSUFFICIENT'}")
+ print(f" 🎯 Confidence: {coverage_analysis.get('confidence', 0):.2f}")
+ print(f" 💭 Reasoning: {coverage_analysis.get('reasoning', 'N/A')}...")
+ if not has_coverage:
+ print(
+ f" ❌ Missing Info: {coverage_analysis.get('missing_information', 'N/A')[:100]}..."
+ )
+ continue
+ print()
+
+ # Step 2: Process memories and test answer ability (Acc Rate)
+ print(" 🔄 Step 2: Processing memories and testing answer ability...")
+
+ processing_result = analyzer.scheduler_mem_process(
+ query=case["query"],
+ memories=case["memories"],
+ )
+ print(f"Original Memories: {case['memories']}")
+ print(f"Processed Memories: {processing_result['processed_memories']}")
+ print(f" 📏 Compression ratio: {processing_result['compression_ratio']:.2f}")
+ print(f" 📄 Processed memories length: {processing_result['processed_length']} chars")
+
+ # Generate answer with processed memories
+ answer_result = analyzer.generate_answer_with_memories(
+ case["query"], processing_result["processed_memories"], "processed_enhanced"
+ )
+
+ # Evaluate if the generated answer is correct
+ print(" 🎯 Step 3: Evaluating answer correctness...")
+ answer_evaluation = analyzer.compare_answer_quality(
+ case["query"],
+ golden_answer_str, # Use the string version
+ "No original answer available", # We don't have original answer
+ answer_result["answer"],
+ )
+
+ # Determine if processed memories can correctly answer (simplified logic)
+ processed_accuracy = answer_evaluation.get("processed_scores", {}).get("accuracy", 0)
+ can_answer_correctly = processed_accuracy >= 0.7 # Threshold for "correct" answer
+
+ if can_answer_correctly:
+ acc_count += 1
+
+ print(f" 💬 Generated Answer: {answer_result['answer']}...")
+ print(
+ f" ✅ Answer Accuracy: {'CORRECT' if can_answer_correctly else 'INCORRECT'} (score: {processed_accuracy:.2f})"
+ )
+ print()
+
+ # Calculate and print real-time rates
+ current_cover_rate = cover_count / total_cases
+ current_acc_rate = acc_count / total_cases
+
+ print(" 📊 REAL-TIME STATISTICS:")
+ print(f" 🎯 Cover Rate: {current_cover_rate:.2%} ({cover_count}/{total_cases})")
+ print(f" ✅ Acc Rate: {current_acc_rate:.2%} ({acc_count}/{total_cases})")
+ print()
+
+ print("-" * 80)
+ print()
+
+ # Final summary
+ print("🏁 FINAL ANALYSIS SUMMARY")
+ print("=" * 80)
+ print(f"📊 Total Cases Processed: {total_cases}")
+ print(f"🎯 Final Cover Rate: {cover_count / total_cases:.2%} ({cover_count}/{total_cases})")
+ print(f" - Cases with sufficient memory coverage: {cover_count}")
+ print(f" - Cases with insufficient memory coverage: {total_cases - cover_count}")
+ print()
+ print(f"✅ Final Acc Rate: {acc_count / total_cases:.2%} ({acc_count}/{total_cases})")
+ print(f" - Cases where processed memories can answer correctly: {acc_count}")
+ print(f" - Cases where processed memories cannot answer correctly: {total_cases - acc_count}")
+ print()
+
+ # Additional insights
+ if cover_count > 0:
+ effective_processing_rate = acc_count / cover_count if cover_count > 0 else 0
+ print(f"🔄 Processing Effectiveness: {effective_processing_rate:.2%}")
+ print(
+ f" - Among cases with sufficient coverage, {effective_processing_rate:.1%} can be answered correctly after processing"
+ )
+
+ print("=" * 80)
+
+
+def load_real_bad_cases(file_path: str) -> list[dict[str, Any]]:
+ """Load real bad cases from JSON file."""
+ print(f"📂 Loading bad cases from: {file_path}")
+
+ with open(file_path, encoding="utf-8") as f:
+ data = json.load(f)
+
+ bad_cases = data.get("bad_cases", [])
+ print(f"✅ Loaded {len(bad_cases)} bad cases")
+
+ return bad_cases
+
+
+def main():
+ """Main test function."""
+ print("🚀 Memory Processing Test Suite")
+ print("=" * 60)
+ print()
+
+ # Check if OpenAI API key is set
+ if not os.getenv("OPENAI_API_KEY"):
+ print("⚠️ Warning: OPENAI_API_KEY not found in environment variables")
+ print(" Please set your OpenAI API key to run the tests")
+ return
+
+ try:
+ bad_cases_file = f"{BASE_DIR}/tmp/eval_analyzer/bad_cases_extraction_only.json"
+ bad_cases = load_real_bad_cases(bad_cases_file)
+
+ print(f"✅ Created {len(bad_cases)} sample bad cases")
+ print()
+
+ # Run memory processing tests
+ memory_processing(bad_cases)
+
+ print("✅ All tests completed successfully!")
+
+ except Exception as e:
+ print(f"❌ Test failed with error: {e}")
+ import traceback
+
+ traceback.print_exc()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py
index ace67eff6..df504ee75 100644
--- a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py
+++ b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py
@@ -427,7 +427,6 @@ def chat(self, query: str, user_id: str | None = None) -> str:
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=QUERY_LABEL,
content=query,
timestamp=datetime.now(),
@@ -518,12 +517,11 @@ def chat(self, query: str, user_id: str | None = None) -> str:
message_item = ScheduleMessageItem(
user_id=target_user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
label=ANSWER_LABEL,
content=response,
timestamp=datetime.now(),
)
- self.mem_scheduler.submit_messages(messages=[message_item])
+ self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item])
return response
diff --git a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py
index 7c0fa5a4a..3d0235871 100644
--- a/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py
+++ b/src/memos/mem_scheduler/analyzer/scheduler_for_eval.py
@@ -226,9 +226,9 @@ def evaluate_memory_answer_ability(
try:
# Extract JSON response
- from memos.mem_scheduler.utils.misc_utils import extract_json_dict
+ from memos.mem_scheduler.utils.misc_utils import extract_json_obj
- result = extract_json_dict(response)
+ result = extract_json_obj(response)
# Validate response structure
if "result" in result:
diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py
index 028fe8e3f..6ad7f5cdd 100644
--- a/src/memos/mem_scheduler/base_scheduler.py
+++ b/src/memos/mem_scheduler/base_scheduler.py
@@ -1,6 +1,5 @@
-import contextlib
import multiprocessing
-import queue
+import os
import threading
import time
@@ -15,8 +14,8 @@
from memos.context.context import ContextThread
from memos.llms.base import BaseLLM
from memos.log import get_logger
+from memos.mem_cube.base import BaseMemCube
from memos.mem_cube.general import GeneralMemCube
-from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher
from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue
from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule
from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever
@@ -24,9 +23,11 @@
from memos.mem_scheduler.monitors.general_monitor import SchedulerGeneralMonitor
from memos.mem_scheduler.schemas.general_schemas import (
DEFAULT_ACT_MEM_DUMP_PATH,
+ DEFAULT_CONSUME_BATCH,
DEFAULT_CONSUME_INTERVAL_SECONDS,
DEFAULT_CONTEXT_WINDOW_SIZE,
DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE,
+ DEFAULT_MAX_WEB_LOG_QUEUE_SIZE,
DEFAULT_STARTUP_MODE,
DEFAULT_THREAD_POOL_MAX_WORKERS,
DEFAULT_TOP_K,
@@ -41,6 +42,9 @@
ScheduleMessageItem,
)
from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem
+from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher
+from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
+from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue
from memos.mem_scheduler.utils.db_utils import get_utc_now
from memos.mem_scheduler.utils.filter_utils import (
transform_name_to_key,
@@ -50,11 +54,12 @@
from memos.memories.activation.kv import KVCacheMemory
from memos.memories.activation.vllmkv import VLLMKVCacheItem, VLLMKVCacheMemory
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
+from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
from memos.templates.mem_scheduler_prompts import MEMORY_ASSEMBLY_TEMPLATE
if TYPE_CHECKING:
- from memos.mem_cube.base import BaseMemCube
+ from memos.reranker.http_bge import HTTPBGEReranker
logger = get_logger(__name__)
@@ -86,37 +91,12 @@ def __init__(self, config: BaseSchedulerConfig):
"scheduler_startup_mode", DEFAULT_STARTUP_MODE
)
- self.retriever: SchedulerRetriever | None = None
- self.db_engine: Engine | None = None
- self.monitor: SchedulerGeneralMonitor | None = None
- self.dispatcher_monitor: SchedulerDispatcherMonitor | None = None
- self.mem_reader = None # Will be set by MOSCore
- self.dispatcher = SchedulerDispatcher(
- config=self.config,
- max_workers=self.thread_pool_max_workers,
- enable_parallel_dispatch=self.enable_parallel_dispatch,
- )
-
# optional configs
- self.disable_handlers: list | None = self.config.get("disable_handlers", None)
+ self.disabled_handlers: list | None = self.config.get("disabled_handlers", None)
- # message queue configuration
- self.use_redis_queue = self.config.get("use_redis_queue", DEFAULT_USE_REDIS_QUEUE)
- self.max_internal_message_queue_size = self.config.get(
- "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE
+ self.max_web_log_queue_size = self.config.get(
+ "max_web_log_queue_size", DEFAULT_MAX_WEB_LOG_QUEUE_SIZE
)
-
- # Initialize message queue based on configuration
- if self.use_redis_queue:
- self.memos_message_queue = None # Will use Redis instead
- # Initialize Redis if using Redis queue with auto-initialization
- self.auto_initialize_redis()
- else:
- self.memos_message_queue: Queue[ScheduleMessageItem] = Queue(
- maxsize=self.max_internal_message_queue_size
- )
-
- self.max_web_log_queue_size = self.config.get("max_web_log_queue_size", 50)
self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue(
maxsize=self.max_web_log_queue_size
)
@@ -126,6 +106,31 @@ def __init__(self, config: BaseSchedulerConfig):
self._consume_interval = self.config.get(
"consume_interval_seconds", DEFAULT_CONSUME_INTERVAL_SECONDS
)
+ self.consume_batch = self.config.get("consume_batch", DEFAULT_CONSUME_BATCH)
+
+ # message queue configuration
+ self.use_redis_queue = self.config.get("use_redis_queue", DEFAULT_USE_REDIS_QUEUE)
+ self.max_internal_message_queue_size = self.config.get(
+ "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE
+ )
+ self.memos_message_queue = ScheduleTaskQueue(
+ use_redis_queue=self.use_redis_queue,
+ maxsize=self.max_internal_message_queue_size,
+ disabled_handlers=self.disabled_handlers,
+ )
+ self.searcher: Searcher | None = None
+ self.retriever: SchedulerRetriever | None = None
+ self.db_engine: Engine | None = None
+ self.monitor: SchedulerGeneralMonitor | None = None
+ self.dispatcher_monitor: SchedulerDispatcherMonitor | None = None
+ self.mem_reader = None # Will be set by MOSCore
+ self.dispatcher = SchedulerDispatcher(
+ config=self.config,
+ memos_message_queue=self.memos_message_queue,
+ use_redis_queue=self.use_redis_queue,
+ max_workers=self.thread_pool_max_workers,
+ enable_parallel_dispatch=self.enable_parallel_dispatch,
+ )
# other attributes
self._context_lock = threading.Lock()
@@ -136,6 +141,22 @@ def __init__(self, config: BaseSchedulerConfig):
self.auth_config = None
self.rabbitmq_config = None
+ def init_mem_cube(
+ self,
+ mem_cube: BaseMemCube,
+ searcher: Searcher | None = None,
+ ):
+ self.mem_cube = mem_cube
+ self.text_mem: TreeTextMemory = self.mem_cube.text_mem
+ self.reranker: HTTPBGEReranker = self.text_mem.reranker
+ if searcher is None:
+ self.searcher: Searcher = self.text_mem.get_searcher(
+ manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false",
+ moscube=False,
+ )
+ else:
+ self.searcher = searcher
+
def initialize_modules(
self,
chat_llm: BaseLLM,
@@ -192,6 +213,9 @@ def initialize_modules(
# start queue monitor if enabled and a bot is set later
+ def debug_mode_on(self):
+ self.memos_message_queue.debug_mode_on()
+
def _cleanup_on_init_failure(self):
"""Clean up resources if initialization fails."""
try:
@@ -201,23 +225,16 @@ def _cleanup_on_init_failure(self):
logger.warning(f"Error during cleanup: {e}")
@property
- def mem_cube(self) -> GeneralMemCube:
+ def mem_cube(self) -> BaseMemCube:
"""The memory cube associated with this MemChat."""
return self.current_mem_cube
@mem_cube.setter
- def mem_cube(self, value: GeneralMemCube) -> None:
+ def mem_cube(self, value: BaseMemCube) -> None:
"""The memory cube associated with this MemChat."""
self.current_mem_cube = value
self.retriever.mem_cube = value
- def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None:
- """Update current user/cube context from the incoming message (thread-safe)."""
- with self._context_lock:
- self.current_user_id = msg.user_id
- self.current_mem_cube_id = msg.mem_cube_id
- self.current_mem_cube = msg.mem_cube
-
def transform_working_memories_to_monitors(
self, query_keywords, memories: list[TextualMemoryItem]
) -> list[MemoryMonitorItem]:
@@ -516,37 +533,7 @@ def update_activation_memory_periodically(
logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True)
def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]):
- """Submit messages to the message queue (either local queue or Redis)."""
- if isinstance(messages, ScheduleMessageItem):
- messages = [messages] # transform single message to list
-
- for message in messages:
- if not isinstance(message, ScheduleMessageItem):
- error_msg = f"Invalid message type: {type(message)}, expected ScheduleMessageItem"
- logger.error(error_msg)
- raise TypeError(error_msg)
-
- if getattr(message, "timestamp", None) is None:
- with contextlib.suppress(Exception):
- message.timestamp = datetime.utcnow()
-
- if self.disable_handlers and message.label in self.disable_handlers:
- logger.info(f"Skipping disabled handler: {message.label} - {message.content}")
- continue
-
- if self.use_redis_queue:
- # Use Redis stream for message queue
- self.redis_add_message_stream(message.to_dict())
- logger.info(f"Submitted message to Redis: {message.label} - {message.content}")
- else:
- # Use local queue
- self.memos_message_queue.put(message)
- logger.info(
- f"Submitted message to local queue: {message.label} - {message.content}"
- )
- with contextlib.suppress(Exception):
- if messages:
- self.dispatcher.on_messages_enqueued(messages)
+ self.memos_message_queue.submit_messages(messages=messages)
def _submit_web_logs(
self, messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem]
@@ -590,7 +577,7 @@ def get_web_log_messages(self) -> list[dict]:
try:
item = self._web_log_message_queue.get_nowait() # Thread-safe get
messages.append(item.to_dict())
- except queue.Empty:
+ except Exception:
break
return messages
@@ -601,62 +588,34 @@ def _message_consumer(self) -> None:
Runs in a dedicated thread to process messages at regular intervals.
For Redis queue, this method starts the Redis listener.
"""
- if self.use_redis_queue:
- # For Redis queue, start the Redis listener
- def redis_message_handler(message_data):
- """Handler for Redis messages"""
- try:
- # Redis message data needs to be decoded from bytes to string
- decoded_data = {}
- for key, value in message_data.items():
- if isinstance(key, bytes):
- key = key.decode("utf-8")
- if isinstance(value, bytes):
- value = value.decode("utf-8")
- decoded_data[key] = value
-
- message = ScheduleMessageItem.from_dict(decoded_data)
- self.dispatcher.dispatch([message])
- except Exception as e:
- logger.error(f"Error processing Redis message: {e}")
- logger.error(f"Message data: {message_data}")
-
- self.redis_start_listening(handler=redis_message_handler)
-
- # Keep the thread alive while Redis listener is running
- while self._running:
- time.sleep(self._consume_interval)
- else:
- # Original local queue logic
- while self._running: # Use a running flag for graceful shutdown
- try:
- # Get all available messages at once (thread-safe approach)
- messages = []
- while True:
- try:
- # Use get_nowait() directly without empty() check to avoid race conditions
- message = self.memos_message_queue.get_nowait()
- messages.append(message)
- except queue.Empty:
- # No more messages available
- break
- if messages:
- try:
- self.dispatcher.dispatch(messages)
- except Exception as e:
- logger.error(f"Error dispatching messages: {e!s}")
- finally:
- # Mark all messages as processed
- for _ in messages:
- self.memos_message_queue.task_done()
+ # Original local queue logic
+ while self._running: # Use a running flag for graceful shutdown
+ try:
+ # Get messages in batches based on consume_batch setting
+
+ messages = self.memos_message_queue.get_messages(batch_size=self.consume_batch)
- # Sleep briefly to prevent busy waiting
- time.sleep(self._consume_interval) # Adjust interval as needed
+ if messages:
+ try:
+ import contextlib
- except Exception as e:
+ with contextlib.suppress(Exception):
+ if messages:
+ self.dispatcher.on_messages_enqueued(messages)
+
+ self.dispatcher.dispatch(messages)
+ except Exception as e:
+ logger.error(f"Error dispatching messages: {e!s}")
+
+ # Sleep briefly to prevent busy waiting
+ time.sleep(self._consume_interval) # Adjust interval as needed
+
+ except Exception as e:
+ # Don't log error for "No messages available in Redis queue" as it's expected
+ if "No messages available in Redis queue" not in str(e):
logger.error(f"Unexpected error in message consumer: {e!s}")
- time.sleep(self._consume_interval) # Prevent tight error loops
+ time.sleep(self._consume_interval) # Prevent tight error loops
def start(self) -> None:
"""
@@ -666,16 +625,25 @@ def start(self) -> None:
1. Message consumer thread or process (based on startup_mode)
2. Dispatcher thread pool (if parallel dispatch enabled)
"""
- if self._running:
- logger.warning("Memory Scheduler is already running")
- return
-
# Initialize dispatcher resources
if self.enable_parallel_dispatch:
logger.info(
f"Initializing dispatcher thread pool with {self.thread_pool_max_workers} workers"
)
+ self.start_consumer()
+
+ def start_consumer(self) -> None:
+ """
+ Start only the message consumer thread/process.
+
+ This method can be used to restart the consumer after it has been stopped
+ with stop_consumer(), without affecting other scheduler components.
+ """
+ if self._running:
+ logger.warning("Memory Scheduler consumer is already running")
+ return
+
# Start consumer based on startup mode
self._running = True
@@ -698,15 +666,15 @@ def start(self) -> None:
self._consumer_thread.start()
logger.info("Message consumer thread started")
- def stop(self) -> None:
- """Stop all scheduler components gracefully.
+ def stop_consumer(self) -> None:
+ """Stop only the message consumer thread/process gracefully.
- 1. Stops message consumer thread/process
- 2. Shuts down dispatcher thread pool
- 3. Cleans up resources
+ This method stops the consumer without affecting other components like
+ dispatcher or monitors. Useful when you want to pause message processing
+ while keeping other scheduler components running.
"""
if not self._running:
- logger.warning("Memory Scheduler is not running")
+ logger.warning("Memory Scheduler consumer is not running")
return
# Signal consumer thread/process to stop
@@ -726,12 +694,30 @@ def stop(self) -> None:
logger.info("Consumer process terminated")
else:
logger.info("Consumer process stopped")
+ self._consumer_process = None
elif self._consumer_thread and self._consumer_thread.is_alive():
self._consumer_thread.join(timeout=5.0)
if self._consumer_thread.is_alive():
logger.warning("Consumer thread did not stop gracefully")
else:
logger.info("Consumer thread stopped")
+ self._consumer_thread = None
+
+ logger.info("Memory Scheduler consumer stopped")
+
+ def stop(self) -> None:
+ """Stop all scheduler components gracefully.
+
+ 1. Stops message consumer thread/process
+ 2. Shuts down dispatcher thread pool
+ 3. Cleans up resources
+ """
+ if not self._running:
+ logger.warning("Memory Scheduler is not running")
+ return
+
+ # Stop consumer first
+ self.stop_consumer()
# Shutdown dispatcher
if self.dispatcher:
@@ -743,10 +729,6 @@ def stop(self) -> None:
logger.info("Shutting down monitor...")
self.dispatcher_monitor.stop()
- # Clean up queues
- self._cleanup_queues()
- logger.info("Memory Scheduler stopped completely")
-
@property
def handlers(self) -> dict[str, Callable]:
"""
@@ -819,30 +801,6 @@ def get_running_tasks(self, filter_func: Callable | None = None) -> dict[str, di
return result
- def _cleanup_queues(self) -> None:
- """Ensure all queues are emptied and marked as closed."""
- if self.use_redis_queue:
- # For Redis queue, stop the listener and close connection
- try:
- self.redis_stop_listening()
- self.redis_close()
- except Exception as e:
- logger.error(f"Error cleaning up Redis connection: {e}")
- else:
- # Original local queue cleanup
- try:
- while not self.memos_message_queue.empty():
- self.memos_message_queue.get_nowait()
- self.memos_message_queue.task_done()
- except queue.Empty:
- pass
-
- try:
- while not self._web_log_message_queue.empty():
- self._web_log_message_queue.get_nowait()
- except queue.Empty:
- pass
-
def mem_scheduler_wait(
self, timeout: float = 180.0, poll: float = 0.1, log_every: float = 0.01
) -> bool:
@@ -906,11 +864,24 @@ def _fmt_eta(seconds: float | None) -> str:
st = (
stats_fn()
) # expected: {'pending':int,'running':int,'done':int?,'rate':float?}
- pend = int(st.get("pending", 0))
run = int(st.get("running", 0))
+
except Exception:
pass
+ if isinstance(self.memos_message_queue, SchedulerRedisQueue):
+ # For Redis queue, prefer XINFO GROUPS to compute pending
+ groups_info = self.memos_message_queue.redis.xinfo_groups(
+ self.memos_message_queue.stream_key_prefix
+ )
+ if groups_info:
+ for group in groups_info:
+ if group.get("name") == self.memos_message_queue.consumer_group:
+ pend = int(group.get("pending", pend))
+ break
+ else:
+ pend = run
+
# 2) dynamic total (allows new tasks queued while waiting)
total_now = max(init_unfinished, done_total + curr_unfinished)
done_total = max(0, total_now - curr_unfinished)
diff --git a/src/memos/mem_scheduler/general_modules/base.py b/src/memos/mem_scheduler/general_modules/base.py
index 392f2bde3..e0ee65ba0 100644
--- a/src/memos/mem_scheduler/general_modules/base.py
+++ b/src/memos/mem_scheduler/general_modules/base.py
@@ -18,8 +18,6 @@ def __init__(self):
self._chat_llm = None
self._process_llm = None
- self.mem_cubes: dict[str, GeneralMemCube] = {}
-
def load_template(self, template_name: str) -> str:
if template_name not in PROMPT_MAPPING:
logger.error("Prompt template is not found!")
@@ -51,7 +49,7 @@ def _build_system_prompt(self, memories: list | None = None) -> str:
def get_mem_cube(self, mem_cube_id: str) -> GeneralMemCube:
logger.error(f"mem_cube {mem_cube_id} does not exists.")
- return self.mem_cubes.get(mem_cube_id, None)
+ return self.current_mem_cube
@property
def chat_llm(self) -> BaseLLM:
diff --git a/src/memos/mem_scheduler/general_modules/misc.py b/src/memos/mem_scheduler/general_modules/misc.py
index b6f48d043..e4e7edb89 100644
--- a/src/memos/mem_scheduler/general_modules/misc.py
+++ b/src/memos/mem_scheduler/general_modules/misc.py
@@ -199,6 +199,9 @@ class AutoDroppingQueue(Queue[T]):
"""A thread-safe queue that automatically drops the oldest item when full."""
def __init__(self, maxsize: int = 0):
+ # If maxsize <= 0, set to 0 (unlimited queue size)
+ if maxsize <= 0:
+ maxsize = 0
super().__init__(maxsize=maxsize)
def put(self, item: T, block: bool = False, timeout: float | None = None) -> None:
@@ -218,7 +221,7 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non
# First try non-blocking put
super().put(item, block=block, timeout=timeout)
except Full:
- # Remove oldest item and mark it done to avoid leaking unfinished_tasks
+ # Remove the oldest item and mark it done to avoid leaking unfinished_tasks
with suppress(Empty):
_ = self.get_nowait()
# If the removed item had previously incremented unfinished_tasks,
@@ -228,12 +231,70 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non
# Retry putting the new item
super().put(item, block=block, timeout=timeout)
+ def get(
+ self, block: bool = True, timeout: float | None = None, batch_size: int | None = None
+ ) -> list[T] | T:
+ """Get items from the queue.
+
+ Args:
+ block: Whether to block if no items are available (default: True)
+ timeout: Timeout in seconds for blocking operations (default: None)
+ batch_size: Number of items to retrieve (default: 1)
+
+ Returns:
+ List of items (always returns a list for consistency)
+
+ Raises:
+ Empty: If no items are available and block=False or timeout expires
+ """
+
+ if batch_size is None:
+ return super().get(block=block, timeout=timeout)
+ items = []
+ for _ in range(batch_size):
+ try:
+ items.append(super().get(block=block, timeout=timeout))
+ except Empty:
+ if not items and block:
+ # If we haven't gotten any items and we're blocking, re-raise Empty
+ raise
+ break
+ return items
+
+ def get_nowait(self, batch_size: int | None = None) -> list[T]:
+ """Get items from the queue without blocking.
+
+ Args:
+ batch_size: Number of items to retrieve (default: 1)
+
+ Returns:
+ List of items (always returns a list for consistency)
+ """
+ if batch_size is None:
+ return super().get_nowait()
+
+ items = []
+ for _ in range(batch_size):
+ try:
+ items.append(super().get_nowait())
+ except Empty:
+ break
+ return items
+
def get_queue_content_without_pop(self) -> list[T]:
"""Return a copy of the queue's contents without modifying it."""
# Ensure a consistent snapshot by holding the mutex
with self.mutex:
return list(self.queue)
+ def qsize(self) -> int:
+ """Return the approximate size of the queue.
+
+ Returns:
+ Number of items currently in the queue
+ """
+ return super().qsize()
+
def clear(self) -> None:
"""Remove all items from the queue.
diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py
index 1f89d3b02..d35a4f106 100644
--- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py
+++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py
@@ -1,7 +1,7 @@
from collections.abc import Callable
from memos.log import get_logger
-from memos.mem_cube.general import GeneralMemCube
+from memos.mem_cube.base import BaseMemCube
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.schemas.general_schemas import (
ACTIVATION_MEMORY_TYPE,
@@ -44,7 +44,7 @@ def create_autofilled_log_item(
to_memory_type: str,
user_id: str,
mem_cube_id: str,
- mem_cube: GeneralMemCube,
+ mem_cube: BaseMemCube,
) -> ScheduleLogForWebItem:
text_mem_base: TreeTextMemory = mem_cube.text_mem
current_memory_sizes = text_mem_base.get_current_memory_size()
@@ -106,7 +106,7 @@ def log_working_memory_replacement(
new_memory: list[TextualMemoryItem],
user_id: str,
mem_cube_id: str,
- mem_cube: GeneralMemCube,
+ mem_cube: BaseMemCube,
log_func_callback: Callable[[list[ScheduleLogForWebItem]], None],
):
"""Log changes when working memory is replaced."""
@@ -163,7 +163,7 @@ def log_activation_memory_update(
label: str,
user_id: str,
mem_cube_id: str,
- mem_cube: GeneralMemCube,
+ mem_cube: BaseMemCube,
log_func_callback: Callable[[list[ScheduleLogForWebItem]], None],
):
"""Log changes when activation memory is updated."""
@@ -214,7 +214,7 @@ def log_adding_memory(
memory_type: str,
user_id: str,
mem_cube_id: str,
- mem_cube: GeneralMemCube,
+ mem_cube: BaseMemCube,
log_func_callback: Callable[[list[ScheduleLogForWebItem]], None],
):
"""Log changes when working memory is replaced."""
diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py
index 6840adc2b..92e317881 100644
--- a/src/memos/mem_scheduler/general_scheduler.py
+++ b/src/memos/mem_scheduler/general_scheduler.py
@@ -5,6 +5,7 @@
from memos.configs.mem_scheduler import GeneralSchedulerConfig
from memos.context.context import ContextThreadPoolExecutor
from memos.log import get_logger
+from memos.mem_cube.base import BaseMemCube
from memos.mem_cube.general import GeneralMemCube
from memos.mem_scheduler.base_scheduler import BaseScheduler
from memos.mem_scheduler.schemas.general_schemas import (
@@ -22,6 +23,7 @@
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem
from memos.mem_scheduler.utils.filter_utils import is_all_chinese, is_all_english
+from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
from memos.memories.textual.item import TextualMemoryItem
from memos.memories.textual.preference import PreferenceTextMemory
from memos.memories.textual.tree import TreeTextMemory
@@ -51,10 +53,7 @@ def __init__(self, config: GeneralSchedulerConfig):
def long_memory_update_process(
self, user_id: str, mem_cube_id: str, messages: list[ScheduleMessageItem]
):
- mem_cube = messages[0].mem_cube
-
- # for status update
- self._set_current_context_from_message(msg=messages[0])
+ mem_cube = self.current_mem_cube
# update query monitors
for msg in messages:
@@ -140,7 +139,7 @@ def long_memory_update_process(
label=QUERY_LABEL,
user_id=user_id,
mem_cube_id=mem_cube_id,
- mem_cube=messages[0].mem_cube,
+ mem_cube=self.mem_cube,
)
def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
@@ -153,7 +152,7 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
logger.info(f"Messages {messages} assigned to {QUERY_LABEL} handler.")
# Process the query in a session turn
- grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages)
+ grouped_messages = group_messages_by_user_and_mem_cube(messages)
self.validate_schedule_messages(messages=messages, label=QUERY_LABEL)
@@ -175,7 +174,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
"""
logger.info(f"Messages {messages} assigned to {ANSWER_LABEL} handler.")
# Process the query in a session turn
- grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages)
+ grouped_messages = group_messages_by_user_and_mem_cube(messages)
self.validate_schedule_messages(messages=messages, label=ANSWER_LABEL)
@@ -185,13 +184,11 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
if len(messages) == 0:
return
- # for status update
- self._set_current_context_from_message(msg=messages[0])
-
def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
logger.info(f"Messages {messages} assigned to {ADD_LABEL} handler.")
# Process the query in a session turn
- grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages)
+ grouped_messages = group_messages_by_user_and_mem_cube(messages)
+ mem_cube = self.mem_cube
self.validate_schedule_messages(messages=messages, label=ADD_LABEL)
try:
@@ -201,9 +198,6 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
if len(messages) == 0:
return
- # for status update
- self._set_current_context_from_message(msg=messages[0])
-
# submit logs
for msg in messages:
try:
@@ -212,7 +206,6 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
logger.error(f"Error: {e}. Content: {msg.content}", exc_info=True)
userinput_memory_ids = []
- mem_cube = msg.mem_cube
for memory_id in userinput_memory_ids:
try:
mem_item: TextualMemoryItem = mem_cube.text_mem.get(
@@ -234,7 +227,7 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None:
memory_type=mem_type,
user_id=msg.user_id,
mem_cube_id=msg.mem_cube_id,
- mem_cube=msg.mem_cube,
+ mem_cube=self.mem_cube,
log_func_callback=self._submit_web_logs,
)
@@ -248,7 +241,7 @@ def process_message(message: ScheduleMessageItem):
try:
user_id = message.user_id
mem_cube_id = message.mem_cube_id
- mem_cube = message.mem_cube
+ mem_cube = self.mem_cube
content = message.content
user_name = message.user_name
@@ -272,7 +265,6 @@ def process_message(message: ScheduleMessageItem):
mem_ids=mem_ids,
user_id=user_id,
mem_cube_id=mem_cube_id,
- mem_cube=mem_cube,
text_mem=text_mem,
user_name=user_name,
)
@@ -297,7 +289,6 @@ def _process_memories_with_reader(
mem_ids: list[str],
user_id: str,
mem_cube_id: str,
- mem_cube: GeneralMemCube,
text_mem: TreeTextMemory,
user_name: str,
) -> None:
@@ -308,7 +299,6 @@ def _process_memories_with_reader(
mem_ids: List of memory IDs to process
user_id: User ID
mem_cube_id: Memory cube ID
- mem_cube: Memory cube instance
text_mem: Text memory instance
"""
try:
@@ -412,7 +402,7 @@ def process_message(message: ScheduleMessageItem):
try:
user_id = message.user_id
mem_cube_id = message.mem_cube_id
- mem_cube = message.mem_cube
+ mem_cube = self.mem_cube
content = message.content
user_name = message.user_name
@@ -461,7 +451,7 @@ def _process_memories_with_reorganize(
mem_ids: list[str],
user_id: str,
mem_cube_id: str,
- mem_cube: GeneralMemCube,
+ mem_cube: BaseMemCube,
text_mem: TreeTextMemory,
user_name: str,
) -> None:
@@ -513,10 +503,11 @@ def _pref_add_message_consumer(self, messages: list[ScheduleMessageItem]) -> Non
def process_message(message: ScheduleMessageItem):
try:
+ mem_cube = self.mem_cube
+
user_id = message.user_id
session_id = message.session_id
mem_cube_id = message.mem_cube_id
- mem_cube = message.mem_cube
content = message.content
messages_list = json.loads(content)
@@ -530,7 +521,9 @@ def process_message(message: ScheduleMessageItem):
# Use pref_mem.get_memory to process the memories
pref_memories = pref_mem.get_memory(
- messages_list, type="chat", info={"user_id": user_id, "session_id": session_id}
+ messages_list,
+ type="chat",
+ info={"user_id": user_id, "session_id": session_id, "mem_cube_id": mem_cube_id},
)
# Add pref_mem to vector db
pref_ids = pref_mem.add(pref_memories)
diff --git a/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py b/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py
index e18c6e51a..25b9a98f3 100644
--- a/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py
+++ b/src/memos/mem_scheduler/memory_manage_modules/memory_filter.py
@@ -2,7 +2,7 @@
from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
-from memos.mem_scheduler.utils.misc_utils import extract_json_dict
+from memos.mem_scheduler.utils.misc_utils import extract_json_obj
from memos.memories.textual.tree import TextualMemoryItem
@@ -66,7 +66,7 @@ def filter_unrelated_memories(
try:
# Parse JSON response
- response = extract_json_dict(response)
+ response = extract_json_obj(response)
logger.debug(f"Parsed JSON response: {response}")
relevant_indices = response["relevant_memories"]
filtered_count = response["filtered_count"]
@@ -164,7 +164,7 @@ def filter_redundant_memories(
try:
# Parse JSON response
- response = extract_json_dict(response)
+ response = extract_json_obj(response)
logger.debug(f"Parsed JSON response: {response}")
kept_indices = response["kept_memories"]
redundant_groups = response.get("redundant_groups", [])
@@ -226,8 +226,6 @@ def filter_unrelated_and_redundant_memories(
Note:
If LLM filtering fails, returns all memories (conservative approach)
"""
- success_flag = False
-
if not memories:
logger.info("No memories to filter for unrelated and redundant - returning empty list")
return [], True
@@ -265,7 +263,7 @@ def filter_unrelated_and_redundant_memories(
try:
# Parse JSON response
- response = extract_json_dict(response)
+ response = extract_json_obj(response)
logger.debug(f"Parsed JSON response: {response}")
kept_indices = response["kept_memories"]
unrelated_removed_count = response.get("unrelated_removed_count", 0)
diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py
index b766f0010..01b57563d 100644
--- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py
+++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py
@@ -1,9 +1,18 @@
+import time
+
+from concurrent.futures import as_completed
+
from memos.configs.mem_scheduler import BaseSchedulerConfig
+from memos.context.context import ContextThreadPoolExecutor
from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.mem_cube.general import GeneralMemCube
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.schemas.general_schemas import (
+ DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE,
+ DEFAULT_SCHEDULER_RETRIEVER_RETRIES,
+ FINE_STRATEGY,
+ FineStrategy,
TreeTextMemory_FINE_SEARCH_METHOD,
TreeTextMemory_SEARCH_METHOD,
)
@@ -12,11 +21,11 @@
filter_vector_based_similar_memories,
transform_name_to_key,
)
-from memos.mem_scheduler.utils.misc_utils import (
- extract_json_dict,
-)
+from memos.mem_scheduler.utils.misc_utils import extract_json_obj, extract_list_items_in_answer
+from memos.memories.textual.item import TextualMemoryMetadata
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
+# Extract JSON response
from .memory_filter import MemoryFilter
@@ -30,12 +39,199 @@ def __init__(self, process_llm: BaseLLM, config: BaseSchedulerConfig):
# hyper-parameters
self.filter_similarity_threshold = 0.75
self.filter_min_length_threshold = 6
-
- self.config: BaseSchedulerConfig = config
+ self.memory_filter = MemoryFilter(process_llm=process_llm, config=config)
self.process_llm = process_llm
+ self.config = config
- # Initialize memory filter
- self.memory_filter = MemoryFilter(process_llm=process_llm, config=config)
+ # Configure enhancement batching & retries from config with safe defaults
+ self.batch_size: int | None = getattr(
+ config, "scheduler_retriever_batch_size", DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE
+ )
+ self.retries: int = getattr(
+ config, "scheduler_retriever_enhance_retries", DEFAULT_SCHEDULER_RETRIEVER_RETRIES
+ )
+
+ def evaluate_memory_answer_ability(
+ self, query: str, memory_texts: list[str], top_k: int | None = None
+ ) -> bool:
+ limited_memories = memory_texts[:top_k] if top_k is not None else memory_texts
+ # Build prompt using the template
+ prompt = self.build_prompt(
+ template_name="memory_answer_ability_evaluation",
+ query=query,
+ memory_list="\n".join([f"- {memory}" for memory in limited_memories])
+ if limited_memories
+ else "No memories available",
+ )
+
+ # Use the process LLM to generate response
+ response = self.process_llm.generate([{"role": "user", "content": prompt}])
+
+ try:
+ result = extract_json_obj(response)
+
+ # Validate response structure
+ if "result" in result:
+ logger.info(
+ f"Answerability: result={result['result']}; reason={result.get('reason', 'n/a')}; evaluated={len(limited_memories)}"
+ )
+ return result["result"]
+ else:
+ logger.warning(f"Answerability: invalid LLM JSON structure; payload={result}")
+ return False
+
+ except Exception as e:
+ logger.error(f"Answerability: parse failed; err={e}; raw={str(response)[:200]}...")
+ # Fallback: return False if we can't determine answer ability
+ return False
+
+ # ---------------------- Enhancement helpers ----------------------
+ def _build_enhancement_prompt(self, query_history: list[str], batch_texts: list[str]) -> str:
+ if len(query_history) == 1:
+ query_history = query_history[0]
+ else:
+ query_history = (
+ [f"[{i}] {query}" for i, query in enumerate(query_history)]
+ if len(query_history) > 1
+ else query_history[0]
+ )
+ # Include numbering for rewrite mode to help LLM reference original memory IDs
+ if FINE_STRATEGY == FineStrategy.REWRITE:
+ text_memories = "\n".join([f"- [{i}] {mem}" for i, mem in enumerate(batch_texts)])
+ prompt_name = "memory_rewrite_enhancement"
+ else:
+ text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)])
+ prompt_name = "memory_recreate_enhancement"
+ return self.build_prompt(
+ prompt_name,
+ query_history=query_history,
+ memories=text_memories,
+ )
+
+ def _process_enhancement_batch(
+ self,
+ batch_index: int,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ retries: int,
+ ) -> tuple[list[TextualMemoryItem], bool]:
+ attempt = 0
+ text_memories = [one.memory for one in memories]
+
+ prompt = self._build_enhancement_prompt(
+ query_history=query_history, batch_texts=text_memories
+ )
+
+ llm_response = None
+ while attempt <= max(0, retries) + 1:
+ try:
+ llm_response = self.process_llm.generate([{"role": "user", "content": prompt}])
+ processed_text_memories = extract_list_items_in_answer(llm_response)
+ if len(processed_text_memories) > 0:
+ # create new
+ enhanced_memories = []
+ user_id = memories[0].metadata.user_id
+ if FINE_STRATEGY == FineStrategy.RECREATE:
+ for new_mem in processed_text_memories:
+ enhanced_memories.append(
+ TextualMemoryItem(
+ memory=new_mem, metadata=TextualMemoryMetadata(user_id=user_id)
+ )
+ )
+ elif FINE_STRATEGY == FineStrategy.REWRITE:
+ # Parse index from each processed line and rewrite corresponding original memory
+ def _parse_index_and_text(s: str) -> tuple[int | None, str]:
+ import re
+
+ s = (s or "").strip()
+ # Preferred: [index] text
+ m = re.match(r"^\s*\[(\d+)\]\s*(.+)$", s)
+ if m:
+ return int(m.group(1)), m.group(2).strip()
+ # Fallback: index: text or index - text
+ m = re.match(r"^\s*(\d+)\s*[:\-\)]\s*(.+)$", s)
+ if m:
+ return int(m.group(1)), m.group(2).strip()
+ return None, s
+
+ idx_to_original = dict(enumerate(memories))
+ for j, item in enumerate(processed_text_memories):
+ idx, new_text = _parse_index_and_text(item)
+ if idx is not None and idx in idx_to_original:
+ orig = idx_to_original[idx]
+ else:
+ # Fallback: align by order if index missing/invalid
+ orig = memories[j] if j < len(memories) else None
+ if not orig:
+ continue
+ enhanced_memories.append(
+ TextualMemoryItem(
+ id=orig.id,
+ memory=new_text,
+ metadata=orig.metadata,
+ )
+ )
+ else:
+ logger.error(f"Fine search strategy {FINE_STRATEGY} not exists")
+
+ logger.info(
+ f"[enhance_memories_with_query] ✅ done | Strategy={FINE_STRATEGY} | prompt={prompt} | llm_response={llm_response}"
+ )
+ return enhanced_memories, True
+ else:
+ raise ValueError(
+ f"Fail to run memory enhancement; retry {attempt}/{max(1, retries) + 1}; processed_text_memories: {processed_text_memories}"
+ )
+ except Exception as e:
+ attempt += 1
+ time.sleep(1)
+ logger.debug(
+ f"[enhance_memories_with_query][batch={batch_index}] 🔁 retry {attempt}/{max(1, retries) + 1} failed: {e}"
+ )
+ logger.error(
+ f"Fail to run memory enhancement; prompt: {prompt};\n llm_response: {llm_response}",
+ exc_info=True,
+ )
+ return memories, False
+
+ @staticmethod
+ def _split_batches(
+ memories: list[TextualMemoryItem], batch_size: int
+ ) -> list[tuple[int, int, list[TextualMemoryItem]]]:
+ batches: list[tuple[int, int, list[TextualMemoryItem]]] = []
+ start = 0
+ n = len(memories)
+ while start < n:
+ end = min(start + batch_size, n)
+ batches.append((start, end, memories[start:end]))
+ start = end
+ return batches
+
+ def recall_for_missing_memories(
+ self,
+ query: str,
+ memories: list[TextualMemoryItem],
+ ) -> tuple[str, bool]:
+ text_memories = [one.memory for one in memories] if memories else []
+ text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(text_memories)])
+
+ prompt = self.build_prompt(
+ template_name="enlarge_recall",
+ query=query,
+ memories_inline=text_memories,
+ )
+ llm_response = self.process_llm.generate([{"role": "user", "content": prompt}])
+
+ json_result: dict = extract_json_obj(llm_response)
+
+ logger.info(
+ f"[recall_for_missing_memories] ✅ done | prompt={prompt} | llm_response={llm_response}"
+ )
+
+ hint = json_result.get("hint", "")
+ if len(hint) == 0:
+ return hint, False
+ return hint, json_result.get("trigger_recall", False)
def search(
self,
@@ -81,6 +277,79 @@ def search(
results = []
return results
+ def enhance_memories_with_query(
+ self,
+ query_history: list[str],
+ memories: list[TextualMemoryItem],
+ ) -> (list[TextualMemoryItem], bool):
+ """
+ Enhance memories by adding context and making connections to better answer queries.
+
+ Args:
+ query_history: List of user queries in chronological order
+ memories: List of memory items to enhance
+
+ Returns:
+ Tuple of (enhanced_memories, success_flag)
+ """
+ if not memories:
+ logger.warning("[Enhance] ⚠️ skipped (no memories to process)")
+ return memories, True
+
+ batch_size = self.batch_size
+ retries = self.retries
+ num_of_memories = len(memories)
+ try:
+ # no parallel
+ if batch_size is None or num_of_memories <= batch_size:
+ # Single batch path with retry
+ enhanced_memories, success_flag = self._process_enhancement_batch(
+ batch_index=0,
+ query_history=query_history,
+ memories=memories,
+ retries=retries,
+ )
+
+ all_success = success_flag
+ else:
+ # parallel running batches
+ # Split into batches preserving order
+ batches = self._split_batches(memories=memories, batch_size=batch_size)
+
+ # Process batches concurrently
+ all_success = True
+ failed_batches = 0
+ with ContextThreadPoolExecutor(max_workers=len(batches)) as executor:
+ future_map = {
+ executor.submit(
+ self._process_enhancement_batch, bi, query_history, texts, retries
+ ): (bi, s, e)
+ for bi, (s, e, texts) in enumerate(batches)
+ }
+ enhanced_memories = []
+ for fut in as_completed(future_map):
+ bi, s, e = future_map[fut]
+
+ batch_memories, ok = fut.result()
+ enhanced_memories.extend(batch_memories)
+ if not ok:
+ all_success = False
+ failed_batches += 1
+ logger.info(
+ f"[Enhance] ✅ multi-batch done | batches={len(batches)} | enhanced={len(enhanced_memories)} |"
+ f" failed_batches={failed_batches} | success={all_success}"
+ )
+
+ except Exception as e:
+ logger.error(f"[Enhance] ❌ fatal error: {e}", exc_info=True)
+ all_success = False
+ enhanced_memories = memories
+
+ if len(enhanced_memories) == 0:
+ enhanced_memories = []
+ logger.error("[Enhance] ❌ fatal error: enhanced_memories is empty", exc_info=True)
+ return enhanced_memories, all_success
+
def rerank_memories(
self, queries: list[str], original_memories: list[str], top_k: int
) -> (list[str], bool):
@@ -115,7 +384,7 @@ def rerank_memories(
try:
# Parse JSON response
- response = extract_json_dict(response)
+ response = extract_json_obj(response)
new_order = response["new_order"][:top_k]
text_memories_with_new_order = [original_memories[idx] for idx in new_order]
logger.info(
diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py
index 46c4e2d49..03221aa7b 100644
--- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py
+++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py
@@ -7,12 +7,13 @@
from memos.context.context import ContextThread, ContextThreadPoolExecutor
from memos.log import get_logger
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
-from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher
from memos.mem_scheduler.schemas.general_schemas import (
DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL,
DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES,
+ DEFAULT_STOP_WAIT,
DEFAULT_STUCK_THREAD_TOLERANCE,
)
+from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher
from memos.mem_scheduler.utils.db_utils import get_utc_now
@@ -46,6 +47,11 @@ def __init__(self, config: BaseSchedulerConfig):
self.dispatcher: SchedulerDispatcher | None = None
self.dispatcher_pool_name = "dispatcher"
+ # Configure shutdown wait behavior from config or default
+ self.stop_wait = (
+ self.config.get("stop_wait", DEFAULT_STOP_WAIT) if self.config else DEFAULT_STOP_WAIT
+ )
+
def initialize(self, dispatcher: SchedulerDispatcher):
self.dispatcher = dispatcher
self.register_pool(
@@ -129,7 +135,9 @@ def _check_pools_health(self) -> None:
pool_info=pool_info,
stuck_max_interval=4,
)
- logger.info(f"Pool '{name}'. is_healthy: {is_healthy}. pool_info: {pool_info}")
+ if not is_healthy:
+ logger.info(f"Pool '{name}'. is_healthy: {is_healthy}. pool_info: {pool_info}")
+
with self._pool_lock:
if is_healthy:
pool_info["failure_count"] = 0
@@ -231,20 +239,7 @@ def _check_pool_health(
# Log health status with comprehensive information
if self.dispatcher:
- # Check thread activity
- active_threads = sum(
- 1
- for t in threading.enumerate()
- if t.name.startswith(executor._thread_name_prefix) # pylint: disable=protected-access
- )
-
- task_count = self.dispatcher.get_running_task_count()
max_workers = pool_info.get("max_workers", 0)
- stuck_count = len(stuck_tasks)
- logger.info(
- f"Pool health check passed - {active_threads} active threads, "
- f"{task_count} running tasks, pool size: {max_workers}, stuck tasks: {stuck_count}"
- )
return True, ""
@@ -367,12 +362,9 @@ def stop(self) -> None:
if not executor._shutdown: # pylint: disable=protected-access
try:
logger.info(f"Shutting down thread pool '{name}'")
- executor.shutdown(wait=True, cancel_futures=True)
+ executor.shutdown(wait=self.stop_wait, cancel_futures=True)
logger.info(f"Successfully shut down thread pool '{name}'")
except Exception as e:
logger.error(f"Error shutting down pool '{name}': {e!s}", exc_info=True)
- # Clear the pool registry
- self._pools.clear()
-
logger.info("Thread pool monitor and all pools stopped")
diff --git a/src/memos/mem_scheduler/monitors/general_monitor.py b/src/memos/mem_scheduler/monitors/general_monitor.py
index a789d581e..a5f1c0097 100644
--- a/src/memos/mem_scheduler/monitors/general_monitor.py
+++ b/src/memos/mem_scheduler/monitors/general_monitor.py
@@ -29,7 +29,7 @@
QueryMonitorQueue,
)
from memos.mem_scheduler.utils.db_utils import get_utc_now
-from memos.mem_scheduler.utils.misc_utils import extract_json_dict
+from memos.mem_scheduler.utils.misc_utils import extract_json_obj
from memos.memories.textual.tree import TreeTextMemory
@@ -92,7 +92,7 @@ def extract_query_keywords(self, query: str) -> list:
llm_response = self._process_llm.generate([{"role": "user", "content": prompt}])
try:
# Parse JSON output from LLM response
- keywords = extract_json_dict(llm_response)
+ keywords = extract_json_obj(llm_response)
assert isinstance(keywords, list)
except Exception as e:
logger.error(
@@ -206,7 +206,7 @@ def update_working_memory_monitors(
self.working_mem_monitor_capacity = min(
DEFAULT_WORKING_MEM_MONITOR_SIZE_LIMIT,
(
- text_mem_base.memory_manager.memory_size["WorkingMemory"]
+ int(text_mem_base.memory_manager.memory_size["WorkingMemory"])
+ self.partial_retention_number
),
)
@@ -353,7 +353,7 @@ def detect_intent(
)
response = self._process_llm.generate([{"role": "user", "content": prompt}])
try:
- response = extract_json_dict(response)
+ response = extract_json_obj(response)
assert ("trigger_retrieval" in response) and ("missing_evidences" in response)
except Exception:
logger.error(f"Fail to extract json dict from response: {response}")
diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py
index a087ab2df..21b2d63f0 100644
--- a/src/memos/mem_scheduler/optimized_scheduler.py
+++ b/src/memos/mem_scheduler/optimized_scheduler.py
@@ -2,7 +2,7 @@
import os
from collections import OrderedDict
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
from memos.api.product_models import APISearchRequest
from memos.configs.mem_scheduler import GeneralSchedulerConfig
@@ -20,15 +20,13 @@
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.utils.api_utils import format_textual_memory_item
from memos.mem_scheduler.utils.db_utils import get_utc_now
+from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory
from memos.types import UserContext
if TYPE_CHECKING:
from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem
- from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
- from memos.reranker.http_bge import HTTPBGEReranker
-
logger = get_logger(__name__)
@@ -52,11 +50,15 @@ def __init__(self, config: GeneralSchedulerConfig):
API_MIX_SEARCH_LABEL: self._api_mix_search_message_consumer,
}
)
+ self.searcher = None
+ self.reranker = None
+ self.text_mem = None
def submit_memory_history_async_task(
self,
search_req: APISearchRequest,
user_context: UserContext,
+ memories_to_store: dict | None = None,
session_id: str | None = None,
):
# Create message for async fine search
@@ -71,25 +73,22 @@ def submit_memory_history_async_task(
"chat_history": search_req.chat_history,
},
"user_context": {"mem_cube_id": user_context.mem_cube_id},
+ "memories_to_store": memories_to_store,
}
async_task_id = f"mix_search_{search_req.user_id}_{get_utc_now().timestamp()}"
- # Get mem_cube for the message
- mem_cube = self.current_mem_cube
-
message = ScheduleMessageItem(
item_id=async_task_id,
user_id=search_req.user_id,
mem_cube_id=user_context.mem_cube_id,
label=API_MIX_SEARCH_LABEL,
- mem_cube=mem_cube,
content=json.dumps(message_content),
timestamp=get_utc_now(),
)
# Submit async task
- self.submit_messages([message])
+ self.memos_message_queue.submit_messages([message])
logger.info(f"Submitted async fine search task for user {search_req.user_id}")
return async_task_id
@@ -127,33 +126,29 @@ def mix_search_memories(
self,
search_req: APISearchRequest,
user_context: UserContext,
- ):
+ ) -> list[dict[str, Any]]:
"""
Mix search memories: fast search + async fine search
"""
+ logger.info(
+ f"Mix searching memories for user {search_req.user_id} with query: {search_req.query}"
+ )
# Get mem_cube for fast search
- mem_cube = self.current_mem_cube
-
target_session_id = search_req.session_id
if not target_session_id:
target_session_id = "default_session"
search_filter = {"session_id": search_req.session_id} if search_req.session_id else None
- text_mem: TreeTextMemory = mem_cube.text_mem
- searcher: Searcher = text_mem.get_searcher(
- manual_close_internet=not search_req.internet_search,
- moscube=False,
- )
# Rerank Memories - reranker expects TextualMemoryItem objects
- reranker: HTTPBGEReranker = text_mem.reranker
+
info = {
"user_id": search_req.user_id,
"session_id": target_session_id,
"chat_history": search_req.chat_history,
}
- fast_retrieved_memories = searcher.retrieve(
+ fast_retrieved_memories = self.searcher.retrieve(
query=search_req.query,
user_name=user_context.mem_cube_id,
top_k=search_req.top_k,
@@ -164,48 +159,104 @@ def mix_search_memories(
info=info,
)
- self.submit_memory_history_async_task(
- search_req=search_req,
- user_context=user_context,
- session_id=search_req.session_id,
- )
-
- # Try to get pre-computed fine memories if available
+ # Try to get pre-computed memories if available
history_memories = self.api_module.get_history_memories(
user_id=search_req.user_id,
mem_cube_id=user_context.mem_cube_id,
turns=self.history_memory_turns,
)
-
+ logger.info(f"Found {len(history_memories)} history memories.")
if not history_memories:
- fast_memories = searcher.post_retrieve(
+ memories = self.searcher.post_retrieve(
retrieved_results=fast_retrieved_memories,
top_k=search_req.top_k,
user_name=user_context.mem_cube_id,
info=info,
)
- # Format fast memories for return
- formatted_memories = [format_textual_memory_item(data) for data in fast_memories]
- return formatted_memories
-
- sorted_history_memories = reranker.rerank(
- query=search_req.query, # Use search_req.query instead of undefined query
- graph_results=history_memories, # Pass TextualMemoryItem objects directly
- top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k
- search_filter=search_filter,
- )
+ else:
+ # if history memories can directly answer
+ sorted_history_memories = self.reranker.rerank(
+ query=search_req.query, # Use search_req.query instead of undefined query
+ graph_results=history_memories, # Pass TextualMemoryItem objects directly
+ top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k
+ search_filter=search_filter,
+ )
+ logger.info(f"Reranked {len(sorted_history_memories)} history memories.")
+ processed_hist_mem = self.searcher.post_retrieve(
+ retrieved_results=sorted_history_memories,
+ top_k=search_req.top_k,
+ user_name=user_context.mem_cube_id,
+ info=info,
+ )
- sorted_results = fast_retrieved_memories + sorted_history_memories
- final_results = searcher.post_retrieve(
- retrieved_results=sorted_results,
- top_k=search_req.top_k,
- user_name=user_context.mem_cube_id,
- info=info,
- )
+ can_answer = self.retriever.evaluate_memory_answer_ability(
+ query=search_req.query, memory_texts=[one.memory for one in processed_hist_mem]
+ )
- formatted_memories = [
- format_textual_memory_item(item) for item in final_results[: search_req.top_k]
- ]
+ if can_answer:
+ logger.info("History memories can answer the query.")
+ sorted_results = fast_retrieved_memories + sorted_history_memories
+ combined_results = self.searcher.post_retrieve(
+ retrieved_results=sorted_results,
+ top_k=search_req.top_k,
+ user_name=user_context.mem_cube_id,
+ info=info,
+ )
+ memories = combined_results[: search_req.top_k]
+ else:
+ logger.info("History memories cannot answer the query, enhancing memories.")
+ sorted_results = fast_retrieved_memories + sorted_history_memories
+ combined_results = self.searcher.post_retrieve(
+ retrieved_results=sorted_results,
+ top_k=search_req.top_k,
+ user_name=user_context.mem_cube_id,
+ info=info,
+ )
+ enhanced_memories, _ = self.retriever.enhance_memories_with_query(
+ query_history=[search_req.query],
+ memories=combined_results,
+ )
+
+ if len(enhanced_memories) < search_req.top_k:
+ logger.info(
+ f"Enhanced memories ({len(enhanced_memories)}) are less than top_k ({search_req.top_k}). Recalling for more."
+ )
+ missing_info_hint, trigger = self.retriever.recall_for_missing_memories(
+ query=search_req.query,
+ memories=combined_results,
+ )
+ retrieval_size = search_req.top_k - len(enhanced_memories)
+ if trigger:
+ logger.info(f"Triggering additional search with hint: {missing_info_hint}")
+ additional_memories = self.searcher.search(
+ query=missing_info_hint,
+ user_name=user_context.mem_cube_id,
+ top_k=retrieval_size,
+ mode=SearchMode.FAST,
+ memory_type="All",
+ search_filter=search_filter,
+ info=info,
+ )
+ else:
+ logger.info("Not triggering additional search, using combined results.")
+ additional_memories = combined_results[:retrieval_size]
+ logger.info(
+ f"Added {len(additional_memories)} more memories. Total enhanced memories: {len(enhanced_memories)}"
+ )
+ enhanced_memories += additional_memories
+
+ memories = enhanced_memories[: search_req.top_k]
+
+ formatted_memories = [format_textual_memory_item(item) for item in memories]
+ logger.info("Submitted memory history async task.")
+ self.submit_memory_history_async_task(
+ search_req=search_req,
+ user_context=user_context,
+ memories_to_store={
+ "memories": [one.to_dict() for one in memories],
+ "formatted_memories": formatted_memories,
+ },
+ )
return formatted_memories
@@ -213,13 +264,10 @@ def update_search_memories_to_redis(
self,
messages: list[ScheduleMessageItem],
):
- mem_cube: NaiveMemCube = self.current_mem_cube
-
for msg in messages:
content_dict = json.loads(msg.content)
search_req = content_dict["search_req"]
user_context = content_dict["user_context"]
-
session_id = search_req.get("session_id")
if session_id:
if session_id not in self.session_counter:
@@ -237,13 +285,20 @@ def update_search_memories_to_redis(
else:
session_turn = 0
- memories: list[TextualMemoryItem] = self.search_memories(
- search_req=APISearchRequest(**content_dict["search_req"]),
- user_context=UserContext(**content_dict["user_context"]),
- mem_cube=mem_cube,
- mode=SearchMode.FAST,
- )
- formatted_memories = [format_textual_memory_item(data) for data in memories]
+ memories_to_store = content_dict["memories_to_store"]
+ if memories_to_store is None:
+ memories: list[TextualMemoryItem] = self.search_memories(
+ search_req=APISearchRequest(**content_dict["search_req"]),
+ user_context=UserContext(**content_dict["user_context"]),
+ mem_cube=self.current_mem_cube,
+ mode=SearchMode.FAST,
+ )
+ formatted_memories = [format_textual_memory_item(data) for data in memories]
+ else:
+ memories = [
+ TextualMemoryItem.from_dict(one) for one in memories_to_store["memories"]
+ ]
+ formatted_memories = memories_to_store["formatted_memories"]
# Sync search data to Redis
self.api_module.sync_search_data(
@@ -267,7 +322,7 @@ def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem])
logger.info(f"Messages {messages} assigned to {API_MIX_SEARCH_LABEL} handler.")
# Process the query in a session turn
- grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages)
+ grouped_messages = group_messages_by_user_and_mem_cube(messages)
self.validate_schedule_messages(messages=messages, label=API_MIX_SEARCH_LABEL)
diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py
index f3d2191f8..8dd51c5bd 100644
--- a/src/memos/mem_scheduler/schemas/general_schemas.py
+++ b/src/memos/mem_scheduler/schemas/general_schemas.py
@@ -1,3 +1,5 @@
+import os
+
from enum import Enum
from pathlib import Path
from typing import NewType
@@ -11,6 +13,14 @@ class SearchMode(str, Enum):
MIXTURE = "mixture"
+class FineStrategy(str, Enum):
+ """Enumeration for fine strategies."""
+
+ REWRITE = "rewrite"
+ RECREATE = "recreate"
+ DEEP_SEARCH = "deep_search"
+
+
FILE_PATH = Path(__file__).absolute()
BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent
@@ -31,15 +41,19 @@ class SearchMode(str, Enum):
DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT = 20
DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache"
DEFAULT_THREAD_POOL_MAX_WORKERS = 50
-DEFAULT_CONSUME_INTERVAL_SECONDS = 0.05
+DEFAULT_CONSUME_INTERVAL_SECONDS = 0.01
+DEFAULT_CONSUME_BATCH = 1
DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300
DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2
DEFAULT_STUCK_THREAD_TOLERANCE = 10
-DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 1000000
+DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = -1
DEFAULT_TOP_K = 10
DEFAULT_CONTEXT_WINDOW_SIZE = 5
-DEFAULT_USE_REDIS_QUEUE = False
+DEFAULT_USE_REDIS_QUEUE = True
DEFAULT_MULTI_TASK_RUNNING_TIMEOUT = 30
+DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE = 20
+DEFAULT_SCHEDULER_RETRIEVER_RETRIES = 1
+DEFAULT_STOP_WAIT = False
# startup mode configuration
STARTUP_BY_THREAD = "thread"
@@ -64,8 +78,23 @@ class SearchMode(str, Enum):
MONITOR_ACTIVATION_MEMORY_TYPE = "MonitorActivationMemoryType"
DEFAULT_MAX_QUERY_KEY_WORDS = 1000
DEFAULT_WEIGHT_VECTOR_FOR_RANKING = [0.9, 0.05, 0.05]
+DEFAULT_MAX_WEB_LOG_QUEUE_SIZE = 50
# new types
UserID = NewType("UserID", str)
MemCubeID = NewType("CubeID", str)
+
+# algorithm strategies
+DEFAULT_FINE_STRATEGY = FineStrategy.REWRITE
+
+# Read fine strategy from environment variable `FINE_STRATEGY`.
+# If provided and valid, use it; otherwise fall back to default.
+_env_fine_strategy = os.getenv("FINE_STRATEGY")
+if _env_fine_strategy:
+ try:
+ FINE_STRATEGY = FineStrategy(_env_fine_strategy)
+ except ValueError:
+ FINE_STRATEGY = DEFAULT_FINE_STRATEGY
+else:
+ FINE_STRATEGY = DEFAULT_FINE_STRATEGY
diff --git a/src/memos/mem_scheduler/schemas/message_schemas.py b/src/memos/mem_scheduler/schemas/message_schemas.py
index 7f328474f..f1d48f3f1 100644
--- a/src/memos/mem_scheduler/schemas/message_schemas.py
+++ b/src/memos/mem_scheduler/schemas/message_schemas.py
@@ -2,11 +2,10 @@
from typing import Any
from uuid import uuid4
-from pydantic import BaseModel, ConfigDict, Field, field_serializer
+from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import TypedDict
from memos.log import get_logger
-from memos.mem_cube.base import BaseMemCube
from memos.mem_scheduler.general_modules.misc import DictConversionMixin
from memos.mem_scheduler.utils.db_utils import get_utc_now
@@ -34,22 +33,19 @@
class ScheduleMessageItem(BaseModel, DictConversionMixin):
item_id: str = Field(description="uuid", default_factory=lambda: str(uuid4()))
+ redis_message_id: str = Field(default="", description="the message get from redis stream")
user_id: str = Field(..., description="user id")
mem_cube_id: str = Field(..., description="memcube id")
+ session_id: str = Field(default="", description="Session ID for soft-filtering memories")
label: str = Field(..., description="Label of the schedule message")
- mem_cube: BaseMemCube | str = Field(..., description="memcube for schedule")
content: str = Field(..., description="Content of the schedule message")
timestamp: datetime = Field(
default_factory=get_utc_now, description="submit time for schedule_messages"
)
- user_name: str | None = Field(
- default=None,
+ user_name: str = Field(
+ default="",
description="user name / display name (optional)",
)
- session_id: str | None = Field(
- default=None,
- description="session_id (optional)",
- )
# Pydantic V2 model configuration
model_config = ConfigDict(
@@ -65,7 +61,6 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin):
"user_id": "user123", # Example user identifier
"mem_cube_id": "cube456", # Sample memory cube ID
"label": "sample_label", # Demonstration label value
- "mem_cube": "obj of GeneralMemCube", # Added mem_cube example
"content": "sample content", # Example message content
"timestamp": "2024-07-22T12:00:00Z", # Added timestamp example
"user_name": "Alice", # Added username example
@@ -73,13 +68,6 @@ class ScheduleMessageItem(BaseModel, DictConversionMixin):
},
)
- @field_serializer("mem_cube")
- def serialize_mem_cube(self, cube: BaseMemCube | str, _info) -> str:
- """Custom serializer for BaseMemCube objects to string representation"""
- if isinstance(cube, str):
- return cube
- return f"<{type(cube).__name__}:{id(cube)}>"
-
def to_dict(self) -> dict:
"""Convert model to dictionary suitable for Redis Stream"""
return {
@@ -101,7 +89,6 @@ def from_dict(cls, data: dict) -> "ScheduleMessageItem":
user_id=data["user_id"],
mem_cube_id=data["cube_id"],
label=data["label"],
- mem_cube="Not Applicable", # Custom cube deserialization
content=data["content"],
timestamp=datetime.fromisoformat(data["timestamp"]),
user_name=data.get("user_name"),
diff --git a/evaluation/scripts/temporal_locomo/models/__init__.py b/src/memos/mem_scheduler/task_schedule_modules/__init__.py
similarity index 100%
rename from evaluation/scripts/temporal_locomo/models/__init__.py
rename to src/memos/mem_scheduler/task_schedule_modules/__init__.py
diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
similarity index 87%
rename from src/memos/mem_scheduler/general_modules/dispatcher.py
rename to src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
index c2407b9e6..b1a304754 100644
--- a/src/memos/mem_scheduler/general_modules/dispatcher.py
+++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py
@@ -11,9 +11,11 @@
from memos.log import get_logger
from memos.mem_scheduler.general_modules.base import BaseSchedulerModule
from memos.mem_scheduler.general_modules.task_threads import ThreadManager
+from memos.mem_scheduler.schemas.general_schemas import DEFAULT_STOP_WAIT
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem
from memos.mem_scheduler.utils.metrics import MetricsRegistry
+from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
logger = get_logger(__name__)
@@ -32,13 +34,23 @@ class SchedulerDispatcher(BaseSchedulerModule):
- Thread race competition for parallel task execution
"""
- def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None):
+ def __init__(
+ self,
+ max_workers: int = 30,
+ memos_message_queue: Any | None = None,
+ use_redis_queue: bool | None = None,
+ enable_parallel_dispatch: bool = True,
+ config=None,
+ ):
super().__init__()
self.config = config
# Main dispatcher thread pool
self.max_workers = max_workers
+ self.memos_message_queue = memos_message_queue
+ self.use_redis_queue = use_redis_queue
+
# Get multi-task timeout from config
self.multi_task_running_timeout = (
self.config.get("multi_task_running_timeout") if self.config else None
@@ -73,6 +85,11 @@ def __init__(self, max_workers=30, enable_parallel_dispatch=True, config=None):
self._completed_tasks = []
self.completed_tasks_max_show_size = 10
+ # Configure shutdown wait behavior from config or default
+ self.stop_wait = (
+ self.config.get("stop_wait", DEFAULT_STOP_WAIT) if self.config else DEFAULT_STOP_WAIT
+ )
+
self.metrics = MetricsRegistry(
topk_per_label=(self.config or {}).get("metrics_topk_per_label", 50)
)
@@ -131,6 +148,18 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
# --- mark done ---
for m in messages:
self.metrics.on_done(label=m.label, mem_cube_id=m.mem_cube_id, now=time.time())
+
+ # acknowledge redis messages
+ if self.use_redis_queue and self.memos_message_queue is not None:
+ for msg in messages:
+ redis_message_id = msg.redis_message_id
+ # Acknowledge message processing
+ self.memos_message_queue.ack_message(
+ user_id=msg.user_id,
+ mem_cube_id=msg.mem_cube_id,
+ redis_message_id=redis_message_id,
+ )
+
# Mark task as completed and remove from tracking
with self._task_lock:
if task_item.item_id in self._running_tasks:
@@ -138,7 +167,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
del self._running_tasks[task_item.item_id]
self._completed_tasks.append(task_item)
if len(self._completed_tasks) > self.completed_tasks_max_show_size:
- self._completed_tasks[-self.completed_tasks_max_show_size :]
+ self._completed_tasks.pop(0)
logger.info(f"Task completed: {task_item.get_execution_info()}")
return result
@@ -152,7 +181,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
task_item.mark_failed(str(e))
del self._running_tasks[task_item.item_id]
if len(self._completed_tasks) > self.completed_tasks_max_show_size:
- self._completed_tasks[-self.completed_tasks_max_show_size :]
+ self._completed_tasks.pop(0)
logger.error(f"Task failed: {task_item.get_execution_info()}, Error: {e}")
raise
@@ -299,38 +328,6 @@ def stats(self) -> dict[str, int]:
def _default_message_handler(self, messages: list[ScheduleMessageItem]) -> None:
logger.debug(f"Using _default_message_handler to deal with messages: {messages}")
- def _group_messages_by_user_and_mem_cube(
- self, messages: list[ScheduleMessageItem]
- ) -> dict[str, dict[str, list[ScheduleMessageItem]]]:
- """
- Groups messages into a nested dictionary structure first by user_id, then by mem_cube_id.
-
- Args:
- messages: List of ScheduleMessageItem objects to be grouped
-
- Returns:
- A nested dictionary with the structure:
- {
- "user_id_1": {
- "mem_cube_id_1": [msg1, msg2, ...],
- "mem_cube_id_2": [msg3, msg4, ...],
- ...
- },
- "user_id_2": {
- ...
- },
- ...
- }
- Where each msg is the original ScheduleMessageItem object
- """
- grouped_dict = defaultdict(lambda: defaultdict(list))
-
- for msg in messages:
- grouped_dict[msg.user_id][msg.mem_cube_id].append(msg)
-
- # Convert defaultdict to regular dict for cleaner output
- return {user_id: dict(cube_groups) for user_id, cube_groups in grouped_dict.items()}
-
def _handle_future_result(self, future):
self._futures.remove(future)
try:
@@ -350,7 +347,7 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]):
return
# Group messages by user_id and mem_cube_id first
- user_cube_groups = self._group_messages_by_user_and_mem_cube(msg_list)
+ user_cube_groups = group_messages_by_user_and_mem_cube(msg_list)
# Process each user and mem_cube combination
for user_id, cube_groups in user_cube_groups.items():
@@ -381,17 +378,13 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]):
wrapped_handler = self._create_task_wrapper(handler, task_item)
# dispatch to different handler
- logger.debug(
- f"Dispatch {len(msgs)} message(s) to {label} handler for user {user_id} and mem_cube {mem_cube_id}."
- )
- logger.info(f"Task started: {task_item.get_execution_info()}")
-
+ logger.debug(f"Task started: {task_item.get_execution_info()}")
if self.enable_parallel_dispatch and self.dispatcher_executor is not None:
# Capture variables in lambda to avoid loop variable issues
- future = self.dispatcher_executor.submit(wrapped_handler, msgs)
- self._futures.add(future)
- future.add_done_callback(self._handle_future_result)
- logger.info(f"Dispatched {len(msgs)} message(s) as future task")
+ _ = self.dispatcher_executor.submit(wrapped_handler, msgs)
+ logger.info(
+ f"Dispatch {len(msgs)} message(s) to {label} handler for user {user_id} and mem_cube {mem_cube_id}."
+ )
else:
wrapped_handler(msgs)
@@ -484,17 +477,9 @@ def shutdown(self) -> None:
"""Gracefully shutdown the dispatcher."""
self._running = False
- if self.dispatcher_executor is not None:
- # Cancel pending tasks
- cancelled = 0
- for future in self._futures:
- if future.cancel():
- cancelled += 1
- logger.info(f"Cancelled {cancelled}/{len(self._futures)} pending tasks")
-
# Shutdown executor
try:
- self.dispatcher_executor.shutdown(wait=True)
+ self.dispatcher_executor.shutdown(wait=self.stop_wait, cancel_futures=True)
except Exception as e:
logger.error(f"Executor shutdown error: {e}", exc_info=True)
finally:
diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py
new file mode 100644
index 000000000..f7e3eac15
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py
@@ -0,0 +1,152 @@
+"""
+Local Queue implementation for SchedulerMessageItem objects.
+This module provides a local-based queue implementation that can replace
+the local memos_message_queue functionality in BaseScheduler.
+"""
+
+from memos.log import get_logger
+from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule
+
+
+logger = get_logger(__name__)
+
+
+class SchedulerLocalQueue(RedisSchedulerModule):
+ def __init__(
+ self,
+ maxsize: int,
+ ):
+ """
+ Initialize the SchedulerLocalQueue with a maximum queue size limit.
+
+ Args:
+ maxsize (int): Maximum number of messages allowed
+ in each individual queue.
+ If exceeded, subsequent puts will block
+ or raise an exception based on `block` parameter.
+ """
+ super().__init__()
+
+ self.stream_key_prefix = "local_queue"
+
+ self.max_internal_message_queue_size = maxsize
+ # Dictionary to hold per-stream queues: key = stream_key, value = Queue[ScheduleMessageItem]
+ self.queue_streams: dict[str, Queue[ScheduleMessageItem]] = {}
+ logger.info(
+ f"SchedulerLocalQueue initialized with max_internal_message_queue_size={maxsize}"
+ )
+
+ def get_stream_key(self, user_id: str, mem_cube_id: str) -> str:
+ stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}"
+ return stream_key
+
+ def put(
+ self, message: ScheduleMessageItem, block: bool = True, timeout: float | None = None
+ ) -> None:
+ """
+ Put a message into the appropriate internal queue based on user_id and mem_cube_id.
+
+ If the corresponding queue does not exist, it is created automatically.
+ This method uses a local in-memory queue (not Redis) for buffering messages.
+
+ Args:
+ message (ScheduleMessageItem): The message to enqueue.
+ block (bool): If True, block if the queue is full; if False, raise Full immediately.
+ timeout (float | None): Maximum time to wait for the queue to become available.
+ If None, block indefinitely. Ignored if block=False.
+
+ Raises:
+ queue.Full: If the queue is full and block=False or timeout expires.
+ Exception: Any underlying error during queue.put() operation.
+ """
+ stream_key = self.get_stream_key(user_id=message.user_id, mem_cube_id=message.mem_cube_id)
+
+ # Create the queue if it doesn't exist yet
+ if stream_key not in self.queue_streams:
+ logger.info(f"Creating new internal queue for stream: {stream_key}")
+ self.queue_streams[stream_key] = Queue(maxsize=self.max_internal_message_queue_size)
+
+ try:
+ self.queue_streams[stream_key].put(item=message, block=block, timeout=timeout)
+ logger.info(
+ f"Message successfully put into queue '{stream_key}'. Current size: {self.queue_streams[stream_key].qsize()}"
+ )
+ except Exception as e:
+ logger.error(f"Failed to put message into queue '{stream_key}': {e}", exc_info=True)
+ raise # Re-raise to maintain caller expectations
+
+ def get(
+ self,
+ stream_key: str,
+ block: bool = True,
+ timeout: float | None = None,
+ batch_size: int | None = None,
+ ) -> list[ScheduleMessageItem]:
+ if batch_size is not None and batch_size <= 0:
+ logger.warning(
+ f"get() called with invalid batch_size: {batch_size}. Returning empty list."
+ )
+ return []
+
+ # Return empty list if queue does not exist
+ if stream_key not in self.queue_streams:
+ logger.error(f"Stream {stream_key} does not exist when trying to get messages.")
+ return []
+
+ # Note: Assumes custom Queue implementation supports batch_size parameter
+ res = self.queue_streams[stream_key].get(
+ block=block, timeout=timeout, batch_size=batch_size
+ )
+ logger.debug(
+ f"Retrieved {len(res)} messages from queue '{stream_key}'. Current size: {self.queue_streams[stream_key].qsize()}"
+ )
+ return res
+
+ def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]:
+ """
+ Non-blocking version of get(). Equivalent to get(block=False, batch_size=batch_size).
+
+ Returns immediately with available messages or an empty list if queue is empty.
+
+ Args:
+ batch_size (int | None): Number of messages to retrieve in a batch.
+ If None, retrieves one message.
+
+ Returns:
+ List[ScheduleMessageItem]: Retrieved messages or empty list if queue is empty.
+ """
+ logger.debug(f"get_nowait() called with batch_size: {batch_size}")
+ return self.get(block=False, batch_size=batch_size)
+
+ def qsize(self) -> dict:
+ """
+ Return the current size of all internal queues as a dictionary.
+
+ Each key is the stream name, and each value is the number of messages in that queue.
+
+ Returns:
+ Dict[str, int]: Mapping from stream name to current queue size.
+ """
+ sizes = {stream: queue.qsize() for stream, queue in self.queue_streams.items()}
+ logger.debug(f"Current queue sizes: {sizes}")
+ return sizes
+
+ def clear(self) -> None:
+ for queue in self.queue_streams.values():
+ queue.clear()
+
+ @property
+ def unfinished_tasks(self) -> int:
+ """
+ Calculate the total number of unprocessed messages across all queues.
+
+ This is a convenience property for monitoring overall system load.
+
+ Returns:
+ int: Sum of all message counts in all internal queues.
+ """
+ total = sum(self.qsize().values())
+ logger.debug(f"Total unfinished tasks across all queues: {total}")
+ return total
diff --git a/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py
new file mode 100644
index 000000000..5e850c8ce
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py
@@ -0,0 +1,456 @@
+"""
+Redis Queue implementation for SchedulerMessageItem objects.
+
+This module provides a Redis-based queue implementation that can replace
+the local memos_message_queue functionality in BaseScheduler.
+"""
+
+import re
+import time
+
+from collections.abc import Callable
+from uuid import uuid4
+
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule
+
+
+logger = get_logger(__name__)
+
+
+class SchedulerRedisQueue(RedisSchedulerModule):
+ """
+ Redis-based queue for storing and processing SchedulerMessageItem objects.
+
+ This class provides a Redis Stream-based implementation that can replace
+ the local memos_message_queue functionality, offering better scalability
+ and persistence for message processing.
+
+ Inherits from RedisSchedulerModule to leverage existing Redis connection
+ and initialization functionality.
+ """
+
+ def __init__(
+ self,
+ stream_key_prefix: str = "scheduler:messages:stream",
+ consumer_group: str = "scheduler_group",
+ consumer_name: str | None = "scheduler_consumer",
+ max_len: int = 10000,
+ maxsize: int = 0, # For Queue compatibility
+ auto_delete_acked: bool = True, # Whether to automatically delete acknowledged messages
+ ):
+ """
+ Initialize the Redis queue.
+
+ Args:
+ stream_key_prefix: Name of the Redis stream
+ consumer_group: Name of the consumer group
+ consumer_name: Name of the consumer (auto-generated if None)
+ max_len: Maximum length of the stream (for memory management)
+ maxsize: Maximum size of the queue (for Queue compatibility, ignored)
+ auto_delete_acked: Whether to automatically delete acknowledged messages from stream
+ """
+ super().__init__()
+
+ # If maxsize <= 0, set to None (unlimited queue size)
+ if maxsize <= 0:
+ maxsize = 0
+
+ # Stream configuration
+ self.stream_key_prefix = stream_key_prefix
+ self.consumer_group = consumer_group
+ self.consumer_name = consumer_name or f"consumer_{uuid4().hex[:8]}"
+ self.max_len = max_len
+ self.maxsize = maxsize # For Queue compatibility
+ self.auto_delete_acked = auto_delete_acked # Whether to delete acknowledged messages
+
+ # Consumer state
+ self._is_listening = False
+ self._message_handler: Callable[[ScheduleMessageItem], None] | None = None
+
+ # Connection state
+ self._is_connected = False
+
+ # Task tracking for mem_scheduler_wait compatibility
+ self._unfinished_tasks = 0
+
+ # Auto-initialize Redis connection
+ if self.auto_initialize_redis():
+ self._is_connected = True
+
+ self.seen_streams = set()
+
+ def get_stream_key(self, user_id: str, mem_cube_id: str) -> str:
+ stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}"
+ return stream_key
+
+ def _ensure_consumer_group(self, stream_key) -> None:
+ """Ensure the consumer group exists for the stream."""
+ if not self._redis_conn:
+ return
+
+ try:
+ self._redis_conn.xgroup_create(stream_key, self.consumer_group, id="0", mkstream=True)
+ logger.debug(
+ f"Created consumer group '{self.consumer_group}' for stream '{stream_key}'"
+ )
+ except Exception as e:
+ # Check if it's a "consumer group already exists" error
+ error_msg = str(e).lower()
+ if "busygroup" in error_msg or "already exists" in error_msg:
+ logger.info(
+ f"Consumer group '{self.consumer_group}' already exists for stream '{stream_key}'"
+ )
+ else:
+ logger.error(f"Error creating consumer group: {e}", exc_info=True)
+
+ def put(
+ self, message: ScheduleMessageItem, block: bool = True, timeout: float | None = None
+ ) -> None:
+ """
+ Add a message to the Redis queue (Queue-compatible interface).
+
+ Args:
+ message: SchedulerMessageItem to add to the queue
+ block: Ignored for Redis implementation (always non-blocking)
+ timeout: Ignored for Redis implementation
+
+ Raises:
+ ConnectionError: If not connected to Redis
+ TypeError: If message is not a ScheduleMessageItem
+ """
+ if not self._redis_conn:
+ raise ConnectionError("Not connected to Redis. Redis connection not available.")
+
+ if not isinstance(message, ScheduleMessageItem):
+ raise TypeError(f"Expected ScheduleMessageItem, got {type(message)}")
+
+ try:
+ stream_key = self.get_stream_key(
+ user_id=message.user_id, mem_cube_id=message.mem_cube_id
+ )
+
+ if stream_key not in self.seen_streams:
+ self.seen_streams.add(stream_key)
+ self._ensure_consumer_group(stream_key=stream_key)
+
+ # Convert message to dictionary for Redis storage
+ message_data = message.to_dict()
+
+ # Add to Redis stream with automatic trimming
+ message_id = self._redis_conn.xadd(
+ stream_key, message_data, maxlen=self.max_len, approximate=True
+ )
+
+ logger.info(
+ f"Added message {message_id} to Redis stream: {message.label} - {message.content[:100]}..."
+ )
+
+ except Exception as e:
+ logger.error(f"Failed to add message to Redis queue: {e}")
+ raise
+
+ def ack_message(self, user_id, mem_cube_id, redis_message_id) -> None:
+ stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id)
+
+ self.redis.xack(stream_key, self.consumer_group, redis_message_id)
+
+ # Optionally delete the message from the stream to keep it clean
+ if self.auto_delete_acked:
+ try:
+ self._redis_conn.xdel(stream_key, redis_message_id)
+ logger.info(f"Successfully delete acknowledged message {redis_message_id}")
+ except Exception as e:
+ logger.warning(f"Failed to delete acknowledged message {redis_message_id}: {e}")
+
+ def get(
+ self,
+ stream_key: str,
+ block: bool = True,
+ timeout: float | None = None,
+ batch_size: int | None = None,
+ ) -> list[ScheduleMessageItem]:
+ if not self._redis_conn:
+ raise ConnectionError("Not connected to Redis. Redis connection not available.")
+
+ try:
+ # Calculate timeout for Redis
+ redis_timeout = None
+ if block and timeout is not None:
+ redis_timeout = int(timeout * 1000)
+ elif not block:
+ redis_timeout = None # Non-blocking
+
+ # Read messages from the consumer group
+ try:
+ messages = self._redis_conn.xreadgroup(
+ self.consumer_group,
+ self.consumer_name,
+ {stream_key: ">"},
+ count=batch_size if not batch_size else 1,
+ block=redis_timeout,
+ )
+ except Exception as read_err:
+ # Handle missing group/stream by creating and retrying once
+ err_msg = str(read_err).lower()
+ if "nogroup" in err_msg or "no such key" in err_msg:
+ logger.warning(
+ f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry."
+ )
+ self._ensure_consumer_group(stream_key=stream_key)
+ messages = self._redis_conn.xreadgroup(
+ self.consumer_group,
+ self.consumer_name,
+ {stream_key: ">"},
+ count=batch_size if not batch_size else 1,
+ block=redis_timeout,
+ )
+ else:
+ raise
+ result_messages = []
+
+ for _stream, stream_messages in messages:
+ for message_id, fields in stream_messages:
+ try:
+ # Convert Redis message back to SchedulerMessageItem
+ message = ScheduleMessageItem.from_dict(fields)
+ message.redis_message_id = message_id
+
+ result_messages.append(message)
+
+ except Exception as e:
+ logger.error(f"Failed to parse message {message_id}: {e}")
+
+ # Always return a list for consistency
+ if not result_messages:
+ if not block:
+ return [] # Return empty list for non-blocking calls
+ else:
+ # If no messages were found, raise Empty exception
+ from queue import Empty
+
+ raise Empty("No messages available in Redis queue")
+
+ return result_messages if batch_size is not None else result_messages[0]
+
+ except Exception as e:
+ if "Empty" in str(type(e).__name__):
+ raise
+ logger.error(f"Failed to get message from Redis queue: {e}")
+ raise
+
+ def get_nowait(
+ self, user_id: str, mem_cube_id: str, batch_size: int | None = None
+ ) -> list[ScheduleMessageItem]:
+ """
+ Get messages from the Redis queue without blocking (Queue-compatible interface).
+
+ Returns:
+ List of SchedulerMessageItem objects
+
+ Raises:
+ Empty: If no message is available
+ """
+ return self.get(
+ user_id=user_id, mem_cube_id=mem_cube_id, block=False, batch_size=batch_size
+ )
+
+ def qsize(self) -> int:
+ """
+ Get the current size of the Redis queue (Queue-compatible interface).
+
+ This method scans for all streams matching the `stream_key_prefix`
+ and sums up their lengths to get the total queue size.
+
+ Returns:
+ Total number of messages across all matching streams.
+ """
+ if not self._redis_conn:
+ return 0
+
+ total_size = 0
+ try:
+ # Scan for all stream keys matching the prefix
+ for stream_key in self._redis_conn.scan_iter(f"{self.stream_key_prefix}:*"):
+ try:
+ # Get the length of each stream and add to total
+ total_size += self._redis_conn.xlen(stream_key)
+ except Exception as e:
+ logger.debug(f"Failed to get length for stream {stream_key}: {e}")
+ return total_size
+ except Exception as e:
+ logger.error(f"Failed to get Redis queue size: {e}")
+ return 0
+
+ def get_stream_keys(self) -> list[str]:
+ """
+ List all Redis stream keys that match this queue's prefix.
+
+ Returns:
+ A list of stream keys like `"{prefix}:{user_id}:{mem_cube_id}"`.
+ """
+ if not self._redis_conn:
+ return []
+
+ # First, get all keys that might match (using Redis pattern matching)
+ redis_pattern = f"{self.stream_key_prefix}:*"
+ raw_keys = [
+ key.decode("utf-8") if isinstance(key, bytes) else key
+ for key in self._redis_conn.scan_iter(match=redis_pattern)
+ ]
+
+ # Second, filter using Python regex to ensure exact prefix match
+ # Escape special regex characters in the prefix, then add :.*
+ escaped_prefix = re.escape(self.stream_key_prefix)
+ regex_pattern = f"^{escaped_prefix}:"
+ stream_keys = [key for key in raw_keys if re.match(regex_pattern, key)]
+
+ logger.debug(f"get stream_keys from redis: {stream_keys}")
+ return stream_keys
+
+ def size(self) -> int:
+ """
+ Get the current size of the Redis queue (alias for qsize).
+
+ Returns:
+ Number of messages in the queue
+ """
+ return self.qsize()
+
+ def empty(self) -> bool:
+ """
+ Check if the Redis queue is empty (Queue-compatible interface).
+
+ Returns:
+ True if the queue is empty, False otherwise
+ """
+ return self.qsize() == 0
+
+ def full(self) -> bool:
+ """
+ Check if the Redis queue is full (Queue-compatible interface).
+
+ For Redis streams, we consider the queue full if it exceeds maxsize.
+ If maxsize is 0 or None, the queue is never considered full.
+
+ Returns:
+ True if the queue is full, False otherwise
+ """
+ if self.maxsize <= 0:
+ return False
+ return self.qsize() >= self.maxsize
+
+ def join(self) -> None:
+ """
+ Block until all items in the queue have been gotten and processed (Queue-compatible interface).
+
+ For Redis streams, this would require tracking pending messages,
+ which is complex. For now, this is a no-op.
+ """
+
+ def clear(self) -> None:
+ """Clear all messages from the queue."""
+ if not self._is_connected or not self._redis_conn:
+ return
+
+ try:
+ stream_keys = self.get_stream_keys()
+
+ for stream_key in stream_keys:
+ # Delete the entire stream
+ self._redis_conn.delete(stream_key)
+ logger.info(f"Cleared Redis stream: {stream_key}")
+
+ except Exception as e:
+ logger.error(f"Failed to clear Redis queue: {e}")
+
+ def start_listening(
+ self,
+ handler: Callable[[ScheduleMessageItem], None],
+ batch_size: int = 10,
+ poll_interval: float = 0.1,
+ ) -> None:
+ """
+ Start listening for messages and process them with the provided handler.
+
+ Args:
+ handler: Function to call for each received message
+ batch_size: Number of messages to process in each batch
+ poll_interval: Interval between polling attempts in seconds
+ """
+ if not self._is_connected:
+ raise ConnectionError("Not connected to Redis. Call connect() first.")
+
+ self._message_handler = handler
+ self._is_listening = True
+
+ logger.info(f"Started listening on Redis stream: {self.stream_key_prefix}")
+
+ try:
+ while self._is_listening:
+ messages = self.get(timeout=poll_interval, count=batch_size)
+
+ for message in messages:
+ try:
+ self._message_handler(message)
+ except Exception as e:
+ logger.error(f"Error processing message {message.item_id}: {e}")
+
+ # Small sleep to prevent excessive CPU usage
+ if not messages:
+ time.sleep(poll_interval)
+
+ except KeyboardInterrupt:
+ logger.info("Received interrupt signal, stopping listener")
+ except Exception as e:
+ logger.error(f"Error in message listener: {e}")
+ finally:
+ self._is_listening = False
+ logger.info("Stopped listening for messages")
+
+ def stop_listening(self) -> None:
+ """Stop the message listener."""
+ self._is_listening = False
+ logger.info("Requested stop for message listener")
+
+ def connect(self) -> None:
+ """Establish connection to Redis and set up the queue."""
+ if self._redis_conn is not None:
+ try:
+ # Test the connection
+ self._redis_conn.ping()
+ self._is_connected = True
+ logger.debug("Redis connection established successfully")
+ except Exception as e:
+ logger.error(f"Failed to connect to Redis: {e}")
+ self._is_connected = False
+ else:
+ logger.error("Redis connection not initialized")
+ self._is_connected = False
+
+ def disconnect(self) -> None:
+ """Disconnect from Redis and clean up resources."""
+ self._is_connected = False
+ if self._is_listening:
+ self.stop_listening()
+ logger.debug("Disconnected from Redis")
+
+ def __enter__(self):
+ """Context manager entry."""
+ self.connect()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """Context manager exit."""
+ self.stop_listening()
+ self.disconnect()
+
+ def __del__(self):
+ """Cleanup when object is destroyed."""
+ if self._is_connected:
+ self.disconnect()
+
+ @property
+ def unfinished_tasks(self) -> int:
+ return self.qsize()
diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py
new file mode 100644
index 000000000..6d824f4b1
--- /dev/null
+++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py
@@ -0,0 +1,125 @@
+"""
+Redis Queue implementation for SchedulerMessageItem objects.
+
+This module provides a Redis-based queue implementation that can replace
+the local memos_message_queue functionality in BaseScheduler.
+"""
+
+from memos.log import get_logger
+from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
+from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue
+from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue
+from memos.mem_scheduler.utils.db_utils import get_utc_now
+from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
+
+
+logger = get_logger(__name__)
+
+
+class ScheduleTaskQueue:
+ def __init__(
+ self,
+ use_redis_queue: bool,
+ maxsize: int,
+ disabled_handlers: list | None = None,
+ ):
+ self.use_redis_queue = use_redis_queue
+ self.maxsize = maxsize
+
+ if self.use_redis_queue:
+ self.memos_message_queue = SchedulerRedisQueue(maxsize=self.maxsize)
+ else:
+ self.memos_message_queue = SchedulerLocalQueue(maxsize=self.maxsize)
+
+ self.disabled_handlers = disabled_handlers
+
+ def ack_message(
+ self,
+ user_id,
+ mem_cube_id,
+ redis_message_id,
+ ) -> None:
+ if not isinstance(self.memos_message_queue, SchedulerRedisQueue):
+ logger.warning("ack_message is only supported for Redis queues")
+ return
+
+ self.memos_message_queue.ack_message(
+ user_id=user_id,
+ mem_cube_id=mem_cube_id,
+ redis_message_id=redis_message_id,
+ )
+
+ def debug_mode_on(self):
+ self.memos_message_queue.stream_key_prefix = (
+ f"debug_mode:{self.memos_message_queue.stream_key_prefix}"
+ )
+
+ def get_stream_keys(self) -> list[str]:
+ if isinstance(self.memos_message_queue, SchedulerRedisQueue):
+ stream_keys = self.memos_message_queue.get_stream_keys()
+ else:
+ stream_keys = list(self.memos_message_queue.queue_streams.keys())
+ return stream_keys
+
+ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]):
+ """Submit messages to the message queue (either local queue or Redis)."""
+ if isinstance(messages, ScheduleMessageItem):
+ messages = [messages]
+
+ if len(messages) < 1:
+ logger.error("Submit empty")
+ elif len(messages) == 1:
+ self.memos_message_queue.put(messages[0])
+ else:
+ user_cube_groups = group_messages_by_user_and_mem_cube(messages)
+
+ # Process each user and mem_cube combination
+ for _user_id, cube_groups in user_cube_groups.items():
+ for _mem_cube_id, user_cube_msgs in cube_groups.items():
+ for message in user_cube_msgs:
+ if not isinstance(message, ScheduleMessageItem):
+ error_msg = f"Invalid message type: {type(message)}, expected ScheduleMessageItem"
+ logger.error(error_msg)
+ raise TypeError(error_msg)
+
+ if getattr(message, "timestamp", None) is None:
+ message.timestamp = get_utc_now()
+
+ if self.disabled_handlers and message.label in self.disabled_handlers:
+ logger.info(
+ f"Skipping disabled handler: {message.label} - {message.content}"
+ )
+ continue
+
+ self.memos_message_queue.put(message)
+ logger.info(
+ f"Submitted message to local queue: {message.label} - {message.content}"
+ )
+
+ def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]:
+ stream_keys = self.get_stream_keys()
+
+ if len(stream_keys) == 0:
+ return []
+
+ messages: list[ScheduleMessageItem] = []
+
+ for stream_key in stream_keys:
+ fetched = self.memos_message_queue.get(
+ stream_key=stream_key,
+ block=False,
+ batch_size=batch_size,
+ )
+
+ messages.extend(fetched)
+ if len(messages) > 0:
+ logger.debug(
+ f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}"
+ )
+ return messages
+
+ def clear(self):
+ self.memos_message_queue.clear()
+
+ def qsize(self):
+ return self.memos_message_queue.qsize()
diff --git a/src/memos/mem_scheduler/utils/metrics.py b/src/memos/mem_scheduler/utils/metrics.py
index 5155c98b3..0d781c996 100644
--- a/src/memos/mem_scheduler/utils/metrics.py
+++ b/src/memos/mem_scheduler/utils/metrics.py
@@ -6,10 +6,14 @@
from dataclasses import dataclass, field
+from memos.log import get_logger
+
# ==== global window config ====
WINDOW_SEC = 120 # 2 minutes sliding window
+logger = get_logger(__name__)
+
# ---------- O(1) EWMA ----------
class Ewma:
@@ -184,12 +188,7 @@ def on_enqueue(
inst_rate = (1.0 / max(1e-3, dt)) if dt is not None else 0.0 # first sample: no spike
ls.last_enqueue_ts = now
ls.backlog += 1
- old_lam = ls.lambda_ewma.value_at(now)
ls.lambda_ewma.update(inst_rate, now)
- new_lam = ls.lambda_ewma.value_at(now)
- print(
- f"[DEBUG enqueue] {label} backlog={ls.backlog} dt={dt if dt is not None else '—'}s inst={inst_rate:.3f} λ {old_lam:.3f}→{new_lam:.3f}"
- )
self._label_topk[label].add(mem_cube_id)
ds = self._get_detail(label, mem_cube_id)
if ds:
@@ -222,12 +221,7 @@ def on_done(
ls.last_done_ts = now
if ls.backlog > 0:
ls.backlog -= 1
- old_mu = ls.mu_ewma.value_at(now)
ls.mu_ewma.update(inst_rate, now)
- new_mu = ls.mu_ewma.value_at(now)
- print(
- f"[DEBUG done] {label} backlog={ls.backlog} dt={dt if dt is not None else '—'}s inst={inst_rate:.3f} μ {old_mu:.3f}→{new_mu:.3f}"
- )
ds = self._detail_stats.get((label, mem_cube_id))
if ds:
prev_ts_d = ds.last_done_ts
diff --git a/src/memos/mem_scheduler/utils/misc_utils.py b/src/memos/mem_scheduler/utils/misc_utils.py
index aa9b5c489..7b0bcea34 100644
--- a/src/memos/mem_scheduler/utils/misc_utils.py
+++ b/src/memos/mem_scheduler/utils/misc_utils.py
@@ -1,18 +1,23 @@
import json
import re
+import traceback
+from collections import defaultdict
from functools import wraps
from pathlib import Path
import yaml
from memos.log import get_logger
+from memos.mem_scheduler.schemas.message_schemas import (
+ ScheduleMessageItem,
+)
logger = get_logger(__name__)
-def extract_json_dict(text: str):
+def extract_json_obj(text: str):
"""
Safely extracts JSON from LLM response text with robust error handling.
@@ -40,7 +45,7 @@ def extract_json_dict(text: str):
try:
return json.loads(text.strip())
except json.JSONDecodeError as e:
- logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
+ logger.info(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
# Fallback 1: Extract JSON using regex
json_pattern = r"\{[\s\S]*\}|\[[\s\S]*\]"
@@ -49,7 +54,7 @@ def extract_json_dict(text: str):
try:
return json.loads(matches[0])
except json.JSONDecodeError as e:
- logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
+ logger.info(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
# Fallback 2: Handle malformed JSON (common LLM issues)
try:
@@ -57,10 +62,125 @@ def extract_json_dict(text: str):
text = re.sub(r"([\{\s,])(\w+)(:)", r'\1"\2"\3', text)
return json.loads(text)
except json.JSONDecodeError as e:
- logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}", exc_info=True)
+ logger.error(f"Failed to parse JSON from text: {text}. Error: {e!s}")
+ logger.error("Full traceback:\n" + traceback.format_exc())
raise ValueError(text) from e
+def extract_list_items(text: str, bullet_prefixes: tuple[str, ...] = ("- ",)) -> list[str]:
+ """
+ Extract bullet list items from LLM output where each item is on a single line
+ starting with a given bullet prefix (default: "- ").
+
+ This function is designed to be robust to common LLM formatting variations,
+ following similar normalization practices as `extract_json_obj`.
+
+ Behavior:
+ - Strips common code-fence markers (```json, ```python, ``` etc.).
+ - Collects all lines that start with any of the provided `bullet_prefixes`.
+ - Tolerates the "• " bullet as a loose fallback.
+ - Unescapes common sequences like "\\n" and "\\t" within items.
+ - If no bullet lines are found, falls back to attempting to parse a JSON array
+ (using `extract_json_obj`) and returns its string elements.
+
+ Args:
+ text: Raw text response from LLM.
+ bullet_prefixes: Tuple of accepted bullet line prefixes.
+
+ Returns:
+ List of extracted items (strings). Returns an empty list if none can be parsed.
+ """
+ if not text:
+ return []
+
+ # Normalize the text similar to extract_json_obj
+ normalized = text.strip()
+ patterns_to_remove = ["json```", "```python", "```json", "latex```", "```latex", "```"]
+ for pattern in patterns_to_remove:
+ normalized = normalized.replace(pattern, "")
+ normalized = normalized.replace("\r\n", "\n")
+
+ lines = normalized.splitlines()
+ items: list[str] = []
+ seen: set[str] = set()
+
+ for raw in lines:
+ line = raw.strip()
+ if not line:
+ continue
+
+ matched = False
+ for prefix in bullet_prefixes:
+ if line.startswith(prefix):
+ content = line[len(prefix) :].strip()
+ content = content.replace("\\n", "\n").replace("\\t", "\t").replace("\\r", "\r")
+ if content and content not in seen:
+ items.append(content)
+ seen.add(content)
+ matched = True
+ break
+
+ if matched:
+ continue
+
+ if items:
+ return items
+ else:
+ logger.error(f"Fail to parse {text}")
+
+ return []
+
+
+def extract_list_items_in_answer(
+ text: str, bullet_prefixes: tuple[str, ...] = ("- ",)
+) -> list[str]:
+ """
+ Extract list items specifically from content enclosed within `...` tags.
+
+ - When one or more `...` blocks are present, concatenates their inner
+ contents with newlines and parses using `extract_list_items`.
+ - When no `` block is found, falls back to parsing the entire input with
+ `extract_list_items`.
+ - Case-insensitive matching of the `` tag.
+
+ Args:
+ text: Raw text that may contain `...` blocks.
+ bullet_prefixes: Accepted bullet prefixes (default: strictly `"- "`).
+
+ Returns:
+ List of extracted items (strings), or an empty list when nothing is parseable.
+ """
+ if not text:
+ return []
+
+ try:
+ normalized = text.strip().replace("\r\n", "\n")
+ # Ordered, exact-case matching for blocks: answer -> Answer -> ANSWER
+ tag_variants = ["answer", "Answer", "ANSWER"]
+ matches: list[str] = []
+ for tag in tag_variants:
+ matches = re.findall(rf"<{tag}>([\\s\\S]*?){tag}>", normalized)
+ if matches:
+ break
+ # Fallback: case-insensitive matching if none of the exact-case variants matched
+ if not matches:
+ matches = re.findall(r"([\\s\\S]*?)", normalized, flags=re.IGNORECASE)
+
+ if matches:
+ combined = "\n".join(m.strip() for m in matches if m is not None)
+ return extract_list_items(combined, bullet_prefixes=bullet_prefixes)
+
+ # Fallback: parse the whole text if tags are absent
+ return extract_list_items(normalized, bullet_prefixes=bullet_prefixes)
+ except Exception as e:
+ logger.info(f"Failed to extract items within tags: {e!s}", exc_info=True)
+ # Final fallback: attempt direct list extraction
+ try:
+ return extract_list_items(text, bullet_prefixes=bullet_prefixes)
+ except Exception:
+ return []
+
+
def parse_yaml(yaml_file: str | Path):
yaml_path = Path(yaml_file)
if not yaml_path.is_file():
@@ -100,3 +220,36 @@ def wrapper(*args, **kwargs):
return wrapper
return decorator
+
+
+def group_messages_by_user_and_mem_cube(
+ messages: list[ScheduleMessageItem],
+) -> dict[str, dict[str, list[ScheduleMessageItem]]]:
+ """
+ Groups messages into a nested dictionary structure first by user_id, then by mem_cube_id.
+
+ Args:
+ messages: List of ScheduleMessageItem objects to be grouped
+
+ Returns:
+ A nested dictionary with the structure:
+ {
+ "user_id_1": {
+ "mem_cube_id_1": [msg1, msg2, ...],
+ "mem_cube_id_2": [msg3, msg4, ...],
+ ...
+ },
+ "user_id_2": {
+ ...
+ },
+ ...
+ }
+ Where each msg is the original ScheduleMessageItem object
+ """
+ grouped_dict = defaultdict(lambda: defaultdict(list))
+
+ for msg in messages:
+ grouped_dict[msg.user_id][msg.mem_cube_id].append(msg)
+
+ # Convert defaultdict to regular dict for cleaner output
+ return {user_id: dict(cube_groups) for user_id, cube_groups in grouped_dict.items()}
diff --git a/src/memos/mem_scheduler/webservice_modules/redis_service.py b/src/memos/mem_scheduler/webservice_modules/redis_service.py
index 5439af9c6..e79553f33 100644
--- a/src/memos/mem_scheduler/webservice_modules/redis_service.py
+++ b/src/memos/mem_scheduler/webservice_modules/redis_service.py
@@ -333,6 +333,15 @@ def redis_start_listening(self, handler: Callable | None = None):
logger.warning("Listener is already running")
return
+ # Check Redis connection before starting listener
+ if self.redis is None:
+ logger.warning(
+ "Redis connection is None, attempting to auto-initialize before starting listener..."
+ )
+ if not self.auto_initialize_redis():
+ logger.error("Failed to initialize Redis connection, cannot start listener")
+ return
+
if handler is None:
handler = self.redis_consume_message_stream
diff --git a/src/memos/memories/textual/item.py b/src/memos/memories/textual/item.py
index 2c23ae193..e7595443d 100644
--- a/src/memos/memories/textual/item.py
+++ b/src/memos/memories/textual/item.py
@@ -198,6 +198,7 @@ class PreferenceTextualMemoryMetadata(TextualMemoryMetadata):
embedding: list[float] | None = Field(default=None, description="Vector of the dialog.")
preference: str | None = Field(default=None, description="Preference.")
created_at: str | None = Field(default=None, description="Timestamp of the dialog.")
+ mem_cube_id: str | None = Field(default=None, description="ID of the MemCube.")
class TextualMemoryItem(BaseModel):
diff --git a/src/memos/memories/textual/prefer_text_memory/adder.py b/src/memos/memories/textual/prefer_text_memory/adder.py
index a78601e86..5e58d23a5 100644
--- a/src/memos/memories/textual/prefer_text_memory/adder.py
+++ b/src/memos/memories/textual/prefer_text_memory/adder.py
@@ -10,6 +10,7 @@
from memos.log import get_logger
from memos.memories.textual.item import TextualMemoryItem
from memos.templates.prefer_complete_prompt import (
+ NAIVE_JUDGE_DUP_WITH_TEXT_MEM_PROMPT,
NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT,
NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_FINE,
NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT_OP_TRACE,
@@ -24,7 +25,7 @@ class BaseAdder(ABC):
"""Abstract base class for adders."""
@abstractmethod
- def __init__(self, llm_provider=None, embedder=None, vector_db=None):
+ def __init__(self, llm_provider=None, embedder=None, vector_db=None, text_mem=None):
"""Initialize the adder."""
@abstractmethod
@@ -41,12 +42,13 @@ def add(self, memories: list[TextualMemoryItem | dict[str, Any]], *args, **kwarg
class NaiveAdder(BaseAdder):
"""Naive adder."""
- def __init__(self, llm_provider=None, embedder=None, vector_db=None):
+ def __init__(self, llm_provider=None, embedder=None, vector_db=None, text_mem=None):
"""Initialize the naive adder."""
- super().__init__(llm_provider, embedder, vector_db)
+ super().__init__(llm_provider, embedder, vector_db, text_mem)
self.llm_provider = llm_provider
self.embedder = embedder
self.vector_db = vector_db
+ self.text_mem = text_mem
def _judge_update_or_add_fast(self, old_msg: str, new_msg: str) -> bool:
"""Judge if the new message expresses the same core content as the old message."""
@@ -81,6 +83,44 @@ def _judge_update_or_add_fine(self, new_mem: str, retrieved_mems: str) -> dict[s
logger.error(f"Error in judge_update_or_add_fine: {e}")
return None
+ def _judge_dup_with_text_mem(self, new_pref: MilvusVecDBItem) -> bool:
+ """Judge if the new message is the same as the text memory for a single preference."""
+ if new_pref.payload["preference_type"] != "explicit_preference":
+ return False
+ text_recalls = self.text_mem.search(
+ query=new_pref.memory,
+ top_k=5,
+ info={
+ "user_id": new_pref.payload["user_id"],
+ "session_id": new_pref.payload["session_id"],
+ },
+ mode="fast",
+ search_filter={"session_id": new_pref.payload["session_id"]},
+ user_name=new_pref.payload["mem_cube_id"],
+ )
+
+ text_mem_recalls = [
+ {"id": text_recall.id, "memory": text_recall.memory} for text_recall in text_recalls
+ ]
+
+ if not text_mem_recalls:
+ return False
+
+ new_preference = {"id": new_pref.id, "memory": new_pref.payload["preference"]}
+
+ prompt = NAIVE_JUDGE_DUP_WITH_TEXT_MEM_PROMPT.replace(
+ "{new_preference}", json.dumps(new_preference, ensure_ascii=False)
+ ).replace("{retrieved_memories}", json.dumps(text_mem_recalls, ensure_ascii=False))
+ try:
+ response = self.llm_provider.generate([{"role": "user", "content": prompt}])
+ response = response.strip().replace("```json", "").replace("```", "").strip()
+ result = json.loads(response)
+ exists = result.get("exists", False)
+ return exists
+ except Exception as e:
+ logger.error(f"Error in judge_dup_with_text_mem: {e}")
+ return False
+
def _judge_update_or_add_trace_op(
self, new_mems: str, retrieved_mems: str
) -> dict[str, Any] | None:
@@ -98,6 +138,32 @@ def _judge_update_or_add_trace_op(
logger.error(f"Error in judge_update_or_add_trace_op: {e}")
return None
+ def _dedup_explicit_pref_by_textual(
+ self, new_prefs: list[MilvusVecDBItem]
+ ) -> list[MilvusVecDBItem]:
+ """Deduplicate explicit preferences by textual memory."""
+ if os.getenv("DEDUP_PREF_EXP_BY_TEXTUAL", "false").lower() != "true" or not self.text_mem:
+ return new_prefs
+ dedup_prefs = []
+ with ContextThreadPoolExecutor(max_workers=max(1, min(len(new_prefs), 5))) as executor:
+ future_to_idx = {
+ executor.submit(self._judge_dup_with_text_mem, new_pref): idx
+ for idx, new_pref in enumerate(new_prefs)
+ }
+ is_dup_flags = [False] * len(new_prefs)
+ for future in as_completed(future_to_idx):
+ idx = future_to_idx[future]
+ try:
+ is_dup_flags[idx] = future.result()
+ except Exception as e:
+ logger.error(
+ f"Error in _judge_dup_with_text_mem for pref {new_prefs[idx].id}: {e}"
+ )
+ is_dup_flags[idx] = False
+
+ dedup_prefs = [pref for idx, pref in enumerate(new_prefs) if not is_dup_flags[idx]]
+ return dedup_prefs
+
def _update_memory_op_trace(
self,
new_memories: list[TextualMemoryItem],
@@ -139,10 +205,17 @@ def _update_memory_op_trace(
]
rsp = self._judge_update_or_add_trace_op(
- new_mems=json.dumps(new_mem_inputs),
- retrieved_mems=json.dumps(retrieved_mem_inputs) if retrieved_mem_inputs else "",
+ new_mems=json.dumps(new_mem_inputs, ensure_ascii=False),
+ retrieved_mems=json.dumps(retrieved_mem_inputs, ensure_ascii=False)
+ if retrieved_mem_inputs
+ else "",
)
if not rsp:
+ dedup_rsp = self._dedup_explicit_pref_by_textual(new_vec_db_items)
+ if not dedup_rsp:
+ return []
+ else:
+ new_vec_db_items = dedup_rsp
with ContextThreadPoolExecutor(max_workers=min(len(new_vec_db_items), 5)) as executor:
futures = {
executor.submit(self.vector_db.add, collection_name, [db_item]): db_item
@@ -222,8 +295,10 @@ def _update_memory_fine(
if mem.payload.get("preference", None)
]
rsp = self._judge_update_or_add_fine(
- new_mem=json.dumps(new_mem_input),
- retrieved_mems=json.dumps(retrieved_mem_inputs) if retrieved_mem_inputs else "",
+ new_mem=json.dumps(new_mem_input, ensure_ascii=False),
+ retrieved_mems=json.dumps(retrieved_mem_inputs, ensure_ascii=False)
+ if retrieved_mem_inputs
+ else "",
)
need_update = rsp.get("need_update", False) if rsp else False
need_update = (
@@ -245,6 +320,9 @@ def _update_memory_fine(
self.vector_db.update(collection_name, rsp["id"], update_vec_db_item)
return rsp["id"]
else:
+ dedup_rsp = self._dedup_explicit_pref_by_textual([vec_db_item])
+ if not dedup_rsp:
+ return ""
self.vector_db.add(collection_name, [vec_db_item])
return vec_db_item.id
@@ -272,6 +350,9 @@ def _update_memory_fast(
old_msg_str = recall.memory
new_msg_str = new_memory.memory
is_same = self._judge_update_or_add_fast(old_msg=old_msg_str, new_msg=new_msg_str)
+ dedup_rsp = self._dedup_explicit_pref_by_textual([vec_db_item])
+ if not dedup_rsp:
+ return ""
if is_same:
vec_db_item.id = recall.id
self.vector_db.update(collection_name, recall.id, vec_db_item)
diff --git a/src/memos/memories/textual/prefer_text_memory/factory.py b/src/memos/memories/textual/prefer_text_memory/factory.py
index 22182261a..3c96b7dac 100644
--- a/src/memos/memories/textual/prefer_text_memory/factory.py
+++ b/src/memos/memories/textual/prefer_text_memory/factory.py
@@ -19,14 +19,21 @@ class AdderFactory(BaseAdder):
@classmethod
def from_config(
- cls, config_factory: AdderConfigFactory, llm_provider=None, embedder=None, vector_db=None
+ cls,
+ config_factory: AdderConfigFactory,
+ llm_provider=None,
+ embedder=None,
+ vector_db=None,
+ text_mem=None,
) -> BaseAdder:
"""Create a Adder instance from a configuration factory."""
backend = config_factory.backend
if backend not in cls.backend_to_class:
raise ValueError(f"Invalid backend: {backend}")
adder_class = cls.backend_to_class[backend]
- return adder_class(llm_provider=llm_provider, embedder=embedder, vector_db=vector_db)
+ return adder_class(
+ llm_provider=llm_provider, embedder=embedder, vector_db=vector_db, text_mem=text_mem
+ )
class ExtractorFactory(BaseExtractor):
diff --git a/src/memos/memories/textual/prefer_text_memory/retrievers.py b/src/memos/memories/textual/prefer_text_memory/retrievers.py
index 0074c3f1c..c3aa950e4 100644
--- a/src/memos/memories/textual/prefer_text_memory/retrievers.py
+++ b/src/memos/memories/textual/prefer_text_memory/retrievers.py
@@ -119,6 +119,9 @@ def retrieve(
if pref.payload.get("preference", None)
]
+ # store explicit id and score, use it after reranker
+ explicit_id_scores = {item.id: item.score for item in explicit_prefs}
+
reranker_map = {
"naive": self._naive_reranker,
"original_text": self._original_text_reranker,
@@ -131,4 +134,9 @@ def retrieve(
query=query, prefs_mem=implicit_prefs_mem, prefs=implicit_prefs, top_k=top_k
)
+ # filter explicit mem by score bigger than threshold
+ explicit_prefs_mem = [
+ item for item in explicit_prefs_mem if explicit_id_scores.get(item.id, 0) >= 0.0
+ ]
+
return explicit_prefs_mem + implicit_prefs_mem
diff --git a/src/memos/memories/textual/simple_tree.py b/src/memos/memories/textual/simple_tree.py
index 313989cd2..05e62e3ee 100644
--- a/src/memos/memories/textual/simple_tree.py
+++ b/src/memos/memories/textual/simple_tree.py
@@ -1,7 +1,4 @@
-import time
-
-from datetime import datetime
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING
from memos.configs.memory import TreeTextMemoryConfig
from memos.embedders.base import BaseEmbedder
@@ -9,13 +6,10 @@
from memos.llms.base import BaseLLM
from memos.log import get_logger
from memos.mem_reader.base import BaseMemReader
-from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
from memos.memories.textual.tree import TreeTextMemory
from memos.memories.textual.tree_text_memory.organize.manager import MemoryManager
from memos.memories.textual.tree_text_memory.retrieve.bm25_util import EnhancedBM25
-from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher
from memos.reranker.base import BaseReranker
-from memos.types import MessageList
if TYPE_CHECKING:
@@ -43,43 +37,22 @@ def __init__(
is_reorganize: bool = False,
):
"""Initialize memory with the given configuration."""
- time_start = time.time()
self.config: TreeTextMemoryConfig = config
self.mode = self.config.mode
logger.info(f"Tree mode is {self.mode}")
self.extractor_llm: OpenAILLM | OllamaLLM | AzureLLM = llm
- logger.info(f"time init: extractor_llm time is: {time.time() - time_start}")
-
- time_start_ex = time.time()
self.dispatcher_llm: OpenAILLM | OllamaLLM | AzureLLM = llm
- logger.info(f"time init: dispatcher_llm time is: {time.time() - time_start_ex}")
-
- time_start_em = time.time()
self.embedder: OllamaEmbedder = embedder
- logger.info(f"time init: embedder time is: {time.time() - time_start_em}")
-
- time_start_gs = time.time()
self.graph_store: Neo4jGraphDB = graph_db
- logger.info(f"time init: graph_store time is: {time.time() - time_start_gs}")
-
- time_start_bm = time.time()
self.search_strategy = config.search_strategy
self.bm25_retriever = (
EnhancedBM25()
if self.search_strategy and self.search_strategy.get("bm25", False)
else None
)
- logger.info(f"time init: bm25_retriever time is: {time.time() - time_start_bm}")
-
- time_start_rr = time.time()
self.reranker = reranker
- logger.info(f"time init: reranker time is: {time.time() - time_start_rr}")
-
- time_start_mm = time.time()
self.memory_manager: MemoryManager = memory_manager
- logger.info(f"time init: memory_manager time is: {time.time() - time_start_mm}")
- time_start_ir = time.time()
# Create internet retriever if configured
self.internet_retriever = None
if config.internet_retriever is not None:
@@ -89,223 +62,3 @@ def __init__(
)
else:
logger.info("No internet retriever configured")
- logger.info(f"time init: internet_retriever time is: {time.time() - time_start_ir}")
-
- def replace_working_memory(
- self, memories: list[TextualMemoryItem], user_name: str | None = None
- ) -> None:
- self.memory_manager.replace_working_memory(memories, user_name=user_name)
-
- def get_working_memory(self, user_name: str | None = None) -> list[TextualMemoryItem]:
- working_memories = self.graph_store.get_all_memory_items(
- scope="WorkingMemory", user_name=user_name
- )
- items = [TextualMemoryItem.from_dict(record) for record in (working_memories)]
- # Sort by updated_at in descending order
- sorted_items = sorted(
- items, key=lambda x: x.metadata.updated_at or datetime.min, reverse=True
- )
- return sorted_items
-
- def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int]:
- """
- Get the current size of each memory type.
- This delegates to the MemoryManager.
- """
- return self.memory_manager.get_current_memory_size(user_name=user_name)
-
- def get_searcher(
- self,
- manual_close_internet: bool = False,
- moscube: bool = False,
- ):
- if (self.internet_retriever is not None) and manual_close_internet:
- logger.warning(
- "Internet retriever is init by config , but this search set manual_close_internet is True and will close it"
- )
- searcher = Searcher(
- self.dispatcher_llm,
- self.graph_store,
- self.embedder,
- self.reranker,
- internet_retriever=None,
- moscube=moscube,
- )
- else:
- searcher = Searcher(
- self.dispatcher_llm,
- self.graph_store,
- self.embedder,
- self.reranker,
- internet_retriever=self.internet_retriever,
- moscube=moscube,
- )
- return searcher
-
- def search(
- self,
- query: str,
- top_k: int,
- info=None,
- mode: str = "fast",
- memory_type: str = "All",
- manual_close_internet: bool = False,
- moscube: bool = False,
- search_filter: dict | None = None,
- user_name: str | None = None,
- ) -> list[TextualMemoryItem]:
- """Search for memories based on a query.
- User query -> TaskGoalParser -> MemoryPathResolver ->
- GraphMemoryRetriever -> MemoryReranker -> MemoryReasoner -> Final output
- Args:
- query (str): The query to search for.
- top_k (int): The number of top results to return.
- info (dict): Leave a record of memory consumption.
- mode (str, optional): The mode of the search.
- - 'fast': Uses a faster search process, sacrificing some precision for speed.
- - 'fine': Uses a more detailed search process, invoking large models for higher precision, but slower performance.
- memory_type (str): Type restriction for search.
- ['All', 'WorkingMemory', 'LongTermMemory', 'UserMemory']
- manual_close_internet (bool): If True, the internet retriever will be closed by this search, it high priority than config.
- moscube (bool): whether you use moscube to answer questions
- search_filter (dict, optional): Optional metadata filters for search results.
- - Keys correspond to memory metadata fields (e.g., "user_id", "session_id").
- - Values are exact-match conditions.
- Example: {"user_id": "123", "session_id": "abc"}
- If None, no additional filtering is applied.
- Returns:
- list[TextualMemoryItem]: List of matching memories.
- """
- if (self.internet_retriever is not None) and manual_close_internet:
- searcher = Searcher(
- self.dispatcher_llm,
- self.graph_store,
- self.embedder,
- self.reranker,
- bm25_retriever=self.bm25_retriever,
- internet_retriever=None,
- moscube=moscube,
- search_strategy=self.search_strategy,
- )
- else:
- searcher = Searcher(
- self.dispatcher_llm,
- self.graph_store,
- self.embedder,
- self.reranker,
- bm25_retriever=self.bm25_retriever,
- internet_retriever=self.internet_retriever,
- moscube=moscube,
- search_strategy=self.search_strategy,
- )
- return searcher.search(
- query, top_k, info, mode, memory_type, search_filter, user_name=user_name
- )
-
- def get_relevant_subgraph(
- self, query: str, top_k: int = 5, depth: int = 2, center_status: str = "activated"
- ) -> dict[str, Any]:
- """
- Find and merge the local neighborhood sub-graphs of the top-k
- nodes most relevant to the query.
- Process:
- 1. Embed the user query into a vector representation.
- 2. Use vector similarity search to find the top-k similar nodes.
- 3. For each similar node:
- - Ensure its status matches `center_status` (e.g., 'active').
- - Retrieve its local subgraph up to `depth` hops.
- - Collect the center node, its neighbors, and connecting edges.
- 4. Merge all retrieved subgraphs into a single unified subgraph.
- 5. Return the merged subgraph structure.
-
- Args:
- query (str): The user input or concept to find relevant memories for.
- top_k (int, optional): How many top similar nodes to retrieve. Default is 5.
- depth (int, optional): The neighborhood depth (number of hops). Default is 2.
- center_status (str, optional): Status condition the center node must satisfy (e.g., 'active').
-
- Returns:
- dict[str, Any]: A subgraph dict with:
- - 'core_id': ID of the top matching core node, or None if none found.
- - 'nodes': List of unique nodes (core + neighbors) in the merged subgraph.
- - 'edges': List of unique edges (as dicts with 'from', 'to', 'type') in the merged subgraph.
- """
- # Step 1: Embed query
- query_embedding = self.embedder.embed([query])[0]
-
- # Step 2: Get top-1 similar node
- similar_nodes = self.graph_store.search_by_embedding(query_embedding, top_k=top_k)
- if not similar_nodes:
- logger.info("No similar nodes found for query embedding.")
- return {"core_id": None, "nodes": [], "edges": []}
-
- # Step 3: Fetch neighborhood
- all_nodes = {}
- all_edges = set()
- cores = []
-
- for node in similar_nodes:
- core_id = node["id"]
- score = node["score"]
-
- subgraph = self.graph_store.get_subgraph(
- center_id=core_id, depth=depth, center_status=center_status
- )
-
- if not subgraph["core_node"]:
- logger.info(f"Skipping node {core_id} (inactive or not found).")
- continue
-
- core_node = subgraph["core_node"]
- neighbors = subgraph["neighbors"]
- edges = subgraph["edges"]
-
- # Collect nodes
- all_nodes[core_node["id"]] = core_node
- for n in neighbors:
- all_nodes[n["id"]] = n
-
- # Collect edges
- for e in edges:
- all_edges.add((e["source"], e["target"], e["type"]))
-
- cores.append(
- {"id": core_id, "score": score, "core_node": core_node, "neighbors": neighbors}
- )
-
- top_core = cores[0]
- return {
- "core_id": top_core["id"],
- "nodes": list(all_nodes.values()),
- "edges": [{"source": f, "target": t, "type": ty} for (f, t, ty) in all_edges],
- }
-
- def extract(self, messages: MessageList) -> list[TextualMemoryItem]:
- raise NotImplementedError
-
- def update(self, memory_id: str, new_memory: TextualMemoryItem | dict[str, Any]) -> None:
- raise NotImplementedError
-
- def get(self, memory_id: str) -> TextualMemoryItem:
- """Get a memory by its ID."""
- result = self.graph_store.get_node(memory_id)
- if result is None:
- raise ValueError(f"Memory with ID {memory_id} not found")
- metadata_dict = result.get("metadata", {})
- return TextualMemoryItem(
- id=result["id"],
- memory=result["memory"],
- metadata=TreeNodeTextualMemoryMetadata(**metadata_dict),
- )
-
- def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]:
- raise NotImplementedError
-
- def delete_all(self) -> None:
- """Delete all memories and their relationships from the graph store."""
- try:
- self.graph_store.clear()
- logger.info("All memories and edges have been deleted from the graph.")
- except Exception as e:
- logger.error(f"An error occurred while deleting all memories: {e}")
- raise
diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py
index dea3cc1ab..1b2355bc8 100644
--- a/src/memos/memories/textual/tree.py
+++ b/src/memos/memories/textual/tree.py
@@ -103,11 +103,15 @@ def add(
"""
return self.memory_manager.add(memories, user_name=user_name, mode=self.mode)
- def replace_working_memory(self, memories: list[TextualMemoryItem]) -> None:
- self.memory_manager.replace_working_memory(memories)
-
- def get_working_memory(self) -> list[TextualMemoryItem]:
- working_memories = self.graph_store.get_all_memory_items(scope="WorkingMemory")
+ def replace_working_memory(
+ self, memories: list[TextualMemoryItem], user_name: str | None = None
+ ) -> None:
+ self.memory_manager.replace_working_memory(memories, user_name=user_name)
+
+ def get_working_memory(self, user_name: str | None = None) -> list[TextualMemoryItem]:
+ working_memories = self.graph_store.get_all_memory_items(
+ scope="WorkingMemory", user_name=user_name
+ )
items = [TextualMemoryItem.from_dict(record) for record in (working_memories)]
# Sort by updated_at in descending order
sorted_items = sorted(
@@ -115,12 +119,12 @@ def get_working_memory(self) -> list[TextualMemoryItem]:
)
return sorted_items
- def get_current_memory_size(self) -> dict[str, int]:
+ def get_current_memory_size(self, user_name: str | None = None) -> dict[str, int]:
"""
Get the current size of each memory type.
This delegates to the MemoryManager.
"""
- return self.memory_manager.get_current_memory_size()
+ return self.memory_manager.get_current_memory_size(user_name=user_name)
def get_searcher(
self,
@@ -157,9 +161,10 @@ def search(
info=None,
mode: str = "fast",
memory_type: str = "All",
- manual_close_internet: bool = False,
+ manual_close_internet: bool = True,
moscube: bool = False,
search_filter: dict | None = None,
+ user_name: str | None = None,
) -> list[TextualMemoryItem]:
"""Search for memories based on a query.
User query -> TaskGoalParser -> MemoryPathResolver ->
@@ -184,9 +189,6 @@ def search(
list[TextualMemoryItem]: List of matching memories.
"""
if (self.internet_retriever is not None) and manual_close_internet:
- logger.warning(
- "Internet retriever is init by config , but this search set manual_close_internet is True and will close it"
- )
searcher = Searcher(
self.dispatcher_llm,
self.graph_store,
@@ -196,6 +198,7 @@ def search(
internet_retriever=None,
moscube=moscube,
search_strategy=self.search_strategy,
+ manual_close_internet=manual_close_internet,
)
else:
searcher = Searcher(
@@ -207,11 +210,19 @@ def search(
internet_retriever=self.internet_retriever,
moscube=moscube,
search_strategy=self.search_strategy,
+ manual_close_internet=manual_close_internet,
)
- return searcher.search(query, top_k, info, mode, memory_type, search_filter)
+ return searcher.search(
+ query, top_k, info, mode, memory_type, search_filter, user_name=user_name
+ )
def get_relevant_subgraph(
- self, query: str, top_k: int = 5, depth: int = 2, center_status: str = "activated"
+ self,
+ query: str,
+ top_k: int = 5,
+ depth: int = 2,
+ center_status: str = "activated",
+ user_name: str | None = None,
) -> dict[str, Any]:
"""
Find and merge the local neighborhood sub-graphs of the top-k
@@ -242,7 +253,9 @@ def get_relevant_subgraph(
query_embedding = self.embedder.embed([query])[0]
# Step 2: Get top-1 similar node
- similar_nodes = self.graph_store.search_by_embedding(query_embedding, top_k=top_k)
+ similar_nodes = self.graph_store.search_by_embedding(
+ query_embedding, top_k=top_k, user_name=user_name
+ )
if not similar_nodes:
logger.info("No similar nodes found for query embedding.")
return {"core_id": None, "nodes": [], "edges": []}
@@ -257,7 +270,7 @@ def get_relevant_subgraph(
score = node["score"]
subgraph = self.graph_store.get_subgraph(
- center_id=core_id, depth=depth, center_status=center_status
+ center_id=core_id, depth=depth, center_status=center_status, user_name=user_name
)
if subgraph is None or not subgraph["core_node"]:
@@ -306,7 +319,9 @@ def get(self, memory_id: str) -> TextualMemoryItem:
metadata=TreeNodeTextualMemoryMetadata(**metadata_dict),
)
- def get_by_ids(self, memory_ids: list[str]) -> list[TextualMemoryItem]:
+ def get_by_ids(
+ self, memory_ids: list[str], user_name: str | None = None
+ ) -> list[TextualMemoryItem]:
raise NotImplementedError
def get_all(self, user_name: str | None = None) -> dict:
diff --git a/src/memos/memories/textual/tree_text_memory/organize/manager.py b/src/memos/memories/textual/tree_text_memory/organize/manager.py
index 0c41717ea..a71fee02f 100644
--- a/src/memos/memories/textual/tree_text_memory/organize/manager.py
+++ b/src/memos/memories/textual/tree_text_memory/organize/manager.py
@@ -92,7 +92,7 @@ def add(
"""
added_ids: list[str] = []
- with ContextThreadPoolExecutor(max_workers=20) as executor:
+ with ContextThreadPoolExecutor(max_workers=200) as executor:
futures = {executor.submit(self._process_memory, m, user_name): m for m in memories}
for future in as_completed(futures, timeout=60):
try:
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py
index 8cf2f47f3..375048900 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/recall.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/recall.py
@@ -66,7 +66,7 @@ def retrieve(
working_memories = self.graph_store.get_all_memory_items(
scope="WorkingMemory", include_embedding=False, user_name=user_name
)
- return [TextualMemoryItem.from_dict(record) for record in working_memories]
+ return [TextualMemoryItem.from_dict(record) for record in working_memories[:top_k]]
with ContextThreadPoolExecutor(max_workers=3) as executor:
# Structured graph-based retrieval
@@ -104,15 +104,6 @@ def retrieve(
# Merge and deduplicate by ID
combined = {item.id: item for item in graph_results + vector_results + bm25_results}
- graph_ids = {item.id for item in graph_results}
- combined_ids = set(combined.keys())
- lost_ids = graph_ids - combined_ids
-
- if lost_ids:
- print(
- f"[DEBUG] The following nodes were in graph_results but missing in combined: {lost_ids}"
- )
-
return list(combined.values())
def retrieve_from_cube(
@@ -150,15 +141,6 @@ def retrieve_from_cube(
# Merge and deduplicate by ID
combined = {item.id: item for item in graph_results}
- graph_ids = {item.id for item in graph_results}
- combined_ids = set(combined.keys())
- lost_ids = graph_ids - combined_ids
-
- if lost_ids:
- print(
- f"[DEBUG] The following nodes were in graph_results but missing in combined: {lost_ids}"
- )
-
return list(combined.values())
def _graph_recall(
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
index f408755fd..933ef5af1 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py
@@ -43,6 +43,7 @@ def __init__(
internet_retriever: None = None,
moscube: bool = False,
search_strategy: dict | None = None,
+ manual_close_internet: bool = True,
):
self.graph_store = graph_store
self.embedder = embedder
@@ -58,7 +59,7 @@ def __init__(
self.moscube = moscube
self.vec_cot = search_strategy.get("cot", False) if search_strategy else False
self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False
-
+ self.manual_close_internet = manual_close_internet
self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage")
@timed
@@ -72,7 +73,7 @@ def retrieve(
search_filter: dict | None = None,
user_name: str | None = None,
**kwargs,
- ) -> list[TextualMemoryItem]:
+ ) -> list[tuple[TextualMemoryItem, float]]:
logger.info(
f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}"
)
@@ -94,7 +95,7 @@ def retrieve(
def post_retrieve(
self,
- retrieved_results: list[TextualMemoryItem],
+ retrieved_results: list[tuple[TextualMemoryItem, float]],
top_k: int,
user_name: str | None = None,
info=None,
@@ -108,7 +109,7 @@ def post_retrieve(
def search(
self,
query: str,
- top_k: int,
+ top_k: int = 10,
info=None,
mode="fast",
memory_type="All",
@@ -182,7 +183,7 @@ def _parse_task(
query_embedding = None
# fine mode will trigger initial embedding search
- if mode == "fine":
+ if mode == "fine_old":
logger.info("[SEARCH] Fine mode: embedding search")
query_embedding = self.embedder.embed([query])[0]
@@ -458,7 +459,7 @@ def _retrieve_from_internet(
user_id: str | None = None,
):
"""Retrieve and rerank from Internet source"""
- if not self.internet_retriever or mode == "fast":
+ if not self.internet_retriever or self.manual_close_internet:
logger.info(f"[PATH-C] '{query}' Skipped (no retriever, fast mode)")
return []
if memory_type not in ["All"]:
diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py
index 55e33494c..b9814f079 100644
--- a/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py
+++ b/src/memos/memories/textual/tree_text_memory/retrieve/task_goal_parser.py
@@ -22,6 +22,7 @@ class TaskGoalParser:
def __init__(self, llm=BaseLLM):
self.llm = llm
self.tokenizer = FastTokenizer()
+ self.retries = 1
def parse(
self,
@@ -103,18 +104,24 @@ def _parse_response(self, response: str, **kwargs) -> ParsedTaskGoal:
"""
Parse LLM JSON output safely.
"""
- try:
- context = kwargs.get("context", "")
- response = response.replace("```", "").replace("json", "").strip()
- response_json = eval(response)
- return ParsedTaskGoal(
- memories=response_json.get("memories", []),
- keys=response_json.get("keys", []),
- tags=response_json.get("tags", []),
- rephrased_query=response_json.get("rephrased_instruction", None),
- internet_search=response_json.get("internet_search", False),
- goal_type=response_json.get("goal_type", "default"),
- context=context,
- )
- except Exception as e:
- raise ValueError(f"Failed to parse LLM output: {e}\nRaw response:\n{response}") from e
+ # Ensure at least one attempt
+ attempts = max(1, getattr(self, "retries", 1))
+
+ for attempt_times in range(attempts):
+ try:
+ context = kwargs.get("context", "")
+ response = response.replace("```", "").replace("json", "").strip()
+ response_json = eval(response)
+ return ParsedTaskGoal(
+ memories=response_json.get("memories", []),
+ keys=response_json.get("keys", []),
+ tags=response_json.get("tags", []),
+ rephrased_query=response_json.get("rephrased_instruction", None),
+ internet_search=response_json.get("internet_search", False),
+ goal_type=response_json.get("goal_type", "default"),
+ context=context,
+ )
+ except Exception as e:
+ raise ValueError(
+ f"Failed to parse LLM output: {e}\nRaw response:\n{response} retried: {attempt_times + 1}/{attempts + 1}"
+ ) from e
diff --git a/src/memos/memos_tools/dinding_report_bot.py b/src/memos/memos_tools/dinding_report_bot.py
index 9791cf65a..d8b762855 100644
--- a/src/memos/memos_tools/dinding_report_bot.py
+++ b/src/memos/memos_tools/dinding_report_bot.py
@@ -7,6 +7,7 @@
import json
import os
import time
+import traceback
import urllib.parse
from datetime import datetime
@@ -14,6 +15,11 @@
from dotenv import load_dotenv
+from memos.log import get_logger
+
+
+logger = get_logger(__name__)
+
load_dotenv()
@@ -57,6 +63,20 @@
ROBOT_CODE = os.getenv("DINGDING_ROBOT_CODE")
DING_APP_KEY = os.getenv("DINGDING_APP_KEY")
DING_APP_SECRET = os.getenv("DINGDING_APP_SECRET")
+ENV_NAME = os.getenv("ENV_NAME", "PLAYGROUND_OFFLINE")
+
+theme_map = {
+ "ONLINE": {
+ "color": "#2196F3",
+ "grad": ("#E3F2FD", "#BBDEFB"),
+ "emoji": "🩵",
+ },
+ "OFFLINE": {
+ "color": "#FFC107",
+ "grad": ("#FFF8E1", "#FFECB3"),
+ "emoji": "🤍",
+ },
+}
# Get access_token
@@ -311,7 +331,7 @@ def error_bot(
)
# ---------- Markdown ----------
- colored_title = f"{title}"
+ colored_title = f"{ENV_NAME}"
at_suffix = ""
if user_ids:
at_suffix = "\n\n" + " ".join([f"@{m}" for m in user_ids])
@@ -367,41 +387,52 @@ def online_bot(
other_data2: dict,
emoji: dict,
):
- heading_color = "#00956D" # Green for subtitle
-
- # 0) Banner
- banner_bytes = make_header(header_name, sub_title_name)
- banner_url = upload_bytes_to_oss(banner_bytes, filename="online_report.png")
-
- # 1) Colored main title
- colored_title = f"{header_name}"
-
- # 3) Markdown
- md = "\n\n".join(
- filter(
- None,
- [
- f"",
- f"### 🙄 {colored_title}\n\n",
- _kv_lines(
- other_data1,
- next(iter(emoji.keys())),
- next(iter(emoji.values())),
- heading_color=heading_color,
- ),
- _kv_lines(
- other_data2,
- list(emoji.keys())[1],
- list(emoji.values())[1],
- heading_color=heading_color,
- ),
- f"Time: "
- f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n",
- ],
+ try:
+ logger.info("in online bot")
+ theme = "OFFLINE" if "OFFLINE" in ENV_NAME or "TEST" in ENV_NAME else "ONLINE"
+ style = theme_map.get(theme, theme_map["OFFLINE"])
+ heading_color = style["color"] # Use theme color for subtitle
+
+ # 0) Banner
+ banner_bytes = make_header(
+ header_name,
+ sub_title_name,
+ colors=style["grad"],
+ fg=style["color"],
+ )
+ banner_url = upload_bytes_to_oss(banner_bytes, filename=f"{ENV_NAME}_online_report.png")
+
+ # 1) Colored main title
+ colored_title = f"{ENV_NAME}"
+
+ # 3) Markdown
+ md = "\n\n".join(
+ filter(
+ None,
+ [
+ f"",
+ f"### {style['emoji']} {colored_title}\n\n",
+ _kv_lines(
+ other_data1,
+ next(iter(emoji.keys())),
+ next(iter(emoji.values())),
+ heading_color=heading_color,
+ ),
+ _kv_lines(
+ other_data2,
+ list(emoji.keys())[1],
+ list(emoji.values())[1],
+ heading_color=heading_color,
+ ),
+ f"Time: "
+ f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n",
+ ],
+ )
)
- )
- _send_md(colored_title, md, type="user")
+ _send_md(colored_title, md, type="user")
+ except Exception:
+ logger.error(traceback.format_exc())
if __name__ == "__main__":
diff --git a/src/memos/reranker/http_bge.py b/src/memos/reranker/http_bge.py
index 41011df14..db5a51fc2 100644
--- a/src/memos/reranker/http_bge.py
+++ b/src/memos/reranker/http_bge.py
@@ -119,7 +119,7 @@ def __init__(
self.warn_unknown_filter_keys = bool(warn_unknown_filter_keys)
self._warned_missing_keys: set[str] = set()
- @timed(log=True, log_prefix="RerankerAPI")
+ @timed(log=True, log_prefix="model_timed_rerank")
def rerank(
self,
query: str,
diff --git a/src/memos/templates/mem_scheduler_prompts.py b/src/memos/templates/mem_scheduler_prompts.py
index b4d091c1f..7f7415e79 100644
--- a/src/memos/templates/mem_scheduler_prompts.py
+++ b/src/memos/templates/mem_scheduler_prompts.py
@@ -390,6 +390,122 @@
- Focus on whether the memories can fully answer the query without additional information
"""
+MEMORY_RECREATE_ENHANCEMENT_PROMPT = """
+You are a knowledgeable and precise AI assistant.
+
+# GOAL
+Transform raw memories into clean, query-relevant facts — preserving timestamps and resolving ambiguities without inference.
+
+# RULES & THINKING STEPS
+1. Keep ONLY what’s relevant to the user’s query. Delete irrelevant memories entirely.
+2. Preserve ALL explicit timestamps (e.g., “on October 6”, “daily”, “after injury”).
+3. Resolve all ambiguities using only memory content:
+ - Pronouns → full name: “she” → “Melanie”
+ - Vague nouns → specific detail: “home” → “her childhood home in Guangzhou”
+ - “the user” → identity from context (e.g., “Melanie” if travel/running memories)
+4. Never invent, assume, or extrapolate.
+5. Each output line must be a standalone, clear, factual statement.
+6. Output format: one line per fact, starting with "- ", no extra text.
+
+# OUTPUT FORMAT (STRICT)
+Return ONLY the following block, with **one enhanced memory per line**.
+Each line MUST start with "- " (dash + space).
+
+Wrap the final output inside:
+
+- enhanced memory 1
+- enhanced memory 2
+...
+
+
+## User Query
+{query_history}
+
+## Original Memories
+{memories}
+
+Final Output:
+"""
+
+# Rewrite version: return enhanced memories with original IDs
+MEMORY_REWRITE_ENHANCEMENT_PROMPT = """
+You are a knowledgeable and precise AI assistant.
+
+# GOAL
+Transform raw memories into clean, query-relevant facts — preserving timestamps and resolving ambiguities without inference. Return each enhanced fact with the ID of the original memory being modified.
+
+# RULES & THINKING STEPS
+1. Keep ONLY what’s relevant to the user’s query. Delete irrelevant memories entirely.
+2. Preserve ALL explicit timestamps (e.g., “on October 6”, “daily”, “after injury”).
+3. Resolve all ambiguities using only memory content:
+ - Pronouns → full name: “she” → “Melanie”
+ - Vague nouns → specific detail: “home” → “her childhood home in Guangzhou”
+ - “the user” → identity from context (e.g., “Melanie” if travel/running memories)
+4. Never invent, assume, or extrapolate.
+5. Each output line must be a standalone, clear, factual statement.
+6. Output format: one line per fact, starting with "- ", no extra text.
+
+# IMPORTANT FOR REWRITE
+- Each output line MUST include the original memory’s ID shown in the input list.
+- Use the index shown for each original memory (e.g., "[0]", "[1]") as the ID to reference which memory you are rewriting.
+- For every rewritten line, prefix with the corresponding index in square brackets.
+
+# OUTPUT FORMAT (STRICT)
+Return ONLY the following block, with **one enhanced memory per line**.
+Each line MUST start with "- " (dash + space) AND include index in square brackets.
+
+Wrap the final output inside:
+
+- [index] enhanced memory 1
+- [index] enhanced memory 2
+...
+
+
+## User Query
+{query_history}
+
+## Original Memories
+{memories}
+
+Final Output:
+"""
+
+# One-sentence prompt for recalling missing information to answer the query (English)
+ENLARGE_RECALL_PROMPT_ONE_SENTENCE = """
+You are a precise AI assistant. Your job is to analyze the user's query and the available memories to identify what specific information is missing to fully answer the query.
+
+# GOAL
+
+Identify the specific missing facts needed to fully answer the user's query and generate a concise hint for recalling them.
+
+# RULES
+
+- Analyze the user's query to understand what information is being asked.
+- Review the available memories to see what information is already present.
+- Identify the gap between the user's query and the available memories.
+- Generate a single, concise hint that prompts the user to provide the missing information.
+- The hint should be a direct question or a statement that clearly indicates what is needed.
+
+# OUTPUT FORMAT
+A JSON object with:
+
+trigger_retrieval: true if information is missing, false if sufficient.
+hint: A clear, specific prompt to retrieve the missing information (or an empty string if trigger_retrieval is false):
+{{
+ "trigger_recall": ,
+ "hint": a paraphrase to retrieve support memories
+}}
+
+## User Query
+{query}
+
+## Available Memories
+{memories_inline}
+
+Final Output:
+"""
+
+
PROMPT_MAPPING = {
"intent_recognizing": INTENT_RECOGNIZING_PROMPT,
"memory_reranking": MEMORY_RERANKING_PROMPT,
@@ -398,6 +514,9 @@
"memory_redundancy_filtering": MEMORY_REDUNDANCY_FILTERING_PROMPT,
"memory_combined_filtering": MEMORY_COMBINED_FILTERING_PROMPT,
"memory_answer_ability_evaluation": MEMORY_ANSWER_ABILITY_EVALUATION_PROMPT,
+ "memory_recreate_enhancement": MEMORY_RECREATE_ENHANCEMENT_PROMPT,
+ "memory_rewrite_enhancement": MEMORY_REWRITE_ENHANCEMENT_PROMPT,
+ "enlarge_recall": ENLARGE_RECALL_PROMPT_ONE_SENTENCE,
}
MEMORY_ASSEMBLY_TEMPLATE = """The retrieved memories are listed as follows:\n\n {memory_text}"""
diff --git a/src/memos/templates/prefer_complete_prompt.py b/src/memos/templates/prefer_complete_prompt.py
index 9e0274cba..3a468b943 100644
--- a/src/memos/templates/prefer_complete_prompt.py
+++ b/src/memos/templates/prefer_complete_prompt.py
@@ -132,6 +132,74 @@
"""
+NAIVE_JUDGE_DUP_WITH_TEXT_MEM_PROMPT = """
+You are a content comparison expert. Your task is to determine whether each new preference information already exists in the retrieved text memories.
+
+**Task:** For each new preference, check if its content/topic/intent is already present in any of the retrieved text memories.
+
+**Input Structure:**
+- New preferences: Array of objects, each with "id" and "memory" fields
+- Retrieved memories: Array of objects, each with "id" and "memory" fields
+
+**Judgment Criteria:**
+- If the core content, topic, or intent of a new preference is **already covered** in any retrieved memory, mark as "exists" (true).
+- Consider both semantic similarity and topic overlap - even if wording differs, if the meaning is the same, it counts as existing.
+- If the new preference introduces **new information, different topic, or unique content** not found in retrieved memories, mark as "exists" (false).
+- Focus on the substantive content rather than minor phrasing differences.
+
+**Output Format (JSON):**
+```json
+{
+ "new_preference_id": "ID of the new preference being evaluated",
+ "exists": true/false,
+ "reasoning": "Brief explanation of your judgment, citing which retrieved memory contains similar content (if exists=true) or why it's new content (if exists=false)",
+ "matched_memory_id": "If exists=true, indicate which retrieved memory id matches; otherwise null"
+}
+```
+**New Preferences (array):**
+{new_preference}
+
+**Retrieved Text Memories (array):**
+{retrieved_memories}
+
+Output only the JSON response, no additional text.
+"""
+
+
+NAIVE_JUDGE_DUP_WITH_TEXT_MEM_PROMPT_ZH = """
+你是一个内容比较专家。你的任务是判断每个新的偏好信息是否已经存在于召回的文本记忆中。
+
+**任务:** 对每个新偏好,检查其内容/主题/意图是否已经在任何召回的文本记忆中存在。
+
+**输入结构:**
+- 新偏好:对象数组,每个对象包含"id"和"memory"字段
+- 召回记忆:对象数组,每个对象包含"id"和"memory"字段
+
+**判断标准:**
+- 如果新偏好的核心内容、主题或意图**已经被覆盖**在任何召回的记忆中,标记为"exists"(true)。
+- 考虑语义相似性和主题重叠 - 即使措辞不同,如果含义相同,也算作已存在。
+- 如果新偏好引入了**新信息、不同主题或独特内容**,且在召回记忆中未找到,标记为"exists"(false)。
+- 关注实质性内容,而非细微的表达差异。
+
+**输出格式(JSON):**
+```json
+{
+ "new_preference_id": "正在评估的新偏好ID",
+ "exists": true/false,
+ "reasoning": "简要说明你的判断理由,引用包含相似内容的召回记忆(如果exists=true)或说明为什么是新内容(如果exists=false)",
+ "matched_memory_id": "如果exists=true,指出匹配的召回记忆id;否则为null"
+}
+```
+**新偏好(数组):**
+{new_preference}
+
+**召回的文本记忆(数组):**
+{retrieved_memories}
+
+只输出JSON响应,不要输出其他任何文本。
+"""
+
+
NAIVE_JUDGE_UPDATE_OR_ADD_PROMPT = """
You are a content comparison expert. Now you are given old and new information, each containing a question, answer topic name and topic description.
Please judge whether these two information express the **same question or core content**, regardless of expression differences, details or example differences. The judgment criteria are as follows:
diff --git a/src/memos/utils.py b/src/memos/utils.py
index 08934ed34..4b1a59834 100644
--- a/src/memos/utils.py
+++ b/src/memos/utils.py
@@ -6,7 +6,7 @@
logger = get_logger(__name__)
-def timed(func=None, *, log=False, log_prefix=""):
+def timed(func=None, *, log=True, log_prefix=""):
"""Decorator to measure and optionally log time of retrieval steps.
Can be used as @timed or @timed(log=True)
@@ -17,8 +17,9 @@ def wrapper(*args, **kwargs):
start = time.perf_counter()
result = fn(*args, **kwargs)
elapsed = time.perf_counter() - start
+ elapsed_ms = elapsed * 1000.0
if log:
- logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed:.2f} seconds")
+ logger.info(f"[TIMER] {log_prefix or fn.__name__} took {elapsed_ms:.0f} ms")
return result
return wrapper
diff --git a/src/memos/vec_dbs/milvus.py b/src/memos/vec_dbs/milvus.py
index e50c8ce18..eafee2633 100644
--- a/src/memos/vec_dbs/milvus.py
+++ b/src/memos/vec_dbs/milvus.py
@@ -236,29 +236,32 @@ def search(
"sparse": self._sparse_search,
"hybrid": self._hybrid_search,
}
+ try:
+ results = search_func_map[search_type](
+ collection_name=collection_name,
+ query_vector=query_vector,
+ query=query,
+ top_k=top_k,
+ filter=expr,
+ )
- results = search_func_map[search_type](
- collection_name=collection_name,
- query_vector=query_vector,
- query=query,
- top_k=top_k,
- filter=expr,
- )
-
- items = []
- for hit in results[0]:
- entity = hit.get("entity", {})
-
- items.append(
- MilvusVecDBItem(
- id=str(entity.get("id")),
- memory=entity.get("memory"),
- original_text=entity.get("original_text"),
- vector=entity.get("vector"),
- payload=entity.get("payload", {}),
- score=1 - float(hit["distance"]),
+ items = []
+ for hit in results[0]:
+ entity = hit.get("entity", {})
+
+ items.append(
+ MilvusVecDBItem(
+ id=str(entity.get("id")),
+ memory=entity.get("memory"),
+ original_text=entity.get("original_text"),
+ vector=entity.get("vector"),
+ payload=entity.get("payload", {}),
+ score=1 - float(hit["distance"]),
+ )
)
- )
+ except Exception as e:
+ logger.error("Error in _%s_search: %s", search_type, e)
+ return []
logger.info(f"Milvus search completed with {len(items)} results.")
return items
diff --git a/tests/mem_scheduler/test_dispatcher.py b/tests/mem_scheduler/test_dispatcher.py
index e3064660b..e687d2986 100644
--- a/tests/mem_scheduler/test_dispatcher.py
+++ b/tests/mem_scheduler/test_dispatcher.py
@@ -14,10 +14,11 @@
)
from memos.llms.base import BaseLLM
from memos.mem_cube.general import GeneralMemCube
-from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher
from memos.mem_scheduler.scheduler_factory import SchedulerFactory
from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem
from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem
+from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher
+from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube
from memos.memories.textual.tree import TreeTextMemory
@@ -90,7 +91,6 @@ def setUp(self):
ScheduleMessageItem(
item_id="msg1",
user_id="user1",
- mem_cube="cube1",
mem_cube_id="msg1",
label="label1",
content="Test content 1",
@@ -99,7 +99,6 @@ def setUp(self):
ScheduleMessageItem(
item_id="msg2",
user_id="user1",
- mem_cube="cube1",
mem_cube_id="msg2",
label="label2",
content="Test content 2",
@@ -108,7 +107,6 @@ def setUp(self):
ScheduleMessageItem(
item_id="msg3",
user_id="user2",
- mem_cube="cube2",
mem_cube_id="msg3",
label="label1",
content="Test content 3",
@@ -193,51 +191,10 @@ def test_dispatch_serial(self):
self.assertEqual(len(label2_messages), 1)
self.assertEqual(label2_messages[0].item_id, "msg2")
- def test_dispatch_parallel(self):
- """Test dispatching messages in parallel mode."""
- # Create fresh mock handlers for this test
- mock_handler1 = MagicMock()
- mock_handler2 = MagicMock()
-
- # Create a new dispatcher for this test to avoid interference
- parallel_dispatcher = SchedulerDispatcher(max_workers=2, enable_parallel_dispatch=True)
- parallel_dispatcher.register_handler("label1", mock_handler1)
- parallel_dispatcher.register_handler("label2", mock_handler2)
-
- # Dispatch messages
- parallel_dispatcher.dispatch(self.test_messages)
-
- # Wait for all futures to complete
- parallel_dispatcher.join(timeout=1.0)
-
- # Verify handlers were called - label1 handler should be called twice (for user1 and user2)
- # label2 handler should be called once (only for user1)
- self.assertEqual(mock_handler1.call_count, 2) # Called for user1/msg1 and user2/msg3
- mock_handler2.assert_called_once() # Called for user1/msg2
-
- # Check that each handler received the correct messages
- # For label1: should have two calls, each with one message
- label1_calls = mock_handler1.call_args_list
- self.assertEqual(len(label1_calls), 2)
-
- # Extract messages from calls
- call1_messages = label1_calls[0][0][0] # First call, first argument (messages list)
- call2_messages = label1_calls[1][0][0] # Second call, first argument (messages list)
-
- # Verify the messages in each call
- self.assertEqual(len(call1_messages), 1)
- self.assertEqual(len(call2_messages), 1)
-
- # For label2: should have one call with [msg2]
- label2_messages = mock_handler2.call_args[0][0]
- self.assertEqual(len(label2_messages), 1)
- self.assertEqual(label2_messages[0].item_id, "msg2")
-
def test_group_messages_by_user_and_mem_cube(self):
"""Test grouping messages by user and cube."""
- # Check actual grouping logic
- with patch("memos.mem_scheduler.general_modules.dispatcher.logger.debug"):
- result = self.dispatcher._group_messages_by_user_and_mem_cube(self.test_messages)
+ # Check actual grouping logic using shared utility function
+ result = group_messages_by_user_and_mem_cube(self.test_messages)
# Adjust expected results based on actual grouping logic
# Note: According to dispatcher.py implementation, grouping is by mem_cube_id not mem_cube
diff --git a/tests/mem_scheduler/test_scheduler.py b/tests/mem_scheduler/test_scheduler.py
index 03a8e4318..fed1e8500 100644
--- a/tests/mem_scheduler/test_scheduler.py
+++ b/tests/mem_scheduler/test_scheduler.py
@@ -1,7 +1,6 @@
import sys
import unittest
-from contextlib import suppress
from datetime import datetime
from pathlib import Path
from unittest.mock import MagicMock, patch
@@ -21,12 +20,9 @@
from memos.mem_scheduler.schemas.general_schemas import (
ANSWER_LABEL,
QUERY_LABEL,
- STARTUP_BY_PROCESS,
- STARTUP_BY_THREAD,
)
from memos.mem_scheduler.schemas.message_schemas import (
ScheduleLogForWebItem,
- ScheduleMessageItem,
)
from memos.memories.textual.tree import TreeTextMemory
@@ -182,124 +178,6 @@ def test_submit_web_logs(self):
self.assertTrue(hasattr(actual_message, "timestamp"))
self.assertTrue(isinstance(actual_message.timestamp, datetime))
- def test_scheduler_startup_mode_default(self):
- """Test that scheduler has default startup mode set to thread."""
- self.assertEqual(self.scheduler.scheduler_startup_mode, STARTUP_BY_THREAD)
-
- def test_scheduler_startup_mode_thread(self):
- """Test scheduler with thread startup mode."""
- # Set scheduler startup mode to thread
- self.scheduler.scheduler_startup_mode = STARTUP_BY_THREAD
-
- # Start the scheduler
- self.scheduler.start()
-
- # Verify that consumer thread is created and process is None
- self.assertIsNotNone(self.scheduler._consumer_thread)
- self.assertIsNone(self.scheduler._consumer_process)
- self.assertTrue(self.scheduler._running)
-
- # Stop the scheduler
- self.scheduler.stop()
-
- def test_redis_message_queue(self):
- """Test Redis message queue functionality for sending and receiving messages."""
- import time
-
- from unittest.mock import MagicMock, patch
-
- # Mock Redis connection and operations
- mock_redis = MagicMock()
- mock_redis.xadd = MagicMock(return_value=b"1234567890-0")
-
- # Track received messages
- received_messages = []
-
- def redis_handler(messages: list[ScheduleMessageItem]) -> None:
- """Handler for Redis messages."""
- received_messages.extend(messages)
-
- # Register Redis handler
- redis_label = "test_redis"
- handlers = {redis_label: redis_handler}
- self.scheduler.register_handlers(handlers)
-
- # Enable Redis queue for this test
- with (
- patch.object(self.scheduler, "use_redis_queue", True),
- patch.object(self.scheduler, "_redis_conn", mock_redis),
- ):
- # Start scheduler
- self.scheduler.start()
-
- # Create test message for Redis
- redis_message = ScheduleMessageItem(
- label=redis_label,
- content="Redis test message",
- user_id="redis_user",
- mem_cube_id="redis_cube",
- mem_cube="redis_mem_cube_obj",
- timestamp=datetime.now(),
- )
-
- # Submit message to Redis queue
- self.scheduler.submit_messages(redis_message)
-
- # Verify Redis xadd was called
- mock_redis.xadd.assert_called_once()
- call_args = mock_redis.xadd.call_args
- self.assertEqual(call_args[0][0], "user:queries:stream")
-
- # Verify message data was serialized correctly
- message_data = call_args[0][1]
- self.assertEqual(message_data["label"], redis_label)
- self.assertEqual(message_data["content"], "Redis test message")
- self.assertEqual(message_data["user_id"], "redis_user")
- self.assertEqual(message_data["cube_id"], "redis_cube") # Note: to_dict uses cube_id
-
- # Simulate Redis message consumption
- # This would normally be handled by the Redis consumer in the scheduler
- time.sleep(0.1) # Brief wait for async operations
-
- # Stop scheduler
- self.scheduler.stop()
-
- print("Redis message queue test completed successfully!")
-
- # Removed test_robustness method - was too time-consuming for CI/CD pipeline
-
- def test_scheduler_startup_mode_process(self):
- """Test scheduler with process startup mode."""
- # Set scheduler startup mode to process
- self.scheduler.scheduler_startup_mode = STARTUP_BY_PROCESS
-
- # Start the scheduler
- try:
- self.scheduler.start()
-
- # Verify that consumer process is created and thread is None
- self.assertIsNotNone(self.scheduler._consumer_process)
- self.assertIsNone(self.scheduler._consumer_thread)
- self.assertTrue(self.scheduler._running)
-
- except Exception as e:
- # Process mode may fail due to pickling issues in test environment
- # This is expected behavior - we just verify the startup mode is set correctly
- self.assertEqual(self.scheduler.scheduler_startup_mode, STARTUP_BY_PROCESS)
- print(f"Process mode test encountered expected pickling issue: {e}")
- finally:
- # Always attempt to stop the scheduler
- with suppress(Exception):
- self.scheduler.stop()
-
- # Verify cleanup attempt was made
- self.assertEqual(self.scheduler.scheduler_startup_mode, STARTUP_BY_PROCESS)
-
- def test_scheduler_startup_mode_constants(self):
- """Test that startup mode constants are properly defined."""
- self.assertEqual(STARTUP_BY_THREAD, "thread")
- self.assertEqual(STARTUP_BY_PROCESS, "process")
-
def test_activation_memory_update(self):
"""Test activation memory update functionality with DynamicCache handling."""
if not self.RUN_ACTIVATION_MEMORY_TESTS:
@@ -401,130 +279,3 @@ def test_dynamic_cache_layers_access(self):
# If layers attribute doesn't exist, verify our fix handles this case
print("⚠️ DynamicCache doesn't have 'layers' attribute in this transformers version")
print("✅ Test passed - our code should handle this gracefully")
-
- def test_get_running_tasks_with_filter(self):
- """Test get_running_tasks method with filter function."""
- # Mock dispatcher and its get_running_tasks method
- mock_task_item1 = MagicMock()
- mock_task_item1.item_id = "task_1"
- mock_task_item1.user_id = "user_1"
- mock_task_item1.mem_cube_id = "cube_1"
- mock_task_item1.task_info = {"type": "query"}
- mock_task_item1.task_name = "test_task_1"
- mock_task_item1.start_time = datetime.now()
- mock_task_item1.end_time = None
- mock_task_item1.status = "running"
- mock_task_item1.result = None
- mock_task_item1.error_message = None
- mock_task_item1.messages = []
-
- # Define a filter function
- def user_filter(task):
- return task.user_id == "user_1"
-
- # Mock the filtered result (only task_1 matches the filter)
- with patch.object(
- self.scheduler.dispatcher, "get_running_tasks", return_value={"task_1": mock_task_item1}
- ) as mock_get_running_tasks:
- # Call get_running_tasks with filter
- result = self.scheduler.get_running_tasks(filter_func=user_filter)
-
- # Verify result
- self.assertIsInstance(result, dict)
- self.assertIn("task_1", result)
- self.assertEqual(len(result), 1)
-
- # Verify dispatcher method was called with filter
- mock_get_running_tasks.assert_called_once_with(filter_func=user_filter)
-
- def test_get_running_tasks_empty_result(self):
- """Test get_running_tasks method when no tasks are running."""
- # Mock dispatcher to return empty dict
- with patch.object(
- self.scheduler.dispatcher, "get_running_tasks", return_value={}
- ) as mock_get_running_tasks:
- # Call get_running_tasks
- result = self.scheduler.get_running_tasks()
-
- # Verify empty result
- self.assertIsInstance(result, dict)
- self.assertEqual(len(result), 0)
-
- # Verify dispatcher method was called
- mock_get_running_tasks.assert_called_once_with(filter_func=None)
-
- def test_get_running_tasks_no_dispatcher(self):
- """Test get_running_tasks method when dispatcher is None."""
- # Temporarily set dispatcher to None
- original_dispatcher = self.scheduler.dispatcher
- self.scheduler.dispatcher = None
-
- # Call get_running_tasks
- result = self.scheduler.get_running_tasks()
-
- # Verify empty result and warning behavior
- self.assertIsInstance(result, dict)
- self.assertEqual(len(result), 0)
-
- # Restore dispatcher
- self.scheduler.dispatcher = original_dispatcher
-
- def test_get_running_tasks_multiple_tasks(self):
- """Test get_running_tasks method with multiple tasks."""
- # Mock multiple task items
- mock_task_item1 = MagicMock()
- mock_task_item1.item_id = "task_1"
- mock_task_item1.user_id = "user_1"
- mock_task_item1.mem_cube_id = "cube_1"
- mock_task_item1.task_info = {"type": "query"}
- mock_task_item1.task_name = "test_task_1"
- mock_task_item1.start_time = datetime.now()
- mock_task_item1.end_time = None
- mock_task_item1.status = "running"
- mock_task_item1.result = None
- mock_task_item1.error_message = None
- mock_task_item1.messages = []
-
- mock_task_item2 = MagicMock()
- mock_task_item2.item_id = "task_2"
- mock_task_item2.user_id = "user_2"
- mock_task_item2.mem_cube_id = "cube_2"
- mock_task_item2.task_info = {"type": "answer"}
- mock_task_item2.task_name = "test_task_2"
- mock_task_item2.start_time = datetime.now()
- mock_task_item2.end_time = None
- mock_task_item2.status = "completed"
- mock_task_item2.result = "success"
- mock_task_item2.error_message = None
- mock_task_item2.messages = ["message1", "message2"]
-
- with patch.object(
- self.scheduler.dispatcher,
- "get_running_tasks",
- return_value={"task_1": mock_task_item1, "task_2": mock_task_item2},
- ) as mock_get_running_tasks:
- # Call get_running_tasks
- result = self.scheduler.get_running_tasks()
-
- # Verify result structure
- self.assertIsInstance(result, dict)
- self.assertEqual(len(result), 2)
- self.assertIn("task_1", result)
- self.assertIn("task_2", result)
-
- # Verify task_1 details
- task1_dict = result["task_1"]
- self.assertEqual(task1_dict["item_id"], "task_1")
- self.assertEqual(task1_dict["user_id"], "user_1")
- self.assertEqual(task1_dict["status"], "running")
-
- # Verify task_2 details
- task2_dict = result["task_2"]
- self.assertEqual(task2_dict["item_id"], "task_2")
- self.assertEqual(task2_dict["user_id"], "user_2")
- self.assertEqual(task2_dict["status"], "completed")
- self.assertEqual(task2_dict["result"], "success")
- self.assertEqual(task2_dict["messages"], ["message1", "message2"])
-
- # Verify dispatcher method was called
- mock_get_running_tasks.assert_called_once_with(filter_func=None)