Skip to content

Is profile not supported as expected? #158

@AllentDan

Description

@AllentDan

gentic_layer.memory_manager - ERROR - memory_manager.py:618 - Error in get_vector_search_results: Unsupported memory type: MemoryType.PROFILE

Claude give me the following modification to fix it in src/agentic_layer/memory_manager.py.

from __future__ import annotations

from typing import Any, List, Optional, Tuple
import logging
import asyncio

from datetime import datetime, timedelta
import jieba
import numpy as np
import time
from typing import Dict, Any
from dataclasses import dataclass

from api_specs.memory_types import (
    BaseMemory,
    EpisodeMemory,
    EventLog,
    Foresight,
    RawDataType,
)
from biz_layer.mem_memorize import memorize
from api_specs.dtos import MemorizeRequest
from .fetch_mem_service import get_fetch_memory_service
from api_specs.dtos import (
    FetchMemRequest,
    FetchMemResponse,
    PendingMessage,
    RetrieveMemRequest,
    RetrieveMemResponse,
)
from api_specs.memory_models import Metadata
from core.di import get_bean_by_type
from core.oxm.constants import MAGIC_ALL
from infra_layer.adapters.out.search.repository.episodic_memory_es_repository import (
    EpisodicMemoryEsRepository,
)
from infra_layer.adapters.out.search.repository.foresight_es_repository import (
    ForesightEsRepository,
)
from infra_layer.adapters.out.search.repository.event_log_es_repository import (
    EventLogEsRepository,
)
from core.observation.tracing.decorators import trace_logger
from core.nlp.stopwords_utils import filter_stopwords
from common_utils.datetime_utils import (
    from_iso_format,
    get_now_with_timezone,
    to_iso_format,
)
from infra_layer.adapters.out.persistence.repository.memcell_raw_repository import (
    MemCellRawRepository,
)
from service.memory_request_log_service import MemoryRequestLogService
from infra_layer.adapters.out.persistence.repository.group_user_profile_memory_raw_repository import (
    GroupUserProfileMemoryRawRepository,
)
from infra_layer.adapters.out.persistence.document.memory.memcell import DataTypeEnum
from infra_layer.adapters.out.persistence.document.memory.user_profile import (
    UserProfile,
)
from infra_layer.adapters.out.search.repository.episodic_memory_milvus_repository import (
    EpisodicMemoryMilvusRepository,
)
from infra_layer.adapters.out.search.repository.foresight_milvus_repository import (
    ForesightMilvusRepository,
)
from infra_layer.adapters.out.search.repository.event_log_milvus_repository import (
    EventLogMilvusRepository,
)
from .vectorize_service import get_vectorize_service
from .rerank_service import get_rerank_service
from api_specs.memory_models import MemoryType, RetrieveMethod
from agentic_layer.metrics.retrieve_metrics import (
    record_retrieve_request,
    record_retrieve_stage,
    record_retrieve_error,
)
import os
from memory_layer.llm.llm_provider import LLMProvider
from agentic_layer.agentic_utils import (
    AgenticConfig,
    check_sufficiency,
    generate_multi_queries,
)
from agentic_layer.retrieval_utils import reciprocal_rank_fusion

logger = logging.getLogger(__name__)


# MemoryType -> ES Repository mapping
ES_REPO_MAP = {
    MemoryType.FORESIGHT: ForesightEsRepository,
    MemoryType.EVENT_LOG: EventLogEsRepository,
    MemoryType.EPISODIC_MEMORY: EpisodicMemoryEsRepository,
}


@dataclass
class EventLogCandidate:
    """Event Log candidate object (used for retrieval from atomic_fact)"""

    event_id: str
    user_id: str
    group_id: str
    timestamp: datetime
    episode: str  # atomic_fact content
    summary: str
    subject: str
    extend: dict  # contains embedding


