diff --git a/mcp_proxy_for_aws/sigv4_helper.py b/mcp_proxy_for_aws/sigv4_helper.py index 570df15..7140b15 100644 --- a/mcp_proxy_for_aws/sigv4_helper.py +++ b/mcp_proxy_for_aws/sigv4_helper.py @@ -107,6 +107,7 @@ def create_sigv4_client( region: str, timeout: Optional[httpx.Timeout] = None, profile: Optional[str] = None, + session: Optional[boto3.Session] = None, headers: Optional[Dict[str, str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, @@ -115,7 +116,8 @@ def create_sigv4_client( Args: service: AWS service name for SigV4 signing - profile: AWS profile to use (optional) + profile: AWS profile to use (optional, only used if session is not provided) + session: AWS boto3 session to use (optional, takes precedence over profile) region: AWS region (optional, defaults to AWS_REGION env var or us-east-1) timeout: Timeout configuration for the HTTP client headers: Headers to include in requests @@ -125,6 +127,10 @@ def create_sigv4_client( Returns: httpx.AsyncClient with SigV4 authentication """ + # Create or use provided AWS session + if session is None: + session = create_aws_session(profile) + # Create a copy of kwargs to avoid modifying the passed dict client_kwargs = { 'follow_redirects': True, @@ -151,7 +157,7 @@ def create_sigv4_client( 'response': [_handle_error_response], 'request': [ partial(_inject_metadata_hook, metadata or {}), - partial(_sign_request_hook, region, service, profile), + partial(_sign_request_hook, region, service, session), ], }, ) @@ -210,7 +216,7 @@ async def _handle_error_response(response: httpx.Response) -> None: async def _sign_request_hook( region: str, service: str, - profile: Optional[str], + session: boto3.Session, request: httpx.Request, ) -> None: """Request hook to sign HTTP requests with AWS SigV4. @@ -222,14 +228,13 @@ async def _sign_request_hook( Args: region: AWS region for SigV4 signing service: AWS service name for SigV4 signing - profile: AWS profile to use (optional) + session: AWS boto3 session to use for credentials request: The HTTP request object to sign (modified in-place) """ # Set Content-Length for signing request.headers['Content-Length'] = str(len(request.content)) - # Get AWS credentials - session = create_aws_session(profile) + # Get AWS credentials from the session credentials = session.get_credentials() logger.info('Signing request with credentials for access key: %s', credentials.access_key) diff --git a/mcp_proxy_for_aws/utils.py b/mcp_proxy_for_aws/utils.py index 5d00cf7..4c3aeac 100644 --- a/mcp_proxy_for_aws/utils.py +++ b/mcp_proxy_for_aws/utils.py @@ -19,7 +19,7 @@ import logging import os from fastmcp.client.transports import StreamableHttpTransport -from mcp_proxy_for_aws.sigv4_helper import create_sigv4_client +from mcp_proxy_for_aws.sigv4_helper import create_aws_session, create_sigv4_client from typing import Any, Dict, Optional, Tuple from urllib.parse import urlparse @@ -49,6 +49,9 @@ def create_transport_with_sigv4( Returns: StreamableHttpTransport instance with SigV4 authentication """ + # Create AWS session once and reuse it for all httpx clients + logger.debug('Creating AWS session with profile: %s', profile) + session = create_aws_session(profile) def client_factory( headers: Optional[Dict[str, str]] = None, @@ -57,7 +60,7 @@ def client_factory( ) -> httpx.AsyncClient: return create_sigv4_client( service=service, - profile=profile, + session=session, region=region, headers=headers, timeout=custom_timeout, diff --git a/tests/unit/test_hooks.py b/tests/unit/test_hooks.py index 51b039a..ca247f5 100644 --- a/tests/unit/test_hooks.py +++ b/tests/unit/test_hooks.py @@ -23,7 +23,7 @@ _inject_metadata_hook, _sign_request_hook, ) -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import MagicMock, Mock def create_request_with_sigv4_headers( @@ -343,23 +343,21 @@ async def test_hook_preserves_other_params(self): class TestSignRequestHook: """Test cases for sign_request_hook function.""" - @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') @pytest.mark.asyncio - async def test_sign_request_hook_signs_request(self, mock_create_session): + async def test_sign_request_hook_signs_request(self): """Test that sign_request_hook properly signs requests.""" # Setup mocks - mock_create_session.return_value = create_mock_session() + mock_session = create_mock_session() region = 'us-east-1' service = 'bedrock-agentcore' - profile = None # Create request without signature headers request_body = json.dumps({'test': 'data'}).encode('utf-8') request = httpx.Request('POST', 'https://example.com/mcp', content=request_body) # Call the hook - await _sign_request_hook(region, service, profile, request) + await _sign_request_hook(region, service, mock_session, request) # Verify signature headers were added assert 'authorization' in request.headers @@ -367,63 +365,54 @@ async def test_sign_request_hook_signs_request(self, mock_create_session): assert 'x-amz-security-token' in request.headers assert request.headers['content-length'] == str(len(request_body)) - @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') @pytest.mark.asyncio - async def test_sign_request_hook_with_profile(self, mock_create_session): - """Test that sign_request_hook uses profile when provided.""" + async def test_sign_request_hook_with_profile(self): + """Test that sign_request_hook uses session when provided.""" # Setup mocks - mock_create_session.return_value = create_mock_session() + mock_session = create_mock_session() region = 'us-west-2' service = 'execute-api' - profile = 'test-profile' request_body = b'test content' request = httpx.Request('POST', 'https://example.com/api', content=request_body) # Call the hook - await _sign_request_hook(region, service, profile, request) - - # Verify session was created with profile - mock_create_session.assert_called_once_with(profile) + await _sign_request_hook(region, service, mock_session, request) # Verify request was signed assert 'authorization' in request.headers assert 'x-amz-date' in request.headers - @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') @pytest.mark.asyncio - async def test_sign_request_hook_sets_content_length(self, mock_create_session): + async def test_sign_request_hook_sets_content_length(self): """Test that sign_request_hook sets Content-Length header.""" # Setup mocks - mock_create_session.return_value = create_mock_session() + mock_session = create_mock_session() region = 'eu-west-1' service = 'lambda' - profile = None # Create request request_body = b'test content with specific length' request = httpx.Request('POST', 'https://example.com/api', content=request_body) - await _sign_request_hook(region, service, profile, request) + await _sign_request_hook(region, service, mock_session, request) # Verify Content-Length was set correctly assert request.headers['content-length'] == str(len(request_body)) - @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') @pytest.mark.asyncio - async def test_sign_request_hook_with_partial_application(self, mock_create_session): + async def test_sign_request_hook_with_partial_application(self): """Test that sign_request_hook works with functools.partial.""" # Setup mocks - mock_create_session.return_value = create_mock_session() + mock_session = create_mock_session() region = 'ap-southeast-1' service = 'execute-api' - profile = 'prod-profile' # Create curried function using partial - curried_hook = partial(_sign_request_hook, region, service, profile) + curried_hook = partial(_sign_request_hook, region, service, mock_session) request_body = b'request data' request = httpx.Request('POST', 'https://example.com/mcp', content=request_body) @@ -434,4 +423,3 @@ async def test_sign_request_hook_with_partial_application(self, mock_create_sess # Verify request was signed assert 'authorization' in request.headers assert 'x-amz-date' in request.headers - mock_create_session.assert_called_once_with(profile) diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py index 4269afd..8cb531c 100644 --- a/tests/unit/test_server.py +++ b/tests/unit/test_server.py @@ -401,17 +401,25 @@ def test_validate_service_name_service_parsing(self): result = determine_service_name(endpoint) assert result == expected_service + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') @patch('mcp_proxy_for_aws.sigv4_helper.httpx.AsyncClient') - def test_create_sigv4_client(self, mock_async_client): + def test_create_sigv4_client(self, mock_async_client, mock_create_session): """Test creating SigV4 authenticated client with request hooks. Note: Session creation and signing now happens in sign_request_hook, not during client creation. """ + # Mock session creation + mock_session = Mock() + mock_session.get_credentials.return_value = Mock(access_key='test-key') + mock_create_session.return_value = mock_session + # Act create_sigv4_client(service='test-service', region='us-west-2', profile='test-profile') # Assert + # Verify session was created with profile + mock_create_session.assert_called_once_with('test-profile') # Verify AsyncClient was called (signing happens via hooks) assert mock_async_client.call_count == 1 call_args = mock_async_client.call_args @@ -422,12 +430,16 @@ def test_create_sigv4_client(self, mock_async_client): # Should have metadata injection + sign hooks assert len(call_args[1]['event_hooks']['request']) == 2 - def test_create_sigv4_client_no_credentials(self): + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') + def test_create_sigv4_client_no_credentials(self, mock_create_session): """Test that credential check happens in sign_request_hook, not during client creation. Note: With the refactoring, client creation no longer validates credentials. Credential validation now happens in sign_request_hook when the request is signed. """ + mock_session = Mock() + mock_create_session.return_value = mock_session + # Client creation should succeed even without credentials # (credentials are checked when signing happens) client = create_sigv4_client(service='test-service', region='test-region') diff --git a/tests/unit/test_sigv4_helper.py b/tests/unit/test_sigv4_helper.py index 7c6cd20..d9bef3e 100644 --- a/tests/unit/test_sigv4_helper.py +++ b/tests/unit/test_sigv4_helper.py @@ -119,11 +119,14 @@ def test_create_aws_session_creation_failure(self, mock_session_class): class TestCreateSigv4Client: """Test cases for the create_sigv4_client function.""" + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') @patch('httpx.AsyncClient') - def test_create_sigv4_client_default(self, mock_client_class): + def test_create_sigv4_client_default(self, mock_client_class, mock_create_session): """Test creating SigV4 client with default parameters.""" mock_client = Mock() mock_client_class.return_value = mock_client + mock_session = Mock() + mock_create_session.return_value = mock_session # Test client creation result = create_sigv4_client(service='test-service', region='test-region') @@ -139,11 +142,14 @@ def test_create_sigv4_client_default(self, mock_client_class): assert call_args[1]['headers']['Accept'] == 'application/json, text/event-stream' assert result == mock_client + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') @patch('httpx.AsyncClient') - def test_create_sigv4_client_with_custom_headers(self, mock_client_class): + def test_create_sigv4_client_with_custom_headers(self, mock_client_class, mock_create_session): """Test creating SigV4 client with custom headers.""" mock_client = Mock() mock_client_class.return_value = mock_client + mock_session = Mock() + mock_create_session.return_value = mock_session # Test client creation with custom headers custom_headers = {'Custom-Header': 'custom-value'} @@ -160,25 +166,38 @@ def test_create_sigv4_client_with_custom_headers(self, mock_client_class): assert call_args[1]['headers'] == expected_headers assert result == mock_client + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') @patch('httpx.AsyncClient') - def test_create_sigv4_client_with_custom_service_and_region(self, mock_client_class): + def test_create_sigv4_client_with_custom_service_and_region( + self, mock_client_class, mock_create_session + ): """Test creating SigV4 client with custom service and region.""" mock_client = Mock() mock_client_class.return_value = mock_client + # Mock session creation + mock_session = Mock() + mock_session.get_credentials.return_value = Mock(access_key='test-key') + mock_create_session.return_value = mock_session + # Test client creation with custom parameters result = create_sigv4_client( service='custom-service', profile='test-profile', region='us-east-1' ) + # Verify session was created with profile + mock_create_session.assert_called_once_with('test-profile') # Verify client was created assert result == mock_client + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') @patch('httpx.AsyncClient') - def test_create_sigv4_client_with_kwargs(self, mock_client_class): + def test_create_sigv4_client_with_kwargs(self, mock_client_class, mock_create_session): """Test creating SigV4 client with additional kwargs.""" mock_client = Mock() mock_client_class.return_value = mock_client + mock_session = Mock() + mock_create_session.return_value = mock_session # Test client creation with additional kwargs result = create_sigv4_client( @@ -194,8 +213,9 @@ def test_create_sigv4_client_with_kwargs(self, mock_client_class): assert call_args[1]['proxies'] == {'http': 'http://proxy:8080'} assert result == mock_client + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') @patch('httpx.AsyncClient') - def test_create_sigv4_client_with_prompt_context(self, mock_client_class): + def test_create_sigv4_client_with_prompt_context(self, mock_client_class, mock_create_session): """Test creating SigV4 client when prompts exist in the system context. This test simulates the scenario where the sigv4_helper is used in a context @@ -204,6 +224,8 @@ def test_create_sigv4_client_with_prompt_context(self, mock_client_class): """ mock_client = Mock() mock_client_class.return_value = mock_client + mock_session = Mock() + mock_create_session.return_value = mock_session # Test client creation with headers that might be used when prompts exist prompt_context_headers = { diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index a58ef22..a4a4435 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -27,13 +27,16 @@ class TestCreateTransportWithSigv4: """Test cases for create_transport_with_sigv4 function (line 129).""" + @patch('mcp_proxy_for_aws.utils.create_aws_session') @patch('mcp_proxy_for_aws.utils.create_sigv4_client') - def test_create_transport_with_sigv4(self, mock_create_sigv4_client): + def test_create_transport_with_sigv4(self, mock_create_sigv4_client, mock_create_session): """Test creating StreamableHttpTransport with SigV4 authentication.""" from httpx import Timeout mock_client = MagicMock() mock_create_sigv4_client.return_value = mock_client + mock_session = MagicMock() + mock_create_session.return_value = mock_session url = 'https://test-service.us-west-2.api.aws/mcp' service = 'test-service' @@ -46,6 +49,9 @@ def test_create_transport_with_sigv4(self, mock_create_sigv4_client): url, service, region, metadata, custom_timeout, profile ) + # Verify session was created with profile + mock_create_session.assert_called_once_with(profile) + # Verify result is StreamableHttpTransport assert isinstance(result, StreamableHttpTransport) assert result.url == url @@ -59,7 +65,7 @@ def test_create_transport_with_sigv4(self, mock_create_sigv4_client): mock_create_sigv4_client.assert_called_once_with( service=service, - profile=profile, + session=mock_session, region=region, headers={'test': 'header'}, timeout=custom_timeout, @@ -70,11 +76,17 @@ def test_create_transport_with_sigv4(self, mock_create_sigv4_client): # If we can't access the factory directly, just verify the transport was created assert result is not None + @patch('mcp_proxy_for_aws.utils.create_aws_session') @patch('mcp_proxy_for_aws.utils.create_sigv4_client') - def test_create_transport_with_sigv4_no_profile(self, mock_create_sigv4_client): + def test_create_transport_with_sigv4_no_profile( + self, mock_create_sigv4_client, mock_create_session + ): """Test creating transport without profile.""" from httpx import Timeout + mock_session = MagicMock() + mock_create_session.return_value = mock_session + url = 'https://test-service.us-west-2.api.aws/mcp' service = 'test-service' region = 'test-region' @@ -83,6 +95,9 @@ def test_create_transport_with_sigv4_no_profile(self, mock_create_sigv4_client): result = create_transport_with_sigv4(url, service, region, metadata, custom_timeout) + # Verify session was created without profile + mock_create_session.assert_called_once_with(None) + # Test that the httpx_client_factory calls create_sigv4_client correctly # We need to access the factory through the transport's internal structure if hasattr(result, 'httpx_client_factory') and result.httpx_client_factory: @@ -91,8 +106,8 @@ def test_create_transport_with_sigv4_no_profile(self, mock_create_sigv4_client): mock_create_sigv4_client.assert_called_once_with( service=service, + session=mock_session, region=region, - profile=None, headers=None, timeout=custom_timeout, auth=None,