Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions mcp_proxy_for_aws/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import asyncio
import httpx
import logging
from fastmcp import Client
from fastmcp.server.middleware.error_handling import RetryMiddleware
from fastmcp.server.middleware.logging import LoggingMiddleware
from fastmcp.server.server import FastMCP
Expand Down Expand Up @@ -83,16 +84,16 @@ async def setup_mcp_mode(local_mcp: FastMCP, args) -> None:
transport = create_transport_with_sigv4(
args.endpoint, service, region, metadata, timeout, profile
)
async with Client(transport=transport) as client:
# Create proxy with the transport
proxy = FastMCP.as_proxy(client)
add_logging_middleware(proxy, args.log_level)
add_tool_filtering_middleware(proxy, args.read_only)

# Create proxy with the transport
proxy = FastMCP.as_proxy(transport)
add_logging_middleware(proxy, args.log_level)
add_tool_filtering_middleware(proxy, args.read_only)
if args.retries:
add_retry_middleware(proxy, args.retries)

if args.retries:
add_retry_middleware(proxy, args.retries)

await proxy.run_async()
await proxy.run_async()


def add_tool_filtering_middleware(mcp: FastMCP, read_only: bool = False) -> None:
Expand Down
51 changes: 43 additions & 8 deletions tests/unit/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tests for the mcp-proxy-for-aws Server."""

import pytest
from fastmcp.client.transports import ClientTransport
from fastmcp.server.server import FastMCP
from mcp_proxy_for_aws.server import (
add_retry_middleware,
Expand All @@ -31,6 +32,7 @@
class TestServer:
"""Tests for the server module."""

@patch('mcp_proxy_for_aws.server.Client')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
Expand All @@ -45,6 +47,7 @@ async def test_setup_mcp_mode(
mock_determine_region,
mock_as_proxy,
mock_create_transport,
mock_client_class,
):
"""Test that MCP mode is set up correctly."""
# Arrange
Expand All @@ -68,9 +71,15 @@ async def test_setup_mcp_mode(
mock_determine_service.return_value = 'test-service'
mock_determine_region.return_value = 'us-east-1'

# Mock the transport and proxy
mock_transport = Mock()
# Mock the transport and client
mock_transport = Mock(spec=ClientTransport)
mock_create_transport.return_value = mock_transport

mock_client = Mock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_class.return_value = mock_client

mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_as_proxy.return_value = mock_proxy
Expand All @@ -90,11 +99,13 @@ async def test_setup_mcp_mode(
assert call_args[0][3] == {'AWS_REGION': 'us-east-1'} # metadata
# call_args[0][4] is the Timeout object
assert call_args[0][5] is None # profile
mock_as_proxy.assert_called_once_with(mock_transport)
mock_client_class.assert_called_once_with(transport=mock_transport)
mock_as_proxy.assert_called_once_with(mock_client)
mock_add_filtering.assert_called_once_with(mock_proxy, True)
mock_add_retry.assert_called_once_with(mock_proxy, 1)
mock_proxy.run_async.assert_called_once()

@patch('mcp_proxy_for_aws.server.Client')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
Expand All @@ -107,6 +118,7 @@ async def test_setup_mcp_mode_no_retries(
mock_determine_region,
mock_as_proxy,
mock_create_transport,
mock_client_class,
):
"""Test that MCP mode setup without retries doesn't add retry middleware."""
# Arrange
Expand All @@ -130,9 +142,15 @@ async def test_setup_mcp_mode_no_retries(
mock_determine_service.return_value = 'test-service'
mock_determine_region.return_value = 'us-east-1'

# Mock the transport and proxy
mock_transport = Mock()
# Mock the transport and client
mock_transport = Mock(spec=ClientTransport)
mock_create_transport.return_value = mock_transport

mock_client = Mock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_class.return_value = mock_client

mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_as_proxy.return_value = mock_proxy
Expand All @@ -155,10 +173,12 @@ async def test_setup_mcp_mode_no_retries(
} # metadata
# call_args[0][4] is the Timeout object
assert call_args[0][5] == 'test-profile' # profile
mock_as_proxy.assert_called_once_with(mock_transport)
mock_client_class.assert_called_once_with(transport=mock_transport)
mock_as_proxy.assert_called_once_with(mock_client)
mock_add_filtering.assert_called_once_with(mock_proxy, False)
mock_proxy.run_async.assert_called_once()

@patch('mcp_proxy_for_aws.server.Client')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
Expand All @@ -171,6 +191,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
mock_determine_region,
mock_as_proxy,
mock_create_transport,
mock_client_class,
):
"""Test that AWS_REGION is automatically injected when no metadata is provided."""
# Arrange
Expand All @@ -192,8 +213,14 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
mock_determine_service.return_value = 'test-service'
mock_determine_region.return_value = 'ap-southeast-1'

mock_transport = Mock()
mock_transport = Mock(spec=ClientTransport)
mock_create_transport.return_value = mock_transport

mock_client = Mock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_class.return_value = mock_client

mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_as_proxy.return_value = mock_proxy
Expand All @@ -207,6 +234,7 @@ async def test_setup_mcp_mode_no_metadata_injects_aws_region(
metadata = call_args[0][3]
assert metadata == {'AWS_REGION': 'ap-southeast-1'}

@patch('mcp_proxy_for_aws.server.Client')
@patch('mcp_proxy_for_aws.server.create_transport_with_sigv4')
@patch('mcp_proxy_for_aws.server.FastMCP.as_proxy')
@patch('mcp_proxy_for_aws.server.determine_aws_region')
Expand All @@ -219,6 +247,7 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
mock_determine_region,
mock_as_proxy,
mock_create_transport,
mock_client_class,
):
"""Test that AWS_REGION is injected even when other metadata is provided."""
# Arrange
Expand All @@ -240,8 +269,14 @@ async def test_setup_mcp_mode_metadata_without_aws_region_injects_it(
mock_determine_service.return_value = 'test-service'
mock_determine_region.return_value = 'us-west-1'

mock_transport = Mock()
mock_transport = Mock(spec=ClientTransport)
mock_create_transport.return_value = mock_transport

mock_client = Mock()
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=None)
mock_client_class.return_value = mock_client

mock_proxy = Mock()
mock_proxy.run_async = AsyncMock()
mock_as_proxy.return_value = mock_proxy
Expand Down
Loading