class MemoryManager:
    """Unified memory interface.

    Provides the following main functions:
    - memorize: Accept raw data and persistently store
    - fetch_mem: Retrieve memory fields by key, supports multiple memory types
    - retrieve_mem: Memory reading based on prompt-based retrieval methods
    """

    def __init__(self) -> None:
        # Get memory service instance
        self._fetch_service = get_fetch_memory_service()
        self._request_log_service: MemoryRequestLogService = get_bean_by_type(
            MemoryRequestLogService
        )

        logger.info(
            "MemoryManager initialized with fetch_mem_service and retrieve_mem_service"
        )

    # --------- Write path (raw data -> memorize) ---------
    @trace_logger(operation_name="agentic_layer memory storage")
    async def memorize(self, memorize_request: MemorizeRequest) -> int:
        """Memorize a heterogeneous list of raw items.

        Accepts list[Any], where each item can be one of the typed raw dataclasses
        (ChatRawData / EmailRawData / MemoRawData / LincDocRawData) or any dict-like
        object. Each item is stored as a MemoryCell with a synthetic key.

        Returns:
            int: Number of memories extracted (0 if no boundary detected)
        """
        count = await memorize(memorize_request)
        return count

    # --------- Read path (query -> fetch_mem) ---------
    # Memory reading based on key-value, including static and dynamic memory
    @trace_logger(operation_name="agentic_layer memory reading")
    async def fetch_mem(self, request: FetchMemRequest) -> FetchMemResponse:
        """Retrieve memory data, supports multiple memory types

        Args:
            request: FetchMemRequest containing query parameters

        Returns:
            FetchMemResponse containing query results
        """
        logger.debug(
            f"fetch_mem called with request: user_id={request.user_id}, group_id={request.group_id}, "
            f"memory_type={request.memory_type}, time_range=[{request.start_time}, {request.end_time}]"
        )

        # repository supports MemoryType.EPISODIC_MEMORY type, default is episodic memory
        response = await self._fetch_service.find_memories(
            user_id=request.user_id,
            memory_type=request.memory_type,
            group_id=request.group_id,
            start_time=request.start_time,
            end_time=request.end_time,
            version_range=request.version_range,
            limit=request.limit,
        )

        # Note: response.metadata already contains complete employee information
        # including source, user_id, memory_type, limit, email, phone, full_name
        # No need to update again here, as fetch_mem_service already provides correct information

        logger.debug(
            f"fetch_mem returned {len(response.memories)} memories for user {request.user_id}"
        )
        return response

    # Memory reading based on retrieve_method, including static and dynamic memory
    @trace_logger(operation_name="agentic_layer memory retrieval")
    async def retrieve_mem(
        self, retrieve_mem_request: 'RetrieveMemRequest'
    ) -> RetrieveMemResponse:
        """Retrieve memory data, dispatching to different retrieval methods based on retrieve_method

        Args:
            retrieve_mem_request: RetrieveMemRequest containing retrieval parameters

        Returns:
            RetrieveMemResponse containing retrieval results
        """
        try:
            # Validate request parameters
            if not retrieve_mem_request:
                raise ValueError("retrieve_mem_request is required for retrieve_mem")

            # Dispatch based on retrieve_method
            retrieve_method = retrieve_mem_request.retrieve_method

            logger.info(
                f"retrieve_mem dispatching request: user_id={retrieve_mem_request.user_id}, "
                f"retrieve_method={retrieve_method}, query={retrieve_mem_request.query}"
            )

            # Create task to fetch pending messages concurrently
            pending_messages_task = asyncio.create_task(
                self._get_pending_messages(
                    user_id=retrieve_mem_request.user_id,
                    group_id=retrieve_mem_request.group_id,
                )
            )

            # Dispatch based on retrieval method
            match retrieve_method:
                case RetrieveMethod.KEYWORD:
                    response = await self.retrieve_mem_keyword(retrieve_mem_request)
                case RetrieveMethod.VECTOR:
                    response = await self.retrieve_mem_vector(retrieve_mem_request)
                case RetrieveMethod.HYBRID:
                    response = await self.retrieve_mem_hybrid(retrieve_mem_request)
                case RetrieveMethod.RRF:
                    response = await self.retrieve_mem_rrf(retrieve_mem_request)
                case RetrieveMethod.AGENTIC:
                    response = await self.retrieve_mem_agentic(retrieve_mem_request)
                case _:
                    raise ValueError(f"Unsupported retrieval method: {retrieve_method}")

            # Await pending messages and attach to response
            pending_messages = await pending_messages_task
            response.pending_messages = pending_messages

            return response

        except Exception as e:
            logger.error(f"Error in retrieve_mem: {e}", exc_info=True)
            return RetrieveMemResponse(
                memories=[],
                original_data=[],
                scores=[],
                importance_scores=[],
                total_count=0,
                has_more=False,
                query_metadata=Metadata(
                    source="retrieve_mem_service",
                    user_id=(
                        retrieve_mem_request.user_id if retrieve_mem_request else ""
                    ),
                    memory_type="retrieve",
                ),
                metadata=Metadata(
                    source="retrieve_mem_service",
                    user_id=(
                        retrieve_mem_request.user_id if retrieve_mem_request else ""
                    ),
                    memory_type="retrieve",
                ),
                pending_messages=[],
            )

    async def _get_pending_messages(
        self, user_id: Optional[str] = None, group_id: Optional[str] = None
    ) -> List[PendingMessage]:
        """
        Get pending (unconsumed) messages from MemoryRequestLogService.

        Fetches cached memory data that hasn't been consumed yet (sync_status=-1 or 0).

        Args:
            user_id: User ID filter (from retrieve_request)
            group_id: Group ID filter (from retrieve_request)

        Returns:
            List of PendingMessage objects
        """
        try:
            result = await self._request_log_service.get_pending_messages(
                user_id=user_id, group_id=group_id, limit=1000
            )

            logger.debug(
                f"Retrieved {len(result)} pending messages: "
                f"user_id={user_id}, group_id={group_id}"
            )
            return result
        except Exception as e:
            logger.error(f"Error fetching pending messages: {e}", exc_info=True)
            return []

    # Keyword retrieval method (original retrieve_mem logic)
    @trace_logger(operation_name="agentic_layer keyword memory retrieval")
    async def retrieve_mem_keyword(
        self, retrieve_mem_request: 'RetrieveMemRequest'
    ) -> RetrieveMemResponse:
        """Keyword-based memory retrieval"""
        start_time = time.perf_counter()
        memory_type = (
            retrieve_mem_request.memory_types[0].value
            if retrieve_mem_request.memory_types
            else 'unknown'
        )

        try:
            hits = await self.get_keyword_search_results(
                retrieve_mem_request, retrieve_method=RetrieveMethod.KEYWORD.value
            )
            duration = time.perf_counter() - start_time
            status = 'success' if hits else 'empty_result'

            record_retrieve_request(
                memory_type=memory_type,
                retrieve_method=RetrieveMethod.KEYWORD.value,
                status=status,
                duration_seconds=duration,
                results_count=len(hits),
            )

            return await self._to_response(hits, retrieve_mem_request)
        except Exception as e:
            duration = time.perf_counter() - start_time
            record_retrieve_request(
                memory_type=memory_type,
                retrieve_method=RetrieveMethod.KEYWORD.value,
                status='error',
                duration_seconds=duration,
                results_count=0,
            )
            logger.error(f"Error in retrieve_mem_keyword: {e}", exc_info=True)
            return await self._to_response([], retrieve_mem_request)

    async def get_keyword_search_results(
        self,
        retrieve_mem_request: 'RetrieveMemRequest',
        retrieve_method: str = RetrieveMethod.KEYWORD.value,
    ) -> List[Dict[str, Any]]:
        """Keyword search with stage-level metrics"""
        stage_start = time.perf_counter()
        memory_type = (
            retrieve_mem_request.memory_types[0].value
            if retrieve_mem_request.memory_types
            else 'unknown'
        )

        try:
            # Get parameters from Request
            if not retrieve_mem_request:
                raise ValueError("retrieve_mem_request is required for retrieve_mem")

            top_k = retrieve_mem_request.top_k
            query = retrieve_mem_request.query
            user_id = retrieve_mem_request.user_id
            group_id = retrieve_mem_request.group_id
            start_time = retrieve_mem_request.start_time
            end_time = retrieve_mem_request.end_time
            memory_types = retrieve_mem_request.memory_types

            # Convert query string to search word list
            # Use jieba for search mode word segmentation, then filter stopwords
            if query:
                raw_words = list(jieba.cut_for_search(query))
                query_words = filter_stopwords(raw_words, min_length=2)
            else:
                query_words = []

            logger.debug(f"query_words: {query_words}")

            # Build time range filter conditions, handle None values
            date_range = {}
            if start_time is not None:
                date_range["gte"] = start_time
            if end_time is not None:
                date_range["lte"] = end_time

            mem_type = memory_types[0]

            repo_class = ES_REPO_MAP.get(mem_type)
            if not repo_class:
                logger.warning(f"Unsupported memory_type: {mem_type}")
                return []

            es_repo = get_bean_by_type(repo_class)
            logger.debug(f"Using {repo_class.__name__} for {mem_type}")

            results = await es_repo.multi_search(
                query=query_words,
                user_id=user_id,
                group_id=group_id,
                size=top_k,
                from_=0,
                date_range=date_range,
            )

            # Mark memory_type, search_source, and unified score
            if results:
                for r in results:
                    r['memory_type'] = mem_type.value
                    r['_search_source'] = RetrieveMethod.KEYWORD.value
                    r['id'] = r.get('_id', '')  # Unify ES '_id' to 'id'
                    r['score'] = r.get('_score', 0.0)  # Unified score field

            # Record stage metrics
            record_retrieve_stage(
                retrieve_method=retrieve_method,
                stage=RetrieveMethod.KEYWORD.value,
                memory_type=memory_type,
                duration_seconds=time.perf_counter() - stage_start,
            )

            return results or []
        except Exception as e:
            record_retrieve_stage(
                retrieve_method=retrieve_method,
                stage=RetrieveMethod.KEYWORD.value,
                memory_type=memory_type,
                duration_seconds=time.perf_counter() - stage_start,
            )
            record_retrieve_error(
                retrieve_method=retrieve_method,
                stage=RetrieveMethod.KEYWORD.value,
                error_type=self._classify_retrieve_error(e),
            )
            logger.error(f"Error in get_keyword_search_results: {e}")
            raise

    # Vector-based memory retrieval
    @trace_logger(operation_name="agentic_layer vector memory retrieval")
    async def retrieve_mem_vector(
        self, retrieve_mem_request: 'RetrieveMemRequest'
    ) -> RetrieveMemResponse:
        """Vector-based memory retrieval"""
        start_time = time.perf_counter()
        memory_type = (
            retrieve_mem_request.memory_types[0].value
            if retrieve_mem_request.memory_types
            else 'unknown'
        )

        try:
            hits = await self.get_vector_search_results(
                retrieve_mem_request, retrieve_method=RetrieveMethod.VECTOR.value
            )
            duration = time.perf_counter() - start_time
            status = 'success' if hits else 'empty_result'

            record_retrieve_request(
                memory_type=memory_type,
                retrieve_method=RetrieveMethod.VECTOR.value,
                status=status,
                duration_seconds=duration,
                results_count=len(hits),
            )

            return await self._to_response(hits, retrieve_mem_request)
        except Exception as e:
            duration = time.perf_counter() - start_time
            record_retrieve_request(
                memory_type=memory_type,
                retrieve_method=RetrieveMethod.VECTOR.value,
                status='error',
                duration_seconds=duration,
                results_count=0,
            )
            logger.error(f"Error in retrieve_mem_vector: {e}")
            return await self._to_response([], retrieve_mem_request)

    async def get_vector_search_results(
        self,
        retrieve_mem_request: 'RetrieveMemRequest',
        retrieve_method: str = RetrieveMethod.VECTOR.value,
    ) -> List[Dict[str, Any]]:
        """Vector search with stage-level metrics (embedding + milvus_search)"""
        memory_type = (
            retrieve_mem_request.memory_types[0].value
            if retrieve_mem_request.memory_types
            else 'unknown'
        )

        milvus_start = time.perf_counter()  # init early so except handler never fails
        try:
            # Get parameters from Request
            logger.debug(
                f"get_vector_search_results called with retrieve_mem_request: {retrieve_mem_request}"
            )
            if not retrieve_mem_request:
                raise ValueError(
                    "retrieve_mem_request is required for get_vector_search_results"
                )
            query = retrieve_mem_request.query
            if not query:
                raise ValueError("query is required for retrieve_mem_vector")

            user_id = retrieve_mem_request.user_id
            group_id = retrieve_mem_request.group_id
            top_k = retrieve_mem_request.top_k
            start_time = retrieve_mem_request.start_time
            end_time = retrieve_mem_request.end_time
            mem_type = retrieve_mem_request.memory_types[0]

            logger.debug(
                f"retrieve_mem_vector called with query: {query}, user_id: {user_id}, group_id: {group_id}, top_k: {top_k}"
            )

            # Get vectorization service
            vectorize_service = get_vectorize_service()

            # Convert query text to vector (embedding stage)
            logger.debug(f"Starting to vectorize query text: {query}")
            embedding_start = time.perf_counter()
            query_vector = await vectorize_service.get_embedding(query)
            query_vector_list = query_vector.tolist()  # Convert to list format
            record_retrieve_stage(
                retrieve_method=retrieve_method,
                stage='embedding',
                memory_type=memory_type,
                duration_seconds=time.perf_counter() - embedding_start,
            )
            logger.debug(
                f"Query text vectorization completed, vector dimension: {len(query_vector_list)}"
            )

            # Select Milvus repository based on memory type
            # Note: PROFILE and GROUP_PROFILE are stored in MongoDB, not ES/Milvus
            match mem_type:
                case MemoryType.FORESIGHT:
                    milvus_repo = get_bean_by_type(ForesightMilvusRepository)
                case MemoryType.EVENT_LOG:
                    milvus_repo = get_bean_by_type(EventLogMilvusRepository)
                case MemoryType.EPISODIC_MEMORY:
                    milvus_repo = get_bean_by_type(EpisodicMemoryMilvusRepository)
                case MemoryType.PROFILE | MemoryType.GROUP_PROFILE:
                    # Profile memories are stored in MongoDB, not searchable via vector search
                    logger.warning(f"Memory type {mem_type} is not supported for vector search")
                    return []
                case _:
                    logger.warning(f"Unsupported memory type for vector search: {mem_type}")
                    return []

            # Handle time range filter conditions
            start_time_dt = None
            end_time_dt = None
            current_time_dt = None

            if start_time is not None:
                start_time_dt = (
                    from_iso_format(start_time)
                    if isinstance(start_time, str)
                    else start_time
                )

            if end_time is not None:
                if isinstance(end_time, str):
                    end_time_dt = from_iso_format(end_time)
                    # If date only format, set to end of day
                    if len(end_time) == 10:
                        end_time_dt = end_time_dt.replace(hour=23, minute=59, second=59)
                else:
                    end_time_dt = end_time

            # Handle foresight time range (only valid for foresight)
            if mem_type == MemoryType.FORESIGHT:
                if retrieve_mem_request.start_time:
                    start_time_dt = from_iso_format(retrieve_mem_request.start_time)
                if retrieve_mem_request.end_time:
                    end_time_dt = from_iso_format(retrieve_mem_request.end_time)
                if retrieve_mem_request.current_time:
                    current_time_dt = from_iso_format(retrieve_mem_request.current_time)

            # Call Milvus vector search (pass different parameters based on memory type)
            milvus_start = time.perf_counter()
            if mem_type == MemoryType.FORESIGHT:
                # Foresight: supports time range and validity filtering, supports radius parameter
                search_results = await milvus_repo.vector_search(
                    query_vector=query_vector_list,
                    user_id=user_id,
                    group_id=group_id,
                    start_time=start_time_dt,
                    end_time=end_time_dt,
                    current_time=current_time_dt,
                    limit=top_k,
                    score_threshold=0.0,
                    radius=retrieve_mem_request.radius,
                )
            else:
                # Episodic memory and event log: use timestamp filtering, supports radius parameter
                search_results = await milvus_repo.vector_search(
                    query_vector=query_vector_list,
                    user_id=user_id,
                    group_id=group_id,
                    start_time=start_time_dt,
                    end_time=end_time_dt,
                    limit=top_k,
                    score_threshold=0.0,
                    radius=retrieve_mem_request.radius,
                )
            record_retrieve_stage(
                retrieve_method=retrieve_method,
                stage='milvus_search',
                memory_type=memory_type,
                duration_seconds=time.perf_counter() - milvus_start,
            )

            for r in search_results:
                r['memory_type'] = mem_type.value
                r['_search_source'] = RetrieveMethod.VECTOR.value
                # Milvus already uses 'score', no need to rename

            return search_results
        except Exception as e:
            record_retrieve_stage(
                retrieve_method=retrieve_method,
                stage=RetrieveMethod.VECTOR.value,
                memory_type=memory_type,
                duration_seconds=time.perf_counter() - milvus_start,
            )
            record_retrieve_error(
                retrieve_method=retrieve_method,
                stage=RetrieveMethod.VECTOR.value,
                error_type=self._classify_retrieve_error(e),
            )
            logger.error(f"Error in get_vector_search_results: {e}")
            raise

    # Hybrid memory retrieval
    @trace_logger(operation_name="agentic_layer hybrid memory retrieval")
    async def retrieve_mem_hybrid(
        self, retrieve_mem_request: 'RetrieveMemRequest'
    ) -> RetrieveMemResponse:
        """Hybrid memory retrieval: keyword + vector + rerank"""
        start_time = time.perf_counter()
        memory_type = (
            retrieve_mem_request.memory_types[0].value
            if retrieve_mem_request.memory_types
            else 'unknown'
        )

        try:
            hits = await self._search_hybrid(
                retrieve_mem_request, retrieve_method=RetrieveMethod.HYBRID.value
            )
            duration = time.perf_counter() - start_time
            status = 'success' if hits else 'empty_result'

            record_retrieve_request(
                memory_type=memory_type,
                retrieve_method=RetrieveMethod.HYBRID.value,
                status=status,
                duration_seconds=duration,
                results_count=len(hits),
            )

            return await self._to_response(hits, retrieve_mem_request)
        except Exception as e:
            duration = time.perf_counter() - start_time
            record_retrieve_request(
                memory_type=memory_type,
                retrieve_method=RetrieveMethod.HYBRID.value,
                status='error',
                duration_seconds=duration,
                results_count=0,
            )
            logger.error(f"Error in retrieve_mem_hybrid: {e}")
            return await self._to_response([], retrieve_mem_request)

    # ================== Core Internal Methods ==================

    async def _rerank(
        self,
        query: str,
        hits: List[Dict],
        top_k: int,
        memory_type: str = 'unknown',
        retrieve_method: str = RetrieveMethod.HYBRID.value,
        instruction: str = None,
    ) -> List[Dict]:
        """Rerank hits using rerank service with stage metrics"""
        if not hits:
            return []

        stage_start = time.perf_counter()
        try:
            result = await get_rerank_service().rerank_memories(
                query, hits, top_k, instruction=instruction
            )
            record_retrieve_stage(
                retrieve_method=retrieve_method,
                stage='rerank',
                memory_type=memory_type,
                duration_seconds=time.perf_counter() - stage_start,
            )
            return result
        except Exception as e:
            record_retrieve_error(
                retrieve_method=retrieve_method,
                stage='rerank',
                error_type=self._classify_retrieve_error(e),
            )
            raise

    async def _search_hybrid(
        self,
        request: 'RetrieveMemRequest',
        retrieve_method: str = RetrieveMethod.HYBRID.value,
    ) -> List[Dict]:
        """Core hybrid search: keyword + vector + rerank, returns flat list"""
        memory_type = (
            request.memory_types[0].value if request.memory_types else 'unknown'
        )
        # Run keyword and vector search concurrently
        kw_results, vec_results = await asyncio.gather(
            self.get_keyword_search_results(request, retrieve_method=retrieve_method),
            self.get_vector_search_results(request, retrieve_method=retrieve_method),
        )
        # Deduplicate by id
        seen_ids = {h.get('id') for h in kw_results}
        merged_results = kw_results + [
            h for h in vec_results if h.get('id') not in seen_ids
        ]
        return await self._rerank(
            request.query, merged_results, request.top_k, memory_type, retrieve_method
        )

    async def _search_rrf(
        self,
        request: 'RetrieveMemRequest',
        retrieve_method: str = RetrieveMethod.RRF.value,
    ) -> List[Dict]:
        """Core RRF search: keyword + vector + RRF fusion, returns flat list"""
        memory_type = (
            request.memory_types[0].value if request.memory_types else 'unknown'
        )

        # Run keyword and vector search concurrently
        kw, vec = await asyncio.gather(
            self.get_keyword_search_results(request, retrieve_method=retrieve_method),
            self.get_vector_search_results(request, retrieve_method=retrieve_method),
        )

        # RRF fusion with stage metrics
        rrf_start = time.perf_counter()
        kw_tuples = [(h, h.get('score', 0)) for h in kw]
        vec_tuples = [(h, h.get('score', 0)) for h in vec]
        fused = reciprocal_rank_fusion(kw_tuples, vec_tuples, k=60)
        record_retrieve_stage(
            retrieve_method=retrieve_method,
            stage='rrf_fusion',
            memory_type=memory_type,
            duration_seconds=time.perf_counter() - rrf_start,
        )

        return [dict(doc, score=score) for doc, score in fused[: request.top_k]]

    def _classify_retrieve_error(self, error: Exception) -> str:
        """Classify error type for metrics"""
        error_str = str(error).lower()
        if 'timeout' in error_str or 'timed out' in error_str:
            return 'timeout'
        elif 'connection' in error_str or 'connect' in error_str:
            return 'connection_error'
        elif 'not found' in error_str or 'notfound' in error_str:
            return 'not_found'
        elif 'validation' in error_str or 'invalid' in error_str:
            return 'validation_error'
        else:
            return 'unknown'

    async def _to_response(
        self, hits: List[Dict], req: 'RetrieveMemRequest'
    ) -> RetrieveMemResponse:
        """Convert flat hits list to grouped RetrieveMemResponse"""
        user_id = req.user_id if req else ""
        source_type = req.retrieve_method.value
        memory_type = req.memory_types[0].value

        if not hits:
            return RetrieveMemResponse(
                memories=[],
                original_data=[],
                scores=[],
                importance_scores=[],
                total_count=0,
                has_more=False,
                query_metadata=Metadata(
                    source=source_type, user_id=user_id or "", memory_type=memory_type
                ),
                metadata=Metadata(
                    source=source_type, user_id=user_id or "", memory_type=memory_type
                ),
            )
        memories, scores, importance_scores, original_data, total_count = (
            await self.group_by_groupid_stratagy(hits, source_type=source_type)
        )
        return RetrieveMemResponse(
            memories=memories,
            scores=scores,
            importance_scores=importance_scores,
            original_data=original_data,
            total_count=total_count,
            has_more=False,
            query_metadata=Metadata(
                source=source_type, user_id=user_id or "", memory_type=memory_type
            ),
            metadata=Metadata(
                source=source_type, user_id=user_id or "", memory_type=memory_type
            ),
        )

    # --------- RRF retrieval (keyword + vector + RRF fusion, no rerank) ---------
    @trace_logger(operation_name="agentic_layer RRF memory retrieval")
    async def retrieve_mem_rrf(
        self, retrieve_mem_request: 'RetrieveMemRequest'
    ) -> RetrieveMemResponse:
        """RRF-based memory retrieval: keyword + vector + RRF fusion"""
        start_time = time.perf_counter()
        memory_type = (
            retrieve_mem_request.memory_types[0].value
            if retrieve_mem_request.memory_types
            else 'unknown'
        )

        try:
            hits = await self._search_rrf(
                retrieve_mem_request, retrieve_method=RetrieveMethod.RRF.value
            )
            duration = time.perf_counter() - start_time
            status = 'success' if hits else 'empty_result'

            record_retrieve_request(
                memory_type=memory_type,
                retrieve_method=RetrieveMethod.RRF.value,
                status=status,
                duration_seconds=duration,
                results_count=len(hits),
            )

            return await self._to_response(hits, retrieve_mem_request)
        except Exception as e:
            duration = time.perf_counter() - start_time
            record_retrieve_request(
                memory_type=memory_type,
                retrieve_method=RetrieveMethod.RRF.value,
                status='error',
                duration_seconds=duration,
                results_count=0,
            )
            logger.error(f"Error in retrieve_mem_rrf: {e}", exc_info=True)
            return await self._to_response([], retrieve_mem_request)

    # --------- Agentic retrieval (LLM-guided multi-round) ---------
    @trace_logger(operation_name="agentic_layer Agentic memory retrieval")
    async def retrieve_mem_agentic(
        self, retrieve_mem_request: 'RetrieveMemRequest'
    ) -> RetrieveMemResponse:
        """Agentic retrieval: LLM-guided multi-round intelligent retrieval

        Process: Round 1 (Hybrid) → Rerank → LLM sufficiency check → Round 2 (multi-query) → Merge → Final Rerank
        """
        start_time = time.perf_counter()
        req = retrieve_mem_request  # alias
        top_k = req.top_k
        config = AgenticConfig()
        memory_type = req.memory_types[0].value if req.memory_types else 'unknown'

        try:
            llm_provider = LLMProvider(
                provider_type=os.getenv("LLM_PROVIDER", "openai"),
                model=os.getenv("LLM_MODEL", "Qwen3-235B"),
                base_url=os.getenv("LLM_BASE_URL"),
                api_key=os.getenv("LLM_API_KEY"),
                temperature=float(os.getenv("LLM_TEMPERATURE", "0.3")),
                max_tokens=int(os.getenv("LLM_MAX_TOKENS", "16384")),
            )

            logger.info(f"Agentic Retrieval: {req.query[:60]}...")

            # ========== Round 1: Hybrid search ==========
            req1 = RetrieveMemRequest(
                query=req.query,
                user_id=req.user_id,
                group_id=req.group_id,
                top_k=config.round1_top_n,
                memory_types=req.memory_types,
            )
            round1 = await self._search_hybrid(req1, retrieve_method='agentic')
            logger.info(f"Round 1: {len(round1)} memories")

            if not round1:
                duration = time.perf_counter() - start_time
                record_retrieve_request(
                    memory_type=memory_type,
                    retrieve_method=RetrieveMethod.AGENTIC.value,
                    status='empty_result',
                    duration_seconds=duration,
                    results_count=0,
                )
                return await self._to_response([], req)

            # ========== Rerank → max(5, top_k) for LLM & return ==========
            rerank_n = max(config.round1_rerank_top_n, top_k)
            reranked = await self._rerank(
                req.query, round1, rerank_n, memory_type, 'agentic',
                instruction=config.reranker_instruction,
            )
            # Use top 5 for sufficiency check
            topn_for_llm = reranked[:config.round1_rerank_top_n]
            topn_pairs = [(m, m.get("score", 0)) for m in topn_for_llm]

            # ========== LLM sufficiency check ==========
            is_sufficient, reasoning, missing_info = await check_sufficiency(
                query=req.query,
                results=topn_pairs,
                llm_provider=llm_provider,
                max_docs=config.round1_rerank_top_n,
            )
            logger.info(
                f"LLM: {'Sufficient' if is_sufficient else 'Insufficient'} - {reasoning}"
            )

            if is_sufficient:
                # Return reranked results (already done above, no extra rerank)
                final_results = reranked[:top_k]
                duration = time.perf_counter() - start_time
                record_retrieve_request(
                    memory_type=memory_type,
                    retrieve_method=RetrieveMethod.AGENTIC.value,
                    status='success',
                    duration_seconds=duration,
                    results_count=len(final_results),
                )
                return await self._to_response(final_results, req)

            # ========== Round 2: Multi-query ==========
            refined_queries, _ = await generate_multi_queries(
                original_query=req.query,
                results=topn_pairs,
                missing_info=missing_info,
                llm_provider=llm_provider,
                max_docs=config.round1_rerank_top_n,
                num_queries=config.num_queries,
            )
            logger.info(f"Generated {len(refined_queries)} queries")

            # Parallel hybrid search
            async def do_search(q: str) -> List[Dict]:
                return await self._search_hybrid(
                    RetrieveMemRequest(
                        query=q,
                        user_id=req.user_id,
                        group_id=req.group_id,
                        top_k=config.round2_per_query_top_n,
                        memory_types=req.memory_types,
                    ),
                    retrieve_method='agentic',
                )

            round2_results = await asyncio.gather(
                *[do_search(q) for q in refined_queries], return_exceptions=True
            )
            all_round2 = [
                h for r in round2_results if not isinstance(r, Exception) for h in r
            ]

            # Deduplicate and merge
            seen_ids = {m.get("id") for m in round1}
            round2_unique = [m for m in all_round2 if m.get("id") not in seen_ids]
            combined = round1 + round2_unique[: config.combined_total - len(round1)]
            logger.info(f"Combined: {len(combined)} memories")

            # ========== Final Rerank ==========
            final = await self._rerank(
                req.query, combined, top_k, memory_type, 'agentic',
                instruction=config.reranker_instruction,
            )

            duration = time.perf_counter() - start_time
            record_retrieve_request(
                memory_type=memory_type,
                retrieve_method=RetrieveMethod.AGENTIC.value,
                status='success',
                duration_seconds=duration,
                results_count=len(final[:top_k]),
            )

            return await self._to_response(final[:top_k], req)

        except Exception as e:
            duration = time.perf_counter() - start_time
            record_retrieve_request(
                memory_type=memory_type,
                retrieve_method=RetrieveMethod.AGENTIC.value,
                status='error',
                duration_seconds=duration,
                results_count=0,
            )
            logger.error(f"Error in retrieve_mem_agentic: {e}", exc_info=True)
            return await self._to_response([], req)

    def _calculate_importance_score(
        self, importance_evidence: Optional[Dict[str, Any]]
    ) -> float:
        """Calculate group importance score

        Calculate score based on group importance evidence, mainly considering:
        - speak_count: User's speaking count in this group
        - refer_count: Number of times user was mentioned
        - conversation_count: Total conversation count in this group

        Importance score = (total speaking count + total mention count) / total conversation count

        Args:
            importance_evidence: Group importance evidence dictionary

        Returns:
            float: Importance score, range [0, +∞), larger value means more important group
        """
        if not importance_evidence or not isinstance(importance_evidence, dict):
            return 0.0

        evidence_list = importance_evidence.get('evidence_list', [])
        if not evidence_list:
            return 0.0

        total_speak_count = 0
        total_refer_count = 0
        total_conversation_count = 0

        # Accumulate statistics from all evidence
        for evidence in evidence_list:
            if isinstance(evidence, dict):
                total_speak_count += evidence.get('speak_count', 0)
                total_refer_count += evidence.get('refer_count', 0)
                total_conversation_count += evidence.get('conversation_count', 0)

        # Avoid division by zero
        if total_conversation_count == 0:
            return 0.0

        # Calculate importance score
        return (total_speak_count + total_refer_count) / total_conversation_count

    async def _batch_get_memcells(
        self, event_ids: List[str], batch_size: int = 100
    ) -> Dict[str, Any]:
        """Batch get MemCells, supports batch queries to control single query size

        Args:
            event_ids: List of event_id to get
            batch_size: Number of items per batch, default 100

        Returns:
            Dict[event_id, MemCell]: Mapping dictionary from event_id to MemCell
        """
        if not event_ids:
            return {}

        # Deduplicate event_ids
        unique_event_ids = list(set(event_ids))
        logger.debug(
            f"Batch get MemCells: Total {len(unique_event_ids)} (before deduplication: {len(event_ids)})"
        )

        memcell_repo = get_bean_by_type(MemCellRawRepository)
        all_memcells = {}

        # Batch get
        for i in range(0, len(unique_event_ids), batch_size):
            batch_event_ids = unique_event_ids[i : i + batch_size]
            logger.debug(
                f"Getting batch {i // batch_size + 1} MemCells: {len(batch_event_ids)} items"
            )

            batch_memcells = await memcell_repo.get_by_event_ids(batch_event_ids)
            all_memcells.update(batch_memcells)

        logger.debug(
            f"Batch get MemCells completed: Successfully retrieved {len(all_memcells)} items"
        )
        return all_memcells

    async def _batch_get_group_profiles(
        self, user_group_pairs: List[Tuple[str, str]]
    ) -> Dict[Tuple[str, str], Any]:
        """Batch get group user profiles, supports efficient querying

        Args:
            user_group_pairs: List of (user_id, group_id) tuples

        Returns:
            Dict[(user_id, group_id), GroupUserProfileMemory]: Mapping dictionary
        """
        if not user_group_pairs:
            return {}

        # Deduplicate
        unique_pairs = list(set(user_group_pairs))
        logger.debug(
            f"Batch get group user profiles: Total {len(unique_pairs)} (before deduplication: {len(user_group_pairs)})"
        )

        group_user_profile_repo = get_bean_by_type(GroupUserProfileMemoryRawRepository)
        profiles = await group_user_profile_repo.batch_get_by_user_groups(unique_pairs)

        logger.debug(
            f"Batch get group user profiles completed: Successfully retrieved {len([v for v in profiles.values() if v is not None])} items"
        )
        return profiles

    def _get_type_str(self, val) -> str:
        """Extract string value of type field"""
        if isinstance(val, RawDataType):
            return val.value
        return str(val) if val else ''

    def _extract_hit_fields_from_es(self, hit: Dict[str, Any]) -> Dict[str, Any]:
        """Extract fields from ES search result"""
        source = hit.get('_source', {})
        return {
            'hit_id': source.get('event_id', ''),
            'user_id': source.get('user_id', ''),
            'group_id': source.get('group_id', ''),
            'timestamp_raw': source.get('timestamp', ''),
            'episode': source.get('episode', ''),
            'memcell_event_id_list': source.get('memcell_event_id_list', []),
            'subject': source.get('subject', ''),
            'summary': source.get('summary', ''),
            'participants': source.get('participants', []),
            'event_type': source.get('type', ''),
            'atomic_fact': source.get('atomic_fact', ''),
            'foresight': source.get('foresight', ''),
            'evidence': source.get('evidence', ''),
            'extend_data': source.get('extend', {}) or {},
            'search_source': 'keyword',
        }

    def _extract_hit_fields_from_milvus(self, hit: Dict[str, Any]) -> Dict[str, Any]:
        """Extract fields from Milvus search result"""
        metadata = hit.get('metadata', {})
        timestamp_val = hit.get('timestamp') or hit.get('start_time')
        return {
            'hit_id': hit.get('id', ''),
            'user_id': hit.get('user_id', ''),
            'group_id': hit.get('group_id', ''),
            'timestamp_raw': timestamp_val,
            'episode': hit.get('episode', ''),
            'memcell_event_id_list': metadata.get('memcell_event_id_list', []),
            'subject': metadata.get('subject', ''),
            'summary': metadata.get('summary', ''),
            'participants': metadata.get('participants', []),
            'event_type': self._get_type_str(hit.get('type') or hit.get('event_type')),
            'atomic_fact': hit.get('atomic_fact', ''),
            'foresight': hit.get(
                'content', ''
            ),  # Milvus foresight uses 'content' field
            'evidence': hit.get('evidence', ''),
            'extend_data': metadata.get('extend', {}) or {},
            'search_source': 'vector',
        }

    def _extract_hit_fields(self, hit: Dict[str, Any]) -> Dict[str, Any]:
        """Extract fields from search result based on _search_source"""
        search_source = hit.get('_search_source')
        match search_source:
            case RetrieveMethod.KEYWORD.value:
                return self._extract_hit_fields_from_es(hit)
            case RetrieveMethod.VECTOR.value:
                return self._extract_hit_fields_from_milvus(hit)
            case _:
                raise ValueError(f"Unknown _search_source: {search_source}")

    async def group_by_groupid_stratagy(
        self,
        search_results: List[Dict[str, Any]],
        source_type: str = RetrieveMethod.VECTOR.value,
    ) -> tuple:
        """Generic search result grouping processing strategy

        Args:
            search_results: List of search results
            source_type: Retrieval method (keyword/vector/hybrid)

        Returns:
            tuple: (memories, scores, importance_scores, original_data, total_count)
        """
        # Step 1: Collect all data needed for queries
        all_memcell_event_ids = []
        all_user_group_pairs = []

        for hit in search_results:
            fields = self._extract_hit_fields(hit)
            memcell_event_id_list = fields['memcell_event_id_list']
            user_id = fields['user_id']
            group_id = fields['group_id']

            if memcell_event_id_list:
                all_memcell_event_ids.extend(memcell_event_id_list)

            # Collect user_id and group_id pairs
            if user_id and group_id:
                all_user_group_pairs.append((user_id, group_id))

        # Step 2: Execute two batch query tasks concurrently
        memcells_task = asyncio.create_task(
            self._batch_get_memcells(all_memcell_event_ids)
        )
        profiles_task = asyncio.create_task(
            self._batch_get_group_profiles(all_user_group_pairs)
        )

        # Wait for all tasks to complete
        memcells_cache, profiles_cache = await asyncio.gather(
            memcells_task, profiles_task
        )

        # Step 3: Process search results
        memories_by_group = (
            {}
        )  # {group_id: {'memories': [Memory], 'scores': [float], 'importance_evidence': dict}}
        original_data_by_group = {}

        for hit in search_results:
            # Extract fields
            fields = self._extract_hit_fields(hit)
            # Get score (each retrieval method uses its own score field)
            score = hit.get('score', 0.0)

            hit_id = fields['hit_id']
            user_id = fields['user_id']
            group_id = fields['group_id']
            timestamp_raw = fields['timestamp_raw']
            memcell_event_id_list = fields['memcell_event_id_list']
            episode = fields['episode']
            subject = fields['subject']
            summary = fields['summary']
            participants = fields['participants']
            event_type = fields['event_type']
            atomic_fact = fields['atomic_fact']
            foresight = fields['foresight']
            evidence = fields['evidence']
            extend_data = fields['extend_data']
            search_source = fields['search_source']
            # Process timestamp
            timestamp = from_iso_format(timestamp_raw)

            # Get memcell data from cache (foresight doesn't need this)
            memory_type_value = hit.get('memory_type', 'episodic_memory')
            memcells = []
            if memcell_event_id_list:
                # Get memcells from cache in original order
                for event_id in memcell_event_id_list:
                    memcell = memcells_cache.get(event_id)
                    if memcell:
                        memcells.append(memcell)
                    else:
                        logger.debug(f"Memcell not found: event_id={event_id}")
                        continue

            # Add raw data for each memcell
            for memcell in memcells:
                if group_id not in original_data_by_group:
                    original_data_by_group[group_id] = []
                # Use extend instead of append to flatten the list structure
                # memcell.original_data is a List[Dict], not a single Dict
                if memcell.original_data:
                    original_data_by_group[group_id].extend(memcell.original_data)

            # Create object based on memory type
            base_kwargs = dict(
                id=hit_id,
                memory_type=memory_type_value,
                user_id=user_id,
                timestamp=timestamp,
                ori_event_id_list=[hit_id],
                group_id=group_id,
                participants=participants,
                memcell_event_id_list=memcell_event_id_list,
                type=RawDataType.from_string(event_type),
                extend={
                    '_search_source': search_source,
                    'parent_type': extend_data.get('parent_type'),
                    'parent_id': extend_data.get('parent_id'),
                },
            )

            match memory_type_value:
                case MemoryType.EVENT_LOG.value:
                    memory = EventLog(**base_kwargs, atomic_fact=atomic_fact)
                case MemoryType.FORESIGHT.value:
                    memory = Foresight(
                        **base_kwargs, foresight=foresight, evidence=evidence
                    )
                case MemoryType.EPISODIC_MEMORY.value:
                    # EpisodeMemory has additional fields: subject, summary, episode
                    memory = EpisodeMemory(
                        **base_kwargs, subject=subject, summary=summary, episode=episode
                    )
                case MemoryType.PROFILE.value | MemoryType.GROUP_PROFILE.value:
                    # Profile memories should not reach here as they are not searchable via ES/Milvus
                    # Skip this hit and continue processing other results
                    logger.warning(
                        f"Skipping PROFILE type memory in search results - profile should be fetched from MongoDB: hit_id={hit_id}"
                    )
                    continue
                case _:
                    logger.warning(f"Skipping unsupported memory type in search results: {memory_type_value}")
                    continue

            # Get group_importance_evidence from cache
            group_importance_evidence = None
            if user_id and group_id:
                group_user_profile = profiles_cache.get((user_id, group_id))
                if (
                    group_user_profile
                    and hasattr(group_user_profile, 'group_importance_evidence')
                    and group_user_profile.group_importance_evidence
                ):
                    group_importance_evidence = (
                        group_user_profile.group_importance_evidence
                    )
                    # Add group_importance_evidence to memory's extend field
                    if not hasattr(memory, 'extend') or memory.extend is None:
                        memory.extend = {}
                    memory.extend['group_importance_evidence'] = (
                        group_importance_evidence
                    )
                    logger.debug(
                        f"Added group_importance_evidence to memory: user_id={user_id}, group_id={group_id}"
                    )

            # Group by group_id
            if group_id not in memories_by_group:
                memories_by_group[group_id] = {
                    'memories': [],
                    'scores': [],
                    'importance_evidence': group_importance_evidence,
                }

            memories_by_group[group_id]['memories'].append(memory)
            memories_by_group[group_id]['scores'].append(score)  # Save original score

            # Update group_importance_evidence (if current memory has updated evidence)
            if group_importance_evidence:
                memories_by_group[group_id][
                    'importance_evidence'
                ] = group_importance_evidence

        # Sort memories within each group by timestamp, and calculate importance score
        group_scores = []
        for group_id, group_data in memories_by_group.items():
            # Sort memories by timestamp
            group_data['memories'].sort(
                key=lambda m: m.timestamp if m.timestamp else ''
            )

            # Calculate importance score
            importance_score = self._calculate_importance_score(
                group_data['importance_evidence']
            )
            group_scores.append((group_id, importance_score))

        # Sort groups by importance score
        group_scores.sort(key=lambda x: x[1], reverse=True)

        # Build final results
        memories = []
        scores = []
        importance_scores = []
        original_data = []
        for group_id, importance_score in group_scores:
            group_data = memories_by_group[group_id]
            group_memories = group_data['memories']
            group_scores_list = group_data['scores']
            group_original_data = original_data_by_group.get(group_id, [])
            memories.append({group_id: group_memories})
            # scores structure consistent with memories: List[Dict[str, List[float]]]
            scores.append({group_id: group_scores_list})
            # original_data structure consistent with memories: List[Dict[str, List[Dict[str, Any]]]]
            original_data.append({group_id: group_original_data})
            importance_scores.append(importance_score)

        total_count = sum(
            len(group_data['memories']) for group_data in memories_by_group.values()
        )
        return memories, scores, importance_scores, original_data, total_count

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions