Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Optional

Expand Down Expand Up @@ -494,30 +495,48 @@ def retrieve_customer_context(self, event: MessageAddedEvent) -> None:
return None

user_query = messages[-1]["content"][0]["text"]

def retrieve_for_namespace(namespace: str, retrieval_config: AgentCoreMemoryConfig):
"""Helper function to retrieve memories for a single namespace."""
resolved_namespace = namespace.format(
actorId=self.config.actor_id,
sessionId=self.config.session_id,
memoryStrategyId=retrieval_config.strategy_id or "",
)

memories = self.memory_client.retrieve_memories(
memory_id=self.config.memory_id,
namespace=resolved_namespace,
query=user_query,
top_k=retrieval_config.top_k,
)
context_items = []
for memory in memories:
if isinstance(memory, dict):
content = memory.get("content", {})
if isinstance(content, dict):
text = content.get("text", "").strip()
if text:
context_items.append(text)
return context_items

try:
# Retrieve customer context from all namespaces
# Retrieve customer context from all namespaces in parallel
all_context = []
for namespace, retrieval_config in self.config.retrieval_config.items():
resolved_namespace = namespace.format(
actorId=self.config.actor_id,
sessionId=self.config.session_id,
memoryStrategyId=retrieval_config.strategy_id or "",
)

memories = self.memory_client.retrieve_memories(
memory_id=self.config.memory_id,
namespace=resolved_namespace,
query=user_query,
top_k=retrieval_config.top_k,
)

for memory in memories:
if isinstance(memory, dict):
content = memory.get("content", {})
if isinstance(content, dict):
text = content.get("text", "").strip()
if text:
all_context.append(text)
with ThreadPoolExecutor() as executor:
future_to_namespace = {
executor.submit(retrieve_for_namespace, namespace, retrieval_config): namespace
for namespace, retrieval_config in self.config.retrieval_config.items()
}
for future in as_completed(future_to_namespace):
try:
context_items = future.result()
all_context.extend(context_items)
except Exception as e:
# Continue processing other futures event if one fails rather than failing the entire operation
namespace = future_to_namespace[future]
logger.error("Failed to retrieve memories for namespace %s: %s", namespace, e)

# Inject customer context into the query
if all_context:
Expand All @@ -527,7 +546,7 @@ def retrieve_customer_context(self, event: MessageAddedEvent) -> None:
"content": [{"text": f"<user_context>{context_text}</user_context>"}],
}
event.agent.messages.append(ltm_msg)
logger.info("Retrieved %s customer context items", {len(all_context)})
logger.info("Retrieved %s customer context items", len(all_context))

except Exception as e:
logger.error("Failed to retrieve customer context: %s", e)
Expand Down
Loading