From 5ac5998bbb708a88f2b5bf69c21b07fd77a7d5f7 Mon Sep 17 00:00:00 2001 From: Harvish N S Date: Mon, 10 Nov 2025 21:50:44 +0100 Subject: [PATCH 1/8] pypi release automation through github actions --- .github/workflows/pypi-publish-on-release.yml | 74 +++++++++++++++++++ .github/workflows/python.yml | 8 ++ 2 files changed, 82 insertions(+) create mode 100644 .github/workflows/pypi-publish-on-release.yml diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml new file mode 100644 index 0000000..8b8dfb1 --- /dev/null +++ b/.github/workflows/pypi-publish-on-release.yml @@ -0,0 +1,74 @@ +name: Publish to PyPI + +on: + release: + types: + - published + +permissions: {} + +jobs: + call-test-lint: + permissions: + contents: read + uses: ./.github/workflows/python.yml + with: + ref: ${{ github.event.release.target_commitish }} + + build: + needs: call-test-lint + runs-on: ubuntu-latest + permissions: + contents: read + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + persist-credentials: false + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install uv + run: pip install uv + + - name: Validate version format + run: | + VERSION=$(python -c "import tomllib; print(tomllib.load(open('pyproject.toml', 'rb'))['project']['version'])") + if [[ ! "$VERSION" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + echo "Invalid version format: $VERSION" + echo "Expected format: X.Y.Z (e.g., 1.0.0)" + exit 1 + fi + echo "Valid version format: $VERSION" + + - name: Build distribution packages + run: uv build + + - name: Upload distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + deploy: + needs: build + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/p/mcp-proxy-for-aws + permissions: + contents: read + steps: + - name: Download distribution packages + uses: actions/download-artifact@v5 + with: + name: python-package-distributions + path: dist/ + + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 20b1dee..fbc2384 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -4,6 +4,12 @@ on: push: pull_request: workflow_dispatch: + workflow_call: + inputs: + ref: + description: 'Git ref to checkout' + required: false + type: string permissions: {} @@ -25,6 +31,8 @@ jobs: actions: read steps: - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + ref: ${{ inputs.ref || github.ref }} - name: Install uv uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2 From b52085fecb495174c6d92d8ff0ebbe5dfeed1d32 Mon Sep 17 00:00:00 2001 From: Kyon <91875365+kyoncal@users.noreply.github.com> Date: Tue, 11 Nov 2025 11:15:51 +0100 Subject: [PATCH 2/8] Forward region via meta (#71) * feat(sigv4_helper): inject AWS_REGION in _meta * Override the sigv4 signature when adding _meta. * feat(sigv4_helper): add region and service argument to _inject_metadata_hook to allow for proper resigning of sigv4 to work * feat(server.py): add forwarding region as optional argument * feat: replace forwarding region with metadata forwarding * refactor: move the hooks from sigv4_helper.py into a new folder and add tests * refactor(siv4_helper.py): move signing logic from client creation to an event hook * test(test_hooks.py): add assertions * refactor(sigv4_helper.py): remove hooks.py module and move hooks to sigv4_helper.py This refactor was needed in order to avoid a circular depdency, which resulted in a mid-module import. --------- Co-authored-by: Kyon Caldera Co-authored-by: Leonardo Araneda Freccero --- README.md | 1 + mcp_proxy_for_aws/cli.py | 47 +- mcp_proxy_for_aws/server.py | 17 +- mcp_proxy_for_aws/sigv4_helper.py | 244 ++++++---- mcp_proxy_for_aws/utils.py | 5 +- tests/integ/mcp/simple_mcp_client.py | 52 ++- .../integ/mcp/simple_mcp_server/mcp_server.py | 10 + tests/integ/test_proxy_simple_mcp_server.py | 138 ++++++ tests/unit/test_cli.py | 46 ++ tests/unit/test_hooks.py | 437 ++++++++++++++++++ tests/unit/test_server.py | 169 +++++-- tests/unit/test_sigv4_helper.py | 176 +------ tests/unit/test_utils.py | 10 +- 13 files changed, 1044 insertions(+), 308 deletions(-) create mode 100644 tests/unit/test_hooks.py diff --git a/README.md b/README.md index a220556..ef8476f 100644 --- a/README.md +++ b/README.md @@ -82,6 +82,7 @@ docker build -t mcp-proxy-for-aws . | `--service` | AWS service name for SigV4 signing | Inferred from endpoint if not provided |No | | `--profile` | AWS profile for AWS credentials to use | Uses `AWS_PROFILE` environment variable if not set |No | | `--region` | AWS region to use | Uses `AWS_REGION` environment variable if not set, defaults to `us-east-1` |No | +| `--metadata` | Metadata to inject into MCP requests as key=value pairs (e.g., `--metadata KEY1=value1 KEY2=value2`) | `AWS_REGION` is automatically injected based on `--region` if not provided |No | | `--read-only` | Disable tools which may require write permissions (tools which DO NOT require write permissions are annotated with [`readOnlyHint=true`](https://modelcontextprotocol.io/specification/2025-06-18/schema#toolannotations-readonlyhint)) | `False` |No | | `--retries` | Configures number of retries done when calling upstream services, setting this to 0 disables retries. | 0 |No | | `--log-level` | Set the logging level (`DEBUG/INFO/WARNING/ERROR/CRITICAL`) | `INFO` |No | diff --git a/mcp_proxy_for_aws/cli.py b/mcp_proxy_for_aws/cli.py index db3a930..dc66665 100644 --- a/mcp_proxy_for_aws/cli.py +++ b/mcp_proxy_for_aws/cli.py @@ -18,6 +18,43 @@ import os from mcp_proxy_for_aws import __version__ from mcp_proxy_for_aws.utils import within_range +from typing import Any, Dict, Optional, Sequence + + +class KeyValueAction(argparse.Action): + """Custom argparse action to parse key=value pairs into a dictionary.""" + + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: str | Sequence[Any] | None, + option_string: Optional[str] = None, + ) -> None: + """Parse key=value pairs into a dictionary. + + Args: + parser: The argument parser + namespace: The namespace object to update + values: The values to parse (list of key=value strings) + option_string: The option string that triggered this action + """ + metadata: Dict[str, str] = {} + # Ensure values is a sequence + if values is None: + # No values provided, set empty dict + setattr(namespace, self.dest, metadata) + return + + if isinstance(values, str): + values = [values] + + for item in values: + if '=' not in item: + parser.error(f'Metadata must be in key=value format, got: {item}') + key, value = item.split('=', 1) + metadata[key] = value + setattr(namespace, self.dest, metadata) def parse_args(): @@ -60,10 +97,18 @@ def parse_args(): parser.add_argument( '--region', - help='AWS region to use (uses AWS_REGION environment variable if not provided, with final fallback to us-east-1)', + help='AWS region to sign (uses AWS_REGION environment variable if not provided, with final fallback to us-east-1)', default=None, ) + parser.add_argument( + '--metadata', + nargs='*', + action=KeyValueAction, + default=None, + help='Metadata to inject into MCP requests as key=value pairs (e.g., --metadata AWS_REGION=us-west-2 KEY=VALUE)', + ) + parser.add_argument( '--read-only', action='store_true', diff --git a/mcp_proxy_for_aws/server.py b/mcp_proxy_for_aws/server.py index 5eb51df..adbc0f0 100644 --- a/mcp_proxy_for_aws/server.py +++ b/mcp_proxy_for_aws/server.py @@ -54,11 +54,22 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None: region = determine_aws_region(args.endpoint, args.region) logger.debug('Using region: %s', region) + # Build metadata dictionary - start with AWS_REGION, then merge user metadata + metadata = {'AWS_REGION': region} + if args.metadata: + metadata.update(args.metadata) + # Get profile profile = args.profile # Log server configuration - logger.info('Using service: %s, region: %s, profile: %s', service, region, profile) + logger.info( + 'Using service: %s, region: %s, metadata: %s, profile: %s', + service, + region, + metadata, + profile, + ) logger.info('Running in MCP mode') timeout = httpx.Timeout( @@ -69,7 +80,9 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None: ) # Create transport with SigV4 authentication - transport = create_transport_with_sigv4(args.endpoint, service, region, timeout, profile) + transport = create_transport_with_sigv4( + args.endpoint, service, region, metadata, timeout, profile + ) # Create proxy with the transport proxy = FastMCP.as_proxy(transport) diff --git a/mcp_proxy_for_aws/sigv4_helper.py b/mcp_proxy_for_aws/sigv4_helper.py index 9313760..2f8b81f 100644 --- a/mcp_proxy_for_aws/sigv4_helper.py +++ b/mcp_proxy_for_aws/sigv4_helper.py @@ -16,10 +16,12 @@ import boto3 import httpx +import json import logging from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials +from functools import partial from typing import Any, Dict, Generator, Optional @@ -71,56 +73,6 @@ def auth_flow(self, request: httpx.Request) -> Generator[httpx.Request, httpx.Re yield request -async def _handle_error_response(response: httpx.Response) -> None: - """Event hook to handle HTTP error responses and extract details. - - This function is called for every HTTP response to check for errors - and provide more detailed error information when requests fail. - - Args: - response: The HTTP response object - - Raises: - No raises. let the mcp http client handle the errors. - """ - if response.is_error: - # warning only because the SDK logs error - log_level = logging.WARNING - if ( - # The server MAY respond 405 to GET (SSE) and DELETE (session). - response.status_code == 405 and response.request.method in ('GET', 'DELETE') - ) or ( - # The server MAY terminate the session at any time, after which it MUST - # respond to requests containing that session ID with HTTP 404 Not Found. - response.status_code == 404 and response.request.method == 'POST' - ): - log_level = logging.DEBUG - - try: - # read the content and settle the response content. required to get body (.json(), .text) - await response.aread() - except Exception as e: - logger.debug('Failed to read response: %s', e) - # do nothing and let the client and SDK handle the error - return - - # Try to extract error details with fallbacks - try: - # Try to parse JSON error details - error_details = response.json() - logger.log(log_level, 'HTTP %d Error Details: %s', response.status_code, error_details) - except Exception: - # If JSON parsing fails, use response text or status code - try: - response_text = response.text - logger.log(log_level, 'HTTP %d Error: %s', response.status_code, response_text) - except Exception: - # Fallback to just status code and URL - logger.log( - log_level, 'HTTP %d Error for url %s', response.status_code, response.url - ) - - def create_aws_session(profile: Optional[str] = None) -> boto3.Session: """Create an AWS session with optional profile. @@ -150,42 +102,13 @@ def create_aws_session(profile: Optional[str] = None) -> boto3.Session: return session -def create_sigv4_auth(service: str, region: str, profile: Optional[str] = None) -> SigV4HTTPXAuth: - """Create SigV4 authentication for AWS requests. - - Args: - service: AWS service name for SigV4 signing - profile: AWS profile to use (optional) - region: AWS region (defaults to AWS_REGION env var or us-east-1) - - Returns: - SigV4HTTPXAuth instance - - Raises: - ValueError: If credentials cannot be obtained - """ - # Create session and get credentials - session = create_aws_session(profile) - credentials = session.get_credentials() - - # Create SigV4Auth with explicit credentials - sigv4_auth = SigV4HTTPXAuth( - credentials=credentials, - service=service, - region=region, - ) - - logger.info("Created SigV4 authentication for service '%s' in region '%s'", service, region) - return sigv4_auth - - def create_sigv4_client( service: str, region: str, timeout: Optional[httpx.Timeout] = None, profile: Optional[str] = None, headers: Optional[Dict[str, str]] = None, - auth: Optional[httpx.Auth] = None, + metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> httpx.AsyncClient: """Create an httpx.AsyncClient with SigV4 authentication. @@ -196,7 +119,7 @@ def create_sigv4_client( region: AWS region (optional, defaults to AWS_REGION env var or us-east-1) timeout: Timeout configuration for the HTTP client headers: Headers to include in requests - auth: Auth parameter (ignored as we provide our own) + metadata: Metadata to inject into MCP _meta field **kwargs: Additional arguments to pass to httpx.AsyncClient Returns: @@ -220,14 +143,159 @@ def create_sigv4_client( 'Creating httpx.AsyncClient with custom headers: %s', client_kwargs.get('headers', {}) ) - # Create SigV4 auth - sigv4_auth = create_sigv4_auth(service, region, profile) - - # Create the client with SigV4 auth and error handling event hook - logger.info("Creating httpx.AsyncClient with SigV4 authentication for service '%s'", service) + logger.info("Creating httpx.AsyncClient with SigV4 request hooks for service '%s'", service) return httpx.AsyncClient( - auth=sigv4_auth, **client_kwargs, - event_hooks={'response': [_handle_error_response]}, + event_hooks={ + 'response': [_handle_error_response], + 'request': [ + partial(_inject_metadata_hook, metadata or {}), + partial(_sign_request_hook, region, service, profile), + ], + }, ) + + +async def _handle_error_response(response: httpx.Response) -> None: + """Event hook to handle HTTP error responses and extract details. + + This function is called for every HTTP response to check for errors + and provide more detailed error information when requests fail. + + Args: + response: The HTTP response object + + Raises: + No raises. let the mcp http client handle the errors. + """ + if response.is_error: + # warning only because the SDK logs error + log_level = logging.WARNING + if ( + # The server MAY respond 405 to GET (SSE) and DELETE (session). + response.status_code == 405 and response.request.method in ('GET', 'DELETE') + ) or ( + # The server MAY terminate the session at any time, after which it MUST + # respond to requests containing that session ID with HTTP 404 Not Found. + response.status_code == 404 and response.request.method == 'POST' + ): + log_level = logging.DEBUG + + try: + # read the content and settle the response content. required to get body (.json(), .text) + await response.aread() + except Exception as e: + logger.debug('Failed to read response: %s', e) + # do nothing and let the client and SDK handle the error + return + + # Try to extract error details with fallbacks + try: + # Try to parse JSON error details + error_details = response.json() + logger.log(log_level, 'HTTP %d Error Details: %s', response.status_code, error_details) + except Exception: + # If JSON parsing fails, use response text or status code + try: + response_text = response.text + logger.log(log_level, 'HTTP %d Error: %s', response.status_code, response_text) + except Exception: + # Fallback to just status code and URL + logger.log( + log_level, 'HTTP %d Error for url %s', response.status_code, response.url + ) + + +async def _sign_request_hook( + region: str, + service: str, + profile: Optional[str], + request: httpx.Request, +) -> None: + """Request hook to sign HTTP requests with AWS SigV4. + + This hook signs the request with AWS SigV4 credentials and adds signature headers. + + This should be the last hook called to ensure the signature includes any modifications. + + Args: + region: AWS region for SigV4 signing + service: AWS service name for SigV4 signing + profile: AWS profile to use (optional) + request: The HTTP request object to sign (modified in-place) + """ + # Set Content-Length for signing + request.headers['Content-Length'] = str(len(request.content)) + + # Get AWS credentials + session = create_aws_session(profile) + credentials = session.get_credentials() + logger.info('Signing request with credentials for access key: %s', credentials.access_key) + + # Create SigV4 auth and use its signing logic + auth = SigV4HTTPXAuth(credentials, service, region) + + # Call auth_flow to sign the request (it modifies request in-place) + auth_flow = auth.auth_flow(request) + next(auth_flow) # Execute the generator to perform signing + + logger.debug('Request headers after signing: %s', request.headers) + + +async def _inject_metadata_hook(metadata: Dict[str, Any], request: httpx.Request) -> None: + """Request hook to inject metadata into MCP calls. + + Args: + metadata: Dictionary of metadata to inject into _meta field + request: The HTTP request object + """ + logger.info('=== Outgoing Request ===') + logger.info('URL: %s', request.url) + logger.info('Method: %s', request.method) + + # Try to inject metadata if it's a JSON-RPC/MCP request + if request.content and metadata: + try: + # Parse the request body + body = json.loads(await request.aread()) + + # Check if it's a JSON-RPC request + if isinstance(body, dict) and 'jsonrpc' in body: + # Ensure _meta exists in params + if '_meta' not in body['params']: + body['params']['_meta'] = {} + + # Get existing metadata + existing_meta = body['params']['_meta'] + + # Merge metadata (existing takes precedence) + if isinstance(existing_meta, dict): + # Check for conflicting keys before merge + conflicting_keys = set(metadata.keys()) & set(existing_meta.keys()) + if conflicting_keys: + for key in conflicting_keys: + logger.warning( + 'Metadata key "%s" already exists in _meta. ' + 'Keeping existing value "%s", ignoring injected value "%s"', + key, + existing_meta[key], + metadata[key], + ) + body['params']['_meta'] = {**metadata, **existing_meta} + else: + logger.info('Replacing non-dict _meta value with injected metadata') + body['params']['_meta'] = metadata + + # Create new content with updated metadata + new_content = json.dumps(body).encode('utf-8') + + # Update the request with new content + request.stream = httpx.ByteStream(new_content) + request._content = new_content + + logger.info('Injected metadata into _meta: %s', body['params']['_meta']) + + except (json.JSONDecodeError, KeyError, TypeError) as e: + # Not a JSON request or invalid format, skip metadata injection + logger.error('Skipping metadata injection: %s', e) diff --git a/mcp_proxy_for_aws/utils.py b/mcp_proxy_for_aws/utils.py index d267bc1..7c588c0 100644 --- a/mcp_proxy_for_aws/utils.py +++ b/mcp_proxy_for_aws/utils.py @@ -21,7 +21,7 @@ import re from fastmcp.client.transports import StreamableHttpTransport from mcp_proxy_for_aws.sigv4_helper import create_sigv4_client -from typing import Dict, Optional +from typing import Any, Dict, Optional from urllib.parse import urlparse @@ -32,6 +32,7 @@ def create_transport_with_sigv4( url: str, service: str, region: str, + metadata: Dict[str, Any], custom_timeout: httpx.Timeout, profile: Optional[str] = None, ) -> StreamableHttpTransport: @@ -41,6 +42,7 @@ def create_transport_with_sigv4( url: The endpoint URL service: AWS service name for SigV4 signing region: AWS region to use + metadata: Metadata dictionary to inject into MCP requests custom_timeout: httpx.Timeout used to connect to the endpoint profile: AWS profile to use (optional) @@ -60,6 +62,7 @@ def client_factory( region=region, headers=headers, timeout=custom_timeout, + metadata=metadata, auth=auth, ) diff --git a/tests/integ/mcp/simple_mcp_client.py b/tests/integ/mcp/simple_mcp_client.py index 1f8b615..98b58e7 100644 --- a/tests/integ/mcp/simple_mcp_client.py +++ b/tests/integ/mcp/simple_mcp_client.py @@ -3,19 +3,28 @@ import logging from fastmcp.client import StdioTransport from fastmcp.client.elicitation import ElicitResult +from typing import Dict, Optional logger = logging.getLogger(__name__) -def build_mcp_client(endpoint: str, region_name: str) -> fastmcp.Client: - """Create a MCP Client using the mcp-proxy-for-aws against a remote MCP Server.""" +def build_mcp_client( + endpoint: str, region_name: str, metadata: Optional[Dict[str, str]] = None +) -> fastmcp.Client: + """Create a MCP Client with custom metadata. + + Args: + endpoint: The MCP server endpoint URL + region_name: AWS region name + metadata: Optional custom metadata to pass via --metadata flag + + Returns: + fastmcp.Client configured to use mcp-proxy-for-aws with custom metadata + """ return fastmcp.Client( StdioTransport( - **_build_mcp_config( - endpoint=endpoint, - region_name=region_name, - ) + **_build_mcp_config(endpoint=endpoint, region_name=region_name, metadata=metadata) ), elicitation_handler=_basic_elicitation_handler, timeout=30.0, # seconds @@ -39,7 +48,7 @@ async def _basic_elicitation_handler(message: str, response_type: type, params, raise RuntimeError(f'Unknown Response-type, rather failing - {response_type}') -def _build_mcp_config(endpoint: str, region_name: str): +def _build_mcp_config(endpoint: str, region_name: str, metadata: Optional[Dict[str, str]] = None): credentials = boto3.Session().get_credentials() environment_variables = { @@ -49,14 +58,29 @@ def _build_mcp_config(endpoint: str, region_name: str): 'AWS_SESSION_TOKEN': credentials.token, } + args = _build_args(endpoint, region_name, metadata) + return { 'command': 'mcp-proxy-for-aws', - 'args': [ - endpoint, - '--log-level', - 'DEBUG', - '--region', - region_name, - ], + 'args': args, 'env': environment_variables, } + + +def _build_args(endpoint: str, region_name: str, metadata: Optional[Dict[str, str]] = None): + """Build command line arguments for mcp-proxy-for-aws.""" + args = [ + endpoint, + '--log-level', + 'DEBUG', + '--region', + region_name, + ] + + # Add metadata arguments if provided + if metadata: + args.append('--metadata') + for key, value in metadata.items(): + args.append(f'{key}={value}') + + return args diff --git a/tests/integ/mcp/simple_mcp_server/mcp_server.py b/tests/integ/mcp/simple_mcp_server/mcp_server.py index 4cead55..57301d8 100644 --- a/tests/integ/mcp/simple_mcp_server/mcp_server.py +++ b/tests/integ/mcp/simple_mcp_server/mcp_server.py @@ -73,6 +73,16 @@ async def elicit_for_my_name(elicitation_expected: str, ctx: Context): return 'cancelled' +##### Metadata Testing + + +@mcp.tool +def echo_metadata(ctx: Context): + """MCP Tool that echoes back the _meta field from the request.""" + meta = ctx.request_context.meta + return {'received_meta': meta} + + #### Server Setup diff --git a/tests/integ/test_proxy_simple_mcp_server.py b/tests/integ/test_proxy_simple_mcp_server.py index 4e3314e..dd33b22 100644 --- a/tests/integ/test_proxy_simple_mcp_server.py +++ b/tests/integ/test_proxy_simple_mcp_server.py @@ -1,8 +1,10 @@ """Test the features about testing connecting to remote MCP Server runtime via the proxy.""" import fastmcp +import json import logging import pytest +from .mcp.simple_mcp_client import build_mcp_client from mcp.types import TextContent @@ -91,3 +93,139 @@ async def test_handle_elicitation_when_declining( async def test_handle_sampling(mcp_client: fastmcp.Client): """TODO.""" pass + + +@pytest.mark.asyncio(loop_scope='module') +async def test_metadata_injection_aws_region( + mcp_client: fastmcp.Client, remote_mcp_server_configuration +): + """Test that AWS_REGION is automatically injected and received by the server. + + This integration test verifies the full flow: + 1. Client makes a request through the proxy + 2. Proxy injects AWS_REGION into the _meta field + 3. Server receives the request with metadata + 4. Server echoes back the metadata it received + 5. We verify AWS_REGION was correctly transmitted + """ + # Call the echo_metadata tool which returns the _meta field it received + actual_response = await mcp_client.call_tool('echo_metadata', {}) + + # Extract the response content + actual_text = get_text_content(actual_response) + + # Parse the JSON response + response_data = json.loads(actual_text) + + # Verify that AWS_REGION was injected and received by the server + assert 'received_meta' in response_data, ( + f'Response should contain received_meta: {response_data}' + ) + assert response_data['received_meta'] is not None, 'Metadata should not be None' + assert 'AWS_REGION' in response_data['received_meta'], ( + f'Metadata should contain AWS_REGION: {response_data["received_meta"]}' + ) + assert ( + response_data['received_meta']['AWS_REGION'] + == remote_mcp_server_configuration['region_name'] + ), f'AWS_REGION should be {remote_mcp_server_configuration["region_name"]}' + + +@pytest.mark.asyncio(loop_scope='module') +async def test_metadata_injection_custom_fields(remote_mcp_server_configuration): + """Test that arbitrary metadata fields can be set via --metadata flag. + + This integration test verifies: + 1. Custom metadata fields are injected + 2. AWS_REGION is automatically added alongside custom fields + 3. Server receives all metadata fields + """ + # Build client with custom metadata + custom_metadata = { + 'TRACKING_ID': 'test-tracking-123', + 'ENVIRONMENT': 'integration-test', + 'CUSTOM_FIELD': 'custom-value-456', + } + + client = build_mcp_client( + endpoint=remote_mcp_server_configuration['endpoint'], + region_name=remote_mcp_server_configuration['region_name'], + metadata=custom_metadata, + ) + + async with client: + # Call the echo_metadata tool + actual_response = await client.call_tool('echo_metadata', {}) + + # Extract and parse response + actual_text = get_text_content(actual_response) + response_data = json.loads(actual_text) + + # Verify custom metadata was received + assert 'received_meta' in response_data, ( + f'Response should contain received_meta: {response_data}' + ) + received_meta = response_data['received_meta'] + + # Verify all custom fields are present + assert received_meta['TRACKING_ID'] == 'test-tracking-123', ( + 'TRACKING_ID should be test-tracking-123' + ) + assert received_meta['ENVIRONMENT'] == 'integration-test', ( + 'ENVIRONMENT should be integration-test' + ) + assert received_meta['CUSTOM_FIELD'] == 'custom-value-456', ( + 'CUSTOM_FIELD should be custom-value-456' + ) + + # Verify AWS_REGION is still auto-injected + assert 'AWS_REGION' in received_meta, 'AWS_REGION should be auto-injected' + assert received_meta['AWS_REGION'] == remote_mcp_server_configuration['region_name'], ( + f'AWS_REGION should be {remote_mcp_server_configuration["region_name"]}' + ) + + +@pytest.mark.asyncio(loop_scope='module') +async def test_metadata_injection_override_aws_region(remote_mcp_server_configuration): + """Test that AWS_REGION can be overridden via --metadata flag. + + This integration test verifies: + 1. User can override AWS_REGION using --metadata + 2. Override takes precedence over --region parameter + 3. Server receives the overridden value + """ + # Build client with AWS_REGION override + overridden_region = 'eu-central-1' + custom_metadata = { + 'AWS_REGION': overridden_region, + 'TEST_FIELD': 'test-value', + } + + client = build_mcp_client( + endpoint=remote_mcp_server_configuration['endpoint'], + region_name=remote_mcp_server_configuration['region_name'], # Original region + metadata=custom_metadata, + ) + + async with client: + # Call the echo_metadata tool + actual_response = await client.call_tool('echo_metadata', {}) + + # Extract and parse response + actual_text = get_text_content(actual_response) + response_data = json.loads(actual_text) + + # Verify metadata was received + assert 'received_meta' in response_data, ( + f'Response should contain received_meta: {response_data}' + ) + received_meta = response_data['received_meta'] + + # Verify AWS_REGION was overridden + assert received_meta['AWS_REGION'] == overridden_region, ( + f'AWS_REGION should be overridden to {overridden_region}, ' + f'not {remote_mcp_server_configuration["region_name"]}' + ) + + # Verify other custom fields are present + assert received_meta['TEST_FIELD'] == 'test-value', 'TEST_FIELD should be test-value' diff --git a/tests/unit/test_cli.py b/tests/unit/test_cli.py index 287a2cf..890fd05 100644 --- a/tests/unit/test_cli.py +++ b/tests/unit/test_cli.py @@ -140,3 +140,49 @@ def test_parse_args_negative_timeout(self): """Test parsing fails with negative timeout (within_range validation).""" with pytest.raises(SystemExit): parse_args() + + @patch( + 'sys.argv', + ['mcp-proxy-for-aws', 'https://example.com', '--metadata', 'KEY1=value1', 'KEY2=value2'], + ) + def test_parse_metadata_argument(self): + """Test parsing metadata key=value pairs.""" + args = parse_args() + assert args.metadata == {'KEY1': 'value1', 'KEY2': 'value2'} + + @patch( + 'sys.argv', + ['mcp-proxy-for-aws', 'https://example.com', '--metadata', 'AWS_REGION=us-west-2'], + ) + def test_parse_metadata_single_pair(self): + """Test parsing single metadata key=value pair.""" + args = parse_args() + assert args.metadata == {'AWS_REGION': 'us-west-2'} + + @patch( + 'sys.argv', + ['mcp-proxy-for-aws', 'https://example.com', '--metadata', 'KEY=value with spaces'], + ) + def test_parse_metadata_with_spaces_in_value(self): + """Test parsing metadata with spaces in value.""" + args = parse_args() + assert args.metadata == {'KEY': 'value with spaces'} + + @patch('sys.argv', ['mcp-proxy-for-aws', 'https://example.com', '--metadata']) + def test_parse_metadata_no_values(self): + """Test parsing --metadata flag with no values results in empty dict.""" + args = parse_args() + # When --metadata is provided with no values (nargs='*'), it should be empty dict + # This is handled by KeyValueAction which sets an empty dict when values is None + assert args.metadata == {} or args.metadata is None, ( + f'Expected empty dict or None, got {args.metadata}' + ) + + @patch('sys.argv', ['mcp-proxy-for-aws', 'https://example.com', '--metadata', 'INVALID']) + def test_parse_metadata_invalid_format(self): + """Test that invalid metadata format raises an error.""" + import argparse + + with pytest.raises((SystemExit, argparse.ArgumentTypeError)): + # argparse may call sys.exit or raise ArgumentTypeError + parse_args() diff --git a/tests/unit/test_hooks.py b/tests/unit/test_hooks.py new file mode 100644 index 0000000..51b039a --- /dev/null +++ b/tests/unit/test_hooks.py @@ -0,0 +1,437 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for hooks module.""" + +import httpx +import json +import pytest +from functools import partial +from mcp_proxy_for_aws.sigv4_helper import ( + _handle_error_response, + _inject_metadata_hook, + _sign_request_hook, +) +from unittest.mock import MagicMock, Mock, patch + + +def create_request_with_sigv4_headers( + url: str, body: bytes, method: str = 'POST' +) -> httpx.Request: + """Helper to create a request with required SigV4 headers for testing.""" + request = httpx.Request(method, url, content=body) + # Add minimal SigV4 headers that the hook will try to delete and re-add + request.headers['Content-Length'] = str(len(body)) + request.headers['x-amz-date'] = '20240101T000000Z' + request.headers['x-amz-security-token'] = 'test-token' + request.headers['Authorization'] = ( + 'AWS4-HMAC-SHA256 Credential=test/20240101/us-west-2/execute-api/aws4_request' + ) + return request + + +def create_mock_session(): + """Helper to create a mocked AWS session with credentials.""" + mock_session = MagicMock() + mock_credentials = MagicMock() + mock_credentials.access_key = 'test-access-key' + mock_credentials.secret_key = 'test-secret-key' + mock_credentials.token = 'test-token' + mock_session.get_credentials.return_value = mock_credentials + return mock_session + + +class TestHandleErrorResponse: + """Test cases for the _handle_error_response function.""" + + @pytest.mark.asyncio + async def test_handle_error_response_with_json_error(self): + """Test error handling with JSON error response.""" + # Create a mock error response with JSON content + request = httpx.Request('GET', 'https://example.com/test') + error_data = {'error': 'Not Found', 'message': 'The requested resource was not found'} + response = httpx.Response( + status_code=404, + headers={'content-type': 'application/json'}, + content=json.dumps(error_data).encode(), + request=request, + ) + + await _handle_error_response(response) + + # Verify response was read (content should be settled) + assert response.is_stream_consumed + + @pytest.mark.asyncio + async def test_handle_error_response_with_non_json_error(self): + """Test error handling with non-JSON error response.""" + # Create a mock error response with plain text content + request = httpx.Request('GET', 'https://example.com/test') + response = httpx.Response( + status_code=500, + headers={'content-type': 'text/plain'}, + content=b'Internal Server Error', + request=request, + ) + + await _handle_error_response(response) + + # Verify response was read + assert response.is_stream_consumed + + @pytest.mark.asyncio + async def test_handle_error_response_with_success_response(self): + """Test that successful responses don't raise errors.""" + # Create a mock success response + request = httpx.Request('GET', 'https://example.com/test') + response = httpx.Response( + status_code=200, + headers={'content-type': 'application/json'}, + content=b'{"success": true}', + request=request, + ) + + await _handle_error_response(response) + + # Verify function completes without error for success responses + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_handle_error_response_with_read_failure(self): + """Test error handling when response reading fails.""" + # Create a mock response that fails to read + request = httpx.Request('GET', 'https://example.com/test') + response = Mock(spec=httpx.Response) + response.is_error = True + response.aread = Mock(side_effect=Exception('Read failed')) + response.json = Mock(side_effect=Exception('JSON parsing failed')) + response.text = 'Mock error text' + response.status_code = 500 + response.url = 'https://example.com/test' + response.raise_for_status = Mock( + side_effect=httpx.HTTPStatusError( + message='HTTP Error', request=request, response=response + ) + ) + + await _handle_error_response(response) + + # Verify it handled the read failure gracefully (no exception raised) + # The aread() was attempted (would have been called) + response.aread.assert_called_once() + + @pytest.mark.asyncio + async def test_handle_error_response_with_invalid_json(self): + """Test error handling with invalid JSON response.""" + # Create a mock error response with invalid JSON + request = httpx.Request('GET', 'https://example.com/test') + response = httpx.Response( + status_code=400, + headers={'content-type': 'application/json'}, + content=b'Invalid JSON content {', + request=request, + ) + + await _handle_error_response(response) + + # Verify response was read despite invalid JSON + assert response.is_stream_consumed + + +class TestMetadataInjectionHook: + """Test cases for _inject_metadata_hook function.""" + + @pytest.mark.asyncio + async def test_hook_injects_metadata_into_jsonrpc_request(self): + """Test that hook injects metadata into JSON-RPC request body.""" + metadata = {'AWS_REGION': 'us-west-2', 'tracking_id': 'test-123'} + + # Create request with JSON-RPC body + request_body = json.dumps( + {'jsonrpc': '2.0', 'id': 1, 'method': 'tools/call', 'params': {'name': 'myTool'}} + ).encode('utf-8') + + request = create_request_with_sigv4_headers('https://example.com/mcp', request_body) + + # Call the hook + await _inject_metadata_hook(metadata, request) + + stream_content = await request.aread() + + # Verify metadata was injected + modified_body = json.loads(stream_content.decode('utf-8')) + assert '_meta' in modified_body['params'] + assert modified_body['params']['_meta']['AWS_REGION'] == 'us-west-2' + assert modified_body['params']['_meta']['tracking_id'] == 'test-123' + + @pytest.mark.asyncio + async def test_hook_merges_with_existing_metadata(self): + """Test that hook merges with existing _meta, existing takes precedence.""" + metadata = {'AWS_REGION': 'us-west-2', 'field1': 'injected'} + + request_body = json.dumps( + { + 'jsonrpc': '2.0', + 'id': 1, + 'method': 'tools/call', + 'params': { + 'name': 'myTool', + '_meta': {'field1': 'existing', 'field2': 'original'}, + }, + } + ).encode('utf-8') + + request = create_request_with_sigv4_headers('https://example.com/mcp', request_body) + + await _inject_metadata_hook(metadata, request) + + stream_content = await request.aread() + + modified_body = json.loads(stream_content.decode('utf-8')) + + # Existing metadata takes precedence + assert modified_body['params']['_meta']['field1'] == 'existing' + assert modified_body['params']['_meta']['field2'] == 'original' + assert modified_body['params']['_meta']['AWS_REGION'] == 'us-west-2' + + @pytest.mark.asyncio + async def test_hook_skips_non_jsonrpc_requests(self): + """Test that hook doesn't modify non-JSON-RPC requests.""" + metadata = {'AWS_REGION': 'us-west-2'} + + request_body = json.dumps({'regular': 'request'}).encode('utf-8') + original_body = request_body + + request = httpx.Request('POST', 'https://example.com/api', content=request_body) + + await _inject_metadata_hook(metadata, request) + + # Body should be unchanged + assert request._content == original_body + + @pytest.mark.asyncio + async def test_hook_handles_invalid_json_gracefully(self): + """Test that hook handles invalid JSON without crashing.""" + metadata = {'AWS_REGION': 'us-west-2'} + + request_body = b'not valid json' + request = httpx.Request('POST', 'https://example.com/mcp', content=request_body) + + # Should not raise exception + await _inject_metadata_hook(metadata, request) + + # Body should be unchanged + assert request._content == request_body + + @pytest.mark.asyncio + async def test_hook_handles_empty_body(self): + """Test that hook handles requests with no body.""" + metadata = {'AWS_REGION': 'us-west-2'} + + request = httpx.Request('GET', 'https://example.com/api') + + # Should not raise exception + await _inject_metadata_hook(metadata, request) + + @pytest.mark.asyncio + async def test_hook_handles_empty_metadata(self): + """Test that hook works with empty metadata dict.""" + metadata = {} + + request_body = json.dumps( + {'jsonrpc': '2.0', 'id': 1, 'method': 'tools/call', 'params': {'name': 'myTool'}} + ).encode('utf-8') + + request = httpx.Request('POST', 'https://example.com/mcp', content=request_body) + + # Should not inject anything but shouldn't crash + await _inject_metadata_hook(metadata, request) + + @pytest.mark.asyncio + async def test_hook_with_partial_application(self): + """Test that hook works correctly with functools.partial.""" + metadata = {'AWS_REGION': 'us-west-2', 'custom': 'value'} + + # Create curried function using partial + curried_hook = partial(_inject_metadata_hook, metadata) + + request_body = json.dumps( + {'jsonrpc': '2.0', 'id': 1, 'method': 'tools/call', 'params': {'name': 'myTool'}} + ).encode('utf-8') + + request = create_request_with_sigv4_headers('https://example.com/mcp', request_body) + + # Call the curried function (only needs request parameter) + await curried_hook(request) + + stream_content = await request.aread() + + modified_body = json.loads(stream_content.decode('utf-8')) + assert modified_body['params']['_meta']['AWS_REGION'] == 'us-west-2' + assert modified_body['params']['_meta']['custom'] == 'value' + + @pytest.mark.asyncio + async def test_hook_handles_non_dict_meta(self): + """Test that hook replaces non-dict _meta with dict.""" + metadata = {'AWS_REGION': 'us-west-2'} + + request_body = json.dumps( + { + 'jsonrpc': '2.0', + 'id': 1, + 'method': 'tools/call', + 'params': {'name': 'myTool', '_meta': 'not a dict'}, + } + ).encode('utf-8') + + request = create_request_with_sigv4_headers('https://example.com/mcp', request_body) + + await _inject_metadata_hook(metadata, request) + + stream_content = await request.aread() + + modified_body = json.loads(stream_content.decode('utf-8')) + + # _meta should be replaced with dict + assert isinstance(modified_body['params']['_meta'], dict) + assert modified_body['params']['_meta'] == metadata + + @pytest.mark.asyncio + async def test_hook_preserves_other_params(self): + """Test that hook doesn't modify other params fields.""" + metadata = {'AWS_REGION': 'us-west-2'} + + request_body = json.dumps( + { + 'jsonrpc': '2.0', + 'id': 1, + 'method': 'tools/call', + 'params': { + 'name': 'myTool', + 'arguments': {'arg1': 'value1'}, + 'other_field': 'preserved', + }, + } + ).encode('utf-8') + + request = create_request_with_sigv4_headers('https://example.com/mcp', request_body) + + await _inject_metadata_hook(metadata, request) + + stream_content = await request.aread() + + modified_body = json.loads(stream_content.decode('utf-8')) + + # Other params should be preserved + assert modified_body['params']['name'] == 'myTool' + assert modified_body['params']['arguments'] == {'arg1': 'value1'} + assert modified_body['params']['other_field'] == 'preserved' + assert modified_body['params']['_meta']['AWS_REGION'] == 'us-west-2' + + +class TestSignRequestHook: + """Test cases for sign_request_hook function.""" + + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') + @pytest.mark.asyncio + async def test_sign_request_hook_signs_request(self, mock_create_session): + """Test that sign_request_hook properly signs requests.""" + # Setup mocks + mock_create_session.return_value = create_mock_session() + + region = 'us-east-1' + service = 'bedrock-agentcore' + profile = None + + # Create request without signature headers + request_body = json.dumps({'test': 'data'}).encode('utf-8') + request = httpx.Request('POST', 'https://example.com/mcp', content=request_body) + + # Call the hook + await _sign_request_hook(region, service, profile, request) + + # Verify signature headers were added + assert 'authorization' in request.headers + assert 'x-amz-date' in request.headers + assert 'x-amz-security-token' in request.headers + assert request.headers['content-length'] == str(len(request_body)) + + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') + @pytest.mark.asyncio + async def test_sign_request_hook_with_profile(self, mock_create_session): + """Test that sign_request_hook uses profile when provided.""" + # Setup mocks + mock_create_session.return_value = create_mock_session() + + region = 'us-west-2' + service = 'execute-api' + profile = 'test-profile' + + request_body = b'test content' + request = httpx.Request('POST', 'https://example.com/api', content=request_body) + + # Call the hook + await _sign_request_hook(region, service, profile, request) + + # Verify session was created with profile + mock_create_session.assert_called_once_with(profile) + + # Verify request was signed + assert 'authorization' in request.headers + assert 'x-amz-date' in request.headers + + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') + @pytest.mark.asyncio + async def test_sign_request_hook_sets_content_length(self, mock_create_session): + """Test that sign_request_hook sets Content-Length header.""" + # Setup mocks + mock_create_session.return_value = create_mock_session() + + region = 'eu-west-1' + service = 'lambda' + profile = None + + # Create request + request_body = b'test content with specific length' + request = httpx.Request('POST', 'https://example.com/api', content=request_body) + + await _sign_request_hook(region, service, profile, request) + + # Verify Content-Length was set correctly + assert request.headers['content-length'] == str(len(request_body)) + + @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') + @pytest.mark.asyncio + async def test_sign_request_hook_with_partial_application(self, mock_create_session): + """Test that sign_request_hook works with functools.partial.""" + # Setup mocks + mock_create_session.return_value = create_mock_session() + + region = 'ap-southeast-1' + service = 'execute-api' + profile = 'prod-profile' + + # Create curried function using partial + curried_hook = partial(_sign_request_hook, region, service, profile) + + request_body = b'request data' + request = httpx.Request('POST', 'https://example.com/mcp', content=request_body) + + # Call the curried function (only needs request parameter) + await curried_hook(request) + + # Verify request was signed + assert 'authorization' in request.headers + assert 'x-amz-date' in request.headers + mock_create_session.assert_called_once_with(profile) diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py index 8cac234..63350a9 100644 --- a/tests/unit/test_server.py +++ b/tests/unit/test_server.py @@ -56,6 +56,7 @@ async def test_setup_mcp_mode( mock_args.profile = None mock_args.read_only = True mock_args.retries = 1 + mock_args.metadata = None # Add timeout parameters mock_args.timeout = 180.0 mock_args.connect_timeout = 60.0 @@ -86,8 +87,9 @@ async def test_setup_mcp_mode( assert call_args[0][0] == 'https://test.example.com' assert call_args[0][1] == 'test-service' assert call_args[0][2] == 'us-east-1' - # call_args[0][3] is the Timeout object - assert call_args[0][4] is None # profile + assert call_args[0][3] == {'AWS_REGION': 'us-east-1'} # metadata + # call_args[0][4] is the Timeout object + assert call_args[0][5] is None # profile mock_as_proxy.assert_called_once_with(mock_transport) mock_add_filtering.assert_called_once_with(mock_proxy, True) mock_add_retry.assert_called_once_with(mock_proxy, 1) @@ -116,6 +118,7 @@ async def test_setup_mcp_mode_no_retries( mock_args.profile = 'test-profile' mock_args.read_only = False mock_args.retries = 0 # No retries + mock_args.metadata = {'AWS_REGION': 'eu-west-1', 'CUSTOM_KEY': 'custom_value'} # Add timeout parameters mock_args.timeout = 180.0 mock_args.connect_timeout = 60.0 @@ -146,12 +149,116 @@ async def test_setup_mcp_mode_no_retries( assert call_args[0][0] == 'https://test.example.com' assert call_args[0][1] == 'test-service' assert call_args[0][2] == 'us-east-1' - # call_args[0][3] is the Timeout object - assert call_args[0][4] == 'test-profile' # profile + assert call_args[0][3] == { + 'AWS_REGION': 'eu-west-1', + 'CUSTOM_KEY': 'custom_value', + } # metadata + # call_args[0][4] is the Timeout object + assert call_args[0][5] == 'test-profile' # profile mock_as_proxy.assert_called_once_with(mock_transport) mock_add_filtering.assert_called_once_with(mock_proxy, False) mock_proxy.run_async.assert_called_once() + @patch('mcp_proxy_for_aws.server.create_transport_with_sigv4') + @patch('mcp_proxy_for_aws.server.FastMCP.as_proxy') + @patch('mcp_proxy_for_aws.server.determine_aws_region') + @patch('mcp_proxy_for_aws.server.determine_service_name') + @patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware') + async def test_setup_mcp_mode_no_metadata_injects_aws_region( + self, + mock_add_filtering, + mock_determine_service, + mock_determine_region, + mock_as_proxy, + mock_create_transport, + ): + """Test that AWS_REGION is automatically injected when no metadata is provided.""" + # Arrange + local_mcp = Mock(spec=FastMCP) + mock_args = Mock() + mock_args.endpoint = 'https://test.example.com' + mock_args.service = 'test-service' + mock_args.region = 'ap-southeast-1' + mock_args.profile = None + mock_args.read_only = False + mock_args.retries = 0 + mock_args.metadata = None # No metadata provided + mock_args.timeout = 180.0 + mock_args.connect_timeout = 60.0 + mock_args.read_timeout = 120.0 + mock_args.write_timeout = 180.0 + mock_args.log_level = 'INFO' + + mock_determine_service.return_value = 'test-service' + mock_determine_region.return_value = 'ap-southeast-1' + + mock_transport = Mock() + mock_create_transport.return_value = mock_transport + mock_proxy = Mock() + mock_proxy.run_async = AsyncMock() + mock_as_proxy.return_value = mock_proxy + + # Act + await setup_mcp_mode(local_mcp, mock_args) + + # Assert - verify AWS_REGION was automatically injected + assert mock_create_transport.call_count == 1 + call_args = mock_create_transport.call_args + metadata = call_args[0][3] + assert metadata == {'AWS_REGION': 'ap-southeast-1'} + + @patch('mcp_proxy_for_aws.server.create_transport_with_sigv4') + @patch('mcp_proxy_for_aws.server.FastMCP.as_proxy') + @patch('mcp_proxy_for_aws.server.determine_aws_region') + @patch('mcp_proxy_for_aws.server.determine_service_name') + @patch('mcp_proxy_for_aws.server.add_tool_filtering_middleware') + async def test_setup_mcp_mode_metadata_without_aws_region_injects_it( + self, + mock_add_filtering, + mock_determine_service, + mock_determine_region, + mock_as_proxy, + mock_create_transport, + ): + """Test that AWS_REGION is injected even when other metadata is provided.""" + # Arrange + local_mcp = Mock(spec=FastMCP) + mock_args = Mock() + mock_args.endpoint = 'https://test.example.com' + mock_args.service = 'test-service' + mock_args.region = 'us-west-1' + mock_args.profile = None + mock_args.read_only = False + mock_args.retries = 0 + mock_args.metadata = {'CUSTOM_KEY': 'custom_value', 'ANOTHER_KEY': 'another_value'} + mock_args.timeout = 180.0 + mock_args.connect_timeout = 60.0 + mock_args.read_timeout = 120.0 + mock_args.write_timeout = 180.0 + mock_args.log_level = 'INFO' + + mock_determine_service.return_value = 'test-service' + mock_determine_region.return_value = 'us-west-1' + + mock_transport = Mock() + mock_create_transport.return_value = mock_transport + mock_proxy = Mock() + mock_proxy.run_async = AsyncMock() + mock_as_proxy.return_value = mock_proxy + + # Act + await setup_mcp_mode(local_mcp, mock_args) + + # Assert - verify AWS_REGION was injected along with custom metadata + assert mock_create_transport.call_count == 1 + call_args = mock_create_transport.call_args + metadata = call_args[0][3] + assert metadata == { + 'AWS_REGION': 'us-west-1', + 'CUSTOM_KEY': 'custom_value', + 'ANOTHER_KEY': 'another_value', + } + def test_add_tool_filtering_middleware(self): """Test that tool filtering middleware is added correctly.""" # Arrange @@ -262,41 +369,37 @@ def test_validate_service_name_service_parsing(self): result = determine_service_name(endpoint) assert result == expected_service - @patch('mcp_proxy_for_aws.sigv4_helper.boto3.Session') @patch('mcp_proxy_for_aws.sigv4_helper.httpx.AsyncClient') - @patch('mcp_proxy_for_aws.sigv4_helper.SigV4Auth') - def test_create_sigv4_client(self, mock_sigv4_auth, mock_async_client, mock_session): - """Test creating SigV4 authenticated client with HTTPX auth.""" - # Arrange - mock_credentials = Mock() - mock_credentials.access_key = 'test_access_key' - mock_credentials.secret_key = 'test_secret_key' - mock_credentials.token = 'test_token' - - mock_session_instance = Mock() - mock_session_instance.get_credentials.return_value = mock_credentials - mock_session.return_value = mock_session_instance + def test_create_sigv4_client(self, mock_async_client): + """Test creating SigV4 authenticated client with request hooks. + Note: Session creation and signing now happens in sign_request_hook, + not during client creation. + """ # Act create_sigv4_client(service='test-service', region='us-west-2', profile='test-profile') # Assert - mock_session.assert_called_once_with(profile_name='test-profile') - mock_sigv4_auth.assert_called_once_with(mock_credentials, 'test-service', 'us-west-2') - mock_async_client.assert_called_once() - - @patch('mcp_proxy_for_aws.sigv4_helper.boto3.Session') - def test_create_sigv4_client_no_credentials(self, mock_session): - """Test creating SigV4 client with no credentials.""" - # Arrange - mock_session_instance = Mock() - mock_session_instance.get_credentials.return_value = None - mock_session.return_value = mock_session_instance - - # Act & Assert - with pytest.raises(ValueError) as exc_info: - create_sigv4_client(service='test-service', region='test-region') - assert 'No AWS credentials found' in str(exc_info.value) + # Verify AsyncClient was called (signing happens via hooks) + assert mock_async_client.call_count == 1 + call_args = mock_async_client.call_args + # Verify hooks are registered + assert 'event_hooks' in call_args[1] + assert 'request' in call_args[1]['event_hooks'] + assert 'response' in call_args[1]['event_hooks'] + # Should have metadata injection + sign hooks + assert len(call_args[1]['event_hooks']['request']) == 2 + + def test_create_sigv4_client_no_credentials(self): + """Test that credential check happens in sign_request_hook, not during client creation. + + Note: With the refactoring, client creation no longer validates credentials. + Credential validation now happens in sign_request_hook when the request is signed. + """ + # Client creation should succeed even without credentials + # (credentials are checked when signing happens) + client = create_sigv4_client(service='test-service', region='test-region') + assert client is not None def test_main_module_execution(self): """Test that main is called when module is executed directly.""" diff --git a/tests/unit/test_sigv4_helper.py b/tests/unit/test_sigv4_helper.py index 804e88f..7c6cd20 100644 --- a/tests/unit/test_sigv4_helper.py +++ b/tests/unit/test_sigv4_helper.py @@ -15,13 +15,10 @@ """Unit tests for sigv4_helper module.""" import httpx -import json import pytest from mcp_proxy_for_aws.sigv4_helper import ( SigV4HTTPXAuth, - _handle_error_response, create_aws_session, - create_sigv4_auth, create_sigv4_client, ) from unittest.mock import Mock, patch @@ -54,87 +51,6 @@ async def test_auth_flow_signs_request(self): assert 'X-Amz-Date' in signed_request.headers -class TestHandleErrorResponse: - """Test cases for the _handle_error_response function.""" - - @pytest.mark.asyncio - async def test_handle_error_response_with_json_error(self): - """Test error handling with JSON error response.""" - # Create a mock error response with JSON content - request = httpx.Request('GET', 'https://example.com/test') - error_data = {'error': 'Not Found', 'message': 'The requested resource was not found'} - response = httpx.Response( - status_code=404, - headers={'content-type': 'application/json'}, - content=json.dumps(error_data).encode(), - request=request, - ) - - await _handle_error_response(response) - - @pytest.mark.asyncio - async def test_handle_error_response_with_non_json_error(self): - """Test error handling with non-JSON error response.""" - # Create a mock error response with plain text content - request = httpx.Request('GET', 'https://example.com/test') - response = httpx.Response( - status_code=500, - headers={'content-type': 'text/plain'}, - content=b'Internal Server Error', - request=request, - ) - - await _handle_error_response(response) - - @pytest.mark.asyncio - async def test_handle_error_response_with_success_response(self): - """Test that successful responses don't raise errors.""" - # Create a mock success response - request = httpx.Request('GET', 'https://example.com/test') - response = httpx.Response( - status_code=200, - headers={'content-type': 'application/json'}, - content=b'{"success": true}', - request=request, - ) - - await _handle_error_response(response) - - @pytest.mark.asyncio - async def test_handle_error_response_with_read_failure(self): - """Test error handling when response reading fails.""" - # Create a mock response that fails to read - request = httpx.Request('GET', 'https://example.com/test') - response = Mock(spec=httpx.Response) - response.is_error = True - response.aread = Mock(side_effect=Exception('Read failed')) - response.json = Mock(side_effect=Exception('JSON parsing failed')) - response.text = 'Mock error text' - response.status_code = 500 - response.url = 'https://example.com/test' - response.raise_for_status = Mock( - side_effect=httpx.HTTPStatusError( - message='HTTP Error', request=request, response=response - ) - ) - - await _handle_error_response(response) - - @pytest.mark.asyncio - async def test_handle_error_response_with_invalid_json(self): - """Test error handling with invalid JSON response.""" - # Create a mock error response with invalid JSON - request = httpx.Request('GET', 'https://example.com/test') - response = httpx.Response( - status_code=400, - headers={'content-type': 'application/json'}, - content=b'Invalid JSON content {', - request=request, - ) - - await _handle_error_response(response) - - class TestCreateAwsSession: """Test cases for the create_aws_session function.""" @@ -200,87 +116,32 @@ def test_create_aws_session_creation_failure(self, mock_session_class): assert 'invalid-profile' in str(exc_info.value) -class TestCreateSigv4Auth: - """Test cases for the create_sigv4_auth function.""" - - @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') - def test_create_sigv4_auth_default(self, mock_create_session): - """Test creating SigV4 auth with default parameters.""" - # Mock session and credentials - mock_session = Mock() - mock_credentials = Mock() - mock_credentials.access_key = 'test_access_key' - mock_credentials.secret_key = 'test_secret_key' - mock_credentials.token = 'test_token' - mock_session.get_credentials.return_value = mock_credentials - mock_create_session.return_value = mock_session - - # Test auth creation - result = create_sigv4_auth('test-service', 'test-region') - - # Verify auth was created correctly - assert isinstance(result, SigV4HTTPXAuth) - assert result.service == 'test-service' - assert result.region == 'test-region' # default region - assert result.credentials == mock_credentials - - @patch('mcp_proxy_for_aws.sigv4_helper.create_aws_session') - def test_create_sigv4_auth_with_explicit_region(self, mock_create_session): - """Test creating SigV4 auth with explicit region parameter.""" - # Mock session and credentials - mock_session = Mock() - mock_credentials = Mock() - mock_credentials.access_key = 'test_access_key' - mock_credentials.secret_key = 'test_secret_key' - mock_credentials.token = 'test_token' - mock_session.get_credentials.return_value = mock_credentials - mock_create_session.return_value = mock_session - - # Test auth creation with explicit region - result = create_sigv4_auth('test-service', region='ap-southeast-1') - - # Verify auth was created with explicit region - assert isinstance(result, SigV4HTTPXAuth) - assert result.service == 'test-service' - assert result.region == 'ap-southeast-1' - assert result.credentials == mock_credentials - - class TestCreateSigv4Client: """Test cases for the create_sigv4_client function.""" - @patch('mcp_proxy_for_aws.sigv4_helper.create_sigv4_auth') @patch('httpx.AsyncClient') - def test_create_sigv4_client_default(self, mock_client_class, mock_create_auth): + def test_create_sigv4_client_default(self, mock_client_class): """Test creating SigV4 client with default parameters.""" - # Mock auth and client - mock_auth = Mock() - mock_create_auth.return_value = mock_auth mock_client = Mock() mock_client_class.return_value = mock_client # Test client creation result = create_sigv4_client(service='test-service', region='test-region') - # Verify client was created correctly - mock_create_auth.assert_called_once_with('test-service', 'test-region', None) - # Check that AsyncClient was called with correct parameters call_args = mock_client_class.call_args - assert call_args[1]['auth'] == mock_auth + assert 'auth' not in call_args[1], 'Auth should not be used, signing via hooks' assert 'event_hooks' in call_args[1] assert 'response' in call_args[1]['event_hooks'] + assert 'request' in call_args[1]['event_hooks'] assert len(call_args[1]['event_hooks']['response']) == 1 + assert len(call_args[1]['event_hooks']['request']) == 2 # metadata + sign hooks assert call_args[1]['headers']['Accept'] == 'application/json, text/event-stream' assert result == mock_client - @patch('mcp_proxy_for_aws.sigv4_helper.create_sigv4_auth') @patch('httpx.AsyncClient') - def test_create_sigv4_client_with_custom_headers(self, mock_client_class, mock_create_auth): + def test_create_sigv4_client_with_custom_headers(self, mock_client_class): """Test creating SigV4 client with custom headers.""" - # Mock auth and client - mock_auth = Mock() - mock_create_auth.return_value = mock_auth mock_client = Mock() mock_client_class.return_value = mock_client @@ -299,15 +160,9 @@ def test_create_sigv4_client_with_custom_headers(self, mock_client_class, mock_c assert call_args[1]['headers'] == expected_headers assert result == mock_client - @patch('mcp_proxy_for_aws.sigv4_helper.create_sigv4_auth') @patch('httpx.AsyncClient') - def test_create_sigv4_client_with_custom_service_and_region( - self, mock_client_class, mock_create_auth - ): + def test_create_sigv4_client_with_custom_service_and_region(self, mock_client_class): """Test creating SigV4 client with custom service and region.""" - # Mock auth and client - mock_auth = Mock() - mock_create_auth.return_value = mock_auth mock_client = Mock() mock_client_class.return_value = mock_client @@ -316,17 +171,12 @@ def test_create_sigv4_client_with_custom_service_and_region( service='custom-service', profile='test-profile', region='us-east-1' ) - # Verify auth was created with custom parameters - mock_create_auth.assert_called_once_with('custom-service', 'us-east-1', 'test-profile') + # Verify client was created assert result == mock_client - @patch('mcp_proxy_for_aws.sigv4_helper.create_sigv4_auth') @patch('httpx.AsyncClient') - def test_create_sigv4_client_with_kwargs(self, mock_client_class, mock_create_auth): + def test_create_sigv4_client_with_kwargs(self, mock_client_class): """Test creating SigV4 client with additional kwargs.""" - # Mock auth and client - mock_auth = Mock() - mock_create_auth.return_value = mock_auth mock_client = Mock() mock_client_class.return_value = mock_client @@ -344,18 +194,14 @@ def test_create_sigv4_client_with_kwargs(self, mock_client_class, mock_create_au assert call_args[1]['proxies'] == {'http': 'http://proxy:8080'} assert result == mock_client - @patch('mcp_proxy_for_aws.sigv4_helper.create_sigv4_auth') @patch('httpx.AsyncClient') - def test_create_sigv4_client_with_prompt_context(self, mock_client_class, mock_create_auth): + def test_create_sigv4_client_with_prompt_context(self, mock_client_class): """Test creating SigV4 client when prompts exist in the system context. This test simulates the scenario where the sigv4_helper is used in a context where MCP prompts are present, ensuring the client is properly configured to handle requests that might include prompt-related content or headers. """ - # Mock auth and client - mock_auth = Mock() - mock_create_auth.return_value = mock_auth mock_client = Mock() mock_client_class.return_value = mock_client @@ -369,12 +215,8 @@ def test_create_sigv4_client_with_prompt_context(self, mock_client_class, mock_c service='test-service', headers=prompt_context_headers, region='us-west-2' ) - # Verify client was created correctly with prompt context - mock_create_auth.assert_called_once_with('test-service', 'us-west-2', None) - # Check that AsyncClient was called with correct parameters including prompt headers call_args = mock_client_class.call_args - assert call_args[1]['auth'] == mock_auth # Verify headers include both default and prompt-context headers expected_headers = { diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 4c18e33..15e09d1 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -39,9 +39,12 @@ def test_create_transport_with_sigv4(self, mock_create_sigv4_client): service = 'test-service' profile = 'test-profile' region = 'us-east-1' + metadata = {'AWS_REGION': 'us-west-2', 'CUSTOM_KEY': 'custom_value'} custom_timeout = Timeout(30.0) - result = create_transport_with_sigv4(url, service, region, custom_timeout, profile) + result = create_transport_with_sigv4( + url, service, region, metadata, custom_timeout, profile + ) # Verify result is StreamableHttpTransport assert isinstance(result, StreamableHttpTransport) @@ -61,6 +64,7 @@ def test_create_transport_with_sigv4(self, mock_create_sigv4_client): headers={'test': 'header'}, timeout=custom_timeout, auth=None, + metadata=metadata, ) else: # If we can't access the factory directly, just verify the transport was created @@ -74,9 +78,10 @@ def test_create_transport_with_sigv4_no_profile(self, mock_create_sigv4_client): url = 'https://test-service.us-west-2.api.aws/mcp' service = 'test-service' region = 'test-region' + metadata = {'AWS_REGION': 'test-forwarding-region'} custom_timeout = Timeout(60.0) - result = create_transport_with_sigv4(url, service, region, custom_timeout) + result = create_transport_with_sigv4(url, service, region, metadata, custom_timeout) # Test that the httpx_client_factory calls create_sigv4_client correctly # We need to access the factory through the transport's internal structure @@ -91,6 +96,7 @@ def test_create_transport_with_sigv4_no_profile(self, mock_create_sigv4_client): headers=None, timeout=custom_timeout, auth=None, + metadata=metadata, ) else: # If we can't access the factory directly, just verify the transport was created From d32cde8c27f407fd86c0abb320f1e1ed355793e1 Mon Sep 17 00:00:00 2001 From: wzxxing <169175349+wzxxing@users.noreply.github.com> Date: Tue, 11 Nov 2025 14:53:14 +0100 Subject: [PATCH 3/8] docs: mention creating issues first before sending large PR with new feature (#76) --- CONTRIBUTING.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 4473ab5..4ab70bb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,6 +4,9 @@ Thank you for your interest in contributing to the MCP Proxy for AWS! We welcome ## Quick Start +> [!NOTE] +> Before implementing new features, please create an issue first to ensure your contribution aligns with the project roadmap. + 1. **Fork the repository** on GitHub 2. **Set up your development environment** - see [DEVELOPMENT.md](DEVELOPMENT.md) for detailed setup instructions 3. **Create a feature branch** from `main` From 0afc5954c17a6f77b8d3e6cf77908a79f019bfc4 Mon Sep 17 00:00:00 2001 From: wzxxing <169175349+wzxxing@users.noreply.github.com> Date: Tue, 11 Nov 2025 16:48:54 +0100 Subject: [PATCH 4/8] feat: allow iam mcp client to take a botocore credentials object (#84) --- mcp_proxy_for_aws/client.py | 46 ++++++++++++++++--------- tests/unit/test_client.py | 69 +++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 17 deletions(-) diff --git a/mcp_proxy_for_aws/client.py b/mcp_proxy_for_aws/client.py index 50ed5a7..efdb6bd 100644 --- a/mcp_proxy_for_aws/client.py +++ b/mcp_proxy_for_aws/client.py @@ -15,6 +15,7 @@ import boto3 import logging from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from botocore.credentials import Credentials from contextlib import _AsyncGeneratorContextManager from datetime import timedelta from mcp.client.streamable_http import GetSessionIdCallback, streamablehttp_client @@ -32,6 +33,7 @@ def aws_iam_streamablehttp_client( aws_service: str, aws_region: Optional[str] = None, aws_profile: Optional[str] = None, + credentials: Optional[Credentials] = None, headers: Optional[dict[str, str]] = None, timeout: float | timedelta = 30, sse_read_timeout: float | timedelta = 60 * 5, @@ -55,6 +57,7 @@ def aws_iam_streamablehttp_client( aws_service: The name of the AWS service the MCP server is hosted on, e.g. "bedrock-agentcore". aws_region: The AWS region name of the MCP server, e.g. "us-west-2". aws_profile: The AWS profile to use for authentication. + credentials: Optional AWS credentials from boto3/botocore. If provided, takes precedence over aws_profile. headers: Optional additional HTTP headers to include in requests. timeout: Request timeout in seconds or timedelta object. Defaults to 30 seconds. sse_read_timeout: Server-sent events read timeout in seconds or timedelta object. @@ -78,28 +81,37 @@ def aws_iam_streamablehttp_client( """ logger.debug('Preparing AWS IAM MCP client for endpoint: %s', endpoint) - kwargs = {} - if aws_profile is not None: - kwargs['profile_name'] = aws_profile - if aws_region is not None: - kwargs['region_name'] = aws_region + if credentials is not None: + creds = credentials + region = aws_region + if not region: + raise ValueError( + 'AWS region must be specified via aws_region parameter when using credentials.' + ) + logger.debug('Using provided AWS credentials') + else: + kwargs = {} + if aws_profile is not None: + kwargs['profile_name'] = aws_profile + if aws_region is not None: + kwargs['region_name'] = aws_region + + session = boto3.Session(**kwargs) + creds = session.get_credentials() + region = session.region_name + + if not region: + raise ValueError( + 'AWS region must be specified via aws_region parameter, AWS_REGION environment variable, or AWS config.' + ) + + logger.debug('AWS profile: %s', session.profile_name) - session = boto3.Session(**kwargs) - - profile = session.profile_name - region = session.region_name - - if not region: - raise ValueError( - 'AWS region must be specified via aws_region parameter, AWS_PROFILE environment variable, or AWS config.' - ) - - logger.debug('AWS profile: %s', profile) logger.debug('AWS region: %s', region) logger.debug('AWS service: %s', aws_service) # Create a SigV4 authentication handler with AWS credentials - auth = SigV4HTTPXAuth(session.get_credentials(), aws_service, region) + auth = SigV4HTTPXAuth(creds, aws_service, region) # Return the streamable HTTP client context manager with AWS IAM authentication return streamablehttp_client( diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 0f9df07..2960e7d 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -15,6 +15,7 @@ """Unit tests for the client, parameterized by internal call.""" import pytest +from botocore.credentials import Credentials from datetime import timedelta from mcp_proxy_for_aws.client import aws_iam_streamablehttp_client from unittest.mock import AsyncMock, Mock, patch @@ -210,3 +211,71 @@ async def mock_aexit(*_): pass assert cleanup_called + + +@pytest.mark.asyncio +async def test_credentials_parameter_with_region(mock_streams): + """Test using provided credentials with aws_region.""" + mock_read, mock_write, mock_get_session = mock_streams + creds = Credentials('test_key', 'test_secret', 'test_token') + + with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth') as mock_auth_cls: + with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + mock_auth = Mock() + mock_auth_cls.return_value = mock_auth + mock_stream_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write, mock_get_session) + ) + mock_stream_client.return_value.__aexit__ = AsyncMock(return_value=None) + + async with aws_iam_streamablehttp_client( + endpoint='https://test.example.com/mcp', + aws_service='bedrock-agentcore', + aws_region='us-east-1', + credentials=creds, + ): + pass + + mock_auth_cls.assert_called_once_with(creds, 'bedrock-agentcore', 'us-east-1') + + +@pytest.mark.asyncio +async def test_credentials_parameter_without_region_raises_error(): + """Test that using credentials without aws_region raises ValueError.""" + creds = Credentials('test_key', 'test_secret', 'test_token') + + with pytest.raises( + ValueError, + match='AWS region must be specified via aws_region parameter when using credentials', + ): + async with aws_iam_streamablehttp_client( + endpoint='https://test.example.com/mcp', + aws_service='bedrock-agentcore', + credentials=creds, + ): + pass + + +@pytest.mark.asyncio +async def test_credentials_parameter_bypasses_boto3_session(mock_streams): + """Test that providing credentials bypasses boto3.Session creation.""" + mock_read, mock_write, mock_get_session = mock_streams + creds = Credentials('test_key', 'test_secret', 'test_token') + + with patch('boto3.Session') as mock_boto: + with patch('mcp_proxy_for_aws.client.SigV4HTTPXAuth'): + with patch('mcp_proxy_for_aws.client.streamablehttp_client') as mock_stream_client: + mock_stream_client.return_value.__aenter__ = AsyncMock( + return_value=(mock_read, mock_write, mock_get_session) + ) + mock_stream_client.return_value.__aexit__ = AsyncMock(return_value=None) + + async with aws_iam_streamablehttp_client( + endpoint='https://test.example.com/mcp', + aws_service='bedrock-agentcore', + aws_region='us-west-2', + credentials=creds, + ): + pass + + mock_boto.assert_not_called() From 459de24ae0c93973beeaa0ae29c5282a37db3865 Mon Sep 17 00:00:00 2001 From: Harvish N S Date: Mon, 10 Nov 2025 21:50:44 +0100 Subject: [PATCH 5/8] pypi release automation through github actions --- .github/workflows/pypi-publish-on-release.yml | 10 ++- .github/workflows/test-pypi-publish.yml | 77 +++++++++++++++++++ 2 files changed, 83 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/test-pypi-publish.yml diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index 8b8dfb1..0a2f521 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -11,6 +11,9 @@ jobs: call-test-lint: permissions: contents: read + pull-requests: read + security-events: write + actions: read uses: ./.github/workflows/python.yml with: ref: ${{ github.event.release.target_commitish }} @@ -36,7 +39,8 @@ jobs: - name: Validate version format run: | - VERSION=$(python -c "import tomllib; print(tomllib.load(open('pyproject.toml', 'rb'))['project']['version'])") + pip install tomli + VERSION=$(python -c "import tomli; print(tomli.load(open('pyproject.toml', 'rb'))['project']['version'])") if [[ ! "$VERSION" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then echo "Invalid version format: $VERSION" echo "Expected format: X.Y.Z (e.g., 1.0.0)" @@ -60,7 +64,7 @@ jobs: name: pypi url: https://pypi.org/p/mcp-proxy-for-aws permissions: - contents: read + id-token: write steps: - name: Download distribution packages uses: actions/download-artifact@v5 @@ -70,5 +74,3 @@ jobs: - name: Publish to PyPI uses: pypa/gh-action-pypi-publish@release/v1 - with: - password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/test-pypi-publish.yml b/.github/workflows/test-pypi-publish.yml new file mode 100644 index 0000000..6ca8975 --- /dev/null +++ b/.github/workflows/test-pypi-publish.yml @@ -0,0 +1,77 @@ +name: Test PyPI Publishing + +on: + workflow_dispatch: + push: + tags: + - 'v*' + +permissions: {} + +jobs: + call-test-lint: + permissions: + contents: read + pull-requests: read + security-events: write + actions: read + uses: ./.github/workflows/python.yml + + build: + needs: call-test-lint + runs-on: ubuntu-latest + permissions: + contents: read + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + persist-credentials: false + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Install uv + run: pip install uv + + - name: Validate version format + run: | + pip install tomli + VERSION=$(python -c "import tomli; print(tomli.load(open('pyproject.toml', 'rb'))['project']['version'])") + if [[ ! "$VERSION" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then + echo "Invalid version format: $VERSION" + echo "Expected format: X.Y.Z (e.g., 1.0.0)" + exit 1 + fi + echo "Valid version format: $VERSION" + + - name: Build distribution packages + run: uv build + + - name: Upload distribution packages + uses: actions/upload-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + deploy-test: + needs: build + runs-on: ubuntu-latest + environment: + name: testpypi + url: https://test.pypi.org/p/mcp-proxy-for-aws + permissions: + id-token: write + steps: + - name: Download distribution packages + uses: actions/download-artifact@v4 + with: + name: python-package-distributions + path: dist/ + + - name: Publish to TestPyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + repository-url: https://test.pypi.org/legacy/ From 18d4e0aa4607ba3fb3401a0497b2d4e79cf2ac14 Mon Sep 17 00:00:00 2001 From: Harvish N S Date: Wed, 12 Nov 2025 14:56:22 +0100 Subject: [PATCH 6/8] chore: pypi publishing added to development.md --- .github/workflows/pypi-publish-on-release.yml | 28 +++---- .github/workflows/test-pypi-publish.yml | 42 ++++------ DEVELOPMENT.md | 84 +++++++++++++++++++ 3 files changed, 114 insertions(+), 40 deletions(-) diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index 0a2f521..5027edd 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -8,7 +8,7 @@ on: permissions: {} jobs: - call-test-lint: + call-test: permissions: contents: read pull-requests: read @@ -18,8 +18,16 @@ jobs: with: ref: ${{ github.event.release.target_commitish }} + call-integ-tests: + permissions: + id-token: write + contents: read + actions: read + uses: ./.github/workflows/python-integ.yml + secrets: inherit + build: - needs: call-test-lint + needs: [call-test, call-integ-tests] runs-on: ubuntu-latest permissions: contents: read @@ -37,17 +45,6 @@ jobs: - name: Install uv run: pip install uv - - name: Validate version format - run: | - pip install tomli - VERSION=$(python -c "import tomli; print(tomli.load(open('pyproject.toml', 'rb'))['project']['version'])") - if [[ ! "$VERSION" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then - echo "Invalid version format: $VERSION" - echo "Expected format: X.Y.Z (e.g., 1.0.0)" - exit 1 - fi - echo "Valid version format: $VERSION" - - name: Build distribution packages run: uv build @@ -72,5 +69,8 @@ jobs: name: python-package-distributions path: dist/ + - name: Set up uv + uses: astral-sh/setup-uv@v4 + - name: Publish to PyPI - uses: pypa/gh-action-pypi-publish@release/v1 + run: uv publish diff --git a/.github/workflows/test-pypi-publish.yml b/.github/workflows/test-pypi-publish.yml index 6ca8975..a32eadc 100644 --- a/.github/workflows/test-pypi-publish.yml +++ b/.github/workflows/test-pypi-publish.yml @@ -2,14 +2,11 @@ name: Test PyPI Publishing on: workflow_dispatch: - push: - tags: - - 'v*' permissions: {} jobs: - call-test-lint: + call-test: permissions: contents: read pull-requests: read @@ -17,8 +14,16 @@ jobs: actions: read uses: ./.github/workflows/python.yml + call-integ-tests: + permissions: + id-token: write + contents: read + actions: read + uses: ./.github/workflows/python-integ.yml + secrets: inherit + build: - needs: call-test-lint + needs: [call-test, call-integ-tests] runs-on: ubuntu-latest permissions: contents: read @@ -28,24 +33,8 @@ jobs: with: persist-credentials: false - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.10' - - - name: Install uv - run: pip install uv - - - name: Validate version format - run: | - pip install tomli - VERSION=$(python -c "import tomli; print(tomli.load(open('pyproject.toml', 'rb'))['project']['version'])") - if [[ ! "$VERSION" =~ ^[0-9]+\.[0-9]+\.[0-9]+$ ]]; then - echo "Invalid version format: $VERSION" - echo "Expected format: X.Y.Z (e.g., 1.0.0)" - exit 1 - fi - echo "Valid version format: $VERSION" + - name: Set up uv + uses: astral-sh/setup-uv@v4 - name: Build distribution packages run: uv build @@ -71,7 +60,8 @@ jobs: name: python-package-distributions path: dist/ + - name: Set up uv + uses: astral-sh/setup-uv@v4 + - name: Publish to TestPyPI - uses: pypa/gh-action-pypi-publish@release/v1 - with: - repository-url: https://test.pypi.org/legacy/ + run: uv publish --publish-url https://test.pypi.org/legacy/ diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 39cf77c..7a33dda 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -379,12 +379,96 @@ export LOG_LEVEL=DEBUG uv run mcp_proxy_for_aws/server.py ``` +## Releasing to PyPI + +The project uses automated PyPI publishing through GitHub Actions. Releases are triggered by creating a GitHub Release. + +### Release Process + +1. **Ensure all changes are merged to main branch** + ```bash + git checkout main + git pull origin main + ``` + +2. **Create GitHub Release** + + Go to the [Releases page](https://github.com/aws/mcp-proxy-for-aws/releases) and click "Draft a new release", then fill in: + + - **Tag**: `v1.10.0` (must start with 'v' and follow semantic versioning) + - **Target**: `main` + - **Title**: `v1.10.0` + - **Description**: Click "Generate release notes" for auto-generated notes + + Click **"Publish release"** (not "Save draft") + +3. **Automated Publishing** + + Once the release is published, GitHub Actions will automatically: + - Run all tests and linting checks + - Build distribution packages (wheel and source) + - Publish to PyPI using Trusted Publishing + + Monitor the workflow at: [Actions tab](https://github.com/aws/mcp-proxy-for-aws/actions) + +### Version Numbering + +Follow [Semantic Versioning](https://semver.org/): +- **MAJOR** (v2.0.0): Breaking changes +- **MINOR** (v1.10.0): New features, backward compatible +- **PATCH** (v1.10.1): Bug fixes, backward compatible + +Version is managed in `pyproject.toml`: `version = "1.10.0"` + +Use Commitizen to bump versions automatically: +```bash +# Bump version based on conventional commits +uv run cz bump + +# This will update both files and create a git tag +``` + +### Testing Releases + +Before creating a production release, you can test with TestPyPI: + +```bash +# Create a test tag +git tag v1.10.0-beta +git push origin v1.10.0-beta + +# This triggers the TestPyPI workflow +# Monitor at: https://github.com/aws/mcp-proxy-for-aws/actions +``` + +### Troubleshooting Releases + +**Release workflow failed:** +- Check the Actions tab for error details +- Ensure all tests pass locally: `uv run pytest` +- Verify the tag follows semantic versioning format + +**Version already exists on PyPI:** +- PyPI doesn't allow re-uploading the same version +- Create a new patch version (e.g., v1.10.1) +- Delete and recreate the tag if needed: + ```bash + git tag -d v1.10.0 + git push origin --delete v1.10.0 + ``` + +**Trusted Publishing authentication failed:** +- Verify PyPI Trusted Publisher is configured correctly +- Check that the workflow file name matches the PyPI configuration +- Ensure the environment name is set to `pypi` + ## Additional Resources - [MCP Specification](https://spec.modelcontextprotocol.io/) - [FastMCP Documentation](https://fastmcp.readthedocs.io/) - [AWS SDK for Python (Boto3)](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html) - [Project README](README.md) +- [PyPI Publishing Guide](https://packaging.python.org/en/latest/guides/publishing-package-distribution-releases-using-github-actions-ci-cd-workflows/) --- From 6d04180621013453dcacceddbc26f314db47c770 Mon Sep 17 00:00:00 2001 From: Harvish N S Date: Wed, 12 Nov 2025 15:04:35 +0100 Subject: [PATCH 7/8] chore: simplifying uv install --- .github/workflows/pypi-publish-on-release.yml | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index 5027edd..815e402 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -37,13 +37,8 @@ jobs: with: persist-credentials: false - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: '3.10' - - - name: Install uv - run: pip install uv + - name: Set up uv + uses: astral-sh/setup-uv@v4 - name: Build distribution packages run: uv build From c65b56f608c774222bcebbac8ec890c4b55a753c Mon Sep 17 00:00:00 2001 From: Harvish N S Date: Wed, 12 Nov 2025 17:14:41 +0100 Subject: [PATCH 8/8] fix: adding input commit id for integ tests, and passing it in publish workflows. --- .github/workflows/pypi-publish-on-release.yml | 2 ++ .github/workflows/python-integ.yml | 9 ++++++++- .github/workflows/test-pypi-publish.yml | 4 ++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pypi-publish-on-release.yml b/.github/workflows/pypi-publish-on-release.yml index 815e402..d7567c9 100644 --- a/.github/workflows/pypi-publish-on-release.yml +++ b/.github/workflows/pypi-publish-on-release.yml @@ -24,6 +24,8 @@ jobs: contents: read actions: read uses: ./.github/workflows/python-integ.yml + with: + ref: ${{ github.event.release.target_commitish }} secrets: inherit build: diff --git a/.github/workflows/python-integ.yml b/.github/workflows/python-integ.yml index 43d454d..a6302c3 100644 --- a/.github/workflows/python-integ.yml +++ b/.github/workflows/python-integ.yml @@ -2,10 +2,15 @@ name: Python Integration Tests on: - workflow_call: workflow_dispatch: push: branches: main + workflow_call: + inputs: + ref: + description: 'Git ref to checkout' + required: false + type: string jobs: integration: @@ -24,6 +29,8 @@ jobs: steps: - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 + with: + ref: ${{ inputs.ref || github.ref }} - name: Install uv uses: astral-sh/setup-uv@85856786d1ce8acfbcc2f13a5f3fbd6b938f9f41 # v7.1.2 diff --git a/.github/workflows/test-pypi-publish.yml b/.github/workflows/test-pypi-publish.yml index a32eadc..3ced2fd 100644 --- a/.github/workflows/test-pypi-publish.yml +++ b/.github/workflows/test-pypi-publish.yml @@ -12,6 +12,8 @@ jobs: pull-requests: read security-events: write actions: read + with: + ref: ${{ github.event.release.target_commitish }} uses: ./.github/workflows/python.yml call-integ-tests: @@ -20,6 +22,8 @@ jobs: contents: read actions: read uses: ./.github/workflows/python-integ.yml + with: + ref: ${{ github.event.release.target_commitish }} secrets: inherit build: