From a7bb4a7bb068584a92eaea60daff9c1615447235 Mon Sep 17 00:00:00 2001 From: Colin Francis Date: Wed, 19 Nov 2025 13:12:44 -0500 Subject: [PATCH] feat: parallelize retrieve memories API calls for multiple namespaces to improve latency --- .../integrations/strands/session_manager.py | 63 ++++++++++++------- 1 file changed, 41 insertions(+), 22 deletions(-) diff --git a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py index 0a7e9bc..d77db53 100644 --- a/src/bedrock_agentcore/memory/integrations/strands/session_manager.py +++ b/src/bedrock_agentcore/memory/integrations/strands/session_manager.py @@ -3,6 +3,7 @@ import json import logging import threading +from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Optional @@ -494,30 +495,48 @@ def retrieve_customer_context(self, event: MessageAddedEvent) -> None: return None user_query = messages[-1]["content"][0]["text"] + + def retrieve_for_namespace(namespace: str, retrieval_config: AgentCoreMemoryConfig): + """Helper function to retrieve memories for a single namespace.""" + resolved_namespace = namespace.format( + actorId=self.config.actor_id, + sessionId=self.config.session_id, + memoryStrategyId=retrieval_config.strategy_id or "", + ) + + memories = self.memory_client.retrieve_memories( + memory_id=self.config.memory_id, + namespace=resolved_namespace, + query=user_query, + top_k=retrieval_config.top_k, + ) + context_items = [] + for memory in memories: + if isinstance(memory, dict): + content = memory.get("content", {}) + if isinstance(content, dict): + text = content.get("text", "").strip() + if text: + context_items.append(text) + return context_items + try: - # Retrieve customer context from all namespaces + # Retrieve customer context from all namespaces in parallel all_context = [] - for namespace, retrieval_config in self.config.retrieval_config.items(): - resolved_namespace = namespace.format( - actorId=self.config.actor_id, - sessionId=self.config.session_id, - memoryStrategyId=retrieval_config.strategy_id or "", - ) - - memories = self.memory_client.retrieve_memories( - memory_id=self.config.memory_id, - namespace=resolved_namespace, - query=user_query, - top_k=retrieval_config.top_k, - ) - for memory in memories: - if isinstance(memory, dict): - content = memory.get("content", {}) - if isinstance(content, dict): - text = content.get("text", "").strip() - if text: - all_context.append(text) + with ThreadPoolExecutor() as executor: + future_to_namespace = { + executor.submit(retrieve_for_namespace, namespace, retrieval_config): namespace + for namespace, retrieval_config in self.config.retrieval_config.items() + } + for future in as_completed(future_to_namespace): + try: + context_items = future.result() + all_context.extend(context_items) + except Exception as e: + # Continue processing other futures event if one fails rather than failing the entire operation + namespace = future_to_namespace[future] + logger.error("Failed to retrieve memories for namespace %s: %s", namespace, e) # Inject customer context into the query if all_context: @@ -527,7 +546,7 @@ def retrieve_customer_context(self, event: MessageAddedEvent) -> None: "content": [{"text": f"{context_text}"}], } event.agent.messages.append(ltm_msg) - logger.info("Retrieved %s customer context items", {len(all_context)}) + logger.info("Retrieved %s customer context items", len(all_context)) except Exception as e: logger.error("Failed to retrieve customer context: %s", e)