diff --git a/src/agents/judge.py b/src/agents/judge.py index 8ca8c57..35d326f 100644 --- a/src/agents/judge.py +++ b/src/agents/judge.py @@ -32,6 +32,7 @@ OperationType, ) from src.storage.base import BaseVectorStore, SearchResult +from src.config import settings # --------------------------------------------------------------------------- @@ -87,13 +88,13 @@ def _format_similar_block( return "\n".join(lines) -SUMMARY_JUDGE_SIMILARITY_THRESHOLD = 0.4 - def _has_summary_judge_candidates( matches_per_item: Dict[str, List[SearchResult]], - threshold: float = SUMMARY_JUDGE_SIMILARITY_THRESHOLD, + threshold: Optional[float] = None, ) -> bool: + if threshold is None: + threshold = settings.summary_judge_similarity_threshold for matches in matches_per_item.values(): for match in matches: if match.score >= threshold: @@ -112,11 +113,12 @@ def _filter_matches_by_threshold( def _deterministic_summary_add(items_strings: List[str], confidence: float = 0.8) -> JudgeResult: + threshold = settings.summary_judge_similarity_threshold operations = [ Operation( type=OperationType.ADD, content=item, - reason="No similar summary at or above 0.4 — defaulting to ADD.", + reason=f"No similar summary at or above {threshold} — defaulting to ADD.", ) for item in items_strings if str(item).strip() @@ -196,7 +198,7 @@ async def arun(self, state: Dict[str, Any]) -> JudgeResult: if domain == JudgeDomain.SUMMARY: matches_per_item = _filter_matches_by_threshold( matches_per_item, - SUMMARY_JUDGE_SIMILARITY_THRESHOLD, + settings.summary_judge_similarity_threshold, ) # 3. Build the prompt diff --git a/src/config/settings.py b/src/config/settings.py index 78630ce..9a73c99 100644 --- a/src/config/settings.py +++ b/src/config/settings.py @@ -136,6 +136,18 @@ class Settings(BaseSettings): default=0.4, description="LLM temperature for generation" ) + summary_judge_similarity_threshold: float = Field( + default=0.4, + ge=0.0, + le=1.0, + description="Threshold score for the Judge to match summary memories" + ) + temporal_search_similarity_threshold: float = Field( + default=0.3, + ge=-1.0, + le=1.0, + description="Minimum cosine similarity threshold score for Neo4j temporal search" + ) llm_timeout_seconds: float = Field( default=45.0, description="Per-agent LLM call timeout in seconds", diff --git a/src/graph/neo4j_client.py b/src/graph/neo4j_client.py index 52eb4f0..77de30e 100644 --- a/src/graph/neo4j_client.py +++ b/src/graph/neo4j_client.py @@ -30,6 +30,7 @@ from neo4j import GraphDatabase from src.graph.schema import GraphSchema +from src.config import settings logger = logging.getLogger("xmem.graph.neo4j") @@ -250,7 +251,7 @@ def search_events_by_embedding( user_id: str, query_text: str, top_k: int = 1, - similarity_threshold: float = 0.3, + similarity_threshold: Optional[float] = None, ) -> List[Dict[str, Any]]: """Semantic search over event embeddings stored on HAS_EVENT relationships. @@ -263,6 +264,9 @@ def search_events_by_embedding( ``similarity_score`` is raw cosine in [-1, 1] (matches the previous dot-product semantics, which assumed unit-normalised embeddings). """ + if similarity_threshold is None: + similarity_threshold = settings.temporal_search_similarity_threshold + if not self._embedding_fn: logger.warning("No embedding function — cannot search by embedding.") return []