From d614a4d8dfc3ce026d8eff909783653d6b1ed897 Mon Sep 17 00:00:00 2001 From: Weizhou Xing <169175349+wzxxing@users.noreply.github.com> Date: Tue, 11 Nov 2025 12:12:40 +0100 Subject: [PATCH] feat: allow iam mcp client to take a botocore credentials object --- mcp_proxy_for_aws/client.py | 46 ++++++++++++++++--------- tests/unit/test_client.py | 69 +++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 17 deletions(-) diff --git a/mcp_proxy_for_aws/client.py b/mcp_proxy_for_aws/client.py index 50ed5a7..efdb6bd 100644 --- a/mcp_proxy_for_aws/client.py +++ b/mcp_proxy_for_aws/client.py @@ -15,6 +15,7 @@ import boto3 import logging from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from botocore.credentials import Credentials from contextlib import _AsyncGeneratorContextManager from datetime import timedelta from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client @@ -32,6 +33,7 @@ def aws_iam_streamablehttp_client( aws_service: str, aws_region: Optional[str] = None, aws_profile: Optional[str] = None, + credentials: Optional[Credentials] = None, headers: Optional[dict[str, str]] = None, timeout: float | timedelta = 30, sse_read_timeout: float | timedelta = 60 * 5, @@ -55,6 +57,7 @@ def aws_iam_streamablehttp_client( aws_service: The name of the AWS service the MCP server is hosted on, e.g. "bedrock-agentcore". aws_region: The AWS region name of the MCP server, e.g. "us-west-2". aws_profile: The AWS profile to use for authentication. + credentials: Optional AWS credentials from boto3/botocore. If provided, takes precedence over aws_profile. headers: Optional additional HTTP headers to include in requests. timeout: Request timeout in seconds or timedelta object. Defaults to 30 seconds. sse_read_timeout: Server-sent events read timeout in seconds or timedelta object. @@ -78,28 +81,37 @@ def aws_iam_streamablehttp_client( """ logger.debug('Preparing AWS IAM MCP client for endpoint: %s', endpoint) - kwargs = {} - if aws_profile is not None: - kwargs['profile_name'] = aws_profile - if aws_region is not None: - kwargs['region_name'] = aws_region + if credentials is not None: + creds = credentials + region = aws_region + if not region: + raise ValueError( + 'AWS region must be specified via aws_region parameter when using credentials.' + ) + logger.debug('Using provided AWS credentials') + else: + kwargs = {} + if aws_profile is not None: + kwargs['profile_name'] = aws_profile + if aws_region is not None: + kwargs['region_name'] = aws_region + + session = boto3.Session(**kwargs) + creds = session.get_credentials() + region = session.region_name + + if not region: + raise ValueError( + 'AWS region must be specified via aws_region parameter, AWS_REGION environment variable, or AWS config.' + ) + + logger.debug('AWS profile: %s', session.profile_name) - session = boto3.Session(**kwargs) - - profile = session.profile_name - region = session.region_name - - if not region: - raise ValueError( - 'AWS region must be specified via aws_region parameter, AWS_PROFILE environment variable, or AWS config.' - ) - - logger.debug('AWS profile: %s', profile) logger.debug('AWS region: %s', region) logger.debug('AWS service: %s', aws_service) # Create a SigV4 authentication handler with AWS credentials - auth = SigV4HTTPXAuth(session.get_credentials(), aws_service, region) + auth = SigV4HTTPXAuth(creds, aws_service, region) # Return the streamable HTTP client context manager with AWS IAM authentication return streamablehttp_client( diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 0f9df07..2960e7d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -15,6 +15,7 @@ """Unit tests for the client, parameterized by internal call.""" import pytest +from botocore.credentials import Credentials from datetime import timedelta from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client from unittest.mock import AsyncMock, Mock, patch @@ -210,3 +211,71 @@ async def mock_aexit(*_): pass assert cleanup_called + + +@pytest.mark.asyncio +async def test_credentials_parameter_with_region(mock_streams): + """Test using provided credentials with aws_region.""" + mock_read, mock_write, mock_get_session = mock_streams + creds = Credentials('test_key', 'test_secret', 'test_token') + + with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth') as mock_auth_cls: + with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + mock_auth = Mock() + mock_auth_cls.return_value = mock_auth + mock_stream_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write, mock_get_session) + ) + mock_stream_client.return_value.__aexit__ = AsyncMock(return_value=None) + + async with aws_iam_streamablehttp_client( + endpoint='https://test.example.com/mcp', + aws_service='bedrock-agentcore', + aws_region='us-east-1', + credentials=creds, + ): + pass + + mock_auth_cls.assert_called_once_with(creds, 'bedrock-agentcore', 'us-east-1') + + +@pytest.mark.asyncio +async def test_credentials_parameter_without_region_raises_error(): + """Test that using credentials without aws_region raises ValueError.""" + creds = Credentials('test_key', 'test_secret', 'test_token') + + with pytest.raises( + ValueError, + match='AWS region must be specified via aws_region parameter when using credentials', + ): + async with aws_iam_streamablehttp_client( + endpoint='https://test.example.com/mcp', + aws_service='bedrock-agentcore', + credentials=creds, + ): + pass + + +@pytest.mark.asyncio +async def test_credentials_parameter_bypasses_boto3_session(mock_streams): + """Test that providing credentials bypasses boto3.Session creation.""" + mock_read, mock_write, mock_get_session = mock_streams + creds = Credentials('test_key', 'test_secret', 'test_token') + + with patch('boto3.Session') as mock_boto: + with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth'): + with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + mock_stream_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write, mock_get_session) + ) + mock_stream_client.return_value.__aexit__ = AsyncMock(return_value=None) + + async with aws_iam_streamablehttp_client( + endpoint='https://test.example.com/mcp', + aws_service='bedrock-agentcore', + aws_region='us-west-2', + credentials=creds, + ): + pass + + mock_boto.assert_not_called()