diff --git a/src/mcp_optimizer/mcp_client.py b/src/mcp_optimizer/mcp_client.py index 6dfbd12..300d23e 100644 --- a/src/mcp_optimizer/mcp_client.py +++ b/src/mcp_optimizer/mcp_client.py @@ -106,29 +106,6 @@ def _determine_proxy_mode(self) -> ToolHiveProxyMode: ) return url_to_toolhive_proxy_mode(self.workload.url) - def _prepare_url_for_proxy_mode(self, proxy_mode: ToolHiveProxyMode) -> str: - """ - Prepare the URL with the correct path for the transport type. - - Args: - proxy_mode: The proxy mode being used - - Returns: - The URL with the correct path appended - """ - url = self.workload.url - if proxy_mode == ToolHiveProxyMode.STREAMABLE: - # Streamable HTTP expects /mcp path - if not url.endswith("/mcp"): - url = url.rstrip("/") + "/mcp" - logger.debug(f"Appended /mcp path to URL: {url}", workload=self.workload.name) - elif proxy_mode == ToolHiveProxyMode.SSE: - # SSE expects /sse path - if not url.endswith("/sse"): - url = url.rstrip("/") + "/sse" - logger.debug(f"Appended /sse path to URL: {url}", workload=self.workload.name) - return url - async def _execute_with_session(self, operation: Callable[[ClientSession], Awaitable]) -> Any: """ Execute an operation with an MCP session. @@ -151,15 +128,9 @@ async def _execute_with_session(self, operation: Callable[[ClientSession], Await f"Using {proxy_mode} client for workload", workload=self.workload.name, proxy_mode_field=self.workload.proxy_mode, + url=self.workload.url, ) - # Prepare URL with correct path for the transport type - url = self._prepare_url_for_proxy_mode(proxy_mode) - - # Temporarily override the URL for this connection - original_url = self.workload.url - self.workload.url = url - try: if proxy_mode == ToolHiveProxyMode.STREAMABLE: return await asyncio.wait_for( @@ -197,9 +168,6 @@ async def _execute_with_session(self, operation: Callable[[ClientSession], Await error=str(e), ) raise WorkloadConnectionError(f"MCP protocol error: {e}") from e - finally: - # Restore original URL - self.workload.url = original_url async def _execute_streamable_session( self, operation: Callable[[ClientSession], Awaitable] diff --git a/tests/test_mcp_client.py b/tests/test_mcp_client.py index 062701d..2cced53 100644 --- a/tests/test_mcp_client.py +++ b/tests/test_mcp_client.py @@ -236,3 +236,179 @@ def test_extract_error_from_exception_group(): result = client._extract_error_from_exception_group(eg) assert "ValueError" in result assert "Some error" in result + + +@pytest.fixture +def mock_mcp_session(): + """Create a mock MCP session for testing.""" + mock_session = AsyncMock() + mock_list_result = AsyncMock() + mock_list_result.tools = [] + mock_session.list_tools.return_value = mock_list_result + + mock_call_result = AsyncMock() + mock_call_result.content = [AsyncMock(text="Tool result")] + mock_session.call_tool.return_value = mock_call_result + + return mock_session + + +@pytest.mark.parametrize( + "url,proxy_mode", + [ + ("http://localhost:8080/sse/test-server", None), + ("http://localhost:8080/mcp/test-server", "streamable-http"), + ("http://localhost:8080/custom/endpoint", "sse"), + ], +) +def test_workload_url_unchanged_after_init(url, proxy_mode): + """Test that workload URL is not modified during MCPServerClient initialization.""" + workload = Workload( + name="test-server", + url=url, + proxy_mode=proxy_mode, + status="running", + tool_type="mcp", + ) + + # Create client + _client = MCPServerClient(workload, timeout=10) + + # Verify URL is unchanged + assert workload.url == url + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "url,client_mock_name,context_return", + [ + ( + "http://localhost:8080/mcp/test-server", + "streamablehttp_client", + (AsyncMock(), AsyncMock(), AsyncMock()), + ), + ("http://localhost:8080/sse/test-server", "sse_client", (AsyncMock(), AsyncMock())), + ], +) +async def test_workload_url_unchanged_during_list_tools( + url, client_mock_name, context_return, mock_mcp_session +): + """Test that workload URL remains unchanged during list_tools for both transport types.""" + workload = Workload( + name="test-server", + url=url, + status="running", + tool_type="mcp", + ) + + client = MCPServerClient(workload, timeout=10) + + with ( + patch(f"mcp_optimizer.mcp_client.{client_mock_name}") as mock_client, + patch( + "mcp_optimizer.mcp_client.ClientSession", return_value=mock_mcp_session + ) as mock_session_class, + ): + # Mock the context manager + mock_client.return_value.__aenter__.return_value = context_return + mock_session_class.return_value.__aenter__.return_value = mock_mcp_session + + # Call list_tools + await client.list_tools() + + # Verify URL is unchanged in workload + assert workload.url == url + + # Verify the client was called with the original URL + mock_client.assert_called_once_with(url) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "url,proxy_mode,client_mock_name,context_return", + [ + ( + "http://localhost:8080/mcp/test-server", + None, + "streamablehttp_client", + (AsyncMock(), AsyncMock(), AsyncMock()), + ), + ( + "http://localhost:8080/custom/endpoint", + "streamable-http", + "streamablehttp_client", + (AsyncMock(), AsyncMock(), AsyncMock()), + ), + ], +) +async def test_workload_url_unchanged_during_call_tool( + url, proxy_mode, client_mock_name, context_return, mock_mcp_session +): + """Test that workload URL remains unchanged during call_tool.""" + workload = Workload( + name="test-server", + url=url, + proxy_mode=proxy_mode, + status="running", + tool_type="mcp", + ) + + client = MCPServerClient(workload, timeout=10) + + with ( + patch(f"mcp_optimizer.mcp_client.{client_mock_name}") as mock_client, + patch( + "mcp_optimizer.mcp_client.ClientSession", return_value=mock_mcp_session + ) as mock_session_class, + ): + # Mock the context manager + mock_client.return_value.__aenter__.return_value = context_return + mock_session_class.return_value.__aenter__.return_value = mock_mcp_session + + # Call tool + await client.call_tool("test_tool", {"param": "value"}) + + # Verify URL is unchanged in workload + assert workload.url == url + + # Verify the client was called with the original URL + mock_client.assert_called_once_with(url) + + +@pytest.mark.asyncio +async def test_workload_url_unchanged_multiple_operations(mock_mcp_session): + """Test that workload URL remains unchanged across multiple operations.""" + original_url = "http://localhost:8080/mcp/test-server" + workload = Workload( + name="test-server", + url=original_url, + status="running", + tool_type="mcp", + ) + + client = MCPServerClient(workload, timeout=10) + + with ( + patch("mcp_optimizer.mcp_client.streamablehttp_client") as mock_client, + patch( + "mcp_optimizer.mcp_client.ClientSession", return_value=mock_mcp_session + ) as mock_session_class, + ): + # Mock the context manager + mock_client.return_value.__aenter__.return_value = (AsyncMock(), AsyncMock(), AsyncMock()) + mock_session_class.return_value.__aenter__.return_value = mock_mcp_session + + # Perform multiple operations + await client.list_tools() + assert workload.url == original_url + + await client.call_tool("test_tool", {"param": "value"}) + assert workload.url == original_url + + await client.list_tools() + assert workload.url == original_url + + # Verify the client was always called with the original URL + assert mock_client.call_count == 3 + for call in mock_client.call_args_list: + assert call[0][0] == original_url