Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 1 addition & 33 deletions src/mcp_optimizer/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,29 +106,6 @@ def _determine_proxy_mode(self) -> ToolHiveProxyMode:
)
return url_to_toolhive_proxy_mode(self.workload.url)

def _prepare_url_for_proxy_mode(self, proxy_mode: ToolHiveProxyMode) -> str:
"""
Prepare the URL with the correct path for the transport type.

Args:
proxy_mode: The proxy mode being used

Returns:
The URL with the correct path appended
"""
url = self.workload.url
if proxy_mode == ToolHiveProxyMode.STREAMABLE:
# Streamable HTTP expects /mcp path
if not url.endswith("/mcp"):
url = url.rstrip("/") + "/mcp"
logger.debug(f"Appended /mcp path to URL: {url}", workload=self.workload.name)
elif proxy_mode == ToolHiveProxyMode.SSE:
# SSE expects /sse path
if not url.endswith("/sse"):
url = url.rstrip("/") + "/sse"
logger.debug(f"Appended /sse path to URL: {url}", workload=self.workload.name)
return url

async def _execute_with_session(self, operation: Callable[[ClientSession], Awaitable]) -> Any:
"""
Execute an operation with an MCP session.
Expand All @@ -151,15 +128,9 @@ async def _execute_with_session(self, operation: Callable[[ClientSession], Await
f"Using {proxy_mode} client for workload",
workload=self.workload.name,
proxy_mode_field=self.workload.proxy_mode,
url=self.workload.url,
)

# Prepare URL with correct path for the transport type
url = self._prepare_url_for_proxy_mode(proxy_mode)

# Temporarily override the URL for this connection
original_url = self.workload.url
self.workload.url = url

try:
if proxy_mode == ToolHiveProxyMode.STREAMABLE:
return await asyncio.wait_for(
Expand Down Expand Up @@ -197,9 +168,6 @@ async def _execute_with_session(self, operation: Callable[[ClientSession], Await
error=str(e),
)
raise WorkloadConnectionError(f"MCP protocol error: {e}") from e
finally:
# Restore original URL
self.workload.url = original_url

async def _execute_streamable_session(
self, operation: Callable[[ClientSession], Awaitable]
Expand Down
176 changes: 176 additions & 0 deletions tests/test_mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,179 @@ def test_extract_error_from_exception_group():
result = client._extract_error_from_exception_group(eg)
assert "ValueError" in result
assert "Some error" in result


@pytest.fixture
def mock_mcp_session():
"""Create a mock MCP session for testing."""
mock_session = AsyncMock()
mock_list_result = AsyncMock()
mock_list_result.tools = []
mock_session.list_tools.return_value = mock_list_result

mock_call_result = AsyncMock()
mock_call_result.content = [AsyncMock(text="Tool result")]
mock_session.call_tool.return_value = mock_call_result

return mock_session


@pytest.mark.parametrize(
"url,proxy_mode",
[
("http://localhost:8080/sse/test-server", None),
("http://localhost:8080/mcp/test-server", "streamable-http"),
("http://localhost:8080/custom/endpoint", "sse"),
],
)
def test_workload_url_unchanged_after_init(url, proxy_mode):
"""Test that workload URL is not modified during MCPServerClient initialization."""
workload = Workload(
name="test-server",
url=url,
proxy_mode=proxy_mode,
status="running",
tool_type="mcp",
)

# Create client
_client = MCPServerClient(workload, timeout=10)

# Verify URL is unchanged
assert workload.url == url


@pytest.mark.asyncio
@pytest.mark.parametrize(
"url,client_mock_name,context_return",
[
(
"http://localhost:8080/mcp/test-server",
"streamablehttp_client",
(AsyncMock(), AsyncMock(), AsyncMock()),
),
("http://localhost:8080/sse/test-server", "sse_client", (AsyncMock(), AsyncMock())),
],
)
async def test_workload_url_unchanged_during_list_tools(
url, client_mock_name, context_return, mock_mcp_session
):
"""Test that workload URL remains unchanged during list_tools for both transport types."""
workload = Workload(
name="test-server",
url=url,
status="running",
tool_type="mcp",
)

client = MCPServerClient(workload, timeout=10)

with (
patch(f"mcp_optimizer.mcp_client.{client_mock_name}") as mock_client,
patch(
"mcp_optimizer.mcp_client.ClientSession", return_value=mock_mcp_session
) as mock_session_class,
):
# Mock the context manager
mock_client.return_value.__aenter__.return_value = context_return
mock_session_class.return_value.__aenter__.return_value = mock_mcp_session

# Call list_tools
await client.list_tools()

# Verify URL is unchanged in workload
assert workload.url == url

# Verify the client was called with the original URL
mock_client.assert_called_once_with(url)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"url,proxy_mode,client_mock_name,context_return",
[
(
"http://localhost:8080/mcp/test-server",
None,
"streamablehttp_client",
(AsyncMock(), AsyncMock(), AsyncMock()),
),
(
"http://localhost:8080/custom/endpoint",
"streamable-http",
"streamablehttp_client",
(AsyncMock(), AsyncMock(), AsyncMock()),
),
],
)
async def test_workload_url_unchanged_during_call_tool(
url, proxy_mode, client_mock_name, context_return, mock_mcp_session
):
"""Test that workload URL remains unchanged during call_tool."""
workload = Workload(
name="test-server",
url=url,
proxy_mode=proxy_mode,
status="running",
tool_type="mcp",
)

client = MCPServerClient(workload, timeout=10)

with (
patch(f"mcp_optimizer.mcp_client.{client_mock_name}") as mock_client,
patch(
"mcp_optimizer.mcp_client.ClientSession", return_value=mock_mcp_session
) as mock_session_class,
):
# Mock the context manager
mock_client.return_value.__aenter__.return_value = context_return
mock_session_class.return_value.__aenter__.return_value = mock_mcp_session

# Call tool
await client.call_tool("test_tool", {"param": "value"})

# Verify URL is unchanged in workload
assert workload.url == url

# Verify the client was called with the original URL
mock_client.assert_called_once_with(url)


@pytest.mark.asyncio
async def test_workload_url_unchanged_multiple_operations(mock_mcp_session):
"""Test that workload URL remains unchanged across multiple operations."""
original_url = "http://localhost:8080/mcp/test-server"
workload = Workload(
name="test-server",
url=original_url,
status="running",
tool_type="mcp",
)

client = MCPServerClient(workload, timeout=10)

with (
patch("mcp_optimizer.mcp_client.streamablehttp_client") as mock_client,
patch(
"mcp_optimizer.mcp_client.ClientSession", return_value=mock_mcp_session
) as mock_session_class,
):
# Mock the context manager
mock_client.return_value.__aenter__.return_value = (AsyncMock(), AsyncMock(), AsyncMock())
mock_session_class.return_value.__aenter__.return_value = mock_mcp_session

# Perform multiple operations
await client.list_tools()
assert workload.url == original_url

await client.call_tool("test_tool", {"param": "value"})
assert workload.url == original_url

await client.list_tools()
assert workload.url == original_url

# Verify the client was always called with the original URL
assert mock_client.call_count == 3
for call in mock_client.call_args_list:
assert call[0][0] == original_url