From 1c696bba7514141bd6eb3bdd1cf8597189c0c132 Mon Sep 17 00:00:00 2001 From: Jonathan James Date: Wed, 1 Oct 2025 11:57:39 -0400 Subject: [PATCH] fix: fix validation exception which occurs if the default aws region mismatches with the user's region_name --- src/bedrock_agentcore/memory/README.md | 18 ++--- src/bedrock_agentcore/memory/session.py | 10 +-- .../bedrock_agentcore/memory/test_session.py | 74 ++++++++++++++++++- 3 files changed, 86 insertions(+), 16 deletions(-) diff --git a/src/bedrock_agentcore/memory/README.md b/src/bedrock_agentcore/memory/README.md index b1cb53f..195b4d1 100644 --- a/src/bedrock_agentcore/memory/README.md +++ b/src/bedrock_agentcore/memory/README.md @@ -198,10 +198,10 @@ from bedrock_agentcore.memory.constants import RetrievalConfig def my_llm(user_input: str, memories: List[Dict]) -> str: # Format context from retrieved memories context = "\n".join([ - m.get('content', {}).get('text', '') + m.get('content', {}).get('text', '') for m in memories ]) - + # Call your LLM (Bedrock, OpenAI, etc.) # This is just an example - use your actual LLM integration response = f"Based on our previous discussions about {context}, here's my response to: {user_input}" @@ -343,25 +343,25 @@ try: memory_id="your-memory-id", region_name="us-east-1" ) - + session = manager.create_memory_session( actor_id="user-123", session_id="session-456" ) - + # Add conversation turns event = session.add_turns([ ConversationalMessage("Hello", MessageRole.USER), ConversationalMessage("Hi there!", MessageRole.ASSISTANT) ]) - + except NoCredentialsError: print("AWS credentials not found. Please configure your credentials.") - + except ClientError as e: error_code = e.response['Error']['Code'] error_message = e.response['Error']['Message'] - + if error_code == 'ResourceNotFoundException': print(f"Memory not found: {error_message}") elif error_code == 'ValidationException': @@ -372,7 +372,7 @@ except ClientError as e: print(f"Request throttled: {error_message}") else: print(f"AWS error ({error_code}): {error_message}") - + except Exception as e: print(f"Unexpected error: {str(e)}") ``` @@ -458,7 +458,7 @@ from bedrock_agentcore.memory import MemoryClient client = MemoryClient() event = client.create_event( memory_id="memory-123", - actor_id="user-456", + actor_id="user-456", session_id="session-789", messages=[("Hello", "USER"), ("Hi there", "ASSISTANT")] ) diff --git a/src/bedrock_agentcore/memory/session.py b/src/bedrock_agentcore/memory/session.py index 6fd9eb6..bd284af 100644 --- a/src/bedrock_agentcore/memory/session.py +++ b/src/bedrock_agentcore/memory/session.py @@ -114,8 +114,8 @@ def __init__( self._memory_id = memory_id # Setup session and validate region consistency + self.region_name = self._validate_and_resolve_region(region_name, boto3_session) session = boto3_session if boto3_session else boto3.Session() - self.region_name = self._validate_and_resolve_region(region_name, session) # Configure and create boto3 client client_config = self._build_client_config(boto_client_config) @@ -138,12 +138,12 @@ def __init__( "list_events", } - def _validate_and_resolve_region(self, region_name: Optional[str], session: boto3.Session) -> str: + def _validate_and_resolve_region(self, region_name: Optional[str], session: Optional[boto3.Session]) -> str: """Validate region consistency and resolve the final region to use. Args: region_name: Explicitly provided region name - session: Boto3 session instance + session: Optional Boto3 session instance Returns: The resolved region name to use @@ -151,10 +151,10 @@ def _validate_and_resolve_region(self, region_name: Optional[str], session: boto Raises: ValueError: If region_name conflicts with session region """ - session_region = session.region_name + session_region = session.region_name if session else None # Validate region consistency if both are provided - if region_name and session_region and isinstance(session_region, str) and region_name != session_region: + if region_name and session and session_region and (region_name != session_region): raise ValueError( f"Region mismatch: provided region_name '{region_name}' does not match " f"boto3_session region '{session_region}'. Please ensure both " diff --git a/tests/bedrock_agentcore/memory/test_session.py b/tests/bedrock_agentcore/memory/test_session.py index 2a81868..ef9f91e 100644 --- a/tests/bedrock_agentcore/memory/test_session.py +++ b/tests/bedrock_agentcore/memory/test_session.py @@ -2228,7 +2228,7 @@ def test_list_events_with_next_token(self): assert second_call_args["nextToken"] == "token-123" def test_validate_and_resolve_region_no_session_region(self): - """Test _validate_and_resolve_region when session has no region - covers line 158.""" + """Test _validate_and_resolve_region when session has no region - covers line 154.""" with patch("boto3.Session") as mock_session_class: mock_session = MagicMock() mock_session.region_name = None # No region in session @@ -2530,7 +2530,7 @@ def mock_llm_callback(user_input: str, memories: List[Dict[str, Any]]) -> str: assert response == "Response" def test_validate_and_resolve_region_edge_case(self): - """Test _validate_and_resolve_region edge case - covers line 158.""" + """Test _validate_and_resolve_region edge case - covers line 154.""" with patch("boto3.Session") as mock_session_class: mock_session = MagicMock() mock_session.region_name = None # No region in session @@ -2606,6 +2606,76 @@ def test_region_validation_with_non_string_session_region(self): manager = MemorySessionManager(memory_id="test-memory", region_name="us-west-1") assert manager.region_name == "us-west-1" + def test_region_validation_order_change(self): + """Test that region validation happens before session creation - covers recent commit changes.""" + # Test case: Conflicting regions should raise ValueError + custom_session = MagicMock() + custom_session.region_name = "us-east-1" + mock_client_instance = MagicMock() + custom_session.client.return_value = mock_client_instance + + with pytest.raises(ValueError) as exc_info: + MemorySessionManager( + memory_id="test-memory", + region_name="us-west-1", # Different from session region + boto3_session=custom_session, + ) + + assert "Region mismatch" in str(exc_info.value) + assert "us-west-1" in str(exc_info.value) + assert "us-east-1" in str(exc_info.value) + + def test_region_validation_with_none_session(self): + """Test region validation when boto3_session is None - covers recent commit changes.""" + with patch("boto3.Session") as mock_session_class: + mock_session = MagicMock() + mock_session.region_name = "us-east-1" + mock_client_instance = MagicMock() + mock_session.client.return_value = mock_client_instance + mock_session_class.return_value = mock_session + + # Test validation when boto3_session parameter is None + manager = MemorySessionManager( + memory_id="test-memory", + region_name="us-west-1", + boto3_session=None, # Explicitly None + ) + + # Should use the provided region_name + assert manager.region_name == "us-west-1" + + def test_region_validation_simplified_logic(self): + """Test the simplified region validation logic - covers recent commit changes.""" + # Test case 1: Conflicting regions should raise ValueError + custom_session = MagicMock() + custom_session.region_name = "us-east-1" + mock_client_instance = MagicMock() + custom_session.client.return_value = mock_client_instance + + with pytest.raises(ValueError) as exc_info: + MemorySessionManager( + memory_id="test-memory", + region_name="us-west-1", # Different from session region + boto3_session=custom_session, + ) + + assert "Region mismatch" in str(exc_info.value) + assert "us-west-1" in str(exc_info.value) + assert "us-east-1" in str(exc_info.value) + + # Test case 2: Matching regions should work + custom_session2 = MagicMock() + custom_session2.region_name = "us-west-1" + custom_session2.client.return_value = mock_client_instance + + manager = MemorySessionManager( + memory_id="test-memory", + region_name="us-west-1", # Same as session region + boto3_session=custom_session2, + ) + + assert manager.region_name == "us-west-1" + def test_configure_timestamp_serialization_non_datetime_value(self): """Test timestamp serialization with non-datetime value.""" with patch("boto3.Session") as mock_session_class: