Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions mcp_proxy_for_aws/sigv4_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def create_sigv4_client(
region: str,
timeout: Optional[httpx.Timeout] = None,
profile: Optional[str] = None,
session: Optional[boto3.Session] = None,
headers: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
Expand All @@ -115,7 +116,8 @@ def create_sigv4_client(

Args:
service: AWS service name for SigV4 signing
profile: AWS profile to use (optional)
profile: AWS profile to use (optional, only used if session is not provided)
session: AWS boto3 session to use (optional, takes precedence over profile)
region: AWS region (optional, defaults to AWS_REGION env var or us-east-1)
timeout: Timeout configuration for the HTTP client
headers: Headers to include in requests
Expand All @@ -125,6 +127,10 @@ def create_sigv4_client(
Returns:
httpx.AsyncClient with SigV4 authentication
"""
# Create or use provided AWS session
if session is None:
session = create_aws_session(profile)

# Create a copy of kwargs to avoid modifying the passed dict
client_kwargs = {
'follow_redirects': True,
Expand All @@ -151,7 +157,7 @@ def create_sigv4_client(
'response': [_handle_error_response],
'request': [
partial(_inject_metadata_hook, metadata or {}),
partial(_sign_request_hook, region, service, profile),
partial(_sign_request_hook, region, service, session),
],
},
)
Expand Down Expand Up @@ -210,7 +216,7 @@ async def _handle_error_response(response: httpx.Response) -> None:
async def _sign_request_hook(
region: str,
service: str,
profile: Optional[str],
session: boto3.Session,
request: httpx.Request,
) -> None:
"""Request hook to sign HTTP requests with AWS SigV4.
Expand All @@ -222,14 +228,13 @@ async def _sign_request_hook(
Args:
region: AWS region for SigV4 signing
service: AWS service name for SigV4 signing
profile: AWS profile to use (optional)
session: AWS boto3 session to use for credentials
request: The HTTP request object to sign (modified in-place)
"""
# Set Content-Length for signing
request.headers['Content-Length'] = str(len(request.content))

# Get AWS credentials
session = create_aws_session(profile)
# Get AWS credentials from the session
credentials = session.get_credentials()
logger.info('Signing request with credentials for access key: %s', credentials.access_key)

Expand Down
7 changes: 5 additions & 2 deletions mcp_proxy_for_aws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import os
from fastmcp.client.transports import StreamableHttpTransport
from mcp_proxy_for_aws.sigv4_helper import create_sigv4_client
from mcp_proxy_for_aws.sigv4_helper import create_aws_session, create_sigv4_client
from typing import Any, Dict, Optional, Tuple
from urllib.parse import urlparse

Expand Down Expand Up @@ -49,6 +49,9 @@ def create_transport_with_sigv4(
Returns:
StreamableHttpTransport instance with SigV4 authentication
"""
# Create AWS session once and reuse it for all httpx clients
logger.debug('Creating AWS session with profile: %s', profile)
session = create_aws_session(profile)

def client_factory(
headers: Optional[Dict[str, str]] = None,
Expand All @@ -57,7 +60,7 @@ def client_factory(
) -> httpx.AsyncClient:
return create_sigv4_client(
service=service,
profile=profile,
session=session,
region=region,
headers=headers,
timeout=custom_timeout,
Expand Down
40 changes: 14 additions & 26 deletions tests/unit/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
_inject_metadata_hook,
_sign_request_hook,
)
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock, Mock


def create_request_with_sigv4_headers(
Expand Down Expand Up @@ -343,87 +343,76 @@ async def test_hook_preserves_other_params(self):
class TestSignRequestHook:
"""Test cases for sign_request_hook function."""

@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
@pytest.mark.asyncio
async def test_sign_request_hook_signs_request(self, mock_create_session):
async def test_sign_request_hook_signs_request(self):
"""Test that sign_request_hook properly signs requests."""
# Setup mocks
mock_create_session.return_value = create_mock_session()
mock_session = create_mock_session()

region = 'us-east-1'
service = 'bedrock-agentcore'
profile = None

# Create request without signature headers
request_body = json.dumps({'test': 'data'}).encode('utf-8')
request = httpx.Request('POST', 'https://example.com/mcp', content=request_body)

# Call the hook
await _sign_request_hook(region, service, profile, request)
await _sign_request_hook(region, service, mock_session, request)

# Verify signature headers were added
assert 'authorization' in request.headers
assert 'x-amz-date' in request.headers
assert 'x-amz-security-token' in request.headers
assert request.headers['content-length'] == str(len(request_body))

@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
@pytest.mark.asyncio
async def test_sign_request_hook_with_profile(self, mock_create_session):
"""Test that sign_request_hook uses profile when provided."""
async def test_sign_request_hook_with_profile(self):
"""Test that sign_request_hook uses session when provided."""
# Setup mocks
mock_create_session.return_value = create_mock_session()
mock_session = create_mock_session()

region = 'us-west-2'
service = 'execute-api'
profile = 'test-profile'

request_body = b'test content'
request = httpx.Request('POST', 'https://example.com/api', content=request_body)

# Call the hook
await _sign_request_hook(region, service, profile, request)

# Verify session was created with profile
mock_create_session.assert_called_once_with(profile)
await _sign_request_hook(region, service, mock_session, request)

# Verify request was signed
assert 'authorization' in request.headers
assert 'x-amz-date' in request.headers

@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
@pytest.mark.asyncio
async def test_sign_request_hook_sets_content_length(self, mock_create_session):
async def test_sign_request_hook_sets_content_length(self):
"""Test that sign_request_hook sets Content-Length header."""
# Setup mocks
mock_create_session.return_value = create_mock_session()
mock_session = create_mock_session()

region = 'eu-west-1'
service = 'lambda'
profile = None

# Create request
request_body = b'test content with specific length'
request = httpx.Request('POST', 'https://example.com/api', content=request_body)

await _sign_request_hook(region, service, profile, request)
await _sign_request_hook(region, service, mock_session, request)

# Verify Content-Length was set correctly
assert request.headers['content-length'] == str(len(request_body))

@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
@pytest.mark.asyncio
async def test_sign_request_hook_with_partial_application(self, mock_create_session):
async def test_sign_request_hook_with_partial_application(self):
"""Test that sign_request_hook works with functools.partial."""
# Setup mocks
mock_create_session.return_value = create_mock_session()
mock_session = create_mock_session()

region = 'ap-southeast-1'
service = 'execute-api'
profile = 'prod-profile'

# Create curried function using partial
curried_hook = partial(_sign_request_hook, region, service, profile)
curried_hook = partial(_sign_request_hook, region, service, mock_session)

request_body = b'request data'
request = httpx.Request('POST', 'https://example.com/mcp', content=request_body)
Expand All @@ -434,4 +423,3 @@ async def test_sign_request_hook_with_partial_application(self, mock_create_sess
# Verify request was signed
assert 'authorization' in request.headers
assert 'x-amz-date' in request.headers
mock_create_session.assert_called_once_with(profile)
16 changes: 14 additions & 2 deletions tests/unit/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,17 +401,25 @@ def test_validate_service_name_service_parsing(self):
result = determine_service_name(endpoint)
assert result == expected_service

@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
@patch('mcp_proxy_for_aws.sigv4_helper.httpx.AsyncClient')
def test_create_sigv4_client(self, mock_async_client):
def test_create_sigv4_client(self, mock_async_client, mock_create_session):
"""Test creating SigV4 authenticated client with request hooks.

Note: Session creation and signing now happens in sign_request_hook,
not during client creation.
"""
# Mock session creation
mock_session = Mock()
mock_session.get_credentials.return_value = Mock(access_key='test-key')
mock_create_session.return_value = mock_session

# Act
create_sigv4_client(service='test-service', region='us-west-2', profile='test-profile')

# Assert
# Verify session was created with profile
mock_create_session.assert_called_once_with('test-profile')
# Verify AsyncClient was called (signing happens via hooks)
assert mock_async_client.call_count == 1
call_args = mock_async_client.call_args
Expand All @@ -422,12 +430,16 @@ def test_create_sigv4_client(self, mock_async_client):
# Should have metadata injection + sign hooks
assert len(call_args[1]['event_hooks']['request']) == 2

def test_create_sigv4_client_no_credentials(self):
@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
def test_create_sigv4_client_no_credentials(self, mock_create_session):
"""Test that credential check happens in sign_request_hook, not during client creation.

Note: With the refactoring, client creation no longer validates credentials.
Credential validation now happens in sign_request_hook when the request is signed.
"""
mock_session = Mock()
mock_create_session.return_value = mock_session

# Client creation should succeed even without credentials
# (credentials are checked when signing happens)
client = create_sigv4_client(service='test-service', region='test-region')
Expand Down
32 changes: 27 additions & 5 deletions tests/unit/test_sigv4_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,14 @@ def test_create_aws_session_creation_failure(self, mock_session_class):
class TestCreateSigv4Client:
"""Test cases for the create_sigv4_client function."""

@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
@patch('httpx.AsyncClient')
def test_create_sigv4_client_default(self, mock_client_class):
def test_create_sigv4_client_default(self, mock_client_class, mock_create_session):
"""Test creating SigV4 client with default parameters."""
mock_client = Mock()
mock_client_class.return_value = mock_client
mock_session = Mock()
mock_create_session.return_value = mock_session

# Test client creation
result = create_sigv4_client(service='test-service', region='test-region')
Expand All @@ -139,11 +142,14 @@ def test_create_sigv4_client_default(self, mock_client_class):
assert call_args[1]['headers']['Accept'] == 'application/json, text/event-stream'
assert result == mock_client

@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
@patch('httpx.AsyncClient')
def test_create_sigv4_client_with_custom_headers(self, mock_client_class):
def test_create_sigv4_client_with_custom_headers(self, mock_client_class, mock_create_session):
"""Test creating SigV4 client with custom headers."""
mock_client = Mock()
mock_client_class.return_value = mock_client
mock_session = Mock()
mock_create_session.return_value = mock_session

# Test client creation with custom headers
custom_headers = {'Custom-Header': 'custom-value'}
Expand All @@ -160,25 +166,38 @@ def test_create_sigv4_client_with_custom_headers(self, mock_client_class):
assert call_args[1]['headers'] == expected_headers
assert result == mock_client

@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
@patch('httpx.AsyncClient')
def test_create_sigv4_client_with_custom_service_and_region(self, mock_client_class):
def test_create_sigv4_client_with_custom_service_and_region(
self, mock_client_class, mock_create_session
):
"""Test creating SigV4 client with custom service and region."""
mock_client = Mock()
mock_client_class.return_value = mock_client

# Mock session creation
mock_session = Mock()
mock_session.get_credentials.return_value = Mock(access_key='test-key')
mock_create_session.return_value = mock_session

# Test client creation with custom parameters
result = create_sigv4_client(
service='custom-service', profile='test-profile', region='us-east-1'
)

# Verify session was created with profile
mock_create_session.assert_called_once_with('test-profile')
# Verify client was created
assert result == mock_client

@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
@patch('httpx.AsyncClient')
def test_create_sigv4_client_with_kwargs(self, mock_client_class):
def test_create_sigv4_client_with_kwargs(self, mock_client_class, mock_create_session):
"""Test creating SigV4 client with additional kwargs."""
mock_client = Mock()
mock_client_class.return_value = mock_client
mock_session = Mock()
mock_create_session.return_value = mock_session

# Test client creation with additional kwargs
result = create_sigv4_client(
Expand All @@ -194,8 +213,9 @@ def test_create_sigv4_client_with_kwargs(self, mock_client_class):
assert call_args[1]['proxies'] == {'http': 'http://proxy:8080'}
assert result == mock_client

@patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session')
@patch('httpx.AsyncClient')
def test_create_sigv4_client_with_prompt_context(self, mock_client_class):
def test_create_sigv4_client_with_prompt_context(self, mock_client_class, mock_create_session):
"""Test creating SigV4 client when prompts exist in the system context.

This test simulates the scenario where the sigv4_helper is used in a context
Expand All @@ -204,6 +224,8 @@ def test_create_sigv4_client_with_prompt_context(self, mock_client_class):
"""
mock_client = Mock()
mock_client_class.return_value = mock_client
mock_session = Mock()
mock_create_session.return_value = mock_session

# Test client creation with headers that might be used when prompts exist
prompt_context_headers = {
Expand Down
Loading
Loading