-
Notifications
You must be signed in to change notification settings - Fork 18
feat(sigv4_helper): inject AWS_REGION in _meta #62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b9e509f
eeee136
094fb77
6eb9fc0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,10 +16,13 @@ | |
|
|
||
| import boto3 | ||
| import httpx | ||
| import json | ||
| import logging | ||
| from botocore.auth import SigV4Auth | ||
| from botocore.awsrequest import AWSRequest | ||
| from botocore.credentials import Credentials | ||
| from functools import partial | ||
| from httpx._content import ByteStream | ||
| from typing import Any, Dict, Generator, Optional | ||
|
|
||
|
|
||
|
|
@@ -120,6 +123,126 @@ async def _handle_error_response(response: httpx.Response) -> None: | |
| raise e | ||
|
|
||
|
|
||
| def _resign_request_with_sigv4( | ||
| request: httpx.Request, | ||
| region: str, | ||
| service: str, | ||
| profile: Optional[str] = None, | ||
| ) -> None: | ||
| """Re-sign an HTTP request with AWS SigV4 after content modification. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably only do signing here and not re-sign. But this changes the core way the proxy works. We need to align on this. |
||
|
|
||
| This function removes old signature headers, creates a new signature based on | ||
| the current request content, and updates the request headers with the new signature. | ||
|
|
||
| Args: | ||
| request: The HTTP request object to re-sign (modified in-place) | ||
| region: AWS region for SigV4 signing | ||
| service: AWS service name for SigV4 signing | ||
| profile: AWS profile to use (optional) | ||
| """ | ||
| # Remove old signature headers before re-signing | ||
| headers_to_remove = ['Content-Length', 'x-amz-date', 'x-amz-security-token', 'authorization'] | ||
| for header in headers_to_remove: | ||
| request.headers.pop(header, None) | ||
|
|
||
| # Set the new Content-Length | ||
| request.headers['Content-Length'] = str(len(request.content)) | ||
|
|
||
| logger.info('Headers after cleanup: %s', request.headers) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't think we should log this anymore, atleast not with info level. |
||
|
|
||
| # Get AWS credentials | ||
| session = create_aws_session(profile) | ||
| credentials = session.get_credentials() | ||
| logger.info('Re-signing request with credentials for access key: %s', credentials.access_key) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should not be logged / be on debug level. |
||
|
|
||
| # Create headers dict for signing, removing connection header like in auth_flow | ||
| headers_for_signing = dict(request.headers) | ||
| headers_for_signing.pop('connection', None) # Remove connection header for signing | ||
|
|
||
| # Create SigV4 signer and AWS request | ||
| signer = SigV4Auth(credentials, service, region) | ||
| aws_request = AWSRequest( | ||
| method=request.method, | ||
| url=str(request.url), | ||
| data=request.content, | ||
| headers=headers_for_signing, | ||
| ) | ||
|
|
||
| # Sign the request | ||
| logger.info('AWS request before signing: %s', aws_request.headers) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above, should be |
||
| signer.add_auth(aws_request) | ||
| logger.info('AWS request after signing: %s', aws_request.headers) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above, should be |
||
|
|
||
| # Update request headers with signed headers | ||
| request.headers.update(dict(aws_request.headers)) | ||
| logger.info('Request headers after re-signing: %s', request.headers) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above, should be |
||
|
|
||
|
|
||
| async def _inject_metadata_hook( | ||
| metadata: Dict[str, Any], region: str, service: str, request: httpx.Request | ||
| ) -> None: | ||
| """Request hook to inject metadata into MCP calls. | ||
|
|
||
| Args: | ||
| metadata: Dictionary of metadata to inject into _meta field | ||
| region: AWS region for SigV4 re-signing after metadata injection | ||
| service: AWS service name for SigV4 re-signing after metadata injection | ||
| request: The HTTP request object | ||
| """ | ||
| logger.info('=== Outgoing Request ===') | ||
| logger.info('URL: %s', request.url) | ||
| logger.info('Method: %s', request.method) | ||
|
|
||
| # Try to inject metadata if it's a JSON-RPC/MCP request | ||
| if request.content and metadata: | ||
| try: | ||
| # Parse the request body | ||
| body = json.loads(await request.aread()) | ||
|
|
||
| # Check if it's a JSON-RPC request | ||
| if isinstance(body, dict) and 'jsonrpc' in body: | ||
| # Ensure _meta exists in params | ||
| if '_meta' not in body['params']: | ||
| body['params']['_meta'] = {} | ||
|
|
||
| # Get existing metadata | ||
| existing_meta = body['params']['_meta'] | ||
|
|
||
| # Merge metadata (existing takes precedence) | ||
| if isinstance(existing_meta, dict): | ||
| # Check for conflicting keys before merge | ||
| conflicting_keys = set(metadata.keys()) & set(existing_meta.keys()) | ||
| if conflicting_keys: | ||
| for key in conflicting_keys: | ||
| logger.warning( | ||
| 'Metadata key "%s" already exists in _meta. ' | ||
| 'Keeping existing value "%s", ignoring injected value "%s"', | ||
| key, | ||
| existing_meta[key], | ||
| metadata[key], | ||
| ) | ||
| body['params']['_meta'] = {**metadata, **existing_meta} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we may need to log if there are failed overwrites here, i.e. if a key was available in both, then this would make it clearer what happened. |
||
| else: | ||
| logger.info('Replacing non-dict _meta value with injected metadata') | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you mean with |
||
| body['params']['_meta'] = metadata | ||
|
|
||
| # Create new content with updated metadata | ||
| new_content = json.dumps(body).encode('utf-8') | ||
|
|
||
| # Update the request with new content | ||
| request.stream = ByteStream(new_content) | ||
| request._content = new_content | ||
|
|
||
| # Re-sign the request with the new content | ||
| _resign_request_with_sigv4(request, region, service) | ||
|
|
||
| logger.info('Injected metadata into _meta: %s', body['params']['_meta']) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Maybe |
||
|
|
||
| except (json.JSONDecodeError, KeyError, TypeError) as e: | ||
| # Not a JSON request or invalid format, skip metadata injection | ||
| logger.error('Skipping metadata injection: %s', e) | ||
|
|
||
|
|
||
| def create_aws_session(profile: Optional[str] = None) -> boto3.Session: | ||
| """Create an AWS session with optional profile. | ||
|
|
||
|
|
@@ -185,6 +308,7 @@ def create_sigv4_client( | |
| profile: Optional[str] = None, | ||
| headers: Optional[Dict[str, str]] = None, | ||
| auth: Optional[httpx.Auth] = None, | ||
| metadata: Optional[Dict[str, Any]] = None, | ||
| **kwargs: Any, | ||
| ) -> httpx.AsyncClient: | ||
| """Create an httpx.AsyncClient with SigV4 authentication. | ||
|
|
@@ -196,6 +320,7 @@ def create_sigv4_client( | |
| timeout: Timeout configuration for the HTTP client | ||
| headers: Headers to include in requests | ||
| auth: Auth parameter (ignored as we provide our own) | ||
| metadata: Metadata to inject into MCP _meta field | ||
| **kwargs: Additional arguments to pass to httpx.AsyncClient | ||
|
|
||
| Returns: | ||
|
|
@@ -228,5 +353,8 @@ def create_sigv4_client( | |
| return httpx.AsyncClient( | ||
| auth=sigv4_auth, | ||
| **client_kwargs, | ||
| event_hooks={'response': [_handle_error_response]}, | ||
| event_hooks={ | ||
| 'response': [_handle_error_response], | ||
| 'request': [partial(_inject_metadata_hook, metadata or {}, region, service)], | ||
| }, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| """Test the features about testing connecting to remote MCP Server runtime via the proxy.""" | ||
|
|
||
| import fastmcp | ||
| import json | ||
| import logging | ||
| import pytest | ||
| from mcp.types import TextContent | ||
|
|
@@ -91,3 +92,39 @@ async def test_handle_elicitation_when_declining( | |
| async def test_handle_sampling(mcp_client: fastmcp.Client): | ||
| """TODO.""" | ||
| pass | ||
|
|
||
|
|
||
| @pytest.mark.asyncio(loop_scope='module') | ||
| async def test_metadata_injection_aws_region( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the integ test! |
||
| mcp_client: fastmcp.Client, remote_mcp_server_configuration | ||
| ): | ||
| """Test that AWS_REGION is automatically injected and received by the server. | ||
|
|
||
| This integration test verifies the full flow: | ||
| 1. Client makes a request through the proxy | ||
| 2. Proxy injects AWS_REGION into the _meta field | ||
| 3. Server receives the request with metadata | ||
| 4. Server echoes back the metadata it received | ||
| 5. We verify AWS_REGION was correctly transmitted | ||
| """ | ||
| # Call the echo_metadata tool which returns the _meta field it received | ||
| actual_response = await mcp_client.call_tool('echo_metadata', {}) | ||
|
|
||
| # Extract the response content | ||
| actual_text = get_text_content(actual_response) | ||
|
|
||
| # Parse the JSON response | ||
| response_data = json.loads(actual_text) | ||
|
|
||
| # Verify that AWS_REGION was injected and received by the server | ||
| assert 'received_meta' in response_data, ( | ||
| f'Response should contain received_meta: {response_data}' | ||
| ) | ||
| assert response_data['received_meta'] is not None, 'Metadata should not be None' | ||
| assert 'AWS_REGION' in response_data['received_meta'], ( | ||
| f'Metadata should contain AWS_REGION: {response_data["received_meta"]}' | ||
| ) | ||
| assert ( | ||
| response_data['received_meta']['AWS_REGION'] | ||
| == remote_mcp_server_configuration['region_name'] | ||
| ), f'AWS_REGION should be {remote_mcp_server_configuration["region_name"]}' | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tbh, if we go with this approach we should not sign it at all before this point and only do signing here. I.e. no need to re-sign.