Skip to content
Merged
1 change: 1 addition & 0 deletions mcp_proxy_for_aws/logging_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ def configure_logging(level: Optional[str] = None) -> None:
# Set httpx logging to WARNING by default to reduce noise
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('httpcore').setLevel(logging.WARNING)
logging.getLogger('botocore').setLevel(logging.WARNING)
131 changes: 99 additions & 32 deletions mcp_proxy_for_aws/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,23 @@
"""

import asyncio
import contextlib
import httpx
import logging
import sys
from fastmcp import Client
from fastmcp.client import ClientTransport
from fastmcp.server.middleware.error_handling import RetryMiddleware
from fastmcp.server.middleware.logging import LoggingMiddleware
from fastmcp.server.server import FastMCP
from mcp import McpError
from mcp.types import (
CONNECTION_CLOSED,
ErrorData,
JSONRPCError,
JSONRPCMessage,
JSONRPCResponse,
)
from mcp_proxy_for_aws.cli import parse_args
from mcp_proxy_for_aws.logging_config import configure_logging
from mcp_proxy_for_aws.middleware.tool_filter import ToolFilteringMiddleware
Expand All @@ -37,13 +48,75 @@
determine_aws_region,
determine_service_name,
)
from typing import Any


logger = logging.getLogger(__name__)


async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
@contextlib.asynccontextmanager
async def _initialize_client(transport: ClientTransport):
"""Handle the exceptions for during client initialize."""
# line = sys.stdin.readline()
# logger.debug('First line from kiro %s', line)
Comment on lines +59 to +60
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# line = sys.stdin.readline()
# logger.debug('First line from kiro %s', line)

async with contextlib.AsyncExitStack() as stack:
try:
client = await stack.enter_async_context(Client(transport))
if client.initialize_result:
print(
client.initialize_result.model_dump_json(
by_alias=True,
exclude_none=True,
),
file=sys.stdout,
)
except httpx.HTTPStatusError as http_error:
logger.error('HTTP Error during initialize %s', http_error)
response = http_error.response
try:
body = await response.aread()
jsonrpc_msg = JSONRPCMessage.model_validate_json(body).root
if isinstance(jsonrpc_msg, (JSONRPCError, JSONRPCResponse)):
line = jsonrpc_msg.model_dump_json(
by_alias=True,
exclude_none=True,
)
logger.debug('Writing the unhandled http error to stdout %s', http_error)
print(line, file=sys.stdout)
else:
logger.debug('Ignoring jsonrpc message type=%s', type(jsonrpc_msg))
except Exception as _:
logger.debug('Cannot read HTTP response body')
raise http_error
except Exception as e:
cause = e.__cause__
if isinstance(cause, McpError):
logger.error('MCP Error during initialize %s', cause.error)
jsonrpc_error = JSONRPCError(jsonrpc='2.0', id=0, error=cause.error)
line = jsonrpc_error.model_dump_json(
by_alias=True,
exclude_none=True,
)
else:
logger.error('Error during initialize %s', e)
jsonrpc_error = JSONRPCError(
jsonrpc='2.0',
id=0,
error=ErrorData(
code=CONNECTION_CLOSED,
message=str(e),
),
)
line = jsonrpc_error.model_dump_json(
by_alias=True,
exclude_none=True,
)
print(line, file=sys.stdout)
raise e
logger.debug('Initialized MCP client')
yield client


async def run_proxy(args) -> None:
"""Set up the server in MCP mode."""
logger.info('Setting up server in MCP mode')

Expand Down Expand Up @@ -84,16 +157,25 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
transport = create_transport_with_sigv4(
args.endpoint, service, region, metadata, timeout, profile
)
async with Client(transport=transport) as client:
# Create proxy with the transport
proxy = FastMCP.as_proxy(client)
add_logging_middleware(proxy, args.log_level)
add_tool_filtering_middleware(proxy, args.read_only)

if args.retries:
add_retry_middleware(proxy, args.retries)

await proxy.run_async()
async with _initialize_client(transport) as client:
try:
proxy = FastMCP.as_proxy(
client,
name='MCP Proxy for AWS',
instructions=(
'MCP Proxy for AWS provides access to SigV4 protected MCP servers through a single interface. '
'This proxy handles authentication and request routing to the appropriate backend services.'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: SigV4 authentication and we don't really do any routing.

),
)
add_logging_middleware(proxy, args.log_level)
add_tool_filtering_middleware(proxy, args.read_only)

if args.retries:
add_retry_middleware(proxy, args.retries)
await proxy.run_async(transport='stdio')
except Exception as e:
logger.error('Cannot start proxy server: %s', e)
raise e


def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None:
Expand Down Expand Up @@ -146,27 +228,12 @@ def main():
configure_logging(args.log_level)
logger.info('Starting MCP Proxy for AWS Server')

# Create FastMCP instance
mcp = FastMCP[Any](
name='MCP Proxy',
instructions=(
'MCP Proxy for AWS Server that provides access to backend servers through a single interface. '
'This proxy handles authentication and request routing to the appropriate backend services.'
),
)

async def setup_and_run():
try:
await setup_mcp_mode(mcp, args)

logger.info('Server setup complete, starting MCP server')

except Exception as e:
logger.error('Failed to start server: %s', e)
raise

# Run the server
asyncio.run(setup_and_run())
try:
asyncio.run(run_proxy(args))
except Exception:
logger.exception('Error launching MCP proxy for aws')
return 1


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ description = "MCP Proxy for AWS"
readme = "README.md"
requires-python = ">=3.10,<3.14"
dependencies = [
"fastmcp>=2.13.0.2",
"fastmcp>=2.13.1",
"boto3>=1.34.0",
"botocore>=1.34.0",
]
Expand Down
6 changes: 4 additions & 2 deletions tests/integ/mcp/simple_mcp_server/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ async def elicit_for_my_name(elicitation_expected: str, ctx: Context):
@mcp.tool
def echo_metadata(ctx: Context):
"""MCP Tool that echoes back the _meta field from the request."""
meta = ctx.request_context.meta
return {'received_meta': meta}
if ctx.request_context:
meta = ctx.request_context.meta
return {'received_meta': meta}
raise RuntimeError('No request context received')


#### Server Setup
Expand Down
174 changes: 174 additions & 0 deletions tests/unit/test_initialize_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for _initialize_client error handling."""

