Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ async def _get_tools_from_server(
server: MCPServer,
mcp_auth_header: Optional[Union[str, Dict[str, str]]] = None,
extra_headers: Optional[Dict[str, str]] = None,
add_prefix: bool = True,
) -> List[MCPTool]:
"""
Helper method to get tools from a single MCP server with prefixed names.
Expand All @@ -468,9 +469,11 @@ async def _get_tools_from_server(

tools = await self._fetch_tools_with_timeout(client, server.name)

prefixed_tools = self._create_prefixed_tools(tools, server)
prefixed_or_original_tools = self._create_prefixed_tools(
tools, server, add_prefix=add_prefix
)

return prefixed_tools
return prefixed_or_original_tools

except Exception as e:
verbose_logger.warning(
Expand Down Expand Up @@ -539,7 +542,7 @@ async def _list_tools_task():
return []

def _create_prefixed_tools(
self, tools: List[MCPTool], server: MCPServer
self, tools: List[MCPTool], server: MCPServer, add_prefix: bool = True
) -> List[MCPTool]:
"""
Create prefixed tools and update tool mapping.
Expand All @@ -557,14 +560,16 @@ def _create_prefixed_tools(
for tool in tools:
prefixed_name = add_server_prefix_to_tool_name(tool.name, prefix)

prefixed_tool = MCPTool(
name=prefixed_name,
name_to_use = prefixed_name if add_prefix else tool.name

tool_obj = MCPTool(
name=name_to_use,
description=tool.description,
inputSchema=tool.inputSchema,
)
prefixed_tools.append(prefixed_tool)
prefixed_tools.append(tool_obj)

# Update tool to server mapping with both original and prefixed names
# Update tool to server mapping for resolution (support both forms)
self.tool_name_to_mcp_server_name_mapping[tool.name] = prefix
self.tool_name_to_mcp_server_name_mapping[prefixed_name] = prefix

Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/_experimental/mcp_server/rest_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ async def _get_tools_for_single_server(server, server_auth_header):
tools = await global_mcp_server_manager._get_tools_from_server(
server=server,
mcp_auth_header=server_auth_header,
add_prefix=False,
)
return _create_tool_response_objects(tools, server.mcp_info)

Expand Down
16 changes: 10 additions & 6 deletions litellm/proxy/_experimental/mcp_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,9 @@ async def _get_tools_from_mcp_servers(
allowed_mcp_servers=allowed_mcp_servers,
)

# Decide whether to add prefix based on number of allowed servers
add_prefix = not (len(allowed_mcp_servers) == 1)

# Get tools from each allowed server
all_tools = []
for server_id in allowed_mcp_servers:
Expand Down Expand Up @@ -448,6 +451,7 @@ async def _get_tools_from_mcp_servers(
server=server,
mcp_auth_header=server_auth_header,
extra_headers=extra_headers,
add_prefix=add_prefix,
)
all_tools.extend(filter_tools_by_allowed_tools(tools, server))
verbose_logger.debug(
Expand Down Expand Up @@ -567,16 +571,16 @@ async def call_mcp_tool(
"litellm_logging_obj", None
)
if litellm_logging_obj:
litellm_logging_obj.model_call_details["mcp_tool_call_metadata"] = (
standard_logging_mcp_tool_call
)
litellm_logging_obj.model_call_details[
"mcp_tool_call_metadata"
] = standard_logging_mcp_tool_call
litellm_logging_obj.model = f"MCP: {name}"
# Try managed server tool first (pass the full prefixed name)
# Primary and recommended way to use MCP servers
#########################################################
mcp_server: Optional[MCPServer] = (
global_mcp_server_manager._get_mcp_server_from_tool_name(name)
)
mcp_server: Optional[
MCPServer
] = global_mcp_server_manager._get_mcp_server_from_tool_name(name)
if mcp_server:
standard_logging_mcp_tool_call["mcp_server_cost_info"] = (
mcp_server.mcp_info or {}
Expand Down
7 changes: 3 additions & 4 deletions tests/mcp_tests/test_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,9 +707,7 @@ def mock_client_constructor(*args, **kwargs):

assert isinstance(response, dict)
assert len(response["tools"]) == 1
# The server should use the server_name as prefix since no alias is provided
expected_prefix = "test_server"
assert response["tools"][0].name == f"{expected_prefix}-test_tool"
assert response["tools"][0].name == "test_tool"
finally:
# Restore original state
global_mcp_server_manager.registry = {}
Expand Down Expand Up @@ -800,7 +798,7 @@ def mock_get_server_by_id(server_id):
)
mock_manager_2.get_mcp_server_by_id = mock_get_server_by_id
mock_manager_2._get_tools_from_server = AsyncMock(
side_effect=lambda server, mcp_auth_header=None, extra_headers=None: (
side_effect=lambda server, mcp_auth_header=None, extra_headers=None, add_prefix=False: (
[mock_tool_1] if server.server_id == "server1_id" else [mock_tool_2]
)
)
Expand Down Expand Up @@ -1713,6 +1711,7 @@ async def test_get_tools_for_single_server():
mock_manager._get_tools_from_server.assert_called_once_with(
server=mock_server,
mcp_auth_header="Bearer test_token",
add_prefix=False,
)

# Verify the result
Expand Down
121 changes: 119 additions & 2 deletions tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def test_get_tools_from_mcp_servers_continues_when_one_server_fails():
)

async def mock_get_tools_from_server(
server, mcp_auth_header=None, extra_headers=None
server, mcp_auth_header=None, extra_headers=None, add_prefix=True
):
if server.name == "working_server":
# Working server returns tools
Expand Down Expand Up @@ -187,7 +187,7 @@ async def test_get_tools_from_mcp_servers_handles_all_servers_failing():
)

async def mock_get_tools_from_server(
server, mcp_auth_header=None, extra_headers=None
server, mcp_auth_header=None, extra_headers=None, add_prefix=True
):
# All servers fail
raise Exception(f"Server {server.name} connection failed")
Expand Down Expand Up @@ -564,3 +564,120 @@ async def mock_fetch_tools_with_timeout(client, server_name):
captured_client_args["extra_headers"]["Authorization"]
== "Bearer github_oauth_token_12345"
)

@pytest.mark.asyncio
async def test_list_tools_single_server_unprefixed_names():
"""When only one MCP server is allowed, list tools should return unprefixed names."""
try:
from litellm.proxy._experimental.mcp_server.server import (
_get_tools_from_mcp_servers,
set_auth_context,
)
except ImportError:
pytest.skip("MCP server not available")

# Mock user auth
user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user")
set_auth_context(user_api_key_auth)

# One allowed server
server = MagicMock()
server.server_id = "server1"
server.name = "Zapier MCP"
server.alias = "zapier"

# Mock manager: allow just one server and return a tool based on add_prefix flag
mock_manager = MagicMock()
mock_manager.get_allowed_mcp_servers = AsyncMock(return_value=["server1"])
mock_manager.get_mcp_server_by_id = (
lambda server_id: server if server_id == "server1" else None
)

async def mock_get_tools_from_server(
server, mcp_auth_header=None, extra_headers=None, add_prefix=False
):
tool = MagicMock()
tool.name = f"{server.alias}-toolA" if add_prefix else "toolA"
tool.description = "desc"
tool.inputSchema = {}
return [tool]

mock_manager._get_tools_from_server = mock_get_tools_from_server

with patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
mock_manager,
):
tools = await _get_tools_from_mcp_servers(
user_api_key_auth=user_api_key_auth,
mcp_auth_header=None,
mcp_servers=None,
mcp_server_auth_headers=None,
)

