From dae7d08ff2a705d6d6bbe081bd284f829f2ec0ff Mon Sep 17 00:00:00 2001 From: Yuta Saito Date: Sun, 28 Sep 2025 08:23:42 +0900 Subject: [PATCH 1/3] Revert "Revert "Merge pull request #14720 from uc4w6c/feat/remove-servername-prefix-mcp_tools"" This reverts commit a88d774f9467d09c3ce638759a268968d9823001. --- .../mcp_server/mcp_server_manager.py | 19 ++- .../mcp_server/rest_endpoints.py | 1 + .../proxy/_experimental/mcp_server/server.py | 4 + .../mcp_server/test_mcp_server.py | 121 +++++++++++++++++- .../mcp_server/test_mcp_server_manager.py | 104 +++++++++++++++ 5 files changed, 240 insertions(+), 9 deletions(-) diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 6423a9ae1531..3cf1f764d139 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -443,6 +443,7 @@ async def _get_tools_from_server( server: MCPServer, mcp_auth_header: Optional[Union[str, Dict[str, str]]] = None, extra_headers: Optional[Dict[str, str]] = None, + add_prefix: bool = True, ) -> List[MCPTool]: """ Helper method to get tools from a single MCP server with prefixed names. @@ -468,9 +469,11 @@ async def _get_tools_from_server( tools = await self._fetch_tools_with_timeout(client, server.name) - prefixed_tools = self._create_prefixed_tools(tools, server) + prefixed_or_original_tools = self._create_prefixed_tools( + tools, server, add_prefix=add_prefix + ) - return prefixed_tools + return prefixed_or_original_tools except Exception as e: verbose_logger.warning( @@ -539,7 +542,7 @@ async def _list_tools_task(): return [] def _create_prefixed_tools( - self, tools: List[MCPTool], server: MCPServer + self, tools: List[MCPTool], server: MCPServer, add_prefix: bool = True ) -> List[MCPTool]: """ Create prefixed tools and update tool mapping. @@ -557,14 +560,16 @@ def _create_prefixed_tools( for tool in tools: prefixed_name = add_server_prefix_to_tool_name(tool.name, prefix) - prefixed_tool = MCPTool( - name=prefixed_name, + name_to_use = prefixed_name if add_prefix else tool.name + + tool_obj = MCPTool( + name=name_to_use, description=tool.description, inputSchema=tool.inputSchema, ) - prefixed_tools.append(prefixed_tool) + prefixed_tools.append(tool_obj) - # Update tool to server mapping with both original and prefixed names + # Update tool to server mapping for resolution (support both forms) self.tool_name_to_mcp_server_name_mapping[tool.name] = prefix self.tool_name_to_mcp_server_name_mapping[prefixed_name] = prefix diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index ef770a9d4332..236dfdb6cf8c 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -73,6 +73,7 @@ async def _get_tools_for_single_server(server, server_auth_header): tools = await global_mcp_server_manager._get_tools_from_server( server=server, mcp_auth_header=server_auth_header, + add_prefix=False, ) return _create_tool_response_objects(tools, server.mcp_info) diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 94f81ccef8d3..9c2551fec218 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -414,6 +414,9 @@ async def _get_tools_from_mcp_servers( allowed_mcp_servers=allowed_mcp_servers, ) + # Decide whether to add prefix based on number of allowed servers + add_prefix = not (len(allowed_mcp_servers) == 1) + # Get tools from each allowed server all_tools = [] for server_id in allowed_mcp_servers: @@ -448,6 +451,7 @@ async def _get_tools_from_mcp_servers( server=server, mcp_auth_header=server_auth_header, extra_headers=extra_headers, + add_prefix=add_prefix, ) all_tools.extend(filter_tools_by_allowed_tools(tools, server)) verbose_logger.debug( diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py index 10a6ab8cb01e..6a1d43b33e1e 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py @@ -103,7 +103,7 @@ async def test_get_tools_from_mcp_servers_continues_when_one_server_fails(): ) async def mock_get_tools_from_server( - server, mcp_auth_header=None, extra_headers=None + server, mcp_auth_header=None, extra_headers=None, add_prefix=True ): if server.name == "working_server": # Working server returns tools @@ -187,7 +187,7 @@ async def test_get_tools_from_mcp_servers_handles_all_servers_failing(): ) async def mock_get_tools_from_server( - server, mcp_auth_header=None, extra_headers=None + server, mcp_auth_header=None, extra_headers=None, add_prefix=True ): # All servers fail raise Exception(f"Server {server.name} connection failed") @@ -564,3 +564,120 @@ async def mock_fetch_tools_with_timeout(client, server_name): captured_client_args["extra_headers"]["Authorization"] == "Bearer github_oauth_token_12345" ) + +@pytest.mark.asyncio +async def test_list_tools_single_server_unprefixed_names(): + """When only one MCP server is allowed, list tools should return unprefixed names.""" + try: + from litellm.proxy._experimental.mcp_server.server import ( + _get_tools_from_mcp_servers, + set_auth_context, + ) + except ImportError: + pytest.skip("MCP server not available") + + # Mock user auth + user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user") + set_auth_context(user_api_key_auth) + + # One allowed server + server = MagicMock() + server.server_id = "server1" + server.name = "Zapier MCP" + server.alias = "zapier" + + # Mock manager: allow just one server and return a tool based on add_prefix flag + mock_manager = MagicMock() + mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"]) + mock_manager.get_mcp_server_by_id = ( + lambda server_id: server if server_id == "server1" else None + ) + + async def mock_get_tools_from_server( + server, mcp_auth_header=None, extra_headers=None, add_prefix=False + ): + tool = MagicMock() + tool.name = f"{server.alias}-toolA" if add_prefix else "toolA" + tool.description = "desc" + tool.inputSchema = {} + return [tool] + + mock_manager._get_tools_from_server = mock_get_tools_from_server + + with patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", + mock_manager, + ): + tools = await _get_tools_from_mcp_servers( + user_api_key_auth=user_api_key_auth, + mcp_auth_header=None, + mcp_servers=None, + mcp_server_auth_headers=None, + ) + + # Should be unprefixed since only one server is allowed + assert len(tools) == 1 + assert tools[0].name == "toolA" + + +@pytest.mark.asyncio +async def test_list_tools_multiple_servers_prefixed_names(): + """When multiple MCP servers are allowed, list tools should return prefixed names.""" + try: + from litellm.proxy._experimental.mcp_server.server import ( + _get_tools_from_mcp_servers, + set_auth_context, + ) + except ImportError: + pytest.skip("MCP server not available") + + # Mock user auth + user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user") + set_auth_context(user_api_key_auth) + + # Two allowed servers + server1 = MagicMock() + server1.server_id = "server1" + server1.name = "Zapier MCP" + server1.alias = "zapier" + + server2 = MagicMock() + server2.server_id = "server2" + server2.name = "Jira MCP" + server2.alias = "jira" + + # Mock manager + mock_manager = MagicMock() + mock_manager.get_allowed_mcp_servers = AsyncMock( + return_value=["server1", "server2"] + ) + mock_manager.get_mcp_server_by_id = ( + lambda server_id: server1 if server_id == "server1" else server2 + ) + + async def mock_get_tools_from_server( + server, mcp_auth_header=None, extra_headers=None, add_prefix=True + ): + tool = MagicMock() + # When multiple servers, add_prefix should be True -> prefixed names + tool.name = f"{server.alias}-toolA" if add_prefix else "toolA" + tool.description = "desc" + tool.inputSchema = {} + return [tool] + + mock_manager._get_tools_from_server = mock_get_tools_from_server + + with patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", + mock_manager, + ): + tools = await _get_tools_from_mcp_servers( + user_api_key_auth=user_api_key_auth, + mcp_auth_header=None, + mcp_servers=None, + mcp_server_auth_headers=None, + ) + + # Should be prefixed since multiple servers are allowed + names = sorted([t.name for t in tools]) + assert names == ["jira-toolA", "zapier-toolA"] diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py index 126a4ebb896a..80ea95a22102 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py @@ -654,6 +654,110 @@ async def test_pre_call_tool_check_allowed_tools_takes_precedence(self): "Tool tool3 is not allowed for server test-server" in exc_info.value.detail["error"] ) + async def test_get_tools_from_server_add_prefix(self): + """Verify _get_tools_from_server respects add_prefix True/False.""" + manager = MCPServerManager() + + # Create a minimal server with alias used as prefix + server = MCPServer( + server_id="zapier", + name="zapier", + transport=MCPTransport.http, + ) + + # Mock client creation and fetching tools + manager._create_mcp_client = MagicMock(return_value=object()) + + # Tools returned upstream (unprefixed from provider) + upstream_tool = MagicMock() + upstream_tool.name = "send_email" + upstream_tool.description = "Send an email" + upstream_tool.inputSchema = {} + + manager._fetch_tools_with_timeout = AsyncMock(return_value=[upstream_tool]) + + # Case 1: add_prefix=True (default for multi-server) -> expect prefixed + tools_prefixed = await manager._get_tools_from_server(server, add_prefix=True) + assert len(tools_prefixed) == 1 + assert tools_prefixed[0].name == "zapier-send_email" + + # Case 2: add_prefix=False (single-server) -> expect unprefixed + tools_unprefixed = await manager._get_tools_from_server( + server, add_prefix=False + ) + assert len(tools_unprefixed) == 1 + assert tools_unprefixed[0].name == "send_email" + + def test_create_prefixed_tools_updates_mapping_for_both_forms(self): + """_create_prefixed_tools should populate mapping for prefixed and original names even when not adding prefix in output.""" + manager = MCPServerManager() + + server = MCPServer( + server_id="jira", + name="jira", + transport=MCPTransport.http, + ) + + # Input tools as would come from upstream + t1 = MagicMock() + t1.name = "create_issue" + t1.description = "" + t1.inputSchema = {} + t2 = MagicMock() + t2.name = "close_issue" + t2.description = "" + t2.inputSchema = {} + + # Do not add prefix in returned objects + out_tools = manager._create_prefixed_tools([t1, t2], server, add_prefix=False) + + # Returned names should be unprefixed + names = sorted([t.name for t in out_tools]) + assert names == ["close_issue", "create_issue"] + + # Mapping should include both original and prefixed names -> resolves calls either way + assert manager.tool_name_to_mcp_server_name_mapping["create_issue"] == "jira" + assert ( + manager.tool_name_to_mcp_server_name_mapping["jira-create_issue"] == "jira" + ) + assert manager.tool_name_to_mcp_server_name_mapping["close_issue"] == "jira" + assert ( + manager.tool_name_to_mcp_server_name_mapping["jira-close_issue"] == "jira" + ) + + def test_get_mcp_server_from_tool_name_with_prefixed_and_unprefixed(self): + """After mapping is populated, manager resolves both prefixed and unprefixed tool names to the same server.""" + manager = MCPServerManager() + + server = MCPServer( + server_id="zapier", + name="zapier", + server_name="zapier", + transport=MCPTransport.http, + ) + + # Register server so resolution can find it + manager.registry = {server.server_id: server} + + # Populate mapping (add_prefix value doesn't matter for mapping population) + base_tool = MagicMock() + base_tool.name = "create_zap" + base_tool.description = "" + base_tool.inputSchema = {} + _ = manager._create_prefixed_tools([base_tool], server, add_prefix=False) + + # Unprefixed resolution + resolved_server_unpref = manager._get_mcp_server_from_tool_name("create_zap") + print(resolved_server_unpref) + assert resolved_server_unpref is not None + assert resolved_server_unpref.server_id == server.server_id + + # Prefixed resolution + resolved_server_pref = manager._get_mcp_server_from_tool_name( + "zapier-create_zap" + ) + assert resolved_server_pref is not None + assert resolved_server_pref.server_id == server.server_id if __name__ == "__main__": From 4f15d1240ee0d265dffc530f0148919b56d24b11 Mon Sep 17 00:00:00 2001 From: Yuta Saito Date: Sun, 28 Sep 2025 08:45:38 +0900 Subject: [PATCH 2/3] fix: resolve linting errors --- litellm/proxy/_experimental/mcp_server/server.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 9c2551fec218..6802a14fc47d 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -571,16 +571,16 @@ async def call_mcp_tool( "litellm_logging_obj", None ) if litellm_logging_obj: - litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = ( - standard_logging_mcp_tool_call - ) + litellm_logging_obj.model_call_details[ + "mcp_tool_call_metadata" + ] = standard_logging_mcp_tool_call litellm_logging_obj.model = f"MCP: {name}" # Try managed server tool first (pass the full prefixed name) # Primary and recommended way to use MCP servers ######################################################### - mcp_server: Optional[MCPServer] = ( - global_mcp_server_manager._get_mcp_server_from_tool_name(name) - ) + mcp_server: Optional[ + MCPServer + ] = global_mcp_server_manager._get_mcp_server_from_tool_name(name) if mcp_server: standard_logging_mcp_tool_call["mcp_server_cost_info"] = ( mcp_server.mcp_info or {} From be5010a67532e076321dfce8204b3b2c701f56d6 Mon Sep 17 00:00:00 2001 From: Yuta Saito Date: Mon, 29 Sep 2025 07:21:32 +0900 Subject: [PATCH 3/3] test: fix tests --- tests/mcp_tests/test_mcp_server.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py index 2dfe39777fd5..cc7cd4cd82e1 100644 --- a/tests/mcp_tests/test_mcp_server.py +++ b/tests/mcp_tests/test_mcp_server.py @@ -707,9 +707,7 @@ def mock_client_constructor(*args, **kwargs): assert isinstance(response, dict) assert len(response["tools"]) == 1 - # The server should use the server_name as prefix since no alias is provided - expected_prefix = "test_server" - assert response["tools"][0].name == f"{expected_prefix}-test_tool" + assert response["tools"][0].name == "test_tool" finally: # Restore original state global_mcp_server_manager.registry = {} @@ -800,7 +798,7 @@ def mock_get_server_by_id(server_id): ) mock_manager_2.get_mcp_server_by_id = mock_get_server_by_id mock_manager_2._get_tools_from_server = AsyncMock( - side_effect=lambda server, mcp_auth_header=None, extra_headers=None: ( + side_effect=lambda server, mcp_auth_header=None, extra_headers=None, add_prefix=False: ( [mock_tool_1] if server.server_id == "server1_id" else [mock_tool_2] ) ) @@ -1713,6 +1711,7 @@ async def test_get_tools_for_single_server(): mock_manager._get_tools_from_server.assert_called_once_with( server=mock_server, mcp_auth_header="Bearer test_token", + add_prefix=False, ) # Verify the result