diff --git a/evaluation/scripts/temporal_locomo/models/__init__.py b/evaluation/scripts/temporal_locomo/models/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/evaluation/scripts/temporal_locomo/models/locomo_eval.py b/evaluation/scripts/temporal_locomo/models/locomo_eval.py deleted file mode 100644 index f98a481e2..000000000 --- a/evaluation/scripts/temporal_locomo/models/locomo_eval.py +++ /dev/null @@ -1,531 +0,0 @@ -import argparse -import asyncio -import json -import os -import time - -import nltk -import numpy as np - -from bert_score import score as bert_score -from dotenv import load_dotenv -from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu -from nltk.translate.meteor_score import meteor_score -from openai import AsyncOpenAI -from pydantic import BaseModel, Field -from rouge_score import rouge_scorer -from scipy.spatial.distance import cosine -from sentence_transformers import SentenceTransformer -from tqdm import tqdm - -from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules -from memos.log import get_logger - - -logger = get_logger(__name__) - - -# Download necessary NLTK resources -try: - nltk.download("wordnet", quiet=True) - nltk.download("punkt", quiet=True) - print("NLTK resources downloaded successfully.") -except Exception as e: - print(f"Warning: Failed to download NLTK resources: {e}") - - -try: - sentence_model_name = "Qwen/Qwen3-Embedding-0.6B" - sentence_model = SentenceTransformer(sentence_model_name) - print(f"SentenceTransformer model : {sentence_model_name} loaded successfully.") -except Exception as e: - print(f"Failed to load SentenceTransformer model: {e}") - sentence_model = None - - -class LLMGrade(BaseModel): - llm_judgment: str = Field(description="CORRECT or WRONG") - llm_reasoning: str = Field(description="Explain why the answer is correct or incorrect.") - - -async def locomo_grader(llm_client, question: str, gold_answer: str, response: str) -> bool: - system_prompt = """ - You are an expert grader that determines if answers to questions match a gold standard answer - """ - - accuracy_prompt = f""" - Your task is to label an answer to a question as ’CORRECT’ or ’WRONG’. You will be given the following data: - (1) a question (posed by one user to another user), - (2) a ’gold’ (ground truth) answer, - (3) a generated answer - which you will score as CORRECT/WRONG. - - The point of the question is to ask about something one user should know about the other user based on their prior conversations. - The gold answer will usually be a concise and short answer that includes the referenced topic, for example: - Question: Do you remember what I got the last time I went to Hawaii? - Gold answer: A shell necklace - The generated answer might be much longer, but you should be generous with your grading - as long as it touches on the same topic as the gold answer, it should be counted as CORRECT. - - For time related questions, the gold answer will be a specific date, month, year, etc. The generated answer might be much longer or use relative time references (like "last Tuesday" or "next month"), but you should be generous with your grading - as long as it refers to the same date or time period as the gold answer, it should be counted as CORRECT. Even if the format differs (e.g., "May 7th" vs "7 May"), consider it CORRECT if it's the same date. - - Now it’s time for the real question: - Question: {question} - Gold answer: {gold_answer} - Generated answer: {response} - - First, provide a short (one sentence) explanation of your reasoning, then finish with CORRECT or WRONG. - Do NOT include both CORRECT and WRONG in your response, or it will break the evaluation script. - - Just return the label CORRECT or WRONG in a json format with the key as "label". - """ - - response = await llm_client.chat.completions.create( - model="gpt-4o-mini", - messages=[ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": accuracy_prompt}, - ], - temperature=0, - ) - message_content = response.choices[0].message.content - label = json.loads(message_content)["label"] - parsed = LLMGrade(llm_judgment=label, llm_reasoning="") - - return parsed.llm_judgment.strip().lower() == "correct" - - -def calculate_rouge_scores(gold_answer, response): - metrics = {"rouge1_f": 0.0, "rouge2_f": 0.0, "rougeL_f": 0.0} - try: - scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True) - rouge_scores = scorer.score(gold_answer, response) - metrics["rouge1_f"] = rouge_scores["rouge1"].fmeasure - metrics["rouge2_f"] = rouge_scores["rouge2"].fmeasure - metrics["rougeL_f"] = rouge_scores["rougeL"].fmeasure - except Exception as e: - print(f"Failed to calculate ROUGE scores: {e}") - return metrics - - -def calculate_bleu_scores(gold_tokens, response_tokens): - metrics = {"bleu1": 0.0, "bleu2": 0.0, "bleu3": 0.0, "bleu4": 0.0} - - try: - smoothing = SmoothingFunction().method1 - weights = [(1, 0, 0, 0), (0.5, 0.5, 0, 0), (0.33, 0.33, 0.33, 0), (0.25, 0.25, 0.25, 0.25)] - - for i, weight in enumerate(weights, 1): - metrics[f"bleu{i}"] = sentence_bleu( - [gold_tokens], response_tokens, weights=weight, smoothing_function=smoothing - ) - except ZeroDivisionError: - pass - except Exception as e: - print(f"Failed to calculate BLEU scores: {e}") - - return metrics - - -def calculate_meteor_score(gold_tokens, response_tokens): - try: - return meteor_score([gold_tokens], response_tokens) - except Exception as e: - print(f"Failed to calculate METEOR score: {e}") - return 0.0 - - -def calculate_semantic_similarity(gold_answer, response): - global sentence_model - - try: - if sentence_model is None: - sentence_model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B") - - gold_embedding = sentence_model.encode([gold_answer], show_progress_bar=False)[0] - response_embedding = sentence_model.encode([response], show_progress_bar=False)[0] - return 1 - cosine(gold_embedding, response_embedding) - except Exception as e: - print(f"Failed to calculate semantic similarity: {e}") - return 0.0 - - -def calculate_f1_score(gold_tokens, response_tokens): - try: - gold_set = set(gold_tokens) - response_set = set(response_tokens) - - if len(gold_set) == 0 or len(response_set) == 0: - return 0.0 - - precision = len(gold_set.intersection(response_set)) / len(response_set) - recall = len(gold_set.intersection(response_set)) / len(gold_set) - - if precision + recall > 0: - return 2 * precision * recall / (precision + recall) - return 0.0 - except Exception as e: - print(f"Failed to calculate F1 score: {e}") - return 0.0 - - -def calculate_nlp_metrics(gold_answer, response, context, options=None): - if options is None: - options = ["lexical", "semantic"] - - gold_answer = str(gold_answer) if gold_answer is not None else "" - response = str(response) if response is not None else "" - - metrics = {"context_tokens": len(nltk.word_tokenize(context)) if context else 0} - - if "lexical" in options: - gold_tokens = nltk.word_tokenize(gold_answer.lower()) - response_tokens = nltk.word_tokenize(response.lower()) - - metrics["lexical"] = {} - metrics["lexical"]["f1"] = calculate_f1_score(gold_tokens, response_tokens) - metrics["lexical"].update(calculate_rouge_scores(gold_answer, response)) - metrics["lexical"].update(calculate_bleu_scores(gold_tokens, response_tokens)) - metrics["lexical"]["meteor"] = calculate_meteor_score(gold_tokens, response_tokens) - - if "semantic" in options: - metrics["semantic"] = {} - metrics["semantic"]["similarity"] = calculate_semantic_similarity(gold_answer, response) - _, _, f1 = bert_score( - [gold_answer], [response], lang="en", rescale_with_baseline=True, verbose=False - ) - metrics["semantic"]["bert_f1"] = f1.item() if f1 is not None else 0.0 - - return metrics - - -def convert_numpy_types(obj): - if isinstance(obj, np.number): - return float(obj) - elif isinstance(obj, dict): - return {k: convert_numpy_types(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [convert_numpy_types(i) for i in obj] - else: - return obj - - -async def process_group_responses( - group_id, group_responses, oai_client, evaluation_options, num_runs: int -): - graded_responses = [] - - # Process responses with asyncio for concurrent API calls - for response in tqdm(group_responses, desc=f"Processing group {group_id}"): - question = response.get("question") - answer = response.get("answer") - ground_truth = response.get("golden_answer") - category = response.get("category") - - context = response.get("search_context", "") - response_duration_ms = response.get("response_duration_ms", 0.0) - search_duration_ms = response.get("search_duration_ms", 0.0) - - if ground_truth is None: - continue - - grading_tasks = [ - locomo_grader(oai_client, question, ground_truth, answer) for _ in range(num_runs) - ] - judgments = await asyncio.gather(*grading_tasks) - judgments_dict = {f"judgment_{i + 1}": j for i, j in enumerate(judgments)} - - nlp_metrics = calculate_nlp_metrics(ground_truth, answer, context, evaluation_options) - - graded_response = { - "question": question, - "answer": answer, - "golden_answer": ground_truth, - "category": category, - "llm_judgments": judgments_dict, - "nlp_metrics": nlp_metrics, - "response_duration_ms": response_duration_ms, - "search_duration_ms": search_duration_ms, - "total_duration_ms": response_duration_ms + search_duration_ms, - } - graded_responses.append(graded_response) - - return group_id, graded_responses - - -async def process_single_group(group_id, group_responses, oai_client, evaluation_options, num_runs): - try: - start_time = time.time() - result = await process_group_responses( - group_id, group_responses, oai_client, evaluation_options, num_runs - ) - end_time = time.time() - elapsed_time = round(end_time - start_time, 2) - print(f"Group {group_id} processed in {elapsed_time} seconds") - return result - except Exception as e: - logger.error(f"Error processing group {group_id}: {e}", exc_info=True) - return group_id, [] - - -class LocomoEvaluator(LocomoEvalModelModules): - def __init__(self, args): - # Initialize base class to populate self.frame, self.version, etc. - super().__init__(args=args) - - self.evaluation_options = getattr(args, "evaluation_options", ["lexical", "semantic"]) - self.num_runs = getattr(args, "num_runs", 1) - self.max_workers = getattr(args, "workers", 4) - - load_dotenv() - self.oai_client = AsyncOpenAI( - api_key=os.getenv("OPENAI_API_KEY"), base_url=os.getenv("OPENAI_BASE_URL") - ) - - def _load_response_data(self): - """ - Load response data from the response path file. - - Returns: - dict: The loaded response data - """ - with open(self.response_path) as file: - return json.load(file) - - def _load_existing_evaluation_results(self): - """ - Attempt to load existing evaluation results from the judged path. - If the file doesn't exist or there's an error loading it, return an empty dict. - - Returns: - dict: Existing evaluation results or empty dict if none available - """ - all_grades = {} - try: - if os.path.exists(self.judged_path): - with open(self.judged_path) as f: - all_grades = json.load(f) - print(f"Loaded existing evaluation results from {self.judged_path}") - except Exception as e: - print(f"Error loading existing evaluation results: {e}") - - return all_grades - - def _create_evaluation_tasks(self, locomo_responses, all_grades, num_users): - """ - Create evaluation tasks for groups that haven't been evaluated yet. - - Args: - locomo_responses (dict): The loaded response data - all_grades (dict): Existing evaluation results - num_users (int): Number of user groups to process - - Returns: - tuple: (tasks list, active users count) - """ - tasks = [] - active_users = 0 - - for group_idx in range(num_users): - group_id = f"locomo_exp_user_{group_idx}" - group_responses = locomo_responses.get(group_id, []) - - if not group_responses: - print(f"No responses found for group {group_id}") - continue - - # Skip groups that already have evaluation results - if all_grades.get(group_id): - print(f"Skipping group {group_id} as it already has evaluation results") - active_users += 1 - continue - - active_users += 1 - tasks.append( - process_single_group( - group_id=group_id, - group_responses=group_responses, - oai_client=self.oai_client, - evaluation_options=self.evaluation_options, - num_runs=self.num_runs, - ) - ) - - return tasks, active_users - - async def _process_tasks(self, tasks): - """ - Process evaluation tasks with concurrency control. - - Args: - tasks (list): List of tasks to process - - Returns: - list: Results from processing all tasks - """ - if not tasks: - return [] - - semaphore = asyncio.Semaphore(self.max_workers) - - async def limited_task(task): - """Helper function to limit concurrent task execution""" - async with semaphore: - return await task - - limited_tasks = [limited_task(task) for task in tasks] - return await asyncio.gather(*limited_tasks) - - def _calculate_scores(self, all_grades): - """ - Calculate evaluation scores based on all grades. - - Args: - all_grades (dict): The complete evaluation results - - Returns: - tuple: (run_scores, evaluated_count) - """ - run_scores = [] - evaluated_count = 0 - - if self.num_runs > 0: - for i in range(1, self.num_runs + 1): - judgment_key = f"judgment_{i}" - current_run_correct_count = 0 - current_run_total_count = 0 - - for group in all_grades.values(): - for response in group: - if judgment_key in response["llm_judgments"]: - if response["llm_judgments"][judgment_key]: - current_run_correct_count += 1 - current_run_total_count += 1 - - if current_run_total_count > 0: - run_accuracy = current_run_correct_count / current_run_total_count - run_scores.append(run_accuracy) - - evaluated_count = current_run_total_count - - return run_scores, evaluated_count - - def _report_scores(self, run_scores, evaluated_count): - """ - Report evaluation scores to the console. - - Args: - run_scores (list): List of accuracy scores for each run - evaluated_count (int): Number of evaluated responses - """ - if evaluated_count > 0: - mean_of_scores = np.mean(run_scores) - std_of_scores = np.std(run_scores) - print(f"LLM-as-a-Judge Mean Score: {mean_of_scores:.4f}") - print(f"LLM-as-a-Judge Standard Deviation: {std_of_scores:.4f}") - print( - f"(Calculated from {self.num_runs} separate runs over {evaluated_count} questions)" - ) - print(f"Individual run scores: {[round(s, 4) for s in run_scores]}") - else: - print("No responses were evaluated") - print("LLM-as-a-Judge score: N/A (0/0)") - - def _save_results(self, all_grades): - """ - Save evaluation results to the judged path file. - - Args: - all_grades (dict): The complete evaluation results to save - """ - all_grades = convert_numpy_types(all_grades) - with open(self.judged_path, "w") as f: - json.dump(all_grades, f, indent=2) - print(f"Saved detailed evaluation results to {self.judged_path}") - - async def run(self): - """ - Main execution method for the LoCoMo evaluation process. - This method orchestrates the entire evaluation workflow: - 1. Loads existing evaluation results if available - 2. Processes only groups that haven't been evaluated yet - 3. Calculates and reports final evaluation scores - """ - print( - f"\n=== Starting LoCoMo evaluation for {self.frame} (version: {self.version}) with {self.num_runs} run(s) per question ===" - ) - print(f"Using {self.max_workers} concurrent workers for processing groups") - - # Load response data and existing evaluation results - locomo_responses = self._load_response_data() - all_grades = self._load_existing_evaluation_results() - - # Count total responses for reporting - num_users = 10 - total_responses_count = sum( - len(locomo_responses.get(f"locomo_exp_user_{i}", [])) for i in range(num_users) - ) - print(f"Found {total_responses_count} total responses across {num_users} users to evaluate") - - # Create tasks only for groups that haven't been evaluated yet - tasks, active_users = self._create_evaluation_tasks(locomo_responses, all_grades, num_users) - print( - f"Starting evaluation of {len(tasks)} user groups with responses (out of {active_users} active users)" - ) - - # Process tasks and update all_grades with results - if tasks: - group_results = await self._process_tasks(tasks) - for group_id, graded_responses in group_results: - all_grades[group_id] = graded_responses - - print("\n=== Evaluation Complete: Calculating final scores ===") - - # Calculate and report scores - run_scores, evaluated_count = self._calculate_scores(all_grades) - self._report_scores(run_scores, evaluated_count) - - # Save results - self._save_results(all_grades) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - default="memos_scheduler", - choices=["zep", "memos", "memos_scheduler", "mem0", "mem0_graph", "langmem", "openai"], - help="Specify the memory framework (zep or memos or mem0 or mem0_graph)", - ) - parser.add_argument( - "--version", - type=str, - default="v1.0.1", - help="Version identifier for loading results (e.g., 1010)", - ) - parser.add_argument( - "--num_runs", - type=int, - default=3, - help="Number of times to run the LLM grader for each question", - ) - parser.add_argument("--evaluation_options", nargs="+", default=["lexical", "semantic"]) - parser.add_argument( - "--workers", type=int, default=10, help="Number of concurrent workers for processing groups" - ) - cli_args = parser.parse_args() - - # Build args for evaluator - class Args: - def __init__(self, cli_args): - self.frame = cli_args.lib - self.version = cli_args.version - self.workers = cli_args.workers - self.num_runs = cli_args.num_runs - self.evaluation_options = cli_args.evaluation_options - self.top_k = 20 - self.scheduler_flag = True - - args = Args(cli_args) - evaluator = LocomoEvaluator(args=args) - asyncio.run(evaluator.run()) diff --git a/evaluation/scripts/temporal_locomo/models/locomo_ingestion.py b/evaluation/scripts/temporal_locomo/models/locomo_ingestion.py deleted file mode 100644 index b45ec3d61..000000000 --- a/evaluation/scripts/temporal_locomo/models/locomo_ingestion.py +++ /dev/null @@ -1,303 +0,0 @@ -import concurrent.futures -import sys -import time -import traceback - -from datetime import datetime, timezone -from pathlib import Path - -from tqdm import tqdm - -from evaluation.scripts.temporal_locomo.modules.constants import ( - MEM0_GRAPH_MODEL, - MEM0_MODEL, - MEMOS_MODEL, - MEMOS_SCHEDULER_MODEL, - ZEP_MODEL, -) -from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules -from memos.log import get_logger - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -class LocomoIngestor(LocomoEvalModelModules): - def __init__(self, args): - super().__init__(args=args) - - def ingest_session(self, client, session, frame, metadata, revised_client=None): - session_date = metadata["session_date"] - date_format = "%I:%M %p on %d %B, %Y UTC" - date_string = datetime.strptime(session_date, date_format).replace(tzinfo=timezone.utc) - iso_date = date_string.isoformat() - conv_id = metadata["conv_id"] - conv_id = "locomo_exp_user_" + str(conv_id) - dt = datetime.fromisoformat(iso_date) - timestamp = int(dt.timestamp()) - print(f"Processing conv {conv_id}, session {metadata['session_key']}") - start_time = time.time() - print_once = True # Print example only once per session - - if frame == ZEP_MODEL: - for chat in tqdm(session, desc=f"{metadata['session_key']}"): - data = chat.get("speaker") + ": " + chat.get("text") - - # Print example only once per session - if print_once: - print({"context": data, "conv_id": conv_id, "created_at": iso_date}) - print_once = False - - # Check if the group exists, if not create it - groups = client.group.get_all_groups() - groups = dict(groups)["groups"] - exist_ids = [gp.group_id for gp in groups] - if conv_id not in exist_ids: - client.group.add(group_id=conv_id) - - # Add the message to the group - client.graph.add( - data=data, - type="message", - created_at=iso_date, - group_id=conv_id, - ) - - elif frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: - messages = [] - messages_reverse = [] - - for chat in tqdm(session, desc=f"{metadata['session_key']}"): - data = chat.get("speaker") + ": " + chat.get("text") - - if chat.get("speaker") == metadata["speaker_a"]: - messages.append({"role": "user", "content": data, "chat_time": iso_date}) - messages_reverse.append( - {"role": "assistant", "content": data, "chat_time": iso_date} - ) - elif chat.get("speaker") == metadata["speaker_b"]: - messages.append({"role": "assistant", "content": data, "chat_time": iso_date}) - messages_reverse.append( - {"role": "user", "content": data, "chat_time": iso_date} - ) - else: - raise ValueError( - f"Unknown speaker {chat.get('speaker')} in session {metadata['session_key']}" - ) - - # Print example only once per session - if print_once: - print({"context": data, "conv_id": conv_id, "created_at": iso_date}) - print_once = False - - speaker_a_user_id = conv_id + "_speaker_a" - speaker_b_user_id = conv_id + "_speaker_b" - - client.add( - messages=messages, - user_id=speaker_a_user_id, - ) - - revised_client.add( - messages=messages_reverse, - user_id=speaker_b_user_id, - ) - print(f"Added messages for {speaker_a_user_id} and {speaker_b_user_id} successfully.") - - elif frame in [MEM0_MODEL, MEM0_GRAPH_MODEL]: - print(f"Processing abc for {metadata['session_key']}") - messages = [] - messages_reverse = [] - - for chat in tqdm(session, desc=f"{metadata['session_key']}"): - data = chat.get("speaker") + ": " + chat.get("text") - - if chat.get("speaker") == metadata["speaker_a"]: - messages.append({"role": "user", "content": data}) - messages_reverse.append({"role": "assistant", "content": data}) - elif chat.get("speaker") == metadata["speaker_b"]: - messages.append({"role": "assistant", "content": data}) - messages_reverse.append({"role": "user", "content": data}) - else: - raise ValueError( - f"Unknown speaker {chat.get('speaker')} in session {metadata['session_key']}" - ) - - # Print example only once per session - if print_once: - print({"context": data, "conv_id": conv_id, "created_at": iso_date}) - print_once = False - - for i in range(0, len(messages), 2): - batch_messages = messages[i : i + 2] - batch_messages_reverse = messages_reverse[i : i + 2] - - if frame == "mem0": - client.add( - messages=batch_messages, - timestamp=timestamp, - user_id=metadata["speaker_a_user_id"], - version="v2", - ) - client.add( - messages=batch_messages_reverse, - timestamp=timestamp, - user_id=metadata["speaker_b_user_id"], - version="v2", - ) - - elif frame == "mem0_graph": - client.add( - messages=batch_messages, - timestamp=timestamp, - user_id=metadata["speaker_a_user_id"], - output_format="v1.1", - version="v2", - enable_graph=True, - ) - client.add( - messages=batch_messages_reverse, - timestamp=timestamp, - user_id=metadata["speaker_b_user_id"], - output_format="v1.1", - version="v2", - enable_graph=True, - ) - - end_time = time.time() - elapsed_time = round(end_time - start_time, 2) - - return elapsed_time - - def process_user_for_ingestion(self, conv_id, frame, locomo_df, version, num_workers=1): - try: - # Check if locomo_df is empty or doesn't have the required columns - if locomo_df.empty or "conversation" not in locomo_df.columns: - logger.warning( - f"Skipping user {conv_id}: locomo_df is empty or missing 'conversation' column" - ) - return 0 - - conversation = locomo_df["conversation"].iloc[conv_id] - max_session_count = 35 - start_time = time.time() - total_session_time = 0 - valid_sessions = 0 - - revised_client = None - if frame == "zep": - client = self.get_client_for_ingestion(frame=frame, user_id=None, version="default") - elif frame == "mem0" or frame == "mem0_graph": - client = self.get_client_for_ingestion(frame=frame, user_id=None, version="default") - client.delete_all(user_id=f"locomo_exp_user_{conv_id}") - client.delete_all(user_id=f"{conversation.get('speaker_a')}_{conv_id}") - client.delete_all(user_id=f"{conversation.get('speaker_b')}_{conv_id}") - elif frame in ["memos", "memos_scheduler"]: - conv_id = "locomo_exp_user_" + str(conv_id) - speaker_a_user_id = conv_id + "_speaker_a" - speaker_b_user_id = conv_id + "_speaker_b" - - client = self.get_client_for_ingestion( - frame=frame, user_id=speaker_a_user_id, version=version - ) - revised_client = self.get_client_for_ingestion( - frame=frame, user_id=speaker_b_user_id, version=version - ) - else: - raise NotImplementedError() - - sessions_to_process = [] - for session_idx in tqdm(range(max_session_count), desc=f"process_user {conv_id}"): - session_key = f"session_{session_idx}" - session = conversation.get(session_key) - if session is None: - continue - - metadata = { - "session_date": conversation.get(f"session_{session_idx}_date_time") + " UTC", - "speaker_a": conversation.get("speaker_a"), - "speaker_b": conversation.get("speaker_b"), - "speaker_a_user_id": f"{conversation.get('speaker_a')}_{conv_id}", - "speaker_b_user_id": f"{conversation.get('speaker_b')}_{conv_id}", - "conv_id": conv_id, - "session_key": session_key, - } - sessions_to_process.append((session, metadata)) - valid_sessions += 1 - - print( - f"Processing {valid_sessions} sessions for user {conv_id} with {num_workers} workers" - ) - with concurrent.futures.ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = { - executor.submit( - self.ingest_session, client, session, frame, metadata, revised_client - ): metadata["session_key"] - for session, metadata in sessions_to_process - } - - for future in concurrent.futures.as_completed(futures): - session_key = futures[future] - try: - session_time = future.result() - total_session_time += session_time - print(f"User {conv_id}, {session_key} processed in {session_time} seconds") - except Exception as e: - print(f"Error processing user {conv_id}, session {session_key}: {e!s}") - - end_time = time.time() - elapsed_time = round(end_time - start_time, 2) - print(f"User {conv_id} processed successfully in {elapsed_time} seconds") - - return elapsed_time - - except Exception as e: - return f"Error processing user {conv_id}: {e!s}. Exception: {traceback.format_exc()}" - - def run_ingestion(self): - frame = self.frame - version = self.version - num_workers = self.workers - - num_users = 10 - start_time = time.time() - total_time = 0 - - print( - f"Starting processing for {num_users} users in serial mode," - f" each user using {num_workers} workers for sessions..." - ) - - for user_id in range(num_users): - try: - result = self.process_user_for_ingestion( - user_id, frame, self.locomo_df, version, num_workers - ) - if isinstance(result, float): - total_time += result - else: - print(result) - except Exception as e: - print( - f"Error processing user {user_id}: {e!s}. Traceback: {traceback.format_exc()}" - ) - - if num_users > 0: - average_time = total_time / num_users - minutes = int(average_time // 60) - seconds = int(average_time % 60) - average_time_formatted = f"{minutes} minutes and {seconds} seconds" - print( - f"The frame {frame} processed {num_users} users in average of {average_time_formatted} per user." - ) - - end_time = time.time() - elapsed_time = round(end_time - start_time, 2) - minutes = int(elapsed_time // 60) - seconds = int(elapsed_time % 60) - elapsed_time = f"{minutes} minutes and {seconds} seconds" - print(f"Total processing time: {elapsed_time}.") diff --git a/evaluation/scripts/temporal_locomo/models/locomo_metric.py b/evaluation/scripts/temporal_locomo/models/locomo_metric.py deleted file mode 100644 index 532fe2e14..000000000 --- a/evaluation/scripts/temporal_locomo/models/locomo_metric.py +++ /dev/null @@ -1,390 +0,0 @@ -import argparse -import json - -import numpy as np -import pandas as pd - -from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules - - -# Category mapping as per your request -category_mapping = { - "4": "single hop", - "1": "multi hop", - "2": "temporal reasoning", - "3": "open domain", -} - - -def calculate_scores(data): - category_scores = {} - category_question_count = {} - - overall_metrics = { - "lexical": { - m: [] - for m in [ - "f1", - "rouge1_f", - "rouge2_f", - "rougeL_f", - "bleu1", - "bleu2", - "bleu3", - "bleu4", - "meteor", - ] - }, - "semantic": {m: [] for m in ["bert_f1", "similarity"]}, - "context_tokens": [], - "duration": { - m: [] for m in ["response_duration_ms", "search_duration_ms", "total_duration_ms"] - }, - } - - category_metrics = {} - user_metrics = {} - - total_questions = 0 - - all_judgment_keys = set() - judgment_run_scores = {} - - for _user, questions in data.items(): - for question in questions: - if "llm_judgments" in question: - all_judgment_keys.update(question["llm_judgments"].keys()) - - for key in all_judgment_keys: - judgment_run_scores[key] = [] - - for user, questions in data.items(): - user_total = 0 - - # Initialize user_metrics with each judgment run - user_metrics[user] = { - "total": 0, - "llm_judge_score": 0, - "llm_judge_std": 0, - "judgment_run_scores": {key: [] for key in all_judgment_keys}, - "lexical": {m: [] for m in overall_metrics["lexical"]}, - "semantic": {m: [] for m in overall_metrics["semantic"]}, - "context_tokens": [], - "duration": {m: [] for m in overall_metrics["duration"]}, - } - - for question in questions: - total_questions += 1 - user_total += 1 - - if "llm_judgments" in question: - for judgment_key, judgment_value in question["llm_judgments"].items(): - score = 1 if judgment_value else 0 - judgment_run_scores[judgment_key].append(score) - user_metrics[user]["judgment_run_scores"][judgment_key].append(score) - - category = question["category"] - if category not in category_scores: - category_scores[category] = { - "total": 0, - "category_name": category_mapping.get(str(category), "Unknown"), - "judgment_run_scores": {key: [] for key in all_judgment_keys}, - } - category_metrics[category] = { - "lexical": {m: [] for m in overall_metrics["lexical"]}, - "semantic": {m: [] for m in overall_metrics["semantic"]}, - "context_tokens": [], - "duration": {m: [] for m in overall_metrics["duration"]}, - } - category_question_count[category] = 0 - - category_scores[category]["total"] += 1 - category_question_count[category] += 1 - - if "llm_judgments" in question: - for judgment_key, judgment_value in question["llm_judgments"].items(): - score = 1 if judgment_value else 0 - category_scores[category]["judgment_run_scores"][judgment_key].append(score) - - nlp = question.get("nlp_metrics", {}) - for metric in overall_metrics["lexical"]: - v = nlp.get("lexical", {}).get(metric) - if v is not None: - overall_metrics["lexical"][metric].append(v) - category_metrics[category]["lexical"][metric].append(v) - user_metrics[user]["lexical"][metric].append(v) - - for metric in overall_metrics["semantic"]: - v = nlp.get("semantic", {}).get(metric) - if v is not None: - overall_metrics["semantic"][metric].append(v) - category_metrics[category]["semantic"][metric].append(v) - user_metrics[user]["semantic"][metric].append(v) - - ct = nlp.get("context_tokens") - if ct is not None: - overall_metrics["context_tokens"].append(ct) - category_metrics[category]["context_tokens"].append(ct) - user_metrics[user]["context_tokens"].append(ct) - - for metric in overall_metrics["duration"]: - v = question.get(metric) - if v is not None: - overall_metrics["duration"][metric].append(v) - category_metrics[category]["duration"][metric].append(v) - user_metrics[user]["duration"][metric].append(v) - - user_metrics[user]["total"] = user_total - - judgment_avgs = [] - for _judgment_key, scores in user_metrics[user]["judgment_run_scores"].items(): - if scores: - avg = np.mean(scores) - judgment_avgs.append(avg) - - user_metrics[user]["llm_judge_score"] = np.mean(judgment_avgs) if judgment_avgs else 0.0 - user_metrics[user]["llm_judge_std"] = ( - np.std(judgment_avgs) if len(judgment_avgs) > 1 else 0.0 - ) - - for group in ["lexical", "semantic"]: - for metric in user_metrics[user][group]: - values = user_metrics[user][group][metric] - user_metrics[user][group][metric] = np.mean(values) if values else 0.0 - - user_metrics[user]["context_tokens"] = ( - np.mean(user_metrics[user]["context_tokens"]) - if user_metrics[user]["context_tokens"] - else 0.0 - ) - - duration_metrics = list(user_metrics[user]["duration"].keys()) - for metric in duration_metrics: - values = user_metrics[user]["duration"][metric] - if values: - user_metrics[user]["duration"][metric] = np.mean(values) - user_metrics[user]["duration"][f"{metric}_p50"] = np.percentile(values, 50) - user_metrics[user]["duration"][f"{metric}_p95"] = np.percentile(values, 95) - else: - user_metrics[user]["duration"][metric] = 0.0 - user_metrics[user]["duration"][f"{metric}_p50"] = 0.0 - user_metrics[user]["duration"][f"{metric}_p95"] = 0.0 - - judgment_run_averages = [] - for _judgment_key, scores in judgment_run_scores.items(): - if scores: - judgment_run_averages.append(np.mean(scores)) - - llm_judge_score = np.mean(judgment_run_averages) if judgment_run_averages else 0.0 - llm_judge_std = np.std(judgment_run_averages) if len(judgment_run_averages) > 1 else 0.0 - - category_overall_scores = {} - for category, score_data in category_scores.items(): - category_judgment_avgs = [] - for _judgment_key, scores in score_data["judgment_run_scores"].items(): - if scores: - category_judgment_avgs.append(np.mean(scores)) - - category_overall_scores[category] = { - "category_name": score_data["category_name"], - "llm_judge_score": np.mean(category_judgment_avgs) if category_judgment_avgs else 0.0, - "llm_judge_std": np.std(category_judgment_avgs) - if len(category_judgment_avgs) > 1 - else 0.0, - "total": score_data["total"], - "lexical": {}, - "semantic": {}, - "duration": {}, - "context_tokens": 0.0, - } - - for group in ["lexical", "semantic"]: - for metric in category_metrics[category][group]: - values = category_metrics[category][group][metric] - category_overall_scores[category][group][metric] = ( - np.mean(values) if values else 0.0 - ) - - category_overall_scores[category]["context_tokens"] = ( - np.mean(category_metrics[category]["context_tokens"]) - if category_metrics[category]["context_tokens"] - else 0.0 - ) - - # Calculate mean and percentiles for category duration metrics - duration_metrics = list( - category_metrics[category]["duration"].keys() - ) # Create a list of keys first - for metric in duration_metrics: - values = category_metrics[category]["duration"][metric] - if values: - category_overall_scores[category]["duration"][metric] = np.mean(values) - # Add P50 (median) and P95 percentiles - category_overall_scores[category]["duration"][f"{metric}_p50"] = np.percentile( - values, 50 - ) - category_overall_scores[category]["duration"][f"{metric}_p95"] = np.percentile( - values, 95 - ) - else: - category_overall_scores[category]["duration"][metric] = 0.0 - category_overall_scores[category]["duration"][f"{metric}_p50"] = 0.0 - category_overall_scores[category]["duration"][f"{metric}_p95"] = 0.0 - - # calculate overall scores - overall_metric_averages = { - "llm_judge_score": llm_judge_score, - "llm_judge_std": llm_judge_std, - "lexical": {}, - "semantic": {}, - "context_tokens": 0.0, - "duration": {}, - } - - for group in ["lexical", "semantic"]: - for metric in overall_metrics[group]: - values = overall_metrics[group][metric] - overall_metric_averages[group][metric] = np.mean(values) if values else 0.0 - - overall_metric_averages["context_tokens"] = ( - np.mean(overall_metrics["context_tokens"]) if overall_metrics["context_tokens"] else 0.0 - ) - - duration_metrics = list(overall_metrics["duration"].keys()) - for metric in duration_metrics: - values = overall_metrics["duration"][metric] - if values: - overall_metric_averages["duration"][metric] = np.mean(values) - overall_metric_averages["duration"][f"{metric}_p50"] = np.percentile(values, 50) - overall_metric_averages["duration"][f"{metric}_p95"] = np.percentile(values, 95) - else: - overall_metric_averages["duration"][metric] = 0.0 - overall_metric_averages["duration"][f"{metric}_p50"] = 0.0 - overall_metric_averages["duration"][f"{metric}_p95"] = 0.0 - - return { - "metrics": overall_metric_averages, - "category_scores": category_overall_scores, - "user_scores": user_metrics, - } - - -def save_to_excel(results, output_path): - # Create a combined data structure for metrics and category scores - combined_data = [] - - # Process overall metrics - flatten nested structures - overall_row = {"category": "overall"} - overall_row["llm_judge_score"] = results["metrics"]["llm_judge_score"] - overall_row["llm_judge_std"] = results["metrics"]["llm_judge_std"] - - # Add all lexical metrics - for metric, value in results["metrics"]["lexical"].items(): - overall_row[metric] = value - - # Add all semantic metrics - for metric, value in results["metrics"]["semantic"].items(): - overall_row[metric] = value - - # Add context tokens - overall_row["context_tokens"] = results["metrics"]["context_tokens"] - - # Add all duration metrics, including percentiles - for metric, value in results["metrics"]["duration"].items(): - overall_row[metric] = value - - combined_data.append(overall_row) - - # Process category scores - flatten nested structures - for _, scores in results["category_scores"].items(): - category_row = {"category": scores["category_name"]} - category_row["llm_judge_score"] = scores["llm_judge_score"] - category_row["llm_judge_std"] = scores["llm_judge_std"] - - # Add all lexical metrics - for metric, value in scores["lexical"].items(): - category_row[metric] = value - - # Add all semantic metrics - for metric, value in scores["semantic"].items(): - category_row[metric] = value - - # Add context tokens - category_row["context_tokens"] = scores["context_tokens"] - - # Add all duration metrics, including percentiles - for metric, value in scores["duration"].items(): - category_row[metric] = value - - combined_data.append(category_row) - - # Create DataFrame and save to Excel - combined_df = pd.DataFrame(combined_data) - - # Create a pandas Excel writer - with pd.ExcelWriter(output_path) as writer: - combined_df.to_excel(writer, sheet_name="Metrics", index=False) - - print(f"Excel file saved to: {output_path}") - - -class LocomoMetric(LocomoEvalModelModules): - def __init__(self, args): - super().__init__(args=args) - - def run(self): - with open(self.judged_path) as file: - data = json.load(file) - - results = calculate_scores(data) - - with open(self.grade_path, "w") as outfile: - json.dump(results, outfile, indent=4) - - save_to_excel(results, self.excel_path) - - print("\n=== Metric Calculation Complete ===") - total = sum(results["category_scores"][cat]["total"] for cat in results["category_scores"]) - print( - f"LLM-as-a-Judge score: {results['metrics']['llm_judge_score']:.4f} ± {results['metrics']['llm_judge_std']:.4f}" - ) - print(f"Total questions evaluated: {total}") - - print("\n=== Duration Metrics ===") - for metric in ["response_duration_ms", "search_duration_ms", "total_duration_ms"]: - print(f"{metric} (avg): {results['metrics']['duration'][metric]:.2f} ms") - print(f"{metric} (P50): {results['metrics']['duration'][f'{metric}_p50']:.2f} ms") - print(f"{metric} (P95): {results['metrics']['duration'][f'{metric}_p95']:.2f} ms") - - print(f"\nResults have been written to {self.grade_path}") - print(f"Excel report has been saved to {self.excel_path}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--lib", - type=str, - default="memos_scheduler", - choices=["zep", "memos", "memos_scheduler", "mem0", "mem0_graph", "langmem", "openai"], - help="Specify the memory framework (zep or memos or mem0 or mem0_graph)", - ) - parser.add_argument( - "--version", - type=str, - default="v1.0.1", - help="Version identifier for loading results (e.g., 1010)", - ) - cli_args = parser.parse_args() - - # Build a minimal args namespace compatible with LocomoEvalModelModules - class _Args: - def __init__(self, frame, version): - self.frame = frame - self.version = version - self.workers = 1 - self.top_k = 20 - self.scheduler_flag = True - - args = _Args(frame=cli_args.lib, version=cli_args.version) - LocomoMetric(args=args).run() diff --git a/evaluation/scripts/temporal_locomo/models/locomo_processor.py b/evaluation/scripts/temporal_locomo/models/locomo_processor.py deleted file mode 100644 index 7cec6f5af..000000000 --- a/evaluation/scripts/temporal_locomo/models/locomo_processor.py +++ /dev/null @@ -1,370 +0,0 @@ -import json -import sys - -from collections import defaultdict -from concurrent.futures import ThreadPoolExecutor, as_completed -from pathlib import Path -from time import time - -from dotenv import load_dotenv - -from evaluation.scripts.temporal_locomo.modules.constants import ( - MEMOS_SCHEDULER_MODEL, -) -from evaluation.scripts.temporal_locomo.modules.locomo_eval_module import LocomoEvalModelModules -from evaluation.scripts.temporal_locomo.modules.prompts import ( - SEARCH_PROMPT_MEM0, - SEARCH_PROMPT_MEM0_GRAPH, - SEARCH_PROMPT_MEMOS, - SEARCH_PROMPT_ZEP, -) -from evaluation.scripts.temporal_locomo.modules.schemas import ContextUpdateMethod, RecordingCase -from evaluation.scripts.temporal_locomo.modules.utils import save_evaluation_cases -from memos.log import get_logger - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -class LocomoProcessor(LocomoEvalModelModules): - """ - A class for handling conversational memory management across different memory frameworks. - Supports multiple memory backends (zep, mem0, memos, etc.) for searching and retrieving - relevant context to generate conversational responses. - """ - - def __init__(self, args): - """Initialize the LocomoChatter with path configurations and templates""" - super().__init__(args=args) - - # Template definitions for different memory frameworks - self.search_template_zep = SEARCH_PROMPT_ZEP - - self.search_template_mem0 = SEARCH_PROMPT_MEM0 - - self.search_template_mem0_graph = SEARCH_PROMPT_MEM0_GRAPH - - self.search_template_memos = SEARCH_PROMPT_MEMOS - - self.processed_data_dir = self.result_dir / "processed_data" - - def update_context(self, conv_id, method, **kwargs): - if method == ContextUpdateMethod.CHAT_HISTORY: - if "query" not in kwargs or "answer" not in kwargs: - raise ValueError("query and answer are required for TEMPLATE update method") - new_context = f"User: {kwargs['query']}\nAssistant: {kwargs['answer']}\n\n" - if self.pre_context_cache[conv_id] is None: - self.pre_context_cache[conv_id] = "" - self.pre_context_cache[conv_id] += new_context - else: - if "cur_context" not in kwargs: - raise ValueError("cur_context is required for DIRECT update method") - cur_context = kwargs["cur_context"] - self.pre_context_cache[conv_id] = cur_context - - def eval_context(self, context, query, gold_answer, oai_client): - can_answer_start = time() - can_answer = self.analyze_context_answerability(context, query, gold_answer, oai_client) - can_answer_duration_ms = (time() - can_answer_start) * 1000 - # Update global stats - with self.stats_lock: - self.stats[self.frame][self.version]["memory_stats"]["total_queries"] += 1 - if can_answer: - self.stats[self.frame][self.version]["memory_stats"]["can_answer_count"] += 1 - else: - self.stats[self.frame][self.version]["memory_stats"]["cannot_answer_count"] += 1 - total_queries = self.stats[self.frame][self.version]["memory_stats"]["total_queries"] - can_answer_count = self.stats[self.frame][self.version]["memory_stats"][ - "can_answer_count" - ] - hit_rate = (can_answer_count / total_queries * 100) if total_queries > 0 else 0 - self.stats[self.frame][self.version]["memory_stats"]["answer_hit_rate"] = hit_rate - self.stats[self.frame][self.version]["memory_stats"]["can_answer_duration_ms"] = ( - can_answer_duration_ms - ) - self.save_stats() - return can_answer, can_answer_duration_ms - - def _update_stats_and_context( - self, - *, - conv_id, - frame, - version, - conv_stats, - conv_stats_path, - query, - answer, - gold_answer, - cur_context, - can_answer, - ): - """ - Update conversation statistics and context. - - Args: - conv_id: Conversation ID - frame: Model frame - version: Model version - conv_stats: Conversation statistics dictionary - conv_stats_path: Path to save conversation statistics - query: User query - answer: Generated answer - gold_answer: Golden answer - cur_context: Current context - can_answer: Whether the context can answer the query - """ - # Update conversation stats - conv_stats["total_queries"] += 1 - conv_stats["response_count"] += 1 - if frame == MEMOS_SCHEDULER_MODEL: - if can_answer: - conv_stats["can_answer_count"] += 1 - else: - conv_stats["cannot_answer_count"] += 1 - if conv_stats["total_queries"] > 0: - conv_stats["answer_hit_rate"] = ( - conv_stats["can_answer_count"] / conv_stats["total_queries"] - ) * 100 - - # Persist conversation stats snapshot - self._save_conv_stats(conv_id, frame, version, conv_stats, conv_stats_path) - - logger.info(f"Processed question: {query[:100]}") - logger.info(f"Answer: {answer[:100]}") - - # Update pre-context cache with current context - with self.stats_lock: - if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY: - self.update_context( - conv_id=conv_id, - method=self.context_update_method, - query=query, - answer=answer, - ) - else: - self.update_context( - conv_id=conv_id, - method=self.context_update_method, - cur_context=cur_context, - ) - - self.print_eval_info() - - def _process_single_qa( - self, - qa, - *, - client, - reversed_client, - metadata, - frame, - version, - conv_id, - conv_stats_path, - oai_client, - top_k, - conv_stats, - ): - query = qa.get("question") - gold_answer = qa.get("answer") - qa_category = qa.get("category") - if qa_category == 5: - return None - - # Search - cur_context, search_duration_ms = self.search_query( - client, query, metadata, frame, reversed_client=reversed_client, top_k=top_k - ) - if not cur_context: - logger.warning(f"No context found for query: {query[:100]}") - cur_context = "" - - if self.context_update_method == ContextUpdateMethod.CURRENT_CONTEXT: - context = cur_context - else: - # Context answer ability analysis (for memos_scheduler only) - if self.pre_context_cache[conv_id] is None: - # Update pre-context cache with current context and return - if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY: - answer_from_cur_context = self.locomo_response( - frame, oai_client, cur_context, query - ) - self.update_context( - conv_id=conv_id, - method=self.context_update_method, - query=query, - answer=answer_from_cur_context, - ) - else: - self.update_context( - conv_id=conv_id, - method=self.context_update_method, - cur_context=cur_context, - ) - return None - else: - context = self.pre_context_cache[conv_id] - - # Generate answer - answer_start = time() - answer = self.locomo_response(frame, oai_client, context, query) - response_duration_ms = (time() - answer_start) * 1000 - - can_answer, can_answer_duration_ms = self.eval_context( - context=context, query=query, gold_answer=gold_answer, oai_client=oai_client - ) - - # Record case for memos_scheduler - try: - recording_case = RecordingCase( - conv_id=conv_id, - query=query, - answer=answer, - context=cur_context, - pre_context=self.pre_context_cache[conv_id], - can_answer=can_answer, - can_answer_reason=f"Context analysis result: {'can answer' if can_answer else 'cannot answer'}", - search_duration_ms=search_duration_ms, - can_answer_duration_ms=can_answer_duration_ms, - response_duration_ms=response_duration_ms, - category=int(qa_category) if qa_category is not None else None, - golden_answer=str(qa.get("answer", "")), - ) - if can_answer: - self.can_answer_cases.append(recording_case) - else: - self.cannot_answer_cases.append(recording_case) - except Exception as e: - logger.error(f"Error creating RecordingCase: {e}") - print(f"Error creating RecordingCase: {e}") - logger.error(f"QA data: {qa}") - print(f"QA data: {qa}") - logger.error(f"Query: {query}") - logger.error(f"Answer: {answer}") - logger.error( - f"Golden answer (raw): {qa.get('answer')} (type: {type(qa.get('answer'))})" - ) - logger.error(f"Category: {qa_category} (type: {type(qa_category)})") - logger.error(f"Can answer: {can_answer}") - raise e - - if self.context_update_method == ContextUpdateMethod.CHAT_HISTORY: - answer_from_cur_context = self.locomo_response(frame, oai_client, cur_context, query) - answer = answer_from_cur_context - # Update conversation stats and context - self._update_stats_and_context( - conv_id=conv_id, - frame=frame, - version=version, - conv_stats=conv_stats, - conv_stats_path=conv_stats_path, - query=query, - answer=answer, - gold_answer=gold_answer, - cur_context=cur_context, - can_answer=can_answer, - ) - - return { - "question": query, - "answer": answer, - "category": qa_category, - "golden_answer": gold_answer, - "search_context": cur_context, - "response_duration_ms": response_duration_ms, - "search_duration_ms": search_duration_ms, - "can_answer_duration_ms": can_answer_duration_ms, - "can_answer": can_answer if frame == MEMOS_SCHEDULER_MODEL else None, - } - - def run_locomo_processing(self, num_users=10): - load_dotenv() - - frame = self.frame - version = self.version - num_workers = self.workers - top_k = self.top_k - - # Storage for aggregated results - all_search_results = defaultdict(list) - all_response_results = defaultdict(list) - num_users = num_users - - # Prepare arguments for each user processing task - user_args = [(idx, self.locomo_df, frame, version, top_k) for idx in range(num_users)] - - if num_workers > 1: - # === parallel running ==== - # Use ThreadPoolExecutor for parallel processing - print( - f"Starting parallel processing for {num_users} users, using {num_workers} workers for sessions..." - ) - with ThreadPoolExecutor(max_workers=num_workers) as executor: - # Submit all user processing tasks - future_to_user = { - executor.submit(self.process_user_wrapper, args): idx - for idx, args in enumerate(user_args) - } - - # Collect results as they complete - for future in as_completed(future_to_user): - idx = future_to_user[future] - user_search_results, user_response_results, error = future.result() - if error is not None: - idx, e, traceback_str = error - print(f"Error processing user {idx}: {e}. Exception: {traceback_str}") - else: - # Aggregate results - conv_id = f"locomo_exp_user_{idx}" - all_search_results[conv_id].extend(user_search_results[conv_id]) - all_response_results[conv_id].extend(user_response_results[conv_id]) - - else: - # Serial processing - print( - f"Starting serial processing for {num_users} users in serial mode, each user using {num_workers} workers for sessions..." - ) - for idx, args in enumerate(user_args): - user_search_results, user_response_results, error = self.process_user_wrapper(args) - if error is not None: - idx, e, traceback_str = error - print(f"Error processing user {idx}: {e}. Exception: {traceback_str}") - else: - # Aggregate results - conv_id = f"locomo_exp_user_{idx}" - all_search_results[conv_id].extend(user_search_results[conv_id]) - all_response_results[conv_id].extend(user_response_results[conv_id]) - - # Print evaluation information statistics - self.print_eval_info() - self.save_stats() - - # Save all aggregated results - with open(self.search_path, "w") as fw: - json.dump(all_search_results, fw, indent=2) - print(f"Saved all search results to {self.search_path}") - - with open(self.response_path, "w") as fw: - json.dump(all_response_results, fw, indent=2) - print(f"Saved all response results to {self.response_path}") - - # Save evaluation cases if they exist - if self.can_answer_cases or self.cannot_answer_cases: - try: - saved_files = save_evaluation_cases( - can_answer_cases=self.can_answer_cases, - cannot_answer_cases=self.cannot_answer_cases, - output_dir=self.stats_dir, - frame=self.frame, - version=self.version, - ) - print(f"Saved evaluation cases: {saved_files}") - except Exception as e: - logger.error(f"Error saving evaluation cases: {e}") - - return dict(all_search_results), dict(all_response_results) diff --git a/evaluation/scripts/temporal_locomo/models/locomo_processor_w_time_eval.py b/evaluation/scripts/temporal_locomo/models/locomo_processor_w_time_eval.py deleted file mode 100644 index b909c64e1..000000000 --- a/evaluation/scripts/temporal_locomo/models/locomo_processor_w_time_eval.py +++ /dev/null @@ -1,229 +0,0 @@ -import sys -import time - -from pathlib import Path -from typing import TYPE_CHECKING - -from evaluation.scripts.temporal_locomo.models.locomo_processor import LocomoProcessor -from evaluation.scripts.temporal_locomo.modules.constants import ( - MEMOS_SCHEDULER_MODEL, -) -from evaluation.scripts.temporal_locomo.modules.prompts import ( - SEARCH_PROMPT_MEMOS, -) -from evaluation.scripts.temporal_locomo.modules.schemas import ContextUpdateMethod, RecordingCase -from memos.log import get_logger - - -if TYPE_CHECKING: - from memos.mem_os.main import MOS - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -class LocomoProcessorWithTimeEval(LocomoProcessor): - def __init__(self, args): - super().__init__(args=args) - self.time_eval_mode = getattr(self.args, "time_eval_mode", False) - assert self.args.frame == MEMOS_SCHEDULER_MODEL - assert self.context_update_method == ContextUpdateMethod.PRE_CONTEXT - if self.time_eval_mode: - logger.warning( - "time_eval_mode is activated. _process_single_qa is replaced by _process_single_qa_for_time_eval" - ) - self._process_single_qa = self._process_single_qa_for_time_eval - - def memos_scheduler_search( - self, client, query, conv_id, speaker_a, speaker_b, reversed_client=None, top_k=20 - ): - # MemOS full search process and skip the parts of scheduler - start = time.time() - client: MOS = client - - if not self.scheduler_flag: - # if not scheduler_flag, search to update working memory - self.memos_search(client, query, conv_id, speaker_a, speaker_b, reversed_client) - - # ========= MemOS Search ========= - # Search for speaker A - search_a_results = client.search( - query=query, - user_id=conv_id + "_speaker_a", - install_cube_ids=[conv_id + "_speaker_a"], - top_k=top_k, - mode="fine", - internet_search=False, - moscube=False, # cube for mos introduction - session_id=None, - )["text_mem"] - search_a_results = [[m.memory for m in one["memories"]] for one in search_a_results] - search_a_results = [item for sublist in search_a_results for item in sublist] - - # Search for speaker B - search_b_results = client.search( - query=query, - user_id=conv_id + "_speaker_b", - install_cube_ids=[conv_id + "_speaker_b"], - top_k=top_k, - mode="fine", - internet_search=False, - moscube=False, # cube for mos introduction - session_id=None, - )["text_mem"] - search_b_results = [[m.memory for m in one["memories"]] for one in search_b_results] - search_b_results = [item for sublist in search_b_results for item in sublist] - - speaker_a_context = "" - for item in search_a_results: - speaker_a_context += f"{item}\n" - - speaker_b_context = "" - for item in search_b_results: - speaker_b_context += f"{item}\n" - - context = SEARCH_PROMPT_MEMOS.format( - speaker_1=speaker_a, - speaker_1_memories=speaker_a_context, - speaker_2=speaker_b, - speaker_2_memories=speaker_b_context, - ) - - logger.info(f'query "{query[:100]}", context: {context[:100]}"') - duration_ms = (time.time() - start) * 1000 - - return context, duration_ms - - def _process_single_qa_for_time_eval( - self, - qa, - *, - client, - reversed_client, - metadata, - frame, - version, - conv_id, - conv_stats_path, - oai_client, - top_k, - conv_stats, - ): - query = qa.get("question") - gold_answer = qa.get("answer") - qa_category = qa.get("category") - if qa_category == 5: - return None - - # 1. two parallel process, - # 1. memos search + response - # 2. pre_memories can answer, true : direct answer false: - - # Search - assert self.args.frame == MEMOS_SCHEDULER_MODEL - cur_context, search_duration_ms = self.search_query( - client, query, metadata, frame, reversed_client=reversed_client, top_k=top_k - ) - if not cur_context: - logger.warning(f"No context found for query: {query[:100]}") - cur_context = "" - - # Context answer ability analysis (for memos_scheduler only) - if self.pre_context_cache[conv_id] is None: - # Update pre-context cache with current context and return - self.update_context( - conv_id=conv_id, - method=self.context_update_method, - cur_context=cur_context, - ) - - # ========= MemOS Scheduler update ========= - _ = client.mem_scheduler.update_working_memory_for_eval( - query=query, user_id=conv_id + "_speaker_a", top_k=top_k - ) - - _ = client.mem_scheduler.update_working_memory_for_eval( - query=query, user_id=conv_id + "_speaker_b", top_k=top_k - ) - return None - - context = self.pre_context_cache[conv_id] - - # Generate answer - answer_start = time.time() - answer = self.locomo_response(frame, oai_client, context, query) - response_duration_ms = (time.time() - answer_start) * 1000 - - can_answer, can_answer_duration_ms = self.eval_context( - context=context, query=query, gold_answer=gold_answer, oai_client=oai_client - ) - - # Record case for memos_scheduler - try: - recording_case = RecordingCase( - conv_id=conv_id, - query=query, - answer=answer, - context=cur_context, - pre_context=self.pre_context_cache[conv_id], - can_answer=can_answer, - can_answer_reason=f"Context analysis result: {'can answer' if can_answer else 'cannot answer'}", - search_duration_ms=search_duration_ms, - can_answer_duration_ms=can_answer_duration_ms, - response_duration_ms=response_duration_ms, - category=int(qa_category) if qa_category is not None else None, - golden_answer=str(qa.get("answer", "")), - ) - if can_answer: - self.can_answer_cases.append(recording_case) - else: - self.cannot_answer_cases.append(recording_case) - except Exception as e: - logger.error(f"Error creating RecordingCase: {e}") - print(f"Error creating RecordingCase: {e}") - logger.error(f"QA data: {qa}") - print(f"QA data: {qa}") - logger.error(f"Query: {query}") - logger.error(f"Answer: {answer}") - logger.error( - f"Golden answer (raw): {qa.get('answer')} (type: {type(qa.get('answer'))})" - ) - logger.error(f"Category: {qa_category} (type: {type(qa_category)})") - logger.error(f"Can answer: {can_answer}") - raise e - - # Update conversation stats and context - self._update_stats_and_context( - conv_id=conv_id, - frame=frame, - version=version, - conv_stats=conv_stats, - conv_stats_path=conv_stats_path, - query=query, - answer=answer, - gold_answer=gold_answer, - cur_context=cur_context, - can_answer=can_answer, - ) - # ========= MemOS Scheduler update ========= - _ = client.mem_scheduler.update_working_memory_for_eval( - query=query, user_id=conv_id + "_speaker_a", top_k=top_k - ) - - _ = client.mem_scheduler.update_working_memory_for_eval( - query=query, user_id=conv_id + "_speaker_b", top_k=top_k - ) - return { - "question": query, - "answer": answer, - "category": qa_category, - "golden_answer": gold_answer, - "search_context": cur_context, - "response_duration_ms": response_duration_ms, - "search_duration_ms": search_duration_ms, - "can_answer_duration_ms": can_answer_duration_ms, - "can_answer": can_answer if frame == MEMOS_SCHEDULER_MODEL else None, - } diff --git a/evaluation/scripts/temporal_locomo/modules/README.md b/evaluation/scripts/temporal_locomo/modules/README.md deleted file mode 100644 index 31a274dd0..000000000 --- a/evaluation/scripts/temporal_locomo/modules/README.md +++ /dev/null @@ -1,83 +0,0 @@ -# Evaluation Modules - -This directory contains the modularized evaluation system for temporal locomo evaluation, organized using inheritance and composition patterns. - -## Structure - -### Base Classes - -- **`base_eval_module.py`**: Contains the `BaseEvalModule` class with common functionality: - - Statistics management - - Data loading and processing - - File I/O operations - - Basic evaluation methods - -### Specialized Modules - -- **`client_manager.py`**: Contains the `ClientManager` class for managing different memory framework clients: - - Zep client management - - Mem0 client management - - Memos client management - - Memos scheduler client management - -- **`search_modules.py`**: Contains the `SearchModules` class with all search methods: - - `mem0_search()`: Mem0 framework search - - `mem0_graph_search()`: Mem0 graph framework search - - `memos_search()`: Memos framework search - - `memos_scheduler_search()`: Memos scheduler framework search - - `zep_search()`: Zep framework search - -- **`locomo_eval_module.py`**: Contains the main `LocomoEvalModule` class that combines all functionality: - - Inherits from `BaseEvalModule` - - Uses `ClientManager` for client management - - Uses `SearchModules` for search operations - - Provides unified interface for evaluation - -## Usage - -### Basic Usage - -```python -from modules import LocomoEvalModule -import argparse - -# Create arguments -args = argparse.Namespace() -args.frame = 'memos_scheduler' -args.version = 'v0.2.1' -args.top_k = 20 -args.workers = 1 - -# Initialize the evaluation module -eval_module = LocomoEvalModule(args) - -# Use the module -eval_module.print_eval_info() -eval_module.save_stats() -``` - -### Backward Compatibility - -For backward compatibility, the original `LocomoEvalModelModules` class is available as an alias: - -```python -from modules import LocomoEvalModule as LocomoEvalModelModules -``` - -## Benefits of Modularization - -1. **Separation of Concerns**: Each module has a specific responsibility -2. **Maintainability**: Easier to modify and extend individual components -3. **Testability**: Each module can be tested independently -4. **Reusability**: Modules can be reused in different contexts -5. **Readability**: Code is more organized and easier to understand - -## Migration from Original Code - -The original `eval_model_modules.py` has been refactored into this modular structure: - -- **Original class**: `LocomoEvalModelModules` -- **New main class**: `LocomoEvalModule` -- **Backward compatibility**: `LocomoEvalModelModules = LocomoEvalModule` - -All existing functionality is preserved, but now organized in a more maintainable structure. diff --git a/evaluation/scripts/temporal_locomo/modules/__init__.py b/evaluation/scripts/temporal_locomo/modules/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/evaluation/scripts/temporal_locomo/modules/base_eval_module.py b/evaluation/scripts/temporal_locomo/modules/base_eval_module.py deleted file mode 100644 index d056745cc..000000000 --- a/evaluation/scripts/temporal_locomo/modules/base_eval_module.py +++ /dev/null @@ -1,386 +0,0 @@ -import json -import os -import traceback - -from collections import defaultdict -from pathlib import Path -from threading import Lock -from typing import TYPE_CHECKING - -import pandas as pd - -from dotenv import load_dotenv - -from memos.configs.mem_scheduler import AuthConfig -from memos.log import get_logger - -from .constants import ( - BASE_DIR, - MEMOS_SCHEDULER_MODEL, -) -from .prompts import ( - CUSTOM_INSTRUCTIONS, -) -from .schemas import ContextUpdateMethod - - -if TYPE_CHECKING: - from .schemas import RecordingCase - - -logger = get_logger(__name__) - - -class BaseEvalModule: - def __init__(self, args): - # hyper-parameters - self.args = args - self.frame = self.args.frame - self.version = self.args.version - self.workers = self.args.workers - self.top_k = self.args.top_k - - # attributes - self.context_update_method = getattr( - self.args, "context_update_method", ContextUpdateMethod.PRE_CONTEXT - ) - self.custom_instructions = CUSTOM_INSTRUCTIONS - self.data_dir = Path(f"{BASE_DIR}/data") - self.locomo_df = pd.read_json(f"{self.data_dir}/locomo/locomo10.json") - - # Load temporal_locomo dataset if it exists - self.temporal_locomo_data = None - temporal_locomo_file = self.data_dir / "temporal_locomo" / "temporal_locomo_qa.json" - if temporal_locomo_file.exists(): - with open(temporal_locomo_file, encoding="utf-8") as f: - self.temporal_locomo_data = json.load(f) - logger.info( - f"Loaded temporal_locomo dataset with {len(self.temporal_locomo_data)} conversations" - ) - else: - logger.warning(f"Temporal locomo dataset not found at {temporal_locomo_file}") - - result_dir_prefix = getattr(self.args, "result_dir_prefix", "") - - # Configure result dir; if scheduler disabled and using memos scheduler, mark as ablation - if ( - hasattr(self.args, "scheduler_flag") - and self.frame == MEMOS_SCHEDULER_MODEL - and self.args.scheduler_flag is False - ): - self.result_dir = Path( - f"{BASE_DIR}/results/temporal_locomo/{result_dir_prefix}{self.frame}-{self.version}-ablation/" - ) - else: - self.result_dir = Path( - f"{BASE_DIR}/results/temporal_locomo/{result_dir_prefix}{self.frame}-{self.version}/" - ) - - if self.context_update_method != ContextUpdateMethod.PRE_CONTEXT: - self.result_dir = ( - self.result_dir.parent / f"{self.result_dir.name}_{self.context_update_method}" - ) - self.result_dir.mkdir(parents=True, exist_ok=True) - - self.search_path = self.result_dir / f"{self.frame}-{self.version}_search_results.json" - self.response_path = self.result_dir / f"{self.frame}-{self.version}_responses.json" - self.judged_path = self.result_dir / f"{self.frame}-{self.version}_judged.json" - self.grade_path = self.result_dir / f"{self.frame}-{self.version}_grades.json" - self.excel_path = self.result_dir / f"{self.frame}-{self.version}_metrics.xlsx" - - self.ingestion_storage_dir = self.result_dir / "storages" - self.mos_config_path = Path(f"{BASE_DIR}/configs-example/mos_w_scheduler_config.json") - self.mem_cube_config_path = Path(f"{BASE_DIR}/configs-example/mem_cube_config.json") - - self.openai_api_key = os.getenv("CHAT_MODEL_API_KEY") - self.openai_base_url = os.getenv("CHAT_MODEL_BASE_URL") - self.openai_chat_model = os.getenv("CHAT_MODEL") - - auth_config_path = Path(f"{BASE_DIR}/scripts/temporal_locomo/eval_auth.json") - if auth_config_path.exists(): - auth_config = AuthConfig.from_local_config(config_path=auth_config_path) - print( - f"✅ Configuration loaded successfully: from local config file {auth_config_path}" - ) - else: - # Load .env file first before reading environment variables - load_dotenv() - auth_config = AuthConfig.from_local_env() - print("✅ Configuration loaded successfully: from environment variables") - self.openai_api_key = auth_config.openai.api_key - self.openai_base_url = auth_config.openai.base_url - self.openai_chat_model = auth_config.openai.default_model - - self.mos_config_data = json.load(self.mos_config_path.open("r", encoding="utf-8")) - self.mem_cube_config_data = json.load(self.mem_cube_config_path.open("r", encoding="utf-8")) - - # Update LLM authentication information in MOS configuration using dictionary assignment - self.mos_config_data["mem_reader"]["config"]["llm"]["config"]["api_key"] = ( - auth_config.openai.api_key - ) - self.mos_config_data["mem_reader"]["config"]["llm"]["config"]["api_base"] = ( - auth_config.openai.base_url - ) - - # Update graph database authentication information in memory cube configuration using dictionary assignment - self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["uri"] = ( - auth_config.graph_db.uri - ) - self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["user"] = ( - auth_config.graph_db.user - ) - self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["password"] = ( - auth_config.graph_db.password - ) - self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["db_name"] = ( - auth_config.graph_db.db_name - ) - self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["auto_create"] = ( - auth_config.graph_db.auto_create - ) - - # Logger initialization - self.logger = logger - - # Statistics tracking with thread safety - self.stats = {self.frame: {self.version: defaultdict(dict)}} - self.stats[self.frame][self.version]["memory_stats"] = defaultdict(dict) - self.stats[self.frame][self.version]["memory_stats"]["total_queries"] = 0 - self.stats[self.frame][self.version]["memory_stats"]["can_answer_count"] = 0 - self.stats[self.frame][self.version]["memory_stats"]["cannot_answer_count"] = 0 - self.stats[self.frame][self.version]["memory_stats"]["answer_hit_rate"] = 0.0 - - # Initialize memory history for tracking retrieval results - self.stats_lock = Lock() - # Reflect CLI flag - self.scheduler_flag = bool(getattr(self.args, "scheduler_flag", True)) - self.stats_dir = self.result_dir / f"stats/{self.frame}_{self.version}" - self.stats_dir.mkdir(parents=True, exist_ok=True) # Ensure the directory exists - self.stats_path = self.stats_dir / "stats.txt" - - self.can_answer_cases: list[RecordingCase] = [] - self.cannot_answer_cases: list[RecordingCase] = [] - - def print_eval_info(self): - """ - Calculate and print the evaluation information including answer statistics for memory scheduler (thread-safe). - Shows total queries, can answer count, cannot answer count, and answer hit rate. - """ - with self.stats_lock: - # Get statistics - total_queries = self.stats[self.frame][self.version]["memory_stats"]["total_queries"] - can_answer_count = self.stats[self.frame][self.version]["memory_stats"][ - "can_answer_count" - ] - cannot_answer_count = self.stats[self.frame][self.version]["memory_stats"][ - "cannot_answer_count" - ] - hit_rate = self.stats[self.frame][self.version]["memory_stats"]["answer_hit_rate"] - - # Print basic statistics - print(f"Total Queries: {total_queries}") - logger.info(f"Total Queries: {total_queries}") - - print(f"Can Answer Count: {can_answer_count}") - logger.info(f"Can Answer Count: {can_answer_count}") - - print(f"Cannot Answer Count: {cannot_answer_count}") - logger.info(f"Cannot Answer Count: {cannot_answer_count}") - - # Verify count consistency - if total_queries != (can_answer_count + cannot_answer_count): - print( - f"WARNING: Count mismatch! Total ({total_queries}) != Can Answer ({can_answer_count}) + Cannot Answer ({cannot_answer_count})" - ) - logger.warning( - f"Count mismatch! Total ({total_queries}) != Can Answer ({can_answer_count}) + Cannot Answer ({cannot_answer_count})" - ) - - print(f"Answer Hit Rate: {hit_rate:.2f}% ({can_answer_count}/{total_queries})") - logger.info(f"Answer Hit Rate: {hit_rate:.2f}% ({can_answer_count}/{total_queries})") - - def save_stats(self): - """ - Serializes and saves the contents of self.stats to the specified path: - Base_dir/results/frame-version/stats - - This method handles directory creation, thread-safe access to statistics data, - and proper JSON serialization of complex data structures. - """ - try: - # Thread-safe access to the stats data using the lock - # Create a copy of the data to prevent modification during serialization - stats_data = dict(self.stats) - - # Helper function to convert defaultdict to regular dict for JSON serialization - def convert_defaultdict(obj): - if isinstance(obj, defaultdict): - return dict(obj) - return obj - - # Debug: Print stats summary before saving - self.logger.info(f"DEBUG: Saving stats for {self.frame}-{self.version}") - self.logger.info(f"DEBUG: Stats path: {self.stats_path}") - self.logger.info(f"DEBUG: Stats data keys: {list(stats_data.keys())}") - if self.frame in stats_data and self.version in stats_data[self.frame]: - frame_data = stats_data[self.frame][self.version] - self.logger.info(f"DEBUG: Memory stats: {frame_data.get('memory_stats', {})}") - self.logger.info( - f"DEBUG: Total queries: {frame_data.get('memory_stats', {}).get('total_queries', 0)}" - ) - - # Serialize and save the statistics data to file - with self.stats_path.open("w", encoding="utf-8") as fw: - json.dump(stats_data, fw, ensure_ascii=False, indent=2, default=convert_defaultdict) - - self.logger.info(f"Successfully saved stats to: {self.stats_path}") - print(f"DEBUG: Stats file created at {self.stats_path}") - - except Exception as e: - self.logger.error(f"Failed to save stats: {e!s}") - self.logger.error(traceback.format_exc()) - print(f"DEBUG: Error saving stats: {e}") - - def get_answer_hit_rate(self): - """ - Get current answer hit rate statistics. - - Returns: - dict: Hit rate statistics - """ - with self.stats_lock: - return { - "total_queries": self.stats[self.frame][self.version]["memory_stats"][ - "total_queries" - ], - "can_answer_count": self.stats[self.frame][self.version]["memory_stats"][ - "can_answer_count" - ], - "hit_rate_percentage": self.stats[self.frame][self.version]["memory_stats"][ - "answer_hit_rate" - ], - } - - def group_and_sort_qa_by_day(self, qa_set, sort_by_evidence): - """ - Groups QA pairs by day and sorts them chronologically within each day group. - - Args: - qa_set (list): List of dictionaries containing QA data with evidence references - - Returns: - dict: Dictionary where keys are day strings (e.g., 'D1') and values are - lists of QA pairs sorted by evidence order within that day - """ - # Initialize a dictionary that automatically creates lists for new keys - day_groups = defaultdict(list) - - # Process each QA pair in the input dataset - for qa in qa_set: - # Extract all unique days referenced in this QA pair's evidence - days = set() - for evidence in qa["evidence"]: - # Split evidence string (e.g., 'D1:3') into day and position parts - day = evidence.split(":")[0] # Gets 'D1', 'D2', etc. - days.add(day) - - # Add this QA pair to each day group it references - for day in days: - day_groups[day].append(qa) - - if sort_by_evidence: - # Sort QA pairs within each day group by their earliest evidence position - for day in day_groups: - # Create list of (qa, position) pairs for proper sorting - qa_position_pairs = [] - - for qa in day_groups[day]: - # Find the earliest evidence position for this day - earliest_position = None - for evidence in qa["evidence"]: - if evidence.startswith(day + ":"): - try: - position = int(evidence.split(":")[1]) - if earliest_position is None or position < earliest_position: - earliest_position = position - except (IndexError, ValueError): - # Skip invalid evidence format - continue - - if earliest_position is not None: - qa_position_pairs.append((qa, earliest_position)) - - # Sort by evidence position (earliest first) - qa_position_pairs = sorted(qa_position_pairs, key=lambda x: x[1]) - day_groups[day] = [qa for qa, _ in qa_position_pairs] - - return dict(day_groups) - - def convert_locomo_to_temporal_locomo(self, output_dir: str | None = None): - """ - Convert locomo dataset to temporal_locomo dataset format. - - This function processes the original locomo dataset and reorganizes it by days - with proper chronological ordering within each day group. - - Args: - output_dir: Output directory for the converted dataset. - Defaults to evaluation/data/temporal_locomo/ - - Returns: - str: Path to the converted dataset file - """ - if output_dir is None: - output_dir = f"{BASE_DIR}/data/temporal_locomo" - - # Create output directory - os.makedirs(output_dir, exist_ok=True) - - # Load original locomo data - locomo_data = self.locomo_df.to_dict("records") - - # Process each conversation - temporal_data = [] - - for conv_id, conversation in enumerate(locomo_data): - logger.info(f"Processing conversation {conv_id + 1}/{len(locomo_data)}") - - # Get QA pairs for this conversation - qa_set = conversation.get("qa", []) - - # Group and sort QA pairs by day - day_groups = self.group_and_sort_qa_by_day(qa_set, sort_by_evidence=False) - - # Create temporal structure for this conversation - temporal_conversation = {"conversation_id": f"locomo_exp_user_{conv_id}", "days": {}} - - # Process each day group - for day, qa_list in day_groups.items(): - temporal_conversation["days"][day] = { - "day_id": day, - "qa_pairs": qa_list, - "total_qa_pairs": len(qa_list), - } - - temporal_data.append(temporal_conversation) - - # Save the converted dataset - output_file = os.path.join(output_dir, "temporal_locomo_qa.json") - with open(output_file, "w", encoding="utf-8") as f: - json.dump(temporal_data, f, indent=2, ensure_ascii=False) - - logger.info(f"Converted dataset saved to: {output_file}") - logger.info(f"Total conversations: {len(temporal_data)}") - - # Log statistics - total_qa_pairs = sum(len(conv["qa"]) for conv in locomo_data) - total_temporal_qa_pairs = sum( - sum(day_data["total_qa_pairs"] for day_data in conv["days"].values()) - for conv in temporal_data - ) - - logger.info(f"Original QA pairs: {total_qa_pairs}") - logger.info(f"Temporal QA pairs: {total_temporal_qa_pairs}") - logger.info("QA pairs may be duplicated across days if they reference multiple days") - - return output_file diff --git a/evaluation/scripts/temporal_locomo/modules/client_manager.py b/evaluation/scripts/temporal_locomo/modules/client_manager.py deleted file mode 100644 index c5882179e..000000000 --- a/evaluation/scripts/temporal_locomo/modules/client_manager.py +++ /dev/null @@ -1,191 +0,0 @@ -""" -Client management module for handling different memory framework clients. -""" - -import os - -from mem0 import MemoryClient -from zep_cloud.client import Zep - -from memos.configs.mem_cube import GeneralMemCubeConfig -from memos.configs.mem_os import MOSConfig -from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube -from memos.mem_os.main import MOS -from memos.mem_scheduler.analyzer.scheduler_for_eval import SchedulerForEval - -from .base_eval_module import BaseEvalModule -from .constants import ( - MEM0_GRAPH_MODEL, - MEM0_MODEL, - MEMOS_MODEL, - MEMOS_SCHEDULER_MODEL, - ZEP_MODEL, -) -from .prompts import ( - ANSWER_PROMPT_MEM0, - ANSWER_PROMPT_MEMOS, - ANSWER_PROMPT_ZEP, -) - - -logger = get_logger(__name__) - - -class EvalModuleWithClientManager(BaseEvalModule): - """ - Manages different memory framework clients for evaluation. - """ - - def __init__(self, args): - super().__init__(args=args) - - def get_client_for_ingestion( - self, frame: str, user_id: str | None = None, version: str = "default" - ): - if frame == ZEP_MODEL: - zep = Zep(api_key=os.getenv("ZEP_API_KEY"), base_url="https://api.getzep.com/api/v2") - return zep - - elif frame in (MEM0_MODEL, MEM0_GRAPH_MODEL): - mem0 = MemoryClient(api_key=os.getenv("MEM0_API_KEY")) - mem0.update_project(custom_instructions=self.custom_instructions) - return mem0 - else: - if frame not in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: - raise NotImplementedError(f"Unsupported framework: {frame}") - - # scheduler is not needed in the ingestion step - self.mos_config_data["top_k"] = 20 - self.mos_config_data["enable_mem_scheduler"] = False - - mos_config = MOSConfig(**self.mos_config_data) - mos = MOS(mos_config) - mos.create_user(user_id=user_id) - - self.mem_cube_config_data["user_id"] = user_id - self.mem_cube_config_data["cube_id"] = user_id - self.mem_cube_config_data["text_mem"]["config"]["graph_db"]["config"]["db_name"] = ( - f"{user_id.replace('_', '')}{version}" - ) - mem_cube_config = GeneralMemCubeConfig.model_validate(self.mem_cube_config_data) - mem_cube = GeneralMemCube(mem_cube_config) - - storage_path = str(self.ingestion_storage_dir / user_id) - try: - mem_cube.dump(storage_path) - except Exception as e: - print(f"dumping memory cube: {e!s} already exists, will use it.") - - mos.register_mem_cube( - mem_cube_name_or_path=storage_path, - mem_cube_id=user_id, - user_id=user_id, - ) - - return mos - - def get_client_from_storage( - self, frame: str, user_id: str | None = None, version: str = "default", top_k: int = 20 - ): - """ - Get a client instance for the specified memory framework. - - Args: - frame: Memory framework to use (zep, mem0, mem0_graph, memos, memos_scheduler) - user_id: Unique identifier for the user - version: Version identifier for result storage - top_k: Number of results to retrieve in search queries - - Returns: - Client instance for the specified framework - """ - storage_path = str(self.ingestion_storage_dir / user_id) - - if frame == ZEP_MODEL: - zep = Zep(api_key=os.getenv("ZEP_API_KEY"), base_url="https://api.getzep.com/api/v2") - return zep - - elif frame == [MEM0_MODEL, MEM0_GRAPH_MODEL]: - mem0 = MemoryClient(api_key=os.getenv("MEM0_API_KEY")) - return mem0 - - else: - if frame not in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: - raise NotImplementedError(f"Unsupported framework: {frame}") - - if frame == MEMOS_MODEL: - self.mos_config_data["enable_mem_scheduler"] = False - - self.mos_config_data["top_k"] = top_k - mos_config = MOSConfig(**self.mos_config_data) - mos = MOS(mos_config) - mos.create_user(user_id=user_id) - mos.register_mem_cube( - mem_cube_name_or_path=storage_path, - mem_cube_id=user_id, - user_id=user_id, - ) - - if frame == MEMOS_SCHEDULER_MODEL: - # Configure memory scheduler - mos.mem_scheduler.current_mem_cube = mos.mem_cubes[user_id] - mos.mem_scheduler.current_mem_cube_id = user_id - mos.mem_scheduler.current_user_id = user_id - - # Create SchedulerForEval instance with the same config - scheduler_for_eval = SchedulerForEval(config=mos.mem_scheduler.config) - # Initialize with the same modules as the original scheduler - scheduler_for_eval.initialize_modules( - chat_llm=mos.mem_scheduler.chat_llm, - process_llm=mos.mem_scheduler.process_llm, - db_engine=mos.mem_scheduler.db_engine, - ) - # Set the same context - scheduler_for_eval.current_mem_cube = mos.mem_cubes[user_id] - scheduler_for_eval.current_mem_cube_id = user_id - scheduler_for_eval.current_user_id = user_id - - # set llms to openai api - mos.chat_llm = mos.mem_reader.llm - for cube in mos.mem_cubes.values(): - cube.text_mem.dispatcher_llm = mos.mem_reader.llm - cube.text_mem.extractor_llm = mos.mem_reader.llm - - # Replace the original scheduler - mos.mem_scheduler = scheduler_for_eval - return mos - - def locomo_response(self, frame, llm_client, context: str, question: str) -> str: - if frame == ZEP_MODEL: - prompt = ANSWER_PROMPT_ZEP.format( - context=context, - question=question, - ) - elif frame in (MEM0_MODEL, MEM0_GRAPH_MODEL): - prompt = ANSWER_PROMPT_MEM0.format( - context=context, - question=question, - ) - elif frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: - prompt = ANSWER_PROMPT_MEMOS.format( - context=context, - question=question, - ) - else: - raise NotImplementedError() - response = llm_client.chat.completions.create( - model=self.openai_chat_model, - messages=[ - {"role": "system", "content": prompt}, - ], - temperature=0, - ) - - result = response.choices[0].message.content or "" - - if result == "": - with self.stats_lock: - self.stats[self.frame][self.version]["response_stats"]["response_failure"] += 1 - self.stats[self.frame][self.version]["response_stats"]["response_count"] += 1 - return result diff --git a/evaluation/scripts/temporal_locomo/modules/constants.py b/evaluation/scripts/temporal_locomo/modules/constants.py deleted file mode 100644 index 51ad7c729..000000000 --- a/evaluation/scripts/temporal_locomo/modules/constants.py +++ /dev/null @@ -1,19 +0,0 @@ -import sys - -from pathlib import Path - -from memos.log import get_logger - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -ZEP_MODEL = "zep" -MEM0_MODEL = "mem0" -MEM0_GRAPH_MODEL = "mem0_graph" -MEMOS_MODEL = "memos" -MEMOS_SCHEDULER_MODEL = "memos_scheduler" diff --git a/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py b/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py deleted file mode 100644 index d444ea62c..000000000 --- a/evaluation/scripts/temporal_locomo/modules/locomo_eval_module.py +++ /dev/null @@ -1,578 +0,0 @@ -import json -import time -import traceback - -from collections import defaultdict -from datetime import datetime -from typing import TYPE_CHECKING - -from openai import OpenAI -from tqdm import tqdm - -from memos.log import get_logger - -from .client_manager import EvalModuleWithClientManager -from .constants import ( - MEM0_GRAPH_MODEL, - MEM0_MODEL, - MEMOS_MODEL, - MEMOS_SCHEDULER_MODEL, - ZEP_MODEL, -) -from .prompts import ( - CONTEXT_ANSWERABILITY_PROMPT, - SEARCH_PROMPT_MEM0, - SEARCH_PROMPT_MEM0_GRAPH, - SEARCH_PROMPT_MEMOS, - SEARCH_PROMPT_ZEP, -) -from .utils import filter_memory_data - - -if TYPE_CHECKING: - from memos.mem_os.main import MOS -logger = get_logger(__name__) - - -class LocomoEvalModelModules(EvalModuleWithClientManager): - """ - Contains search methods for different memory frameworks. - """ - - def __init__(self, args): - super().__init__(args=args) - self.pre_context_cache = {} - - def analyze_context_answerability(self, context, query, gold_answer, oai_client): - """ - Analyze whether the given context can answer the query. - - Args: - context: The context string to analyze - query: The query string - oai_client: OpenAI client for LLM analysis - - Returns: - bool: True if context can answer the query, False otherwise - """ - try: - prompt = CONTEXT_ANSWERABILITY_PROMPT.format( - context=context, question=query, gold_answer=str(gold_answer) - ) - - response = oai_client.chat.completions.create( - model="gpt-4o-mini", - messages=[{"role": "user", "content": prompt}], - temperature=0.1, - max_tokens=10, - ) - - answer = response.choices[0].message.content.strip().upper() - return answer == "YES" - except Exception as e: - logger.error(f"Error analyzing context answerability: {e}") - return False - - def mem0_search(self, client, query, speaker_a_user_id, speaker_b_user_id, top_k=20): - """ - Search memories using the mem0 framework. - - Args: - client: mem0 client instance - query: Search query string - speaker_a_user_id: User ID for first speaker - speaker_b_user_id: User ID for second speaker - top_k: Number of results to retrieve - - Returns: - Tuple containing formatted context and search duration in milliseconds - """ - start = time.time() - search_speaker_a_results = client.search( - query=query, - top_k=top_k, - user_id=speaker_a_user_id, - output_format="v1.1", - version="v2", - filters={"AND": [{"user_id": f"{speaker_a_user_id}"}, {"run_id": "*"}]}, - ) - search_speaker_b_results = client.search( - query=query, - top_k=top_k, - user_id=speaker_b_user_id, - output_format="v1.1", - version="v2", - filters={"AND": [{"user_id": f"{speaker_b_user_id}"}, {"run_id": "*"}]}, - ) - - # Format speaker A memories - search_speaker_a_memory = [ - { - "memory": memory["memory"], - "timestamp": memory["created_at"], - "score": round(memory["score"], 2), - } - for memory in search_speaker_a_results["results"] - ] - - search_speaker_a_memory = [ - [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_a_memory] - ] - - # Format speaker B memories - search_speaker_b_memory = [ - { - "memory": memory["memory"], - "timestamp": memory["created_at"], - "score": round(memory["score"], 2), - } - for memory in search_speaker_b_results["results"] - ] - - search_speaker_b_memory = [ - [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_b_memory] - ] - - # Create context using template - context = SEARCH_PROMPT_MEM0.format( - speaker_1_user_id=speaker_a_user_id.split("_")[0], - speaker_1_memories=json.dumps(search_speaker_a_memory, indent=4), - speaker_2_user_id=speaker_b_user_id.split("_")[0], - speaker_2_memories=json.dumps(search_speaker_b_memory, indent=4), - ) - - duration_ms = (time.time() - start) * 1000 - return context, duration_ms - - def memos_search( - self, client, query, conv_id, speaker_a, speaker_b, reversed_client=None, top_k=20 - ): - """ - Search memories using the memos framework. - - Args: - client: memos client instance - query: Search query string - conv_id: Conversation ID - speaker_a: First speaker identifier - speaker_b: Second speaker identifier - reversed_client: Client instance for reversed speaker context - - Returns: - Tuple containing formatted context and search duration in milliseconds - """ - start = time.time() - # Search memories for speaker A - search_a_results = client.search(query=query, user_id=conv_id + "_speaker_a") - filtered_search_a_results = filter_memory_data(search_a_results)["text_mem"][0]["memories"] - speaker_a_context = "" - for item in filtered_search_a_results[:top_k]: - speaker_a_context += f"{item['memory']}\n" - - # Search memories for speaker B - search_b_results = reversed_client.search( - query=query, - user_id=conv_id + "_speaker_b", - ) - filtered_search_b_results = filter_memory_data(search_b_results)["text_mem"][0]["memories"] - speaker_b_context = "" - for item in filtered_search_b_results[:top_k]: - speaker_b_context += f"{item['memory']}\n" - - # Create context using template - context = SEARCH_PROMPT_MEMOS.format( - speaker_1=speaker_a, - speaker_1_memories=speaker_a_context, - speaker_2=speaker_b, - speaker_2_memories=speaker_b_context, - ) - - duration_ms = (time.time() - start) * 1000 - return context, duration_ms - - def memos_scheduler_search( - self, client, query, conv_id, speaker_a, speaker_b, reversed_client=None, top_k=20 - ): - start = time.time() - client: MOS = client - - if not self.scheduler_flag: - # if not scheduler_flag, search to update working memory - self.memos_search(client, query, conv_id, speaker_a, speaker_b, reversed_client, top_k) - - # Search for speaker A - search_a_results = client.mem_scheduler.search_for_eval( - query=query, - user_id=conv_id + "_speaker_a", - top_k=top_k, - scheduler_flag=self.scheduler_flag, - ) - - # Search for speaker B - search_b_results = reversed_client.mem_scheduler.search_for_eval( - query=query, - user_id=conv_id + "_speaker_b", - top_k=top_k, - scheduler_flag=self.scheduler_flag, - ) - - speaker_a_context = "" - for item in search_a_results: - speaker_a_context += f"{item}\n" - - speaker_b_context = "" - for item in search_b_results: - speaker_b_context += f"{item}\n" - - context = SEARCH_PROMPT_MEMOS.format( - speaker_1=speaker_a, - speaker_1_memories=speaker_a_context, - speaker_2=speaker_b, - speaker_2_memories=speaker_b_context, - ) - - logger.info(f'query "{query[:100]}", context: {context[:100]}"') - duration_ms = (time.time() - start) * 1000 - - return context, duration_ms - - def mem0_graph_search(self, client, query, speaker_a_user_id, speaker_b_user_id, top_k=20): - start = time.time() - search_speaker_a_results = client.search( - query=query, - top_k=top_k, - user_id=speaker_a_user_id, - output_format="v1.1", - version="v2", - enable_graph=True, - filters={"AND": [{"user_id": f"{speaker_a_user_id}"}, {"run_id": "*"}]}, - ) - search_speaker_b_results = client.search( - query=query, - top_k=top_k, - user_id=speaker_b_user_id, - output_format="v1.1", - version="v2", - enable_graph=True, - filters={"AND": [{"user_id": f"{speaker_b_user_id}"}, {"run_id": "*"}]}, - ) - - search_speaker_a_memory = [ - { - "memory": memory["memory"], - "timestamp": memory["created_at"], - "score": round(memory["score"], 2), - } - for memory in search_speaker_a_results["results"] - ] - - search_speaker_a_memory = [ - [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_a_memory] - ] - - search_speaker_b_memory = [ - { - "memory": memory["memory"], - "timestamp": memory["created_at"], - "score": round(memory["score"], 2), - } - for memory in search_speaker_b_results["results"] - ] - - search_speaker_b_memory = [ - [f"{item['timestamp']}: {item['memory']}" for item in search_speaker_b_memory] - ] - - search_speaker_a_graph = [ - { - "source": relation["source"], - "relationship": relation["relationship"], - "target": relation["target"], - } - for relation in search_speaker_a_results["relations"] - ] - - search_speaker_b_graph = [ - { - "source": relation["source"], - "relationship": relation["relationship"], - "target": relation["target"], - } - for relation in search_speaker_b_results["relations"] - ] - context = SEARCH_PROMPT_MEM0_GRAPH.format( - speaker_1_user_id=speaker_a_user_id.split("_")[0], - speaker_1_memories=json.dumps(search_speaker_a_memory, indent=4), - speaker_1_graph_memories=json.dumps(search_speaker_a_graph, indent=4), - speaker_2_user_id=speaker_b_user_id.split("_")[0], - speaker_2_memories=json.dumps(search_speaker_b_memory, indent=4), - speaker_2_graph_memories=json.dumps(search_speaker_b_graph, indent=4), - ) - print(query, context) - duration_ms = (time.time() - start) * 1000 - return context, duration_ms - - def zep_search(self, client, query, group_id, top_k=20): - start = time.time() - nodes_result = client.graph.search( - query=query, - group_id=group_id, - scope="nodes", - reranker="rrf", - limit=top_k, - ) - edges_result = client.graph.search( - query=query, - group_id=group_id, - scope="edges", - reranker="cross_encoder", - limit=top_k, - ) - - nodes = nodes_result.nodes - edges = edges_result.edges - - facts = [f" - {edge.fact} (event_time: {edge.valid_at})" for edge in edges] - entities = [f" - {node.name}: {node.summary}" for node in nodes] - - context = SEARCH_PROMPT_ZEP.format(facts="\n".join(facts), entities="\n".join(entities)) - - duration_ms = (time.time() - start) * 1000 - - return context, duration_ms - - def search_query(self, client, query, metadata, frame, reversed_client=None, top_k=20): - conv_id = metadata.get("conv_id") - speaker_a = metadata.get("speaker_a") - speaker_b = metadata.get("speaker_b") - speaker_a_user_id = metadata.get("speaker_a_user_id") - speaker_b_user_id = metadata.get("speaker_b_user_id") - - if frame == ZEP_MODEL: - context, duration_ms = self.zep_search(client, query, conv_id, top_k) - elif frame == MEM0_MODEL: - context, duration_ms = self.mem0_search( - client, query, speaker_a_user_id, speaker_b_user_id, top_k - ) - elif frame == MEM0_GRAPH_MODEL: - context, duration_ms = self.mem0_graph_search( - client, query, speaker_a_user_id, speaker_b_user_id, top_k - ) - elif frame == MEMOS_MODEL: - context, duration_ms = self.memos_search( - client, query, conv_id, speaker_a, speaker_b, reversed_client, top_k - ) - elif frame == MEMOS_SCHEDULER_MODEL: - context, duration_ms = self.memos_scheduler_search( - client, query, conv_id, speaker_a, speaker_b, reversed_client, top_k - ) - else: - raise NotImplementedError() - - return context, duration_ms - - def _initialize_conv_stats(self): - """Create a fresh statistics dictionary for a conversation.""" - return { - "total_queries": 0, - "can_answer_count": 0, - "cannot_answer_count": 0, - "answer_hit_rate": 0.0, - "response_failure": 0, - "response_count": 0, - } - - def _build_day_groups(self, temporal_conv): - """Build mapping day_id -> qa_pairs from a temporal conversation dict.""" - day_groups = {} - for day_id, day_data in temporal_conv.get("days", {}).items(): - day_groups[day_id] = day_data.get("qa_pairs", []) - return day_groups - - def _build_metadata(self, speaker_a, speaker_b, speaker_a_user_id, speaker_b_user_id, conv_id): - """Assemble metadata for downstream calls.""" - return { - "speaker_a": speaker_a, - "speaker_b": speaker_b, - "speaker_a_user_id": speaker_a_user_id, - "speaker_b_user_id": speaker_b_user_id, - "conv_id": conv_id, - } - - def _get_clients(self, frame, speaker_a_user_id, speaker_b_user_id, conv_id, version, top_k): - """Return (client, reversed_client) according to the target frame.""" - reversed_client = None - if frame in [MEMOS_MODEL, MEMOS_SCHEDULER_MODEL]: - client = self.get_client_from_storage(frame, speaker_a_user_id, version, top_k=top_k) - reversed_client = self.get_client_from_storage( - frame, speaker_b_user_id, version, top_k=top_k - ) - else: - client = self.get_client_from_storage(frame, conv_id, version) - return client, reversed_client - - def _save_conv_stats(self, conv_id, frame, version, conv_stats, conv_stats_path): - """Persist per-conversation stats to disk.""" - conv_stats_data = { - "conversation_id": conv_id, - "frame": frame, - "version": version, - "statistics": conv_stats, - "timestamp": str(datetime.now()), - } - with open(conv_stats_path, "w") as fw: - json.dump(conv_stats_data, fw, indent=2, ensure_ascii=False) - print(f"Saved conversation stats for {conv_id} to {conv_stats_path}") - - def _write_user_search_results(self, user_search_path, search_results, conv_id): - """Write per-user search results to a temporary JSON file.""" - with open(user_search_path, "w") as fw: - json.dump(dict(search_results), fw, indent=2) - print(f"Save search results {conv_id}") - - def process_user(self, conv_id, locomo_df, frame, version, top_k=20): - user_search_path = self.result_dir / f"tmp/{frame}_locomo_search_results_{conv_id}.json" - user_search_path.parent.mkdir(exist_ok=True, parents=True) - search_results = defaultdict(list) - response_results = defaultdict(list) - conv_stats_path = self.stats_dir / f"{frame}_{version}_conv_{conv_id}_stats.json" - - conversation = locomo_df["conversation"].iloc[conv_id] - speaker_a = conversation.get("speaker_a", "speaker_a") - speaker_b = conversation.get("speaker_b", "speaker_b") - - # Use temporal_locomo data if available, otherwise fall back to original locomo data - temporal_conv = self.temporal_locomo_data[conv_id] - conv_id = temporal_conv["conversation_id"] - speaker_a_user_id = f"{conv_id}_speaker_a" - speaker_b_user_id = f"{conv_id}_speaker_b" - - # Process temporal data by days - day_groups = {} - for day_id, day_data in temporal_conv["days"].items(): - day_groups[day_id] = day_data["qa_pairs"] - - # Initialize conversation-level statistics - conv_stats = self._initialize_conv_stats() - - metadata = self._build_metadata( - speaker_a, speaker_b, speaker_a_user_id, speaker_b_user_id, conv_id - ) - - client, reversed_client = self._get_clients( - frame, speaker_a_user_id, speaker_b_user_id, conv_id, version, top_k - ) - - oai_client = OpenAI(api_key=self.openai_api_key, base_url=self.openai_base_url) - - with self.stats_lock: - self.pre_context_cache[conv_id] = None - - def process_qa(qa): - return self._process_single_qa( - qa, - client=client, - reversed_client=reversed_client, - metadata=metadata, - frame=frame, - version=version, - conv_id=conv_id, - conv_stats_path=conv_stats_path, - oai_client=oai_client, - top_k=top_k, - conv_stats=conv_stats, - ) - - # =================================== - conv_stats["theoretical_total_queries"] = 0 - for day, qa_list in day_groups.items(): - conv_stats["theoretical_total_queries"] += len(qa_list) - 1 - conv_stats["processing_failure_count"] = 0 - print(f"Processing user {conv_id} day {day}") - for qa in tqdm(qa_list, desc=f"Processing user {conv_id} day {day}"): - try: - result = process_qa(qa) - except Exception as e: - logger.error(f"Error: {e}. traceback: {traceback.format_exc()}") - conv_stats["processing_failure_count"] += 1 - continue - if result: - context_preview = ( - result["search_context"][:20] + "..." - if result["search_context"] - else "No context" - ) - if "can_answer" in result: - logger.info("Print can_answer case") - logger.info( - { - "question": result["question"][:100], - "pre context can answer": result["can_answer"], - "answer": result["answer"][:100], - "golden_answer": result["golden_answer"], - "search_context": context_preview[:100], - "search_duration_ms": result["search_duration_ms"], - } - ) - - search_results[conv_id].append( - { - "question": result["question"], - "context": result["search_context"], - "search_duration_ms": result["search_duration_ms"], - } - ) - response_results[conv_id].append(result) - - logger.warning( - f"Finished processing user {conv_id} day {day}, data_length: {len(qa_list)}" - ) - - # recording separate search results - with open(user_search_path, "w") as fw: - json.dump(dict(search_results), fw, indent=2) - print(f"Save search results {conv_id}") - - search_durations = [] - for result in response_results[conv_id]: - if "search_duration_ms" in result: - search_durations.append(result["search_duration_ms"]) - - if search_durations: - avg_search_duration = sum(search_durations) / len(search_durations) - with self.stats_lock: - if self.stats[self.frame][self.version]["memory_stats"]["avg_search_duration_ms"]: - self.stats[self.frame][self.version]["memory_stats"][ - "avg_search_duration_ms" - ] = ( - self.stats[self.frame][self.version]["memory_stats"][ - "avg_search_duration_ms" - ] - + avg_search_duration - ) / 2 - print(f"Average search duration: {avg_search_duration:.2f} ms") - - # Dump stats after processing each user - self.save_stats() - - return search_results, response_results - - def process_user_wrapper(self, args): - """ - Wraps the process_user function to support parallel execution and error handling. - - Args: - args: Tuple containing parameters for process_user - - Returns: - tuple: Contains user results or error information - """ - idx, locomo_df, frame, version, top_k = args - try: - print(f"Processing user {idx}...") - user_search_results, user_response_results = self.process_user( - idx, locomo_df, frame, version, top_k - ) - return (user_search_results, user_response_results, None) - except Exception as e: - return (None, None, (idx, e, traceback.format_exc())) diff --git a/evaluation/scripts/temporal_locomo/modules/prompts.py b/evaluation/scripts/temporal_locomo/modules/prompts.py deleted file mode 100644 index c88a8ff28..000000000 --- a/evaluation/scripts/temporal_locomo/modules/prompts.py +++ /dev/null @@ -1,219 +0,0 @@ -CUSTOM_INSTRUCTIONS = """ -Generate personal memories that follow these guidelines: - -1. Each memory should be self-contained with complete context, including: - - The person's name, do not use "user" while creating memories - - Personal details (career aspirations, hobbies, life circumstances) - - Emotional states and reactions - - Ongoing journeys or future plans - - Specific dates when events occurred - -2. Include meaningful personal narratives focusing on: - - Identity and self-acceptance journeys - - Family planning and parenting - - Creative outlets and hobbies - - Mental health and self-care activities - - Career aspirations and education goals - - Important life events and milestones - -3. Make each memory rich with specific details rather than general statements - - Include timeframes (exact dates when possible) - - Name specific activities (e.g., "charity race for mental health" rather than just "exercise") - - Include emotional context and personal growth elements - -4. Extract memories only from user messages, not incorporating assistant responses - -5. Format each memory as a paragraph with a clear narrative structure that captures the person's experience, challenges, and aspirations -""" - -SEARCH_PROMPT_ZEP = """ -FACTS and ENTITIES represent relevant context to the current conversation. - -# These are the most relevant facts for the conversation along with the datetime of the event that the fact refers to. -If a fact mentions something happening a week ago, then the datetime will be the date time of last week and not the datetime -of when the fact was stated. -Timestamps in memories represent the actual time the event occurred, not the time the event was mentioned in a message. - - -{facts} - - -# These are the most relevant entities -# ENTITY_NAME: entity summary - -{entities} - -""" - -SEARCH_PROMPT_MEM0 = """Memories for user {speaker_1_user_id}: - - {speaker_1_memories} - - Memories for user {speaker_2_user_id}: - - {speaker_2_memories} -""" - -SEARCH_PROMPT_MEM0_GRAPH = """Memories for user {speaker_1_user_id}: - - {speaker_1_memories} - - Relations for user {speaker_1_user_id}: - - {speaker_1_graph_memories} - - Memories for user {speaker_2_user_id}: - - {speaker_2_memories} - - Relations for user {speaker_2_user_id}: - - {speaker_2_graph_memories} -""" - -SEARCH_PROMPT_MEMOS = """Memories for user {speaker_1}: - - {speaker_1_memories} - - Memories for user {speaker_2}: - - {speaker_2_memories} -""" - - -ANSWER_PROMPT_MEM0 = """ - You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories. - - # CONTEXT: - You have access to memories from two speakers in a conversation. These memories contain - timestamped information that may be relevant to answering the question. - - # INSTRUCTIONS: - 1. Carefully analyze all provided memories from both speakers - 2. Pay special attention to the timestamps to determine the answer - 3. If the question asks about a specific event or fact, look for direct evidence in the memories - 4. If the memories contain contradictory information, prioritize the most recent memory - 5. If there is a question about time references (like "last year", "two months ago", etc.), - calculate the actual date based on the memory timestamp. For example, if a memory from - 4 May 2022 mentions "went to India last year," then the trip occurred in 2021. - 6. Always convert relative time references to specific dates, months, or years. For example, - convert "last year" to "2022" or "two months ago" to "March 2023" based on the memory - timestamp. Ignore the reference while answering the question. - 7. Focus only on the content of the memories from both speakers. Do not confuse character - names mentioned in memories with the actual users who created those memories. - 8. The answer should be less than 5-6 words. - - # APPROACH (Think step by step): - 1. First, examine all memories that contain information related to the question - 2. Examine the timestamps and content of these memories carefully - 3. Look for explicit mentions of dates, times, locations, or events that answer the question - 4. If the answer requires calculation (e.g., converting relative time references), show your work - 5. Formulate a precise, concise answer based solely on the evidence in the memories - 6. Double-check that your answer directly addresses the question asked - 7. Ensure your final answer is specific and avoids vague time references - - {context} - - Question: {question} - - Answer: - """ - - -ANSWER_PROMPT_ZEP = """ - You are an intelligent memory assistant tasked with retrieving accurate information from conversation memories. - - # CONTEXT: - You have access to memories from a conversation. These memories contain - timestamped information that may be relevant to answering the question. - - # INSTRUCTIONS: - 1. Carefully analyze all provided memories - 2. Pay special attention to the timestamps to determine the answer - 3. If the question asks about a specific event or fact, look for direct evidence in the memories - 4. If the memories contain contradictory information, prioritize the most recent memory - 5. If there is a question about time references (like "last year", "two months ago", etc.), - calculate the actual date based on the memory timestamp. For example, if a memory from - 4 May 2022 mentions "went to India last year," then the trip occurred in 2021. - 6. Always convert relative time references to specific dates, months, or years. For example, - convert "last year" to "2022" or "two months ago" to "March 2023" based on the memory - timestamp. Ignore the reference while answering the question. - 7. Focus only on the content of the memories. Do not confuse character - names mentioned in memories with the actual users who created those memories. - 8. The answer should be less than 5-6 words. - - # APPROACH (Think step by step): - 1. First, examine all memories that contain information related to the question - 2. Examine the timestamps and content of these memories carefully - 3. Look for explicit mentions of dates, times, locations, or events that answer the question - 4. If the answer requires calculation (e.g., converting relative time references), show your work - 5. Formulate a precise, concise answer based solely on the evidence in the memories - 6. Double-check that your answer directly addresses the question asked - 7. Ensure your final answer is specific and avoids vague time references - - Context: - - {context} - - Question: {question} - Answer: - """ - -ANSWER_PROMPT_MEMOS = """ - You are a knowledgeable and helpful AI assistant. - - # CONTEXT: - You have access to memories from two speakers in a conversation. These memories contain - timestamped information that may be relevant to answering the question. - - # INSTRUCTIONS: - 1. Carefully analyze all provided memories. Synthesize information across different entries if needed to form a complete answer. - 2. Pay close attention to the timestamps to determine the answer. If memories contain contradictory information, the **most recent memory** is the source of truth. - 3. If the question asks about a specific event or fact, look for direct evidence in the memories. - 4. Your answer must be grounded in the memories. However, you may use general world knowledge to interpret or complete information found within a memory (e.g., identifying a landmark mentioned by description). - 5. If the question involves time references (like "last year", "two months ago", etc.), you **must** calculate the actual date based on the memory's timestamp. For example, if a memory from 4 May 2022 mentions "went to India last year," then the trip occurred in 2021. - 6. Always convert relative time references to specific dates, months, or years in your final answer. - 7. Do not confuse character names mentioned in memories with the actual users who created them. - 8. The answer must be brief (under 5-6 words) and direct, with no extra description. - - # APPROACH (Think step by step): - 1. First, examine all memories that contain information related to the question. - 2. Synthesize findings from multiple memories if a single entry is insufficient. - 3. Examine timestamps and content carefully, looking for explicit dates, times, locations, or events. - 4. If the answer requires calculation (e.g., converting relative time references), perform the calculation. - 5. Formulate a precise, concise answer based on the evidence from the memories (and allowed world knowledge). - 6. Double-check that your answer directly addresses the question asked and adheres to all instructions. - 7. Ensure your final answer is specific and avoids vague time references. - - {context} - - Question: {question} - - Answer: - """ - -CONTEXT_ANSWERABILITY_PROMPT = """ -You are an AI assistant that analyzes whether given context can answer a specific question, considering the ground-truth answer. - -# TASK: -Analyze the provided context and determine if it contains sufficient information to answer the given question. Use the provided ground-truth answer to guide your judgment: if the context contains the necessary evidence to derive that answer (explicitly or via direct inference), respond YES; otherwise respond NO. - -# INSTRUCTIONS: -1. Carefully examine the context provided -2. Identify if the context contains information directly related to the question -3. Determine if the information is sufficient to provide a complete answer that matches the ground-truth -4. Consider both explicit mentions and straightforward implications present in the context -5. Return only "YES" if the context can yield the ground-truth answer, "NO" if it cannot - -# CONTEXT: -{context} - -# QUESTION: -{question} - -# GROUND_TRUTH_ANSWER: -{gold_answer} - -# ANALYSIS: -Can this context answer the question and support the ground-truth answer? (YES/NO): -""" diff --git a/evaluation/scripts/temporal_locomo/modules/schemas.py b/evaluation/scripts/temporal_locomo/modules/schemas.py deleted file mode 100644 index fee89cc62..000000000 --- a/evaluation/scripts/temporal_locomo/modules/schemas.py +++ /dev/null @@ -1,161 +0,0 @@ -from typing import Any - -from pydantic import BaseModel, Field - - -class ContextUpdateMethod: - """Enumeration for context update methods""" - - PRE_CONTEXT = "pre_context" - CHAT_HISTORY = "chat_history" - CURRENT_CONTEXT = "current_context" - - @classmethod - def values(cls): - """Return a list of all constant values""" - return [ - getattr(cls, attr) - for attr in dir(cls) - if not attr.startswith("_") and isinstance(getattr(cls, attr), str) - ] - - -class RecordingCase(BaseModel): - """ - Data structure for recording evaluation cases in temporal locomo evaluation. - - This schema represents a single evaluation case containing conversation history, - context information, memory data, and evaluation results. - """ - - # Conversation identification - conv_id: str = Field(description="Conversation identifier for this evaluation case") - - context: str = Field( - default="", - description="Current search context retrieved from memory systems for answering the query", - ) - - pre_context: str | None = Field( - default=None, - description="Previous context from the last query, used for answerability analysis", - ) - - # Query and answer information - query: str = Field(description="The current question/query being evaluated") - - answer: str = Field(description="The generated answer for the query") - - # Evaluation metrics - can_answer: bool | None = Field( - default=None, - description="Whether the context can answer the query (only for memos_scheduler frame)", - ) - - can_answer_reason: str | None = Field( - default=None, description="Reasoning for the can_answer decision" - ) - - # Additional metadata - category: int | None = Field( - default=None, description="Category of the query (1-4, where 5 is filtered out)" - ) - - golden_answer: str | None = Field( - default=None, description="Ground truth answer for evaluation" - ) - - search_duration_ms: float | None = Field( - default=None, description="Time taken for memory search in milliseconds" - ) - - response_duration_ms: float | None = Field( - default=None, description="Time taken for response generation in milliseconds" - ) - - can_answer_duration_ms: float | None = Field( - default=None, description="Time taken for answerability analysis in milliseconds" - ) - - def to_dict(self) -> dict[str, Any]: - """ - Convert the RecordingCase to a dictionary for serialization. - - Returns: - Dict[str, Any]: Dictionary representation of the RecordingCase - """ - return self.dict() - - def to_json(self, indent: int = 2) -> str: - """ - Convert the RecordingCase to a JSON string. - - Args: - indent: JSON indentation level - - Returns: - str: JSON string representation of the RecordingCase - """ - return self.json(indent=indent, ensure_ascii=False) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> "RecordingCase": - """ - Create a RecordingCase from a dictionary. - - Args: - data: Dictionary containing RecordingCase data - - Returns: - RecordingCase: New instance created from the dictionary - """ - return cls(**data) - - @classmethod - def from_json(cls, json_str: str) -> "RecordingCase": - """ - Create a RecordingCase from a JSON string. - - Args: - json_str: JSON string containing RecordingCase data - - Returns: - RecordingCase: New instance created from the JSON string - """ - import json - - data = json.loads(json_str) - return cls.from_dict(data) - - class Config: - """Pydantic configuration""" - - extra = "allow" # Allow additional fields not defined in the schema - validate_assignment = True # Validate on assignment - use_enum_values = True # Use enum values instead of enum names - - -class TimeEvalRecordingCase(BaseModel): - memos_search_duration_ms: float | None = Field( - default=None, description="Time taken for memory search in milliseconds" - ) - - memos_response_duration_ms: float | None = Field( - default=None, description="Time taken for response generation in milliseconds" - ) - - memos_can_answer_duration_ms: float | None = Field( - default=None, description="Time taken for answerability analysis in milliseconds" - ) - - scheduler_search_duration_ms: float | None = Field( - default=None, description="Time taken for memory search in milliseconds" - ) - - scheduler_response_duration_ms: float | None = Field( - default=None, description="Time taken for response generation in milliseconds" - ) - - scheduler_can_answer_duration_ms: float | None = Field( - default=None, description="Time taken for answerability analysis in milliseconds" - ) diff --git a/evaluation/scripts/temporal_locomo/modules/utils.py b/evaluation/scripts/temporal_locomo/modules/utils.py deleted file mode 100644 index 215bc4256..000000000 --- a/evaluation/scripts/temporal_locomo/modules/utils.py +++ /dev/null @@ -1,296 +0,0 @@ -import json - -from pathlib import Path - -from .schemas import RecordingCase - - -def filter_memory_data(memories_data): - filtered_data = {} - for key, value in memories_data.items(): - if key == "text_mem": - filtered_data[key] = [] - for mem_group in value: - # Check if it's the new data structure (list of TextualMemoryItem objects) - if "memories" in mem_group and isinstance(mem_group["memories"], list): - # New data structure: directly a list of TextualMemoryItem objects - filtered_memories = [] - for memory_item in mem_group["memories"]: - # Create filtered dictionary - filtered_item = { - "id": memory_item.id, - "memory": memory_item.memory, - "metadata": {}, - } - # Filter metadata, excluding embedding - if hasattr(memory_item, "metadata") and memory_item.metadata: - for attr_name in dir(memory_item.metadata): - if not attr_name.startswith("_") and attr_name != "embedding": - attr_value = getattr(memory_item.metadata, attr_name) - if not callable(attr_value): - filtered_item["metadata"][attr_name] = attr_value - filtered_memories.append(filtered_item) - - filtered_group = { - "cube_id": mem_group.get("cube_id", ""), - "memories": filtered_memories, - } - filtered_data[key].append(filtered_group) - else: - # Old data structure: dictionary with nodes and edges - filtered_group = { - "memories": {"nodes": [], "edges": mem_group["memories"].get("edges", [])} - } - for node in mem_group["memories"].get("nodes", []): - filtered_node = { - "id": node.get("id"), - "memory": node.get("memory"), - "metadata": { - k: v - for k, v in node.get("metadata", {}).items() - if k != "embedding" - }, - } - filtered_group["memories"]["nodes"].append(filtered_node) - filtered_data[key].append(filtered_group) - else: - filtered_data[key] = value - return filtered_data - - -def save_recording_cases( - cases: list[RecordingCase], output_dir: str | Path, filename: str = "recording_cases.json" -) -> Path: - """ - Save a list of RecordingCase objects to a JSON file. - - Args: - cases: List of RecordingCase objects to save - output_dir: Directory to save the file - filename: Name of the output file (default: "recording_cases.json") - - Returns: - Path: Path to the saved file - """ - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - file_path = output_dir / filename - - # Convert cases to dictionaries for JSON serialization - cases_data = [case.to_dict() for case in cases] - - with open(file_path, "w", encoding="utf-8") as f: - json.dump(cases_data, f, indent=2, ensure_ascii=False) - - return file_path - - -def load_recording_cases(file_path: str | Path) -> list[RecordingCase]: - """ - Load RecordingCase objects from a JSON file. - - Args: - file_path: Path to the JSON file containing RecordingCase data - - Returns: - List[RecordingCase]: List of RecordingCase objects loaded from the file - """ - file_path = Path(file_path) - - with open(file_path, encoding="utf-8") as f: - cases_data = json.load(f) - - return [RecordingCase.from_dict(case_data) for case_data in cases_data] - - -def save_evaluation_cases( - can_answer_cases: list[RecordingCase], - cannot_answer_cases: list[RecordingCase], - output_dir: str | Path, - frame: str = "default", - version: str = "default", -) -> dict[str, Path]: - """ - Save both can_answer_cases and cannot_answer_cases to separate JSON files. - - Args: - can_answer_cases: List of cases that can be answered - cannot_answer_cases: List of cases that cannot be answered - output_dir: Directory to save the files - frame: Framework name for filename prefix - version: Version identifier for filename - - Returns: - Dict[str, Path]: Dictionary mapping case type to saved file path - """ - output_dir = Path(output_dir) - output_dir.mkdir(parents=True, exist_ok=True) - - saved_files = {} - - # Save can_answer_cases - if can_answer_cases: - can_answer_filename = f"{frame}_{version}_can_answer_cases.json" - can_answer_path = save_recording_cases(can_answer_cases, output_dir, can_answer_filename) - saved_files["can_answer_cases"] = can_answer_path - print(f"Saved {len(can_answer_cases)} can_answer_cases to {can_answer_path}") - - # Save cannot_answer_cases - if cannot_answer_cases: - cannot_answer_filename = f"{frame}_{version}_cannot_answer_cases.json" - cannot_answer_path = save_recording_cases( - cannot_answer_cases, output_dir, cannot_answer_filename - ) - saved_files["cannot_answer_cases"] = cannot_answer_path - print(f"Saved {len(cannot_answer_cases)} cannot_answer_cases to {cannot_answer_path}") - - return saved_files - - -def compute_can_answer_stats(day_groups, rounds_to_consider=float("inf")): - """ - Compute can-answer statistics for each day using the union of all prior evidences. - - For each day, iterate over the QAs in the given order. If the current QA's - evidences (restricted to the same day) are a subset of the union of all - previously seen evidences for that day, increment can_answer_count. Then add - the current evidences to the seen set. - - Note: - The first QA of each day is excluded from the statistics because it - cannot be answered without any prior evidences. It is still used to - seed the seen evidences for subsequent QAs. - - Args: - day_groups: Dict mapping day_id (e.g., "D1") to a list of QA dicts. Each QA - dict should contain an "evidence" field that is a list of strings. - rounds_to_consider: Number of previous rounds to consider for evidence accumulation. - Default is infinity (all previous rounds). - Set to 1 to only consider the immediately preceding round. - - Returns: - dict: Mapping day_id -> {"can_answer_count": int, "total": int, "ratio": float} - """ - results = {} - for day, qa_list in day_groups.items(): - seen = set() - # Keep track of evidence history for limited rounds - evidence_history = [] - can_answer = 0 - total = max(len(qa_list) - 1, 0) - rounds_count = 0 - for idx, qa in enumerate(qa_list): - cur = set(qa.get("evidence", [])) - rounds_count += 1 - - if idx == 0: - # Seed seen evidences with the first QA but do not count it - evidence_history.append(cur) - seen = set().union(*evidence_history) - continue - - # Check if current evidence is subset of accumulated evidence - if cur and cur.issubset(seen): - can_answer += 1 - - # Add current evidence to history - evidence_history.append(cur) - - # Limit history to specified number of rounds - if rounds_count > rounds_to_consider: - evidence_history.pop(0) - - # Recalculate seen as union of evidence_history - seen = set().union(*evidence_history) - - results[day] = { - "can_answer_count": can_answer, - "total": total, - "ratio": (can_answer / total) if total else 0.0, - } - return results - - -def compute_can_answer_count_by_pre_evidences( - temporal_locomo_data, num_of_users, stats_dir=None, rounds_to_consider=float("inf") -): - """ - Compute can-answer statistics per day for each conversation using the - union of all previously asked evidences within the same day. - - Args: - temporal_locomo_data: The temporal locomo data containing conversations - num_of_users: Number of users/conversations to process - stats_dir: Directory to save statistics (optional) - rounds_to_consider: Number of previous rounds to consider for evidence accumulation. - Default is infinity (all previous rounds). - Set to 1 to only consider the immediately preceding round. - - Returns: - dict: Mapping conversation_id -> per-day stats as produced by compute_can_answer_stats - """ - all_conversations_stats = {} - for conv_idx in range(num_of_users): - temporal_conv = temporal_locomo_data[conv_idx] - conversation_id = temporal_conv["conversation_id"] - - # Build day -> qa_pairs mapping - day_groups = {} - for day_id, day_data in temporal_conv.get("days", {}).items(): - day_groups[day_id] = day_data.get("qa_pairs", []) - - # Use shared utility to compute stats with correct accumulation logic - per_day_stats = compute_can_answer_stats(day_groups, rounds_to_consider) - all_conversations_stats[conversation_id] = per_day_stats - - # Build per-conversation summaries and overall summary - per_conversation_summaries = {} - overall_can = 0 - overall_total = 0 - for conv_id, day_stats in all_conversations_stats.items(): - conv_can = 0 - conv_total = 0 - for _day, stats in day_stats.items(): - conv_can += int(stats.get("can_answer_count", 0)) - conv_total += int(stats.get("total", 0)) - conv_ratio = (conv_can / conv_total) if conv_total else 0.0 - per_conversation_summaries[conv_id] = { - "can_answer_count": conv_can, - "total": conv_total, - "ratio": conv_ratio, - } - overall_can += conv_can - overall_total += conv_total - - overall_summary = { - "can_answer_count": overall_can, - "total": overall_total, - "ratio": (overall_can / overall_total) if overall_total else 0.0, - } - - # Add rounds information to the result - result_payload = { - "per_conversation_summary": per_conversation_summaries, - "overall_summary": overall_summary, - "rounds_considered": rounds_to_consider if rounds_to_consider != float("inf") else "all", - } - - # Print results - print("\nComputed can-answer-by-pre-evidences stats:") - print( - f"Rounds considered: {rounds_to_consider if rounds_to_consider != float('inf') else 'all'}" - ) - print(json.dumps(result_payload, indent=2, ensure_ascii=False)) - - # Save results if stats_dir is provided - if stats_dir: - output_path = ( - stats_dir - / f"evidences_rounds_{rounds_to_consider if rounds_to_consider != float('inf') else 'all'}.json" - ) - with open(output_path, "w", encoding="utf-8") as fw: - json.dump(result_payload, fw, indent=2, ensure_ascii=False) - print(f"Saved stats to {output_path}") - - return result_payload diff --git a/evaluation/scripts/temporal_locomo/scheduler_time_eval.py b/evaluation/scripts/temporal_locomo/scheduler_time_eval.py deleted file mode 100644 index 12d1964cd..000000000 --- a/evaluation/scripts/temporal_locomo/scheduler_time_eval.py +++ /dev/null @@ -1,93 +0,0 @@ -import argparse -import sys - -from pathlib import Path - -from modules.locomo_eval_module import LocomoEvalModelModules -from modules.schemas import ContextUpdateMethod - -from evaluation.scripts.temporal_locomo.models.locomo_ingestion import LocomoIngestor -from evaluation.scripts.temporal_locomo.models.locomo_processor_w_time_eval import ( - LocomoProcessorWithTimeEval, -) -from memos.log import get_logger - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -# TODO: This evaluation has been suspended—it is not finished yet. -class TemporalLocomoForTimeEval(LocomoEvalModelModules): - def __init__(self, args): - args.result_dir_prefix = "time_eval-" - - super().__init__(args=args) - self.num_of_users = 10 - - self.locomo_ingestor = LocomoIngestor(args=args) - self.locomo_processor = LocomoProcessorWithTimeEval(args=args) - - def run_time_eval_pipeline(self, skip_ingestion=True, skip_processing=False): - """ - Run the complete evaluation pipeline including dataset conversion, - data ingestion, and processing. - """ - print("=" * 80) - print("Starting TimeLocomo Evaluation Pipeline") - print("=" * 80) - - # Step 1: Check if temporal_locomo dataset exists, if not convert it - temporal_locomo_file = self.data_dir / "temporal_locomo" / "temporal_locomo_qa.json" - if not temporal_locomo_file.exists(): - print(f"Temporal locomo dataset not found at {temporal_locomo_file}") - print("Converting locomo dataset to temporal_locomo format...") - self.convert_locomo_to_temporal_locomo(output_dir=self.data_dir / "temporal_locomo") - print("Dataset conversion completed.") - else: - print(f"Temporal locomo dataset found at {temporal_locomo_file}, skipping conversion.") - - # Step 2: Data ingestion - if not skip_ingestion: - print("\n" + "=" * 50) - print("Step 2: Data Ingestion") - print("=" * 50) - self.locomo_ingestor.run_ingestion() - - # Step 3: Processing and evaluation - print("\n" + "=" * 50) - print("Step 3: Processing and Evaluation") - print("=" * 50) - print("Running locomo processing to search and answer...") - - print("Starting locomo processing to generate search and response results...") - self.locomo_processor.run_locomo_processing(num_users=self.num_of_users) - print("Processing completed successfully.") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--version", - type=str, - default="v1.0.1", - help="Version identifier for saving results (e.g., 1010)", - ) - parser.add_argument( - "--workers", type=int, default=10, help="Number of parallel workers to process users" - ) - parser.add_argument( - "--top_k", type=int, default=20, help="Number of results to retrieve in search queries" - ) - - args = parser.parse_args() - - args.frame = "memos_scheduler" - args.scheduler_flag = True - args.context_update_method = ContextUpdateMethod.PRE_CONTEXT - - evaluator = TemporalLocomoForTimeEval(args=args) - evaluator.run_time_eval_pipeline() diff --git a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py b/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py deleted file mode 100644 index bb6967e7f..000000000 --- a/evaluation/scripts/temporal_locomo/temporal_locomo_eval.py +++ /dev/null @@ -1,155 +0,0 @@ -import argparse -import asyncio -import os -import sys - -from pathlib import Path - -from modules.locomo_eval_module import LocomoEvalModelModules -from modules.schemas import ContextUpdateMethod -from modules.utils import compute_can_answer_count_by_pre_evidences - -from evaluation.scripts.temporal_locomo.models.locomo_eval import LocomoEvaluator -from evaluation.scripts.temporal_locomo.models.locomo_ingestion import LocomoIngestor -from evaluation.scripts.temporal_locomo.models.locomo_metric import LocomoMetric -from evaluation.scripts.temporal_locomo.models.locomo_processor import LocomoProcessor -from memos.log import get_logger - - -FILE_PATH = Path(__file__).absolute() -BASE_DIR = FILE_PATH.parent.parent.parent -sys.path.insert(0, str(BASE_DIR)) # Enable execution from any working directory - -logger = get_logger(__name__) - - -class TemporalLocomoEval(LocomoEvalModelModules): - def __init__(self, args): - super().__init__(args=args) - self.num_of_users = 10 - - self.locomo_ingestor = LocomoIngestor(args=args) - self.locomo_processor = LocomoProcessor(args=args) - self.locomo_evaluator = LocomoEvaluator(args=args) - self.locomo_metric = LocomoMetric(args=args) - - def run_answer_hit_eval_pipeline(self, skip_ingestion=True, skip_processing=False): - """ - Run the complete evaluation pipeline including dataset conversion, - data ingestion, and processing. - """ - print("=" * 80) - print("Starting TimeLocomo Evaluation Pipeline") - print("=" * 80) - - # Step 1: Check if temporal_locomo dataset exists, if not convert it - temporal_locomo_file = self.data_dir / "temporal_locomo" / "temporal_locomo_qa.json" - if not temporal_locomo_file.exists(): - print(f"Temporal locomo dataset not found at {temporal_locomo_file}") - print("Converting locomo dataset to temporal_locomo format...") - self.convert_locomo_to_temporal_locomo(output_dir=self.data_dir / "temporal_locomo") - print("Dataset conversion completed.") - else: - print(f"Temporal locomo dataset found at {temporal_locomo_file}, skipping conversion.") - - # Step 2: Data ingestion - if not skip_ingestion: - print("\n" + "=" * 50) - print("Step 2: Data Ingestion") - print("=" * 50) - self.locomo_ingestor.run_ingestion() - - # Step 3: Processing and evaluation - if not skip_processing: - print("\n" + "=" * 50) - print("Step 3: Processing and Evaluation") - print("=" * 50) - print("Running locomo processing to search and answer...") - - print("Starting locomo processing to generate search and response results...") - self.locomo_processor.run_locomo_processing(num_users=self.num_of_users) - print("Processing completed successfully.") - - # Optional: run post-hoc evaluation over generated responses if available - try: - if os.path.exists(self.response_path): - print("Running LocomoEvaluator over existing response results...") - asyncio.run(self.locomo_evaluator.run()) - else: - print( - f"Skipping LocomoEvaluator: response file not found at {evaluator.response_path}" - ) - # Run metrics summarization if judged file is produced - - if os.path.exists(self.judged_path): - print("Running LocomoMetric over judged results...") - self.locomo_metric.run() - else: - print(f"Skipping LocomoMetric: judged file not found at {self.judged_path}") - except Exception as e: - logger.error(f"LocomoEvaluator step skipped due to error: {e}", exc_info=True) - - # Step 4: Summary - print("\n" + "=" * 80) - print("Evaluation Pipeline Completed Successfully!") - print("=" * 80) - print("Results saved to:") - print(f" - Search results: {self.search_path}") - print(f" - Response results: {self.response_path}") - print(f" - Statistics: {self.stats_path}") - print("=" * 80) - - def compute_can_answer_count_by_pre_evidences(self, rounds_to_consider): - """ - Compute can-answer statistics per day for each conversation using the - union of all previously asked evidences within the same day. - - Returns: - dict: Mapping conversation_id -> per-day stats as produced by compute_can_answer_stats - """ - return compute_can_answer_count_by_pre_evidences( - temporal_locomo_data=self.temporal_locomo_data, - num_of_users=self.num_of_users, - stats_dir=self.stats_dir, - rounds_to_consider=rounds_to_consider, - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--frame", - type=str, - default="memos", - choices=["zep", "memos", "mem0", "mem0_graph", "memos_scheduler"], - help="Specify the memory framework (zep or memos or mem0 or mem0_graph)", - ) - parser.add_argument( - "--version", - type=str, - default="v1.0.1", - help="Version identifier for saving results (e.g., 1010)", - ) - parser.add_argument( - "--workers", type=int, default=10, help="Number of parallel workers to process users" - ) - parser.add_argument( - "--top_k", type=int, default=20, help="Number of results to retrieve in search queries" - ) - parser.add_argument( - "--scheduler_flag", - action=argparse.BooleanOptionalAction, - default=False, - help="Enable or disable memory scheduler features", - ) - parser.add_argument( - "--context_update_method", - type=str, - default="chat_history", - choices=ContextUpdateMethod.values(), - help="Method to update context: pre_context (use previous context), chat_history (use template with history), current_context (use current context)", - ) - args = parser.parse_args() - - evaluator = TemporalLocomoEval(args=args) - evaluator.run_answer_hit_eval_pipeline() diff --git a/examples/mem_scheduler/api_w_scheduler.py b/examples/mem_scheduler/api_w_scheduler.py index 11f0ebb81..a2184e9ca 100644 --- a/examples/mem_scheduler/api_w_scheduler.py +++ b/examples/mem_scheduler/api_w_scheduler.py @@ -1,3 +1,7 @@ +from memos.api.handlers.scheduler_handler import ( + handle_scheduler_status, + handle_scheduler_wait, +) from memos.api.routers.server_router import mem_scheduler from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem @@ -9,14 +13,9 @@ print(f"use_redis_queue: {mem_scheduler.use_redis_queue}") print(f"Queue type: {type(mem_scheduler.memos_message_queue).__name__}") print(f"Queue maxsize: {getattr(mem_scheduler.memos_message_queue, 'maxsize', 'N/A')}") - -# Check if Redis queue is connected -if hasattr(mem_scheduler.memos_message_queue, "_is_connected"): - print(f"Redis connected: {mem_scheduler.memos_message_queue._is_connected}") -if hasattr(mem_scheduler.memos_message_queue, "_redis_conn"): - print(f"Redis connection: {mem_scheduler.memos_message_queue._redis_conn}") print("=====================================\n") +mem_scheduler.memos_message_queue.debug_mode_on() queue = mem_scheduler.memos_message_queue queue.clear() @@ -27,7 +26,7 @@ def my_test_handler(messages: list[ScheduleMessageItem]): for msg in messages: print(f" my_test_handler - {msg.item_id}: {msg.content}") print( - f"{queue._redis_conn.xinfo_groups(queue.stream_name)} qsize: {queue.qsize()} messages:{messages}" + f"{queue._redis_conn.xinfo_groups(queue.stream_key_prefix)} qsize: {queue.qsize()} messages:{messages}" ) @@ -35,6 +34,12 @@ def my_test_handler(messages: list[ScheduleMessageItem]): TEST_HANDLER_LABEL = "test_handler" mem_scheduler.register_handlers({TEST_HANDLER_LABEL: my_test_handler}) +# 2.1 Monitor global scheduler status before submitting tasks +global_status_before = handle_scheduler_status( + user_name=None, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler" +) +print("[Monitor] Global status before submit:", global_status_before) + # 3. Create messages messages_to_send = [ ScheduleMessageItem( @@ -50,12 +55,33 @@ def my_test_handler(messages: list[ScheduleMessageItem]): # 5. Submit messages for mes in messages_to_send: print(f"Submitting message {mes.item_id} to the scheduler...") - mem_scheduler.submit_messages([mes]) + mem_scheduler.memos_message_queue.submit_messages([mes]) + +# 5.1 Monitor status for specific mem_cube while running +USER_MEM_CUBE = "test_mem_cube" +user_status_running = handle_scheduler_status( + user_name=USER_MEM_CUBE, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler" +) +print(f"[Monitor] Status for {USER_MEM_CUBE} after submit:", user_status_running) # 6. Wait for messages to be processed (limited to 100 checks) print("Waiting for messages to be consumed (max 100 checks)...") mem_scheduler.mem_scheduler_wait() +# 6.1 Wait until idle for specific mem_cube via handler +wait_result = handle_scheduler_wait( + user_name=USER_MEM_CUBE, + timeout_seconds=120.0, + poll_interval=0.2, + mem_scheduler=mem_scheduler, +) +print(f"[Monitor] Wait result for {USER_MEM_CUBE}:", wait_result) + +# 6.2 Monitor global scheduler status after processing +global_status_after = handle_scheduler_status( + user_name=None, mem_scheduler=mem_scheduler, instance_id="api_w_scheduler" +) +print("[Monitor] Global status after processing:", global_status_after) # 7. Stop the scheduler print("Stopping the scheduler...") diff --git a/src/memos/api/handlers/add_handler.py b/src/memos/api/handlers/add_handler.py index 48db7ae6e..ee481d028 100644 --- a/src/memos/api/handlers/add_handler.py +++ b/src/memos/api/handlers/add_handler.py @@ -202,7 +202,7 @@ def _process_pref_mem( content=json.dumps(messages_list), timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item_pref]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_pref]) self.logger.info("Submitted preference add to scheduler (async mode)") except Exception as e: self.logger.error(f"Failed to submit PREF_ADD task: {e}", exc_info=True) @@ -275,7 +275,7 @@ def _schedule_memory_tasks( timestamp=datetime.utcnow(), user_name=add_req.mem_cube_id, ) - self.mem_scheduler.submit_messages(messages=[message_item_read]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_read]) self.logger.info(f"Submitted async memory read task: {json.dumps(mem_ids)}") except Exception as e: self.logger.error(f"Failed to submit async memory tasks: {e}", exc_info=True) @@ -291,4 +291,4 @@ def _schedule_memory_tasks( timestamp=datetime.utcnow(), user_name=add_req.mem_cube_id, ) - self.mem_scheduler.submit_messages(messages=[message_item_add]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item_add]) diff --git a/src/memos/api/handlers/base_handler.py b/src/memos/api/handlers/base_handler.py index 86a00dc37..a174defb1 100644 --- a/src/memos/api/handlers/base_handler.py +++ b/src/memos/api/handlers/base_handler.py @@ -8,6 +8,7 @@ from typing import Any from memos.log import get_logger +from memos.mem_scheduler.base_scheduler import BaseScheduler logger = get_logger(__name__) @@ -123,7 +124,7 @@ def mem_reader(self): return self.deps.mem_reader @property - def mem_scheduler(self): + def mem_scheduler(self) -> BaseScheduler: """Get scheduler instance.""" return self.deps.mem_scheduler diff --git a/src/memos/api/handlers/chat_handler.py b/src/memos/api/handlers/chat_handler.py index f6023e5c8..8540a67ec 100644 --- a/src/memos/api/handlers/chat_handler.py +++ b/src/memos/api/handlers/chat_handler.py @@ -213,7 +213,7 @@ def generate_chat_response() -> Generator[str, None, None]: query=chat_req.query, top_k=20, session_id=chat_req.session_id, - mode=SearchMode.FINE if chat_req.internet_search else SearchMode.FAST, + mode=SearchMode.FAST, internet_search=chat_req.internet_search, # TODO this param is not worked at fine mode moscube=chat_req.moscube, chat_history=chat_req.history, @@ -603,7 +603,7 @@ def _send_message_to_scheduler( content=query, timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) self.logger.info(f"Sent message to scheduler with label: {label}") except Exception as e: self.logger.error(f"Failed to send message to scheduler: {e}", exc_info=True) diff --git a/src/memos/api/handlers/search_handler.py b/src/memos/api/handlers/search_handler.py index e8e4e07d6..cf2ab73bb 100644 --- a/src/memos/api/handlers/search_handler.py +++ b/src/memos/api/handlers/search_handler.py @@ -17,10 +17,14 @@ ) from memos.api.product_models import APISearchRequest, SearchResponse from memos.context.context import ContextThreadPoolExecutor +from memos.log import get_logger from memos.mem_scheduler.schemas.general_schemas import SearchMode from memos.types import MOSSearchResult, UserContext +logger = get_logger(__name__) + + class SearchHandler(BaseHandler): """ Handler for memory search operations. @@ -101,17 +105,6 @@ def handle_search_memories(self, search_req: APISearchRequest) -> SearchResponse ) def _get_search_mode(self, mode: str) -> str: - """ - Get search mode with environment variable fallback. - - Args: - mode: Requested search mode - - Returns: - Search mode string - """ - if mode == SearchMode.NOT_INITIALIZED: - return os.getenv("SEARCH_MODE", SearchMode.FAST) return mode def _search_text( @@ -133,16 +126,16 @@ def _search_text( """ try: if search_mode == SearchMode.FAST: - memories = self._fast_search(search_req, user_context) + text_memories = self._fast_search(search_req, user_context) elif search_mode == SearchMode.FINE: - memories = self._fine_search(search_req, user_context) + text_memories = self._fine_search(search_req, user_context) elif search_mode == SearchMode.MIXTURE: - memories = self._mix_search(search_req, user_context) + text_memories = self._mix_search(search_req, user_context) else: self.logger.error(f"Unsupported search mode: {search_mode}") return [] - return [format_memory_item(data) for data in memories] + return text_memories except Exception as e: self.logger.error("Error in search_text: %s; traceback: %s", e, traceback.format_exc()) @@ -199,7 +192,7 @@ def _fast_search( target_session_id = search_req.session_id or "default_session" search_filter = {"session_id": search_req.session_id} if search_req.session_id else None - return self.naive_mem_cube.text_mem.search( + search_results = self.naive_mem_cube.text_mem.search( query=search_req.query, user_name=user_context.mem_cube_id, top_k=search_req.top_k, @@ -214,6 +207,10 @@ def _fast_search( }, ) + formatted_memories = [format_memory_item(data) for data in search_results] + + return formatted_memories + def _fine_search( self, search_req: APISearchRequest, @@ -240,7 +237,7 @@ def _fine_search( "chat_history": search_req.chat_history, } - # Fast retrieve + # Fine retrieve fast_retrieved_memories = searcher.retrieve( query=search_req.query, user_name=user_context.mem_cube_id, @@ -261,12 +258,45 @@ def _fine_search( ) # Enhance with query - enhanced_results, _ = self.mem_scheduler.retriever.enhance_memories_with_query( + enhanced_memories, _ = self.mem_scheduler.retriever.enhance_memories_with_query( query_history=[search_req.query], memories=fast_memories, ) - return enhanced_results + if len(enhanced_memories) < len(fast_memories): + logger.info( + f"Enhanced memories ({len(enhanced_memories)}) are less than fast memories ({len(fast_memories)}). Recalling for more." + ) + missing_info_hint, trigger = self.mem_scheduler.retriever.recall_for_missing_memories( + query=search_req.query, + memories=fast_memories, + ) + retrieval_size = len(fast_memories) - len(enhanced_memories) + logger.info(f"Retrieval size: {retrieval_size}") + if trigger: + logger.info(f"Triggering additional search with hint: {missing_info_hint}") + additional_memories = searcher.search( + query=missing_info_hint, + user_name=user_context.mem_cube_id, + top_k=retrieval_size, + mode=SearchMode.FAST, + memory_type="All", + search_filter=search_filter, + info=info, + ) + else: + logger.info("Not triggering additional search, using fast memories.") + additional_memories = fast_memories[:retrieval_size] + + enhanced_memories += additional_memories + logger.info( + f"Added {len(additional_memories)} more memories. Total enhanced memories: {len(enhanced_memories)}" + ) + formatted_memories = [format_memory_item(data) for data in enhanced_memories] + + logger.info(f"Found {len(formatted_memories)} memories for user {search_req.user_id}") + + return formatted_memories def _mix_search( self, diff --git a/src/memos/api/product_models.py b/src/memos/api/product_models.py index 892d2d436..30df150ea 100644 --- a/src/memos/api/product_models.py +++ b/src/memos/api/product_models.py @@ -1,3 +1,4 @@ +import os import uuid from typing import Generic, Literal, TypeVar @@ -172,7 +173,7 @@ class APISearchRequest(BaseRequest): user_id: str = Field(None, description="User ID") mem_cube_id: str | None = Field(None, description="Cube ID to search in") mode: SearchMode = Field( - SearchMode.NOT_INITIALIZED, description="search mode: fast, fine, or mixture" + os.getenv("SEARCH_MODE", SearchMode.FAST), description="search mode: fast, fine, or mixture" ) internet_search: bool = Field(False, description="Whether to use internet search") moscube: bool = Field(False, description="Whether to use MemOSCube") diff --git a/src/memos/api/routers/server_router.py b/src/memos/api/routers/server_router.py index d43f9ccdc..b3b517305 100644 --- a/src/memos/api/routers/server_router.py +++ b/src/memos/api/routers/server_router.py @@ -34,6 +34,7 @@ SuggestionResponse, ) from memos.log import get_logger +from memos.mem_scheduler.base_scheduler import BaseScheduler logger = get_logger(__name__) @@ -58,7 +59,7 @@ # Extract commonly used components for function-based handlers # (These can be accessed from the components dict without unpacking all of them) -mem_scheduler = components["mem_scheduler"] +mem_scheduler: BaseScheduler = components["mem_scheduler"] llm = components["llm"] naive_mem_cube = components["naive_mem_cube"] diff --git a/src/memos/mem_os/core.py b/src/memos/mem_os/core.py index 3b53cef1a..f11b3a44c 100644 --- a/src/memos/mem_os/core.py +++ b/src/memos/mem_os/core.py @@ -287,7 +287,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = content=query, timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) memories = mem_cube.text_mem.search( query, @@ -347,7 +347,7 @@ def chat(self, query: str, user_id: str | None = None, base_prompt: str | None = content=response, timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) return response @@ -774,7 +774,9 @@ def process_textual_memory(): content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages( + messages=[message_item] + ) else: message_item = ScheduleMessageItem( user_id=target_user_id, @@ -783,7 +785,9 @@ def process_textual_memory(): content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages( + messages=[message_item] + ) def process_preference_memory(): if ( @@ -818,7 +822,7 @@ def process_preference_memory(): content=json.dumps(messages_list), timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) # Execute both memory processing functions in parallel with ContextThreadPoolExecutor(max_workers=2) as executor: @@ -872,7 +876,9 @@ def process_preference_memory(): content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages( + messages=[message_item] + ) else: message_item = ScheduleMessageItem( user_id=target_user_id, @@ -881,7 +887,9 @@ def process_preference_memory(): content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages( + messages=[message_item] + ) # user doc input if ( @@ -910,7 +918,7 @@ def process_preference_memory(): content=json.dumps(mem_ids), timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) logger.info(f"Add memory to {mem_cube_id} successfully") diff --git a/src/memos/mem_os/main.py b/src/memos/mem_os/main.py index 0114fc0da..11c112d52 100644 --- a/src/memos/mem_os/main.py +++ b/src/memos/mem_os/main.py @@ -220,7 +220,7 @@ def _chat_with_cot_enhancement( content=enhanced_response, timestamp=datetime.now().isoformat(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) return enhanced_response diff --git a/src/memos/mem_os/product.py b/src/memos/mem_os/product.py index 359db72ba..9a4ab3f4d 100644 --- a/src/memos/mem_os/product.py +++ b/src/memos/mem_os/product.py @@ -641,7 +641,7 @@ def _send_message_to_scheduler( content=query, timestamp=datetime.utcnow(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) async def _post_chat_processing( self, diff --git a/src/memos/mem_scheduler/analyzer/eval_analyzer.py b/src/memos/mem_scheduler/analyzer/eval_analyzer.py index d37e17456..cf0b8f1dd 100644 --- a/src/memos/mem_scheduler/analyzer/eval_analyzer.py +++ b/src/memos/mem_scheduler/analyzer/eval_analyzer.py @@ -1244,7 +1244,7 @@ def analyze_bad_cases_with_llm_processing( return results -def main(): +def main(version_name="ct-1111"): """Main test function.""" print("=== EvalAnalyzer Simple Test ===") @@ -1254,7 +1254,7 @@ def main(): print("Analyzer initialized") # Test file paths - eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-xcy-1030-2114-locomo" + eval_result_dir = f"{BASE_DIR}/evaluation/results/locomo/memos-api-{version_name}-locomo" judged_file = os.path.join(eval_result_dir, "memos-api_locomo_judged.json") search_results_file = os.path.join(eval_result_dir, "memos-api_locomo_search_results.json") diff --git a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py index 03e1fc778..df504ee75 100644 --- a/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py +++ b/src/memos/mem_scheduler/analyzer/mos_for_test_scheduler.py @@ -521,7 +521,7 @@ def chat(self, query: str, user_id: str | None = None) -> str: content=response, timestamp=datetime.now(), ) - self.mem_scheduler.submit_messages(messages=[message_item]) + self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item]) return response diff --git a/src/memos/mem_scheduler/base_scheduler.py b/src/memos/mem_scheduler/base_scheduler.py index eb49d0238..657ceea0f 100644 --- a/src/memos/mem_scheduler/base_scheduler.py +++ b/src/memos/mem_scheduler/base_scheduler.py @@ -1,5 +1,5 @@ -import contextlib import multiprocessing +import os import threading import time @@ -14,10 +14,9 @@ from memos.context.context import ContextThread from memos.llms.base import BaseLLM from memos.log import get_logger +from memos.mem_cube.base import BaseMemCube from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue -from memos.mem_scheduler.general_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.general_modules.scheduler_logger import SchedulerLoggerModule from memos.mem_scheduler.memory_manage_modules.retriever import SchedulerRetriever from memos.mem_scheduler.monitors.dispatcher_monitor import SchedulerDispatcherMonitor @@ -43,6 +42,9 @@ ScheduleMessageItem, ) from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem +from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher +from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue +from memos.mem_scheduler.task_schedule_modules.task_queue import ScheduleTaskQueue from memos.mem_scheduler.utils.db_utils import get_utc_now from memos.mem_scheduler.utils.filter_utils import ( transform_name_to_key, @@ -56,7 +58,8 @@ if TYPE_CHECKING: - from memos.mem_cube.base import BaseMemCube + from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher + from memos.reranker.http_bge import HTTPBGEReranker logger = get_logger(__name__) @@ -88,22 +91,34 @@ def __init__(self, config: BaseSchedulerConfig): "scheduler_startup_mode", DEFAULT_STARTUP_MODE ) + # optional configs + self.disabled_handlers: list | None = self.config.get("disabled_handlers", None) + + self.max_web_log_queue_size = self.config.get( + "max_web_log_queue_size", DEFAULT_MAX_WEB_LOG_QUEUE_SIZE + ) + self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue( + maxsize=self.max_web_log_queue_size + ) + self._consumer_thread = None # Reference to our consumer thread/process + self._consumer_process = None # Reference to our consumer process + self._running = False + self._consume_interval = self.config.get( + "consume_interval_seconds", DEFAULT_CONSUME_INTERVAL_SECONDS + ) + self.consume_batch = self.config.get("consume_batch", DEFAULT_CONSUME_BATCH) + # message queue configuration self.use_redis_queue = self.config.get("use_redis_queue", DEFAULT_USE_REDIS_QUEUE) self.max_internal_message_queue_size = self.config.get( "max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE ) - - # Initialize message queue based on configuration - if self.use_redis_queue: - self.memos_message_queue = SchedulerRedisQueue( - maxsize=self.max_internal_message_queue_size - ) - else: - self.memos_message_queue: Queue[ScheduleMessageItem] = Queue( - maxsize=self.max_internal_message_queue_size - ) - + self.memos_message_queue = ScheduleTaskQueue( + use_redis_queue=self.use_redis_queue, + maxsize=self.max_internal_message_queue_size, + disabled_handlers=self.disabled_handlers, + ) + self.searcher: Searcher | None = None self.retriever: SchedulerRetriever | None = None self.db_engine: Engine | None = None self.monitor: SchedulerGeneralMonitor | None = None @@ -117,23 +132,6 @@ def __init__(self, config: BaseSchedulerConfig): enable_parallel_dispatch=self.enable_parallel_dispatch, ) - # optional configs - self.disable_handlers: list | None = self.config.get("disable_handlers", None) - - self.max_web_log_queue_size = self.config.get( - "max_web_log_queue_size", DEFAULT_MAX_WEB_LOG_QUEUE_SIZE - ) - self._web_log_message_queue: Queue[ScheduleLogForWebItem] = Queue( - maxsize=self.max_web_log_queue_size - ) - self._consumer_thread = None # Reference to our consumer thread/process - self._consumer_process = None # Reference to our consumer process - self._running = False - self._consume_interval = self.config.get( - "consume_interval_seconds", DEFAULT_CONSUME_INTERVAL_SECONDS - ) - self.consume_batch = self.config.get("consume_batch", DEFAULT_CONSUME_BATCH) - # other attributes self._context_lock = threading.Lock() self.current_user_id: UserID | str | None = None @@ -143,6 +141,15 @@ def __init__(self, config: BaseSchedulerConfig): self.auth_config = None self.rabbitmq_config = None + def init_mem_cube(self, mem_cube): + self.mem_cube = mem_cube + self.text_mem: TreeTextMemory = self.mem_cube.text_mem + self.searcher: Searcher = self.text_mem.get_searcher( + manual_close_internet=os.getenv("ENABLE_INTERNET", "true").lower() == "false", + moscube=False, + ) + self.reranker: HTTPBGEReranker = self.text_mem.reranker + def initialize_modules( self, chat_llm: BaseLLM, @@ -199,6 +206,9 @@ def initialize_modules( # start queue monitor if enabled and a bot is set later + def debug_mode_on(self): + self.memos_message_queue.debug_mode_on() + def _cleanup_on_init_failure(self): """Clean up resources if initialization fails.""" try: @@ -208,23 +218,16 @@ def _cleanup_on_init_failure(self): logger.warning(f"Error during cleanup: {e}") @property - def mem_cube(self) -> GeneralMemCube: + def mem_cube(self) -> BaseMemCube: """The memory cube associated with this MemChat.""" return self.current_mem_cube @mem_cube.setter - def mem_cube(self, value: GeneralMemCube) -> None: + def mem_cube(self, value: BaseMemCube) -> None: """The memory cube associated with this MemChat.""" self.current_mem_cube = value self.retriever.mem_cube = value - def _set_current_context_from_message(self, msg: ScheduleMessageItem) -> None: - """Update current user/cube context from the incoming message (thread-safe).""" - with self._context_lock: - self.current_user_id = msg.user_id - self.current_mem_cube_id = msg.mem_cube_id - self.current_mem_cube = self.get_mem_cube(msg.mem_cube_id) - def transform_working_memories_to_monitors( self, query_keywords, memories: list[TextualMemoryItem] ) -> list[MemoryMonitorItem]: @@ -523,29 +526,7 @@ def update_activation_memory_periodically( logger.error(f"Error in update_activation_memory_periodically: {e}", exc_info=True) def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): - """Submit messages to the message queue (either local queue or Redis).""" - if isinstance(messages, ScheduleMessageItem): - messages = [messages] # transform single message to list - - for message in messages: - if not isinstance(message, ScheduleMessageItem): - error_msg = f"Invalid message type: {type(message)}, expected ScheduleMessageItem" - logger.error(error_msg) - raise TypeError(error_msg) - - if getattr(message, "timestamp", None) is None: - with contextlib.suppress(Exception): - message.timestamp = datetime.utcnow() - - if self.disable_handlers and message.label in self.disable_handlers: - logger.info(f"Skipping disabled handler: {message.label} - {message.content}") - continue - self.memos_message_queue.put(message) - logger.info(f"Submitted message to local queue: {message.label} - {message.content}") - - with contextlib.suppress(Exception): - if messages: - self.dispatcher.on_messages_enqueued(messages) + self.memos_message_queue.submit_messages(messages=messages) def _submit_web_logs( self, messages: ScheduleLogForWebItem | list[ScheduleLogForWebItem] @@ -606,10 +587,16 @@ def _message_consumer(self) -> None: try: # Get messages in batches based on consume_batch setting - messages = self.memos_message_queue.get(block=True, batch_size=self.consume_batch) + messages = self.memos_message_queue.get_messages(batch_size=self.consume_batch) if messages: try: + import contextlib + + with contextlib.suppress(Exception): + if messages: + self.dispatcher.on_messages_enqueued(messages) + self.dispatcher.dispatch(messages) except Exception as e: logger.error(f"Error dispatching messages: {e!s}") @@ -878,7 +865,7 @@ def _fmt_eta(seconds: float | None) -> str: if isinstance(self.memos_message_queue, SchedulerRedisQueue): # For Redis queue, prefer XINFO GROUPS to compute pending groups_info = self.memos_message_queue.redis.xinfo_groups( - self.memos_message_queue.stream_name + self.memos_message_queue.stream_key_prefix ) if groups_info: for group in groups_info: diff --git a/src/memos/mem_scheduler/general_modules/base.py b/src/memos/mem_scheduler/general_modules/base.py index 0b80b9e7d..e0ee65ba0 100644 --- a/src/memos/mem_scheduler/general_modules/base.py +++ b/src/memos/mem_scheduler/general_modules/base.py @@ -18,8 +18,6 @@ def __init__(self): self._chat_llm = None self._process_llm = None - self.mem_cubes: dict[str, GeneralMemCube] = {} - def load_template(self, template_name: str) -> str: if template_name not in PROMPT_MAPPING: logger.error("Prompt template is not found!") diff --git a/src/memos/mem_scheduler/general_modules/scheduler_logger.py b/src/memos/mem_scheduler/general_modules/scheduler_logger.py index 1f89d3b02..d35a4f106 100644 --- a/src/memos/mem_scheduler/general_modules/scheduler_logger.py +++ b/src/memos/mem_scheduler/general_modules/scheduler_logger.py @@ -1,7 +1,7 @@ from collections.abc import Callable from memos.log import get_logger -from memos.mem_cube.general import GeneralMemCube +from memos.mem_cube.base import BaseMemCube from memos.mem_scheduler.general_modules.base import BaseSchedulerModule from memos.mem_scheduler.schemas.general_schemas import ( ACTIVATION_MEMORY_TYPE, @@ -44,7 +44,7 @@ def create_autofilled_log_item( to_memory_type: str, user_id: str, mem_cube_id: str, - mem_cube: GeneralMemCube, + mem_cube: BaseMemCube, ) -> ScheduleLogForWebItem: text_mem_base: TreeTextMemory = mem_cube.text_mem current_memory_sizes = text_mem_base.get_current_memory_size() @@ -106,7 +106,7 @@ def log_working_memory_replacement( new_memory: list[TextualMemoryItem], user_id: str, mem_cube_id: str, - mem_cube: GeneralMemCube, + mem_cube: BaseMemCube, log_func_callback: Callable[[list[ScheduleLogForWebItem]], None], ): """Log changes when working memory is replaced.""" @@ -163,7 +163,7 @@ def log_activation_memory_update( label: str, user_id: str, mem_cube_id: str, - mem_cube: GeneralMemCube, + mem_cube: BaseMemCube, log_func_callback: Callable[[list[ScheduleLogForWebItem]], None], ): """Log changes when activation memory is updated.""" @@ -214,7 +214,7 @@ def log_adding_memory( memory_type: str, user_id: str, mem_cube_id: str, - mem_cube: GeneralMemCube, + mem_cube: BaseMemCube, log_func_callback: Callable[[list[ScheduleLogForWebItem]], None], ): """Log changes when working memory is replaced.""" diff --git a/src/memos/mem_scheduler/general_scheduler.py b/src/memos/mem_scheduler/general_scheduler.py index 6e916962e..92e317881 100644 --- a/src/memos/mem_scheduler/general_scheduler.py +++ b/src/memos/mem_scheduler/general_scheduler.py @@ -5,6 +5,7 @@ from memos.configs.mem_scheduler import GeneralSchedulerConfig from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger +from memos.mem_cube.base import BaseMemCube from memos.mem_cube.general import GeneralMemCube from memos.mem_scheduler.base_scheduler import BaseScheduler from memos.mem_scheduler.schemas.general_schemas import ( @@ -22,6 +23,7 @@ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.monitor_schemas import QueryMonitorItem from memos.mem_scheduler.utils.filter_utils import is_all_chinese, is_all_english +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube from memos.memories.textual.item import TextualMemoryItem from memos.memories.textual.preference import PreferenceTextMemory from memos.memories.textual.tree import TreeTextMemory @@ -137,7 +139,7 @@ def long_memory_update_process( label=QUERY_LABEL, user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=self.current_mem_cube, + mem_cube=self.mem_cube, ) def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: @@ -150,7 +152,7 @@ def _query_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.info(f"Messages {messages} assigned to {QUERY_LABEL} handler.") # Process the query in a session turn - grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) + grouped_messages = group_messages_by_user_and_mem_cube(messages) self.validate_schedule_messages(messages=messages, label=QUERY_LABEL) @@ -172,7 +174,7 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: """ logger.info(f"Messages {messages} assigned to {ANSWER_LABEL} handler.") # Process the query in a session turn - grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) + grouped_messages = group_messages_by_user_and_mem_cube(messages) self.validate_schedule_messages(messages=messages, label=ANSWER_LABEL) @@ -185,7 +187,8 @@ def _answer_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.info(f"Messages {messages} assigned to {ADD_LABEL} handler.") # Process the query in a session turn - grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) + grouped_messages = group_messages_by_user_and_mem_cube(messages) + mem_cube = self.mem_cube self.validate_schedule_messages(messages=messages, label=ADD_LABEL) try: @@ -203,7 +206,6 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: logger.error(f"Error: {e}. Content: {msg.content}", exc_info=True) userinput_memory_ids = [] - mem_cube = self.current_mem_cube for memory_id in userinput_memory_ids: try: mem_item: TextualMemoryItem = mem_cube.text_mem.get( @@ -225,7 +227,7 @@ def _add_message_consumer(self, messages: list[ScheduleMessageItem]) -> None: memory_type=mem_type, user_id=msg.user_id, mem_cube_id=msg.mem_cube_id, - mem_cube=self.current_mem_cube, + mem_cube=self.mem_cube, log_func_callback=self._submit_web_logs, ) @@ -239,7 +241,7 @@ def process_message(message: ScheduleMessageItem): try: user_id = message.user_id mem_cube_id = message.mem_cube_id - mem_cube = self.current_mem_cube + mem_cube = self.mem_cube content = message.content user_name = message.user_name @@ -263,7 +265,6 @@ def process_message(message: ScheduleMessageItem): mem_ids=mem_ids, user_id=user_id, mem_cube_id=mem_cube_id, - mem_cube=mem_cube, text_mem=text_mem, user_name=user_name, ) @@ -288,7 +289,6 @@ def _process_memories_with_reader( mem_ids: list[str], user_id: str, mem_cube_id: str, - mem_cube: GeneralMemCube, text_mem: TreeTextMemory, user_name: str, ) -> None: @@ -299,7 +299,6 @@ def _process_memories_with_reader( mem_ids: List of memory IDs to process user_id: User ID mem_cube_id: Memory cube ID - mem_cube: Memory cube instance text_mem: Text memory instance """ try: @@ -403,7 +402,7 @@ def process_message(message: ScheduleMessageItem): try: user_id = message.user_id mem_cube_id = message.mem_cube_id - mem_cube = self.current_mem_cube + mem_cube = self.mem_cube content = message.content user_name = message.user_name @@ -452,7 +451,7 @@ def _process_memories_with_reorganize( mem_ids: list[str], user_id: str, mem_cube_id: str, - mem_cube: GeneralMemCube, + mem_cube: BaseMemCube, text_mem: TreeTextMemory, user_name: str, ) -> None: @@ -504,10 +503,11 @@ def _pref_add_message_consumer(self, messages: list[ScheduleMessageItem]) -> Non def process_message(message: ScheduleMessageItem): try: + mem_cube = self.mem_cube + user_id = message.user_id session_id = message.session_id mem_cube_id = message.mem_cube_id - mem_cube = self.current_mem_cube content = message.content messages_list = json.loads(content) diff --git a/src/memos/mem_scheduler/memory_manage_modules/retriever.py b/src/memos/mem_scheduler/memory_manage_modules/retriever.py index 848b1d257..01b57563d 100644 --- a/src/memos/mem_scheduler/memory_manage_modules/retriever.py +++ b/src/memos/mem_scheduler/memory_manage_modules/retriever.py @@ -1,3 +1,5 @@ +import time + from concurrent.futures import as_completed from memos.configs.mem_scheduler import BaseSchedulerConfig @@ -9,6 +11,8 @@ from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE, DEFAULT_SCHEDULER_RETRIEVER_RETRIES, + FINE_STRATEGY, + FineStrategy, TreeTextMemory_FINE_SEARCH_METHOD, TreeTextMemory_SEARCH_METHOD, ) @@ -91,9 +95,15 @@ def _build_enhancement_prompt(self, query_history: list[str], batch_texts: list[ if len(query_history) > 1 else query_history[0] ) - text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)]) + # Include numbering for rewrite mode to help LLM reference original memory IDs + if FINE_STRATEGY == FineStrategy.REWRITE: + text_memories = "\n".join([f"- [{i}] {mem}" for i, mem in enumerate(batch_texts)]) + prompt_name = "memory_rewrite_enhancement" + else: + text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(batch_texts)]) + prompt_name = "memory_recreate_enhancement" return self.build_prompt( - "memory_enhancement", + prompt_name, query_history=query_history, memories=text_memories, ) @@ -107,53 +117,80 @@ def _process_enhancement_batch( ) -> tuple[list[TextualMemoryItem], bool]: attempt = 0 text_memories = [one.memory for one in memories] - while attempt <= max(0, retries) + 1: - try: - prompt = self._build_enhancement_prompt( - query_history=query_history, batch_texts=text_memories - ) - logger.debug( - f"[Enhance][batch={batch_index}] Prompt (first 200 chars, len={len(prompt)}): " - f"{prompt[:200]}]..." - ) - response = self.process_llm.generate([{"role": "user", "content": prompt}]) - logger.debug( - f"[Enhance][batch={batch_index}] Response (first 200 chars): {response}..." - ) + prompt = self._build_enhancement_prompt( + query_history=query_history, batch_texts=text_memories + ) - processed_text_memories = extract_list_items_in_answer(response) - if len(processed_text_memories) == len(memories): - # Update - for i, new_mem in enumerate(processed_text_memories): - memories[i].memory = new_mem - enhanced_memories = memories - else: + llm_response = None + while attempt <= max(0, retries) + 1: + try: + llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) + processed_text_memories = extract_list_items_in_answer(llm_response) + if len(processed_text_memories) > 0: # create new enhanced_memories = [] user_id = memories[0].metadata.user_id - for new_mem in processed_text_memories: - enhanced_memories.append( - TextualMemoryItem( - memory=new_mem, metadata=TextualMemoryMetadata(user_id=user_id) + if FINE_STRATEGY == FineStrategy.RECREATE: + for new_mem in processed_text_memories: + enhanced_memories.append( + TextualMemoryItem( + memory=new_mem, metadata=TextualMemoryMetadata(user_id=user_id) + ) ) - ) - enhanced_memories = ( - enhanced_memories + memories[: len(memories) - len(enhanced_memories)] - ) + elif FINE_STRATEGY == FineStrategy.REWRITE: + # Parse index from each processed line and rewrite corresponding original memory + def _parse_index_and_text(s: str) -> tuple[int | None, str]: + import re + + s = (s or "").strip() + # Preferred: [index] text + m = re.match(r"^\s*\[(\d+)\]\s*(.+)$", s) + if m: + return int(m.group(1)), m.group(2).strip() + # Fallback: index: text or index - text + m = re.match(r"^\s*(\d+)\s*[:\-\)]\s*(.+)$", s) + if m: + return int(m.group(1)), m.group(2).strip() + return None, s + + idx_to_original = dict(enumerate(memories)) + for j, item in enumerate(processed_text_memories): + idx, new_text = _parse_index_and_text(item) + if idx is not None and idx in idx_to_original: + orig = idx_to_original[idx] + else: + # Fallback: align by order if index missing/invalid + orig = memories[j] if j < len(memories) else None + if not orig: + continue + enhanced_memories.append( + TextualMemoryItem( + id=orig.id, + memory=new_text, + metadata=orig.metadata, + ) + ) + else: + logger.error(f"Fine search strategy {FINE_STRATEGY} not exists") logger.info( - f"[Enhance]: processed_text_memories: {len(processed_text_memories)}; padded with original memories to preserve total count" + f"[enhance_memories_with_query] ✅ done | Strategy={FINE_STRATEGY} | prompt={prompt} | llm_response={llm_response}" + ) + return enhanced_memories, True + else: + raise ValueError( + f"Fail to run memory enhancement; retry {attempt}/{max(1, retries) + 1}; processed_text_memories: {processed_text_memories}" ) - - return enhanced_memories, True except Exception as e: attempt += 1 + time.sleep(1) logger.debug( - f"[Enhance][batch={batch_index}] 🔁 retry {attempt}/{max(1, retries) + 1} failed: {e}" + f"[enhance_memories_with_query][batch={batch_index}] 🔁 retry {attempt}/{max(1, retries) + 1} failed: {e}" ) logger.error( - f"Fail to run memory enhancement; original memories: {memories}", exc_info=True + f"Fail to run memory enhancement; prompt: {prompt};\n llm_response: {llm_response}", + exc_info=True, ) return memories, False @@ -170,6 +207,76 @@ def _split_batches( start = end return batches + def recall_for_missing_memories( + self, + query: str, + memories: list[TextualMemoryItem], + ) -> tuple[str, bool]: + text_memories = [one.memory for one in memories] if memories else [] + text_memories = "\n".join([f"- {mem}" for i, mem in enumerate(text_memories)]) + + prompt = self.build_prompt( + template_name="enlarge_recall", + query=query, + memories_inline=text_memories, + ) + llm_response = self.process_llm.generate([{"role": "user", "content": prompt}]) + + json_result: dict = extract_json_obj(llm_response) + + logger.info( + f"[recall_for_missing_memories] ✅ done | prompt={prompt} | llm_response={llm_response}" + ) + + hint = json_result.get("hint", "") + if len(hint) == 0: + return hint, False + return hint, json_result.get("trigger_recall", False) + + def search( + self, + query: str, + mem_cube: GeneralMemCube, + top_k: int, + method: str = TreeTextMemory_SEARCH_METHOD, + info: dict | None = None, + ) -> list[TextualMemoryItem]: + """Search in text memory with the given query. + + Args: + query: The search query string + top_k: Number of top results to return + method: Search method to use + + Returns: + Search results or None if not implemented + """ + text_mem_base = mem_cube.text_mem + try: + if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]: + assert isinstance(text_mem_base, TreeTextMemory) + if info is None: + logger.warning( + "Please input 'info' when use tree.search so that " + "the database would store the consume history." + ) + info = {"user_id": "", "session_id": ""} + + mode = "fast" if method == TreeTextMemory_SEARCH_METHOD else "fine" + results_long_term = text_mem_base.search( + query=query, top_k=top_k, memory_type="LongTermMemory", mode=mode, info=info + ) + results_user = text_mem_base.search( + query=query, top_k=top_k, memory_type="UserMemory", mode=mode, info=info + ) + results = results_long_term + results_user + else: + raise NotImplementedError(str(type(text_mem_base))) + except Exception as e: + logger.error(f"Fail to search. The exeption is {e}.", exc_info=True) + results = [] + return results + def enhance_memories_with_query( self, query_history: list[str], @@ -239,54 +346,10 @@ def enhance_memories_with_query( enhanced_memories = memories if len(enhanced_memories) == 0: - enhanced_memories = memories + enhanced_memories = [] logger.error("[Enhance] ❌ fatal error: enhanced_memories is empty", exc_info=True) return enhanced_memories, all_success - def search( - self, - query: str, - mem_cube: GeneralMemCube, - top_k: int, - method: str = TreeTextMemory_SEARCH_METHOD, - info: dict | None = None, - ) -> list[TextualMemoryItem]: - """Search in text memory with the given query. - - Args: - query: The search query string - top_k: Number of top results to return - method: Search method to use - - Returns: - Search results or None if not implemented - """ - text_mem_base = mem_cube.text_mem - try: - if method in [TreeTextMemory_SEARCH_METHOD, TreeTextMemory_FINE_SEARCH_METHOD]: - assert isinstance(text_mem_base, TreeTextMemory) - if info is None: - logger.warning( - "Please input 'info' when use tree.search so that " - "the database would store the consume history." - ) - info = {"user_id": "", "session_id": ""} - - mode = "fast" if method == TreeTextMemory_SEARCH_METHOD else "fine" - results_long_term = text_mem_base.search( - query=query, top_k=top_k, memory_type="LongTermMemory", mode=mode, info=info - ) - results_user = text_mem_base.search( - query=query, top_k=top_k, memory_type="UserMemory", mode=mode, info=info - ) - results = results_long_term + results_user - else: - raise NotImplementedError(str(type(text_mem_base))) - except Exception as e: - logger.error(f"Fail to search. The exeption is {e}.", exc_info=True) - results = [] - return results - def rerank_memories( self, queries: list[str], original_memories: list[str], top_k: int ) -> (list[str], bool): diff --git a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py index 99982d2e6..f8e321a82 100644 --- a/src/memos/mem_scheduler/monitors/dispatcher_monitor.py +++ b/src/memos/mem_scheduler/monitors/dispatcher_monitor.py @@ -7,13 +7,13 @@ from memos.context.context import ContextThread, ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule -from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.schemas.general_schemas import ( DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL, DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES, DEFAULT_STOP_WAIT, DEFAULT_STUCK_THREAD_TOLERANCE, ) +from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.utils.db_utils import get_utc_now diff --git a/src/memos/mem_scheduler/optimized_scheduler.py b/src/memos/mem_scheduler/optimized_scheduler.py index b62b1e51d..21b2d63f0 100644 --- a/src/memos/mem_scheduler/optimized_scheduler.py +++ b/src/memos/mem_scheduler/optimized_scheduler.py @@ -20,15 +20,13 @@ from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.utils.api_utils import format_textual_memory_item from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube from memos.memories.textual.tree import TextualMemoryItem, TreeTextMemory from memos.types import UserContext if TYPE_CHECKING: from memos.mem_scheduler.schemas.monitor_schemas import MemoryMonitorItem - from memos.memories.textual.tree_text_memory.retrieve.searcher import Searcher - from memos.reranker.http_bge import HTTPBGEReranker - logger = get_logger(__name__) @@ -56,15 +54,6 @@ def __init__(self, config: GeneralSchedulerConfig): self.reranker = None self.text_mem = None - def init_mem_cube(self, mem_cube): - self.current_mem_cube = mem_cube - self.text_mem: TreeTextMemory = self.current_mem_cube.text_mem - self.searcher: Searcher = self.text_mem.get_searcher( - manual_close_internet=False, - moscube=False, - ) - self.reranker: HTTPBGEReranker = self.text_mem.reranker - def submit_memory_history_async_task( self, search_req: APISearchRequest, @@ -99,7 +88,7 @@ def submit_memory_history_async_task( ) # Submit async task - self.submit_messages([message]) + self.memos_message_queue.submit_messages([message]) logger.info(f"Submitted async fine search task for user {search_req.user_id}") return async_task_id @@ -141,6 +130,9 @@ def mix_search_memories( """ Mix search memories: fast search + async fine search """ + logger.info( + f"Mix searching memories for user {search_req.user_id} with query: {search_req.query}" + ) # Get mem_cube for fast search target_session_id = search_req.session_id @@ -173,17 +165,14 @@ def mix_search_memories( mem_cube_id=user_context.mem_cube_id, turns=self.history_memory_turns, ) - + logger.info(f"Found {len(history_memories)} history memories.") if not history_memories: - fast_memories = self.searcher.post_retrieve( + memories = self.searcher.post_retrieve( retrieved_results=fast_retrieved_memories, top_k=search_req.top_k, user_name=user_context.mem_cube_id, info=info, ) - # Format fast memories for return - formatted_memories = [format_textual_memory_item(data) for data in fast_memories] - return formatted_memories else: # if history memories can directly answer sorted_history_memories = self.reranker.rerank( @@ -192,7 +181,7 @@ def mix_search_memories( top_k=search_req.top_k, # Use search_req.top_k instead of undefined top_k search_filter=search_filter, ) - + logger.info(f"Reranked {len(sorted_history_memories)} history memories.") processed_hist_mem = self.searcher.post_retrieve( retrieved_results=sorted_history_memories, top_k=search_req.top_k, @@ -205,6 +194,7 @@ def mix_search_memories( ) if can_answer: + logger.info("History memories can answer the query.") sorted_results = fast_retrieved_memories + sorted_history_memories combined_results = self.searcher.post_retrieve( retrieved_results=sorted_results, @@ -213,9 +203,8 @@ def mix_search_memories( info=info, ) memories = combined_results[: search_req.top_k] - formatted_memories = [format_textual_memory_item(item) for item in memories] - logger.info("can_answer") else: + logger.info("History memories cannot answer the query, enhancing memories.") sorted_results = fast_retrieved_memories + sorted_history_memories combined_results = self.searcher.post_retrieve( retrieved_results=sorted_results, @@ -223,24 +212,53 @@ def mix_search_memories( user_name=user_context.mem_cube_id, info=info, ) - enhanced_results, _ = self.retriever.enhance_memories_with_query( + enhanced_memories, _ = self.retriever.enhance_memories_with_query( query_history=[search_req.query], memories=combined_results, ) - memories = enhanced_results[: search_req.top_k] - formatted_memories = [format_textual_memory_item(item) for item in memories] - logger.info("cannot answer") - - self.submit_memory_history_async_task( - search_req=search_req, - user_context=user_context, - memories_to_store={ - "memories": [one.to_dict() for one in memories], - "formatted_memories": formatted_memories, - }, - ) - return formatted_memories + if len(enhanced_memories) < search_req.top_k: + logger.info( + f"Enhanced memories ({len(enhanced_memories)}) are less than top_k ({search_req.top_k}). Recalling for more." + ) + missing_info_hint, trigger = self.retriever.recall_for_missing_memories( + query=search_req.query, + memories=combined_results, + ) + retrieval_size = search_req.top_k - len(enhanced_memories) + if trigger: + logger.info(f"Triggering additional search with hint: {missing_info_hint}") + additional_memories = self.searcher.search( + query=missing_info_hint, + user_name=user_context.mem_cube_id, + top_k=retrieval_size, + mode=SearchMode.FAST, + memory_type="All", + search_filter=search_filter, + info=info, + ) + else: + logger.info("Not triggering additional search, using combined results.") + additional_memories = combined_results[:retrieval_size] + logger.info( + f"Added {len(additional_memories)} more memories. Total enhanced memories: {len(enhanced_memories)}" + ) + enhanced_memories += additional_memories + + memories = enhanced_memories[: search_req.top_k] + + formatted_memories = [format_textual_memory_item(item) for item in memories] + logger.info("Submitted memory history async task.") + self.submit_memory_history_async_task( + search_req=search_req, + user_context=user_context, + memories_to_store={ + "memories": [one.to_dict() for one in memories], + "formatted_memories": formatted_memories, + }, + ) + + return formatted_memories def update_search_memories_to_redis( self, @@ -304,7 +322,7 @@ def _api_mix_search_message_consumer(self, messages: list[ScheduleMessageItem]) logger.info(f"Messages {messages} assigned to {API_MIX_SEARCH_LABEL} handler.") # Process the query in a session turn - grouped_messages = self.dispatcher._group_messages_by_user_and_mem_cube(messages=messages) + grouped_messages = group_messages_by_user_and_mem_cube(messages) self.validate_schedule_messages(messages=messages, label=API_MIX_SEARCH_LABEL) diff --git a/src/memos/mem_scheduler/schemas/general_schemas.py b/src/memos/mem_scheduler/schemas/general_schemas.py index 7f2c09b7d..524eab785 100644 --- a/src/memos/mem_scheduler/schemas/general_schemas.py +++ b/src/memos/mem_scheduler/schemas/general_schemas.py @@ -1,3 +1,5 @@ +import os + from enum import Enum from pathlib import Path from typing import NewType @@ -6,12 +8,18 @@ class SearchMode(str, Enum): """Enumeration for search modes.""" - NOT_INITIALIZED = "not_initialized" FAST = "fast" FINE = "fine" MIXTURE = "mixture" +class FineStrategy(str, Enum): + """Enumeration for fine strategies.""" + + REWRITE = "rewrite" + RECREATE = "recreate" + + FILE_PATH = Path(__file__).absolute() BASE_DIR = FILE_PATH.parent.parent.parent.parent.parent @@ -32,17 +40,17 @@ class SearchMode(str, Enum): DEFAULT_ACTIVATION_MEM_MONITOR_SIZE_LIMIT = 20 DEFAULT_ACT_MEM_DUMP_PATH = f"{BASE_DIR}/outputs/mem_scheduler/mem_cube_scheduler_test.kv_cache" DEFAULT_THREAD_POOL_MAX_WORKERS = 50 -DEFAULT_CONSUME_INTERVAL_SECONDS = 0.05 +DEFAULT_CONSUME_INTERVAL_SECONDS = 0.01 DEFAULT_CONSUME_BATCH = 1 DEFAULT_DISPATCHER_MONITOR_CHECK_INTERVAL = 300 DEFAULT_DISPATCHER_MONITOR_MAX_FAILURES = 2 DEFAULT_STUCK_THREAD_TOLERANCE = 10 -DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = 0 +DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE = -1 DEFAULT_TOP_K = 10 DEFAULT_CONTEXT_WINDOW_SIZE = 5 DEFAULT_USE_REDIS_QUEUE = True DEFAULT_MULTI_TASK_RUNNING_TIMEOUT = 30 -DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE = 10 +DEFAULT_SCHEDULER_RETRIEVER_BATCH_SIZE = 20 DEFAULT_SCHEDULER_RETRIEVER_RETRIES = 1 DEFAULT_STOP_WAIT = False @@ -75,3 +83,17 @@ class SearchMode(str, Enum): # new types UserID = NewType("UserID", str) MemCubeID = NewType("CubeID", str) + +# algorithm strategies +DEFAULT_FINE_STRATEGY = FineStrategy.REWRITE + +# Read fine strategy from environment variable `FINE_STRATEGY`. +# If provided and valid, use it; otherwise fall back to default. +_env_fine_strategy = os.getenv("FINE_STRATEGY") +if _env_fine_strategy: + try: + FINE_STRATEGY = FineStrategy(_env_fine_strategy) + except ValueError: + FINE_STRATEGY = DEFAULT_FINE_STRATEGY +else: + FINE_STRATEGY = DEFAULT_FINE_STRATEGY diff --git a/evaluation/scripts/temporal_locomo/__init__.py b/src/memos/mem_scheduler/task_schedule_modules/__init__.py similarity index 100% rename from evaluation/scripts/temporal_locomo/__init__.py rename to src/memos/mem_scheduler/task_schedule_modules/__init__.py diff --git a/src/memos/mem_scheduler/general_modules/dispatcher.py b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py similarity index 92% rename from src/memos/mem_scheduler/general_modules/dispatcher.py rename to src/memos/mem_scheduler/task_schedule_modules/dispatcher.py index b74529c8c..ac9f9a6d0 100644 --- a/src/memos/mem_scheduler/general_modules/dispatcher.py +++ b/src/memos/mem_scheduler/task_schedule_modules/dispatcher.py @@ -10,12 +10,12 @@ from memos.context.context import ContextThreadPoolExecutor from memos.log import get_logger from memos.mem_scheduler.general_modules.base import BaseSchedulerModule -from memos.mem_scheduler.general_modules.redis_queue import SchedulerRedisQueue from memos.mem_scheduler.general_modules.task_threads import ThreadManager from memos.mem_scheduler.schemas.general_schemas import DEFAULT_STOP_WAIT from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem from memos.mem_scheduler.utils.metrics import MetricsRegistry +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube logger = get_logger(__name__) @@ -151,15 +151,15 @@ def wrapped_handler(messages: list[ScheduleMessageItem]): # acknowledge redis messages - if ( - self.use_redis_queue - and self.memos_message_queue is not None - and isinstance(self.memos_message_queue, SchedulerRedisQueue) - ): + if self.use_redis_queue and self.memos_message_queue is not None: for msg in messages: redis_message_id = msg.redis_message_id # Acknowledge message processing - self.memos_message_queue.ack_message(redis_message_id=redis_message_id) + self.memos_message_queue.ack_message( + user_id=msg.user_id, + mem_cube_id=msg.mem_cube_id, + redis_message_id=redis_message_id, + ) # Mark task as completed and remove from tracking with self._task_lock: @@ -329,38 +329,6 @@ def stats(self) -> dict[str, int]: def _default_message_handler(self, messages: list[ScheduleMessageItem]) -> None: logger.debug(f"Using _default_message_handler to deal with messages: {messages}") - def _group_messages_by_user_and_mem_cube( - self, messages: list[ScheduleMessageItem] - ) -> dict[str, dict[str, list[ScheduleMessageItem]]]: - """ - Groups messages into a nested dictionary structure first by user_id, then by mem_cube_id. - - Args: - messages: List of ScheduleMessageItem objects to be grouped - - Returns: - A nested dictionary with the structure: - { - "user_id_1": { - "mem_cube_id_1": [msg1, msg2, ...], - "mem_cube_id_2": [msg3, msg4, ...], - ... - }, - "user_id_2": { - ... - }, - ... - } - Where each msg is the original ScheduleMessageItem object - """ - grouped_dict = defaultdict(lambda: defaultdict(list)) - - for msg in messages: - grouped_dict[msg.user_id][msg.mem_cube_id].append(msg) - - # Convert defaultdict to regular dict for cleaner output - return {user_id: dict(cube_groups) for user_id, cube_groups in grouped_dict.items()} - def _handle_future_result(self, future): self._futures.remove(future) try: @@ -380,7 +348,7 @@ def dispatch(self, msg_list: list[ScheduleMessageItem]): return # Group messages by user_id and mem_cube_id first - user_cube_groups = self._group_messages_by_user_and_mem_cube(msg_list) + user_cube_groups = group_messages_by_user_and_mem_cube(msg_list) # Process each user and mem_cube combination for user_id, cube_groups in user_cube_groups.items(): diff --git a/src/memos/mem_scheduler/task_schedule_modules/local_queue.py b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py new file mode 100644 index 000000000..93dd81132 --- /dev/null +++ b/src/memos/mem_scheduler/task_schedule_modules/local_queue.py @@ -0,0 +1,155 @@ +""" +Local Queue implementation for SchedulerMessageItem objects. +This module provides a local-based queue implementation that can replace +the local memos_message_queue functionality in BaseScheduler. +""" + +from memos.log import get_logger +from memos.mem_scheduler.general_modules.misc import AutoDroppingQueue as Queue +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.webservice_modules.redis_service import RedisSchedulerModule + + +logger = get_logger(__name__) + + +class SchedulerLocalQueue(RedisSchedulerModule): + def __init__( + self, + maxsize: int, + ): + """ + Initialize the SchedulerLocalQueue with a maximum queue size limit. + + Args: + maxsize (int): Maximum number of messages allowed + in each individual queue. + If exceeded, subsequent puts will block + or raise an exception based on `block` parameter. + """ + super().__init__() + + self.stream_key_prefix = "local_queue" + + self.max_internal_message_queue_size = maxsize + # Dictionary to hold per-stream queues: key = stream_key, value = Queue[ScheduleMessageItem] + self.queue_streams: dict[str, Queue[ScheduleMessageItem]] = {} + logger.info( + f"SchedulerLocalQueue initialized with max_internal_message_queue_size={maxsize}" + ) + + def get_stream_key(self, user_id: str, mem_cube_id: str) -> str: + stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}" + return stream_key + + def put( + self, message: ScheduleMessageItem, block: bool = True, timeout: float | None = None + ) -> None: + """ + Put a message into the appropriate internal queue based on user_id and mem_cube_id. + + If the corresponding queue does not exist, it is created automatically. + This method uses a local in-memory queue (not Redis) for buffering messages. + + Args: + message (ScheduleMessageItem): The message to enqueue. + block (bool): If True, block if the queue is full; if False, raise Full immediately. + timeout (float | None): Maximum time to wait for the queue to become available. + If None, block indefinitely. Ignored if block=False. + + Raises: + queue.Full: If the queue is full and block=False or timeout expires. + Exception: Any underlying error during queue.put() operation. + """ + stream_key = self.get_stream_key(user_id=message.user_id, mem_cube_id=message.mem_cube_id) + + # Create the queue if it doesn't exist yet + if stream_key not in self.queue_streams: + logger.info(f"Creating new internal queue for stream: {stream_key}") + self.queue_streams[stream_key] = Queue(maxsize=self.max_internal_message_queue_size) + + try: + self.queue_streams[stream_key].put(item=message, block=block, timeout=timeout) + logger.info( + f"Message successfully put into queue '{stream_key}'. Current size: {self.queue_streams[stream_key].qsize()}" + ) + except Exception as e: + logger.error(f"Failed to put message into queue '{stream_key}': {e}", exc_info=True) + raise # Re-raise to maintain caller expectations + + def get( + self, + user_id: str, + mem_cube_id: str, + block: bool = True, + timeout: float | None = None, + batch_size: int | None = None, + ) -> list[ScheduleMessageItem]: + if batch_size is not None and batch_size <= 0: + logger.warning( + f"get() called with invalid batch_size: {batch_size}. Returning empty list." + ) + return [] + + stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id) + + # Return empty list if queue does not exist + if stream_key not in self.queue_streams: + logger.error(f"Stream {stream_key} does not exist when trying to get messages.") + return [] + + # Note: Assumes custom Queue implementation supports batch_size parameter + res = self.queue_streams[stream_key].get( + block=block, timeout=timeout, batch_size=batch_size + ) + logger.debug( + f"Retrieved {len(res)} messages from queue '{stream_key}'. Current size: {self.queue_streams[stream_key].qsize()}" + ) + return res + + def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]: + """ + Non-blocking version of get(). Equivalent to get(block=False, batch_size=batch_size). + + Returns immediately with available messages or an empty list if queue is empty. + + Args: + batch_size (int | None): Number of messages to retrieve in a batch. + If None, retrieves one message. + + Returns: + List[ScheduleMessageItem]: Retrieved messages or empty list if queue is empty. + """ + logger.debug(f"get_nowait() called with batch_size: {batch_size}") + return self.get(block=False, batch_size=batch_size) + + def qsize(self) -> dict: + """ + Return the current size of all internal queues as a dictionary. + + Each key is the stream name, and each value is the number of messages in that queue. + + Returns: + Dict[str, int]: Mapping from stream name to current queue size. + """ + sizes = {stream: queue.qsize() for stream, queue in self.queue_streams.items()} + logger.debug(f"Current queue sizes: {sizes}") + return sizes + + def clear(self) -> None: + for queue in self.queue_streams.values(): + queue.clear() + + @property + def unfinished_tasks(self) -> int: + """ + Calculate the total number of unprocessed messages across all queues. + + This is a convenience property for monitoring overall system load. + + Returns: + int: Sum of all message counts in all internal queues. + """ + total = sum(self.qsize().values()) + logger.debug(f"Total unfinished tasks across all queues: {total}") + return total diff --git a/src/memos/mem_scheduler/general_modules/redis_queue.py b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py similarity index 74% rename from src/memos/mem_scheduler/general_modules/redis_queue.py rename to src/memos/mem_scheduler/task_schedule_modules/redis_queue.py index c10765d05..fe7e3452c 100644 --- a/src/memos/mem_scheduler/general_modules/redis_queue.py +++ b/src/memos/mem_scheduler/task_schedule_modules/redis_queue.py @@ -32,7 +32,7 @@ class SchedulerRedisQueue(RedisSchedulerModule): def __init__( self, - stream_name: str = "scheduler:messages:stream", + stream_key_prefix: str = "scheduler:messages:stream", consumer_group: str = "scheduler_group", consumer_name: str | None = "scheduler_consumer", max_len: int = 10000, @@ -43,7 +43,7 @@ def __init__( Initialize the Redis queue. Args: - stream_name: Name of the Redis stream + stream_key_prefix: Name of the Redis stream consumer_group: Name of the consumer group consumer_name: Name of the consumer (auto-generated if None) max_len: Maximum length of the stream (for memory management) @@ -57,7 +57,7 @@ def __init__( maxsize = 0 # Stream configuration - self.stream_name = stream_name + self.stream_key_prefix = stream_key_prefix self.consumer_group = consumer_group self.consumer_name = consumer_name or f"consumer_{uuid4().hex[:8]}" self.max_len = max_len @@ -77,26 +77,29 @@ def __init__( # Auto-initialize Redis connection if self.auto_initialize_redis(): self._is_connected = True - self._ensure_consumer_group() - def _ensure_consumer_group(self) -> None: + self.seen_streams = set() + + def get_stream_key(self, user_id: str, mem_cube_id: str) -> str: + stream_key = f"{self.stream_key_prefix}:{user_id}:{mem_cube_id}" + return stream_key + + def _ensure_consumer_group(self, stream_key) -> None: """Ensure the consumer group exists for the stream.""" if not self._redis_conn: return try: - self._redis_conn.xgroup_create( - self.stream_name, self.consumer_group, id="0", mkstream=True - ) + self._redis_conn.xgroup_create(stream_key, self.consumer_group, id="0", mkstream=True) logger.debug( - f"Created consumer group '{self.consumer_group}' for stream '{self.stream_name}'" + f"Created consumer group '{self.consumer_group}' for stream '{stream_key}'" ) except Exception as e: # Check if it's a "consumer group already exists" error error_msg = str(e).lower() if "busygroup" in error_msg or "already exists" in error_msg: logger.info( - f"Consumer group '{self.consumer_group}' already exists for stream '{self.stream_name}'" + f"Consumer group '{self.consumer_group}' already exists for stream '{stream_key}'" ) else: logger.error(f"Error creating consumer group: {e}", exc_info=True) @@ -123,12 +126,20 @@ def put( raise TypeError(f"Expected ScheduleMessageItem, got {type(message)}") try: + stream_key = self.get_stream_key( + user_id=message.user_id, mem_cube_id=message.mem_cube_id + ) + + if stream_key not in self.seen_streams: + self.seen_streams.add(stream_key) + self._ensure_consumer_group(stream_key=stream_key) + # Convert message to dictionary for Redis storage message_data = message.to_dict() # Add to Redis stream with automatic trimming message_id = self._redis_conn.xadd( - self.stream_name, message_data, maxlen=self.max_len, approximate=True + stream_key, message_data, maxlen=self.max_len, approximate=True ) logger.info( @@ -139,28 +150,23 @@ def put( logger.error(f"Failed to add message to Redis queue: {e}") raise - def put_nowait(self, message: ScheduleMessageItem) -> None: - """ - Add a message to the Redis queue without blocking (Queue-compatible interface). + def ack_message(self, user_id, mem_cube_id, redis_message_id) -> None: + stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id) - Args: - message: SchedulerMessageItem to add to the queue - """ - self.put(message, block=False) - - def ack_message(self, redis_message_id): - self.redis.xack(self.stream_name, self.consumer_group, redis_message_id) + self.redis.xack(stream_key, self.consumer_group, redis_message_id) # Optionally delete the message from the stream to keep it clean if self.auto_delete_acked: try: - self._redis_conn.xdel(self.stream_name, redis_message_id) + self._redis_conn.xdel(stream_key, redis_message_id) logger.info(f"Successfully delete acknowledged message {redis_message_id}") except Exception as e: logger.warning(f"Failed to delete acknowledged message {redis_message_id}: {e}") def get( self, + user_id: str, + mem_cube_id: str, block: bool = True, timeout: float | None = None, batch_size: int | None = None, @@ -169,6 +175,8 @@ def get( raise ConnectionError("Not connected to Redis. Redis connection not available.") try: + stream_key = self.get_stream_key(user_id=user_id, mem_cube_id=mem_cube_id) + # Calculate timeout for Redis redis_timeout = None if block and timeout is not None: @@ -181,7 +189,7 @@ def get( messages = self._redis_conn.xreadgroup( self.consumer_group, self.consumer_name, - {self.stream_name: ">"}, + {stream_key: ">"}, count=batch_size if not batch_size else 1, block=redis_timeout, ) @@ -190,12 +198,13 @@ def get( err_msg = str(read_err).lower() if "nogroup" in err_msg or "no such key" in err_msg: logger.warning( - f"Consumer group or stream missing for '{self.stream_name}/{self.consumer_group}'. Attempting to create and retry." + f"Consumer group or stream missing for '{stream_key}/{self.consumer_group}'. Attempting to create and retry." ) + self._ensure_consumer_group(stream_key=stream_key) messages = self._redis_conn.xreadgroup( self.consumer_group, self.consumer_name, - {self.stream_name: ">"}, + {stream_key: ">"}, count=batch_size if not batch_size else 1, block=redis_timeout, ) @@ -233,7 +242,9 @@ def get( logger.error(f"Failed to get message from Redis queue: {e}") raise - def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem]: + def get_nowait( + self, user_id: str, mem_cube_id: str, batch_size: int | None = None + ) -> list[ScheduleMessageItem]: """ Get messages from the Redis queue without blocking (Queue-compatible interface). @@ -243,76 +254,58 @@ def get_nowait(self, batch_size: int | None = None) -> list[ScheduleMessageItem] Raises: Empty: If no message is available """ - return self.get(block=False, batch_size=batch_size) + return self.get( + user_id=user_id, mem_cube_id=mem_cube_id, block=False, batch_size=batch_size + ) def qsize(self) -> int: """ Get the current size of the Redis queue (Queue-compatible interface). - Returns the number of pending (unacknowledged) messages in the consumer group, - which represents the actual queue size for processing. + This method scans for all streams matching the `stream_key_prefix` + and sums up their lengths to get the total queue size. Returns: - Number of pending messages in the queue + Total number of messages across all matching streams. """ if not self._redis_conn: return 0 + total_size = 0 try: - # Get pending messages info for the consumer group - # XPENDING returns info about pending messages that haven't been acknowledged - pending_info = self._redis_conn.xpending(self.stream_name, self.consumer_group) - - # pending_info[0] contains the count of pending messages - if pending_info and len(pending_info) > 0 and pending_info[0] is not None: - pending_count = int(pending_info[0]) - if pending_count > 0: - return pending_count - - # If no pending messages, check if there are new messages in the stream - # that haven't been read by any consumer yet - try: - # Get the last delivered ID for the consumer group - groups_info = self._redis_conn.xinfo_groups(self.stream_name) - if not groups_info: - # No groups exist, check total stream length - return self._redis_conn.xlen(self.stream_name) or 0 - - last_delivered_id = "0-0" - - for group_info in groups_info: - if group_info and group_info.get("name") == self.consumer_group: - last_delivered_id = group_info.get("last-delivered-id", "0-0") - break - - # Count messages after the last delivered ID - new_messages = self._redis_conn.xrange( - self.stream_name, - f"({last_delivered_id}", # Exclusive start - "+", # End at the latest message - count=1000, # Limit to avoid memory issues - ) + # Scan for all stream keys matching the prefix + for stream_key in self._redis_conn.scan_iter(f"{self.stream_key_prefix}:*"): + try: + # Get the length of each stream and add to total + total_size += self._redis_conn.xlen(stream_key) + except Exception as e: + logger.debug(f"Failed to get length for stream {stream_key}: {e}") + return total_size + except Exception as e: + logger.error(f"Failed to get Redis queue size: {e}") + return 0 - return len(new_messages) if new_messages else 0 + def get_stream_keys(self) -> list[str]: + """ + List all Redis stream keys that match this queue's prefix. - except Exception as inner_e: - logger.debug(f"Failed to get new messages count: {inner_e}") - # Fallback: return stream length - try: - stream_len = self._redis_conn.xlen(self.stream_name) - return stream_len if stream_len is not None else 0 - except Exception: - return 0 + Returns: + A list of stream keys like `"{prefix}:{user_id}:{mem_cube_id}"`. + """ + if not self._redis_conn: + return [] + try: + # Use match parameter and decode byte strings to regular strings + stream_keys = [ + key.decode("utf-8") if isinstance(key, bytes) else key + for key in self._redis_conn.scan_iter(match=f"{self.stream_key_prefix}:*") + ] + logger.debug(f"get stream_keys from redis: {stream_keys}") + return stream_keys except Exception as e: - logger.debug(f"Failed to get Redis queue size via XPENDING: {e}") - # Fallback to stream length if pending check fails - try: - stream_len = self._redis_conn.xlen(self.stream_name) - return stream_len if stream_len is not None else 0 - except Exception as fallback_e: - logger.error(f"Failed to get Redis queue size (all methods failed): {fallback_e}") - return 0 + logger.error(f"Failed to list Redis stream keys: {e}") + return [] def size(self) -> int: """ @@ -360,12 +353,13 @@ def clear(self) -> None: return try: - # Delete the entire stream - self._redis_conn.delete(self.stream_name) - logger.info(f"Cleared Redis stream: {self.stream_name}") + stream_keys = self.get_stream_keys() + + for stream_key in stream_keys: + # Delete the entire stream + self._redis_conn.delete(stream_key) + logger.info(f"Cleared Redis stream: {stream_key}") - # Recreate the consumer group - self._ensure_consumer_group() except Exception as e: logger.error(f"Failed to clear Redis queue: {e}") @@ -389,7 +383,7 @@ def start_listening( self._message_handler = handler self._is_listening = True - logger.info(f"Started listening on Redis stream: {self.stream_name}") + logger.info(f"Started listening on Redis stream: {self.stream_key_prefix}") try: while self._is_listening: diff --git a/src/memos/mem_scheduler/task_schedule_modules/task_queue.py b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py new file mode 100644 index 000000000..74f1ad1f8 --- /dev/null +++ b/src/memos/mem_scheduler/task_schedule_modules/task_queue.py @@ -0,0 +1,151 @@ +""" +Redis Queue implementation for SchedulerMessageItem objects. + +This module provides a Redis-based queue implementation that can replace +the local memos_message_queue functionality in BaseScheduler. +""" + +from collections import defaultdict + +from memos.log import get_logger +from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem +from memos.mem_scheduler.task_schedule_modules.local_queue import SchedulerLocalQueue +from memos.mem_scheduler.task_schedule_modules.redis_queue import SchedulerRedisQueue +from memos.mem_scheduler.utils.db_utils import get_utc_now +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube + + +logger = get_logger(__name__) + + +class ScheduleTaskQueue: + def __init__( + self, + use_redis_queue: bool, + maxsize: int, + disabled_handlers: list | None = None, + ): + self.use_redis_queue = use_redis_queue + self.maxsize = maxsize + + if self.use_redis_queue: + self.memos_message_queue = SchedulerRedisQueue(maxsize=self.maxsize) + else: + self.memos_message_queue = SchedulerLocalQueue(maxsize=self.maxsize) + + self.disabled_handlers = disabled_handlers + + def ack_message( + self, + user_id, + mem_cube_id, + redis_message_id, + ) -> None: + if not isinstance(self.memos_message_queue, SchedulerRedisQueue): + logger.warning("ack_message is only supported for Redis queues") + return + + self.memos_message_queue.ack_message( + user_id=user_id, + mem_cube_id=mem_cube_id, + redis_message_id=redis_message_id, + ) + + def debug_mode_on(self): + self.memos_message_queue.stream_key_prefix = ( + f"debug_mode:{self.memos_message_queue.stream_key_prefix}" + ) + + def get_stream_keys(self) -> list[str]: + if isinstance(self.memos_message_queue, SchedulerRedisQueue): + return self.memos_message_queue.get_stream_keys() + else: + return list(self.memos_message_queue.queue_streams.keys()) + + def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageItem]): + """Submit messages to the message queue (either local queue or Redis).""" + if isinstance(messages, ScheduleMessageItem): + messages = [messages] + + if len(messages) < 1: + logger.error("Submit empty") + elif len(messages) == 1: + self.memos_message_queue.put(messages[0]) + else: + user_cube_groups = group_messages_by_user_and_mem_cube(messages) + + # Process each user and mem_cube combination + for _user_id, cube_groups in user_cube_groups.items(): + for _mem_cube_id, user_cube_msgs in cube_groups.items(): + for message in user_cube_msgs: + if not isinstance(message, ScheduleMessageItem): + error_msg = f"Invalid message type: {type(message)}, expected ScheduleMessageItem" + logger.error(error_msg) + raise TypeError(error_msg) + + if getattr(message, "timestamp", None) is None: + message.timestamp = get_utc_now() + + if self.disabled_handlers and message.label in self.disabled_handlers: + logger.info( + f"Skipping disabled handler: {message.label} - {message.content}" + ) + continue + + self.memos_message_queue.put(message) + logger.info( + f"Submitted message to local queue: {message.label} - {message.content}" + ) + + def get_messages(self, batch_size: int) -> list[ScheduleMessageItem]: + # Discover all active streams via queue API + streams: list[tuple[str, str]] = [] + + stream_keys = self.get_stream_keys() + for stream_key in stream_keys: + try: + parts = stream_key.split(":") + if len(parts) >= 3: + user_id = parts[-2] + mem_cube_id = parts[-1] + streams.append((user_id, mem_cube_id)) + except Exception as e: + logger.debug(f"Failed to parse stream key {stream_key}: {e}") + + if not streams: + return [] + + messages: list[ScheduleMessageItem] = [] + + # Group by user: {user_id: [mem_cube_id, ...]} + + streams_by_user: dict[str, list[str]] = defaultdict(list) + for user_id, mem_cube_id in streams: + streams_by_user[user_id].append(mem_cube_id) + + # For each user, fairly consume up to batch_size across their streams + for user_id, mem_cube_ids in streams_by_user.items(): + if not mem_cube_ids: + continue + + # First pass: give each stream an equal share for this user + for mem_cube_id in mem_cube_ids: + fetched = self.memos_message_queue.get( + user_id=user_id, + mem_cube_id=mem_cube_id, + block=False, + batch_size=batch_size, + ) + + messages.extend(fetched) + + logger.info( + f"Fetched {len(messages)} messages across users with per-user batch_size={batch_size}" + ) + return messages + + def clear(self): + self.memos_message_queue.clear() + + def qsize(self): + return self.memos_message_queue.qsize() diff --git a/src/memos/mem_scheduler/utils/misc_utils.py b/src/memos/mem_scheduler/utils/misc_utils.py index cce1286bb..7b0bcea34 100644 --- a/src/memos/mem_scheduler/utils/misc_utils.py +++ b/src/memos/mem_scheduler/utils/misc_utils.py @@ -2,12 +2,16 @@ import re import traceback +from collections import defaultdict from functools import wraps from pathlib import Path import yaml from memos.log import get_logger +from memos.mem_scheduler.schemas.message_schemas import ( + ScheduleMessageItem, +) logger = get_logger(__name__) @@ -216,3 +220,36 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +def group_messages_by_user_and_mem_cube( + messages: list[ScheduleMessageItem], +) -> dict[str, dict[str, list[ScheduleMessageItem]]]: + """ + Groups messages into a nested dictionary structure first by user_id, then by mem_cube_id. + + Args: + messages: List of ScheduleMessageItem objects to be grouped + + Returns: + A nested dictionary with the structure: + { + "user_id_1": { + "mem_cube_id_1": [msg1, msg2, ...], + "mem_cube_id_2": [msg3, msg4, ...], + ... + }, + "user_id_2": { + ... + }, + ... + } + Where each msg is the original ScheduleMessageItem object + """ + grouped_dict = defaultdict(lambda: defaultdict(list)) + + for msg in messages: + grouped_dict[msg.user_id][msg.mem_cube_id].append(msg) + + # Convert defaultdict to regular dict for cleaner output + return {user_id: dict(cube_groups) for user_id, cube_groups in grouped_dict.items()} diff --git a/src/memos/memories/textual/tree.py b/src/memos/memories/textual/tree.py index 15a6a8b49..1b2355bc8 100644 --- a/src/memos/memories/textual/tree.py +++ b/src/memos/memories/textual/tree.py @@ -161,7 +161,7 @@ def search( info=None, mode: str = "fast", memory_type: str = "All", - manual_close_internet: bool = False, + manual_close_internet: bool = True, moscube: bool = False, search_filter: dict | None = None, user_name: str | None = None, @@ -189,9 +189,6 @@ def search( list[TextualMemoryItem]: List of matching memories. """ if (self.internet_retriever is not None) and manual_close_internet: - logger.warning( - "Internet retriever is init by config , but this search set manual_close_internet is True and will close it" - ) searcher = Searcher( self.dispatcher_llm, self.graph_store, @@ -201,6 +198,7 @@ def search( internet_retriever=None, moscube=moscube, search_strategy=self.search_strategy, + manual_close_internet=manual_close_internet, ) else: searcher = Searcher( @@ -212,6 +210,7 @@ def search( internet_retriever=self.internet_retriever, moscube=moscube, search_strategy=self.search_strategy, + manual_close_internet=manual_close_internet, ) return searcher.search( query, top_k, info, mode, memory_type, search_filter, user_name=user_name diff --git a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py index 14ea8e2cb..933ef5af1 100644 --- a/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py +++ b/src/memos/memories/textual/tree_text_memory/retrieve/searcher.py @@ -43,6 +43,7 @@ def __init__( internet_retriever: None = None, moscube: bool = False, search_strategy: dict | None = None, + manual_close_internet: bool = True, ): self.graph_store = graph_store self.embedder = embedder @@ -58,7 +59,7 @@ def __init__( self.moscube = moscube self.vec_cot = search_strategy.get("cot", False) if search_strategy else False self.use_fast_graph = search_strategy.get("fast_graph", False) if search_strategy else False - + self.manual_close_internet = manual_close_internet self._usage_executor = ContextThreadPoolExecutor(max_workers=4, thread_name_prefix="usage") @timed @@ -72,7 +73,7 @@ def retrieve( search_filter: dict | None = None, user_name: str | None = None, **kwargs, - ) -> list[TextualMemoryItem]: + ) -> list[tuple[TextualMemoryItem, float]]: logger.info( f"[RECALL] Start query='{query}', top_k={top_k}, mode={mode}, memory_type={memory_type}" ) @@ -94,7 +95,7 @@ def retrieve( def post_retrieve( self, - retrieved_results: list[TextualMemoryItem], + retrieved_results: list[tuple[TextualMemoryItem, float]], top_k: int, user_name: str | None = None, info=None, @@ -458,7 +459,7 @@ def _retrieve_from_internet( user_id: str | None = None, ): """Retrieve and rerank from Internet source""" - if not self.internet_retriever or mode == "fast": + if not self.internet_retriever or self.manual_close_internet: logger.info(f"[PATH-C] '{query}' Skipped (no retriever, fast mode)") return [] if memory_type not in ["All"]: diff --git a/src/memos/templates/mem_scheduler_prompts.py b/src/memos/templates/mem_scheduler_prompts.py index 197a2c1a7..7f7415e79 100644 --- a/src/memos/templates/mem_scheduler_prompts.py +++ b/src/memos/templates/mem_scheduler_prompts.py @@ -390,24 +390,22 @@ - Focus on whether the memories can fully answer the query without additional information """ -MEMORY_ENHANCEMENT_PROMPT = """ +MEMORY_RECREATE_ENHANCEMENT_PROMPT = """ You are a knowledgeable and precise AI assistant. # GOAL -Transform each raw memory into an enhanced version that preserves all relevant factual details and makes the information directly useful for answering the user's query. - -# CORE PRINCIPLE -Focus on **relevance** — the enhanced memories should highlight, clarify, and preserve the information that most directly supports answering the current query. +Transform raw memories into clean, query-relevant facts — preserving timestamps and resolving ambiguities without inference. # RULES & THINKING STEPS -1. Read the user query carefully and identify what specific facts are needed to answer it. -2. Go through each memory and: - - Keep only details directly relevant to the query (dates, actions, entities, outcomes). - - Remove unrelated or background details. - - If nothing in a memory relates to the query, delete the entire memory. -3. Do not add or infer new facts. -4. Keep facts accurate and phrased clearly. -5. Each resulting line should stand alone as a usable fact for answering the query. +1. Keep ONLY what’s relevant to the user’s query. Delete irrelevant memories entirely. +2. Preserve ALL explicit timestamps (e.g., “on October 6”, “daily”, “after injury”). +3. Resolve all ambiguities using only memory content: + - Pronouns → full name: “she” → “Melanie” + - Vague nouns → specific detail: “home” → “her childhood home in Guangzhou” + - “the user” → identity from context (e.g., “Melanie” if travel/running memories) +4. Never invent, assume, or extrapolate. +5. Each output line must be a standalone, clear, factual statement. +6. Output format: one line per fact, starting with "- ", no extra text. # OUTPUT FORMAT (STRICT) Return ONLY the following block, with **one enhanced memory per line**. @@ -423,12 +421,91 @@ ## User Query {query_history} -## Available Memories +## Original Memories {memories} -Answer: +Final Output: +""" + +# Rewrite version: return enhanced memories with original IDs +MEMORY_REWRITE_ENHANCEMENT_PROMPT = """ +You are a knowledgeable and precise AI assistant. + +# GOAL +Transform raw memories into clean, query-relevant facts — preserving timestamps and resolving ambiguities without inference. Return each enhanced fact with the ID of the original memory being modified. + +# RULES & THINKING STEPS +1. Keep ONLY what’s relevant to the user’s query. Delete irrelevant memories entirely. +2. Preserve ALL explicit timestamps (e.g., “on October 6”, “daily”, “after injury”). +3. Resolve all ambiguities using only memory content: + - Pronouns → full name: “she” → “Melanie” + - Vague nouns → specific detail: “home” → “her childhood home in Guangzhou” + - “the user” → identity from context (e.g., “Melanie” if travel/running memories) +4. Never invent, assume, or extrapolate. +5. Each output line must be a standalone, clear, factual statement. +6. Output format: one line per fact, starting with "- ", no extra text. + +# IMPORTANT FOR REWRITE +- Each output line MUST include the original memory’s ID shown in the input list. +- Use the index shown for each original memory (e.g., "[0]", "[1]") as the ID to reference which memory you are rewriting. +- For every rewritten line, prefix with the corresponding index in square brackets. + +# OUTPUT FORMAT (STRICT) +Return ONLY the following block, with **one enhanced memory per line**. +Each line MUST start with "- " (dash + space) AND include index in square brackets. + +Wrap the final output inside: + +- [index] enhanced memory 1 +- [index] enhanced memory 2 +... + + +## User Query +{query_history} + +## Original Memories +{memories} + +Final Output: """ +# One-sentence prompt for recalling missing information to answer the query (English) +ENLARGE_RECALL_PROMPT_ONE_SENTENCE = """ +You are a precise AI assistant. Your job is to analyze the user's query and the available memories to identify what specific information is missing to fully answer the query. + +# GOAL + +Identify the specific missing facts needed to fully answer the user's query and generate a concise hint for recalling them. + +# RULES + +- Analyze the user's query to understand what information is being asked. +- Review the available memories to see what information is already present. +- Identify the gap between the user's query and the available memories. +- Generate a single, concise hint that prompts the user to provide the missing information. +- The hint should be a direct question or a statement that clearly indicates what is needed. + +# OUTPUT FORMAT +A JSON object with: + +trigger_retrieval: true if information is missing, false if sufficient. +hint: A clear, specific prompt to retrieve the missing information (or an empty string if trigger_retrieval is false): +{{ + "trigger_recall": , + "hint": a paraphrase to retrieve support memories +}} + +## User Query +{query} + +## Available Memories +{memories_inline} + +Final Output: +""" + + PROMPT_MAPPING = { "intent_recognizing": INTENT_RECOGNIZING_PROMPT, "memory_reranking": MEMORY_RERANKING_PROMPT, @@ -437,7 +514,9 @@ "memory_redundancy_filtering": MEMORY_REDUNDANCY_FILTERING_PROMPT, "memory_combined_filtering": MEMORY_COMBINED_FILTERING_PROMPT, "memory_answer_ability_evaluation": MEMORY_ANSWER_ABILITY_EVALUATION_PROMPT, - "memory_enhancement": MEMORY_ENHANCEMENT_PROMPT, + "memory_recreate_enhancement": MEMORY_RECREATE_ENHANCEMENT_PROMPT, + "memory_rewrite_enhancement": MEMORY_REWRITE_ENHANCEMENT_PROMPT, + "enlarge_recall": ENLARGE_RECALL_PROMPT_ONE_SENTENCE, } MEMORY_ASSEMBLY_TEMPLATE = """The retrieved memories are listed as follows:\n\n {memory_text}""" diff --git a/tests/mem_scheduler/test_dispatcher.py b/tests/mem_scheduler/test_dispatcher.py index fc154e013..e687d2986 100644 --- a/tests/mem_scheduler/test_dispatcher.py +++ b/tests/mem_scheduler/test_dispatcher.py @@ -14,10 +14,11 @@ ) from memos.llms.base import BaseLLM from memos.mem_cube.general import GeneralMemCube -from memos.mem_scheduler.general_modules.dispatcher import SchedulerDispatcher from memos.mem_scheduler.scheduler_factory import SchedulerFactory from memos.mem_scheduler.schemas.message_schemas import ScheduleMessageItem from memos.mem_scheduler.schemas.task_schemas import RunningTaskItem +from memos.mem_scheduler.task_schedule_modules.dispatcher import SchedulerDispatcher +from memos.mem_scheduler.utils.misc_utils import group_messages_by_user_and_mem_cube from memos.memories.textual.tree import TreeTextMemory @@ -192,9 +193,8 @@ def test_dispatch_serial(self): def test_group_messages_by_user_and_mem_cube(self): """Test grouping messages by user and cube.""" - # Check actual grouping logic - with patch("memos.mem_scheduler.general_modules.dispatcher.logger.debug"): - result = self.dispatcher._group_messages_by_user_and_mem_cube(self.test_messages) + # Check actual grouping logic using shared utility function + result = group_messages_by_user_and_mem_cube(self.test_messages) # Adjust expected results based on actual grouping logic # Note: According to dispatcher.py implementation, grouping is by mem_cube_id not mem_cube