# Should be unprefixed since only one server is allowed
assert len(tools) == 1
assert tools[0].name == "toolA"


@pytest.mark.asyncio
async def test_list_tools_multiple_servers_prefixed_names():
"""When multiple MCP servers are allowed, list tools should return prefixed names."""
try:
from litellm.proxy._experimental.mcp_server.server import (
_get_tools_from_mcp_servers,
set_auth_context,
)
except ImportError:
pytest.skip("MCP server not available")

# Mock user auth
user_api_key_auth = UserAPIKeyAuth(api_key="test_key", user_id="test_user")
set_auth_context(user_api_key_auth)

# Two allowed servers
server1 = MagicMock()
server1.server_id = "server1"
server1.name = "Zapier MCP"
server1.alias = "zapier"

server2 = MagicMock()
server2.server_id = "server2"
server2.name = "Jira MCP"
server2.alias = "jira"

# Mock manager
mock_manager = MagicMock()
mock_manager.get_allowed_mcp_servers = AsyncMock(
return_value=["server1", "server2"]
)
mock_manager.get_mcp_server_by_id = (
lambda server_id: server1 if server_id == "server1" else server2
)

async def mock_get_tools_from_server(
server, mcp_auth_header=None, extra_headers=None, add_prefix=True
):
tool = MagicMock()
# When multiple servers, add_prefix should be True -> prefixed names
tool.name = f"{server.alias}-toolA" if add_prefix else "toolA"
tool.description = "desc"
tool.inputSchema = {}
return [tool]