import httpx
import pytest
from mcp import McpError
from mcp.types import ErrorData, JSONRPCError, JSONRPCResponse
from mcp_proxy_for_aws.server import _initialize_client
from unittest.mock import AsyncMock, Mock, patch


@pytest.mark.asyncio
async def test_successful_initialization():
"""Test successful client initialization."""
mock_transport = Mock()
mock_client = Mock()

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(return_value=mock_client)
mock_client_class.return_value.__aexit__ = AsyncMock(return_value=None)

async with _initialize_client(mock_transport) as client:
assert client == mock_client


@pytest.mark.asyncio
async def test_http_error_with_jsonrpc_error(capsys):
"""Test HTTPStatusError with JSONRPCError response."""
mock_transport = Mock()
error_data = ErrorData(code=-32600, message='Invalid Request')
jsonrpc_error = JSONRPCError(jsonrpc='2.0', id=1, error=error_data)

mock_response = Mock()
mock_response.aread = AsyncMock(return_value=jsonrpc_error.model_dump_json().encode())

http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)

with pytest.raises(httpx.HTTPStatusError):
async with _initialize_client(mock_transport):
pass

captured = capsys.readouterr()
assert 'Invalid Request' in captured.out


@pytest.mark.asyncio
async def test_http_error_with_jsonrpc_response(capsys):
"""Test HTTPStatusError with JSONRPCResponse."""
mock_transport = Mock()
jsonrpc_response = JSONRPCResponse(jsonrpc='2.0', id=1, result={'status': 'error'})

mock_response = Mock()
mock_response.aread = AsyncMock(return_value=jsonrpc_response.model_dump_json().encode())

http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)

with pytest.raises(httpx.HTTPStatusError):
async with _initialize_client(mock_transport):
pass

captured = capsys.readouterr()
assert '"result":{"status":"error"}' in captured.out


@pytest.mark.asyncio
async def test_http_error_with_invalid_json():
"""Test HTTPStatusError with invalid JSON response."""
mock_transport = Mock()

mock_response = Mock()
mock_response.aread = AsyncMock(return_value=b'invalid json')

http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)

with pytest.raises(httpx.HTTPStatusError):
async with _initialize_client(mock_transport):
pass


@pytest.mark.asyncio
async def test_http_error_with_non_jsonrpc_message():
"""Test HTTPStatusError with non-JSONRPCError/Response message."""
mock_transport = Mock()

mock_response = Mock()
mock_response.aread = AsyncMock(return_value=b'{"jsonrpc":"2.0","method":"test"}')

http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)

with pytest.raises(httpx.HTTPStatusError):
async with _initialize_client(mock_transport):
pass


@pytest.mark.asyncio
async def test_http_error_response_read_failure():
"""Test HTTPStatusError when response.aread() fails."""
mock_transport = Mock()

mock_response = Mock()
mock_response.aread = AsyncMock(side_effect=Exception('Read failed'))

http_error = httpx.HTTPStatusError('error', request=Mock(), response=mock_response)

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=http_error)

with pytest.raises(httpx.HTTPStatusError):
async with _initialize_client(mock_transport):
pass


@pytest.mark.asyncio
async def test_generic_error_with_mcp_error_cause(capsys):
"""Test generic exception with McpError as cause."""
mock_transport = Mock()
error_data = ErrorData(code=-32601, message='Method not found')
mcp_error = McpError(error_data)
generic_error = Exception('Wrapper error')
generic_error.__cause__ = mcp_error

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=generic_error)

with pytest.raises(Exception):
async with _initialize_client(mock_transport):
pass

captured = capsys.readouterr()
assert 'Method not found' in captured.out
assert '"code":-32601' in captured.out


@pytest.mark.asyncio
async def test_generic_error_without_mcp_error_cause(capsys):
"""Test generic exception without McpError cause."""
mock_transport = Mock()
generic_error = Exception('Generic error')

with patch('mcp_proxy_for_aws.server.Client') as mock_client_class:
mock_client_class.return_value.__aenter__ = AsyncMock(side_effect=generic_error)

with pytest.raises(Exception):
async with _initialize_client(mock_transport):
pass

captured = capsys.readouterr()
assert 'Generic error' in captured.out
assert '"code":-32000' in captured.out
Loading
Loading