Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions mcp_proxy_for_aws/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,20 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:

# Validate and determine region
region = determine_aws_region(args.endpoint, args.region)
forwarding_region = args.forwarding_region or region
logger.debug('Using region: %s', region)

# Get profile
profile = args.profile

# Log server configuration
logger.info('Using service: %s, region: %s, profile: %s', service, region, profile)
logger.info(
'Using service: %s, region: %s, forwarding region: %s, profile: %s',
service,
region,
forwarding_region,
profile,
)
logger.info('Running in MCP mode')

timeout = httpx.Timeout(
Expand All @@ -72,7 +79,9 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
)

# Create transport with SigV4 authentication
transport = create_transport_with_sigv4(args.endpoint, service, region, timeout, profile)
transport = create_transport_with_sigv4(
args.endpoint, service, region, forwarding_region, timeout, profile
)

# Create proxy with the transport
proxy = FastMCP.as_proxy(transport)
Expand Down Expand Up @@ -163,7 +172,13 @@ def parse_args():

parser.add_argument(
'--region',
help='AWS region to use (uses AWS_REGION environment variable if not provided, with final fallback to us-east-1)',
help='AWS region to sign (uses AWS_REGION environment variable if not provided, with final fallback to us-east-1)',
default=None,
)

parser.add_argument(
'--forwarding-region',
help='AWS region to forward to server (uses --region if not provided)',
default=None,
)

Expand Down
130 changes: 129 additions & 1 deletion mcp_proxy_for_aws/sigv4_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@

import boto3
import httpx
import json
import logging
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
from functools import partial
from httpx._content import ByteStream
from typing import Any, Dict, Generator, Optional


Expand Down Expand Up @@ -120,6 +123,126 @@ async def _handle_error_response(response: httpx.Response) -> None:
raise e


def _resign_request_with_sigv4(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tbh, if we go with this approach we should not sign it at all before this point and only do signing here. I.e. no need to re-sign.

request: httpx.Request,
region: str,
service: str,
profile: Optional[str] = None,
) -> None:
"""Re-sign an HTTP request with AWS SigV4 after content modification.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably only do signing here and not re-sign. But this changes the core way the proxy works. We need to align on this.


This function removes old signature headers, creates a new signature based on
the current request content, and updates the request headers with the new signature.

Args:
request: The HTTP request object to re-sign (modified in-place)
region: AWS region for SigV4 signing
service: AWS service name for SigV4 signing
profile: AWS profile to use (optional)
"""
# Remove old signature headers before re-signing
headers_to_remove = ['Content-Length', 'x-amz-date', 'x-amz-security-token', 'authorization']
for header in headers_to_remove:
request.headers.pop(header, None)

# Set the new Content-Length
request.headers['Content-Length'] = str(len(request.content))

logger.info('Headers after cleanup: %s', request.headers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we should log this anymore, atleast not with info level.


# Get AWS credentials
session = create_aws_session(profile)
credentials = session.get_credentials()
logger.info('Re-signing request with credentials for access key: %s', credentials.access_key)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should not be logged / be on debug level.


# Create headers dict for signing, removing connection header like in auth_flow
headers_for_signing = dict(request.headers)
headers_for_signing.pop('connection', None) # Remove connection header for signing

# Create SigV4 signer and AWS request
signer = SigV4Auth(credentials, service, region)
aws_request = AWSRequest(
method=request.method,
url=str(request.url),
data=request.content,
headers=headers_for_signing,
)

# Sign the request
logger.info('AWS request before signing: %s', aws_request.headers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, should be debug

signer.add_auth(aws_request)
logger.info('AWS request after signing: %s', aws_request.headers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, should be debug


# Update request headers with signed headers
request.headers.update(dict(aws_request.headers))
logger.info('Request headers after re-signing: %s', request.headers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, should be debug



async def _inject_metadata_hook(
metadata: Dict[str, Any], region: str, service: str, request: httpx.Request
) -> None:
"""Request hook to inject metadata into MCP calls.

Args:
metadata: Dictionary of metadata to inject into _meta field
region: AWS region for SigV4 re-signing after metadata injection
service: AWS service name for SigV4 re-signing after metadata injection
request: The HTTP request object
"""
logger.info('=== Outgoing Request ===')
logger.info('URL: %s', request.url)
logger.info('Method: %s', request.method)

# Try to inject metadata if it's a JSON-RPC/MCP request
if request.content and metadata:
try:
# Parse the request body
body = json.loads(await request.aread())

# Check if it's a JSON-RPC request
if isinstance(body, dict) and 'jsonrpc' in body:
# Ensure _meta exists in params
if '_meta' not in body['params']:
body['params']['_meta'] = {}

# Get existing metadata
existing_meta = body['params']['_meta']

# Merge metadata (existing takes precedence)
if isinstance(existing_meta, dict):
# Check for conflicting keys before merge
conflicting_keys = set(metadata.keys()) & set(existing_meta.keys())
if conflicting_keys:
for key in conflicting_keys:
logger.warning(
'Metadata key "%s" already exists in _meta. '
'Keeping existing value "%s", ignoring injected value "%s"',
key,
existing_meta[key],
metadata[key],
)
body['params']['_meta'] = {**metadata, **existing_meta}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we may need to log if there are failed overwrites here, i.e. if a key was available in both, then this would make it clearer what happened.

else:
logger.info('Replacing non-dict _meta value with injected metadata')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean with non-dict here?
Maybe:

'No conclicting _meta entries exist.`

body['params']['_meta'] = metadata

# Create new content with updated metadata
new_content = json.dumps(body).encode('utf-8')

# Update the request with new content
request.stream = ByteStream(new_content)
request._content = new_content

# Re-sign the request with the new content
_resign_request_with_sigv4(request, region, service)

logger.info('Injected metadata into _meta: %s', body['params']['_meta'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: MaybeUpdated _meta after injection of additional _meta parameters: <meta>


except (json.JSONDecodeError, KeyError, TypeError) as e:
# Not a JSON request or invalid format, skip metadata injection
logger.error('Skipping metadata injection: %s', e)


def create_aws_session(profile: Optional[str] = None) -> boto3.Session:
"""Create an AWS session with optional profile.

Expand Down Expand Up @@ -185,6 +308,7 @@ def create_sigv4_client(
profile: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
auth: Optional[httpx.Auth] = None,
metadata: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> httpx.AsyncClient:
"""Create an httpx.AsyncClient with SigV4 authentication.
Expand All @@ -196,6 +320,7 @@ def create_sigv4_client(
timeout: Timeout configuration for the HTTP client
headers: Headers to include in requests
auth: Auth parameter (ignored as we provide our own)
metadata: Metadata to inject into MCP _meta field
**kwargs: Additional arguments to pass to httpx.AsyncClient

Returns:
Expand Down Expand Up @@ -228,5 +353,8 @@ def create_sigv4_client(
return httpx.AsyncClient(
auth=sigv4_auth,
**client_kwargs,
event_hooks={'response': [_handle_error_response]},
event_hooks={
'response': [_handle_error_response],
'request': [partial(_inject_metadata_hook, metadata or {}, region, service)],
},
)
3 changes: 3 additions & 0 deletions mcp_proxy_for_aws/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def create_transport_with_sigv4(
url: str,
service: str,
region: str,
forwarding_region: str,
custom_timeout: httpx.Timeout,
profile: Optional[str] = None,
) -> StreamableHttpTransport:
Expand All @@ -41,6 +42,7 @@ def create_transport_with_sigv4(
url: The endpoint URL
service: AWS service name for SigV4 signing
region: AWS region to use
forwarding_region: AWS region to forward to server
custom_timeout: httpx.Timeout used to connect to the endpoint
profile: AWS profile to use (optional)

Expand All @@ -60,6 +62,7 @@ def client_factory(
region=region,
headers=headers,
timeout=custom_timeout,
metadata={'AWS_REGION': forwarding_region},
auth=auth,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/integ/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _build_endpoint_environment_remote_configuration():

region_name = os.environ.get('AWS_REGION')
if not region_name:
logger.warn('AWS_REGION param not set. Defaulting to us-east-1')
logger.warning('AWS_REGION param not set. Defaulting to us-east-1')
region_name = 'us-east-1'

logger.info(f'Starting server with config - {remote_endpoint_url=} and {region_name=}')
Expand Down
10 changes: 10 additions & 0 deletions tests/integ/mcp/simple_mcp_server/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ async def elicit_for_my_name(elicitation_expected: str, ctx: Context):
return 'cancelled'


##### Metadata Testing


@mcp.tool
def echo_metadata(ctx: Context):
"""MCP Tool that echoes back the _meta field from the request."""
meta = ctx.request_context.meta
return {'received_meta': meta}


#### Server Setup


Expand Down
37 changes: 37 additions & 0 deletions tests/integ/test_proxy_simple_mcp_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Test the features about testing connecting to remote MCP Server runtime via the proxy."""

import fastmcp
import json
import logging
import pytest
from mcp.types import TextContent
Expand Down Expand Up @@ -91,3 +92,39 @@ async def test_handle_elicitation_when_declining(
async def test_handle_sampling(mcp_client: fastmcp.Client):
"""TODO."""
pass


@pytest.mark.asyncio(loop_scope='module')
async def test_metadata_injection_aws_region(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the integ test!

mcp_client: fastmcp.Client, remote_mcp_server_configuration
):
"""Test that AWS_REGION is automatically injected and received by the server.

This integration test verifies the full flow:
1. Client makes a request through the proxy
2. Proxy injects AWS_REGION into the _meta field
3. Server receives the request with metadata
4. Server echoes back the metadata it received
5. We verify AWS_REGION was correctly transmitted
"""
# Call the echo_metadata tool which returns the _meta field it received
actual_response = await mcp_client.call_tool('echo_metadata', {})

# Extract the response content
actual_text = get_text_content(actual_response)

# Parse the JSON response
response_data = json.loads(actual_text)

# Verify that AWS_REGION was injected and received by the server
assert 'received_meta' in response_data, (
f'Response should contain received_meta: {response_data}'
)
assert response_data['received_meta'] is not None, 'Metadata should not be None'
assert 'AWS_REGION' in response_data['received_meta'], (
f'Metadata should contain AWS_REGION: {response_data["received_meta"]}'
)
assert (
response_data['received_meta']['AWS_REGION']
== remote_mcp_server_configuration['region_name']
), f'AWS_REGION should be {remote_mcp_server_configuration["region_name"]}'
12 changes: 8 additions & 4 deletions tests/unit/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ async def test_setup_mcp_mode(
mock_args.profile = None
mock_args.read_only = True
mock_args.retries = 1
mock_args.forwarding_region = None
# Add timeout parameters
mock_args.timeout = 180.0
mock_args.connect_timeout = 60.0
Expand Down Expand Up @@ -86,8 +87,9 @@ async def test_setup_mcp_mode(
assert call_args[0][0] == 'https://test.example.com'
assert call_args[0][1] == 'test-service'
assert call_args[0][2] == 'us-east-1'
# call_args[0][3] is the Timeout object
assert call_args[0][4] is None # profile
assert call_args[0][3] == 'us-east-1' # forwarding_region (defaults to region)
# call_args[0][4] is the Timeout object
assert call_args[0][5] is None # profile
mock_as_proxy.assert_called_once_with(mock_transport)
mock_add_filtering.assert_called_once_with(mock_proxy, True)
mock_add_retry.assert_called_once_with(mock_proxy, 1)
Expand Down Expand Up @@ -116,6 +118,7 @@ async def test_setup_mcp_mode_no_retries(
mock_args.profile = 'test-profile'
mock_args.read_only = False
mock_args.retries = 0 # No retries
mock_args.forwarding_region = 'eu-west-1'
# Add timeout parameters
mock_args.timeout = 180.0
mock_args.connect_timeout = 60.0
Expand Down Expand Up @@ -146,8 +149,9 @@ async def test_setup_mcp_mode_no_retries(
assert call_args[0][0] == 'https://test.example.com'
assert call_args[0][1] == 'test-service'
assert call_args[0][2] == 'us-east-1'
# call_args[0][3] is the Timeout object
assert call_args[0][4] == 'test-profile' # profile
assert call_args[0][3] == 'eu-west-1' # forwarding_region
# call_args[0][4] is the Timeout object
assert call_args[0][5] == 'test-profile' # profile
mock_as_proxy.assert_called_once_with(mock_transport)
mock_add_filtering.assert_called_once_with(mock_proxy, False)
mock_proxy.run_async.assert_called_once()
Expand Down
Loading