Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/bedrock_agentcore/memory/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ from bedrock_agentcore.memory.constants import RetrievalConfig
def my_llm(user_input: str, memories: List[Dict]) -> str:
# Format context from retrieved memories
context = "\n".join([
m.get('content', {}).get('text', '')
m.get('content', {}).get('text', '')
for m in memories
])

# Call your LLM (Bedrock, OpenAI, etc.)
# This is just an example - use your actual LLM integration
response = f"Based on our previous discussions about {context}, here's my response to: {user_input}"
Expand Down Expand Up @@ -343,25 +343,25 @@ try:
memory_id="your-memory-id",
region_name="us-east-1"
)

session = manager.create_memory_session(
actor_id="user-123",
session_id="session-456"
)

# Add conversation turns
event = session.add_turns([
ConversationalMessage("Hello", MessageRole.USER),
ConversationalMessage("Hi there!", MessageRole.ASSISTANT)
])

except NoCredentialsError:
print("AWS credentials not found. Please configure your credentials.")

except ClientError as e:
error_code = e.response['Error']['Code']
error_message = e.response['Error']['Message']

if error_code == 'ResourceNotFoundException':
print(f"Memory not found: {error_message}")
elif error_code == 'ValidationException':
Expand All @@ -372,7 +372,7 @@ except ClientError as e:
print(f"Request throttled: {error_message}")
else:
print(f"AWS error ({error_code}): {error_message}")

except Exception as e:
print(f"Unexpected error: {str(e)}")
```
Expand Down Expand Up @@ -458,7 +458,7 @@ from bedrock_agentcore.memory import MemoryClient
client = MemoryClient()
event = client.create_event(
memory_id="memory-123",
actor_id="user-456",
actor_id="user-456",
session_id="session-789",
messages=[("Hello", "USER"), ("Hi there", "ASSISTANT")]
)
Expand Down
10 changes: 5 additions & 5 deletions src/bedrock_agentcore/memory/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def __init__(
self._memory_id = memory_id

# Setup session and validate region consistency
self.region_name = self._validate_and_resolve_region(region_name, boto3_session)
session = boto3_session if boto3_session else boto3.Session()
self.region_name = self._validate_and_resolve_region(region_name, session)

# Configure and create boto3 client
client_config = self._build_client_config(boto_client_config)
Expand All @@ -138,23 +138,23 @@ def __init__(
"list_events",
}

def _validate_and_resolve_region(self, region_name: Optional[str], session: boto3.Session) -> str:
def _validate_and_resolve_region(self, region_name: Optional[str], session: Optional[boto3.Session]) -> str:
"""Validate region consistency and resolve the final region to use.

Args:
region_name: Explicitly provided region name
session: Boto3 session instance
session: Optional Boto3 session instance

Returns:
The resolved region name to use

Raises:
ValueError: If region_name conflicts with session region
"""
session_region = session.region_name
session_region = session.region_name if session else None

# Validate region consistency if both are provided
if region_name and session_region and isinstance(session_region, str) and region_name != session_region:
if region_name and session and session_region and (region_name != session_region):
raise ValueError(
f"Region mismatch: provided region_name '{region_name}' does not match "
f"boto3_session region '{session_region}'. Please ensure both "
Expand Down
74 changes: 72 additions & 2 deletions tests/bedrock_agentcore/memory/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2228,7 +2228,7 @@ def test_list_events_with_next_token(self):
assert second_call_args["nextToken"] == "token-123"

def test_validate_and_resolve_region_no_session_region(self):
"""Test _validate_and_resolve_region when session has no region - covers line 158."""
"""Test _validate_and_resolve_region when session has no region - covers line 154."""
with patch("boto3.Session") as mock_session_class:
mock_session = MagicMock()
mock_session.region_name = None # No region in session
Expand Down Expand Up @@ -2530,7 +2530,7 @@ def mock_llm_callback(user_input: str, memories: List[Dict[str, Any]]) -> str:
assert response == "Response"

def test_validate_and_resolve_region_edge_case(self):
"""Test _validate_and_resolve_region edge case - covers line 158."""
"""Test _validate_and_resolve_region edge case - covers line 154."""
with patch("boto3.Session") as mock_session_class:
mock_session = MagicMock()
mock_session.region_name = None # No region in session
Expand Down Expand Up @@ -2606,6 +2606,76 @@ def test_region_validation_with_non_string_session_region(self):
manager = MemorySessionManager(memory_id="test-memory", region_name="us-west-1")
assert manager.region_name == "us-west-1"

def test_region_validation_order_change(self):
"""Test that region validation happens before session creation - covers recent commit changes."""
# Test case: Conflicting regions should raise ValueError
custom_session = MagicMock()
custom_session.region_name = "us-east-1"
mock_client_instance = MagicMock()
custom_session.client.return_value = mock_client_instance

with pytest.raises(ValueError) as exc_info:
MemorySessionManager(
memory_id="test-memory",
region_name="us-west-1", # Different from session region
boto3_session=custom_session,
)

assert "Region mismatch" in str(exc_info.value)
assert "us-west-1" in str(exc_info.value)
assert "us-east-1" in str(exc_info.value)

def test_region_validation_with_none_session(self):
"""Test region validation when boto3_session is None - covers recent commit changes."""
with patch("boto3.Session") as mock_session_class:
mock_session = MagicMock()
mock_session.region_name = "us-east-1"
mock_client_instance = MagicMock()
mock_session.client.return_value = mock_client_instance
mock_session_class.return_value = mock_session

# Test validation when boto3_session parameter is None
manager = MemorySessionManager(
memory_id="test-memory",
region_name="us-west-1",
boto3_session=None, # Explicitly None
)

# Should use the provided region_name
assert manager.region_name == "us-west-1"

def test_region_validation_simplified_logic(self):
"""Test the simplified region validation logic - covers recent commit changes."""
# Test case 1: Conflicting regions should raise ValueError
custom_session = MagicMock()
custom_session.region_name = "us-east-1"
mock_client_instance = MagicMock()
custom_session.client.return_value = mock_client_instance

with pytest.raises(ValueError) as exc_info:
MemorySessionManager(
memory_id="test-memory",
region_name="us-west-1", # Different from session region
boto3_session=custom_session,
)

assert "Region mismatch" in str(exc_info.value)
assert "us-west-1" in str(exc_info.value)
assert "us-east-1" in str(exc_info.value)

# Test case 2: Matching regions should work
custom_session2 = MagicMock()
custom_session2.region_name = "us-west-1"
custom_session2.client.return_value = mock_client_instance

manager = MemorySessionManager(
memory_id="test-memory",
region_name="us-west-1", # Same as session region
boto3_session=custom_session2,
)

assert manager.region_name == "us-west-1"

def test_configure_timestamp_serialization_non_datetime_value(self):
"""Test timestamp serialization with non-datetime value."""
with patch("boto3.Session") as mock_session_class:
Expand Down
Loading