mock_manager._get_tools_from_server = mock_get_tools_from_server

with patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
mock_manager,
):
tools = await _get_tools_from_mcp_servers(
user_api_key_auth=user_api_key_auth,
mcp_auth_header=None,
mcp_servers=None,
mcp_server_auth_headers=None,
)

# Should be prefixed since multiple servers are allowed
names = sorted([t.name for t in tools])
assert names == ["jira-toolA", "zapier-toolA"]
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,110 @@ async def test_pre_call_tool_check_allowed_tools_takes_precedence(self):
"Tool tool3 is not allowed for server test-server"
in exc_info.value.detail["error"]
)
async def test_get_tools_from_server_add_prefix(self):
"""Verify _get_tools_from_server respects add_prefix True/False."""
manager = MCPServerManager()

# Create a minimal server with alias used as prefix
server = MCPServer(
server_id="zapier",
name="zapier",
transport=MCPTransport.http,
)

# Mock client creation and fetching tools
manager._create_mcp_client = MagicMock(return_value=object())

# Tools returned upstream (unprefixed from provider)
upstream_tool = MagicMock()
upstream_tool.name = "send_email"
upstream_tool.description = "Send an email"
upstream_tool.inputSchema = {}

manager._fetch_tools_with_timeout = AsyncMock(return_value=[upstream_tool])

# Case 1: add_prefix=True (default for multi-server) -> expect prefixed
tools_prefixed = await manager._get_tools_from_server(server, add_prefix=True)
assert len(tools_prefixed) == 1
assert tools_prefixed[0].name == "zapier-send_email"

# Case 2: add_prefix=False (single-server) -> expect unprefixed
tools_unprefixed = await manager._get_tools_from_server(
server, add_prefix=False
)
assert len(tools_unprefixed) == 1
assert tools_unprefixed[0].name == "send_email"

def test_create_prefixed_tools_updates_mapping_for_both_forms(self):
"""_create_prefixed_tools should populate mapping for prefixed and original names even when not adding prefix in output."""
manager = MCPServerManager()

server = MCPServer(
server_id="jira",
name="jira",
transport=MCPTransport.http,
)

# Input tools as would come from upstream
t1 = MagicMock()
t1.name = "create_issue"
t1.description = ""
t1.inputSchema = {}
t2 = MagicMock()
t2.name = "close_issue"
t2.description = ""
t2.inputSchema = {}

# Do not add prefix in returned objects
out_tools = manager._create_prefixed_tools([t1, t2], server, add_prefix=False)

# Returned names should be unprefixed
names = sorted([t.name for t in out_tools])
assert names == ["close_issue", "create_issue"]

# Mapping should include both original and prefixed names -> resolves calls either way
assert manager.tool_name_to_mcp_server_name_mapping["create_issue"] == "jira"
assert (
manager.tool_name_to_mcp_server_name_mapping["jira-create_issue"] == "jira"
)
assert manager.tool_name_to_mcp_server_name_mapping["close_issue"] == "jira"
assert (
manager.tool_name_to_mcp_server_name_mapping["jira-close_issue"] == "jira"
)

def test_get_mcp_server_from_tool_name_with_prefixed_and_unprefixed(self):
"""After mapping is populated, manager resolves both prefixed and unprefixed tool names to the same server."""
manager = MCPServerManager()

server = MCPServer(
server_id="zapier",
name="zapier",
server_name="zapier",
transport=MCPTransport.http,
)

# Register server so resolution can find it
manager.registry = {server.server_id: server}

# Populate mapping (add_prefix value doesn't matter for mapping population)
base_tool = MagicMock()
base_tool.name = "create_zap"
base_tool.description = ""
base_tool.inputSchema = {}
_ = manager._create_prefixed_tools([base_tool], server, add_prefix=False)

# Unprefixed resolution
resolved_server_unpref = manager._get_mcp_server_from_tool_name("create_zap")
print(resolved_server_unpref)
assert resolved_server_unpref is not None
assert resolved_server_unpref.server_id == server.server_id

# Prefixed resolution
resolved_server_pref = manager._get_mcp_server_from_tool_name(
"zapier-create_zap"
)
assert resolved_server_pref is not None
assert resolved_server_pref.server_id == server.server_id


if __name__ == "__main__":
Expand Down
Loading