diff --git a/mcp_proxy_for_aws/server.py b/mcp_proxy_for_aws/server.py index adbc0f0..c782bd1 100644 --- a/mcp_proxy_for_aws/server.py +++ b/mcp_proxy_for_aws/server.py @@ -25,6 +25,7 @@ import asyncio import httpx import logging +from fastmcp import Client from fastmcp.server.middleware.error_handling import RetryMiddleware from fastmcp.server.middleware.logging import LoggingMiddleware from fastmcp.server.server import FastMCP @@ -83,16 +84,16 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None: transport = create_transport_with_sigv4( args.endpoint, service, region, metadata, timeout, profile ) + async with Client(transport=transport) as client: + # Create proxy with the transport + proxy = FastMCP.as_proxy(client) + add_logging_middleware(proxy, args.log_level) + add_tool_filtering_middleware(proxy, args.read_only) - # Create proxy with the transport - proxy = FastMCP.as_proxy(transport) - add_logging_middleware(proxy, args.log_level) - add_tool_filtering_middleware(proxy, args.read_only) + if args.retries: + add_retry_middleware(proxy, args.retries) - if args.retries: - add_retry_middleware(proxy, args.retries) - - await proxy.run_async() + await proxy.run_async() def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None: diff --git a/tests/unit/test_server.py b/tests/unit/test_server.py index 63350a9..45233ea 100644 --- a/tests/unit/test_server.py +++ b/tests/unit/test_server.py @@ -15,6 +15,7 @@ """Tests for the mcp-proxy-for-aws Server.""" import pytest +from fastmcp.client.transports import ClientTransport from fastmcp.server.server import FastMCP from mcp_proxy_for_aws.server import ( add_retry_middleware, @@ -31,6 +32,7 @@ class TestServer: """Tests for the server module.""" + @patch('mcp_proxy_for_aws.server.Client') @patch('mcp_proxy_for_aws.server.create_transport_with_sigv4') @patch('mcp_proxy_for_aws.server.FastMCP.as_proxy') @patch('mcp_proxy_for_aws.server.determine_aws_region') @@ -45,6 +47,7 @@ async def test_setup_mcp_mode( mock_determine_region, mock_as_proxy, mock_create_transport, + mock_client_class, ): """Test that MCP mode is set up correctly.""" # Arrange @@ -68,9 +71,15 @@ async def test_setup_mcp_mode( mock_determine_service.return_value = 'test-service' mock_determine_region.return_value = 'us-east-1' - # Mock the transport and proxy - mock_transport = Mock() + # Mock the transport and client + mock_transport = Mock(spec=ClientTransport) mock_create_transport.return_value = mock_transport + + mock_client = Mock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + mock_proxy = Mock() mock_proxy.run_async = AsyncMock() mock_as_proxy.return_value = mock_proxy @@ -90,11 +99,13 @@ async def test_setup_mcp_mode( assert call_args[0][3] == {'AWS_REGION': 'us-east-1'} # metadata # call_args[0][4] is the Timeout object assert call_args[0][5] is None # profile - mock_as_proxy.assert_called_once_with(mock_transport) + mock_client_class.assert_called_once_with(transport=mock_transport) + mock_as_proxy.assert_called_once_with(mock_client) mock_add_filtering.assert_called_once_with(mock_proxy, True) mock_add_retry.assert_called_once_with(mock_proxy, 1) mock_proxy.run_async.assert_called_once() + @patch('mcp_proxy_for_aws.server.Client') @patch('mcp_proxy_for_aws.server.create_transport_with_sigv4') @patch('mcp_proxy_for_aws.server.FastMCP.as_proxy') @patch('mcp_proxy_for_aws.server.determine_aws_region') @@ -107,6 +118,7 @@ async def test_setup_mcp_mode_no_retries( mock_determine_region, mock_as_proxy, mock_create_transport, + mock_client_class, ): """Test that MCP mode setup without retries doesn't add retry middleware.""" # Arrange @@ -130,9 +142,15 @@ async def test_setup_mcp_mode_no_retries( mock_determine_service.return_value = 'test-service' mock_determine_region.return_value = 'us-east-1' - # Mock the transport and proxy - mock_transport = Mock() + # Mock the transport and client + mock_transport = Mock(spec=ClientTransport) mock_create_transport.return_value = mock_transport + + mock_client = Mock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + mock_proxy = Mock() mock_proxy.run_async = AsyncMock() mock_as_proxy.return_value = mock_proxy @@ -155,10 +173,12 @@ async def test_setup_mcp_mode_no_retries( } # metadata # call_args[0][4] is the Timeout object assert call_args[0][5] == 'test-profile' # profile - mock_as_proxy.assert_called_once_with(mock_transport) + mock_client_class.assert_called_once_with(transport=mock_transport) + mock_as_proxy.assert_called_once_with(mock_client) mock_add_filtering.assert_called_once_with(mock_proxy, False) mock_proxy.run_async.assert_called_once() + @patch('mcp_proxy_for_aws.server.Client') @patch('mcp_proxy_for_aws.server.create_transport_with_sigv4') @patch('mcp_proxy_for_aws.server.FastMCP.as_proxy') @patch('mcp_proxy_for_aws.server.determine_aws_region') @@ -171,6 +191,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region( mock_determine_region, mock_as_proxy, mock_create_transport, + mock_client_class, ): """Test that AWS_REGION is automatically injected when no metadata is provided.""" # Arrange @@ -192,8 +213,14 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region( mock_determine_service.return_value = 'test-service' mock_determine_region.return_value = 'ap-southeast-1' - mock_transport = Mock() + mock_transport = Mock(spec=ClientTransport) mock_create_transport.return_value = mock_transport + + mock_client = Mock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + mock_proxy = Mock() mock_proxy.run_async = AsyncMock() mock_as_proxy.return_value = mock_proxy @@ -207,6 +234,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region( metadata = call_args[0][3] assert metadata == {'AWS_REGION': 'ap-southeast-1'} + @patch('mcp_proxy_for_aws.server.Client') @patch('mcp_proxy_for_aws.server.create_transport_with_sigv4') @patch('mcp_proxy_for_aws.server.FastMCP.as_proxy') @patch('mcp_proxy_for_aws.server.determine_aws_region') @@ -219,6 +247,7 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it( mock_determine_region, mock_as_proxy, mock_create_transport, + mock_client_class, ): """Test that AWS_REGION is injected even when other metadata is provided.""" # Arrange @@ -240,8 +269,14 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it( mock_determine_service.return_value = 'test-service' mock_determine_region.return_value = 'us-west-1' - mock_transport = Mock() + mock_transport = Mock(spec=ClientTransport) mock_create_transport.return_value = mock_transport + + mock_client = Mock() + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + mock_proxy = Mock() mock_proxy.run_async = AsyncMock() mock_as_proxy.return_value = mock_proxy