Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 29 additions & 17 deletions mcp_proxy_for_aws/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import boto3
import logging
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from botocore.credentials import Credentials
from contextlib import _AsyncGeneratorContextManager
from datetime import timedelta
from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client
Expand All @@ -32,6 +33,7 @@ def aws_iam_streamablehttp_client(
aws_service: str,
aws_region: Optional[str] = None,
aws_profile: Optional[str] = None,
credentials: Optional[Credentials] = None,
headers: Optional[dict[str, str]] = None,
timeout: float | timedelta = 30,
sse_read_timeout: float | timedelta = 60 * 5,
Expand All @@ -55,6 +57,7 @@ def aws_iam_streamablehttp_client(
aws_service: The name of the AWS service the MCP server is hosted on, e.g. "bedrock-agentcore".
aws_region: The AWS region name of the MCP server, e.g. "us-west-2".
aws_profile: The AWS profile to use for authentication.
credentials: Optional AWS credentials from boto3/botocore. If provided, takes precedence over aws_profile.
headers: Optional additional HTTP headers to include in requests.
timeout: Request timeout in seconds or timedelta object. Defaults to 30 seconds.
sse_read_timeout: Server-sent events read timeout in seconds or timedelta object.
Expand All @@ -78,28 +81,37 @@ def aws_iam_streamablehttp_client(
"""
logger.debug('Preparing AWS IAM MCP client for endpoint: %s', endpoint)

kwargs = {}
if aws_profile is not None:
kwargs['profile_name'] = aws_profile
if aws_region is not None:
kwargs['region_name'] = aws_region
if credentials is not None:
creds = credentials
region = aws_region
if not region:
raise ValueError(
'AWS region must be specified via aws_region parameter when using credentials.'
)
logger.debug('Using provided AWS credentials')
else:
kwargs = {}
if aws_profile is not None:
kwargs['profile_name'] = aws_profile
if aws_region is not None:
kwargs['region_name'] = aws_region

session = boto3.Session(**kwargs)
creds = session.get_credentials()
region = session.region_name

if not region:
raise ValueError(
'AWS region must be specified via aws_region parameter, AWS_REGION environment variable, or AWS config.'
)

logger.debug('AWS profile: %s', session.profile_name)

session = boto3.Session(**kwargs)

profile = session.profile_name
region = session.region_name

if not region:
raise ValueError(
'AWS region must be specified via aws_region parameter, AWS_PROFILE environment variable, or AWS config.'
)

logger.debug('AWS profile: %s', profile)
logger.debug('AWS region: %s', region)
logger.debug('AWS service: %s', aws_service)

# Create a SigV4 authentication handler with AWS credentials
auth = SigV4HTTPXAuth(session.get_credentials(), aws_service, region)
auth = SigV4HTTPXAuth(creds, aws_service, region)

# Return the streamable HTTP client context manager with AWS IAM authentication
return streamablehttp_client(
Expand Down
69 changes: 69 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Unit tests for the client, parameterized by internal call."""

import pytest
from botocore.credentials import Credentials
from datetime import timedelta
from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client
from unittest.mock import AsyncMock, Mock, patch
Expand Down Expand Up @@ -210,3 +211,71 @@ async def mock_aexit(*_):
pass

assert cleanup_called


@pytest.mark.asyncio
async def test_credentials_parameter_with_region(mock_streams):
"""Test using provided credentials with aws_region."""
mock_read, mock_write, mock_get_session = mock_streams
creds = Credentials('test_key', 'test_secret', 'test_token')

with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth') as mock_auth_cls:
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
mock_auth = Mock()
mock_auth_cls.return_value = mock_auth
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
mock_stream_client.return_value.__aexit__ = AsyncMock(return_value=None)

async with aws_iam_streamablehttp_client(
endpoint='https://test.example.com/mcp',
aws_service='bedrock-agentcore',
aws_region='us-east-1',
credentials=creds,
):
pass

mock_auth_cls.assert_called_once_with(creds, 'bedrock-agentcore', 'us-east-1')


@pytest.mark.asyncio
async def test_credentials_parameter_without_region_raises_error():
"""Test that using credentials without aws_region raises ValueError."""
creds = Credentials('test_key', 'test_secret', 'test_token')

with pytest.raises(
ValueError,
match='AWS region must be specified via aws_region parameter when using credentials',
):
async with aws_iam_streamablehttp_client(
endpoint='https://test.example.com/mcp',
aws_service='bedrock-agentcore',
credentials=creds,
):
pass


@pytest.mark.asyncio
async def test_credentials_parameter_bypasses_boto3_session(mock_streams):
"""Test that providing credentials bypasses boto3.Session creation."""
mock_read, mock_write, mock_get_session = mock_streams
creds = Credentials('test_key', 'test_secret', 'test_token')

with patch('boto3.Session') as mock_boto:
with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth'):
with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client:
mock_stream_client.return_value.__aenter__ = AsyncMock(
return_value=(mock_read, mock_write, mock_get_session)
)
mock_stream_client.return_value.__aexit__ = AsyncMock(return_value=None)

async with aws_iam_streamablehttp_client(
endpoint='https://test.example.com/mcp',
aws_service='bedrock-agentcore',
aws_region='us-west-2',
credentials=creds,
):
pass

mock_boto.assert_not_called()