From 6e5002e788293d8824bf29acce95aae043a33fc0 Mon Sep 17 00:00:00 2001 From: Yan Xiao Date: Tue, 14 Apr 2026 11:25:40 +0800 Subject: [PATCH 1/3] Migrate to online llm clustering --- .../demo/utils/simple_memory_manager.py | 67 +- methods/evermemos/env.template | 17 +- .../src/agentic_layer/get_mem_service.py | 4 +- .../src/agentic_layer/memory_manager.py | 3 + .../src/agentic_layer/retrieval_utils.py | 3 + .../src/agentic_layer/search_mem_service.py | 151 +- .../evermemos/src/api_specs/dtos/memory.py | 31 - .../evermemos/src/biz_layer/mem_memorize.py | 975 +++++------- .../src/biz_layer/memorize_config.py | 97 +- .../evermemos/src/common_utils/async_utils.py | 32 + .../src/core/constants/exceptions.py | 260 +--- .../src/core/tenants/init_tenant_all.py | 4 +- .../src/core/tenants/tenant_constants.py | 20 +- .../tenantize/kv/redis/tenant_key_utils.py | 33 +- .../tenants/tenantize/tenant_cache_utils.py | 15 +- .../data_fix/milvus_rebuild_collection.py | 17 +- .../input/api/memory/memory_controller.py | 137 +- .../persistence/document/memory/mem_scene.py | 14 +- .../repository/agent_case_raw_repository.py | 33 + .../converter/agent_skill_milvus_converter.py | 3 +- methods/evermemos/src/manage.py | 7 +- .../memory_layer/cluster_manager/config.py | 10 + .../memory_layer/cluster_manager/manager.py | 630 ++++---- .../src/memory_layer/llm/api_key_rotator.py | 90 ++ .../src/memory_layer/llm/llm_metrics.py | 61 + .../src/memory_layer/llm/openai_provider.py | 488 +++--- .../conv_memcell_extractor.py | 78 +- .../memory_extractor/agent_case_extractor.py | 14 +- .../memory_extractor/agent_skill_extractor.py | 253 ++-- .../episode_memory_extractor.py | 5 +- .../src/memory_layer/memory_manager.py | 17 +- .../memory_layer/profile_manager/manager.py | 13 +- .../src/memory_layer/prompts/__init__.py | 7 +- .../memory_layer/prompts/en/agent_prompts.py | 108 +- .../prompts/en/cluster_prompts.py | 1 + .../memory_layer/prompts/zh/agent_prompts.py | 4 +- .../prompts/zh/cluster_prompts.py | 1 + .../src/service/memcell_delete_service.py | 3 + .../test_agent_converters_and_pipeline.py | 26 +- .../tests/test_agent_skill_extractor.py | 8 +- .../test_agent_skill_relevance_verify.py | 80 +- .../evermemos/tests/test_api_key_rotator.py | 104 ++ .../tests/test_cluster_memcell_llm.py | 458 ++++++ methods/evermemos/tests/test_llm_metrics.py | 200 +++ methods/evermemos/tests/test_mem_memorize.py | 1318 ----------------- .../test_openai_provider_key_rotation.py | 310 ++++ .../tests/test_profile_extraction_interval.py | 29 + .../tests/test_tenant_cache_utils.py | 156 ++ 48 files changed, 3053 insertions(+), 3342 deletions(-) create mode 100644 methods/evermemos/src/common_utils/async_utils.py create mode 100644 methods/evermemos/src/memory_layer/llm/api_key_rotator.py create mode 100644 methods/evermemos/src/memory_layer/llm/llm_metrics.py create mode 100644 methods/evermemos/src/memory_layer/prompts/en/cluster_prompts.py create mode 100644 methods/evermemos/src/memory_layer/prompts/zh/cluster_prompts.py create mode 100644 methods/evermemos/tests/test_api_key_rotator.py create mode 100644 methods/evermemos/tests/test_cluster_memcell_llm.py create mode 100644 methods/evermemos/tests/test_llm_metrics.py delete mode 100644 methods/evermemos/tests/test_mem_memorize.py create mode 100644 methods/evermemos/tests/test_openai_provider_key_rotation.py create mode 100644 methods/evermemos/tests/test_tenant_cache_utils.py diff --git a/methods/evermemos/demo/utils/simple_memory_manager.py b/methods/evermemos/demo/utils/simple_memory_manager.py index f3ed0dd66..f7b9499fd 100644 --- a/methods/evermemos/demo/utils/simple_memory_manager.py +++ b/methods/evermemos/demo/utils/simple_memory_manager.py @@ -80,7 +80,6 @@ def __init__( base_url: str = "http://localhost:1995", group_id: str = "default_group", scene: str = ScenarioType.SOLO.value, - user_id: str = "demo_user", ): """Initialize the manager @@ -88,13 +87,11 @@ def __init__( base_url: API server address (default: localhost:1995) group_id: Group ID (default: default_group) scene: Scene type (default: "solo", options: "solo" or "team") - user_id: User ID for personal endpoint (default: "demo_user") """ self.base_url = base_url self.group_id = group_id self.group_name = "Simple Demo Group" self.scene = scene - self.user_id = user_id self.memorize_url = f"{base_url}/api/v1/memories" self.retrieve_url = f"{base_url}/api/v1/memories/search" self.settings_url = f"{base_url}/api/v1/settings" @@ -122,32 +119,29 @@ async def store(self, content: str, sender: str = "User") -> bool: ) # Use project's unified time utility (with timezone) message_id = f"msg_{self._message_counter}_{int(now.timestamp() * 1000)}" - # Build v1 PersonalAddRequest payload - role = "user" if sender.lower() == "user" else "assistant" - message_item = { + # Build message data (completely consistent with test_v1api_search.py format) + message_data = { "message_id": message_id, - "sender_id": self.user_id if role == "user" else sender, - "sender_name": sender, - "role": role, - "timestamp": int(now.timestamp() * 1000), + "create_time": to_iso_format( + now + ), # Use project's unified time formatting (with timezone) + "sender": sender, + "sender_name": sender, # Consistent with JSON data format + "type": "text", # Message type "content": content, - } - payload = { - "user_id": self.user_id, - "messages": [message_item], + "group_id": self.group_id, + "group_name": self.group_name, + "scene": self.scene, # Use configured scene } try: async with httpx.AsyncClient(timeout=500.0) as client: - response = await client.post(self.memorize_url, json=payload) + response = await client.post(self.memorize_url, json=message_data) response.raise_for_status() result = response.json() - # v1 response: {"data": {"status": "...", "count": N, ...}} - data = result.get("data", {}) - status = data.get("status", "") - count = data.get("count", 0) - if status: + if result.get("status") == "ok": + count = result.get("result", {}).get("count", 0) if count > 0: print( f" ✅ Stored: {content[:40]}... (Extracted {count} memories)" @@ -206,48 +200,51 @@ async def _init_settings(self) -> bool: return False async def search( - self, query: str, top_k: int = 3, mode: str = "vector", show_details: bool = True + self, query: str, top_k: int = 3, mode: str = "rrf", show_details: bool = True ) -> List[Dict[str, Any]]: """Search memories Args: query: Query text top_k: Number of results to return (default: 3) - mode: + mode: Retrieval mode (default: "rrf") + - "rrf": RRF fusion (recommended) - "keyword": Keyword retrieval (BM25) - "vector": Vector retrieval - "hybrid": Keyword + Vector + Rerank + - "rrf": Keyword + Vector + RRF fusion - "agentic": LLM-guided multi-round retrieval show_details: Whether to show detailed information (default: True) Returns: List of memories """ - # v1 SearchMemoriesRequest: POST with body {query, method, memory_types, top_k, filters} payload = { "query": query, - "method": mode, - "memory_types": ["episodic_memory"], "top_k": top_k, - "filters": {"user_id": self.user_id}, + "memory_types": "episodic_memory", + "retrieve_method": mode, + "group_id": self.group_id, } try: async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.post(self.retrieve_url, json=payload) + response = await client.get(self.retrieve_url, params=payload) response.raise_for_status() result = response.json() - # v1 response: {"data": {"episodes": [...], "profiles": [...], "raw_messages": [...], "agent_memory": ...}} - data = result.get("data", {}) - if data: - # Aggregate across memory_type buckets (we only requested episodic_memory here) - memories = [] - for key in ("episodes", "profiles", "raw_messages"): - memories.extend(data.get(key) or []) - metadata = data.get("metadata", {}) or {} + if result.get("status") == "ok": + # memories is grouped: [{"group_id": [Memory, ...]}, ...] + raw_memories = result.get("result", {}).get("memories", []) + metadata = result.get("result", {}).get("metadata", {}) latency = metadata.get("total_latency_ms", 0) + # Flatten grouped memories to flat list + memories = [] + for group_dict in raw_memories: + for group_id, mem_list in group_dict.items(): + memories.extend(mem_list) + if show_details: print( f" 🔍 Found {len(memories)} memories (took {latency:.2f}ms)" diff --git a/methods/evermemos/env.template b/methods/evermemos/env.template index 5bb3bc25f..ee8beed18 100755 --- a/methods/evermemos/env.template +++ b/methods/evermemos/env.template @@ -30,7 +30,9 @@ LLM_BASE_URL=https://openrouter.ai/api/v1 # OpenRouter Configuration # Preferred provider naming rule: # {PROVIDER}_API_KEY / {PROVIDER}_BASE_URL -OPENROUTER_API_KEY=sk-or-v1-xxxx +# Supports multiple keys (comma-separated) for rate-limit distribution: +# OPENROUTER_API_KEY=key1,key2,key3 +OPENROUTER_API_KEY=your-openrouter-api-key OPENROUTER_BASE_URL=https://openrouter.ai/api/v1 # Also supported: # {PROVIDER}_LLM_API_KEY / {PROVIDER}_LLM_BASE_URL @@ -215,19 +217,14 @@ RERANK_SCORE_THRESHOLD=0.6 AGENTIC_ROUND1_RERANK_TOP_N=10 # =================== -# Agent Memorize Configuration +# Agent Memorize Mode # =================== -# Agent memorize mode: "online" (real-time, embedding-based clustering; -# LLM-based clustering is WIP) or "fast_skill" (batched, LLM-based clustering; -# skips non-skill memory). -# Default: online +# Controls which MemorizeConfig is used for agent conversations. +# - online: full pipeline (default) +# - fast_skill: skip profile/foresight/eventlog, skip maturity scoring AGENT_MEMORIZE_MODE=online -# Clustering similarity threshold for agent memory (0.0-1.0) -# Default: 0.5 -AGENT_CLUSTER_SIMILARITY_THRESHOLD=0.5 - # =================== # Environment & Logging Configuration # =================== diff --git a/methods/evermemos/src/agentic_layer/get_mem_service.py b/methods/evermemos/src/agentic_layer/get_mem_service.py index 641f02b4f..84e6edb46 100644 --- a/methods/evermemos/src/agentic_layer/get_mem_service.py +++ b/methods/evermemos/src/agentic_layer/get_mem_service.py @@ -43,7 +43,7 @@ from infra_layer.adapters.out.persistence.repository.agent_skill_raw_repository import ( AgentSkillRawRepository, ) -from biz_layer.memorize_config import AGENT_DEFAULT_MEMORIZE_CONFIG +from biz_layer.memorize_config import DEFAULT_MEMORIZE_CONFIG logger = logging.getLogger(__name__) @@ -243,7 +243,7 @@ async def _get_agent_skills( self, mongo_filter: dict, skip: int, limit: int, sort: list ) -> GetMemResponse: """Query v1_agent_skills via repository and return GetMemResponse.""" - retire_confidence = AGENT_DEFAULT_MEMORIZE_CONFIG.skill_retire_confidence + retire_confidence = DEFAULT_MEMORIZE_CONFIG.skill_retire_confidence mongo_filter.setdefault("confidence", {"$gte": retire_confidence}) with timed("query_memories"): docs, total_count = await self._agent_skill_repo.find_by_query( diff --git a/methods/evermemos/src/agentic_layer/memory_manager.py b/methods/evermemos/src/agentic_layer/memory_manager.py index 2cd03643d..e6e930057 100644 --- a/methods/evermemos/src/agentic_layer/memory_manager.py +++ b/methods/evermemos/src/agentic_layer/memory_manager.py @@ -1077,6 +1077,9 @@ async def do_search(q: str) -> List[Dict]: round2_results = await asyncio.gather( *[do_search(q) for q in refined_queries], return_exceptions=True ) + from common_utils.async_utils import reraise_critical_errors + + reraise_critical_errors(round2_results) all_round2 = [ h for r in round2_results if not isinstance(r, Exception) for h in r ] diff --git a/methods/evermemos/src/agentic_layer/retrieval_utils.py b/methods/evermemos/src/agentic_layer/retrieval_utils.py index 8e66284c8..c592b48aa 100644 --- a/methods/evermemos/src/agentic_layer/retrieval_utils.py +++ b/methods/evermemos/src/agentic_layer/retrieval_utils.py @@ -350,6 +350,9 @@ async def multi_query_retrieval( ] multi_query_results = await asyncio.gather(*tasks, return_exceptions=True) + from common_utils.async_utils import reraise_critical_errors + + reraise_critical_errors(multi_query_results) # Collect valid results valid_results = [] diff --git a/methods/evermemos/src/agentic_layer/search_mem_service.py b/methods/evermemos/src/agentic_layer/search_mem_service.py index f325657dc..f4186501d 100644 --- a/methods/evermemos/src/agentic_layer/search_mem_service.py +++ b/methods/evermemos/src/agentic_layer/search_mem_service.py @@ -91,7 +91,7 @@ from service.raw_message_service import RawMessageService # Memorize config (for skill_retire_confidence threshold) -from biz_layer.memorize_config import AGENT_DEFAULT_MEMORIZE_CONFIG +from biz_layer.memorize_config import DEFAULT_MEMORIZE_CONFIG # Constants from core.oxm.constants import MAGIC_ALL @@ -510,6 +510,11 @@ async def search_memories( results = await asyncio.gather(*search_tasks, return_exceptions=True) search_duration = time.perf_counter() - search_start + # Propagate critical system errors before processing results + from common_utils.async_utils import reraise_critical_errors + + reraise_critical_errors(results) + # Collect results from parallel searches has_error = False total_result_count = 0 @@ -1393,7 +1398,7 @@ async def _fetch_agent_skills_by_ids( id_filter = self._build_mongo_id_filter(skill_ids) if not id_filter: return {} - retire_confidence = AGENT_DEFAULT_MEMORIZE_CONFIG.skill_retire_confidence + retire_confidence = DEFAULT_MEMORIZE_CONFIG.skill_retire_confidence id_filter["confidence"] = {"$gte": retire_confidence} docs = await AgentSkillRecord.find_many(id_filter).to_list() return {str(d.id): d for d in docs} @@ -1442,7 +1447,7 @@ async def _search_agent_skills( agent_skill_mt = MemoryType.AGENT_SKILL.value with timed("agent_skill_search"): - retire_confidence = AGENT_DEFAULT_MEMORIZE_CONFIG.skill_retire_confidence + retire_confidence = DEFAULT_MEMORIZE_CONFIG.skill_retire_confidence if method == "keyword": stage_start = time.perf_counter() @@ -1575,12 +1580,74 @@ async def _search_agent_skills( query=query, filter_values=filter_values, top_k=top_k ) - if AGENT_DEFAULT_MEMORIZE_CONFIG.enable_skill_llm_verify: - if results and query: - results = await self._verify_skill_relevance(query, results) + if results and DEFAULT_MEMORIZE_CONFIG.enable_skill_llm_verify: + results = await self._verify_skill_relevance(query, results) return results + async def _verify_skill_relevance( + self, query: str, skills: List[SearchAgentSkillItem] + ) -> List[SearchAgentSkillItem]: + """Use LLM to post-verify whether retrieved skills are relevant to the query.""" + import json + from common_utils.json_utils import parse_json_response + from memory_layer.prompts import get_prompt_by + from memory_layer.llm.llm_provider import build_default_provider + + if not skills or not query: + return skills + + skills_for_prompt = [ + { + "index": i, + "name": skill.name or "", + "description": skill.description or "", + "content": skill.content or "", + } + for i, skill in enumerate(skills) + ] + + prompt_template = get_prompt_by("AGENT_SKILL_RELEVANCE_VERIFY_PROMPT") + prompt = prompt_template.format( + query=query, skills_json=json.dumps(skills_for_prompt, ensure_ascii=False) + ) + + try: + llm_provider = build_default_provider() + response_text = await llm_provider.generate( + prompt, temperature=0.0, response_format={"type": "json_object"} + ) + + result = parse_json_response(response_text) + score_map = { + item["index"]: item.get("score", 0.0) + for item in result.get("results", []) + } + + scored = [] + for i, skill in enumerate(skills): + relevance_score = score_map.get(i, 0.0) + if relevance_score >= 0.4: + skill.score = relevance_score + scored.append(skill) + + scored.sort(key=lambda s: s.score, reverse=True) + + logger.info( + "Skill relevance verification: %d/%d skills passed (threshold=0.4) for query: %s", + len(scored), + len(skills), + query[:60], + ) + + return scored + + except Exception as e: + logger.warning( + "Skill relevance verification failed, returning all results: %s", e + ) + return skills + # ------------------------------------------------------------------ # Agentic retrieval for agent memory types # ------------------------------------------------------------------ @@ -1628,7 +1695,6 @@ async def _search_agentic_agent_skills( ) -> List[SearchAgentSkillItem]: """Search agent skills using agentic retrieval (LLM-guided multi-round). - After retrieval, uses LLM to verify skill relevance to the query. """ retrieve_request = RetrieveMemRequest( query=query, @@ -1663,74 +1729,3 @@ async def _search_agentic_agent_skills( results.append(self._agent_skill_doc_to_item(doc, score=score)) return results - - async def _verify_skill_relevance( - self, query: str, skills: List[SearchAgentSkillItem] - ) -> List[SearchAgentSkillItem]: - """Use LLM to post-verify whether retrieved skills are relevant to the query. - - Args: - query: The user's search query - skills: List of SearchAgentSkillItem results - - Returns: - Filtered list containing only helpful skills - """ - import json - from common_utils.json_utils import parse_json_response - from memory_layer.prompts import get_prompt_by - from memory_layer.llm.llm_provider import build_default_provider - - if not skills or not query: - return skills - - skills_for_prompt = [ - { - "index": i, - "name": skill.name or "", - "description": skill.description or "", - "content": skill.content or "", - } - for i, skill in enumerate(skills) - ] - - prompt_template = get_prompt_by("AGENT_SKILL_RELEVANCE_VERIFY_PROMPT") - prompt = prompt_template.format( - query=query, skills_json=json.dumps(skills_for_prompt, ensure_ascii=False) - ) - - try: - llm_provider = build_default_provider() - response_text = await llm_provider.generate( - prompt, temperature=0.0, response_format={"type": "json_object"} - ) - - result = parse_json_response(response_text) - score_map = { - item["index"]: item.get("score", 0.0) - for item in result.get("results", []) - } - - scored = [] - for i, skill in enumerate(skills): - relevance_score = score_map.get(i, 0.0) - if relevance_score >= 0.4: - skill.score = relevance_score - scored.append(skill) - - scored.sort(key=lambda s: s.score, reverse=True) - - logger.info( - "Skill relevance verification: %d/%d skills passed (threshold=0.4) for query: %s", - len(scored), - len(skills), - query[:60], - ) - - return scored - - except Exception as e: - logger.warning( - "Skill relevance verification failed, returning all results: %s", e - ) - return skills diff --git a/methods/evermemos/src/api_specs/dtos/memory.py b/methods/evermemos/src/api_specs/dtos/memory.py index 06b0c5f36..2ee9d04f5 100644 --- a/methods/evermemos/src/api_specs/dtos/memory.py +++ b/methods/evermemos/src/api_specs/dtos/memory.py @@ -568,37 +568,6 @@ class FlushResponse(BaseApiResponse[FlushResult]): data: FlushResult = Field(default_factory=FlushResult, description="Flush result") -# ============================================================================= -# Flush Clustering DTOs (POST /api/v1/memories/agent/flush-clustering) -# ============================================================================= - - -class AgentFlushClusteringRequest(BaseModel): - """Request to force-drain pending memcells and run batch clustering.""" - - user_id: str = Field( - ..., - description="User ID (used to derive group_id for solo scene)", - ) - - -class AgentFlushClusteringResult(BaseModel): - """Result of flush-clustering operation.""" - - request_id: str = Field(default="", description="Request ID") - status: str = Field(default="", description="Flush clustering status") - message: str = Field(default="", description="Status message") - - -class AgentFlushClusteringResponse(BaseApiResponse[AgentFlushClusteringResult]): - """Flush-clustering endpoint response.""" - - data: AgentFlushClusteringResult = Field( - default_factory=AgentFlushClusteringResult, - description="Flush clustering result", - ) - - # ============================================================================= # Search/Retrieve DTOs (GET /api/v1/memories/search) # ============================================================================= diff --git a/methods/evermemos/src/biz_layer/mem_memorize.py b/methods/evermemos/src/biz_layer/mem_memorize.py index 78e6caf62..be1adcbc9 100644 --- a/methods/evermemos/src/biz_layer/mem_memorize.py +++ b/methods/evermemos/src/biz_layer/mem_memorize.py @@ -1,8 +1,11 @@ from dataclasses import dataclass +import random import time +import json import traceback from core.observation.stage_timer import timed, timed_parallel +from api_specs.memory_types import ScenarioType from agentic_layer.metrics.memorize_metrics import ( record_extraction_stage, record_memory_extracted, @@ -13,11 +16,11 @@ from api_specs.memory_types import ( MemoryType, MemCell, + BaseMemory, EpisodeMemory, RawDataType, Foresight, AgentCase, - ScenarioType, ) from api_specs.memory_types import AtomicFact, get_text_from_content_items from biz_layer.memorize_config import DEFAULT_MEMORIZE_CONFIG @@ -35,15 +38,23 @@ ConversationStatusRawRepository, ) from service.settings_service import SettingsService +from infra_layer.adapters.out.persistence.repository.memcell_raw_repository import ( + MemCellRawRepository, +) from infra_layer.adapters.out.persistence.repository.conversation_data_raw_repository import ( ConversationDataRepository, ) +from api_specs.memory_types import RawDataType from typing import List, Dict, Optional, Any +from dataclasses import dataclass +import uuid from datetime import datetime, timedelta +import os import asyncio from collections import defaultdict from common_utils.datetime_utils import get_now_with_timezone, to_iso_format from memory_layer.memcell_extractor.base_memcell_extractor import StatusResult +import traceback from core.observation.logger import get_logger from infra_layer.adapters.out.search.elasticsearch.converter.episodic_memory_converter import ( @@ -78,7 +89,6 @@ class MemoryDocPayload: from biz_layer.memorize_config import ( MemorizeConfig, DEFAULT_MEMORIZE_CONFIG, - AGENT_DEFAULT_MEMORIZE_CONFIG, ) @@ -96,6 +106,7 @@ def _is_agent_case_quality_sufficient( return True + async def _trigger_clustering( group_id: str, memcell: MemCell, @@ -104,415 +115,274 @@ async def _trigger_clustering( episode_text: Optional[str] = None, agent_case: Optional[AgentCase] = None, ) -> None: - """Trigger MemCell clustering. - - Accumulates memcells in a pending queue. Once the queue reaches - cluster_batch_size, drains it and runs clustering. When - cluster_batch_size == 1, this means every memcell is processed - immediately. Use flush_clustering() to drain on demand. + """Trigger MemCell clustering Args: group_id: Group ID memcell: The MemCell just saved - scene: Conversation scene ("solo" or "team") - config: Memory extraction configuration - episode_text: Episode text extracted from this MemCell - agent_case: Extracted AgentCase (if agent conversation) + scene: Conversation scene (used to determine Profile extraction strategy) + - "solo": 1 user + N agents scenario + - "team": multi-user + agents scenario + episode_text: Episode text extracted from this MemCell (used for clustering similarity) + agent_case: Extracted AgentCase (if agent conversation), used for skill extraction """ logger.info( - f"[Clustering] Start triggering clustering: group_id={group_id}, " - f"event_id={memcell.event_id}, scene={scene}" + f"[Clustering] Start triggering clustering: group_id={group_id}, event_id={memcell.event_id}, scene={scene}" ) - pending_entry = { - "event_id": str(memcell.event_id), - "episode": episode_text, - "timestamp": memcell.timestamp.timestamp() if memcell.timestamp else None, - "participants": memcell.participants or [], - "group_id": group_id, - "scene": scene, - } - if agent_case: - pending_entry["agent_case"] = { - "id": agent_case.id, - "task_intent": agent_case.task_intent, - "approach": agent_case.approach, - "key_insight": getattr(agent_case, "key_insight", None), - "quality_score": agent_case.quality_score, - } + try: + from memory_layer.cluster_manager import ( + ClusterManager, + ClusterManagerConfig, + MemSceneState, + ) + from infra_layer.adapters.out.persistence.repository.mem_scene_raw_repository import ( + MemSceneRawRepository, + ) + from core.di import get_bean_by_type - await _drain_and_cluster( - group_id=group_id, - config=config, - new_entry=pending_entry, - ) + logger.info(f"[Clustering] Retrieving MemSceneRawRepository...") + # Get MongoDB storage + cluster_storage = get_bean_by_type(MemSceneRawRepository) + logger.info( + f"[Clustering] MemSceneRawRepository retrieved successfully: {type(cluster_storage)}" + ) + # Create ClusterManager (pure computation component) + has_case = agent_case is not None + cluster_config = ClusterManagerConfig( + similarity_threshold=config.cluster_similarity_threshold, + max_time_gap_days=config.cluster_max_time_gap_days, + ) -async def _drain_and_cluster( - group_id: str, - config: MemorizeConfig, - new_entry: Optional[Dict] = None, - force_drain: bool = False, -) -> int: - """Core clustering pipeline shared by _trigger_clustering and flush_clustering. + # Build LLM provider and context fetcher for agent clustering + llm_provider = None + context_fetcher = None + if has_case: + from memory_layer.llm.llm_provider import build_default_provider + from infra_layer.adapters.out.persistence.repository.agent_case_raw_repository import ( + AgentCaseRawRepository, + ) - Acquires the distributed lock, optionally appends a new pending entry, - drains the queue when batch_size is reached (or force_drain=True), - and runs batch clustering. Skill extraction runs after the lock is released. + llm_provider = build_default_provider() + agent_case_repo = get_bean_by_type(AgentCaseRawRepository) + context_fetcher = agent_case_repo.fetch_task_intents_by_event_ids - Args: - group_id: Group ID - config: Memory extraction configuration - new_entry: Pending entry dict to append (None for flush) - force_drain: If True, drain regardless of batch_size (flush mode) + cluster_manager = ClusterManager( + config=cluster_config, + llm_provider=llm_provider, + context_fetcher=context_fetcher, + ) - Returns: - Number of pending memcells that were drained and clustered. - """ - try: - from memory_layer.cluster_manager import MemSceneState - from infra_layer.adapters.out.persistence.repository.mem_scene_raw_repository import ( - MemSceneRawRepository, + # Clustering text: task_intent for agent case, episode for normal + clustering_text = ( + agent_case.task_intent if has_case and agent_case.task_intent + else episode_text + ) + logger.info( + f"[Clustering] ClusterManager created (has_case={has_case})" ) - from core.lock.redis_distributed_lock import distributed_lock - cluster_storage = get_bean_by_type(MemSceneRawRepository) + # Convert MemCell to dictionary format required for clustering + memcell_dict = { + "event_id": str(memcell.event_id), + "episode": episode_text, + "clustering_text": clustering_text, + "timestamp": memcell.timestamp.timestamp() if memcell.timestamp else None, + "participants": memcell.participants or [], + "group_id": group_id, + } + + logger.debug( + f"[Clustering] Start clustering execution: {memcell_dict['event_id']}" + ) + + from core.lock.redis_distributed_lock import distributed_lock - drained_memcells = [] - cluster_ids = None - result = 0 + # ===== Phase 1 + 2: Clustering + Profile extraction ===== + # Lock: trigger_clustering:{group_id} + # + # Protected shared state (read-modify-write): + # - mem_scene_state: loaded from DB, mutated by cluster_memcell(), saved back. + # Concurrent writes without this lock would cause lost updates. + # - Profile extraction (Phase 2) reads mem_scene_state snapshot taken in Phase 1 + # and reads/writes user profiles in DB. + # + # Released before Phase 3 so the next request's Phase 1+2 is not blocked + # by a slow LLM skill-extraction call. + # + # Data flow out of this lock: + # - cluster_id: determined by Phase 1, used as key for Phase 3 lock. + # - agent_case: passed through from caller, not modified here. + # Both are safe to use after lock release because Phase 3 re-reads + # its own shared state (existing_skills) from DB inside its own lock. + cluster_id = None lock_resource = f"trigger_clustering:{group_id}" async with distributed_lock( - resource=lock_resource, timeout=1200.0, blocking_timeout=3600.0 + resource=lock_resource, + timeout=config.clustering_lock_timeout, + blocking_timeout=config.clustering_lock_blocking_timeout, ) as acquired: if not acquired: logger.error( - f"[Clustering] Failed to acquire lock for group {group_id}" + f"[Clustering] Failed to acquire lock for group {group_id}, " + f"skipping memcell {memcell.event_id}" ) - return 0 + return + # ===== Phase 1: Clustering ===== state_dict = await cluster_storage.load_mem_scene(group_id) mem_scene_state = ( MemSceneState.from_dict(state_dict) if state_dict else MemSceneState() ) + logger.info( + f"[Clustering] Loaded clustering state: {len(mem_scene_state.event_ids)} clustered events" + ) - # Append new entry if provided (trigger mode) - if new_entry: - mem_scene_state.pending_clustering.append(new_entry) + cluster_id, mem_scene_state = await cluster_manager.cluster_memcell( + memcell_dict, mem_scene_state, has_case=has_case + ) - pending_count = len(mem_scene_state.pending_clustering) + await cluster_storage.save_mem_scene(group_id, mem_scene_state.to_dict()) + logger.info(f"[Clustering] Clustering state saved") - # Check whether to drain - should_drain = force_drain or pending_count >= config.cluster_batch_size - if not should_drain: - await cluster_storage.save_mem_scene( - group_id, mem_scene_state.to_dict() + if cluster_id: + logger.debug( + f"[Clustering] ✅ MemCell {memcell.event_id} -> Cluster {cluster_id} (group: {group_id})" ) - logger.info( - f"[Clustering] Accumulated {pending_count}/{config.cluster_batch_size} " - f"pending memcells for group {group_id}" + else: + logger.warning( + f"[Clustering] ⚠️ MemCell {memcell.event_id} clustering returned None (group: {group_id})" ) - return 0 - if pending_count == 0: - logger.info( - f"[Clustering] No pending memcells for group {group_id}" + # ===== Phase 2: Profile extraction (with interval-based throttling) ===== + if cluster_id and not config.skip_profile_extraction: + total_memcell_count = sum(mem_scene_state.cluster_counts.values()) + should_extract = ( + config.profile_extraction_interval <= 1 + or total_memcell_count % config.profile_extraction_interval == 0 ) - return 0 - # Drain pending queue - drained_memcells = list(mem_scene_state.pending_clustering) - mem_scene_state.pending_clustering.clear() + if should_extract: + # --- Group-level cluster selection (Layer 1 of 2) --- + # Profile extraction uses a two-layer filtering strategy: + # + # Layer 1 (here): Select which clusters to fetch from DB. + # Uses min(last_updated_ts) across all users in the group as baseline. + # This is intentionally broad — it covers the "slowest" user so no + # cluster is missed for any user. Fetches ALL events from selected + # clusters in one DB query. + # + # Layer 2 (manager.py, per-user loop): Filters fetched original_data per + # user based on each user's own last_updated timestamp, so each user's + # LLM prompt only contains data they haven't seen yet. + # (Note: the code calls them "episodes" but the actual content is + # memcell original_data — raw chat messages, not episode summaries.) + # + # For new groups with no profiles, defaults to current memcell timestamp + # to avoid selecting all historical clusters (cold-start protection). + from infra_layer.adapters.out.persistence.repository.user_profile_raw_repository import ( + UserProfileRawRepository, + ) + from core.di import get_bean_by_type - logger.info( - f"[Clustering] Draining {len(drained_memcells)} pending memcells " - f"(group={group_id}, existing_events={len(mem_scene_state.event_ids)})" - ) + profile_repo = get_bean_by_type(UserProfileRawRepository) + existing_profiles = await profile_repo.get_all_by_group(group_id) - cluster_ids = await _run_batch_clustering( - group_id=group_id, - drained_memcells=drained_memcells, - mem_scene_state=mem_scene_state, - cluster_storage=cluster_storage, - config=config, - ) + current_memcell_ts = memcell.timestamp.timestamp() + + if existing_profiles: + timestamps = [ + p.last_updated_ts + for p in existing_profiles + if p.last_updated_ts is not None + ] + last_profile_ts = ( + min(timestamps) if timestamps else current_memcell_ts + ) + else: + last_profile_ts = current_memcell_ts + + target_cluster_ids = [ + cid + for cid, ts in mem_scene_state.cluster_last_ts.items() + if ts is not None and ts > last_profile_ts + ] + if cluster_id not in target_cluster_ids: + target_cluster_ids.append(cluster_id) + + logger.info( + f"[Profile] Timestamp-based selection: last_profile_ts={last_profile_ts}, " + f"target_clusters={target_cluster_ids}" + ) - # Profile extraction inside the lock to prevent concurrent - # reads/writes to profile data for the same group. - if cluster_ids: - try: - await _run_profile_extraction_for_batch( + await _trigger_profile_extraction( group_id=group_id, - drained_memcells=drained_memcells, - cluster_ids=cluster_ids, + cluster_ids=target_cluster_ids, mem_scene_state=mem_scene_state, + memcell=memcell, + scene=scene, config=config, - force_drain=force_drain, ) - except Exception as e: - logger.error( - f"[Clustering] Profile extraction failed after clustering " - f"(group={group_id}): {e}", - exc_info=True, + else: + logger.debug( + f"[Profile] Skipping extraction: total_memcells={total_memcell_count}, " + f"interval={config.profile_extraction_interval}" ) - result = pending_count - - # Lock released — run skill extraction outside the lock - if drained_memcells and cluster_ids is not None: - try: - await _run_skill_extraction_for_batch( + # ===== Phase 3: Agent skill extraction ===== + # Lock: trigger_agent_skill:{group_id}:{cluster_id} + # + # Separate lock so Phase 1+2 of the next request is not blocked by this + # slow LLM call. + # + # Data dependencies (all safe after Lock 1 release): + # - cluster_id: immutable identifier, determined in Phase 1. + # - agent_case: this request's own extraction result, not shared state. + # - existing_skills: re-read from DB inside _trigger_agent_skill_extraction, + # so it always reflects the latest state (including writes by prior requests). + # + # IMPORTANT for future maintainers: + # This function does NOT read memcells or agent_cases from DB. It only uses + # the passed-in agent_case (current request) + existing_skills (from DB). + # If you add logic that reads cluster memcells from DB here, you must + # consider that new memcells may have been added between Lock 1 release + # and Lock 2 acquisition. + if cluster_id and agent_case and _is_agent_case_quality_sufficient(agent_case, config): + skill_lock_resource = f"trigger_agent_skill:{group_id}:{cluster_id}" + async with distributed_lock( + resource=skill_lock_resource, + timeout=config.skill_extraction_lock_timeout, + blocking_timeout=config.skill_extraction_lock_blocking_timeout, + ) as skill_acquired: + if not skill_acquired: + logger.error( + f"[AgentSkill] Failed to acquire lock for group {group_id}, " + f"cluster {cluster_id}, skipping memcell {memcell.event_id}" + ) + return + await _trigger_agent_skill_extraction( group_id=group_id, - drained_memcells=drained_memcells, - cluster_ids=cluster_ids, + cluster_id=cluster_id, + memcell=memcell, + agent_case=agent_case, config=config, ) - except Exception as e: - logger.error( - f"[Clustering] Skill extraction failed after clustering " - f"(group={group_id}): {e}", - exc_info=True, - ) - - return result except Exception as e: logger.error( - f"[Clustering] Clustering failed: {e}", exc_info=True + f"[Clustering] ❌ Triggering clustering failed: {e}", exc_info=True ) raise -async def _run_batch_clustering( - group_id: str, - drained_memcells: List[Dict], - mem_scene_state, - cluster_storage, - config: MemorizeConfig, -) -> List[str]: - """Run clustering for drained memcells and return cluster_ids. - - Unified path for both single-memcell and accumulated batch clustering. - Uses LLM-based assignment when more than one item; - otherwise uses incremental embedding similarity for single items. - - Called from _drain_and_cluster. Caller must hold the distributed lock - and have already loaded mem_scene_state. - - Args: - group_id: Group ID - drained_memcells: List of pending memcell dicts to cluster - mem_scene_state: Loaded MemSceneState (mutated in-place) - cluster_storage: MemSceneRawRepository instance - config: Memory extraction configuration - - Returns: - List of cluster_ids corresponding to each entry in drained_memcells. - """ - from memory_layer.cluster_manager import ClusterManager, ClusterManagerConfig - - cluster_config = ClusterManagerConfig( - similarity_threshold=config.cluster_similarity_threshold, - max_time_gap_days=config.cluster_max_time_gap_days, - ) - cluster_manager = ClusterManager(config=cluster_config) - - if len(drained_memcells) > 1: - # LLM-based clustering - from memory_layer.llm.llm_provider import build_default_provider - from infra_layer.adapters.out.persistence.document.memory.episodic_memory import ( - EpisodicMemory, - ) - - llm_custom_setting = await _load_llm_custom_setting() - if llm_custom_setting: - from memory_layer.llm.llm_provider import LLMProvider - llm_provider = LLMProvider(**llm_custom_setting) - else: - llm_provider = build_default_provider() - - # Fetch recent episodes per existing cluster from DB - existing_cluster_episodes: Dict[str, List[str]] = {} - cluster_to_events: Dict[str, List[tuple]] = defaultdict(list) - for eid, cid in mem_scene_state.eventid_to_cluster.items(): - ts = 0.0 - idx = None - try: - idx = mem_scene_state.event_ids.index(eid) - except ValueError: - pass - if idx is not None and idx < len(mem_scene_state.timestamps): - ts = mem_scene_state.timestamps[idx] - cluster_to_events[cid].append((ts, eid)) - - # Pick top 3 most recent event_ids per cluster - top_event_ids = [] - cluster_event_map: Dict[str, List[str]] = {} - for cid, items in cluster_to_events.items(): - items.sort(reverse=True) - recent_eids = [eid for _, eid in items[:3]] - cluster_event_map[cid] = recent_eids - top_event_ids.extend(recent_eids) - - if top_event_ids: - try: - episode_docs = await EpisodicMemory.find( - {"parent_id": {"$in": top_event_ids}} - ).to_list() - parent_to_episode: Dict[str, str] = {} - for doc in episode_docs: - pid = doc.parent_id - if pid and hasattr(doc, "episode") and doc.episode: - parent_to_episode[pid] = doc.episode[:500] - for cid, eids in cluster_event_map.items(): - episodes = [ - parent_to_episode[eid] - for eid in eids - if eid in parent_to_episode - ] - if episodes: - existing_cluster_episodes[cid] = episodes - except Exception as e: - logger.warning(f"[Clustering] Failed to fetch cluster episodes: {e}") - - clustering_extra_body = None - if config.skip_clustering_reasoning: - clustering_extra_body = {"chat_template_kwargs": {"enable_thinking": False}} - - cluster_ids, mem_scene_state = await cluster_manager.cluster_memcells_batch_llm( - drained_memcells, mem_scene_state, - llm_provider=llm_provider, - existing_cluster_episodes=existing_cluster_episodes, - extra_body=clustering_extra_body, - ) - else: - # Incremental embedding-based clustering (single memcell) - cluster_id, mem_scene_state = await cluster_manager.cluster_memcell( - drained_memcells[0], mem_scene_state - ) - cluster_ids = [cluster_id] - - await cluster_storage.save_mem_scene(group_id, mem_scene_state.to_dict()) - - # Log results - valid_ids = [cid for cid in cluster_ids if cid is not None] - unique_clusters = set(valid_ids) - logger.info( - f"[Clustering] Batch complete: {len(drained_memcells)} memcells -> " - f"{len(unique_clusters)} unique clusters (group={group_id})" - ) - - return cluster_ids - - -async def _run_profile_extraction_for_batch( - group_id: str, - drained_memcells: List[Dict], - cluster_ids: List[str], - mem_scene_state, - config: MemorizeConfig, - force_drain: bool = False, -) -> None: - """Run profile extraction for the batch of drained memcells. - - Adapts the release/20260306 per-memcell profile extraction logic to work - with the batch clustering pipeline. Runs outside the distributed lock. - - Args: - group_id: Group ID - drained_memcells: List of pending memcell dicts - cluster_ids: Cluster ID for each entry in drained_memcells - mem_scene_state: MemSceneState after clustering (read-only here) - config: Memory extraction configuration - """ - if config.skip_profile_extraction: - return - - total_memcell_count = sum(mem_scene_state.cluster_counts.values()) - is_batch_mode = config.cluster_batch_size > 1 - if config.profile_extraction_interval <= 1: - should_extract = True - elif is_batch_mode or force_drain: - should_extract = True - else: - should_extract = total_memcell_count % config.profile_extraction_interval == 0 - if not should_extract: - logger.debug( - f"[Profile] Skipping extraction: total_memcells={total_memcell_count}, " - f"interval={config.profile_extraction_interval}" - ) - return - - # Determine scene from the last drained entry - scene = None - for entry in reversed(drained_memcells): - if entry.get("scene"): - scene = entry["scene"] - break - - # Determine the latest memcell timestamp (used as last_updated_ts on save) - latest_ts = 0.0 - for entry in drained_memcells: - ts = entry.get("timestamp") - if ts and ts > latest_ts: - latest_ts = ts - - from infra_layer.adapters.out.persistence.repository.user_profile_raw_repository import ( - UserProfileRawRepository, - ) - - profile_repo = get_bean_by_type(UserProfileRawRepository) - existing_profiles = await profile_repo.get_all_by_group(group_id) - - current_memcell_ts = latest_ts or time.time() - if existing_profiles: - timestamps = [ - p.last_updated_ts - for p in existing_profiles - if p.last_updated_ts is not None - ] - last_profile_ts = min(timestamps) if timestamps else current_memcell_ts - else: - last_profile_ts = current_memcell_ts - - # Select clusters that have newer data than last profile extraction - target_cluster_ids = [ - cid - for cid, ts in mem_scene_state.cluster_last_ts.items() - if ts is not None and ts > last_profile_ts - ] - # Ensure all clusters from this batch are included - for cid in cluster_ids: - if cid and cid not in target_cluster_ids: - target_cluster_ids.append(cid) - - if not target_cluster_ids: - return - - logger.info( - f"[Profile] Timestamp-based selection: last_profile_ts={last_profile_ts}, " - f"target_clusters={target_cluster_ids}" - ) - - await _trigger_profile_extraction( - group_id=group_id, - cluster_ids=target_cluster_ids, - mem_scene_state=mem_scene_state, - latest_memcell_ts=latest_ts, - scene=scene, - config=config, - new_user_max_context=max(config.profile_extraction_interval, len(drained_memcells)), - ) - - async def _trigger_profile_extraction( group_id: str, cluster_ids: List[str], mem_scene_state, # MemSceneState - latest_memcell_ts: float, + memcell: MemCell, scene: Optional[str] = None, config: MemorizeConfig = DEFAULT_MEMORIZE_CONFIG, - new_user_max_context: int = 0, ) -> None: """Trigger Profile extraction for one or more clusters. @@ -520,21 +390,18 @@ async def _trigger_profile_extraction( group_id: Group ID cluster_ids: Cluster IDs to extract profiles from mem_scene_state: Current mem scene state - latest_memcell_ts: Timestamp of the latest memcell in the batch - new_user_max_context: Max cluster context for new users (0 = no limit) + memcell: The MemCell currently being processed (appended as new_memcell) scene: Conversation scene config: Memory extraction configuration """ - user_id_list: List[str] = [] try: from memory_layer.profile_manager import ProfileManager, ProfileManagerConfig from infra_layer.adapters.out.persistence.repository.user_profile_raw_repository import ( UserProfileRawRepository, ) - from infra_layer.adapters.out.persistence.repository.memcell_raw_repository import ( - MemCellRawRepository, - ) from memory_layer.llm.llm_provider import build_default_provider + from core.di import get_bean_by_type + import os total_memcell_count = sum( mem_scene_state.cluster_counts.get(cid, 0) for cid in cluster_ids @@ -571,24 +438,25 @@ async def _trigger_profile_extraction( ) # ===== Fetch memcells from all target clusters ===== + current_event_id = str(memcell.event_id) if memcell.event_id else None target_cluster_set = set(cluster_ids) target_event_ids = set() if mem_scene_state and hasattr(mem_scene_state, 'eventid_to_cluster'): for event_id, cid in mem_scene_state.eventid_to_cluster.items(): - if cid in target_cluster_set: + if cid in target_cluster_set and event_id != current_event_id: target_event_ids.add(event_id) all_memcells = [] if target_event_ids: try: fetched = await memcell_repo.get_by_event_ids(list(target_event_ids)) - all_memcells = sorted( - fetched.values(), - key=lambda mc: mc.timestamp or datetime.min, - ) + all_memcells = list(fetched.values()) except Exception as e: logger.warning(f"[Profile] Failed to fetch cluster memcells: {e}") + # Append current memcell as the last one (new_memcell) + all_memcells.append(memcell) + # Merge participants from all memcells (deduplicated) all_participants: set = set() for mc in all_memcells: @@ -602,12 +470,12 @@ async def _trigger_profile_extraction( logger.info( f"[Profile] Context: clusters={len(cluster_ids)}, " - f"memcells={len(all_memcells)}, users={len(user_id_list)}" + f"memcells={len(all_memcells) - 1}, new=1, users={len(user_id_list)}" ) # ===== Extract and save profiles ===== - # Caller (_trigger_clustering) already holds the distributed lock for this group, - # so no additional lock is needed here. + # Caller (_trigger_clustering) holds trigger_clustering:{group_id} while calling + # this function, so concurrent profile writes for the same group are serialized. # Load old profiles old_profiles_dict = await profile_repo.get_all_profiles(group_id=group_id) @@ -621,7 +489,9 @@ async def _trigger_profile_extraction( logger.info(f"[Profile] Profile for {uid}: keys={keys[:8]}") # Extract profiles - profile_scene = ScenarioType.TEAM if scene == ScenarioType.TEAM.value else ScenarioType.SOLO + profile_scene = ( + ScenarioType.TEAM if scene == ScenarioType.TEAM.value else ScenarioType.SOLO + ) new_profiles = await profile_manager.extract_profiles( memcells=all_memcells, old_profiles=old_profiles, @@ -629,13 +499,12 @@ async def _trigger_profile_extraction( group_id=group_id, max_items=config.profile_max_items, scene=profile_scene, - new_user_max_context=new_user_max_context, ) # Save profiles - memcell_ts = latest_memcell_ts if latest_memcell_ts else 0.0 for profile in new_profiles: try: + memcell_ts = memcell.timestamp.timestamp() if memcell.timestamp else 0.0 user_id = profile.user_id profile_data = profile.to_dict() metadata = { @@ -663,10 +532,14 @@ async def _trigger_profile_extraction( # of the same clusters. The data is "skipped" — acceptable tradeoff vs. # getting stuck in a loop retrying the same failing extraction. try: - memcell_ts = latest_memcell_ts if latest_memcell_ts else 0.0 + memcell_ts = memcell.timestamp.timestamp() if memcell.timestamp else 0.0 for uid in user_id_list: existing = await profile_repo.get_by_user_and_group(uid, group_id) - profile_data = existing.profile_data if existing else {"explicit_info": [], "implicit_traits": []} + profile_data = ( + existing.profile_data + if existing + else {"explicit_info": [], "implicit_traits": []} + ) await profile_repo.upsert( user_id=uid, group_id=group_id, @@ -678,158 +551,26 @@ async def _trigger_profile_extraction( f"[Profile] Advanced last_updated_ts to {memcell_ts} for {len(user_id_list)} users despite failure" ) except Exception as ts_err: - logger.warning(f"[Profile] Failed to advance last_updated_ts on failure: {ts_err}") - - -async def _run_skill_extraction_for_batch( - group_id: str, - drained_memcells: List[Dict], - cluster_ids: List[str], - config: MemorizeConfig, -) -> None: - """Run agent skill extraction for clustered drained memcells. - - Called after the clustering lock is released to avoid holding - the lock during potentially slow LLM-based skill extraction. - Each cluster does its own Milvus/ES writes (without flush), - then a single Milvus flush is issued at the end to avoid rate limiting. - - Args: - group_id: Group ID - drained_memcells: List of pending memcell dicts (same as passed to clustering) - cluster_ids: Cluster ID for each entry in drained_memcells - config: Memory extraction configuration - """ - if config.skip_skill_extraction: - return - - agent_cases = _build_agent_cases_from_batch(drained_memcells) or None - if not agent_cases: - return - - # Group cases by cluster_id - cluster_cases: Dict[str, List[AgentCase]] = defaultdict(list) - for i, entry in enumerate(drained_memcells): - eid = entry.get("event_id") - if not eid or eid not in agent_cases: - continue - cid = cluster_ids[i] if i < len(cluster_ids) else None - if not cid: - continue - agent_case = agent_cases[eid] - if not _is_agent_case_quality_sufficient(agent_case, config): - continue - cluster_cases[cid].append(agent_case) - - if not cluster_cases: - return - - # Resolve user_id from the first agent_case's participants - first_case = next(iter(agent_cases.values())) - user_id = first_case.user_id or None - - # Run skill extraction for each cluster in parallel (Milvus writes without flush) - has_milvus_changes = await asyncio.gather( - *( - _trigger_agent_skill_extraction( - group_id=group_id, - cluster_id=cid, - user_id=user_id, - agent_cases=cases, - config=config, - ) - for cid, cases in cluster_cases.items() - ) - ) - - # Single Milvus flush after all clusters are done - if any(has_milvus_changes): - try: - from infra_layer.adapters.out.search.repository.agent_skill_milvus_repository import ( - AgentSkillMilvusRepository, - ) - agent_skill_milvus_repo = get_bean_by_type(AgentSkillMilvusRepository) - await agent_skill_milvus_repo.flush() - except Exception as milvus_exc: logger.warning( - f"[AgentSkill] Milvus flush failed (group={group_id}): {milvus_exc}" + f"[Profile] Failed to advance last_updated_ts on failure: {ts_err}" ) -async def flush_clustering( - user_id: str, - config: Optional[MemorizeConfig] = None, -) -> int: - """Public entry point: force-drain all pending memcells and run batch clustering. - - Called by the flush-clustering HTTP endpoint. Reuses the same - _drain_and_cluster path as _trigger_clustering with force_drain=True. - - Args: - user_id: User ID (used to derive group_id) - config: Optional config override. Defaults to AGENT_DEFAULT_MEMORIZE_CONFIG. - - Returns: - Number of pending memcells that were drained and clustered. - """ - if config is None: - config = AGENT_DEFAULT_MEMORIZE_CONFIG - - from api_specs.id_generator import generate_single_user_group_id - group_id = generate_single_user_group_id(user_id) - - return await _drain_and_cluster( - group_id=group_id, - config=config, - force_drain=True, - ) - - -def _build_agent_cases_from_batch(drained_memcells: List[Dict]) -> Dict[str, AgentCase]: - """Reconstruct AgentCase objects from serialized agent_case dicts in pending entries.""" - agent_cases: Dict[str, AgentCase] = {} - for entry in drained_memcells: - eid = entry.get("event_id") - ac_dict = entry.get("agent_case") - if not eid or not ac_dict: - continue - ts = entry.get("timestamp") - participants = entry.get("participants", []) - agent_cases[eid] = AgentCase( - id=ac_dict.get("id"), - memory_type=MemoryType.AGENT_CASE, - user_id=participants[0] if participants else "", - timestamp=datetime.fromtimestamp(ts) if ts else datetime.now(), - task_intent=ac_dict.get("task_intent"), - approach=ac_dict.get("approach"), - key_insight=ac_dict.get("key_insight"), - quality_score=ac_dict.get("quality_score"), - ) - return agent_cases - - async def _trigger_agent_skill_extraction( group_id: str, cluster_id: str, - user_id: Optional[str], - agent_cases: List[AgentCase], - config: MemorizeConfig = AGENT_DEFAULT_MEMORIZE_CONFIG, -) -> bool: + memcell: MemCell, + agent_case: AgentCase, + config: MemorizeConfig = DEFAULT_MEMORIZE_CONFIG, +) -> None: """Trigger incremental AgentSkill extraction for a MemScene cluster. - Performs DB read-modify-write under a per-cluster lock, then syncs - to Milvus/ES without flushing. Caller is responsible for a single - Milvus flush after all clusters complete. - Args: group_id: Group ID cluster_id: The cluster (MemScene) to extract skills for - user_id: User ID (agent owner) - agent_cases: List of AgentCase BOs to integrate into this cluster + memcell: The MemCell currently being processed (for user_id and event_id) + agent_case: The extracted AgentCase BO config: Memory extraction configuration - - Returns: - True if any Milvus writes were made (caller should flush). """ try: from infra_layer.adapters.out.persistence.repository.agent_skill_raw_repository import ( @@ -851,113 +592,120 @@ async def _trigger_agent_skill_extraction( from infra_layer.adapters.out.search.repository.agent_skill_es_repository import ( AgentSkillEsRepository, ) - from core.lock.redis_distributed_lock import distributed_lock - # Per-cluster lock to prevent concurrent skill extraction on the same cluster - # when multiple clustering batches finish close together. - lock_resource = f"skill_extraction:{group_id}:{cluster_id}" - has_milvus_changes = False + # Caller (_trigger_clustering) acquires trigger_agent_skill:{group_id}:{cluster_id} + # before calling this function, so concurrent skill writes for the same cluster are + # serialized while different clusters within the same group can run in parallel. + # + # Concurrency safety of data used in this function: + # - existing_skills: read from DB below (inside the caller's lock), always fresh. + # - agent_case: passed in from the current request, not shared with other requests. + # - memcell: only used to extract user_id, no shared-state concern. + # - extract_and_save() does NOT read memcells or agent_cases from DB. + # It only merges new_case_records (passed-in) with existing_skill_records (from DB). + # If future changes add DB reads of memcells/cases here, re-evaluate the lock + # boundary — the gap between Lock 1 release and Lock 2 acquisition means + # new memcells may have been clustered in between. - async with distributed_lock( - resource=lock_resource, timeout=1200.0, blocking_timeout=1200.0 - ) as acquired: - if not acquired: - logger.warning( - f"[AgentSkill] Failed to acquire lock for cluster={cluster_id}, " - f"group={group_id}, skipping extraction" - ) - return False + # Fetch existing skills for incremental merging + skill_repo = get_bean_by_type(AgentSkillRawRepository) + existing_skills = await skill_repo.get_by_cluster_id( + cluster_id, group_id=group_id, min_confidence=config.skill_retire_confidence + ) - skill_repo = get_bean_by_type(AgentSkillRawRepository) - llm_provider = build_default_provider() - extractor = AgentSkillExtractor( - llm_provider=llm_provider, - maturity_threshold=config.skill_maturity_threshold, - retire_confidence=config.skill_retire_confidence, - skip_maturity_scoring=config.skip_skill_maturity_scoring, + logger.info( + f"[AgentSkill] Incremental extraction: cluster={cluster_id}, " + f"new_experience=1, existing_skills={len(existing_skills)}" + ) + + # Resolve user_id from the memcell's original conversation data + user_id = _extract_user_id_from_memcell(memcell) + + # Run incremental skill extraction + llm_provider = build_default_provider() + extractor = AgentSkillExtractor( + llm_provider=llm_provider, + maturity_threshold=config.skill_maturity_threshold, + retire_confidence=config.skill_retire_confidence, + skip_maturity_scoring=config.skip_skill_maturity_scoring, + ) + extraction_result = await extractor.extract_and_save( + cluster_id=cluster_id, + group_id=group_id, + new_case_records=[agent_case], + existing_skill_records=existing_skills, + skill_repo=skill_repo, + user_id=user_id, + ) + + if extraction_result.deleted_ids: + logger.info( + f"[AgentSkill] Retired skills for cluster={cluster_id}: " + f"ids={extraction_result.deleted_ids}" ) + logger.info( + f"[AgentSkill] Extraction result for cluster={cluster_id}: " + f"added={len(extraction_result.added_records)}, " + f"updated={len(extraction_result.updated_records)}, " + f"retired={len(extraction_result.deleted_ids)}" + ) - # Process cases one by one within this cluster - for case_idx, agent_case in enumerate(agent_cases): - # Reload existing skills each round (previous round may have changed them) - existing_skills = await skill_repo.get_by_cluster_id( - cluster_id, group_id=group_id, min_confidence=config.skill_retire_confidence - ) + # Records that need insert into search engines (added + updated) + upsert_records = ( + extraction_result.added_records + extraction_result.updated_records + ) + # IDs of updated records that need their old entry removed first + updated_ids = [str(r.id) for r in extraction_result.updated_records] + # IDs to remove from search engines (deleted + updated-old-entries) + remove_ids = extraction_result.deleted_ids + updated_ids + if upsert_records or remove_ids: + # Milvus sync: delete stale entries -> insert new/updated + try: + agent_skill_milvus_repo = get_bean_by_type(AgentSkillMilvusRepository) + for old_id in remove_ids: + await agent_skill_milvus_repo.delete_by_id(old_id) + inserted_count = 0 + for record in upsert_records: + milvus_entity = AgentSkillMilvusConverter.from_mongo(record) + if milvus_entity.get("vector"): + await agent_skill_milvus_repo.insert(milvus_entity, flush=False) + inserted_count += 1 + else: + logger.warning( + f"[AgentSkill] Milvus skip (no vector): record={record.id}" + ) logger.info( - f"[AgentSkill] Incremental extraction: cluster={cluster_id}, " - f"case {case_idx + 1}/{len(agent_cases)}, existing_skills={len(existing_skills)}" + f"[AgentSkill] Milvus synced for cluster={cluster_id}: " + f"inserted={inserted_count}, removed={len(remove_ids)}" ) - - extraction_result = await extractor.extract_and_save( - cluster_id=cluster_id, - group_id=group_id, - new_case_records=[agent_case], - existing_skill_records=existing_skills, - skill_repo=skill_repo, - user_id=user_id, + except Exception as milvus_exc: + logger.warning( + f"[AgentSkill] Milvus sync failed for cluster={cluster_id}: {milvus_exc}" ) - if extraction_result.deleted_ids: - logger.info( - f"[AgentSkill] Retired skills for cluster={cluster_id}: " - f"ids={extraction_result.deleted_ids}" - ) + # ES sync: delete stale entries -> insert new/updated + try: + agent_skill_es_repo = get_bean_by_type(AgentSkillEsRepository) + for old_id in remove_ids: + await agent_skill_es_repo.delete_by_id(old_id) + for record in upsert_records: + es_doc = AgentSkillConverter.from_mongo(record) + await agent_skill_es_repo.create(es_doc) logger.info( - f"[AgentSkill] Extraction result for cluster={cluster_id} " - f"case {case_idx + 1}/{len(agent_cases)}: " - f"added={len(extraction_result.added_records)}, " - f"updated={len(extraction_result.updated_records)}, " - f"retired={len(extraction_result.deleted_ids)}" + f"[AgentSkill] ES synced for cluster={cluster_id}: " + f"inserted={len(upsert_records)}, removed={len(remove_ids)}" ) - - # Sync to search engines (without Milvus flush) - upsert_records = ( - extraction_result.added_records + extraction_result.updated_records + except Exception as es_exc: + logger.warning( + f"[AgentSkill] ES sync failed for cluster={cluster_id}: {es_exc}" ) - updated_ids = [str(r.id) for r in extraction_result.updated_records] - remove_ids = extraction_result.deleted_ids + updated_ids - - if upsert_records or remove_ids: - try: - agent_skill_milvus_repo = get_bean_by_type(AgentSkillMilvusRepository) - for old_id in remove_ids: - await agent_skill_milvus_repo.delete_by_id(old_id) - has_milvus_changes = True - for record in upsert_records: - milvus_entity = AgentSkillMilvusConverter.from_mongo(record) - if milvus_entity.get("vector"): - await agent_skill_milvus_repo.insert(milvus_entity, flush=False) - has_milvus_changes = True - else: - logger.warning( - f"[AgentSkill] Milvus skip (no vector): record={record.id}" - ) - except Exception as milvus_exc: - logger.warning( - f"[AgentSkill] Milvus sync failed for cluster={cluster_id}: {milvus_exc}" - ) - - try: - agent_skill_es_repo = get_bean_by_type(AgentSkillEsRepository) - for old_id in remove_ids: - await agent_skill_es_repo.delete_by_id(old_id) - for record in upsert_records: - es_doc = AgentSkillConverter.from_mongo(record) - await agent_skill_es_repo.create(es_doc) - except Exception as es_exc: - logger.warning( - f"[AgentSkill] ES sync failed for cluster={cluster_id}: {es_exc}" - ) - - return has_milvus_changes except Exception as e: logger.error( f"[AgentSkill] Skill extraction failed for cluster={cluster_id}: {e}", exc_info=True, ) - return False from biz_layer.mem_db_operations import ( @@ -965,10 +713,12 @@ async def _trigger_agent_skill_extraction( _convert_foresight_to_doc, _convert_atomic_fact_to_docs, _convert_agent_case_to_doc, + _extract_user_id_from_memcell, _save_memcell_to_database, _update_status_for_continuing_conversation, _update_status_after_memcell_extraction, ) +from typing import Tuple def if_memorize(memcell: MemCell) -> bool: @@ -1124,7 +874,7 @@ async def _timed_extract_agent_case(): count=1, ) - # 3. Fire-and-forget clustering + skill extraction (no data dependency on step 4) + # 3. Fire-and-forget clustering + profile extraction (no data dependency on step 4) async def _clustering_with_metrics(): cluster_start = time.perf_counter() try: @@ -1154,12 +904,7 @@ async def _clustering_with_metrics(): ) # Fire-and-forget: extract and save foresight/atomic_fact in background. # Solo scenes only; episode_saved confirms parent_doc is available for linking. - # Skip for agent conversations when skip_foresight_and_eventlog is enabled. - skip_foresight = ( - is_agent_conversation - and AGENT_DEFAULT_MEMORIZE_CONFIG.skip_foresight_and_eventlog - ) - if state.is_solo_scene and state.episode_saved and not skip_foresight: + if state.is_solo_scene and state.episode_saved and not DEFAULT_MEMORIZE_CONFIG.skip_foresight_and_eventlog: asyncio.create_task( _foresight_and_atomic_facts_with_metrics(state, memory_manager) ) @@ -1203,6 +948,9 @@ async def _extract_episodes(state: ExtractionState, memory_manager: MemoryManage ) results = await asyncio.gather(*tasks, return_exceptions=True) + from common_utils.async_utils import reraise_critical_errors + + reraise_critical_errors(results) _process_episode_results(state, results) @@ -1257,11 +1005,7 @@ async def _update_memcell_and_cluster(state: ExtractionState): return try: - # Select config based on conversation type - is_agent = state.memcell.type == RawDataType.AGENTCONVERSATION - cluster_config = ( - AGENT_DEFAULT_MEMORIZE_CONFIG if is_agent else DEFAULT_MEMORIZE_CONFIG - ) + cluster_config = DEFAULT_MEMORIZE_CONFIG await _trigger_clustering( state.request.group_id, @@ -1427,13 +1171,6 @@ def _clone_episodes_for_users(state: ExtractionState) -> List[EpisodeMemory]: cloned = [] group_ep = state.group_episode_memories[0] for user_id in state.participants: - if ( - "robot" in user_id.lower() - or "assistant" in user_id.lower() - or "agent" in user_id.lower() - or "tool" in user_id.lower() - ): - continue cloned.append(replace(group_ep, user_id=user_id, user_name=user_id)) logger.info(f"[MemCell Processing] Copied group Episode to {len(cloned)} users") return cloned @@ -1501,13 +1238,7 @@ async def _save_foresight_and_atomic_fact( # solo scene: copy to each user if state.is_solo_scene: - user_ids = [ - u - for u in state.participants - if "robot" not in u.lower() - and "assistant" not in u.lower() - and "agent" not in u.lower() - ] + user_ids = list(state.participants) foresight_docs.extend( [ doc.model_copy(update={"user_id": uid, "user_name": uid}) diff --git a/methods/evermemos/src/biz_layer/memorize_config.py b/methods/evermemos/src/biz_layer/memorize_config.py index fc55deee0..f97de895c 100644 --- a/methods/evermemos/src/biz_layer/memorize_config.py +++ b/methods/evermemos/src/biz_layer/memorize_config.py @@ -4,11 +4,14 @@ Centralized management of all trigger conditions and thresholds for easy adjustment and maintenance. """ -import os +import logging from dataclasses import dataclass +import os from api_specs.memory_types import ParentType +logger = logging.getLogger(__name__) + @dataclass class MemorizeConfig: @@ -40,6 +43,18 @@ class MemorizeConfig: # Default parent type for AtomicFact (memcell or episode) default_atomic_fact_parent_type: str = ParentType.MEMCELL.value + # ===== Clustering lock configuration ===== + # Timeout (seconds) for acquiring the clustering lock + clustering_lock_timeout: float = 600.0 + # Blocking timeout (seconds) for waiting to acquire the clustering lock + clustering_lock_blocking_timeout: float = 2400.0 + + # ===== Skill extraction lock configuration ===== + # Timeout (seconds) for acquiring the skill extraction lock + skill_extraction_lock_timeout: float = 600.0 + # Blocking timeout (seconds) for waiting to acquire the skill extraction lock + skill_extraction_lock_blocking_timeout: float = 2400.0 + # ===== Agent Skill extraction configuration ===== # Minimum quality score (0.0-1.0) of the AgentCase required to trigger # skill extraction. Cases below this threshold are considered too low @@ -52,65 +67,33 @@ class MemorizeConfig: # (data preserved) but removed from search engines and excluded from # future extraction context. skill_retire_confidence: float = 0.1 - # Skip LLM-based maturity scoring for skills. When True, all skills - # are assigned maturity_score=1.0 directly, saving one LLM call per - # add/update operation. + + # ===== Skip flags ===== + # Skip skill maturity scoring during skill extraction skip_skill_maturity_scoring: bool = False - # Skip foresight and atomic_fact extraction for agent conversations. - # When True, only episodes and agent_case are extracted, saving LLM - # calls that are not needed for the skill extraction pipeline. + # Skip foresight and eventlog extraction skip_foresight_and_eventlog: bool = False - - # ===== Extraction toggles (for fast evaluation) ===== - # When True, skip agent skill extraction entirely. - skip_skill_extraction: bool = False - # When True, skip profile extraction entirely. + # Skip profile extraction skip_profile_extraction: bool = False - - # ===== Skill retrieval configuration ===== - # When True, apply LLM-based relevance verification after vector search - # for agent skills, filtering out irrelevant results. + # Enable LLM-based relevance verification for skill search results enable_skill_llm_verify: bool = False - # ===== LLM request configuration ===== - # When True, disable reasoning/thinking for episode and agent case - # extraction by injecting {"chat_template_kwargs": {"enable_thinking": false}} - # into every LLM request body. Useful for reasoning models (e.g. Qwen3.5) - # deployed via vLLM or SGLang. - skip_episode_case_reasoning: bool = False - # When True, disable reasoning/thinking for LLM-based clustering. - skip_clustering_reasoning: bool = False - - # ===== Batch clustering configuration ===== - # Accumulate N memcells before running clustering. 1 = cluster immediately - # (current behavior). Values > 1 reduce lock contention and enable batched - # embedding calls. Use the flush-clustering API to drain pending items on demand. - cluster_batch_size: int = 1 - - - -# Global default configuration (can be overridden via from_env()) -# TODO Move nescessary configurations to ENV. Use default values for now. -DEFAULT_MEMORIZE_CONFIG = MemorizeConfig() - -_agent_cluster_similarity_threshold = float(os.getenv("AGENT_CLUSTER_SIMILARITY_THRESHOLD", "0.5")) - -FAST_SKILL_MEMORIZE_CONFIG = MemorizeConfig( - cluster_similarity_threshold=_agent_cluster_similarity_threshold, - cluster_batch_size=int(os.getenv("AGENT_CLUSTER_BATCH_SIZE", "20")), - skip_skill_maturity_scoring=True, - skip_foresight_and_eventlog=True, - skip_profile_extraction=True, - enable_skill_llm_verify=True, - skip_episode_case_reasoning=True, -) - -ONLINE_AGENT_MEMORIZE_CONFIG = MemorizeConfig( - cluster_similarity_threshold=_agent_cluster_similarity_threshold, -) - -_agent_mode = os.getenv("AGENT_MEMORIZE_MODE", "online").lower() -if _agent_mode == "fast_skill": - AGENT_DEFAULT_MEMORIZE_CONFIG = FAST_SKILL_MEMORIZE_CONFIG + +# Select config based on AGENT_MEMORIZE_MODE env var: +# "online" (default) — full pipeline +# "fast_skill" — skip profile/foresight/eventlog, skip maturity scoring +_AGENT_MEMORIZE_MODE = os.getenv("AGENT_MEMORIZE_MODE", "online").strip().lower() + +if _AGENT_MEMORIZE_MODE == "fast_skill": + DEFAULT_MEMORIZE_CONFIG = MemorizeConfig( + skip_skill_maturity_scoring=True, + skip_foresight_and_eventlog=True, + skip_profile_extraction=True, + clustering_lock_blocking_timeout=4800, + skill_extraction_lock_blocking_timeout=4800, + enable_skill_llm_verify=True + ) else: - AGENT_DEFAULT_MEMORIZE_CONFIG = ONLINE_AGENT_MEMORIZE_CONFIG + if _AGENT_MEMORIZE_MODE != "online": + logger.warning("Unknown AGENT_MEMORIZE_MODE=%r, falling back to 'online'", _AGENT_MEMORIZE_MODE) + DEFAULT_MEMORIZE_CONFIG = MemorizeConfig() diff --git a/methods/evermemos/src/common_utils/async_utils.py b/methods/evermemos/src/common_utils/async_utils.py new file mode 100644 index 000000000..2cd506104 --- /dev/null +++ b/methods/evermemos/src/common_utils/async_utils.py @@ -0,0 +1,32 @@ +""" +Async utility functions + +Provides helper functions for common async patterns like processing +asyncio.gather results with proper error propagation. +""" + +from typing import Sequence, Any + +from core.constants.exceptions import CriticalError + + +def reraise_critical_errors(results: Sequence[Any]) -> None: + """Re-raise any CriticalError found in asyncio.gather results. + + When using ``asyncio.gather(return_exceptions=True)``, all exceptions are + captured as return values. The common ``isinstance(result, Exception)`` + check then logs-and-continues, silently swallowing every error. + + Call this function **before** processing gather results to ensure + ``CriticalError`` subclasses (e.g. missing tenant context, broken + invariants) always propagate to the caller. + + Args: + results: The list returned by ``asyncio.gather(return_exceptions=True)`` + + Raises: + CriticalError: The first CriticalError found in *results* + """ + for result in results: + if isinstance(result, CriticalError): + raise result diff --git a/methods/evermemos/src/core/constants/exceptions.py b/methods/evermemos/src/core/constants/exceptions.py index 14269d6a4..5035a4646 100644 --- a/methods/evermemos/src/core/constants/exceptions.py +++ b/methods/evermemos/src/core/constants/exceptions.py @@ -5,79 +5,25 @@ Follows a unified exception handling specification, facilitating error tracking and debugging. """ -from enum import Enum from typing import Optional, Dict, Any from core.constants.errors import ErrorCode -class BaseException(Exception): - """Base exception class +class CriticalError(Exception): + """Marker base class for critical errors that must never be silently swallowed. - Base class for all custom exceptions, providing a unified exception handling interface. - Includes error code, error message, and optional details. - """ - - def __init__( - self, - code: str, - message: str, - details: Optional[Dict[str, Any]] = None, - original_exception: Optional[Exception] = None, - ): - """ - Initialize base exception - - Args: - code: Error code - message: Error message - details: Optional dictionary of detailed information - original_exception: Original exception object - """ - super().__init__(message) - self.code = code - self.message = message - self.details = details or {} - self.original_exception = original_exception - - def __str__(self) -> str: - """Return string representation of the exception""" - return f"[{self.code}] {self.message}" + Errors inheriting from this class indicate serious system-level issues + (e.g., missing tenant context, broken invariants) that should always + propagate to the caller and result in a request failure (HTTP 500). - def __repr__(self) -> str: - """Return detailed representation of the exception""" - details_str = f", details={self.details}" if self.details else "" - original_str = ( - f", original={self.original_exception}" if self.original_exception else "" - ) - return f"{self.__class__.__name__}(code='{self.code}', message='{self.message}'{details_str}{original_str})" - - def to_dict(self) -> Dict[str, Any]: - """Convert exception to dictionary format for easy serialization""" - return { - "code": self.code, - "message": self.message, - "details": self.details, - "exception_type": self.__class__.__name__, - } - - -class AgentException(BaseException): - """Base class for Agent-related exceptions - - Base class for all exceptions related to Agent execution. + Use ``reraise_critical_errors()`` from ``common_utils.async_utils`` + to guard ``asyncio.gather(return_exceptions=True)`` result processing. """ - def __init__( - self, - code: str, - message: str, - details: Optional[Dict[str, Any]] = None, - original_exception: Optional[Exception] = None, - ): - super().__init__(code, message, details, original_exception) + pass -class ValidationException(BaseException): +class ValidationException(Exception): """Data validation exception Raised when input data validation fails. @@ -92,180 +38,10 @@ def __init__( if field: message = f"Field '{field}': {message}" - super().__init__( - code=ErrorCode.VALIDATION_ERROR.value, message=message, details=details - ) - - -class ResourceNotFoundException(BaseException): - """Resource not found exception - - Raised when the requested resource does not exist. - """ - - def __init__( - self, - resource_type: str, - resource_id: str, - details: Optional[Dict[str, Any]] = None, - ): - message = f"{resource_type} with id '{resource_id}' not found" - super().__init__( - code=ErrorCode.RESOURCE_NOT_FOUND.value, message=message, details=details - ) - - -class ConfigurationException(BaseException): - """Configuration exception - - Raised when system configuration is incorrect or missing. - """ - - def __init__( - self, - message: str, - config_key: Optional[str] = None, - details: Optional[Dict[str, Any]] = None, - ): - if config_key: - message = f"Configuration error for '{config_key}': {message}" - - super().__init__( - code=ErrorCode.CONFIGURATION_ERROR.value, message=message, details=details - ) - - -class DatabaseException(BaseException): - """Database exception - - Raised when a database operation fails. - """ - - def __init__( - self, - message: str, - operation: Optional[str] = None, - details: Optional[Dict[str, Any]] = None, - original_exception: Optional[Exception] = None, - ): - if operation: - message = f"Database {operation} failed: {message}" - - super().__init__( - code=ErrorCode.DATABASE_ERROR.value, - message=message, - details=details, - original_exception=original_exception, - ) - - -class ExternalServiceException(BaseException): - """External service exception - - Raised when calling an external service fails. - """ - - def __init__( - self, - service_name: str, - message: str, - status_code: Optional[int] = None, - details: Optional[Dict[str, Any]] = None, - original_exception: Optional[Exception] = None, - ): - if status_code: - message = f"{service_name} service error (HTTP {status_code}): {message}" - else: - message = f"{service_name} service error: {message}" - - super().__init__( - code=ErrorCode.EXTERNAL_SERVICE_ERROR.value, - message=message, - details=details, - original_exception=original_exception, - ) - - -class AuthenticationException(BaseException): - """Authentication exception - - Raised when user authentication fails. - """ - - def __init__( - self, - message: str, - details: Optional[Dict[str, Any]] = None, - original_exception: Optional[Exception] = None, - ): - super().__init__( - code=ErrorCode.AUTHENTICATION_ERROR.value, - message=message, - details=details, - original_exception=original_exception, - ) - - -class LLMOutputParsingException(AgentException): - """LLM output parsing exception - - Raised when the content returned by LLM cannot be parsed correctly. - """ - - def __init__( - self, - message: str, - llm_output: Optional[str] = None, - expected_format: Optional[str] = None, - attempt_count: Optional[int] = None, - details: Optional[Dict[str, Any]] = None, - original_exception: Optional[Exception] = None, - ): - if expected_format: - message = f"LLM output parsing failed, expected format: {expected_format}, error: {message}" - if attempt_count: - message = f"{message} [Attempt {attempt_count}]" - - # Add LLM output to details - if details is None: - details = {} - if llm_output: - details["llm_output"] = llm_output[ - :500 - ] # Limit length to avoid being too long - - super().__init__( - code=ErrorCode.LLM_OUTPUT_PARSING_ERROR.value, - message=message, - details=details, - original_exception=original_exception, - ) - - -def create_exception_from_error_code( - error_code: ErrorCode, - message: str, - details: Optional[Dict[str, Any]] = None, - original_exception: Optional[Exception] = None, -) -> BaseException: - """ - Create corresponding exception object based on error code - - Args: - error_code: Error code enumeration - message: Error message - details: Optional detailed information - original_exception: Original exception object - - Returns: - Corresponding exception object - """ - return BaseException( - code=error_code.value, - message=message, - details=details, - original_exception=original_exception, - ) + super().__init__(message) + self.code = ErrorCode.VALIDATION_ERROR.value + self.message = message + self.details = details or {} # Long Job System Errors - Long job system error classes @@ -284,16 +60,8 @@ def create_exception_from_error_code( __all__ = [ # Error codes and base exception 'ErrorCode', - 'BaseException', - 'AgentException', + 'CriticalError', 'ValidationException', - 'ResourceNotFoundException', - 'ConfigurationException', - 'DatabaseException', - 'ExternalServiceException', - 'AuthenticationException', - 'LLMOutputParsingException', - 'create_exception_from_error_code', # Long job system error classes 'FatalError', 'BusinessLogicError', diff --git a/methods/evermemos/src/core/tenants/init_tenant_all.py b/methods/evermemos/src/core/tenants/init_tenant_all.py index 1175d4128..980123337 100644 --- a/methods/evermemos/src/core/tenants/init_tenant_all.py +++ b/methods/evermemos/src/core/tenants/init_tenant_all.py @@ -54,7 +54,7 @@ TENANT_INIT_STORAGE_INFO_ENV = "TENANT_INIT_STORAGE_INFO" -def _setup_tenant_context_from_env() -> str: +def setup_tenant_context_from_env() -> str: """ Set up tenant context from TENANT_INIT_STORAGE_INFO environment variable. @@ -183,7 +183,7 @@ async def run_tenant_init() -> bool: logger.info("*" * 60) # Set up tenant context - tenant_id = _setup_tenant_context_from_env() + tenant_id = setup_tenant_context_from_env() logger.info("Tenant ID: %s", tenant_id) logger.info("*" * 60) diff --git a/methods/evermemos/src/core/tenants/tenant_constants.py b/methods/evermemos/src/core/tenants/tenant_constants.py index b0d9532bf..1e6ef8409 100644 --- a/methods/evermemos/src/core/tenants/tenant_constants.py +++ b/methods/evermemos/src/core/tenants/tenant_constants.py @@ -8,10 +8,9 @@ Prefix Mode Example Source ────── ────────────────── ──────────────── ────────────────────────── - b0001 Base (no tenant) b0001_memsys get_base_resource_prefix() + s0001 Base / shared pool s0001_memsys get_base_resource_prefix() dev Single-tenant dev_memsys TENANT_SINGLE_TENANT_ID t3a7b2c Multi-tenant excl. t3a7b2c_memsys enterprise tenant_id_generator - s0001 Multi-tenant shared s0001_memsys enterprise tenant_id_generator All separators use underscore (_). No hyphens in resource names. @@ -47,18 +46,21 @@ ISOLATION_MODE_EXCLUSIVE = "exclusive" # ============================================================ -# Resource prefix: base ("b") — no tenant context +# Resource prefix: base / shared pool ("s") # ============================================================ # Used during multi-tenant startup before any request arrives. +# Using "s" (same as shared pool prefix) means the ORM startup +# resources (Beanie init, etc.) land directly in the shared pool, +# avoiding an extra set of "b0001_*" phantom resources. +# # The version suffix is configurable via TENANT_BASE_RESOURCE_VERSION env var, -# allowing operators to bump the version (b0001 -> b0002) on upgrades so +# allowing operators to bump the version (s0001 -> s0002) on upgrades so # new resources are auto-created while old ones are left intact for rollback. # -# Companion prefixes defined in enterprise (tenant_id_generator.py): +# Other prefixes defined in enterprise (tenant_id_generator.py): # "t" — exclusive tenant (t + 10-hex hash of org+space, e.g., "t3a7b2c1d9e") -# "s" — shared node (s + 4-digit node id, e.g., "s0001") -BASE_RESOURCE_PREFIX_LETTER = "b" +BASE_RESOURCE_PREFIX_LETTER = "s" @lru_cache(maxsize=1) @@ -66,11 +68,11 @@ def get_base_resource_prefix() -> str: """ Get the base resource prefix for resources created without tenant context. - Format: "b" + version (e.g., "b0001", "b0002"). + Format: "s" + version (e.g., "s0001", "s0002"). Version is read from env TENANT_BASE_RESOURCE_VERSION, defaults to "0001". Returns: - str: e.g., "b0001" + str: e.g., "s0001" """ version = os.getenv("TENANT_BASE_RESOURCE_VERSION", "0001") return f"{BASE_RESOURCE_PREFIX_LETTER}{version}" diff --git a/methods/evermemos/src/core/tenants/tenantize/kv/redis/tenant_key_utils.py b/methods/evermemos/src/core/tenants/tenantize/kv/redis/tenant_key_utils.py index 7a7a73385..87fdc74f0 100644 --- a/methods/evermemos/src/core/tenants/tenantize/kv/redis/tenant_key_utils.py +++ b/methods/evermemos/src/core/tenants/tenantize/kv/redis/tenant_key_utils.py @@ -9,35 +9,22 @@ from core.tenants.tenant_contextvar import get_current_tenant_id -def patch_redis_tenant_key(key: str) -> str: +def build_tenant_redis_key(prefix: str, tenant_id: str, key: str) -> str: """ - Add tenant prefix to Redis key name + Build a tenant-scoped Redis key with an explicit tenant_id. - Retrieve the tenant ID from the current context and prepend it to the key to achieve multi-tenant data isolation. - If no tenant information is set in the current context, return the original key. - - Format: {tenant_id}:{key} + Format: {prefix}:{tenant_id}:{key} Args: - key: Original Redis key name + prefix: Key namespace prefix (e.g. "task_status") + tenant_id: Tenant identifier + key: Business key (e.g. task_id, request_id) Returns: - str: Redis key name with tenant prefix; if no tenant, return the original key + str: "{prefix}:{tenant_id}:{key}" Examples: - >>> # Assume current tenant ID is "tenant_001" - >>> patch_redis_tenant_key("conversation_data:group_123") - 'tenant_001:conversation_data:group_123' - - >>> # If no tenant is set - >>> patch_redis_tenant_key("conversation_data:group_123") - 'conversation_data:group_123' + >>> build_tenant_redis_key("task_status", "t3a7b2c1d9e", "abc123") + 'task_status:t3a7b2c1d9e:abc123' """ - tenant_id: Optional[str] = get_current_tenant_id() - - if tenant_id: - # When tenant ID exists, concatenate tenant prefix - return f"{tenant_id}:{key}" - - # When no tenant ID exists, return the original key - return key + return f"{prefix}:{tenant_id}:{key}" diff --git a/methods/evermemos/src/core/tenants/tenantize/tenant_cache_utils.py b/methods/evermemos/src/core/tenants/tenantize/tenant_cache_utils.py index 8cca153d5..88a49c159 100644 --- a/methods/evermemos/src/core/tenants/tenantize/tenant_cache_utils.py +++ b/methods/evermemos/src/core/tenants/tenantize/tenant_cache_utils.py @@ -17,6 +17,7 @@ """ from typing import TypeVar, Callable, Optional, Union +from core.constants.exceptions import CriticalError from core.observation.logger import get_logger from core.tenants.tenant_contextvar import get_current_tenant from core.tenants.tenant_config import ( @@ -29,6 +30,16 @@ T = TypeVar("T") +class TenantContextMissingError(CriticalError): + """Raised when tenant context is missing after app startup (strict check). + + Inherits from CriticalError so that ``reraise_critical_errors()`` will + propagate it out of ``asyncio.gather(return_exceptions=True)`` result processing. + """ + + pass + + def get_or_compute_tenant_cache( patch_key: TenantPatchKey, compute_func: Callable[[], T], @@ -95,7 +106,7 @@ def get_or_compute_tenant_cache( if not tenant_info: # Strict check mode: after app startup, tenant context must exist in tenant mode if config.app_ready: - raise RuntimeError( + raise TenantContextMissingError( f"🚨 Strict tenant check failed: app is ready but tenant context is missing!" f"This usually indicates a serious code issue, please check the call chain." f"[cache_key={patch_key.value}, cache_description={cache_description}]" @@ -142,6 +153,8 @@ def get_or_compute_tenant_cache( return computed_value + except CriticalError: + raise except Exception as e: # Exception handling: try to use fallback (lazy evaluation) fallback_value = _resolve_fallback(fallback, cache_description) diff --git a/methods/evermemos/src/devops_scripts/data_fix/milvus_rebuild_collection.py b/methods/evermemos/src/devops_scripts/data_fix/milvus_rebuild_collection.py index f2e8d3bba..841b2005e 100644 --- a/methods/evermemos/src/devops_scripts/data_fix/milvus_rebuild_collection.py +++ b/methods/evermemos/src/devops_scripts/data_fix/milvus_rebuild_collection.py @@ -13,7 +13,7 @@ - All collections rebuild: --all Usage (via bootstrap with SKIP_LIFESPAN to avoid schema validation on startup): - SKIP_LIFESPAN=true python src/bootstrap.py src/devops_scripts/data_fix/milvus_rebuild_collection.py --all + SKIP_LIFESPAN=true TENANT_INIT_STORAGE_INFO='...' python src/bootstrap.py src/devops_scripts/data_fix/milvus_rebuild_collection.py --all Note: This script migrates data by default (in batches of 3000). To disable data migration, use the --no-migrate-data option. @@ -204,7 +204,7 @@ def discover_all_aliases() -> List[str]: Discover all concrete collection aliases by scanning MilvusCollectionBase subclasses. Returns: - List of collection base names (e.g., ["v1_episodic_memory", "v1_user_profile"]) + List of collection base names (e.g., ["v1_episodic_memory", "v1_user_profile"]) #skip-sensitive-check """ aliases = [] for cls in get_all_subclasses(MilvusCollectionBase): @@ -301,8 +301,8 @@ def main(argv: Optional[List[str]] = None) -> int: # Rebuild ALL discovered collections ... --all - # With tenant context - SKIP_LIFESPAN=true TENANT_SINGLE_TENANT_ID=xxx python src/bootstrap.py ... --all + # With tenant context (recommended) + SKIP_LIFESPAN=true TENANT_INIT_STORAGE_INFO='{"tenant_id":"s0001","isolation_mode":"shared","storage_info":{"milvus":{"collection_prefix":"s0001"}}}' python src/bootstrap.py ... --all #skip-sensitive-check # Rebuild without migrating data ... -a v1_episodic_memory --no-migrate-data @@ -348,6 +348,15 @@ def main(argv: Optional[List[str]] = None) -> int: migrate_data = not args.no_migrate_data progress = StdoutProgressReporter() + # Set up tenant context from TENANT_INIT_STORAGE_INFO if available + import os + + if os.getenv("TENANT_INIT_STORAGE_INFO"): + from core.tenants.init_tenant_all import setup_tenant_context_from_env + + tenant_id = setup_tenant_context_from_env() + logger.info("Rebuild running with tenant context: %s", tenant_id) + # Determine which aliases to rebuild if args.rebuild_all: aliases = discover_all_aliases() diff --git a/methods/evermemos/src/infra_layer/adapters/input/api/memory/memory_controller.py b/methods/evermemos/src/infra_layer/adapters/input/api/memory/memory_controller.py index 67674cecc..a9ae63b98 100644 --- a/methods/evermemos/src/infra_layer/adapters/input/api/memory/memory_controller.py +++ b/methods/evermemos/src/infra_layer/adapters/input/api/memory/memory_controller.py @@ -47,12 +47,7 @@ AddResponse, FlushResponse, ) -from api_specs.dtos.memory import ( - AgentAddRequest, - AgentFlushRequest, - AgentFlushClusteringRequest, - AgentFlushClusteringResponse, -) +from api_specs.dtos.memory import AgentAddRequest, AgentFlushRequest from api_specs.dtos.memory_delete import DeleteMemoriesRequest from core.request import log_request from core.request.app_logic_provider import AppLogicProvider @@ -142,19 +137,21 @@ async def add_personal_memories( if session_id and session_id != DEFAULT_SESSION_ID: asyncio.create_task(self._ensure_session_exists(session_id=session_id)) - # Auto-register senders from messages - messages = request_data.get("messages", []) - self._auto_register_senders(messages) + # Auto-register senders from converted data (includes auto-filled sender_ids) + self._auto_register_senders(memorize_request.new_raw_data_list) # Enrich sender_name from DB for messages that didn't provide one - await self._enrich_sender_names( - messages, memorize_request.new_raw_data_list - ) + messages = request_data.get("messages", []) + with timed("enrich_sender_names"): + await self._enrich_sender_names( + messages, memorize_request.new_raw_data_list + ) # Content enrichment (e.g. multimodal parsing, no-op by default) # Must run BEFORE save_request_logs so that parsed multimodal text # is included in the flat content saved to RawMessage. - await self._content_enrich.enrich(memorize_request.new_raw_data_list) + with timed("enrich_content"): + await self._content_enrich.enrich(memorize_request.new_raw_data_list) # Save request logs with timed("persist_raw_messages"): @@ -273,9 +270,9 @@ async def add_group_memories( group_id = memorize_request.group_id - # Auto-register group with metadata (only when group_meta is provided) - group_meta = request_data.get("group_meta") - if group_id and group_meta: + # Auto-register group (with optional metadata) + if group_id: + group_meta = request_data.get("group_meta") or {} asyncio.create_task( self._ensure_group_exists( group_id=group_id, @@ -284,19 +281,21 @@ async def add_group_memories( ) ) - # Auto-register senders - messages = request_data.get("messages", []) - self._auto_register_senders(messages) + # Auto-register senders from converted data (includes sender_ids from request) + self._auto_register_senders(memorize_request.new_raw_data_list) # Enrich sender_name from DB for messages that didn't provide one - await self._enrich_sender_names( - messages, memorize_request.new_raw_data_list - ) + messages = request_data.get("messages", []) + with timed("enrich_sender_names"): + await self._enrich_sender_names( + messages, memorize_request.new_raw_data_list + ) # Content enrichment (e.g. multimodal parsing, no-op by default) # Must run BEFORE save_request_logs so that parsed multimodal text # is included in the flat content saved to RawMessage. - await self._content_enrich.enrich(memorize_request.new_raw_data_list) + with timed("enrich_content"): + await self._content_enrich.enrich(memorize_request.new_raw_data_list) # Save request logs with timed("persist_raw_messages"): @@ -609,73 +608,25 @@ async def flush_agent_memories( status_code=500, detail="Flush failed, please try again later" ) from e - # ========================================================================= - # Flush Clustering (POST /memories/agent/flush-clustering) - # ========================================================================= - - @post( - "/agent/flush-clustering", - response_model=AgentFlushClusteringResponse, - summary="Flush pending clustering", - description="Force-drain all pending memcells for a group and run batch clustering. " - "Use this when cluster_batch_size > 1 and you want to trigger clustering immediately.", - ) - @log_request() - @stage_timed("flush_clustering") - async def flush_clustering_endpoint( - self, request: FastAPIRequest, request_body: AgentFlushClusteringRequest = None - ) -> AgentFlushClusteringResponse: - """POST /api/v1/memories/agent/flush-clustering - Force batch clustering.""" - del request_body - - try: - request_data = await request.json() - user_id = request_data.get("user_id") - if not user_id: - raise ValueError("user_id is required") - - logger.info( - "Received flush-clustering: user_id=%s", - user_id, - ) - - from biz_layer.mem_memorize import flush_clustering - - pending_count = await flush_clustering(user_id=user_id) - - status = 'clustered' if pending_count > 0 else 'no_clustering' - return { - "data": { - "request_id": self._app_logic.get_current_request_id(), - "status": status, - "message": "Flush clustering completed", - } - } - - except ValueError as e: - logger.error("Flush-clustering parameter error: %s", e) - raise HTTPException(status_code=422, detail=str(e)) from e - except Exception as e: - logger.error("Flush-clustering failed: %s", e, exc_info=True) - raise HTTPException( - status_code=500, - detail="Flush clustering failed, please try again later", - ) from e - # ========================================================================= # Helper Methods # ========================================================================= - def _auto_register_senders(self, messages: list) -> None: - """Fire-and-forget auto-register senders from message list.""" + def _auto_register_senders(self, raw_data_list: list) -> None: + """Fire-and-forget auto-register senders from converted raw data list. + + Uses the converted RawData objects (not the original request JSON) + so that auto-filled sender_ids are included. + """ seen = set() - for msg in messages: - sender_id = msg.get("sender_id") + for raw_data in raw_data_list: + content = raw_data.content + sender_id = content.get("sender_id") if sender_id and sender_id not in seen: seen.add(sender_id) asyncio.create_task( self._ensure_sender_exists( - sender_id=sender_id, sender_name=msg.get("sender_name") + sender_id=sender_id, sender_name=content.get("sender_name") ) ) @@ -834,18 +785,30 @@ async def add_agent_memories( if session_id and session_id != DEFAULT_SESSION_ID: asyncio.create_task(self._ensure_session_exists(session_id=session_id)) - # Auto-register senders (skip role=tool as they are not real users) - messages = request_data.get("messages", []) - for msg in messages: - if msg.get("role") != "tool": - sender_id = msg.get("sender_id") + # Auto-register senders from converted data (skip role=tool) + for raw_data in memorize_request.new_raw_data_list: + content = raw_data.content + if content.get("role") != "tool": + sender_id = content.get("sender_id") if sender_id: asyncio.create_task( self._ensure_sender_exists( - sender_id=sender_id, sender_name=msg.get("sender_name") + sender_id=sender_id, + sender_name=content.get("sender_name"), ) ) + # Enrich sender_name from DB for messages that didn't provide one + messages = request_data.get("messages", []) + await self._enrich_sender_names( + messages, memorize_request.new_raw_data_list + ) + + # Content enrichment (e.g. multimodal parsing, no-op by default) + # Must run BEFORE save_request_logs so that parsed multimodal text + # is included in the flat content saved to RawMessage. + await self._content_enrich.enrich(memorize_request.new_raw_data_list) + # Save request logs with timed("persist_raw_messages"): await self._save_raw_messages( diff --git a/methods/evermemos/src/infra_layer/adapters/out/persistence/document/memory/mem_scene.py b/methods/evermemos/src/infra_layer/adapters/out/persistence/document/memory/mem_scene.py index fb2719ec1..c43f74825 100644 --- a/methods/evermemos/src/infra_layer/adapters/out/persistence/document/memory/mem_scene.py +++ b/methods/evermemos/src/infra_layer/adapters/out/persistence/document/memory/mem_scene.py @@ -29,19 +29,17 @@ class MemScene(TenantAwareDocumentBase, AuditBase): description="Per-cluster centroid, latest timestamp, and member count", ) + # Cluster IDs that contain agent conversation (case) memcells. + # Used to route case memcells to LLM-based clustering and exclude them from embedding-only clustering. + case_cluster_ids: Optional[List[str]] = Field( + default_factory=list, description="Cluster IDs containing agent case memcells" + ) + # Auto-increment counter for cluster ID generation (cluster_000, cluster_001, ...). next_cluster_idx: int = Field( default=0, description="Counter for generating unique cluster IDs" ) - # Pending memcells waiting for batch clustering. - # Each entry is a dict with keys: event_id, episode, timestamp, participants, group_id, agent_case. - # Drained when len(pending_clustering) >= cluster_batch_size or on explicit flush. - pending_clustering: List[Dict[str, Any]] = Field( - default_factory=list, - description="Memcells waiting for batch clustering", - ) - class Settings: name = "v1_mem_scenes" indexes = [ diff --git a/methods/evermemos/src/infra_layer/adapters/out/persistence/repository/agent_case_raw_repository.py b/methods/evermemos/src/infra_layer/adapters/out/persistence/repository/agent_case_raw_repository.py index 67f4c8e01..dad81ce3b 100644 --- a/methods/evermemos/src/infra_layer/adapters/out/persistence/repository/agent_case_raw_repository.py +++ b/methods/evermemos/src/infra_layer/adapters/out/persistence/repository/agent_case_raw_repository.py @@ -14,6 +14,7 @@ from core.oxm.mongo.base_repository import BaseRepository from infra_layer.adapters.out.persistence.document.memory.agent_case import ( AgentCaseRecord, + AgentCaseProjection, ) from agentic_layer.vectorize_service import get_vectorize_service @@ -345,6 +346,38 @@ async def delete_by_filters( logger.error(f"[AgentCaseRepo] Failed to delete by filters: {e}") return 0 + async def fetch_task_intents_by_event_ids( + self, event_ids: List[str] + ) -> Dict[str, str]: + """Fetch task_intent texts from AgentCase DB by parent event IDs. + + Used as context_fetcher callback for ClusterManager in LLM mode. + + Args: + event_ids: List of memcell event IDs (used as parent_id in agent cases) + + Returns: + Dict mapping event_id -> task_intent text + """ + if not event_ids: + return {} + + try: + cases = ( + await self.model.find({"parent_id": {"$in": event_ids}}) + .project(AgentCaseProjection) + .to_list() + ) + + result: Dict[str, str] = {} + for case in cases: + if case.parent_id and case.task_intent: + result[case.parent_id] = case.task_intent + return result + except Exception as e: + logger.error(f"[AgentCaseRepo] Failed to fetch task intents: {e}") + return {} + async def find_by_filter_paginated( self, query_filter: Optional[Dict[str, Any]] = None, diff --git a/methods/evermemos/src/infra_layer/adapters/out/search/milvus/converter/agent_skill_milvus_converter.py b/methods/evermemos/src/infra_layer/adapters/out/search/milvus/converter/agent_skill_milvus_converter.py index 2d86b387e..c5ae9f72f 100644 --- a/methods/evermemos/src/infra_layer/adapters/out/search/milvus/converter/agent_skill_milvus_converter.py +++ b/methods/evermemos/src/infra_layer/adapters/out/search/milvus/converter/agent_skill_milvus_converter.py @@ -46,6 +46,7 @@ def from_mongo(cls, source_doc: AgentSkillRecord) -> Dict[str, Any]: try: name = source_doc.name or "" description = source_doc.description or "" + content = source_doc.content or "" # Primary text field: name + description combined content_field = "\n".join(s for s in [name, description] if s) @@ -56,7 +57,7 @@ def from_mongo(cls, source_doc: AgentSkillRecord) -> Dict[str, Any]: "user_id": source_doc.user_id or "", "group_id": source_doc.group_id or "", "cluster_id": source_doc.cluster_id or "", - "content": content_field[:4000], + "content": content_field[:5000], "maturity_score": source_doc.maturity_score, "confidence": source_doc.confidence, } diff --git a/methods/evermemos/src/manage.py b/methods/evermemos/src/manage.py index f06f3095f..a0345cadd 100644 --- a/methods/evermemos/src/manage.py +++ b/methods/evermemos/src/manage.py @@ -216,12 +216,11 @@ def tenant_init( """ Initialize MongoDB and Milvus databases for a specific tenant - Tenant ID is specified via environment variable TENANT_SINGLE_TENANT_ID. - Database connection configurations are obtained from default environment variables. + Tenant context is specified via environment variable TENANT_INIT_STORAGE_INFO (JSON). Examples: - # Set tenant ID environment variable - export TENANT_SINGLE_TENANT_ID=tenant_001 + # Set tenant context + export TENANT_INIT_STORAGE_INFO='{"tenant_id":"s0001","isolation_mode":"shared","storage_info":{"mongodb":{"database":"s0001_memsys"},"elasticsearch":{"index_prefix":"s0001"},"milvus":{"collection_prefix":"s0001","num_partitions":256}}}' #skip-sensitive-check # Run initialization python src/manage.py tenant-init diff --git a/methods/evermemos/src/memory_layer/cluster_manager/config.py b/methods/evermemos/src/memory_layer/cluster_manager/config.py index 87b3300cd..412c7fce4 100644 --- a/methods/evermemos/src/memory_layer/cluster_manager/config.py +++ b/methods/evermemos/src/memory_layer/cluster_manager/config.py @@ -13,6 +13,9 @@ class ClusterManagerConfig: enable_persistence: Whether to persist mem scene state to disk persist_dir: Directory for mem scene state persistence (required if enable_persistence=True) clustering_algorithm: Algorithm to use ('centroid' or 'nearest') + llm_top_k_clusters: Number of candidate clusters pre-filtered by embedding for LLM + llm_max_context_per_cluster: Max recent items per cluster in LLM context + llm_skip_threshold: Skip LLM if top-1 embedding similarity exceeds this """ similarity_threshold: float = 0.65 @@ -20,6 +23,13 @@ class ClusterManagerConfig: enable_persistence: bool = False persist_dir: str = None clustering_algorithm: str = "centroid" # 'centroid' or 'nearest' + # LLM clustering: number of candidate clusters pre-filtered by embedding similarity + llm_top_k_clusters: int = 30 + # LLM clustering: max recent items per cluster to include in LLM context + llm_max_context_per_cluster: int = 5 + # LLM clustering: if top-1 embedding similarity exceeds this threshold, + # skip LLM and assign directly (set to 1.0 to always use LLM) + llm_skip_threshold: float = 0.85 def __post_init__(self): """Validate configuration.""" diff --git a/methods/evermemos/src/memory_layer/cluster_manager/manager.py b/methods/evermemos/src/memory_layer/cluster_manager/manager.py index 3c4e71ea0..0f807cb95 100644 --- a/methods/evermemos/src/memory_layer/cluster_manager/manager.py +++ b/methods/evermemos/src/memory_layer/cluster_manager/manager.py @@ -11,6 +11,7 @@ """ import asyncio +import json import numpy as np from typing import Any, Callable, Dict, List, Optional, Tuple from pathlib import Path @@ -30,6 +31,7 @@ logger.warning("Vectorize service not available, clustering will be limited") + class MemSceneState: """Internal state for a single group's clustering.""" @@ -47,8 +49,8 @@ def __init__(self): self.cluster_counts: Dict[str, int] = {} self.cluster_last_ts: Dict[str, Optional[float]] = {} - # Pending memcells waiting for batch clustering - self.pending_clustering: List[Dict[str, Any]] = [] + # Clusters that contain agent conversation memcells + self.case_cluster_ids: set = set() def assign_new_cluster(self, event_id: str) -> str: """Assign a new cluster ID to an event.""" @@ -127,12 +129,14 @@ def to_dict(self) -> Dict[str, Any]: "count": self.cluster_counts.get(cid, 0), } - return { + result = { "memcell_info": memcell_info, "memscene_info": memscene_info, "next_cluster_idx": self.next_cluster_idx, - "pending_clustering": self.pending_clustering, } + if self.case_cluster_ids: + result["case_cluster_ids"] = sorted(self.case_cluster_ids) + return result @staticmethod def from_dict(data: Dict[str, Any]) -> "MemSceneState": @@ -143,6 +147,7 @@ def from_dict(data: Dict[str, Any]) -> "MemSceneState": """ state = MemSceneState() state.next_cluster_idx = int(data.get("next_cluster_idx", 0)) + state.case_cluster_ids = set(data.get("case_cluster_ids") or []) if "memcell_info" in data: # New format @@ -178,8 +183,6 @@ def from_dict(data: Dict[str, Any]) -> "MemSceneState": k: float(v) for k, v in (data.get("cluster_last_ts", {}) or {}).items() } - state.pending_clustering = list(data.get("pending_clustering", [])) - return state @@ -208,16 +211,26 @@ class ClusterManager: ``` """ - def __init__(self, config: Optional[ClusterManagerConfig] = None): + def __init__( + self, + config: Optional[ClusterManagerConfig] = None, + llm_provider: Optional[Any] = None, + context_fetcher: Optional[Callable] = None, + ): """Initialize ClusterManager. Args: config: Clustering configuration (uses defaults if None) + llm_provider: LLM provider instance (required for agent memcell clustering) + context_fetcher: Async callback to fetch context texts from DB. + Signature: (event_ids: List[str]) -> Dict[str, str] + Returns mapping of event_id -> task_intent text. + Required for agent memcell clustering. """ self.config = config or ClusterManagerConfig() self._callbacks: List[Callable] = [] - # Vectorize service + # Vectorize service (for embedding) self._vectorize_service = None if VECTORIZE_SERVICE_AVAILABLE: try: @@ -225,6 +238,10 @@ def __init__(self, config: Optional[ClusterManagerConfig] = None): except Exception as e: logger.warning(f"Failed to initialize vectorize service: {e}") + # LLM provider (for llm algorithm) + self._llm_provider = llm_provider + self._context_fetcher = context_fetcher + # Statistics self._stats = { "total_memcells": 0, @@ -244,22 +261,37 @@ def on_cluster_assigned( self._callbacks.append(callback) async def cluster_memcell( - self, memcell: Dict[str, Any], state: MemSceneState + self, + memcell: Dict[str, Any], + state: MemSceneState, + has_case: bool = False, ) -> Tuple[Optional[str], MemSceneState]: """Cluster a memcell and return updated state. - Pure computation method - no storage operations. Caller is responsible for loading state before and saving it after. + Routing: + - has_case=False: embedding clustering over non-case clusters, text=episode + - has_case=True: embedding recall + LLM over case clusters, text=task_intent + Args: memcell: Memcell dictionary with event_id, timestamp, episode/summary state: Current mem scene state for the group + has_case: Whether this memcell has an agent case Returns: Tuple of (cluster_id, updated_state): - cluster_id: Assigned cluster ID, or None if failed - state: Updated MemSceneState (same object, mutated) """ + if has_case: + return await self._cluster_memcell_llm(memcell, state) + return await self._cluster_memcell_embedding(memcell, state) + + async def _cluster_memcell_embedding( + self, memcell: Dict[str, Any], state: MemSceneState + ) -> Tuple[Optional[str], MemSceneState]: + """Embedding-based clustering using vector cosine similarity.""" self._stats["total_memcells"] += 1 # Extract key fields @@ -273,273 +305,377 @@ async def cluster_memcell( # Get embedding vector = await self._get_embedding(text) - cluster_id = await self._assign_to_cluster( - event_id, vector, timestamp, state + if vector is None or vector.size == 0: + logger.warning( + f"Failed to get embedding for event {event_id}, creating singleton cluster" + ) + cluster_id = state.assign_new_cluster(event_id) + state.event_ids.append(event_id) + state.timestamps.append(timestamp or 0.0) + state.vectors.append(np.zeros((1,), dtype=np.float32)) + self._stats["new_clusters"] += 1 + self._stats["failed_embeddings"] += 1 + return cluster_id, state + + # Find best matching cluster (exclude case clusters) + cluster_id = self._find_best_cluster( + state, vector, timestamp, exclude_cids=state.case_cluster_ids ) + + # Add to cluster + if cluster_id is None: + cluster_id = state.assign_new_cluster(event_id) + state._update_cluster_centroid(cluster_id, vector, timestamp) + self._stats["new_clusters"] += 1 + else: + state.add_to_cluster(event_id, cluster_id, vector, timestamp) + + # Update state + state.event_ids.append(event_id) + state.timestamps.append(timestamp or 0.0) + state.vectors.append(vector) + + self._stats["clustered_memcells"] += 1 + return cluster_id, state - async def cluster_memcells_batch_llm( + def _create_new_cluster( self, - memcells: List[Dict[str, Any]], state: MemSceneState, - llm_provider=None, - existing_cluster_episodes: Optional[Dict[str, List[str]]] = None, - extra_body: Optional[dict] = None, - ) -> Tuple[List[Optional[str]], MemSceneState]: - """Cluster multiple memcells using LLM-based assignment. + event_id: str, + vector: Optional[np.ndarray], + timestamp: Optional[float], + is_case: bool = False, + ) -> str: + """Create a new cluster and assign the event to it.""" + cluster_id = state.assign_new_cluster(event_id) + if is_case: + state.case_cluster_ids.add(cluster_id) + # _update_cluster_centroid handles cluster_counts when vector is present; + # for None/empty vector we must set it explicitly. + if vector is not None and vector.size > 0: + state._update_cluster_centroid(cluster_id, vector, timestamp) + else: + state.cluster_counts[cluster_id] = 1 + if timestamp is not None: + state.cluster_last_ts[cluster_id] = timestamp + self._stats["new_clusters"] += 1 + return cluster_id - Instead of embedding similarity, sends batch items and existing cluster - descriptions to an LLM, which decides assignments (existing or new clusters). + def _assign_to_cluster( + self, + state: MemSceneState, + event_id: str, + cluster_id: str, + vector: Optional[np.ndarray], + timestamp: Optional[float], + ) -> None: + """Assign an event to an existing cluster.""" + state.eventid_to_cluster[event_id] = cluster_id + state.cluster_ids.append(cluster_id) + # _update_cluster_centroid handles cluster_counts when vector is present; + # for None/empty vector we must increment explicitly. + if vector is not None and vector.size > 0: + state._update_cluster_centroid(cluster_id, vector, timestamp) + else: + state.cluster_counts[cluster_id] = ( + state.cluster_counts.get(cluster_id, 0) + 1 + ) + if timestamp is not None: + prev_ts = state.cluster_last_ts.get(cluster_id) + state.cluster_last_ts[cluster_id] = max( + prev_ts or timestamp, timestamp + ) - Falls back to per-item cluster_memcell if LLM call fails. + def _append_event( + self, + state: MemSceneState, + event_id: str, + vector: Optional[np.ndarray], + timestamp: Optional[float], + ) -> None: + """Append event metadata to state lists.""" + state.event_ids.append(event_id) + state.timestamps.append(timestamp or 0.0) + state.vectors.append( + vector if vector is not None else np.zeros((1,), dtype=np.float32) + ) - Args: - memcells: List of memcell dicts - state: Current MemSceneState - llm_provider: LLMProvider instance for LLM calls - existing_cluster_episodes: Dict of cluster_id -> list of recent episode - texts (up to 3), fetched by the caller from the database. - extra_body: Optional extra body for LLM requests (e.g. disable thinking). + async def _cluster_memcell_llm( + self, + memcell: Dict[str, Any], + state: MemSceneState, + ) -> Tuple[Optional[str], MemSceneState]: + """LLM-based clustering with embedding pre-filtering. - Returns: - (list_of_cluster_ids, updated_state) + Two-stage approach: + 1. Use embedding similarity to recall top-K candidate clusters + 2. Fetch recent episodes for candidates, let LLM make the final decision """ - if not memcells: - return [], state + self._stats["total_memcells"] += 1 - if llm_provider is None: - logger.warning( - "[LLMClustering] No LLM provider, falling back to per-item embedding clustering" + event_id = str(memcell.get("event_id", "")) + if not event_id: + logger.warning("Memcell missing event_id, skipping clustering") + return None, state + + timestamp = self._parse_timestamp(memcell.get("timestamp")) + text = self._extract_text(memcell) + + if self._llm_provider is None: + logger.error( + "[LLM Clustering] No LLM provider configured, " + "falling back to embedding-only case clustering" + ) + vector = await self._get_embedding(text) + best_cid = self._find_top_k_clusters( + state, vector, k=1, only_cids=state.case_cluster_ids, ) - return await self._fallback_cluster_individually(memcells, state) + if best_cid and best_cid[0][1] >= self.config.similarity_threshold: + cluster_id = best_cid[0][0] + self._assign_to_cluster( + state, event_id, cluster_id, vector, timestamp + ) + else: + cluster_id = self._create_new_cluster( + state, event_id, vector, timestamp, is_case=True + ) + self._append_event(state, event_id, vector, timestamp) + self._stats["clustered_memcells"] += 1 + return cluster_id, state + + # No existing case clusters — just create a new one directly + if not state.case_cluster_ids: + vector = await self._get_embedding(text) + cluster_id = self._create_new_cluster( + state, event_id, vector, timestamp, is_case=True + ) + self._append_event(state, event_id, vector, timestamp) + self._stats["clustered_memcells"] += 1 + logger.info( + f"[LLM Clustering] First case cluster: {event_id} -> {cluster_id}" + ) + return cluster_id, state - num_clusters = len(existing_cluster_episodes) if existing_cluster_episodes else 0 - if num_clusters > 200: - logger.warning( - f"[LLMClustering] Too many existing clusters ({num_clusters} > 200), " - "falling back to per-item embedding clustering" + # Stage 1: Embedding recall — find top-K candidate clusters (case only) + vector = await self._get_embedding(text) + scored_candidates = self._find_top_k_clusters( + state, vector, + k=self.config.llm_top_k_clusters, + only_cids=state.case_cluster_ids, + ) + candidate_ids = [cid for cid, _ in scored_candidates] + top1_sim = scored_candidates[0][1] if scored_candidates else -1.0 + logger.info( + f"[LLM Clustering] Embedding recall: {len(candidate_ids)} candidates " + f"(top1_sim={top1_sim:.3f}), " + f"from {len(state.case_cluster_ids)} case clusters" + ) + + # Fast path: if top-1 similarity is high enough, skip LLM + if top1_sim >= self.config.llm_skip_threshold: + cluster_id = scored_candidates[0][0] + self._assign_to_cluster(state, event_id, cluster_id, vector, timestamp) + self._append_event(state, event_id, vector, timestamp) + self._stats["clustered_memcells"] += 1 + logger.info( + f"[LLM Clustering] Fast path: {event_id} -> {cluster_id} " + f"(sim={top1_sim:.3f} >= {self.config.llm_skip_threshold})" ) - return await self._fallback_cluster_individually(memcells, state) - episode_truncate_len = 200 if num_clusters > 100 else 500 - - # ===== Extract fields ===== - event_ids: List[str] = [] - timestamps: List[Optional[float]] = [] - texts: List[str] = [] - valid_indices: List[int] = [] - for i, mc in enumerate(memcells): - eid = str(mc.get("event_id", "")) - if not eid: - logger.warning("Memcell missing event_id in batch, skipping") - continue - event_ids.append(eid) - timestamps.append(self._parse_timestamp(mc.get("timestamp"))) - texts.append(self._extract_text(mc)) - valid_indices.append(i) - - cluster_ids: List[Optional[str]] = [None] * len(memcells) - if not event_ids: - return cluster_ids, state - - # ===== Build LLM prompt ===== - existing_clusters_desc = [] - if existing_cluster_episodes: - for cid, episodes in existing_cluster_episodes.items(): - if not episodes: - continue - count = state.cluster_counts.get(cid, 0) - existing_clusters_desc.append({ - "cluster_id": cid, - "recent_episodes": [ep[:episode_truncate_len] for ep in episodes[-3:]], - "item_count": count, - }) + return cluster_id, state - new_items_desc = [] - for idx, text in enumerate(texts): - new_items_desc.append({ - "item_index": idx, - "text": text[:500], - }) + # Stage 2: Fetch recent context for candidates + cluster_context = await self._fetch_cluster_context(state, candidate_ids) - import json + # Stage 3: LLM decision + clusters_json = self._build_clusters_json( + state, candidate_ids, cluster_context + ) + next_new_id = f"{state.next_cluster_idx:03d}" from memory_layer.prompts import get_prompt_by - from common_utils.json_utils import parse_json_response - prompt_template = get_prompt_by("CLUSTER_LLM_ASSIGNMENT_PROMPT") - next_new_id = state.next_cluster_idx + prompt_template = get_prompt_by("AGENT_CLUSTER_LLM_ASSIGN_PROMPT") prompt = prompt_template.format( - existing_clusters_json=json.dumps(existing_clusters_desc, ensure_ascii=False, indent=2), - new_items_json=json.dumps(new_items_desc, ensure_ascii=False, indent=2), + memcell_text=text, + clusters_json=clusters_json, next_new_id=next_new_id, - next_new_id_plus1=next_new_id + 1, ) - - # ===== Call LLM ===== - llm_result = None - for attempt in range(2): - try: - resp = await llm_provider.generate(prompt, extra_body=extra_body) - data = parse_json_response(resp) - if data and isinstance(data.get("assignments"), list): - llm_result = data - break - logger.warning( - f"[LLMClustering] LLM retry {attempt + 1}/2: invalid format" - ) - except Exception as e: - logger.warning(f"[LLMClustering] LLM retry {attempt + 1}/2: {e}") + llm_result = await self._call_llm_for_clustering(prompt) if llm_result is None: logger.warning( - "[LLMClustering] LLM failed after retries, falling back to per-item embedding clustering" + f"[LLM Clustering] LLM call failed for event {event_id}, " + f"falling back to embedding top-1" ) - return await self._fallback_cluster_individually(memcells, state) - - # ===== Batch embed (for centroid updates) ===== - vectors = await self._get_embeddings_batch(texts) - - # ===== Apply LLM assignments ===== - assignments = llm_result["assignments"] - assigned_indices = set() - - for assignment in assignments: - item_idx = assignment.get("item_index") - target_cid = assignment.get("cluster_id") - if item_idx is None or target_cid is None: - continue - if item_idx < 0 or item_idx >= len(event_ids): - continue - if item_idx in assigned_indices: - continue - assigned_indices.add(item_idx) - - eid = event_ids[item_idx] - ts = timestamps[item_idx] - vec = vectors[item_idx] if item_idx < len(vectors) else None - text = texts[item_idx] - - self._stats["total_memcells"] += 1 - - known_clusters = set(state.cluster_centroids.keys()) - if existing_cluster_episodes: - known_clusters.update(existing_cluster_episodes.keys()) - is_existing = target_cid in known_clusters - if is_existing: - # Assign to existing cluster - if vec is not None and hasattr(vec, 'size') and vec.size > 0: - state.add_to_cluster(eid, target_cid, vec, ts) - else: - state.eventid_to_cluster[eid] = target_cid - state.cluster_ids.append(target_cid) - self._stats["clustered_memcells"] += 1 + # Fall back to embedding: use top-1 candidate if available, else new cluster + if scored_candidates and scored_candidates[0][1] >= self.config.similarity_threshold: + cluster_id = scored_candidates[0][0] + self._assign_to_cluster( + state, event_id, cluster_id, vector, timestamp + ) else: - # New cluster assigned by LLM - state.eventid_to_cluster[eid] = target_cid - state.cluster_ids.append(target_cid) - # Update next_cluster_idx if LLM used a higher index - try: - idx_num = int(target_cid.split("_")[-1]) - if idx_num >= state.next_cluster_idx: - state.next_cluster_idx = idx_num + 1 - except (ValueError, IndexError): - pass - if vec is not None and hasattr(vec, 'size') and vec.size > 0: - state._update_cluster_centroid(target_cid, vec, ts) - else: - state._update_cluster_centroid(target_cid, None, ts) - self._stats["new_clusters"] += 1 - self._stats["clustered_memcells"] += 1 - - state.event_ids.append(eid) - state.timestamps.append(ts or 0.0) - state.vectors.append(vec if vec is not None else np.zeros((1,), dtype=np.float32)) - cluster_ids[valid_indices[item_idx]] = target_cid - - # Handle any items not assigned by LLM (fallback: create singleton clusters) - unassigned = [i for i in range(len(event_ids)) if i not in assigned_indices] - if unassigned: - logger.warning( - "[LLMClustering] %d/%d items not assigned by LLM, falling back to embedding clustering: %s", - len(unassigned), len(event_ids), - [event_ids[i] for i in unassigned], - ) - unassigned_memcells = [memcells[valid_indices[idx]] for idx in unassigned] - fallback_ids, state = await self._fallback_cluster_individually( - unassigned_memcells, state - ) - for idx, cid in zip(unassigned, fallback_ids): - cluster_ids[valid_indices[idx]] = cid + cluster_id = self._create_new_cluster( + state, event_id, vector, timestamp, is_case=True + ) + else: + chosen_id = llm_result.get("cluster_id", "") + if chosen_id in state.cluster_counts and chosen_id in state.case_cluster_ids: + cluster_id = chosen_id + self._assign_to_cluster( + state, event_id, cluster_id, vector, timestamp + ) + else: + cluster_id = self._create_new_cluster( + state, event_id, vector, timestamp, is_case=True + ) - valid_ids = [cid for cid in cluster_ids if cid is not None] - unique_clusters = set(valid_ids) + self._append_event(state, event_id, vector, timestamp) + self._stats["clustered_memcells"] += 1 + reason = llm_result.get("reason", "") if llm_result else "" logger.info( - f"[LLMClustering] 📦 Batch complete: {len(event_ids)} memcells -> " - f"{len(unique_clusters)} unique clusters" + f"[LLM Clustering] 🎯 Event {event_id} -> {cluster_id} " + f"| intent: {text} | reason: {reason}" ) + return cluster_id, state - return cluster_ids, state - - async def _fallback_cluster_individually( + def _find_top_k_clusters( self, - memcells: List[Dict[str, Any]], state: MemSceneState, - ) -> Tuple[List[Optional[str]], MemSceneState]: - """Fallback: cluster each memcell individually via embedding similarity. + vector: Optional[np.ndarray], + k: int = 10, + only_cids: Optional[set] = None, + ) -> List[Tuple[str, float]]: + """Find top-K candidate clusters by embedding similarity. + + Args: + only_cids: If provided, only consider these cluster IDs. - Used when LLM clustering fails. Iterates over each memcell and calls - cluster_memcell one by one. + Returns: + List of (cluster_id, similarity) tuples, sorted by similarity desc. + Similarity is -1.0 when embedding is unavailable. """ - cluster_ids: List[Optional[str]] = [] - for mc in memcells: - cid, state = await self.cluster_memcell(mc, state) - cluster_ids.append(cid) - return cluster_ids, state + all_cids = list(state.cluster_counts.keys()) + if only_cids is not None: + all_cids = [c for c in all_cids if c in only_cids] + if not all_cids: + return [] + + # If no embedding or no centroids, return all with unknown similarity + if vector is None or vector.size == 0 or not state.cluster_centroids: + return [(c, -1.0) for c in all_cids[:k]] - async def _assign_to_cluster( + # Score each cluster by cosine similarity (ignore time gap for recall stage) + vector_norm = np.linalg.norm(vector) + 1e-9 + scored = [] + for cid in all_cids: + centroid = state.cluster_centroids.get(cid) + if centroid is None or centroid.size == 0: + scored.append((cid, -1.0)) + continue + centroid_norm = np.linalg.norm(centroid) + 1e-9 + sim = float((centroid @ vector) / (centroid_norm * vector_norm)) + scored.append((cid, sim)) + + scored.sort(key=lambda x: x[1], reverse=True) + return scored[:k] + + async def _fetch_cluster_context( self, - event_id: str, - vector: Optional[np.ndarray], - timestamp: Optional[float], state: MemSceneState, - ) -> Optional[str]: - """Core assignment logic for single-memcell clustering. + candidate_ids: List[str], + ) -> Dict[str, List[str]]: + """Fetch recent context texts for candidate clusters via context_fetcher. - Assigns event_id to a cluster (existing or new), updates state in-place, - and returns the cluster_id. + Returns: + Dict mapping cluster_id -> list of recent context texts """ - if vector is None or (hasattr(vector, 'size') and vector.size == 0): - logger.warning( - f"Failed to get embedding for event {event_id}, creating singleton cluster" + if not self._context_fetcher or not candidate_ids: + return {} + + max_per = self.config.llm_max_context_per_cluster + + # Collect recent event_ids per candidate cluster + from collections import defaultdict + + candidate_set = set(candidate_ids) + cluster_event_ids: Dict[str, List[str]] = defaultdict(list) + for eid, cid in state.eventid_to_cluster.items(): + if cid in candidate_set: + cluster_event_ids[cid].append(eid) + + # Take last N per cluster, collect all target event_ids + cluster_slices: Dict[str, List[str]] = {} + all_target_eids: List[str] = [] + for cid in candidate_ids: + eids = cluster_event_ids.get(cid, []) + recent = eids[-max_per:] + cluster_slices[cid] = recent + all_target_eids.extend(recent) + + if not all_target_eids: + return {} + + # Call the fetcher: event_ids -> {event_id: episode_text} + eid_to_text = await self._context_fetcher(all_target_eids) + + # Assemble per-cluster context + result: Dict[str, List[str]] = {} + for cid, eids in cluster_slices.items(): + texts = [eid_to_text[eid] for eid in eids if eid in eid_to_text] + if texts: + result[cid] = texts + return result + + def _build_clusters_json( + self, + state: MemSceneState, + candidate_ids: List[str], + cluster_context: Dict[str, List[str]], + ) -> str: + """Build JSON representation of candidate clusters for LLM prompt.""" + if not candidate_ids: + return "(No existing clusters)" + + clusters = [] + for cid in candidate_ids: + count = state.cluster_counts.get(cid, 0) + recent = cluster_context.get(cid, []) + clusters.append( + { + "cluster_id": cid, + "item_count": count, + "recent_task_intents": recent, + } ) - cluster_id = state.assign_new_cluster(event_id) - state.event_ids.append(event_id) - state.timestamps.append(timestamp or 0.0) - state.vectors.append(np.zeros((1,), dtype=np.float32)) - self._stats["new_clusters"] += 1 - self._stats["failed_embeddings"] += 1 - return cluster_id + return json.dumps(clusters, ensure_ascii=False, indent=2) - # Find best matching cluster - cluster_id = self._find_best_cluster(state, vector, timestamp) - - # Add to cluster - if cluster_id is None: - cluster_id = state.assign_new_cluster(event_id) - state._update_cluster_centroid(cluster_id, vector, timestamp) - self._stats["new_clusters"] += 1 - else: - state.add_to_cluster(event_id, cluster_id, vector, timestamp) - - # Update state - state.event_ids.append(event_id) - state.timestamps.append(timestamp or 0.0) - state.vectors.append(vector) + async def _call_llm_for_clustering( + self, prompt: str + ) -> Optional[Dict[str, Any]]: + """Call LLM and parse clustering decision.""" + for attempt in range(3): + try: + resp = await self._llm_provider.generate(prompt) + from common_utils.json_utils import parse_json_response - self._stats["clustered_memcells"] += 1 - return cluster_id + data = parse_json_response(resp) + if data and "cluster_id" in data: + return data + logger.warning( + f"[LLM Clustering] Retry {attempt + 1}/3: invalid response format" + ) + except Exception as e: + logger.warning( + f"[LLM Clustering] Retry {attempt + 1}/3: {e}" + ) + return None def _find_best_cluster( - self, state: MemSceneState, vector: np.ndarray, timestamp: Optional[float] + self, + state: MemSceneState, + vector: np.ndarray, + timestamp: Optional[float], + exclude_cids: Optional[set] = None, ) -> Optional[str]: """Find the best matching cluster for a vector.""" if not state.cluster_centroids: @@ -551,6 +687,8 @@ def _find_best_cluster( vector_norm = np.linalg.norm(vector) + 1e-9 for cluster_id, centroid in state.cluster_centroids.items(): + if exclude_cids and cluster_id in exclude_cids: + continue if centroid is None or centroid.size == 0: continue @@ -590,31 +728,15 @@ async def _get_embedding(self, text: str) -> Optional[np.ndarray]: return None - async def _get_embeddings_batch(self, texts: List[str]) -> List[Optional[np.ndarray]]: - """Get embeddings for multiple texts in a single batched call.""" - if not self._vectorize_service: - logger.warning("Vectorize service not available for batch embedding") - return [None] * len(texts) - - try: - vectors = await self._vectorize_service.get_embeddings(texts) - return [ - np.array(v, dtype=np.float32) if v is not None else None - for v in vectors - ] - except Exception as e: - logger.warning(f"Batch embedding failed, falling back to individual: {e}") - # Fallback: call individual embeddings - results = [] - for text in texts: - results.append(await self._get_embedding(text)) - return results - def _extract_text(self, memcell: Dict[str, Any]) -> str: """Extract representative text from memcell. - Priority: episode > original_data + Priority: clustering_text > episode > original_data """ + clustering_text = memcell.get("clustering_text") + if isinstance(clustering_text, str) and clustering_text.strip(): + return clustering_text.strip() + episode = memcell.get("episode") if isinstance(episode, str) and episode.strip(): return episode.strip() diff --git a/methods/evermemos/src/memory_layer/llm/api_key_rotator.py b/methods/evermemos/src/memory_layer/llm/api_key_rotator.py new file mode 100644 index 000000000..767e0ab2f --- /dev/null +++ b/methods/evermemos/src/memory_layer/llm/api_key_rotator.py @@ -0,0 +1,90 @@ +"""API key round-robin rotator. + +Supports multiple keys rotating in turn to spread rate-limit pressure +across OpenRouter and other providers. Behaves identically to a single +key when only one key is supplied. +""" + +import itertools +from collections.abc import Sequence +from typing import ClassVar + +from core.observation.logger import get_logger + +logger = get_logger(__name__) + + +class ApiKeyRotator: + """API Key rotator for round-robin selection. + + Spreads rate-limit pressure across multiple keys. Behaves identically + to a single key when only one key is supplied. Process-level shared + instance available via ``get_or_create``. + + Args: + keys: One or more API keys for rotation. + + Note: + Relies on asyncio single-threaded event loop; not thread-safe. + """ + + _shared: ClassVar["ApiKeyRotator | None"] = None + + def __init__(self, keys: Sequence[str]) -> None: + if not keys: + raise ValueError("At least one API key is required") + if len(keys) != len(set(keys)): + logger.warning( + "ApiKeyRotator: duplicate keys detected, rotation may be uneven" + ) + self._keys: tuple[str, ...] = tuple(keys) + self._cycle = itertools.cycle(range(len(self._keys))) + + def get_rotation(self) -> tuple[str, ...]: + """Return all keys starting from the current cycle position. + + Advances the global cycle by one position for load distribution. + The returned tuple is a per-request local snapshot: + + - ``rotation[0]`` is the key for the first attempt (equivalent to + a single ``get_next()`` call). + - ``rotation[1:]`` are the retry keys, starting from the one after + the first-attempt key, guaranteeing each key is tried before any + key is reused. + + Concurrent requests each advance the cycle independently, so their + first-attempt keys are naturally staggered. + """ + start_idx = next(self._cycle) + if self.size > 1: + logger.debug( + "ApiKeyRotator: selected key index %d/%d", start_idx + 1, self.size + ) + return tuple(self._keys[(start_idx + i) % self.size] for i in range(self.size)) + + @property + def size(self) -> int: + """Number of API keys in the rotation pool.""" + return len(self._keys) + + @classmethod + def get_or_create(cls, raw: str) -> "ApiKeyRotator": + """Get the process-level shared instance, creating it on first call. + + Subsequent calls return the existing instance; ``raw`` is only used + for the initial creation and is ignored afterward. + """ + if cls._shared is None: + keys = [k.strip() for k in raw.split(",") if k.strip()] + cls._shared = cls(keys) + else: + new_keys = tuple(k.strip() for k in raw.split(",") if k.strip()) + if new_keys != cls._shared._keys: + logger.warning( + "ApiKeyRotator: get_or_create called with different keys, " + "returning existing instance (keys are locked at first creation)" + ) + return cls._shared + + def __repr__(self) -> str: + return f"ApiKeyRotator(size={self.size})" diff --git a/methods/evermemos/src/memory_layer/llm/llm_metrics.py b/methods/evermemos/src/memory_layer/llm/llm_metrics.py new file mode 100644 index 000000000..543b9be37 --- /dev/null +++ b/methods/evermemos/src/memory_layer/llm/llm_metrics.py @@ -0,0 +1,61 @@ +""" +LLM Provider Metrics + +Prometheus metrics for monitoring LLM API call volume and error rates. +Co-located with the LLM provider per Prometheus instrumentation best practices. +""" + +from core.observation.metrics import Counter + + +# ============================================================ +# Counter Metrics +# ============================================================ + +LLM_REQUESTS_TOTAL = Counter( + name='llm_requests_total', + description='Total number of LLM API requests', + labelnames=['model', 'status'], + namespace='evermemos', + subsystem='memory_layer', +) +""" +LLM requests counter. + +Labels: +- model: LLM model name (e.g. "gpt-4.1-mini", "qwen/qwen3-235b-a22b-2507") +- status: Request outcome + - success: HTTP 200 with valid response + - rate_limit: HTTP 429 (all keys exhausted) + - key_error: HTTP 401/402/403 (all keys exhausted) + - server_error: HTTP 5xx (after max retries) + - client_error: Network / connection error (after max retries) + - request_error: HTTP 400/404/422 (no retry) + +PromQL examples: + # Total requests per second + rate(evermemos_memory_layer_llm_requests_total[5m]) + + # 429 count + evermemos_memory_layer_llm_requests_total{status="rate_limit"} + + # 429 ratio + evermemos_memory_layer_llm_requests_total{status="rate_limit"} + / evermemos_memory_layer_llm_requests_total +""" + + +# ============================================================ +# Helper Functions +# ============================================================ + + +def record_llm_request(model: str, status: str) -> None: + """Record an LLM request outcome. + + Args: + model: LLM model name. + status: Request outcome (success, rate_limit, key_error, + server_error, client_error, request_error). + """ + LLM_REQUESTS_TOTAL.labels(model=model, status=status).inc() diff --git a/methods/evermemos/src/memory_layer/llm/openai_provider.py b/methods/evermemos/src/memory_layer/llm/openai_provider.py index 0dbfc0de2..96ac717f8 100644 --- a/methods/evermemos/src/memory_layer/llm/openai_provider.py +++ b/methods/evermemos/src/memory_layer/llm/openai_provider.py @@ -4,81 +4,73 @@ This provider uses a caller-supplied API key and base URL. """ -from math import log +import asyncio +import json import os +import random import time -import json -import urllib.request -import urllib.parse -import urllib.error + import aiohttp -from typing import Optional -import asyncio -import random -from memory_layer.llm.protocol import LLMProvider, LLMError -from core.observation.logger import get_logger -from core.di.utils import get_bean_by_type from core.component.token_usage_collector import TokenUsageCollector +from core.di.utils import get_bean_by_type +from core.observation.logger import get_logger +from memory_layer.llm.api_key_rotator import ApiKeyRotator +from memory_layer.llm.llm_metrics import record_llm_request +from memory_layer.llm.protocol import LLMProvider, LLMError logger = get_logger(__name__) +_MAX_RETRIES = 5 -class OpenAIProvider(LLMProvider): - """ - OpenAI-compatible LLM provider. - This provider expects the caller to supply API key and base URL. +class OpenAIProvider(LLMProvider): + """OpenAI-compatible LLM provider. + + Sends requests to any OpenAI-compatible endpoint (OpenRouter, OpenAI, etc.) + with automatic multi-key rotation and differentiated retry strategies. + + Args: + model: Model name (e.g. "gpt-4.1-mini", "qwen/qwen3-235b-a22b-2507"). + api_key: API key(s), comma-separated for multi-key rotation. + base_url: API base URL. + temperature: Sampling temperature. + max_tokens: Maximum tokens to generate. + enable_stats: Enable per-call usage statistics. + provider_type: Provider identifier ("openai" or "openrouter"). """ def __init__( self, - model: str = "gpt-4.1-mini", + model: str = "gpt-4.1-mini", # skip-sensitive-check api_key: str | None = None, base_url: str | None = None, temperature: float = 0.3, max_tokens: int | None = 100 * 1024, - enable_stats: bool = False, # New: optional statistics feature, disabled by default - provider_type: str | None = None, # Provider type: "openai" or "openrouter" + enable_stats: bool = False, + provider_type: str | None = None, **kwargs, - ): - """ - Initialize OpenAI provider. - - Args: - model: Model name (e.g., "gpt-4o-mini", "gpt-4o") - api_key: API key (required by caller) - base_url: API base URL (required by caller) - temperature: Sampling temperature - max_tokens: Maximum tokens to generate - enable_stats: Enable usage statistics accumulation (default: False) - provider_type: Provider type ("openai" or "openrouter") - **kwargs: Additional arguments (ignored for now) - """ + ) -> None: self.model = model self.temperature = temperature self.max_tokens = max_tokens self.enable_stats = enable_stats - self.provider_type = (provider_type or "openrouter").lower() - self.api_key = api_key + self.provider_type = ( + provider_type or "openrouter" # skip-sensitive-check + ).lower() + self._key_rotator = ( + ApiKeyRotator.get_or_create(api_key) if api_key else ApiKeyRotator([""]) + ) self.base_url = base_url - # Validate model whitelist from env: {PROVIDER}_WHITE_LIST - # If whitelist is empty or not set, no restriction is applied. self._validate_model_whitelist(self.provider_type, model) - # Optional per-call statistics (disabled by default) if self.enable_stats: - self.current_call_stats = None # Store statistics for current call + self.current_call_stats = None @staticmethod def _validate_model_whitelist(provider_type: str, model: str) -> None: - """ - Validate model against the provider's whitelist from environment variable. - - Reads {PROVIDER}_WHITE_LIST env var (comma-separated model names). - If the env var is not set or empty, no restriction is applied. - """ + """Validate model against the provider's whitelist from environment variable.""" env_key = f"{provider_type.upper()}_WHITE_LIST" raw = os.getenv(env_key, "").strip() if not raw: @@ -91,233 +83,257 @@ def _validate_model_whitelist(provider_type: str, model: str) -> None: f"Provider '{provider_type}' only supports: {', '.join(sorted(allowed_models))}. Got: '{model}'." ) - async def generate( - self, - prompt: str, - temperature: float | None = None, - max_tokens: int | None = None, - extra_body: dict | None = None, - response_format: dict | None = None, - ) -> str: - """ - Generate a response for the given prompt. - - Args: - prompt: Input prompt - temperature: Override temperature for this request - max_tokens: Override max tokens for this request + @staticmethod + def _resolve_openrouter_provider() -> dict | None: + """Parse LLM_OPENROUTER_PROVIDER env var into an OpenRouter provider dict.""" + raw = os.getenv("LLM_OPENROUTER_PROVIDER", "default") # skip-sensitive-check + if raw == "default": + return None + provider_list = [p.strip() for p in raw.split(",")] + return {"order": provider_list, "allow_fallbacks": False} - Returns: - Generated response text + @staticmethod + def _extract_error_message(response_data: dict, status_code: int) -> str: + """Extract a human-readable error message from an error response body.""" + return response_data.get("error", {}).get("message", f"HTTP {status_code}") - Raises: - LLMError: If generation fails - """ - # Use time.perf_counter() for more precise time measurement - start_time = time.perf_counter() - # Prepare request data - if os.getenv("LLM_OPENROUTER_PROVIDER", "default") != "default": - provider_str = os.getenv('LLM_OPENROUTER_PROVIDER') - provider_list = [p.strip() for p in provider_str.split(',')] - openrouter_provider = {"order": provider_list, "allow_fallbacks": False} - else: - openrouter_provider = None - # Prepare request data - data = { + def _build_request_data( + self, + prompt: str, + temperature: float | None, + max_tokens: int | None, + response_format: dict | None, + ) -> dict: + """Build the JSON payload for the chat completions request.""" + data: dict = { "model": self.model, "messages": [{"role": "user", "content": prompt}], "temperature": temperature if temperature is not None else self.temperature, - "provider": openrouter_provider, + "provider": self._resolve_openrouter_provider(), "response_format": response_format, } - # print(data) - # print(data["extra_body"]) - # Add max_tokens if specified if max_tokens is not None: data["max_tokens"] = max_tokens elif self.max_tokens is not None: data["max_tokens"] = self.max_tokens + return data - # Merge per-call extra_body into request data - if extra_body: - data.update(extra_body) + async def _do_request(self, data: dict, api_key: str) -> tuple[int, dict]: + """Execute a single HTTP POST to the chat completions endpoint. - # Use asynchronous aiohttp instead of synchronous urllib + Returns: + (status_code, parsed_response_body) + """ headers = { - 'Content-Type': 'application/json', - 'Authorization': f'Bearer {self.api_key}', + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", } - max_retries = 5 - for retry_num in range(max_retries): + timeout = aiohttp.ClientTimeout(total=600) + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post( + f"{self.base_url}/chat/completions", json=data, headers=headers + ) as response: + raw = await response.read() + try: + response_data = json.loads(raw.decode()) + except (json.JSONDecodeError, UnicodeDecodeError): + # Non-JSON response (e.g. Cloudflare HTML error page) + return response.status, { + "error": {"message": raw[:500].decode(errors="replace")} + } + return response.status, response_data + + def _report_token_usage(self, prompt_tokens: int, completion_tokens: int) -> None: + """Report token usage to the global TokenUsageCollector (best-effort).""" + try: + collector = get_bean_by_type(TokenUsageCollector) + collector.add(self.model, prompt_tokens, completion_tokens, call_type="llm") + except Exception: + pass + + def _log_completion_metrics(self, response_data: dict, duration: float) -> None: + """Log finish reason, duration, and token usage for a completed request.""" + finish_reason = response_data.get("choices", [{}])[0].get("finish_reason", "") + if finish_reason == "stop": + logger.debug("[OpenAI-%s] Finish reason: %s", self.model, finish_reason) + else: + logger.warning("[OpenAI-%s] Finish reason: %s", self.model, finish_reason) + + usage = response_data.get("usage", {}) + prompt_tokens = usage.get("prompt_tokens", 0) + completion_tokens = usage.get("completion_tokens", 0) + + logger.debug("[OpenAI-%s] Duration: %.2fs", self.model, duration) + if duration > 30: + logger.warning("[OpenAI-%s] Duration too long: %.2fs", self.model, duration) + logger.debug( + "[OpenAI-%s] Tokens: %s prompt, %s completion, %s total", + self.model, + format(prompt_tokens, ","), + format(completion_tokens, ","), + format(usage.get("total_tokens", 0), ","), + ) + + self._report_token_usage(prompt_tokens, completion_tokens) + + def _handle_success(self, response_data: dict, start_time: float) -> str: + """Process a successful (HTTP 200) response: log metrics, report usage, return text.""" + duration = time.perf_counter() - start_time + self._log_completion_metrics(response_data, duration) + + if self.enable_stats: + usage = response_data.get("usage", {}) + self.current_call_stats = { + "prompt_tokens": usage.get("prompt_tokens", 0), + "completion_tokens": usage.get("completion_tokens", 0), + "total_tokens": usage.get("total_tokens", 0), + "duration": duration, + "timestamp": time.time(), + } + + return response_data["choices"][0]["message"]["content"] + + def _handle_key_error( + self, status_code: int, error_msg: str, consecutive_failures: int + ) -> int: + """Handle key-level errors (401/402/403/429): rotate key, raise if all exhausted.""" + consecutive_failures += 1 + if consecutive_failures >= self._key_rotator.size: + metric_status = "rate_limit" if status_code == 429 else "key_error" + record_llm_request(self.model, metric_status) + raise LLMError( + f"HTTP {status_code}: {error_msg} " + f"(all {self._key_rotator.size} keys exhausted)" + ) + logger.warning( + "[OpenAI-%s] Key error %d, rotating key (%d/%d exhausted)", + self.model, + status_code, + consecutive_failures, + self._key_rotator.size, + ) + return consecutive_failures + + async def _handle_server_error( + self, status_code: int, error_msg: str, retry_num: int + ) -> None: + """Handle 5xx server error: sleep and retry, or raise on final attempt.""" + if retry_num < _MAX_RETRIES - 1: + logger.warning( + "[OpenAI-%s] Server error %d, retry %d/%d", + self.model, + status_code, + retry_num + 1, + _MAX_RETRIES, + ) + await asyncio.sleep(random.randint(5, 20)) + return + record_llm_request(self.model, "server_error") + raise LLMError( + f"HTTP Error {status_code}: {error_msg} (after {_MAX_RETRIES} retries)" + ) + + async def _execute_with_retry(self, data: dict, start_time: float) -> str: + """Retry loop: key-level errors rotate key, 5xx backs off with sleep.""" + consecutive_key_failures = 0 + key_rotation = self._key_rotator.get_rotation() + + for retry_num in range(_MAX_RETRIES): + current_key = key_rotation[retry_num % len(key_rotation)] try: - timeout = aiohttp.ClientTimeout(total=600) - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post( - f"{self.base_url}/chat/completions", json=data, headers=headers - ) as response: - chunks = [] - async for chunk in response.content.iter_any(): - chunks.append(chunk) - test = b"".join(chunks).decode() - response_data = json.loads(test) - # print(response_data) - # Handle error responses - if response.status != 200: - error_msg = response_data.get('error', {}).get( - 'message', f"HTTP {response.status}" - ) - logger.error( - f"❌ [OpenAI-{self.model}] HTTP error {response.status}:" - ) - logger.error(f" 💬 Error message: {error_msg}") - - # Retryable errors: rate limit, server errors - if response.status in (429, 500, 502, 503, 504): - logger.warning( - f"Retryable error {response.status}, retry {retry_num + 1}/{max_retries}" - ) - await asyncio.sleep(random.randint(5, 20)) - if retry_num < max_retries - 1: - continue # Retry - # Last retry failed - raise LLMError( - f"HTTP Error {response.status}: {error_msg} (after {max_retries} retries)" - ) - - # Non-retryable errors (401, 403, 404, 400, etc.) - raise LLMError(f"HTTP Error {response.status}: {error_msg}") - - # Use time.perf_counter() for more precise time measurement - end_time = time.perf_counter() - - # Extract finish_reason - finish_reason = response_data.get('choices', [{}])[0].get( - 'finish_reason', '' - ) - if finish_reason == 'stop': - logger.debug( - f"[OpenAI-{self.model}] Finish reason: {finish_reason}" - ) - else: - logger.warning( - f"[OpenAI-{self.model}] Finish reason: {finish_reason}" - ) - - # Extract token usage information - usage = response_data.get('usage', {}) - prompt_tokens = usage.get('prompt_tokens', 0) - completion_tokens = usage.get('completion_tokens', 0) - total_tokens = usage.get('total_tokens', 0) - - # Print detailed usage information - - logger.debug(f"[OpenAI-{self.model}] API call completed:") - logger.debug( - f"[OpenAI-{self.model}] Duration: {end_time - start_time:.2f}s" - ) - # If the duration is too long - if end_time - start_time > 30: - logger.warning( - f"[OpenAI-{self.model}] Duration too long: {end_time - start_time:.2f}s" - ) - logger.debug( - f"[OpenAI-{self.model}] Prompt Tokens: {prompt_tokens:,}" - ) - logger.debug( - f"[OpenAI-{self.model}] Completion Tokens: {completion_tokens:,}" - ) - logger.debug( - f"[OpenAI-{self.model}] Total Tokens: {total_tokens:,}" - ) - - # Report token usage to collector - try: - collector = get_bean_by_type(TokenUsageCollector) - collector.add( - self.model, - prompt_tokens, - completion_tokens, - call_type="llm", - ) - except Exception: - pass - - # New: record statistics for current call (if statistics enabled) - if self.enable_stats: - self.current_call_stats = { - 'prompt_tokens': prompt_tokens, - 'completion_tokens': completion_tokens, - 'total_tokens': total_tokens, - 'duration': end_time - start_time, - 'timestamp': time.time(), - } - - message = response_data['choices'][0]['message'] - reasoning = message.get('reasoning_content') or message.get('reasoning') or message.get('thinking') - if reasoning: - logger.debug( - f"[OpenAI-{self.model}] " - f"🧠 Thinking detected: " - f"{len(reasoning)} chars" - ) - else: - logger.debug( - f"[OpenAI-{self.model}] " - f"💭 No thinking in response" - ) - return message['content'] - - except aiohttp.ClientError as e: - error_time = time.perf_counter() - logger.error("aiohttp.ClientError: %s", e) - # logger.error(f"❌ [OpenAI-{self.model}] Request failed:") - logger.error(f" ⏱️ Duration: {error_time - start_time:.2f}s") - logger.error(f" 💬 Error message: {str(e)}") - logger.error(f"retry_num: {retry_num}") - # raise LLMError(f"Request failed: {str(e)}") - if retry_num == max_retries - 1: - raise LLMError(f"Request failed: {str(e)}") - except Exception as e: - error_time = time.perf_counter() - logger.error("Exception: %s", e) - logger.error(f" ⏱️ Duration: {error_time - start_time:.2f}s") - logger.error(f" 💬 Error message: {str(e)}") - logger.error(f"retry_num: {retry_num}") - if retry_num == max_retries - 1: - raise LLMError(f"Request failed: {str(e)}") + status_code, response_data = await self._do_request(data, current_key) + except aiohttp.ClientError as exc: + logger.error("aiohttp.ClientError: %s", exc) + if retry_num == _MAX_RETRIES - 1: + record_llm_request(self.model, "client_error") + raise LLMError(f"Request failed: {exc}") from exc + continue + except LLMError: + raise + except Exception as exc: + logger.error("Unexpected error: %s", exc) + if retry_num == _MAX_RETRIES - 1: + record_llm_request(self.model, "client_error") + raise LLMError(f"Request failed: {exc}") from exc + continue + + if status_code == 200: + record_llm_request(self.model, "success") + return self._handle_success(response_data, start_time) + + error_msg = self._extract_error_message(response_data, status_code) + logger.error("[OpenAI-%s] HTTP %d: %s", self.model, status_code, error_msg) + + # Key-level errors: rotate key immediately, no sleep. + # - 401 Unauthorized: invalid/missing key + # - 402 Payment Required: key quota exhausted + # - 403 Forbidden: key lacks permission + # - 429 Too Many Requests: key rate-limited + if status_code in (401, 402, 403, 429): + consecutive_key_failures = self._handle_key_error( + status_code, error_msg, consecutive_key_failures + ) + continue - async def test_connection(self) -> bool: - """ - Test the connection to the OpenRouter API. + # 5xx: sleep then retry (key rotates per retry_rotation sequence) + if status_code in (500, 502, 503, 504): + await self._handle_server_error(status_code, error_msg, retry_num) + continue - Returns: - True if connection successful, False otherwise - """ + # Request-level errors (400, 404, 422, etc.): not key-related, no retry + record_llm_request(self.model, "request_error") + raise LLMError(f"HTTP Error {status_code}: {error_msg}") + + record_llm_request(self.model, "client_error") + raise LLMError(f"Request failed after {_MAX_RETRIES} retries") + + async def generate( + self, + prompt: str, + temperature: float | None = None, + max_tokens: int | None = None, + extra_body: dict | None = None, + response_format: dict | None = None, + ) -> str: + """Generate a response for the given prompt.""" + start_time = time.perf_counter() + data = self._build_request_data( + prompt, temperature, max_tokens, response_format + ) + return await self._execute_with_retry(data, start_time) + + async def test_connection(self) -> bool: + """Test the connection to the API endpoint.""" try: - logger.info(f"🔗 [OpenAI-{self.model}] Testing API connection...") - # Try a simple generation to test connection + logger.info("\U0001f517 [OpenAI-%s] Testing API connection...", self.model) test_response = await self.generate("Hello", temperature=0.1) success = len(test_response) > 0 if success: - logger.info(f"✅ [OpenAI-{self.model}] API connection test succeeded") + logger.info( + "\u2705 [OpenAI-%s] API connection test succeeded", self.model + ) else: logger.error( - f"❌ [OpenAI-{self.model}] API connection test failed: Empty response" + "\u274c [OpenAI-%s] API connection test failed: Empty response", + self.model, ) return success except Exception as e: - logger.error(f"❌ [OpenAI-{self.model}] API connection test failed: {e}") + logger.error( + "\u274c [OpenAI-%s] API connection test failed: %s", self.model, e + ) return False - def get_current_call_stats(self) -> Optional[dict]: + def get_current_call_stats(self) -> dict | None: + """Return per-call statistics if stats collection is enabled.""" if self.enable_stats: return self.current_call_stats return None def __repr__(self) -> str: - """String representation of the provider.""" return ( "OpenAIProvider(" - f"provider_type={self.provider_type}, model={self.model}, base_url={self.base_url}" + f"provider_type={self.provider_type}, model={self.model}, " + f"base_url={self.base_url}, keys={self._key_rotator.size}" ")" ) diff --git a/methods/evermemos/src/memory_layer/memcell_extractor/conv_memcell_extractor.py b/methods/evermemos/src/memory_layer/memcell_extractor/conv_memcell_extractor.py index bd0326a10..eff4c0130 100644 --- a/methods/evermemos/src/memory_layer/memcell_extractor/conv_memcell_extractor.py +++ b/methods/evermemos/src/memory_layer/memcell_extractor/conv_memcell_extractor.py @@ -143,7 +143,9 @@ def _extract_participant_ids( participant_ids = set() for raw_data in chat_raw_data_list: - if raw_data.get('role') == MessageSenderRole.USER.value and raw_data.get('sender_id'): + if raw_data.get('role') == MessageSenderRole.USER.value and raw_data.get( + 'sender_id' + ): participant_ids.add(raw_data['sender_id']) return list(participant_ids) @@ -400,50 +402,46 @@ async def _detect_boundaries( ) with timed("detect_boundaries"): + # Retry only when LLM returns unparseable content. + # Infrastructure errors (auth, rate-limit, network) are handled + # by the lower layer and will propagate as exceptions. for i in range(5): - try: - resp = await self.llm_provider.generate(prompt) - logger.debug( - f"[ConvMemCellExtractor] === BOUNDARY DETECTION RESPONSE (attempt {i+1}) ===\n" - f"{resp}\n" - f"[ConvMemCellExtractor] === END RESPONSE ===" - ) + resp = await self.llm_provider.generate(prompt) + logger.debug( + f"[ConvMemCellExtractor] === BOUNDARY DETECTION RESPONSE (attempt {i+1}) ===\n" + f"{resp}\n" + f"[ConvMemCellExtractor] === END RESPONSE ===" + ) - result = self._parse_batch_boundary_response(resp) - if result is not None: - # Validate boundary indices - valid_boundaries = [ - b for b in result.boundaries if 1 <= b < len(messages) - ] - if len(valid_boundaries) != len(result.boundaries): - logger.warning( - f"[ConvMemCellExtractor] Filtered {len(result.boundaries) - len(valid_boundaries)} " - f"out-of-range boundaries (total messages: {len(messages)})" - ) - result.boundaries = sorted(valid_boundaries) - - # Record metrics for the overall detection - detection_result = ( - 'should_end' if result.boundaries else 'should_wait' - ) - record_boundary_detection( - space_id=get_space_id_for_metrics(), - raw_data_type=self.raw_data_type.value, - result=detection_result, - trigger_type='llm', - ) - return result - else: + result = self._parse_batch_boundary_response(resp) + if result is not None: + # Validate boundary indices + valid_boundaries = [ + b for b in result.boundaries if 1 <= b < len(messages) + ] + if len(valid_boundaries) != len(result.boundaries): logger.warning( - f"[ConvMemCellExtractor] Failed to parse JSON from LLM response " - f"(attempt {i + 1}/5), response: {resp[:200]}..." + f"[ConvMemCellExtractor] Filtered {len(result.boundaries) - len(valid_boundaries)} " + f"out-of-range boundaries (total messages: {len(messages)})" ) - continue - except Exception as e: - logger.warning( - f"[ConvMemCellExtractor] Boundary detection error (attempt {i + 1}/5): {e}" + result.boundaries = sorted(valid_boundaries) + + # Record metrics for the overall detection + detection_result = ( + 'should_end' if result.boundaries else 'should_wait' + ) + record_boundary_detection( + space_id=get_space_id_for_metrics(), + raw_data_type=self.raw_data_type.value, + result=detection_result, + trigger_type='llm', ) - continue + return result + + logger.warning( + f"[ConvMemCellExtractor] Failed to parse JSON from LLM response " + f"(attempt {i + 1}/5), response: {resp[:200]}..." + ) # All retries exhausted, raise error to interrupt the flow error_msg = ( diff --git a/methods/evermemos/src/memory_layer/memory_extractor/agent_case_extractor.py b/methods/evermemos/src/memory_layer/memory_extractor/agent_case_extractor.py index b07de310b..317b760b2 100644 --- a/methods/evermemos/src/memory_layer/memory_extractor/agent_case_extractor.py +++ b/methods/evermemos/src/memory_layer/memory_extractor/agent_case_extractor.py @@ -97,11 +97,9 @@ def __init__( max_tool_output_tokens: int = MAX_TOOL_OUTPUT_TOKENS, max_tool_args_tokens: int = MAX_TOOL_ARGS_TOKENS, max_assistant_response_tokens: int = MAX_ASSISTANT_RESPONSE_TOKENS, - extra_body: dict | None = None, ): super().__init__(MemoryType.AGENT_CASE) self.llm_provider = llm_provider - self.extra_body = extra_body self.filter_prompt = filter_prompt or get_prompt_by("AGENT_CASE_FILTER_PROMPT") self.experience_compress_prompt = experience_compress_prompt or get_prompt_by( "AGENT_CASE_COMPRESS_PROMPT" @@ -369,9 +367,7 @@ async def _compress_tool_chunk( for attempt in range(2): try: - resp = await self.llm_provider.generate( - prompt, extra_body=self.extra_body - ) + resp = await self.llm_provider.generate(prompt) data = parse_json_response(resp) if ( data @@ -395,9 +391,7 @@ async def _filter_conversation(self, messages_json: str) -> bool: """LLM-based filter to determine if the conversation is worth extracting.""" prompt = self.filter_prompt.format(messages=messages_json) try: - resp = await self.llm_provider.generate( - prompt, extra_body=self.extra_body - ) + resp = await self.llm_provider.generate(prompt) data = parse_json_response(resp) if data and "worth_extracting" in data: worth = data["worth_extracting"] @@ -418,9 +412,7 @@ async def _compress_experience( for attempt in range(2): try: - resp = await self.llm_provider.generate( - prompt, extra_body=self.extra_body - ) + resp = await self.llm_provider.generate(prompt) data = parse_json_response(resp) if data and "task_intent" in data: if not data["task_intent"]: diff --git a/methods/evermemos/src/memory_layer/memory_extractor/agent_skill_extractor.py b/methods/evermemos/src/memory_layer/memory_extractor/agent_skill_extractor.py index 96e5c9368..9135178fa 100644 --- a/methods/evermemos/src/memory_layer/memory_extractor/agent_skill_extractor.py +++ b/methods/evermemos/src/memory_layer/memory_extractor/agent_skill_extractor.py @@ -26,6 +26,7 @@ from memory_layer.llm.llm_provider import LLMProvider from memory_layer.prompts import get_prompt_by from core.observation.logger import get_logger +from core.observation.stage_timer import timed logger = get_logger(__name__) @@ -113,7 +114,8 @@ def _truncate_text(text: str, max_chars: int = 800) -> str: text = (text or "").strip() if len(text) <= max_chars: return text - return text[: max_chars - 3].rstrip() + "..." + suffix = "... [omitted]" + return text[: max_chars - len(suffix)].rstrip() + suffix def _summarize_case_for_prompt( self, case_record: Any, max_approach_chars: int = 800 @@ -158,8 +160,8 @@ def _format_existing_skills( item: Dict[str, Any] = { "index": idx, "name": rec.name, - "description": rec.description, - "content": rec.content, + "description": self._truncate_text(rec.description, max_chars=1000), + "content": self._truncate_text(rec.content, max_chars=20000), "confidence": rec.confidence, } @@ -219,7 +221,8 @@ async def _select_top_k_skills( if not query_embedding: logger.warning( "[AgentSkillExtractor] Failed to compute query embedding for top-k selection, " - "falling back to first %d skills", top_k, + "falling back to first %d skills", + top_k, ) return existing_records[:top_k] query_vec = query_embedding["embedding"] @@ -242,7 +245,9 @@ async def _select_top_k_skills( logger.info( "[AgentSkillExtractor] Top-k selection: %d/%d skills selected (top_k=%d)", - len(selected), len(existing_records), top_k, + len(selected), + len(existing_records), + top_k, ) return selected @@ -265,18 +270,22 @@ async def _compute_embedding(self, text: str) -> Optional[Dict[str, Any]]: def _select_prompt(self, case_records: List[AgentCase]) -> str: """Select extraction prompt based on the max quality_score of new cases.""" - max_quality = max( - (getattr(rec, "quality_score", 0.5) or 0.5) for rec in case_records - ) if case_records else 0.5 + max_quality = ( + max((getattr(rec, "quality_score", 0.5) or 0.5) for rec in case_records) + if case_records + else 0.5 + ) if max_quality < self.FAILURE_QUALITY_THRESHOLD: logger.debug( "[AgentSkillExtractor] Using failure prompt (max_quality=%.2f < %.2f)", - max_quality, self.FAILURE_QUALITY_THRESHOLD, + max_quality, + self.FAILURE_QUALITY_THRESHOLD, ) return self.failure_extract_prompt logger.debug( "[AgentSkillExtractor] Using success prompt (max_quality=%.2f >= %.2f)", - max_quality, self.FAILURE_QUALITY_THRESHOLD, + max_quality, + self.FAILURE_QUALITY_THRESHOLD, ) return self.success_extract_prompt @@ -285,8 +294,7 @@ async def _call_llm( ) -> Optional[Dict[str, Any]]: """Single LLM call to produce incremental skill operations.""" prompt = prompt_template.format( - new_case_json=new_case_json, - existing_skills_json=existing_skills_json, + new_case_json=new_case_json, existing_skills_json=existing_skills_json ) for attempt in range(3): try: @@ -302,17 +310,16 @@ async def _call_llm( return None async def _evaluate_maturity( - self, - name: str, - description: str, - content: str, - confidence: float, + self, name: str, description: str, content: str, confidence: float ) -> Optional[float]: """Evaluate maturity of a skill via LLM scoring. Scores the skill across 4 dimensions (1-5 each, total out of 20), then normalizes to 0.0-1.0. """ + if self.skip_maturity_scoring: + logger.info("[AgentSkillExtractor] Maturity scoring skipped by config, returning 1.0") + return 1.0 try: prompt = self.maturity_prompt.format( name=name or "", @@ -334,14 +341,16 @@ async def _evaluate_maturity( logger.info( "[AgentSkillExtractor] Maturity evaluation: name='%s', " "raw=%.1f, score=%.2f, threshold=%.2f, ready=%s, reason=%s", - name, raw_total, score, self.maturity_threshold, - score >= self.maturity_threshold, data.get("reason", ""), + name, + raw_total, + score, + self.maturity_threshold, + score >= self.maturity_threshold, + data.get("reason", ""), ) return score except Exception as e: - logger.warning( - "[AgentSkillExtractor] Maturity evaluation failed: %s", e - ) + logger.warning("[AgentSkillExtractor] Maturity evaluation failed: %s", e) return None # Content change ratio below which maturity re-evaluation is always skipped @@ -359,9 +368,7 @@ def _is_hypothesis_promotion(old_content: str, new_content: str) -> bool: old_has_potential = bool( re.search(r"^##\s+Potential Steps", old_content or "", re.MULTILINE) ) - new_has_steps = bool( - re.search(r"^##\s+Steps", new_content or "", re.MULTILINE) - ) + new_has_steps = bool(re.search(r"^##\s+Steps", new_content or "", re.MULTILINE)) new_has_potential = bool( re.search(r"^##\s+Potential Steps", new_content or "", re.MULTILINE) ) @@ -406,7 +413,9 @@ async def _apply_add( data = op.get("data", {}) content = data.get("content", "") if not content: - logger.warning("[AgentSkillExtractor] add operation has empty content, skipping") + logger.warning( + "[AgentSkillExtractor] add operation has empty content, skipping" + ) return None if not self._is_skill_content_sufficient(content): @@ -420,7 +429,9 @@ async def _apply_add( name = data.get("name", "") description = data.get("description", "") if not name and not description: - logger.warning("[AgentSkillExtractor] add operation has no name and no description, skipping") + logger.warning( + "[AgentSkillExtractor] add operation has no name and no description, skipping" + ) return None try: @@ -435,19 +446,9 @@ async def _apply_add( AgentSkillRecord, ) - if self.skip_maturity_scoring: - score = 1.0 - logger.debug( - "[AgentSkillExtractor] Skipping maturity scoring for new skill '%s', " - "using default score=1.0", name, - ) - else: - score = await self._evaluate_maturity( - name=name, - description=description, - content=content, - confidence=confidence, - ) + score = await self._evaluate_maturity( + name=name, description=description, content=content, confidence=confidence + ) record = AgentSkillRecord( cluster_id=cluster_id, @@ -529,7 +530,8 @@ async def _apply_update( logger.warning( "[AgentSkillExtractor] update operation for index %d has insufficient content " "(too short or no steps), skipping. content=%r", - index, new_content[:100], + index, + new_content[:100], ) return False @@ -570,7 +572,10 @@ async def _apply_update( logger.warning( "[AgentSkillExtractor] Retiring skill[%d] (confidence=%.2f < %.2f): " "id=%s, name=%r", - index, final_confidence, self.retire_confidence, record_id, + index, + final_confidence, + self.retire_confidence, + record_id, getattr(record, "name", ""), ) retire_updates: Dict[str, Any] = {"confidence": final_confidence} @@ -586,7 +591,9 @@ async def _apply_update( # Re-embed only if name or description actually changed name_changed = bool(new_name) and new_name != (record.name or "") - desc_changed = bool(new_description) and new_description != (record.description or "") + desc_changed = bool(new_description) and new_description != ( + record.description or "" + ) if name_changed or desc_changed: effective_name = new_name or record.name or "" effective_desc = new_description or record.description or "" @@ -605,15 +612,11 @@ async def _apply_update( # - mature (>= threshold) AND confidence not dropping: skip # - immature (< threshold) AND case quality < 0.3: skip (low-quality case won't help) # - otherwise: re-score via LLM - real_content_changed = (bool(new_content) and new_content != (record.content or "")) + real_content_changed = bool(new_content) and new_content != ( + record.content or "" + ) content_changed = real_content_changed or name_changed or desc_changed - if self.skip_maturity_scoring and content_changed: - updates["maturity_score"] = 1.0 - logger.debug( - "[AgentSkillExtractor] Skipping maturity re-evaluation for skill[%d], " - "using default score=1.0", index, - ) - elif content_changed: + if content_changed: change_ratio = self._content_change_ratio( record.content or "", new_content or record.content or "" ) @@ -623,7 +626,9 @@ async def _apply_update( logger.info( "[AgentSkillExtractor] Skipping maturity re-evaluation for skill[%d]: " "trivial change_ratio=%.2f < %.2f", - index, change_ratio, self.MATURITY_TRIVIAL_CHANGE_RATIO, + index, + change_ratio, + self.MATURITY_TRIVIAL_CHANGE_RATIO, ) # 2) Major change (>= 40%) or hypothesis promotion: always LLM elif ( @@ -634,14 +639,19 @@ async def _apply_update( ): reason = ( "hypothesis promotion" - if self._is_hypothesis_promotion(record.content or "", new_content or "") + if self._is_hypothesis_promotion( + record.content or "", new_content or "" + ) else f"major content change (ratio={change_ratio:.2f})" ) logger.info( "[AgentSkillExtractor] %s for skill[%d], using LLM maturity evaluation", - reason, index, + reason, + index, + ) + await self._rescore_maturity( + updates, new_name, new_description, new_content, record ) - await self._rescore_maturity(updates, new_name, new_description, new_content, record) # 3) Moderate change (20~40%) else: old_score = record.maturity_score or 0.0 @@ -656,14 +666,22 @@ async def _apply_update( logger.info( "[AgentSkillExtractor] Skipping maturity re-evaluation for skill[%d]: " "already mature (%.2f >= %.2f), confidence=%.2f (dropping=%s), change_ratio=%.2f", - index, old_score, self.maturity_threshold, new_confidence_val, confidence_dropping, change_ratio, + index, + old_score, + self.maturity_threshold, + new_confidence_val, + confidence_dropping, + change_ratio, ) elif old_score < self.maturity_threshold and source_quality < 0.3: # Immature but low-quality case won't improve it: skip logger.info( "[AgentSkillExtractor] Skipping maturity re-evaluation for skill[%d]: " "immature (%.2f) but low source quality (%.2f < 0.3), change_ratio=%.2f", - index, old_score, source_quality, change_ratio, + index, + old_score, + source_quality, + change_ratio, ) else: # Re-score: immature skill with decent case, or mature but confidence dropping @@ -671,9 +689,14 @@ async def _apply_update( "[AgentSkillExtractor] Moderate change for skill[%d]: " "score=%.2f, confidence_dropping=%s, source_quality=%.2f, " "using LLM maturity evaluation", - index, old_score, confidence_dropping, source_quality, + index, + old_score, + confidence_dropping, + source_quality, + ) + await self._rescore_maturity( + updates, new_name, new_description, new_content, record ) - await self._rescore_maturity(updates, new_name, new_description, new_content, record) success = await skill_repo.update_skill_by_id(record_id, updates) if success: @@ -773,11 +796,10 @@ async def extract_and_save( f"[AgentSkillExtractor] {len(existing_skill_records)} existing skills exceed " f"max_skills_in_prompt={max_skills_in_prompt}, selecting top-k" ) - existing_skill_records = await self._select_top_k_skills( - existing_skill_records, - new_case_records, - top_k=max_skills_in_prompt, - ) + with timed("select_top_k_skills"): + existing_skill_records = await self._select_top_k_skills( + existing_skill_records, new_case_records, top_k=max_skills_in_prompt + ) # Load case history AFTER top-k selection so we only load cases # relevant to the skills that will actually appear in the prompt. @@ -796,7 +818,10 @@ async def extract_and_save( f"new_cases={len(new_case_records)}, existing_skills={len(existing_skill_records)}" ) - llm_result = await self._call_llm(new_case_json, existing_skills_json, prompt_template) + with timed("extract_skill_ops"): + llm_result = await self._call_llm( + new_case_json, existing_skills_json, prompt_template + ) if not llm_result: logger.warning( f"[AgentSkillExtractor] LLM extraction failed for cluster={cluster_id}" @@ -817,53 +842,65 @@ async def extract_and_save( update_count = 0 processed_indices: set = set() - for op in operations: - action = op.get("action", "none") - - if action == "add": - saved = await self._apply_add( - op, cluster_id, group_id, user_id, skill_repo, - source_case_ids=source_case_ids, - ) - if saved: - result.added_records.append(saved) - - elif action == "update": - try: - index = int(op.get("index", -1)) - except (ValueError, TypeError): - logger.warning( - f"[AgentSkillExtractor] update index is not a valid integer: {op.get('index')!r}, skipping" + with timed("apply_operations"): + for op in operations: + action = op.get("action", "none") + + if action == "add": + saved = await self._apply_add( + op, + cluster_id, + group_id, + user_id, + skill_repo, + source_case_ids=source_case_ids, ) - continue - if index in processed_indices: - logger.warning( - f"[AgentSkillExtractor] Duplicate operation on index {index}, skipping update" + if saved: + result.added_records.append(saved) + + elif action == "update": + try: + index = int(op.get("index", -1)) + except (ValueError, TypeError): + logger.warning( + f"[AgentSkillExtractor] update index is not a valid integer: {op.get('index')!r}, skipping" + ) + continue + if index in processed_indices: + logger.warning( + f"[AgentSkillExtractor] Duplicate operation on index {index}, skipping update" + ) + continue + processed_indices.add(index) + # Pass max quality_score from new cases for maturity decision + max_quality = ( + max( + (getattr(rec, "quality_score", 0.5) or 0.5) + for rec in new_case_records + ) + if new_case_records + else 0.5 ) - continue - processed_indices.add(index) - # Pass max quality_score from new cases for maturity decision - max_quality = max( - (getattr(rec, "quality_score", 0.5) or 0.5) - for rec in new_case_records - ) if new_case_records else 0.5 - success = await self._apply_update( - op, existing_skill_records, skill_repo, result, - source_case_ids=source_case_ids, - source_quality=max_quality, - ) - if success: - update_count += 1 + success = await self._apply_update( + op, + existing_skill_records, + skill_repo, + result, + source_case_ids=source_case_ids, + source_quality=max_quality, + ) + if success: + update_count += 1 - elif action == "none": - logger.debug( - f"[AgentSkillExtractor] No-op for cluster={cluster_id}" - ) + elif action == "none": + logger.debug( + f"[AgentSkillExtractor] No-op for cluster={cluster_id}" + ) - else: - logger.warning( - f"[AgentSkillExtractor] Unknown action '{action}', skipping" - ) + else: + logger.warning( + f"[AgentSkillExtractor] Unknown action '{action}', skipping" + ) logger.info( f"[AgentSkillExtractor] cluster={cluster_id} operations applied: " diff --git a/methods/evermemos/src/memory_layer/memory_extractor/episode_memory_extractor.py b/methods/evermemos/src/memory_layer/memory_extractor/episode_memory_extractor.py index 8339ca223..6f346a920 100644 --- a/methods/evermemos/src/memory_layer/memory_extractor/episode_memory_extractor.py +++ b/methods/evermemos/src/memory_layer/memory_extractor/episode_memory_extractor.py @@ -75,7 +75,6 @@ def __init__( """ super().__init__(MemoryType.EPISODIC_MEMORY) self.llm_provider = llm_provider - self.extra_body: dict | None = None self.default_parent_type = DEFAULT_MEMORIZE_CONFIG.default_episode_parent_type # Use custom prompts or get default via PromptManager @@ -268,9 +267,7 @@ async def _extract_episode( for i in range(5): try: prompt = prompt_template.format(**format_params) - response = await self.llm_provider.generate( - prompt, extra_body=self.extra_body - ) + response = await self.llm_provider.generate(prompt) # Parse JSON if '```json' in response: diff --git a/methods/evermemos/src/memory_layer/memory_manager.py b/methods/evermemos/src/memory_layer/memory_manager.py index 942b35d87..e496f5c54 100644 --- a/methods/evermemos/src/memory_layer/memory_manager.py +++ b/methods/evermemos/src/memory_layer/memory_manager.py @@ -44,10 +44,6 @@ from memory_layer.memcell_extractor.base_memcell_extractor import StatusResult from api_specs.memory_models import MessageSenderRole from memory_layer.constants import EXTRACT_SCENES -from biz_layer.memorize_config import ( - DEFAULT_MEMORIZE_CONFIG, - AGENT_DEFAULT_MEMORIZE_CONFIG, -) logger = get_logger(__name__) @@ -162,15 +158,6 @@ def _get_provider_for_scene(self, scene: str) -> LLMProvider: provider = self.providers_mapping.get(DEFAULT_PROVIDER_NAME) return provider - @staticmethod - def _get_skip_reasoning_extra_body(memcell: MemCell) -> dict | None: - """Return extra_body to disable reasoning if config requires it.""" - is_agent = memcell and memcell.type == RawDataType.AGENTCONVERSATION - config = AGENT_DEFAULT_MEMORIZE_CONFIG if is_agent else DEFAULT_MEMORIZE_CONFIG - if config.skip_episode_case_reasoning: - return {"chat_template_kwargs": {"enable_thinking": False}} - return None - # TODO: add username async def extract_memcell( self, @@ -334,7 +321,6 @@ async def _extract_episode( self._episode_extractor = EpisodeMemoryExtractor( self._get_provider_for_scene("extraction") ) - self._episode_extractor.extra_body = self._get_skip_reasoning_extra_body(memcell) # Build extraction request from memory_layer.memory_extractor.base_memory_extractor import ( @@ -475,8 +461,7 @@ async def _extract_agent_case( logger.debug("[MemoryManager] Extracting AgentCase") extractor = AgentCaseExtractor( - llm_provider=self._get_provider_for_scene("extraction"), - extra_body=self._get_skip_reasoning_extra_body(memcell), + llm_provider=self._get_provider_for_scene("extraction") ) request = AgentCaseExtractRequest( memcell=memcell, diff --git a/methods/evermemos/src/memory_layer/profile_manager/manager.py b/methods/evermemos/src/memory_layer/profile_manager/manager.py index 03fea6e72..fdf05120a 100644 --- a/methods/evermemos/src/memory_layer/profile_manager/manager.py +++ b/methods/evermemos/src/memory_layer/profile_manager/manager.py @@ -76,7 +76,6 @@ async def extract_profiles( group_id: Optional[str] = None, max_items: int = 25, scene: ScenarioType = ScenarioType.SOLO, - new_user_max_context: int = 0, ) -> List[ProfileMemory]: """Extract profiles from memcells (batch multi-user). @@ -156,19 +155,11 @@ async def extract_profiles( if old_profile and old_profile.last_updated: user_baseline = old_profile.last_updated else: - # New user: use the Nth oldest cluster memcell as baseline - # so that at most new_user_max_context episodes are included - max_cluster = new_user_max_context - 1 - if max_cluster > 0 and len(cluster_contexts) > max_cluster: - user_baseline = cluster_contexts[-(max_cluster + 1)].get("created_at") - elif max_cluster > 0 and cluster_contexts: - user_baseline = None - else: - user_baseline = new_context.get("created_at") + user_baseline = new_context.get("created_at") user_cluster_episodes = [ ep for ep in cluster_contexts - if user_baseline is None or ep.get("created_at") is None or ep.get("created_at") > user_baseline + if ep.get("created_at") is None or ep.get("created_at") > user_baseline ] # Build request diff --git a/methods/evermemos/src/memory_layer/prompts/__init__.py b/methods/evermemos/src/memory_layer/prompts/__init__.py index 369fceb1f..453b1a145 100644 --- a/methods/evermemos/src/memory_layer/prompts/__init__.py +++ b/methods/evermemos/src/memory_layer/prompts/__init__.py @@ -109,7 +109,12 @@ "zh": ("memory_layer.prompts.zh.agent_prompts", False), }, # Clustering - "CLUSTER_LLM_ASSIGNMENT_PROMPT": { + "AGENT_CLUSTER_LLM_ASSIGN_PROMPT": { + "en": ("memory_layer.prompts.en.agent_prompts", False), + "zh": ("memory_layer.prompts.zh.agent_prompts", False), + }, + # Skill relevance verification + "AGENT_SKILL_RELEVANCE_VERIFY_PROMPT": { "en": ("memory_layer.prompts.en.agent_prompts", False), "zh": ("memory_layer.prompts.zh.agent_prompts", False), }, diff --git a/methods/evermemos/src/memory_layer/prompts/en/agent_prompts.py b/methods/evermemos/src/memory_layer/prompts/en/agent_prompts.py index 3572f6ac7..4940bf177 100644 --- a/methods/evermemos/src/memory_layer/prompts/en/agent_prompts.py +++ b/methods/evermemos/src/memory_layer/prompts/en/agent_prompts.py @@ -163,15 +163,20 @@ **What makes a GOOD skill:** - Reasoning principles WITH concrete patterns: teaches HOW to think, not just what to do - Decision branches that cover the different problem variants seen across cases -- Examples preserve real entity names, numbers, and scenarios from cases +- A FEW well-chosen examples that illustrate distinct branches — not an exhaustive catalog **What makes a BAD skill:** - Too abstract: "Analyze constraints" without showing what analysis looks like in practice - Too narrow: A single solution template that only works for one exact case +- **Bloated**: Listing dozens of case-specific details (names, dates, institutions, compounds, etc.) inside parentheses or comma-separated lists. Each How/Decision/e.g. field should contain 1-2 illustrative examples, NOT an inventory of every case seen **Field-level requirements:** -- **description** (max 150 tokens): One-sentence summary of what this skill solves + trigger scenarios. Append `Keywords:` with terms an agent would use when facing this problem class. +- **description** (HARD LIMIT: max 150 tokens, must be under 500 characters): + - One-sentence summary of the **abstract problem class** this skill solves — describe the general pattern, NOT specific cases. + - Do NOT list multiple scenarios, entity names, or case-specific details. + - Append `Keywords:` with up to 10 general terms (no specific names, numbers, or case-specific phrases). + - Example: "Identifies academic researchers by cross-referencing biographical constraints with publication records. Keywords: researcher identification, biographical verification, publication matching, academic search" - **content** (max 2000 tokens): Markdown format: ```markdown @@ -189,20 +194,21 @@ **HARD RULES for content:** - **Max 5 steps.** - - **Max 2 examples per step.** Each example MUST illustrate a distinct decision branch, using real entities/numbers/scenarios from the cases. - - **Decision branches**: REQUIRED when the next action depends on what was found. For linear steps with no branching, Decision may be omitted. + - **Max 2 examples per step.** Each example MUST be a SHORT, single-sentence illustration of a distinct decision branch. Do NOT list multiple sub-examples inside parentheses or comma-separated lists. + - **Decision branches**: REQUIRED when the next action depends on what was found. For linear steps with no branching, Decision may be omitted. Each Decision should have at most 3 branches. - **Max 4 pitfalls.** When adding a new one beyond 4, replace the most generic existing pitfall. + - **No parenthetical catalogs**: FORBIDDEN to stuff dozens of case-specific terms (names, dates, compounds, institutions, etc.) inside a single parenthetical `(e.g., X, Y, Z, ...)`. Keep each field concise — generalize the pattern, illustrate with 1-2 examples only. -【New AgentCase(s) to integrate】 +[New AgentCase(s) to integrate] {new_case_json} -【Existing skills for this cluster】(Each item has an index number) +[Existing skills for this cluster](Each item has an index number) {existing_skills_json} -【Task】 +[Task] Analyze the new case(s) and output a list of operations (add / update / none). -【Operation Guide — follow in order】 +[Operation Guide — follow in order] **Step 1: Overlap Check (mandatory before every add/update decision)** For each new case, compare against each existing skill: @@ -221,12 +227,14 @@ - **update**: The new case overlaps an existing skill (coverage >= 60%). Enrich it with new Decision branches, better examples, or sharper How explanations. - You MAY substantially rewrite content (restructure steps, replace examples, refine How explanations), but **preserve existing verified content unless the new case directly contradicts it**. + - **CRITICAL: The updated content MUST stay within 2000 tokens. Do NOT simply append new content — replace weaker examples with stronger ones, merge redundant steps, and compress prose. If the existing content is already long, aggressively condense it while preserving the core logic.** + - **CRITICAL: The updated description MUST stay under 500 characters. Generalize — do NOT accumulate case-specific details.** - **Hypothesis promotion rule**: If the existing skill contains `## Potential Steps`, treat this update as a **promotion** — rewrite as `## Steps` using the new case as primary source. confidence = `0.6`. - **Confidence-only update**: If the new case merely confirms the existing skill without adding new decision logic or better examples, bump confidence only. - **none**: Trivially duplicate — no new decision branches, no new examples worth keeping, no confidence change needed. -【Confidence Anchoring Rules】 +[Confidence Anchoring Rules] - **New skill (add)**: confidence = `0.5` - **Promoted skill (hypothesis → verified)**: confidence = `0.6` - **Update with new decision branch**: confidence = existing + `0.1` (cap 0.95) @@ -235,7 +243,7 @@ **CRITICAL LANGUAGE RULE**: Output in the SAME language as the input conversation content. -【Output Format】 +[Output Format] ```json {{ "operations": [ @@ -264,10 +272,10 @@ **Field-level requirements:** -- **description** (max 200 tokens): Must include three parts: - 1. A one-sentence summary of the problem class and the known failure patterns - 2. **Use cases**: 2-3 brief trigger scenarios - 3. **Keywords**: Include concrete case phrases, failure symptom terms, and tool names. Format: `Keywords: term1, term2, term3, ...` +- **description** (HARD LIMIT: max 150 tokens, must be under 500 characters): + - One-sentence summary of the **abstract problem class** and the known failure patterns — describe the general pattern, NOT specific cases. + - Do NOT list multiple scenarios, entity names, or case-specific details. + - Append `Keywords:` with up to 10 general terms (no specific names, numbers, or case-specific phrases). - **content**: Output in **Markdown format** using this template: ```markdown @@ -290,23 +298,26 @@ - **Potential Steps**: Include ONLY steps with demonstrable forward progress. If NO steps clearly progressed, omit the numbered list and keep only the `> Extracted from...` note. - **Pitfalls**: MUST be included and populated. Every failed case must contribute at least one specific, traceable pitfall. FORBIDDEN: generic warnings, speculative risks, best-practice reminders not directly traceable to a failure in this case. -【New AgentCase(s) to integrate】 +[New AgentCase(s) to integrate] {new_case_json} -【Existing skills for this cluster】(Each item has an index number) +[Existing skills for this cluster](Each item has an index number) {existing_skills_json} -【Task】 +[Task] Analyze the failed case(s) and output operations (add / update / none). -【Operation Guide】 +[Operation Guide] - **update**: If an existing skill covers the same problem class, integrate failure insights by index: - If existing skill has `## Steps` (verified): preserve Steps intact — only append new entries to `## Pitfalls`. - If existing skill has `## Potential Steps` (hypothesis): you may also enrich `## Potential Steps` with any steps from this case that demonstrably succeeded, in addition to appending to `## Pitfalls`. + - **CRITICAL: The updated content MUST stay within 2000 tokens. Do NOT simply append — if Pitfalls exceed 4 entries, replace the most generic one. If Potential Steps are already sufficient, do NOT add redundant ones. Aggressively condense existing content if it is already long.** + - **CRITICAL: The updated description MUST stay under 500 characters. Generalize — do NOT accumulate case-specific details.** + - **No parenthetical catalogs**: FORBIDDEN to stuff dozens of case-specific terms (names, dates, compounds, etc.) inside parentheses. Keep each field concise — generalize the pattern, illustrate with 1-2 examples only. - **add**: If no existing skill covers this problem class, create a new skill using the Potential Steps + Pitfalls template above. - **none**: The case is completely irrelevant to all existing skills and too isolated to form a useful pattern. Use very sparingly. -【Confidence Anchoring Rules】 +[Confidence Anchoring Rules] - **New skill (add)**: confidence = `0.5` - **Update existing skill with pitfall only**: confidence unchanged (failure insight doesn't validate the SOP steps). - **Update existing hypothesis skill with new Potential Steps**: confidence = existing + 0.05 (slight bump for additional partial evidence). @@ -314,7 +325,7 @@ **CRITICAL LANGUAGE RULE**: Output in the SAME language as the input conversation content. -【Output Format】 +[Output Format] No operations: ```json {{"operations": [{{"action": "none"}}], "update_note": "failed case adds no new failure patterns to existing skills"}} @@ -324,7 +335,7 @@ ```json {{ "operations": [ - {{"action": "add", "data": {{"name": "Short descriptive name (max 10 words)", "description": "One-sentence summary of problem class. Use when: scenario1; scenario2. Keywords: term1, term2 (max 150 tokens)", "content": "## Potential Steps\\n> Extracted from a failed case. Only steps that demonstrably progressed correctly are listed.\\n1. \\n - How: \\n - e.g., ``\\n - Check: \\n\\n## Pitfalls\\n- ", "confidence": 0.5}}}}, + {{"action": "add", "data": {{"name": "Short descriptive name (max 10 words)", "description": "One-sentence abstract summary of problem class. Keywords: term1, term2 (max 150 tokens, under 500 chars)", "content": "## Potential Steps\\n> Extracted from a failed case. Only steps that demonstrably progressed correctly are listed.\\n1. \\n - How: \\n - e.g., ``\\n - Check: \\n\\n## Pitfalls\\n- ", "confidence": 0.5}}}}, {{"action": "update", "index": 0, "data": {{"content": "## Steps\\n\\n\\n## Pitfalls\\n\\n- "}}}} ], "update_note": "added pitfall from failed case to skill[0]; created new skill from partial exploration" @@ -356,35 +367,6 @@ {{"results": [{{"index": 0, "score": 0.85, "reason": "brief reason"}}, {{"index": 1, "score": 0.15, "reason": "brief reason"}}]}} """ -CLUSTER_LLM_ASSIGNMENT_PROMPT = """You are a clustering expert. You will receive a batch of new memory items and a list of existing clusters (each described by its most recent episodes). Your task is to assign each new item to either an existing cluster or a new cluster. - -Two items belong in the same cluster if they are about the **same topic, domain, or recurring theme**. Focus on WHAT the content is about, not surface-level wording. - -【Existing Clusters】 -Each cluster is represented by its cluster_id, item_count, and up to 3 most recent episodes. -If this list is empty, carefully distinguish the new items below and create different clusters for them. -{existing_clusters_json} - -【New Items to Classify】 -{new_items_json} - -【Rules】 -1. For each new item, decide: assign to an existing cluster (by cluster_id) OR create a new cluster. -2. If multiple new items belong together AND no existing cluster fits, group them into the same NEW cluster. -3. Use numeric IDs for new clusters starting from {next_new_id} (e.g., "cluster_{next_new_id:03d}", "cluster_{next_new_id_plus1:03d}", ...). -4. Every item_index must appear in exactly one assignment. -5. Be specific enough to separate genuinely different topics, but do NOT over-split. Items about slightly different sub-topics within the same domain should share a cluster. -6. Keep each "reason" field to 20 tokens or fewer. - -Return ONLY valid JSON (no markdown fences, no explanation): -{{ - "assignments": [ - {{"item_index": 0, "cluster_id": "", "reason": "short reason (max 20 tokens)"}}, - {{"item_index": 1, "cluster_id": "cluster_{next_new_id:03d}", "reason": "new topic not in existing clusters"}}, - ... - ] -}}""" - AGENT_SKILL_MATURITY_SCORE_PROMPT = """You are a quality evaluator for agent skill documents (SOPs). Skills come in two forms — detect which type before scoring: @@ -435,3 +417,29 @@ Return ONLY valid JSON (no markdown fences): {{"completeness": 1-5, "executability": 1-5, "evidence": 1-5, "clarity": 1-5, "reason": "brief justification for the scores"}} """ + +AGENT_CLUSTER_LLM_ASSIGN_PROMPT = """You are a clustering expert. Your goal is to group similar and related tasks together so that patterns and reusable strategies can be extracted from each cluster. Assign the new task intent to an existing cluster, or create a new one if no existing cluster fits. + +[How to decide] +The goal of clustering is to group cases that would produce a **specific, actionable skill** — not generic advice. Use this test: "Would an agent who solved one task in this cluster have a **concrete advantage** (reusable tools, domain knowledge, verified strategies) when facing the other tasks?" + +1. **Identify two dimensions**: the task's **subject domain** (e.g., medical research, urban planning, e-commerce) and its **problem-solving pattern** (e.g., root cause analysis, constraint satisfaction, data pipeline design). +2. **Cluster by the more specific dimension**. If the domain is already narrow (e.g., "clinical trial data extraction"), domain alone is enough. If the domain is broad (e.g., "software engineering"), use the problem-solving pattern to differentiate (e.g., "performance profiling" vs. "schema migration"). +3. **Do NOT merge across unrelated domains just because the strategy is similar.** "Diagnose a patient's symptoms via differential diagnosis" and "diagnose a supply chain bottleneck via constraint analysis" both use diagnostic reasoning, but involve completely different domain knowledge and belong in separate clusters. +4. Scan candidate clusters. Prefer the cluster whose existing items would **benefit most from sharing a skill** with the new task. +5. Create a new cluster only when no candidate cluster is a good fit. + +[Candidate Clusters] +Each cluster is represented by its cluster_id, item_count, and most recent task intents. +{clusters_json} + +[New Task Intent] +{memcell_text} + +[Rules] +- Output decision as JSON. Keep "reason" under 50 tokens. +- To assign: use an existing cluster_id. To create new: use "cluster_{next_new_id}". + +Return ONLY valid JSON (no markdown fences, no explanation): +{{"cluster_id": "", "reason": "short reason"}} +""" diff --git a/methods/evermemos/src/memory_layer/prompts/en/cluster_prompts.py b/methods/evermemos/src/memory_layer/prompts/en/cluster_prompts.py new file mode 100644 index 000000000..ba2a87390 --- /dev/null +++ b/methods/evermemos/src/memory_layer/prompts/en/cluster_prompts.py @@ -0,0 +1 @@ +# Cluster prompts moved to agent_prompts.py diff --git a/methods/evermemos/src/memory_layer/prompts/zh/agent_prompts.py b/methods/evermemos/src/memory_layer/prompts/zh/agent_prompts.py index f3f6bf470..48cb847bc 100644 --- a/methods/evermemos/src/memory_layer/prompts/zh/agent_prompts.py +++ b/methods/evermemos/src/memory_layer/prompts/zh/agent_prompts.py @@ -7,7 +7,7 @@ AGENT_SKILL_FAILURE_EXTRACT_PROMPT, AGENT_SKILL_RELEVANCE_VERIFY_PROMPT, AGENT_SKILL_MATURITY_SCORE_PROMPT, - CLUSTER_LLM_ASSIGNMENT_PROMPT, + AGENT_CLUSTER_LLM_ASSIGN_PROMPT, ) __all__ = [ @@ -18,5 +18,5 @@ "AGENT_SKILL_FAILURE_EXTRACT_PROMPT", "AGENT_SKILL_RELEVANCE_VERIFY_PROMPT", "AGENT_SKILL_MATURITY_SCORE_PROMPT", - "CLUSTER_LLM_ASSIGNMENT_PROMPT", + "AGENT_CLUSTER_LLM_ASSIGN_PROMPT", ] diff --git a/methods/evermemos/src/memory_layer/prompts/zh/cluster_prompts.py b/methods/evermemos/src/memory_layer/prompts/zh/cluster_prompts.py new file mode 100644 index 000000000..ba2a87390 --- /dev/null +++ b/methods/evermemos/src/memory_layer/prompts/zh/cluster_prompts.py @@ -0,0 +1 @@ +# Cluster prompts moved to agent_prompts.py diff --git a/methods/evermemos/src/service/memcell_delete_service.py b/methods/evermemos/src/service/memcell_delete_service.py index b7a3fff69..871d5c5d9 100644 --- a/methods/evermemos/src/service/memcell_delete_service.py +++ b/methods/evermemos/src/service/memcell_delete_service.py @@ -199,6 +199,9 @@ async def _gather_deletes(self, *tasks: tuple[str, Any, dict]) -> dict[str, int] names = [t[0] for t in tasks] coros = [t[1].delete_by_filters(**t[2]) for t in tasks] results = await asyncio.gather(*coros, return_exceptions=True) + from common_utils.async_utils import reraise_critical_errors + + reraise_critical_errors(results) counts: dict[str, int] = {} for name, result in zip(names, results): if isinstance(result, Exception): diff --git a/methods/evermemos/tests/test_agent_converters_and_pipeline.py b/methods/evermemos/tests/test_agent_converters_and_pipeline.py index 8ffa6b7b7..73e8ae360 100644 --- a/methods/evermemos/tests/test_agent_converters_and_pipeline.py +++ b/methods/evermemos/tests/test_agent_converters_and_pipeline.py @@ -954,8 +954,8 @@ def get_bean_side_effect(cls): await memorize_mod._trigger_agent_skill_extraction( group_id="g1", cluster_id="cluster_001", - user_id=agent_case.user_id, - agent_cases=[agent_case], + memcell=memcell, + agent_case=agent_case, ) # Verify extractor was called @@ -999,7 +999,7 @@ async def test_extraction_exception_handled(self): # Should not raise - exception is caught in the outer try-except await memorize_mod._trigger_agent_skill_extraction( - group_id="g1", cluster_id="c1", user_id=agent_case.user_id, agent_cases=[agent_case] + group_id="g1", cluster_id="c1", memcell=memcell, agent_case=agent_case ) @@ -1507,7 +1507,7 @@ async def test_lock_not_acquired_skips_extraction(self): importlib.reload(memorize_mod) await memorize_mod._trigger_agent_skill_extraction( - group_id="g1", cluster_id="c1", user_id=agent_case.user_id, agent_cases=[agent_case] + group_id="g1", cluster_id="c1", memcell=memcell, agent_case=agent_case ) @pytest.mark.asyncio @@ -1571,7 +1571,7 @@ async def test_updated_records_delete_old_then_insert_new(self): importlib.reload(memorize_mod) await memorize_mod._trigger_agent_skill_extraction( - group_id="g1", cluster_id="c1", user_id=agent_case.user_id, agent_cases=[agent_case] + group_id="g1", cluster_id="c1", memcell=memcell, agent_case=agent_case ) updated_id = str(updated_record.id) @@ -1641,7 +1641,7 @@ async def test_record_without_vector_skipped_in_milvus(self): importlib.reload(memorize_mod) await memorize_mod._trigger_agent_skill_extraction( - group_id="g1", cluster_id="c1", user_id=agent_case.user_id, agent_cases=[agent_case] + group_id="g1", cluster_id="c1", memcell=memcell, agent_case=agent_case ) mock_milvus_repo.insert.assert_not_called() @@ -1709,7 +1709,7 @@ async def test_milvus_failure_does_not_block_es(self): importlib.reload(memorize_mod) await memorize_mod._trigger_agent_skill_extraction( - group_id="g1", cluster_id="c1", user_id=agent_case.user_id, agent_cases=[agent_case] + group_id="g1", cluster_id="c1", memcell=memcell, agent_case=agent_case ) mock_es_repo.create.assert_called_once() @@ -1776,7 +1776,7 @@ async def test_es_failure_does_not_raise(self): importlib.reload(memorize_mod) await memorize_mod._trigger_agent_skill_extraction( - group_id="g1", cluster_id="c1", user_id=agent_case.user_id, agent_cases=[agent_case] + group_id="g1", cluster_id="c1", memcell=memcell, agent_case=agent_case ) @pytest.mark.asyncio @@ -1838,7 +1838,7 @@ async def test_deleted_ids_removed_from_search_engines(self): importlib.reload(memorize_mod) await memorize_mod._trigger_agent_skill_extraction( - group_id="g1", cluster_id="c1", user_id=agent_case.user_id, agent_cases=[agent_case] + group_id="g1", cluster_id="c1", memcell=memcell, agent_case=agent_case ) assert mock_milvus_repo.delete_by_id.call_count == 2 @@ -1904,7 +1904,7 @@ async def test_no_changes_skips_sync(self): importlib.reload(memorize_mod) await memorize_mod._trigger_agent_skill_extraction( - group_id="g1", cluster_id="c1", user_id=agent_case.user_id, agent_cases=[agent_case] + group_id="g1", cluster_id="c1", memcell=memcell, agent_case=agent_case ) mock_milvus_repo.insert.assert_not_called() @@ -2000,8 +2000,8 @@ async def test_no_group_episode_skips_clustering(self): @pytest.mark.asyncio async def test_agent_conversation_uses_agent_config(self): - """Agent conversations should use AGENT_DEFAULT_MEMORIZE_CONFIG.""" - from biz_layer.memorize_config import AGENT_DEFAULT_MEMORIZE_CONFIG + """Agent conversations should use DEFAULT_MEMORIZE_CONFIG.""" + from biz_layer.memorize_config import DEFAULT_MEMORIZE_CONFIG import biz_layer.mem_memorize as mod memcell = _make_agent_memcell_for_trigger() @@ -2023,7 +2023,7 @@ async def test_agent_conversation_uses_agent_config(self): await mod._update_memcell_and_cluster(state) call_kwargs = mock_cluster.call_args.kwargs - assert call_kwargs["config"] is AGENT_DEFAULT_MEMORIZE_CONFIG + assert call_kwargs["config"] is DEFAULT_MEMORIZE_CONFIG finally: mod._trigger_clustering = original_trigger diff --git a/methods/evermemos/tests/test_agent_skill_extractor.py b/methods/evermemos/tests/test_agent_skill_extractor.py index a6dc73416..97cb55802 100644 --- a/methods/evermemos/tests/test_agent_skill_extractor.py +++ b/methods/evermemos/tests/test_agent_skill_extractor.py @@ -1853,7 +1853,7 @@ def test_approach_truncated(self): case = _make_case_record(approach=long_approach) result = extractor._summarize_case_for_prompt(case, max_approach_chars=100) assert len(result["approach"]) == 100 - assert result["approach"].endswith("...") + assert result["approach"].endswith("... [omitted]") def test_no_approach_omitted(self): extractor = _build_extractor() @@ -1877,9 +1877,9 @@ def test_exact_length_unchanged(self): assert AgentSkillExtractor._truncate_text("hello", 5) == "hello" def test_long_text_truncated_with_ellipsis(self): - result = AgentSkillExtractor._truncate_text("abcdefghij", 8) - assert len(result) == 8 - assert result.endswith("...") + result = AgentSkillExtractor._truncate_text("a" * 30, 20) + assert len(result) == 20 + assert result.endswith("... [omitted]") def test_none_returns_empty(self): assert AgentSkillExtractor._truncate_text(None, 10) == "" diff --git a/methods/evermemos/tests/test_agent_skill_relevance_verify.py b/methods/evermemos/tests/test_agent_skill_relevance_verify.py index 9450e0f98..732da54b8 100644 --- a/methods/evermemos/tests/test_agent_skill_relevance_verify.py +++ b/methods/evermemos/tests/test_agent_skill_relevance_verify.py @@ -8,7 +8,7 @@ import pytest from unittest.mock import AsyncMock, patch, MagicMock -from api_specs.dtos.memory import SearchAgentSkillItem as AgentSkillItem +from api_specs.dtos.memory import AgentSkillItem, SearchAgentSkillItem def _make_service(): @@ -26,9 +26,9 @@ def _make_service(): return SearchMemoryService() -def _make_skill(name: str, description: str = "desc", content: str = "content", score: float = 0.8) -> AgentSkillItem: - """Helper to create an AgentSkillItem instance.""" - return AgentSkillItem( +def _make_skill(name: str, description: str = "desc", content: str = "content", score: float = 0.8) -> SearchAgentSkillItem: + """Helper to create a SearchAgentSkillItem instance.""" + return SearchAgentSkillItem( id=f"skill_{name}", user_id="test_user", name=name, @@ -63,7 +63,7 @@ async def test_empty_query_returns_all(service): @pytest.mark.asyncio async def test_filters_irrelevant_skills(service): - """LLM gives high score to relevant skill and low score to irrelevant — only high-scoring one is returned.""" + """LLM marks one skill as helpful and one as not — only helpful one is returned.""" skills = [ _make_skill( "Database connection pool tuning", @@ -96,18 +96,17 @@ async def test_filters_irrelevant_skills(service): assert len(result) == 1 assert result[0].name == "Database connection pool tuning" - assert result[0].score == 0.9 @pytest.mark.asyncio -async def test_all_skills_high_score(service): - """When LLM gives all skills high scores, all are returned sorted by score descending.""" +async def test_all_skills_helpful(service): + """When LLM says all skills are helpful, all are returned.""" skills = [_make_skill("skill_a"), _make_skill("skill_b")] llm_response = json.dumps({ "results": [ - {"index": 0, "score": 0.7, "reason": "relevant"}, - {"index": 1, "score": 0.9, "reason": "very relevant"}, + {"index": 0, "score": 0.8, "reason": "relevant"}, + {"index": 1, "score": 0.7, "reason": "also relevant"}, ] }) @@ -121,18 +120,16 @@ async def test_all_skills_high_score(service): ) assert len(result) == 2 - assert result[0].name == "skill_b" - assert result[1].name == "skill_a" @pytest.mark.asyncio -async def test_all_skills_low_score(service): - """When LLM gives all skills low scores, empty list is returned.""" +async def test_no_skills_helpful(service): + """When LLM says no skills are helpful, empty list is returned.""" skills = [_make_skill("skill_a")] llm_response = json.dumps({ "results": [ - {"index": 0, "score": 0.2, "reason": "not relevant"}, + {"index": 0, "score": 0.1, "reason": "not relevant"}, ] }) @@ -197,7 +194,7 @@ async def test_none_skills_returns_none(service): @pytest.mark.asyncio async def test_none_fields_in_skill_use_empty_string(service): """Skills with None name/description/content are serialised as empty strings in prompt.""" - skills = [AgentSkillItem(id="s1", user_id="u1", name=None, description=None, content=None, score=0.5)] + skills = [SearchAgentSkillItem(id="s1", user_id="u1", name=None, description=None, content=None, score=0.5)] llm_response = json.dumps({"results": [{"index": 0, "score": 0.8, "reason": "ok"}]}) mock_provider = AsyncMock() @@ -229,8 +226,8 @@ async def test_out_of_range_index_ignored(service): llm_response = json.dumps({ "results": [ - {"index": 0, "score": 0.85, "reason": "ok"}, - {"index": 99, "score": 0.9, "reason": "ghost"}, + {"index": 0, "score": 0.8, "reason": "ok"}, + {"index": 99, "score": 0.8, "reason": "ghost"}, ] }) @@ -280,13 +277,13 @@ async def test_missing_score_field_defaults_zero(service): @pytest.mark.asyncio -async def test_partial_indices_only_returns_scored_above_threshold(service): - """LLM only returns judgement for some skills — unjudged ones default to 0.0 and are excluded.""" +async def test_partial_indices_only_returns_judged_helpful(service): + """LLM only returns judgement for some skills — unjudged ones are excluded.""" skills = [_make_skill("a"), _make_skill("b"), _make_skill("c")] llm_response = json.dumps({ "results": [ - {"index": 1, "score": 0.75, "reason": "relevant"}, + {"index": 1, "score": 0.8, "reason": "relevant"}, ] }) mock_provider = AsyncMock() @@ -301,15 +298,14 @@ async def test_partial_indices_only_returns_scored_above_threshold(service): @pytest.mark.asyncio -async def test_results_sorted_by_score_descending(service): - """Results are sorted by LLM relevance score in descending order.""" +async def test_filtered_results_sorted_by_score_desc(service): + """Filtered results are sorted by score descending.""" skills = [_make_skill("first"), _make_skill("second"), _make_skill("third")] llm_response = json.dumps({ "results": [ - {"index": 0, "score": 0.6, "reason": "ok"}, - {"index": 1, "score": 0.9, "reason": "great"}, - {"index": 2, "score": 0.75, "reason": "good"}, + {"index": 2, "score": 0.9, "reason": "ok"}, + {"index": 0, "score": 0.5, "reason": "ok"}, ] }) mock_provider = AsyncMock() @@ -319,33 +315,9 @@ async def test_results_sorted_by_score_descending(service): patch("memory_layer.prompts.get_prompt_by", return_value="{query}{skills_json}"): result = await service._verify_skill_relevance(query="q", skills=skills) - assert len(result) == 3 - assert result[0].name == "second" - assert result[1].name == "third" - assert result[2].name == "first" - - -@pytest.mark.asyncio -async def test_threshold_boundary(service): - """Score exactly at 0.4 passes, score below 0.4 is excluded.""" - skills = [_make_skill("at_boundary"), _make_skill("below_boundary")] - - llm_response = json.dumps({ - "results": [ - {"index": 0, "score": 0.4, "reason": "borderline"}, - {"index": 1, "score": 0.39, "reason": "just below"}, - ] - }) - mock_provider = AsyncMock() - mock_provider.generate = AsyncMock(return_value=llm_response) - - with patch("memory_layer.llm.llm_provider.build_default_provider", return_value=mock_provider), \ - patch("memory_layer.prompts.get_prompt_by", return_value="{query}{skills_json}"): - result = await service._verify_skill_relevance(query="q", skills=skills) - - assert len(result) == 1 - assert result[0].name == "at_boundary" - assert result[0].score == 0.4 + assert len(result) == 2 + assert result[0].name == "third" + assert result[1].name == "first" @pytest.mark.asyncio @@ -353,7 +325,7 @@ async def test_prompt_receives_correct_arguments(service): """Verify get_prompt_by is called with correct key and format receives query + skills_json.""" skills = [_make_skill("db_tuning", "tune db", "step1")] - llm_response = json.dumps({"results": [{"index": 0, "score": 0.85, "reason": "ok"}]}) + llm_response = json.dumps({"results": [{"index": 0, "score": 0.8, "reason": "ok"}]}) mock_provider = AsyncMock() mock_provider.generate = AsyncMock(return_value=llm_response) diff --git a/methods/evermemos/tests/test_api_key_rotator.py b/methods/evermemos/tests/test_api_key_rotator.py new file mode 100644 index 000000000..379e9f7c1 --- /dev/null +++ b/methods/evermemos/tests/test_api_key_rotator.py @@ -0,0 +1,104 @@ +"""ApiKeyRotator unit tests.""" + +import pytest + +from memory_layer.llm.api_key_rotator import ApiKeyRotator + + +@pytest.fixture(autouse=True) +def _reset_shared_rotator(): + """Ensure each test starts with a clean singleton state.""" + ApiKeyRotator._shared = None + yield + ApiKeyRotator._shared = None + + +class TestApiKeyRotator: + """Unit tests for ApiKeyRotator core rotation logic.""" + + def test_single_key_always_returns_same(self) -> None: + rotator = ApiKeyRotator(["key-a"]) + results = [rotator.get_next() for _ in range(5)] + assert results == ["key-a"] * 5 + + def test_multiple_keys_round_robin(self) -> None: + rotator = ApiKeyRotator(["key-a", "key-b", "key-c"]) + results = [rotator.get_next() for _ in range(3)] + assert results == ["key-a", "key-b", "key-c"] + + def test_multiple_keys_wraps_around(self) -> None: + rotator = ApiKeyRotator(["key-a", "key-b", "key-c"]) + results = [rotator.get_next() for _ in range(6)] + assert results == ["key-a", "key-b", "key-c", "key-a", "key-b", "key-c"] + + def test_size_property(self) -> None: + assert ApiKeyRotator(["key-a"]).size == 1 + assert ApiKeyRotator(["key-a", "key-b", "key-c"]).size == 3 + + def test_empty_keys_raises_value_error(self) -> None: + with pytest.raises(ValueError, match="At least one API key is required"): + ApiKeyRotator([]) + + def test_repr(self) -> None: + rotator = ApiKeyRotator(["key-a", "key-b"]) + assert repr(rotator) == "ApiKeyRotator(size=2)" + + def test_keys_are_immutable(self) -> None: + original = ["key-a", "key-b"] + rotator = ApiKeyRotator(original) + original.append("key-c") + assert rotator.size == 2 + + +class TestApiKeyRotatorGetOrCreate: + """Tests for get_or_create: parsing + singleton behavior.""" + + def test_single_key(self) -> None: + rotator = ApiKeyRotator.get_or_create("key-a") + assert rotator.size == 1 + assert rotator.get_next() == "key-a" + + def test_multiple_keys_comma_separated(self) -> None: + rotator = ApiKeyRotator.get_or_create("key-a,key-b,key-c") + assert rotator.size == 3 + assert rotator.get_next() == "key-a" + assert rotator.get_next() == "key-b" + assert rotator.get_next() == "key-c" + + def test_strips_whitespace(self) -> None: + rotator = ApiKeyRotator.get_or_create(" key-a , key-b , key-c ") + assert rotator.size == 3 + assert rotator.get_next() == "key-a" + + def test_ignores_trailing_comma(self) -> None: + rotator = ApiKeyRotator.get_or_create("key-a,key-b,") + assert rotator.size == 2 + + def test_ignores_empty_segments(self) -> None: + rotator = ApiKeyRotator.get_or_create("key-a,,key-b") + assert rotator.size == 2 + + def test_empty_string_raises_value_error(self) -> None: + with pytest.raises(ValueError, match="At least one API key is required"): + ApiKeyRotator.get_or_create("") + + def test_only_commas_raises_value_error(self) -> None: + with pytest.raises(ValueError, match="At least one API key is required"): + ApiKeyRotator.get_or_create(",,,") + + def test_returns_same_instance(self) -> None: + r1 = ApiKeyRotator.get_or_create("key-a,key-b") + r2 = ApiKeyRotator.get_or_create("key-a,key-b") + assert r1 is r2 + + def test_shared_counter_across_calls(self) -> None: + r1 = ApiKeyRotator.get_or_create("key-a,key-b") + assert r1.get_next() == "key-a" + r2 = ApiKeyRotator.get_or_create("key-a,key-b") + assert r2.get_next() == "key-b" + + def test_new_instance_after_clearing_shared(self) -> None: + r1 = ApiKeyRotator.get_or_create("key-a,key-b") + ApiKeyRotator._shared = None + r2 = ApiKeyRotator.get_or_create("key-a,key-b") + assert r1 is not r2 diff --git a/methods/evermemos/tests/test_cluster_memcell_llm.py b/methods/evermemos/tests/test_cluster_memcell_llm.py new file mode 100644 index 000000000..7ec87bc11 --- /dev/null +++ b/methods/evermemos/tests/test_cluster_memcell_llm.py @@ -0,0 +1,458 @@ +"""Unit tests for ClusterManager._cluster_memcell_llm. + +Covers every branch of the LLM-based clustering method: +1. Missing event_id +2. No LLM provider (embedding fallback) + 2a. top-1 sim >= threshold -> assign existing + 2b. top-1 sim < threshold -> new cluster +3. No existing case clusters -> new cluster +4. Fast path: top-1 sim >= llm_skip_threshold +5. LLM failure (returns None) -> embedding fallback + 5a. top-1 sim >= threshold -> assign existing + 5b. no good candidate -> new cluster +6. LLM returns valid result + 6a. chosen_id is valid case cluster -> assign + 6b. chosen_id invalid -> new cluster +""" + +import numpy as np +import pytest +from unittest.mock import AsyncMock, patch, MagicMock + +from memory_layer.cluster_manager.manager import ClusterManager, MemSceneState +from memory_layer.cluster_manager.config import ClusterManagerConfig + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_config(**overrides) -> ClusterManagerConfig: + defaults = dict( + similarity_threshold=0.65, + llm_skip_threshold=0.85, + llm_top_k_clusters=5, + llm_max_context_per_cluster=3, + ) + defaults.update(overrides) + return ClusterManagerConfig(**defaults) + + +def _make_manager( + config=None, + llm_provider=None, + context_fetcher=None, + embedding_vec=None, +) -> ClusterManager: + """Build a ClusterManager with vectorize service mocked out.""" + cfg = config or _make_config() + mgr = ClusterManager.__new__(ClusterManager) + mgr.config = cfg + mgr._callbacks = [] + mgr._llm_provider = llm_provider + mgr._context_fetcher = context_fetcher + mgr._stats = { + "total_memcells": 0, + "clustered_memcells": 0, + "new_clusters": 0, + "failed_embeddings": 0, + } + + # Mock vectorize service to return a controlled vector + mock_vs = AsyncMock() + if embedding_vec is not None: + mock_vs.get_embedding = AsyncMock(return_value=embedding_vec.tolist()) + else: + mock_vs.get_embedding = AsyncMock(return_value=[1.0, 0.0, 0.0]) + mgr._vectorize_service = mock_vs + return mgr + + +def _make_memcell(event_id="evt_1", text="some task", timestamp=1000.0): + return { + "event_id": event_id, + "clustering_text": text, + "timestamp": timestamp, + } + + +def _state_with_case_cluster( + cluster_id="cluster_000", + event_id="existing_evt", + centroid=None, + count=1, +): + """Build a MemSceneState that already has one case cluster.""" + state = MemSceneState() + state.next_cluster_idx = 1 + state.case_cluster_ids = {cluster_id} + state.cluster_counts[cluster_id] = count + state.cluster_last_ts[cluster_id] = 900.0 + state.eventid_to_cluster[event_id] = cluster_id + state.event_ids.append(event_id) + state.timestamps.append(900.0) + if centroid is not None: + state.cluster_centroids[cluster_id] = centroid + state.vectors.append(centroid) + else: + vec = np.array([1.0, 0.0, 0.0], dtype=np.float32) + state.cluster_centroids[cluster_id] = vec + state.vectors.append(vec) + state.cluster_ids.append(cluster_id) + return state + + +# =========================================================================== +# 1. Missing event_id -> (None, state) +# =========================================================================== + + +class TestMissingEventId: + + @pytest.mark.asyncio + async def test_empty_event_id_returns_none(self): + mgr = _make_manager(llm_provider=MagicMock()) + state = MemSceneState() + memcell = {"event_id": "", "clustering_text": "x", "timestamp": 1.0} + + cid, out_state = await mgr._cluster_memcell_llm(memcell, state) + + assert cid is None + assert out_state is state + assert mgr._stats["total_memcells"] == 1 + assert mgr._stats["clustered_memcells"] == 0 + + @pytest.mark.asyncio + async def test_missing_event_id_key_returns_none(self): + mgr = _make_manager(llm_provider=MagicMock()) + state = MemSceneState() + memcell = {"clustering_text": "x"} + + cid, _ = await mgr._cluster_memcell_llm(memcell, state) + assert cid is None + + +# =========================================================================== +# 2. No LLM provider -> embedding fallback +# =========================================================================== + + +class TestNoLlmProvider: + + @pytest.mark.asyncio + async def test_no_llm_assign_existing_when_similar(self): + """2a: top-1 sim >= threshold -> assign to existing cluster.""" + centroid = np.array([1.0, 0.0, 0.0], dtype=np.float32) + mgr = _make_manager( + llm_provider=None, + embedding_vec=centroid, # identical to centroid -> sim=1.0 + ) + state = _state_with_case_cluster(centroid=centroid) + + cid, out_state = await mgr._cluster_memcell_llm( + _make_memcell(), state + ) + + assert cid == "cluster_000" + assert "evt_1" in out_state.eventid_to_cluster + assert out_state.eventid_to_cluster["evt_1"] == "cluster_000" + assert mgr._stats["clustered_memcells"] == 1 + + @pytest.mark.asyncio + async def test_no_llm_new_cluster_when_dissimilar(self): + """2b: top-1 sim < threshold -> new cluster.""" + centroid = np.array([1.0, 0.0, 0.0], dtype=np.float32) + orthogonal = np.array([0.0, 1.0, 0.0], dtype=np.float32) + mgr = _make_manager( + llm_provider=None, + embedding_vec=orthogonal, # sim ~ 0 with centroid + ) + state = _state_with_case_cluster(centroid=centroid) + + cid, out_state = await mgr._cluster_memcell_llm( + _make_memcell(), state + ) + + assert cid != "cluster_000" + assert cid.startswith("cluster_") + assert cid in out_state.case_cluster_ids + assert mgr._stats["new_clusters"] == 1 + + +# =========================================================================== +# 3. No existing case clusters -> create first case cluster +# =========================================================================== + + +class TestNoCaseClusters: + + @pytest.mark.asyncio + async def test_first_case_cluster_created(self): + mgr = _make_manager(llm_provider=MagicMock()) + state = MemSceneState() + + cid, out_state = await mgr._cluster_memcell_llm( + _make_memcell(), state + ) + + assert cid == "cluster_000" + assert cid in out_state.case_cluster_ids + assert out_state.eventid_to_cluster["evt_1"] == cid + assert mgr._stats["new_clusters"] == 1 + assert mgr._stats["clustered_memcells"] == 1 + assert "evt_1" in out_state.event_ids + + +# =========================================================================== +# 4. Fast path: top-1 sim >= llm_skip_threshold +# =========================================================================== + + +class TestFastPath: + + @pytest.mark.asyncio + async def test_skip_llm_when_very_similar(self): + """sim=1.0 >= llm_skip_threshold=0.85 -> assign without LLM.""" + centroid = np.array([1.0, 0.0, 0.0], dtype=np.float32) + mock_llm = MagicMock() + mgr = _make_manager( + llm_provider=mock_llm, + embedding_vec=centroid, + ) + state = _state_with_case_cluster(centroid=centroid) + + cid, out_state = await mgr._cluster_memcell_llm( + _make_memcell(), state + ) + + assert cid == "cluster_000" + assert out_state.eventid_to_cluster["evt_1"] == "cluster_000" + assert mgr._stats["clustered_memcells"] == 1 + # LLM should NOT have been called + assert not hasattr(mock_llm, 'generate') or not mock_llm.generate.called + + @pytest.mark.asyncio + async def test_no_fast_path_when_below_threshold(self): + """sim < llm_skip_threshold -> should proceed to LLM stage.""" + centroid = np.array([1.0, 0.0, 0.0], dtype=np.float32) + different = np.array([0.6, 0.8, 0.0], dtype=np.float32) # sim ~ 0.6 + mock_llm = AsyncMock() + mock_llm.generate = AsyncMock( + return_value='{"cluster_id": "cluster_000", "reason": "same topic"}' + ) + mgr = _make_manager( + llm_provider=mock_llm, + embedding_vec=different, + ) + state = _state_with_case_cluster(centroid=centroid) + + with patch( + "memory_layer.prompts.get_prompt_by", + return_value="{memcell_text}{clusters_json}{next_new_id}", + ): + cid, _ = await mgr._cluster_memcell_llm( + _make_memcell(), state + ) + + # LLM was called and returned cluster_000 + assert cid == "cluster_000" + mock_llm.generate.assert_called() + + +# =========================================================================== +# 5. LLM failure (returns None) -> embedding fallback +# =========================================================================== + + +class TestLlmFailureFallback: + + @pytest.mark.asyncio + async def test_llm_fail_assign_existing_when_similar(self): + """5a: LLM fails, top-1 sim >= threshold -> assign existing.""" + centroid = np.array([1.0, 0.0, 0.0], dtype=np.float32) + close_vec = np.array([0.95, 0.31, 0.0], dtype=np.float32) # sim ~ 0.95 + mock_llm = AsyncMock() + mock_llm.generate = AsyncMock(return_value="invalid json {{{") + mgr = _make_manager( + llm_provider=mock_llm, + embedding_vec=close_vec, + config=_make_config(llm_skip_threshold=1.0), # never skip + ) + state = _state_with_case_cluster(centroid=centroid) + + with patch( + "memory_layer.prompts.get_prompt_by", + return_value="{memcell_text}{clusters_json}{next_new_id}", + ): + cid, out_state = await mgr._cluster_memcell_llm( + _make_memcell(), state + ) + + assert cid == "cluster_000" + assert out_state.eventid_to_cluster["evt_1"] == "cluster_000" + + @pytest.mark.asyncio + async def test_llm_fail_new_cluster_when_dissimilar(self): + """5b: LLM fails, no good candidate -> new cluster.""" + centroid = np.array([1.0, 0.0, 0.0], dtype=np.float32) + orthogonal = np.array([0.0, 1.0, 0.0], dtype=np.float32) + mock_llm = AsyncMock() + mock_llm.generate = AsyncMock(return_value="invalid json {{{") + mgr = _make_manager( + llm_provider=mock_llm, + embedding_vec=orthogonal, + config=_make_config(llm_skip_threshold=1.0), + ) + state = _state_with_case_cluster(centroid=centroid) + + with patch( + "memory_layer.prompts.get_prompt_by", + return_value="{memcell_text}{clusters_json}{next_new_id}", + ): + cid, out_state = await mgr._cluster_memcell_llm( + _make_memcell(), state + ) + + assert cid != "cluster_000" + assert cid in out_state.case_cluster_ids + assert mgr._stats["new_clusters"] == 1 + + +# =========================================================================== +# 6. LLM returns valid result +# =========================================================================== + + +class TestLlmValidResult: + + @pytest.mark.asyncio + async def test_llm_assigns_valid_existing_cluster(self): + """6a: LLM returns valid case cluster_id -> assign.""" + centroid = np.array([1.0, 0.0, 0.0], dtype=np.float32) + different = np.array([0.6, 0.8, 0.0], dtype=np.float32) + mock_llm = AsyncMock() + mock_llm.generate = AsyncMock( + return_value='{"cluster_id": "cluster_000", "reason": "related"}' + ) + mgr = _make_manager( + llm_provider=mock_llm, + embedding_vec=different, + config=_make_config(llm_skip_threshold=1.0), + ) + state = _state_with_case_cluster(centroid=centroid) + + with patch( + "memory_layer.prompts.get_prompt_by", + return_value="{memcell_text}{clusters_json}{next_new_id}", + ): + cid, out_state = await mgr._cluster_memcell_llm( + _make_memcell(), state + ) + + assert cid == "cluster_000" + assert out_state.eventid_to_cluster["evt_1"] == "cluster_000" + assert mgr._stats["clustered_memcells"] == 1 + + @pytest.mark.asyncio + async def test_llm_returns_new_cluster_id(self): + """6b: LLM returns an id not in state -> new cluster.""" + centroid = np.array([1.0, 0.0, 0.0], dtype=np.float32) + different = np.array([0.6, 0.8, 0.0], dtype=np.float32) + mock_llm = AsyncMock() + mock_llm.generate = AsyncMock( + return_value='{"cluster_id": "001", "reason": "new topic"}' + ) + mgr = _make_manager( + llm_provider=mock_llm, + embedding_vec=different, + config=_make_config(llm_skip_threshold=1.0), + ) + state = _state_with_case_cluster(centroid=centroid) + + with patch( + "memory_layer.prompts.get_prompt_by", + return_value="{memcell_text}{clusters_json}{next_new_id}", + ): + cid, out_state = await mgr._cluster_memcell_llm( + _make_memcell(), state + ) + + assert cid != "cluster_000" + assert cid in out_state.case_cluster_ids + assert mgr._stats["new_clusters"] == 1 + + @pytest.mark.asyncio + async def test_llm_returns_non_case_cluster_creates_new(self): + """6b variant: LLM returns cluster_id that exists but is NOT a case cluster.""" + centroid = np.array([1.0, 0.0, 0.0], dtype=np.float32) + different = np.array([0.6, 0.8, 0.0], dtype=np.float32) + mock_llm = AsyncMock() + # Return a cluster id that exists in cluster_counts but not in case_cluster_ids + mock_llm.generate = AsyncMock( + return_value='{"cluster_id": "cluster_999", "reason": "matched"}' + ) + mgr = _make_manager( + llm_provider=mock_llm, + embedding_vec=different, + config=_make_config(llm_skip_threshold=1.0), + ) + state = _state_with_case_cluster(centroid=centroid) + # Add a non-case cluster + state.cluster_counts["cluster_999"] = 2 + + with patch( + "memory_layer.prompts.get_prompt_by", + return_value="{memcell_text}{clusters_json}{next_new_id}", + ): + cid, out_state = await mgr._cluster_memcell_llm( + _make_memcell(), state + ) + + # Should NOT assign to cluster_999 because it's not a case cluster + assert cid != "cluster_999" + assert cid in out_state.case_cluster_ids + + +# =========================================================================== +# State mutation correctness +# =========================================================================== + + +class TestStateMutation: + + @pytest.mark.asyncio + async def test_event_appended_to_state(self): + """Every successful path appends event to state lists.""" + mgr = _make_manager(llm_provider=MagicMock()) + state = MemSceneState() + + cid, out_state = await mgr._cluster_memcell_llm( + _make_memcell(event_id="e1", timestamp=500.0), state + ) + + assert "e1" in out_state.event_ids + assert 500.0 in out_state.timestamps + assert len(out_state.vectors) == 1 + + @pytest.mark.asyncio + async def test_stats_incremented(self): + """total_memcells and clustered_memcells always incremented on success.""" + mgr = _make_manager(llm_provider=MagicMock()) + state = MemSceneState() + + await mgr._cluster_memcell_llm(_make_memcell(), state) + + assert mgr._stats["total_memcells"] == 1 + assert mgr._stats["clustered_memcells"] == 1 + + @pytest.mark.asyncio + async def test_multiple_memcells_increment_cluster_count(self): + """Assigning two events to the same cluster increments count.""" + centroid = np.array([1.0, 0.0, 0.0], dtype=np.float32) + mgr = _make_manager(llm_provider=None, embedding_vec=centroid) + state = _state_with_case_cluster(centroid=centroid) + + original_count = state.cluster_counts["cluster_000"] + await mgr._cluster_memcell_llm(_make_memcell(), state) + + assert state.cluster_counts["cluster_000"] > original_count diff --git a/methods/evermemos/tests/test_llm_metrics.py b/methods/evermemos/tests/test_llm_metrics.py new file mode 100644 index 000000000..06d748f32 --- /dev/null +++ b/methods/evermemos/tests/test_llm_metrics.py @@ -0,0 +1,200 @@ +"""LLM Prometheus metrics integration tests (mock HTTP).""" + +import pytest +import aiohttp +from unittest.mock import AsyncMock, patch + +from memory_layer.llm.openai_provider import OpenAIProvider +from memory_layer.llm.protocol import LLMError +from memory_layer.llm.api_key_rotator import ApiKeyRotator + + +@pytest.fixture(autouse=True) +def _reset_shared_rotator(): + """Ensure each test starts with a clean singleton state.""" + ApiKeyRotator._shared = None + yield + ApiKeyRotator._shared = None + + +def _success_body(content: str = "hello") -> dict: + return { + "choices": [{"message": {"content": content}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + + +def _error_body(message: str = "error") -> dict: + return {"error": {"message": message}} + + +METRICS_PATCH = "memory_layer.llm.openai_provider.record_llm_request" + + +class TestMetricsOnSuccess: + """HTTP 200: record status=success.""" + + @pytest.mark.asyncio + async def test_success_records_metric(self) -> None: + provider = OpenAIProvider( + api_key="key-a", base_url="https://fake.api", model="test-model" + ) + + async def mock_request(data: dict, api_key: str) -> tuple[int, dict]: + return 200, _success_body("ok") + + provider._do_request = mock_request + + with patch(METRICS_PATCH) as mock_record: + await provider.generate("test") + mock_record.assert_called_once_with("test-model", "success") + + +class TestMetricsOnRateLimit: + """HTTP 429 (all keys exhausted): record status=rate_limit.""" + + @pytest.mark.asyncio + async def test_429_all_keys_exhausted_records_rate_limit(self) -> None: + provider = OpenAIProvider( + api_key="key-a,key-b", base_url="https://fake.api", model="test-model" + ) + + async def always_429(data: dict, api_key: str) -> tuple[int, dict]: + return 429, _error_body("rate limited") + + provider._do_request = always_429 + + with patch(METRICS_PATCH) as mock_record: + with pytest.raises(LLMError, match="keys exhausted"): + await provider.generate("test") + mock_record.assert_called_once_with("test-model", "rate_limit") + + +class TestMetricsOnKeyError: + """HTTP 401/402/403 (all keys exhausted): record status=key_error.""" + + @pytest.mark.asyncio + async def test_401_all_keys_exhausted_records_key_error(self) -> None: + provider = OpenAIProvider( + api_key="key-a", base_url="https://fake.api", model="test-model" + ) + + async def always_401(data: dict, api_key: str) -> tuple[int, dict]: + return 401, _error_body("unauthorized") + + provider._do_request = always_401 + + with patch(METRICS_PATCH) as mock_record: + with pytest.raises(LLMError, match="keys exhausted"): + await provider.generate("test") + mock_record.assert_called_once_with("test-model", "key_error") + + +class TestMetricsOnServerError: + """HTTP 5xx (after max retries): record status=server_error.""" + + @pytest.mark.asyncio + async def test_5xx_exhausted_records_server_error(self) -> None: + provider = OpenAIProvider( + api_key="key-a", base_url="https://fake.api", model="test-model" + ) + + async def always_502(data: dict, api_key: str) -> tuple[int, dict]: + return 502, _error_body("bad gateway") + + provider._do_request = always_502 + + with ( + patch(METRICS_PATCH) as mock_record, + patch( + "memory_layer.llm.openai_provider.asyncio.sleep", new_callable=AsyncMock + ), + ): + with pytest.raises(LLMError, match="after 5 retries"): + await provider.generate("test") + mock_record.assert_called_once_with("test-model", "server_error") + + +class TestMetricsOnClientError: + """Network errors (after max retries): record status=client_error.""" + + @pytest.mark.asyncio + async def test_network_error_records_client_error(self) -> None: + provider = OpenAIProvider( + api_key="key-a", base_url="https://fake.api", model="test-model" + ) + + async def always_fail(data: dict, api_key: str) -> tuple[int, dict]: + raise aiohttp.ClientError("connection reset") + + provider._do_request = always_fail + + with patch(METRICS_PATCH) as mock_record: + with pytest.raises(LLMError, match="Request failed"): + await provider.generate("test") + mock_record.assert_called_once_with("test-model", "client_error") + + +class TestMetricsOnRequestError: + """HTTP 400/404/422: record status=request_error.""" + + @pytest.mark.asyncio + async def test_400_records_request_error(self) -> None: + provider = OpenAIProvider( + api_key="key-a", base_url="https://fake.api", model="test-model" + ) + + async def always_400(data: dict, api_key: str) -> tuple[int, dict]: + return 400, _error_body("bad request") + + provider._do_request = always_400 + + with patch(METRICS_PATCH) as mock_record: + with pytest.raises(LLMError, match="HTTP Error 400"): + await provider.generate("test") + mock_record.assert_called_once_with("test-model", "request_error") + + +class TestMetricsNotRecordedOnRetry: + """Metrics only recorded on final outcome, not intermediate retries.""" + + @pytest.mark.asyncio + async def test_429_then_success_records_only_success(self) -> None: + """429 followed by 200: only 'success' is recorded.""" + provider = OpenAIProvider( + api_key="key-a,key-b", base_url="https://fake.api", model="test-model" + ) + + responses = [(429, _error_body()), (200, _success_body("ok"))] + + async def mock_request(data: dict, api_key: str) -> tuple[int, dict]: + return responses.pop(0) + + provider._do_request = mock_request + + with patch(METRICS_PATCH) as mock_record: + await provider.generate("test") + mock_record.assert_called_once_with("test-model", "success") + + @pytest.mark.asyncio + async def test_5xx_then_success_records_only_success(self) -> None: + """502 followed by 200: only 'success' is recorded.""" + provider = OpenAIProvider( + api_key="key-a", base_url="https://fake.api", model="test-model" + ) + + responses = [(502, _error_body()), (200, _success_body("ok"))] + + async def mock_request(data: dict, api_key: str) -> tuple[int, dict]: + return responses.pop(0) + + provider._do_request = mock_request + + with ( + patch(METRICS_PATCH) as mock_record, + patch( + "memory_layer.llm.openai_provider.asyncio.sleep", new_callable=AsyncMock + ), + ): + await provider.generate("test") + mock_record.assert_called_once_with("test-model", "success") diff --git a/methods/evermemos/tests/test_mem_memorize.py b/methods/evermemos/tests/test_mem_memorize.py deleted file mode 100644 index 23e00e87b..000000000 --- a/methods/evermemos/tests/test_mem_memorize.py +++ /dev/null @@ -1,1318 +0,0 @@ -""" -tests/test_mem_memorize.py - -Unit tests for biz_layer/mem_memorize.py — the core memorization pipeline. - -Usage: - PYTHONPATH=src pytest tests/test_mem_memorize.py -v -""" - -import asyncio -from collections import defaultdict -from contextlib import ExitStack -from dataclasses import dataclass -from datetime import datetime -from types import SimpleNamespace -from typing import Dict, List, Optional -from unittest.mock import AsyncMock, MagicMock, patch, PropertyMock - -import pytest - -from api_specs.memory_types import ( - AgentCase, - AtomicFact, - EpisodeMemory, - Foresight, - MemCell, - MemoryType, - RawDataType, - ScenarioType, -) -from api_specs.dtos import MemorizeRequest -from biz_layer.memorize_config import ( - MemorizeConfig, - DEFAULT_MEMORIZE_CONFIG, - AGENT_DEFAULT_MEMORIZE_CONFIG, -) -from biz_layer.mem_memorize import ( - ExtractionState, - _is_agent_case_quality_sufficient, - _build_agent_cases_from_batch, - _clone_episodes_for_users, - _should_skip_atomic_fact_for_agent, - if_memorize, -) - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _make_memcell( - raw_data_type: RawDataType = RawDataType.CONVERSATION, - event_id: str = "evt-001", - participants: Optional[List[str]] = None, - timestamp: Optional[datetime] = None, -) -> MemCell: - mc = MagicMock(spec=MemCell) - mc.type = raw_data_type - mc.event_id = event_id - mc.original_data = [] - mc.participants = participants or ["user_001"] - mc.timestamp = timestamp or datetime(2026, 4, 7, 10, 0, 0) - return mc - - -def _make_request( - scene: str = "solo", - group_id: str = "grp-001", - raw_data_type: RawDataType = RawDataType.CONVERSATION, -) -> MemorizeRequest: - req = MagicMock(spec=MemorizeRequest) - req.group_id = group_id - req.session_id = "sess-001" - req.scene = scene - req.raw_data_type = raw_data_type - return req - - -def _make_agent_case( - quality_score: float = 0.8, - task_intent: str = "deploy service", - user_id: str = "user_001", -) -> AgentCase: - ac = MagicMock(spec=AgentCase) - ac.id = "ac-001" - ac.quality_score = quality_score - ac.task_intent = task_intent - ac.approach = "ssh and restart" - ac.key_insight = "check logs first" - ac.user_id = user_id - ac.memory_type = MemoryType.AGENT_CASE - ac.timestamp = datetime(2026, 4, 7, 10, 0, 0) - return ac - - -def _make_config(**overrides) -> MemorizeConfig: - return MemorizeConfig(**overrides) - - -def _make_state( - is_solo: bool = True, - has_episode: bool = True, - participants: Optional[List[str]] = None, -) -> ExtractionState: - state = MagicMock(spec=ExtractionState) - state.memcell = _make_memcell() - state.request = _make_request("solo" if is_solo else "team") - state.is_solo_scene = is_solo - state.participants = participants or ["user_001"] - state.current_time = datetime(2026, 4, 7, 10, 0, 0) - state.foresight_parent_type = "memcell" - state.atomic_fact_parent_type = "memcell" - state.parent_id = "evt-001" - saved_ep = MagicMock() - saved_ep.id = "ep-mongo-001" - state.group_episode_memories = [MagicMock(id="ep-mongo-001")] if has_episode else [] - state.parent_docs_map = {"ep-mongo-001": saved_ep} if has_episode else {} - state.episode_saved = has_episode - state.agent_case = None - state.group_episode = MagicMock() if has_episode else None - state.episode_memories = [] - return state - - -def _make_pending_entry( - event_id: str = "evt-001", - episode: str = "user asked about deployment", - timestamp: float = 1712484000.0, - participants: Optional[List[str]] = None, - scene: str = "solo", - agent_case: Optional[dict] = None, -) -> dict: - entry = { - "event_id": event_id, - "episode": episode, - "timestamp": timestamp, - "participants": participants if participants is not None else ["user_001"], - "group_id": "grp-001", - "scene": scene, - } - if agent_case: - entry["agent_case"] = agent_case - return entry - - -def _make_mem_scene_state( - pending: Optional[list] = None, - cluster_counts: Optional[dict] = None, - eventid_to_cluster: Optional[dict] = None, - cluster_last_ts: Optional[dict] = None, -): - state = MagicMock() - state.pending_clustering = pending if pending is not None else [] - state.cluster_counts = cluster_counts or {} - state.eventid_to_cluster = eventid_to_cluster or {} - state.cluster_last_ts = cluster_last_ts or {} - state.event_ids = list((eventid_to_cluster or {}).keys()) - state.timestamps = [] - state.to_dict.return_value = {} - return state - - -# =========================================================================== -# _is_agent_case_quality_sufficient -# =========================================================================== - -class TestIsAgentCaseQualitySufficient: - - def test_score_above_threshold_returns_true(self): - ac = _make_agent_case(quality_score=0.5) - config = _make_config(skill_min_quality_score=0.2) - assert _is_agent_case_quality_sufficient(ac, config) is True - - def test_score_equal_to_threshold_returns_true(self): - ac = _make_agent_case(quality_score=0.2) - config = _make_config(skill_min_quality_score=0.2) - assert _is_agent_case_quality_sufficient(ac, config) is True - - def test_score_below_threshold_returns_false(self): - ac = _make_agent_case(quality_score=0.1) - config = _make_config(skill_min_quality_score=0.2) - assert _is_agent_case_quality_sufficient(ac, config) is False - - def test_score_none_returns_false(self): - ac = _make_agent_case(quality_score=0.5) - ac.quality_score = None - config = _make_config(skill_min_quality_score=0.2) - assert _is_agent_case_quality_sufficient(ac, config) is False - - def test_score_zero_below_nonzero_threshold_returns_false(self): - ac = _make_agent_case(quality_score=0.0) - config = _make_config(skill_min_quality_score=0.1) - assert _is_agent_case_quality_sufficient(ac, config) is False - - def test_score_zero_with_zero_threshold_returns_true(self): - ac = _make_agent_case(quality_score=0.0) - config = _make_config(skill_min_quality_score=0.0) - assert _is_agent_case_quality_sufficient(ac, config) is True - - -# =========================================================================== -# _build_agent_cases_from_batch -# =========================================================================== - -class TestBuildAgentCasesFromBatch: - - def test_builds_cases_from_valid_entries(self): - entries = [ - _make_pending_entry( - event_id="evt-001", - timestamp=1712484000.0, - participants=["user_A"], - agent_case={ - "id": "ac-1", - "task_intent": "deploy", - "approach": "ssh", - "key_insight": "check logs", - "quality_score": 0.9, - }, - ), - _make_pending_entry( - event_id="evt-002", - timestamp=1712484100.0, - participants=["user_B"], - agent_case={ - "id": "ac-2", - "task_intent": "rollback", - "approach": "revert", - "key_insight": None, - "quality_score": 0.5, - }, - ), - ] - result = _build_agent_cases_from_batch(entries) - assert len(result) == 2 - assert "evt-001" in result - assert "evt-002" in result - assert result["evt-001"].task_intent == "deploy" - assert result["evt-002"].quality_score == 0.5 - - def test_skips_entry_without_event_id(self): - entries = [ - { - "episode": "test", - "timestamp": 1712484000.0, - "participants": [], - "agent_case": {"id": "ac-1", "task_intent": "x"}, - } - ] - result = _build_agent_cases_from_batch(entries) - assert len(result) == 0 - - def test_skips_entry_without_agent_case(self): - entries = [_make_pending_entry(event_id="evt-001")] - result = _build_agent_cases_from_batch(entries) - assert len(result) == 0 - - def test_empty_list_returns_empty_dict(self): - result = _build_agent_cases_from_batch([]) - assert result == {} - - def test_user_id_from_first_participant(self): - entries = [ - _make_pending_entry( - event_id="evt-001", - participants=["user_X", "user_Y"], - agent_case={"id": "ac-1", "task_intent": "x", "approach": "y"}, - ) - ] - result = _build_agent_cases_from_batch(entries) - assert result["evt-001"].user_id == "user_X" - - def test_empty_participants_gives_empty_user_id(self): - entries = [ - _make_pending_entry( - event_id="evt-001", - participants=[], - agent_case={"id": "ac-1", "task_intent": "x", "approach": "y"}, - ) - ] - result = _build_agent_cases_from_batch(entries) - assert result["evt-001"].user_id == "" - - def test_timestamp_none_uses_datetime_now(self): - entries = [ - _make_pending_entry( - event_id="evt-001", - timestamp=None, - agent_case={"id": "ac-1", "task_intent": "x", "approach": "y"}, - ) - ] - # timestamp=None in entry dict - entries[0]["timestamp"] = None - result = _build_agent_cases_from_batch(entries) - # Should not raise; datetime should be approximately now - assert result["evt-001"].timestamp is not None - - -# =========================================================================== -# _clone_episodes_for_users -# =========================================================================== - -class TestCloneEpisodesForUsers: - - def _make_clone_state(self, participants: List[str]) -> ExtractionState: - memcell = _make_memcell(participants=participants) - state = ExtractionState( - memcell=memcell, - request=_make_request(), - current_time=datetime(2026, 4, 7), - scene="solo", - is_solo_scene=True, - participants=participants, - ) - ep = EpisodeMemory( - memory_type=MemoryType.EPISODIC_MEMORY, - user_id="group", - timestamp=datetime(2026, 4, 7), - episode="test episode", - ) - state.group_episode_memories = [ep] - return state - - def test_clones_to_regular_users(self): - state = self._make_clone_state(["alice", "bob"]) - cloned = _clone_episodes_for_users(state) - assert len(cloned) == 2 - user_ids = {ep.user_id for ep in cloned} - assert user_ids == {"alice", "bob"} - - def test_filters_robot_names(self): - state = self._make_clone_state(["alice", "Robot_1", "assistant_bot", "Agent_X", "tool_call"]) - cloned = _clone_episodes_for_users(state) - assert len(cloned) == 1 - assert cloned[0].user_id == "alice" - - def test_filter_is_case_insensitive(self): - state = self._make_clone_state(["ROBOT_1", "ASSISTANT", "AGENT", "TOOL"]) - cloned = _clone_episodes_for_users(state) - assert len(cloned) == 0 - - def test_empty_participants(self): - state = self._make_clone_state([]) - cloned = _clone_episodes_for_users(state) - assert len(cloned) == 0 - - -# =========================================================================== -# _should_skip_atomic_fact_for_agent -# =========================================================================== - -class TestShouldSkipAtomicFactForAgent: - - def _make_memcell_with_data(self, messages: list) -> MemCell: - mc = MagicMock(spec=MemCell) - mc.original_data = [{"message": m} for m in messages] - return mc - - def test_no_tool_calls_returns_false(self): - mc = self._make_memcell_with_data([ - {"role": "user", "content": "hello"}, - {"role": "assistant", "content": "hi there, how can I help?"}, - ]) - assert _should_skip_atomic_fact_for_agent(mc) is False - - def test_tool_calls_short_response_returns_false(self): - mc = self._make_memcell_with_data([ - {"role": "user", "content": "run test"}, - {"role": "assistant", "tool_calls": [{"id": "t1"}]}, - {"role": "tool", "content": "ok"}, - {"role": "assistant", "content": "done"}, - ]) - assert _should_skip_atomic_fact_for_agent(mc) is False - - def test_tool_calls_long_response_returns_true(self): - long_text = "x" * 1000 - mc = self._make_memcell_with_data([ - {"role": "user", "content": "analyze"}, - {"role": "assistant", "tool_calls": [{"id": "t1"}]}, - {"role": "tool", "content": "data"}, - {"role": "assistant", "content": long_text}, - ]) - assert _should_skip_atomic_fact_for_agent(mc) is True - - def test_cumulative_response_across_messages(self): - # 600 + 500 = 1100 >= 1000 -> True - mc = self._make_memcell_with_data([ - {"role": "user", "content": "go"}, - {"role": "assistant", "tool_calls": [{"id": "t1"}]}, - {"role": "tool", "content": "ok"}, - {"role": "assistant", "content": "a" * 600}, - {"role": "assistant", "content": "b" * 500}, - ]) - assert _should_skip_atomic_fact_for_agent(mc) is True - - def test_tool_call_assistant_messages_not_counted(self): - # Only non-tool-call assistant msgs count - mc = self._make_memcell_with_data([ - {"role": "user", "content": "go"}, - {"role": "assistant", "tool_calls": [{"id": "t1"}], "content": "x" * 2000}, - {"role": "tool", "content": "ok"}, - {"role": "assistant", "content": "short"}, - ]) - assert _should_skip_atomic_fact_for_agent(mc) is False - - def test_empty_original_data(self): - mc = MagicMock(spec=MemCell) - mc.original_data = [] - assert _should_skip_atomic_fact_for_agent(mc) is False - - def test_none_original_data(self): - mc = MagicMock(spec=MemCell) - mc.original_data = None - assert _should_skip_atomic_fact_for_agent(mc) is False - - -# =========================================================================== -# if_memorize -# =========================================================================== - -class TestIfMemorize: - - def test_always_returns_true(self): - mc = _make_memcell() - assert if_memorize(mc) is True - - -# =========================================================================== -# _trigger_clustering -# =========================================================================== - -class TestTriggerClustering: - - @pytest.mark.asyncio - async def test_builds_pending_entry_and_calls_drain(self): - from biz_layer.mem_memorize import _trigger_clustering - - mc = _make_memcell(event_id="evt-100", participants=["u1", "u2"]) - mc.timestamp = datetime(2026, 4, 7, 12, 0, 0) - - with patch( - 'biz_layer.mem_memorize._drain_and_cluster', - new_callable=AsyncMock, - return_value=0, - ) as mock_drain: - await _trigger_clustering( - group_id="grp-001", - memcell=mc, - scene="team", - episode_text="they discussed plans", - ) - - mock_drain.assert_called_once() - call_kwargs = mock_drain.call_args[1] - assert call_kwargs["group_id"] == "grp-001" - entry = call_kwargs["new_entry"] - assert entry["event_id"] == "evt-100" - assert entry["episode"] == "they discussed plans" - assert entry["participants"] == ["u1", "u2"] - assert entry["scene"] == "team" - - @pytest.mark.asyncio - async def test_includes_agent_case_when_provided(self): - from biz_layer.mem_memorize import _trigger_clustering - - mc = _make_memcell() - ac = _make_agent_case() - - with patch( - 'biz_layer.mem_memorize._drain_and_cluster', - new_callable=AsyncMock, - return_value=0, - ) as mock_drain: - await _trigger_clustering( - group_id="grp-001", - memcell=mc, - agent_case=ac, - ) - entry = mock_drain.call_args[1]["new_entry"] - assert "agent_case" in entry - assert entry["agent_case"]["task_intent"] == "deploy service" - - @pytest.mark.asyncio - async def test_no_agent_case_key_when_none(self): - from biz_layer.mem_memorize import _trigger_clustering - - mc = _make_memcell() - - with patch( - 'biz_layer.mem_memorize._drain_and_cluster', - new_callable=AsyncMock, - return_value=0, - ) as mock_drain: - await _trigger_clustering(group_id="grp-001", memcell=mc) - entry = mock_drain.call_args[1]["new_entry"] - assert "agent_case" not in entry - - -# =========================================================================== -# _drain_and_cluster -# =========================================================================== - -class TestDrainAndCluster: - - def _patch_drain_deps(self, mem_scene_state=None, acquired=True, cluster_ids=None): - """Build patches for _drain_and_cluster dependencies.""" - if mem_scene_state is None: - mem_scene_state = _make_mem_scene_state() - - mock_storage = AsyncMock() - mock_storage.load_mem_scene = AsyncMock(return_value=None) - mock_storage.save_mem_scene = AsyncMock() - - patches = [ - patch( - 'biz_layer.mem_memorize.get_bean_by_type', - return_value=mock_storage, - ), - patch( - 'memory_layer.cluster_manager.MemSceneState', - return_value=mem_scene_state, - ), - patch( - 'memory_layer.cluster_manager.MemSceneState.from_dict', - return_value=mem_scene_state, - ), - patch( - 'biz_layer.mem_memorize._run_batch_clustering', - new_callable=AsyncMock, - return_value=cluster_ids or ["cluster-1"], - ), - patch( - 'biz_layer.mem_memorize._run_profile_extraction_for_batch', - new_callable=AsyncMock, - ), - patch( - 'biz_layer.mem_memorize._run_skill_extraction_for_batch', - new_callable=AsyncMock, - ), - ] - - # Mock distributed_lock context manager - lock_cm = AsyncMock() - lock_cm.__aenter__ = AsyncMock(return_value=acquired) - lock_cm.__aexit__ = AsyncMock(return_value=False) - - patches.append( - patch( - 'core.lock.redis_distributed_lock.distributed_lock', - return_value=lock_cm, - ) - ) - - return patches, mock_storage, mem_scene_state - - @pytest.mark.asyncio - async def test_lock_not_acquired_returns_zero(self): - from biz_layer.mem_memorize import _drain_and_cluster - - patches, _, _ = self._patch_drain_deps(acquired=False) - with ExitStack() as stack: - for p in patches: - stack.enter_context(p) - result = await _drain_and_cluster("grp-001", _make_config()) - assert result == 0 - - @pytest.mark.asyncio - async def test_accumulates_when_below_batch_size(self): - from biz_layer.mem_memorize import _drain_and_cluster - - mss = _make_mem_scene_state(pending=[]) - patches, mock_storage, _ = self._patch_drain_deps(mem_scene_state=mss) - config = _make_config(cluster_batch_size=5) - - with ExitStack() as stack: - mocks = [stack.enter_context(p) for p in patches] - entry = _make_pending_entry() - result = await _drain_and_cluster("grp-001", config, new_entry=entry) - - assert result == 0 - mock_storage.save_mem_scene.assert_called() - - @pytest.mark.asyncio - async def test_drains_when_batch_size_reached(self): - from biz_layer.mem_memorize import _drain_and_cluster - - pending = [_make_pending_entry(event_id=f"evt-{i}") for i in range(4)] - mss = _make_mem_scene_state(pending=pending) - patches, _, _ = self._patch_drain_deps(mem_scene_state=mss) - config = _make_config(cluster_batch_size=5) - - with ExitStack() as stack: - mocks = [stack.enter_context(p) for p in patches] - entry = _make_pending_entry(event_id="evt-4") - result = await _drain_and_cluster("grp-001", config, new_entry=entry) - - assert result == 5 - - @pytest.mark.asyncio - async def test_force_drain_with_pending_items(self): - from biz_layer.mem_memorize import _drain_and_cluster - - pending = [_make_pending_entry(event_id="evt-0")] - mss = _make_mem_scene_state(pending=pending) - patches, _, _ = self._patch_drain_deps(mem_scene_state=mss) - config = _make_config(cluster_batch_size=100) - - with ExitStack() as stack: - mocks = [stack.enter_context(p) for p in patches] - result = await _drain_and_cluster( - "grp-001", config, force_drain=True - ) - - assert result == 1 - - @pytest.mark.asyncio - async def test_force_drain_empty_returns_zero(self): - from biz_layer.mem_memorize import _drain_and_cluster - - mss = _make_mem_scene_state(pending=[]) - patches, _, _ = self._patch_drain_deps(mem_scene_state=mss) - config = _make_config(cluster_batch_size=100) - - with ExitStack() as stack: - for p in patches: - stack.enter_context(p) - result = await _drain_and_cluster( - "grp-001", config, force_drain=True - ) - - assert result == 0 - - @pytest.mark.asyncio - async def test_skill_extraction_runs_outside_lock(self): - from biz_layer.mem_memorize import _drain_and_cluster - - pending = [_make_pending_entry()] - mss = _make_mem_scene_state(pending=pending) - patches, _, _ = self._patch_drain_deps(mem_scene_state=mss) - config = _make_config(cluster_batch_size=1) - - with ExitStack() as stack: - mocks = [stack.enter_context(p) for p in patches] - # Find _run_skill_extraction_for_batch mock - skill_mock = None - for m in mocks: - if hasattr(m, '_mock_name') and 'skill' in str(getattr(m, '_mock_name', '')): - skill_mock = m - result = await _drain_and_cluster( - "grp-001", config, new_entry=_make_pending_entry() - ) - - # result > 0 means drain happened - assert result > 0 - - -# =========================================================================== -# _run_profile_extraction_for_batch -# =========================================================================== - -class TestRunProfileExtractionForBatch: - - @pytest.mark.asyncio - async def test_skip_when_config_flag_set(self): - from biz_layer.mem_memorize import _run_profile_extraction_for_batch - - config = _make_config(skip_profile_extraction=True) - # Should return immediately without calling anything - await _run_profile_extraction_for_batch( - group_id="grp-001", - drained_memcells=[_make_pending_entry()], - cluster_ids=["c1"], - mem_scene_state=_make_mem_scene_state(cluster_counts={"c1": 5}), - config=config, - ) - # No exception means it returned early - - @pytest.mark.asyncio - async def test_interval_lte_1_always_extracts(self): - from biz_layer.mem_memorize import _run_profile_extraction_for_batch - - config = _make_config( - skip_profile_extraction=False, - profile_extraction_interval=1, - ) - mss = _make_mem_scene_state( - cluster_counts={"c1": 1}, - cluster_last_ts={"c1": 9999999999.0}, - ) - - mock_profile_repo = AsyncMock() - mock_profile_repo.get_all_by_group = AsyncMock(return_value=[]) - - with ( - patch('biz_layer.mem_memorize.get_bean_by_type', return_value=mock_profile_repo), - patch( - 'biz_layer.mem_memorize._trigger_profile_extraction', - new_callable=AsyncMock, - ) as mock_trigger, - ): - await _run_profile_extraction_for_batch( - group_id="grp-001", - drained_memcells=[_make_pending_entry(timestamp=100.0)], - cluster_ids=["c1"], - mem_scene_state=mss, - config=config, - ) - mock_trigger.assert_called_once() - - @pytest.mark.asyncio - async def test_interval_modulo_skips_when_not_met(self): - from biz_layer.mem_memorize import _run_profile_extraction_for_batch - - config = _make_config( - skip_profile_extraction=False, - profile_extraction_interval=5, - cluster_batch_size=1, # not batch mode - ) - mss = _make_mem_scene_state(cluster_counts={"c1": 3}) - - with patch( - 'biz_layer.mem_memorize._trigger_profile_extraction', - new_callable=AsyncMock, - ) as mock_trigger: - await _run_profile_extraction_for_batch( - group_id="grp-001", - drained_memcells=[_make_pending_entry()], - cluster_ids=["c1"], - mem_scene_state=mss, - config=config, - ) - mock_trigger.assert_not_called() - - @pytest.mark.asyncio - async def test_batch_mode_extracts_when_count_gte_interval(self): - from biz_layer.mem_memorize import _run_profile_extraction_for_batch - - config = _make_config( - skip_profile_extraction=False, - profile_extraction_interval=5, - cluster_batch_size=10, # batch mode - ) - mss = _make_mem_scene_state( - cluster_counts={"c1": 5}, - cluster_last_ts={"c1": 9999999999.0}, - ) - - mock_profile_repo = AsyncMock() - mock_profile_repo.get_all_by_group = AsyncMock(return_value=[]) - - with ( - patch('biz_layer.mem_memorize.get_bean_by_type', return_value=mock_profile_repo), - patch( - 'biz_layer.mem_memorize._trigger_profile_extraction', - new_callable=AsyncMock, - ) as mock_trigger, - ): - await _run_profile_extraction_for_batch( - group_id="grp-001", - drained_memcells=[_make_pending_entry(timestamp=100.0)], - cluster_ids=["c1"], - mem_scene_state=mss, - config=config, - ) - mock_trigger.assert_called_once() - - @pytest.mark.asyncio - async def test_force_drain_extracts_when_count_gte_interval(self): - from biz_layer.mem_memorize import _run_profile_extraction_for_batch - - config = _make_config( - skip_profile_extraction=False, - profile_extraction_interval=5, - cluster_batch_size=1, # non-batch - ) - mss = _make_mem_scene_state( - cluster_counts={"c1": 6}, - cluster_last_ts={"c1": 9999999999.0}, - ) - - mock_profile_repo = AsyncMock() - mock_profile_repo.get_all_by_group = AsyncMock(return_value=[]) - - with ( - patch('biz_layer.mem_memorize.get_bean_by_type', return_value=mock_profile_repo), - patch( - 'biz_layer.mem_memorize._trigger_profile_extraction', - new_callable=AsyncMock, - ) as mock_trigger, - ): - await _run_profile_extraction_for_batch( - group_id="grp-001", - drained_memcells=[_make_pending_entry(timestamp=100.0)], - cluster_ids=["c1"], - mem_scene_state=mss, - config=config, - force_drain=True, - ) - mock_trigger.assert_called_once() - - @pytest.mark.asyncio - async def test_scene_picked_from_last_entry(self): - from biz_layer.mem_memorize import _run_profile_extraction_for_batch - - config = _make_config( - skip_profile_extraction=False, - profile_extraction_interval=1, - ) - mss = _make_mem_scene_state( - cluster_counts={"c1": 2}, - cluster_last_ts={"c1": 9999999999.0}, - ) - entries = [ - _make_pending_entry(event_id="e1", scene="solo"), - _make_pending_entry(event_id="e2", scene="team"), - ] - - mock_profile_repo = AsyncMock() - mock_profile_repo.get_all_by_group = AsyncMock(return_value=[]) - - with ( - patch('biz_layer.mem_memorize.get_bean_by_type', return_value=mock_profile_repo), - patch( - 'biz_layer.mem_memorize._trigger_profile_extraction', - new_callable=AsyncMock, - ) as mock_trigger, - ): - await _run_profile_extraction_for_batch( - group_id="grp-001", - drained_memcells=entries, - cluster_ids=["c1", "c1"], - mem_scene_state=mss, - config=config, - ) - call_kwargs = mock_trigger.call_args[1] - assert call_kwargs["scene"] == "team" - - @pytest.mark.asyncio - async def test_no_target_clusters_skips_extraction(self): - from biz_layer.mem_memorize import _run_profile_extraction_for_batch - - config = _make_config( - skip_profile_extraction=False, - profile_extraction_interval=1, - ) - # cluster_last_ts has old timestamps, all cluster_ids are None - mss = _make_mem_scene_state( - cluster_counts={"c1": 2}, - cluster_last_ts={"c1": 0.0}, - ) - - mock_profile_repo = AsyncMock() - existing_profile = MagicMock() - existing_profile.last_updated_ts = 9999999999.0 - mock_profile_repo.get_all_by_group = AsyncMock(return_value=[existing_profile]) - - with ( - patch('biz_layer.mem_memorize.get_bean_by_type', return_value=mock_profile_repo), - patch( - 'biz_layer.mem_memorize._trigger_profile_extraction', - new_callable=AsyncMock, - ) as mock_trigger, - ): - await _run_profile_extraction_for_batch( - group_id="grp-001", - drained_memcells=[_make_pending_entry(timestamp=100.0)], - cluster_ids=[None], # None cluster_id - mem_scene_state=mss, - config=config, - ) - mock_trigger.assert_not_called() - - -# =========================================================================== -# _trigger_profile_extraction -# =========================================================================== - -class TestTriggerProfileExtraction: - - def _patch_profile_deps(self, all_memcells=None, old_profiles=None, new_profiles=None): - mock_profile_repo = AsyncMock() - mock_profile_repo.get_all_profiles = AsyncMock( - return_value=old_profiles or {} - ) - mock_profile_repo.save_profile = AsyncMock() - mock_profile_repo.get_by_user_and_group = AsyncMock(return_value=None) - mock_profile_repo.upsert = AsyncMock() - - mock_memcell_repo = AsyncMock() - mock_memcell_repo.get_by_event_ids = AsyncMock( - return_value={} if all_memcells is None else {mc.event_id: mc for mc in all_memcells} - ) - - mock_llm = MagicMock() - - mock_profile_manager = AsyncMock() - mock_profile_manager.extract_profiles = AsyncMock( - return_value=new_profiles or [] - ) - - patches = [ - patch( - 'biz_layer.mem_memorize.get_bean_by_type', - side_effect=lambda cls: { - 'UserProfileRawRepository': mock_profile_repo, - 'MemCellRawRepository': mock_memcell_repo, - }.get(cls.__name__, MagicMock()), - ), - patch( - 'memory_layer.llm.llm_provider.build_default_provider', - return_value=mock_llm, - ), - patch( - 'memory_layer.profile_manager.ProfileManager', - return_value=mock_profile_manager, - ), - ] - return patches, mock_profile_repo, mock_memcell_repo, mock_profile_manager - - @pytest.mark.asyncio - async def test_skips_when_below_min_memcells(self): - from biz_layer.mem_memorize import _trigger_profile_extraction - - mss = _make_mem_scene_state(cluster_counts={"c1": 1}) - config = _make_config(profile_min_memcells=5) - - patches, mock_repo, _, _ = self._patch_profile_deps() - with ExitStack() as stack: - for p in patches: - stack.enter_context(p) - await _trigger_profile_extraction( - group_id="grp-001", - cluster_ids=["c1"], - mem_scene_state=mss, - latest_memcell_ts=100.0, - config=config, - ) - mock_repo.get_all_profiles.assert_not_called() - - @pytest.mark.asyncio - async def test_extracts_and_saves_profiles(self): - from biz_layer.mem_memorize import _trigger_profile_extraction - - fetched_mc = MagicMock() - fetched_mc.event_id = "evt-001" - fetched_mc.participants = ["user_A", "user_B"] - - new_profile = MagicMock() - new_profile.user_id = "user_A" - new_profile.to_dict.return_value = {"explicit_info": ["trait"]} - new_profile.total_items.return_value = 1 - - mss = _make_mem_scene_state( - cluster_counts={"c1": 3}, - eventid_to_cluster={"evt-001": "c1", "evt-002": "c1"}, - ) - - patches, mock_repo, _, mock_pm = self._patch_profile_deps( - all_memcells=[fetched_mc], - new_profiles=[new_profile], - ) - config = _make_config(profile_min_memcells=1) - - with ExitStack() as stack: - for p in patches: - stack.enter_context(p) - await _trigger_profile_extraction( - group_id="grp-001", - cluster_ids=["c1"], - mem_scene_state=mss, - latest_memcell_ts=100.0, - config=config, - ) - - mock_pm.extract_profiles.assert_called_once() - mock_repo.save_profile.assert_called_once() - - @pytest.mark.asyncio - async def test_advances_ts_on_failure(self): - from biz_layer.mem_memorize import _trigger_profile_extraction - - mss = _make_mem_scene_state( - cluster_counts={"c1": 3}, - eventid_to_cluster={"evt-001": "c1"}, - ) - - mock_profile_repo = AsyncMock() - # get_all_profiles succeeds so we reach extract_profiles which fails - mock_profile_repo.get_all_profiles = AsyncMock(return_value={}) - mock_profile_repo.get_by_user_and_group = AsyncMock(return_value=None) - mock_profile_repo.upsert = AsyncMock() - - mock_memcell_repo = AsyncMock() - fetched_mc = MagicMock() - fetched_mc.event_id = "evt-001" - fetched_mc.participants = ["user_X"] - mock_memcell_repo.get_by_event_ids = AsyncMock( - return_value={"evt-001": fetched_mc} - ) - - mock_pm = AsyncMock() - mock_pm.extract_profiles = AsyncMock(side_effect=RuntimeError("LLM down")) - - with ( - patch( - 'biz_layer.mem_memorize.get_bean_by_type', - side_effect=lambda cls: { - 'UserProfileRawRepository': mock_profile_repo, - 'MemCellRawRepository': mock_memcell_repo, - }.get(cls.__name__, MagicMock()), - ), - patch('memory_layer.llm.llm_provider.build_default_provider', return_value=MagicMock()), - patch('memory_layer.profile_manager.ProfileManager', return_value=mock_pm), - ): - # Should not raise - await _trigger_profile_extraction( - group_id="grp-001", - cluster_ids=["c1"], - mem_scene_state=mss, - latest_memcell_ts=100.0, - config=_make_config(profile_min_memcells=1), - ) - - # Should attempt to advance timestamp for user_X despite failure - mock_profile_repo.upsert.assert_called_once() - call_kwargs = mock_profile_repo.upsert.call_args[1] - assert call_kwargs["user_id"] == "user_X" - assert call_kwargs["metadata"]["last_updated_ts"] == 100.0 - - -# =========================================================================== -# _run_skill_extraction_for_batch -# =========================================================================== - -class TestRunSkillExtractionForBatch: - - @pytest.mark.asyncio - async def test_skip_when_config_flag_set(self): - from biz_layer.mem_memorize import _run_skill_extraction_for_batch - - config = _make_config(skip_skill_extraction=True) - await _run_skill_extraction_for_batch( - group_id="grp-001", - drained_memcells=[_make_pending_entry()], - cluster_ids=["c1"], - config=config, - ) - # No exception = returned early - - @pytest.mark.asyncio - async def test_skip_when_no_agent_cases(self): - from biz_layer.mem_memorize import _run_skill_extraction_for_batch - - config = _make_config(skip_skill_extraction=False) - entries = [_make_pending_entry()] # no agent_case - - with patch( - 'biz_layer.mem_memorize._trigger_agent_skill_extraction', - new_callable=AsyncMock, - ) as mock_trigger: - await _run_skill_extraction_for_batch( - group_id="grp-001", - drained_memcells=entries, - cluster_ids=["c1"], - config=config, - ) - mock_trigger.assert_not_called() - - @pytest.mark.asyncio - async def test_filters_low_quality_cases(self): - from biz_layer.mem_memorize import _run_skill_extraction_for_batch - - config = _make_config( - skip_skill_extraction=False, - skill_min_quality_score=0.5, - ) - entries = [ - _make_pending_entry( - event_id="evt-001", - agent_case={ - "id": "ac-1", - "task_intent": "x", - "approach": "y", - "quality_score": 0.1, # below threshold - }, - ) - ] - - with patch( - 'biz_layer.mem_memorize._trigger_agent_skill_extraction', - new_callable=AsyncMock, - ) as mock_trigger: - await _run_skill_extraction_for_batch( - group_id="grp-001", - drained_memcells=entries, - cluster_ids=["c1"], - config=config, - ) - mock_trigger.assert_not_called() - - @pytest.mark.asyncio - async def test_groups_by_cluster_and_triggers(self): - from biz_layer.mem_memorize import _run_skill_extraction_for_batch - - config = _make_config( - skip_skill_extraction=False, - skill_min_quality_score=0.0, - ) - ac_dict = { - "id": "ac-1", - "task_intent": "deploy", - "approach": "ssh", - "quality_score": 0.9, - } - entries = [ - _make_pending_entry(event_id="e1", agent_case=ac_dict), - _make_pending_entry(event_id="e2", agent_case={**ac_dict, "id": "ac-2"}), - _make_pending_entry(event_id="e3", agent_case={**ac_dict, "id": "ac-3"}), - ] - cluster_ids = ["c1", "c1", "c2"] - - with patch( - 'biz_layer.mem_memorize._trigger_agent_skill_extraction', - new_callable=AsyncMock, - return_value=False, - ) as mock_trigger: - await _run_skill_extraction_for_batch( - group_id="grp-001", - drained_memcells=entries, - cluster_ids=cluster_ids, - config=config, - ) - assert mock_trigger.call_count == 2 # c1 and c2 - - @pytest.mark.asyncio - async def test_milvus_flush_when_changes(self): - from biz_layer.mem_memorize import _run_skill_extraction_for_batch - - config = _make_config( - skip_skill_extraction=False, - skill_min_quality_score=0.0, - ) - entries = [ - _make_pending_entry( - event_id="e1", - agent_case={"id": "ac-1", "task_intent": "x", "approach": "y", "quality_score": 0.9}, - ) - ] - - mock_milvus_repo = AsyncMock() - mock_milvus_repo.flush = AsyncMock() - - with ( - patch( - 'biz_layer.mem_memorize._trigger_agent_skill_extraction', - new_callable=AsyncMock, - return_value=True, # has milvus changes - ), - patch( - 'biz_layer.mem_memorize.get_bean_by_type', - return_value=mock_milvus_repo, - ), - ): - await _run_skill_extraction_for_batch( - group_id="grp-001", - drained_memcells=entries, - cluster_ids=["c1"], - config=config, - ) - mock_milvus_repo.flush.assert_called_once() - - @pytest.mark.asyncio - async def test_no_milvus_flush_when_no_changes(self): - from biz_layer.mem_memorize import _run_skill_extraction_for_batch - - config = _make_config( - skip_skill_extraction=False, - skill_min_quality_score=0.0, - ) - entries = [ - _make_pending_entry( - event_id="e1", - agent_case={"id": "ac-1", "task_intent": "x", "approach": "y", "quality_score": 0.9}, - ) - ] - - with patch( - 'biz_layer.mem_memorize._trigger_agent_skill_extraction', - new_callable=AsyncMock, - return_value=False, # no milvus changes - ): - # Should not try to flush - await _run_skill_extraction_for_batch( - group_id="grp-001", - drained_memcells=entries, - cluster_ids=["c1"], - config=config, - ) - # No exception = no flush attempted - - -# =========================================================================== -# flush_clustering -# =========================================================================== - -class TestFlushClustering: - - @pytest.mark.asyncio - async def test_uses_agent_default_config(self): - from biz_layer.mem_memorize import flush_clustering - - with ( - patch( - 'api_specs.id_generator.generate_single_user_group_id', - return_value="grp-user1", - ), - patch( - 'biz_layer.mem_memorize._drain_and_cluster', - new_callable=AsyncMock, - return_value=3, - ) as mock_drain, - ): - result = await flush_clustering("user1") - assert result == 3 - call_kwargs = mock_drain.call_args[1] - assert call_kwargs["force_drain"] is True - assert call_kwargs["config"] is AGENT_DEFAULT_MEMORIZE_CONFIG - - @pytest.mark.asyncio - async def test_custom_config_override(self): - from biz_layer.mem_memorize import flush_clustering - - custom = _make_config(cluster_batch_size=50) - with ( - patch( - 'api_specs.id_generator.generate_single_user_group_id', - return_value="grp-user1", - ), - patch( - 'biz_layer.mem_memorize._drain_and_cluster', - new_callable=AsyncMock, - return_value=0, - ) as mock_drain, - ): - await flush_clustering("user1", config=custom) - assert mock_drain.call_args[1]["config"] is custom - - -# =========================================================================== -# ExtractionState -# =========================================================================== - -class TestExtractionState: - - def test_episode_saved_true_when_parent_docs_map_populated(self): - mc = _make_memcell() - state = ExtractionState( - memcell=mc, - request=_make_request(), - current_time=datetime(2026, 4, 7), - scene="solo", - is_solo_scene=True, - participants=["user_001"], - ) - state.parent_docs_map["ep-001"] = MagicMock() - assert state.episode_saved is True - - def test_episode_saved_false_when_empty(self): - mc = _make_memcell() - state = ExtractionState( - memcell=mc, - request=_make_request(), - current_time=datetime(2026, 4, 7), - scene="solo", - is_solo_scene=True, - participants=["user_001"], - ) - assert state.episode_saved is False - - def test_default_parent_types_from_config(self): - mc = _make_memcell() - state = ExtractionState( - memcell=mc, - request=_make_request(), - current_time=datetime(2026, 4, 7), - scene="solo", - is_solo_scene=True, - participants=["user_001"], - ) - assert state.episode_parent_type == DEFAULT_MEMORIZE_CONFIG.default_episode_parent_type - assert state.foresight_parent_type == DEFAULT_MEMORIZE_CONFIG.default_foresight_parent_type - assert state.atomic_fact_parent_type == DEFAULT_MEMORIZE_CONFIG.default_atomic_fact_parent_type - - def test_parent_id_defaults_to_event_id(self): - mc = _make_memcell(event_id="evt-999") - state = ExtractionState( - memcell=mc, - request=_make_request(), - current_time=datetime(2026, 4, 7), - scene="solo", - is_solo_scene=True, - participants=["user_001"], - ) - assert state.parent_id == "evt-999" - - def test_post_init_creates_empty_lists(self): - mc = _make_memcell() - state = ExtractionState( - memcell=mc, - request=_make_request(), - current_time=datetime(2026, 4, 7), - scene="solo", - is_solo_scene=True, - participants=["user_001"], - ) - assert state.group_episode_memories == [] - assert state.episode_memories == [] - assert state.parent_docs_map == {} diff --git a/methods/evermemos/tests/test_openai_provider_key_rotation.py b/methods/evermemos/tests/test_openai_provider_key_rotation.py new file mode 100644 index 000000000..f2ed998ca --- /dev/null +++ b/methods/evermemos/tests/test_openai_provider_key_rotation.py @@ -0,0 +1,310 @@ +"""OpenAIProvider key rotation integration tests (mock HTTP).""" + +import pytest +import aiohttp +from unittest.mock import AsyncMock, patch + +from memory_layer.llm.openai_provider import OpenAIProvider +from memory_layer.llm.protocol import LLMError +from memory_layer.llm.api_key_rotator import ApiKeyRotator + + +@pytest.fixture(autouse=True) +def _reset_shared_rotator(): + """Ensure each test starts with a clean singleton state.""" + ApiKeyRotator._shared = None + yield + ApiKeyRotator._shared = None + + +def _success_body(content: str = "hello") -> dict: + return { + "choices": [{"message": {"content": content}, "finish_reason": "stop"}], + "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, + } + + +def _error_body(message: str = "rate limited") -> dict: + return {"error": {"message": message}} + + +class TestKeyRotationOn429: + """429: rotate key immediately, no sleep.""" + + @pytest.mark.asyncio + async def test_429_then_success_with_next_key(self) -> None: + provider = OpenAIProvider( + api_key="key-a,key-b,key-c", base_url="https://fake.api", model="test-model" + ) + + responses = [(429, _error_body()), (200, _success_body("ok"))] + call_keys: list[str] = [] + + async def capture_do_request(data: dict, api_key: str) -> tuple[int, dict]: + call_keys.append(api_key) + return responses.pop(0) + + provider._do_request = capture_do_request + + result = await provider.generate("test") + assert result == "ok" + assert len(call_keys) == 2 + assert call_keys[0] != call_keys[1] + + @pytest.mark.asyncio + async def test_all_keys_429_raises_error(self) -> None: + provider = OpenAIProvider( + api_key="key-a,key-b,key-c", base_url="https://fake.api", model="test-model" + ) + + async def always_429(data: dict, api_key: str) -> tuple[int, dict]: + return 429, _error_body("rate limited") + + provider._do_request = always_429 + + with pytest.raises(LLMError, match="3 keys exhausted"): + await provider.generate("test") + + @pytest.mark.asyncio + async def test_429_does_not_sleep(self) -> None: + provider = OpenAIProvider( + api_key="key-a,key-b", base_url="https://fake.api", model="test-model" + ) + + call_count = 0 + + async def mock_request(data: dict, api_key: str) -> tuple[int, dict]: + nonlocal call_count + call_count += 1 + if call_count == 1: + return 429, _error_body() + return 200, _success_body() + + provider._do_request = mock_request + + with patch( + "memory_layer.llm.openai_provider.asyncio.sleep", new_callable=AsyncMock + ) as mock_sleep: + await provider.generate("test") + mock_sleep.assert_not_called() + + @pytest.mark.asyncio + async def test_429_counter_not_reset_by_5xx(self) -> None: + """429 -> 5xx -> 429: consecutive_rate_limits does not reset on 5xx.""" + provider = OpenAIProvider( + api_key="key-a,key-b", base_url="https://fake.api", model="test-model" + ) + + responses = [ + (429, _error_body("rate limited")), + (502, _error_body("bad gateway")), + (429, _error_body("rate limited")), + ] + idx = 0 + + async def mock_request(data: dict, api_key: str) -> tuple[int, dict]: + nonlocal idx + resp = responses[idx] + idx += 1 + return resp + + provider._do_request = mock_request + + with patch( + "memory_layer.llm.openai_provider.asyncio.sleep", new_callable=AsyncMock + ): + with pytest.raises(LLMError, match="2 keys exhausted"): + await provider.generate("test") + + +class TestKeyRotationOn429And5xxInterleaved: + """429/5xx interleaved: sleep only on 5xx, not on 429.""" + + @pytest.mark.asyncio + async def test_429_then_5xx_sleeps_only_on_5xx(self) -> None: + """429 -> 502 -> 200: sleep called exactly once (on 502 only).""" + provider = OpenAIProvider( + api_key="key-a,key-b,key-c", base_url="https://fake.api", model="test-model" + ) + + responses = [ + (429, _error_body("rate limited")), + (502, _error_body("bad gateway")), + (200, _success_body("ok")), + ] + + async def mock_request(data: dict, api_key: str) -> tuple[int, dict]: + return responses.pop(0) + + provider._do_request = mock_request + + with patch( + "memory_layer.llm.openai_provider.asyncio.sleep", new_callable=AsyncMock + ) as mock_sleep: + result = await provider.generate("test") + assert result == "ok" + mock_sleep.assert_called_once() # only on 502, not on 429 + + +class TestRequestLevelErrors: + """400/404/422: no retry, raise immediately.""" + + @pytest.mark.asyncio + async def test_400_raises_immediately_no_retry(self) -> None: + provider = OpenAIProvider( + api_key="key-a,key-b", base_url="https://fake.api", model="test-model" + ) + + call_count = 0 + + async def mock_request(data: dict, api_key: str) -> tuple[int, dict]: + nonlocal call_count + call_count += 1 + return 400, _error_body("bad request") + + provider._do_request = mock_request + + with pytest.raises(LLMError, match="HTTP Error 400"): + await provider.generate("test") + assert call_count == 1 # no retry + + +class TestNetworkErrors: + """aiohttp.ClientError: retry up to max attempts.""" + + @pytest.mark.asyncio + async def test_client_error_retries_then_raises(self) -> None: + provider = OpenAIProvider( + api_key="key-a", base_url="https://fake.api", model="test-model" + ) + + async def always_fail(data: dict, api_key: str) -> tuple[int, dict]: + raise aiohttp.ClientError("connection reset") + + provider._do_request = always_fail + + with pytest.raises(LLMError, match="Request failed"): + await provider.generate("test") + + @pytest.mark.asyncio + async def test_client_error_then_success(self) -> None: + provider = OpenAIProvider( + api_key="key-a", base_url="https://fake.api", model="test-model" + ) + + call_count = 0 + + async def fail_then_ok(data: dict, api_key: str) -> tuple[int, dict]: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise aiohttp.ClientError("timeout") + return 200, _success_body("recovered") + + provider._do_request = fail_then_ok + + result = await provider.generate("test") + assert result == "recovered" + assert call_count == 2 + + +class TestKeyRotationOn5xx: + """5xx: sleep then retry.""" + + @pytest.mark.asyncio + async def test_5xx_retries_with_sleep(self) -> None: + provider = OpenAIProvider( + api_key="key-a,key-b", base_url="https://fake.api", model="test-model" + ) + + call_count = 0 + + async def mock_request(data: dict, api_key: str) -> tuple[int, dict]: + nonlocal call_count + call_count += 1 + if call_count == 1: + return 502, _error_body("bad gateway") + return 200, _success_body() + + provider._do_request = mock_request + + with patch( + "memory_layer.llm.openai_provider.asyncio.sleep", new_callable=AsyncMock + ) as mock_sleep: + result = await provider.generate("test") + assert result == "hello" + mock_sleep.assert_called_once() + + +class TestRetryStartsFromNextKey: + """Retries start from the key AFTER the failed one, cycling through all.""" + + @pytest.mark.asyncio + async def test_retry_uses_rotation_sequence(self) -> None: + """All attempts follow the rotation: rotation[0], [1], [2], all distinct.""" + provider = OpenAIProvider( + api_key="key-a,key-b,key-c", base_url="https://fake.api", model="test-model" + ) + + call_keys: list[str] = [] + + async def capture_request(data: dict, api_key: str) -> tuple[int, dict]: + call_keys.append(api_key) + return 429, _error_body("rate limited") + + provider._do_request = capture_request + + with pytest.raises(LLMError, match="keys exhausted"): + await provider.generate("test") + # 3 keys, 3 attempts, all distinct, in rotation order + assert len(call_keys) == 3 + assert len(set(call_keys)) == 3 + first = call_keys[0] + keys = ["key-a", "key-b", "key-c"] + first_idx = keys.index(first) + assert call_keys[1] == keys[(first_idx + 1) % 3] + assert call_keys[2] == keys[(first_idx + 2) % 3] + + @pytest.mark.asyncio + async def test_keys_repeat_after_full_cycle(self) -> None: + """With 2 keys and 5 retries (5xx), keys alternate without adjacent repeats.""" + provider = OpenAIProvider( + api_key="key-a,key-b", base_url="https://fake.api", model="test-model" + ) + + call_keys: list[str] = [] + + async def capture_request(data: dict, api_key: str) -> tuple[int, dict]: + call_keys.append(api_key) + return 502, _error_body("bad gateway") + + provider._do_request = capture_request + + with patch( + "memory_layer.llm.openai_provider.asyncio.sleep", new_callable=AsyncMock + ): + with pytest.raises(LLMError, match="after 5 retries"): + await provider.generate("test") + assert len(call_keys) == 5 + # Adjacent keys always differ + for i in range(len(call_keys) - 1): + assert call_keys[i] != call_keys[i + 1] + + +class TestSingleKeyBackwardCompat: + """Single key: behavior unchanged.""" + + @pytest.mark.asyncio + async def test_single_key_works_normally(self) -> None: + provider = OpenAIProvider( + api_key="single-key", base_url="https://fake.api", model="test-model" + ) + + async def mock_request(data: dict, api_key: str) -> tuple[int, dict]: + assert api_key == "single-key" + return 200, _success_body("single key response") + + provider._do_request = mock_request + + result = await provider.generate("test") + assert result == "single key response" diff --git a/methods/evermemos/tests/test_profile_extraction_interval.py b/methods/evermemos/tests/test_profile_extraction_interval.py index e9dc1fc6c..32ded293b 100644 --- a/methods/evermemos/tests/test_profile_extraction_interval.py +++ b/methods/evermemos/tests/test_profile_extraction_interval.py @@ -106,6 +106,35 @@ def test_backward_compat_old_format(self): assert state.cluster_counts == {"cluster_000": 1} assert state.cluster_last_ts == {"cluster_000": 100.0} + def test_from_dict_case_cluster_ids_none(self): + """case_cluster_ids=None in DB should deserialize to empty set.""" + data = { + "memcell_info": {}, + "memscene_info": {}, + "next_cluster_idx": 0, + "case_cluster_ids": None, + } + state = MemSceneState.from_dict(data) + assert state.case_cluster_ids == set() + + def test_from_dict_case_cluster_ids_missing(self): + """Missing case_cluster_ids key should deserialize to empty set.""" + data = { + "memcell_info": {}, + "memscene_info": {}, + "next_cluster_idx": 0, + } + state = MemSceneState.from_dict(data) + assert state.case_cluster_ids == set() + + def test_from_dict_case_cluster_ids_with_values(self): + """case_cluster_ids with values should roundtrip correctly.""" + state = MemSceneState() + state.case_cluster_ids = {"cluster_000", "cluster_001"} + d = state.to_dict() + restored = MemSceneState.from_dict(d) + assert restored.case_cluster_ids == {"cluster_000", "cluster_001"} + class TestIntervalLogic: """Interval skip/trigger decision logic.""" diff --git a/methods/evermemos/tests/test_tenant_cache_utils.py b/methods/evermemos/tests/test_tenant_cache_utils.py new file mode 100644 index 000000000..bb246e361 --- /dev/null +++ b/methods/evermemos/tests/test_tenant_cache_utils.py @@ -0,0 +1,156 @@ +""" +Test: TenantContextMissingError propagation in tenant_cache_utils + +Verifies that when app is ready but tenant context is missing, +TenantContextMissingError is raised instead of silently falling back. + +Run: + PYTHONPATH=src uv run pytest tests/test_tenant_cache_utils.py -v +""" + +from unittest.mock import patch, MagicMock + +import pytest + +from core.constants.exceptions import CriticalError +from core.tenants.tenant_models import TenantPatchKey +from core.tenants.tenantize.tenant_cache_utils import ( + get_or_compute_tenant_cache, + TenantContextMissingError, +) + + +@pytest.fixture +def mock_app_ready(): + """Mock tenant config with app_ready=True and no tenant context.""" + config = MagicMock() + config.app_ready = True + with ( + patch( + "core.tenants.tenantize.tenant_cache_utils.get_tenant_config", + return_value=config, + ), + patch( + "core.tenants.tenantize.tenant_cache_utils.get_current_tenant", + return_value=None, + ), + ): + yield config + + +@pytest.fixture +def mock_app_not_ready(): + """Mock tenant config with app_ready=False and no tenant context.""" + config = MagicMock() + config.app_ready = False + with ( + patch( + "core.tenants.tenantize.tenant_cache_utils.get_tenant_config", + return_value=config, + ), + patch( + "core.tenants.tenantize.tenant_cache_utils.get_current_tenant", + return_value=None, + ), + ): + yield config + + +class TestTenantContextMissingError: + """Test that strict tenant check raises TenantContextMissingError after app startup.""" + + def test_app_ready_no_tenant_raises_error(self, mock_app_ready): + """When app is ready and tenant context is missing, should raise even with fallback.""" + with pytest.raises( + TenantContextMissingError, match="Strict tenant check failed" + ): + get_or_compute_tenant_cache( + patch_key=TenantPatchKey.MILVUS_CONNECTION_CACHE_KEY, + compute_func=lambda: "computed", + fallback="default", + cache_description="test cache", + ) + + def test_app_ready_no_tenant_raises_error_callable_fallback(self, mock_app_ready): + """Callable fallback should not be invoked when strict check fails.""" + fallback_called = False + + def fallback_func(): + nonlocal fallback_called + fallback_called = True + return "fallback_value" + + with pytest.raises(TenantContextMissingError): + get_or_compute_tenant_cache( + patch_key=TenantPatchKey.MILVUS_CONNECTION_CACHE_KEY, + compute_func=lambda: "computed", + fallback=fallback_func, + cache_description="test cache", + ) + + assert ( + not fallback_called + ), "Fallback should not be called when strict check fails" + + def test_app_not_ready_uses_fallback(self, mock_app_not_ready): + """During startup (app not ready), should use fallback instead of raising.""" + result = get_or_compute_tenant_cache( + patch_key=TenantPatchKey.MILVUS_CONNECTION_CACHE_KEY, + compute_func=lambda: "computed", + fallback="default", + cache_description="test cache", + ) + assert result == "default" + + def test_app_not_ready_no_fallback_raises_runtime_error(self, mock_app_not_ready): + """During startup with no fallback, should raise RuntimeError (not TenantContextMissingError).""" + with pytest.raises(RuntimeError, match="no fallback provided"): + get_or_compute_tenant_cache( + patch_key=TenantPatchKey.MILVUS_CONNECTION_CACHE_KEY, + compute_func=lambda: "computed", + fallback=None, + cache_description="test cache", + ) + + def test_error_inherits_critical_error(self): + """TenantContextMissingError should be a CriticalError (and thus Exception).""" + assert issubclass(TenantContextMissingError, CriticalError) + assert issubclass(TenantContextMissingError, Exception) + + def test_error_not_swallowed_by_except_exception_in_cache_func( + self, mock_app_ready + ): + """The outer except Exception in get_or_compute_tenant_cache should not swallow it.""" + with pytest.raises(TenantContextMissingError): + get_or_compute_tenant_cache( + patch_key=TenantPatchKey.MILVUS_CONNECTION_CACHE_KEY, + compute_func=lambda: "computed", + fallback="default", + cache_description="test cache", + ) + + +class TestReraiseGatherCriticalErrors: + """Test that reraise_critical_errors works with asyncio.gather patterns.""" + + def test_reraise_critical_error(self): + """CriticalError in gather results should be re-raised.""" + from common_utils.async_utils import reraise_critical_errors + + error = TenantContextMissingError("tenant missing") + results = ["ok", error, "also ok"] + with pytest.raises(TenantContextMissingError, match="tenant missing"): + reraise_critical_errors(results) + + def test_regular_exceptions_not_reraised(self): + """Regular Exception in gather results should NOT be re-raised.""" + from common_utils.async_utils import reraise_critical_errors + + results = ["ok", ValueError("some error"), "also ok"] + reraise_critical_errors(results) # Should not raise + + def test_gather_isinstance_exception_still_matches(self): + """isinstance(error, Exception) should be True — CriticalError IS an Exception.""" + error = TenantContextMissingError("test") + assert isinstance(error, Exception) + assert isinstance(error, CriticalError) From 815679a9279881f999fd61963983d774c7cd82ca Mon Sep 17 00:00:00 2001 From: Yan Xiao Date: Tue, 14 Apr 2026 11:33:23 +0800 Subject: [PATCH 2/3] v0 -> v1 interface --- .../demo/utils/simple_memory_manager.py | 69 ++++++++++--------- 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/methods/evermemos/demo/utils/simple_memory_manager.py b/methods/evermemos/demo/utils/simple_memory_manager.py index f7b9499fd..5f88d000c 100644 --- a/methods/evermemos/demo/utils/simple_memory_manager.py +++ b/methods/evermemos/demo/utils/simple_memory_manager.py @@ -80,6 +80,7 @@ def __init__( base_url: str = "http://localhost:1995", group_id: str = "default_group", scene: str = ScenarioType.SOLO.value, + user_id: str = "demo_user", ): """Initialize the manager @@ -87,11 +88,13 @@ def __init__( base_url: API server address (default: localhost:1995) group_id: Group ID (default: default_group) scene: Scene type (default: "solo", options: "solo" or "team") + user_id: User ID for personal endpoint (default: "demo_user") """ self.base_url = base_url self.group_id = group_id self.group_name = "Simple Demo Group" self.scene = scene + self.user_id = user_id self.memorize_url = f"{base_url}/api/v1/memories" self.retrieve_url = f"{base_url}/api/v1/memories/search" self.settings_url = f"{base_url}/api/v1/settings" @@ -119,29 +122,32 @@ async def store(self, content: str, sender: str = "User") -> bool: ) # Use project's unified time utility (with timezone) message_id = f"msg_{self._message_counter}_{int(now.timestamp() * 1000)}" - # Build message data (completely consistent with test_v1api_search.py format) - message_data = { + # Build v1 PersonalAddRequest payload + role = "user" if sender.lower() == "user" else "assistant" + message_item = { "message_id": message_id, - "create_time": to_iso_format( - now - ), # Use project's unified time formatting (with timezone) - "sender": sender, - "sender_name": sender, # Consistent with JSON data format - "type": "text", # Message type + "sender_id": self.user_id if role == "user" else sender, + "sender_name": sender, + "role": role, + "timestamp": int(now.timestamp() * 1000), "content": content, - "group_id": self.group_id, - "group_name": self.group_name, - "scene": self.scene, # Use configured scene + } + payload = { + "user_id": self.user_id, + "messages": [message_item], } try: async with httpx.AsyncClient(timeout=500.0) as client: - response = await client.post(self.memorize_url, json=message_data) + response = await client.post(self.memorize_url, json=payload) response.raise_for_status() result = response.json() - if result.get("status") == "ok": - count = result.get("result", {}).get("count", 0) + # v1 response: {"data": {"status": "...", "count": N, ...}} + data = result.get("data", {}) + status = data.get("status", "") + count = data.get("count", 0) + if status: if count > 0: print( f" ✅ Stored: {content[:40]}... (Extracted {count} memories)" @@ -200,50 +206,47 @@ async def _init_settings(self) -> bool: return False async def search( - self, query: str, top_k: int = 3, mode: str = "rrf", show_details: bool = True + self, query: str, top_k: int = 3, mode: str = "vector", show_details: bool = True ) -> List[Dict[str, Any]]: """Search memories Args: query: Query text top_k: Number of results to return (default: 3) - mode: Retrieval mode (default: "rrf") - - "rrf": RRF fusion (recommended) + mode: - "keyword": Keyword retrieval (BM25) - "vector": Vector retrieval - "hybrid": Keyword + Vector + Rerank - - "rrf": Keyword + Vector + RRF fusion - "agentic": LLM-guided multi-round retrieval show_details: Whether to show detailed information (default: True) Returns: List of memories """ + # v1 SearchMemoriesRequest: POST with body {query, method, memory_types, top_k, filters} payload = { "query": query, + "method": mode, + "memory_types": ["episodic_memory"], "top_k": top_k, - "memory_types": "episodic_memory", - "retrieve_method": mode, - "group_id": self.group_id, + "filters": {"user_id": self.user_id}, } try: async with httpx.AsyncClient(timeout=30.0) as client: - response = await client.get(self.retrieve_url, params=payload) + response = await client.post(self.retrieve_url, json=payload) response.raise_for_status() result = response.json() - if result.get("status") == "ok": - # memories is grouped: [{"group_id": [Memory, ...]}, ...] - raw_memories = result.get("result", {}).get("memories", []) - metadata = result.get("result", {}).get("metadata", {}) - latency = metadata.get("total_latency_ms", 0) - - # Flatten grouped memories to flat list + # v1 response: {"data": {"episodes": [...], "profiles": [...], "raw_messages": [...], "agent_memory": ...}} + data = result.get("data", {}) + if data: + # Aggregate across memory_type buckets (we only requested episodic_memory here) memories = [] - for group_dict in raw_memories: - for group_id, mem_list in group_dict.items(): - memories.extend(mem_list) + for key in ("episodes", "profiles", "raw_messages"): + memories.extend(data.get(key) or []) + metadata = data.get("metadata", {}) or {} + latency = metadata.get("total_latency_ms", 0) if show_details: print( @@ -336,4 +339,4 @@ def print_summary(self): print(" - ❌ Won't extract: Too brief, low-information small talk") print( " - 🎯 Best practice: Multi-turn conversations, rich context, specific details" - ) + ) \ No newline at end of file From cb7e242b3925076ee70f0a8df01ac5661a99a59e Mon Sep 17 00:00:00 2001 From: Yan Xiao Date: Tue, 14 Apr 2026 11:48:33 +0800 Subject: [PATCH 3/3] Update comment --- methods/evermemos/env.template | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/methods/evermemos/env.template b/methods/evermemos/env.template index ee8beed18..c9495b778 100755 --- a/methods/evermemos/env.template +++ b/methods/evermemos/env.template @@ -221,7 +221,7 @@ AGENTIC_ROUND1_RERANK_TOP_N=10 # =================== # Controls which MemorizeConfig is used for agent conversations. -# - online: full pipeline (default) +# - online: full pipeline, fast skill search (default) # - fast_skill: skip profile/foresight/eventlog, skip maturity scoring AGENT_MEMORIZE_MODE=online