Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-- AlterTable
ALTER TABLE "LiteLLM_MCPServerTable" ADD COLUMN "allowed_tools" TEXT[] DEFAULT ARRAY[]::TEXT[];

Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-- AlterTable
ALTER TABLE "LiteLLM_MCPServerTable" ADD COLUMN "extra_headers" TEXT[] DEFAULT ARRAY[]::TEXT[];

1 change: 1 addition & 0 deletions litellm-proxy-extras/litellm_proxy_extras/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ model LiteLLM_MCPServerTable {
mcp_info Json? @default("{}")
mcp_access_groups String[]
allowed_tools String[] @default([])
extra_headers String[] @default([])
// Health check status
status String? @default("unknown")
last_health_check DateTime?
Expand Down
4 changes: 2 additions & 2 deletions litellm/proxy/_experimental/mcp_server/db.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from litellm._uuid import uuid
from typing import Any, Dict, Iterable, List, Optional, Set, Union

from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.proxy._types import (
LiteLLM_MCPServerTable,
LiteLLM_ObjectPermissionTable,
Expand Down Expand Up @@ -30,7 +30,7 @@ def _prepare_mcp_server_data(
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps

# Convert model to dict
data_dict = data.model_dump()
data_dict = data.model_dump(exclude_none=True)
# Ensure alias is always present in the dict (even if None)
if "alias" not in data_dict:
data_dict["alias"] = getattr(data, "alias", None)
Expand Down
189 changes: 88 additions & 101 deletions litellm/proxy/_experimental/mcp_server/mcp_server_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import datetime
import hashlib
import json
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Optional, Set, Union, cast

from fastapi import HTTPException
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
Expand Down Expand Up @@ -240,50 +240,64 @@ def remove_server(self, mcp_server: LiteLLM_MCPServerTable):
)

def add_update_server(self, mcp_server: LiteLLM_MCPServerTable):
if mcp_server.server_id not in self.get_registry():
_mcp_info: MCPInfo = mcp_server.mcp_info or {}
# Use helper to deserialize environment dictionary
# Safely access env field which may not exist on Prisma model objects
env_data = getattr(mcp_server, "env", None)
env_dict = _deserialize_env_dict(env_data)
# Use alias for name if present, else server_name
name_for_prefix = (
mcp_server.alias or mcp_server.server_name or mcp_server.server_id
)
# Preserve all custom fields from database while setting defaults for core fields
mcp_info: MCPInfo = _mcp_info.copy()
# Set default values for core fields if not present
if "server_name" not in mcp_info:
mcp_info["server_name"] = mcp_server.server_name or mcp_server.server_id
if "description" not in mcp_info and mcp_server.description:
mcp_info["description"] = mcp_server.description
try:
if mcp_server.server_id not in self.get_registry():
_mcp_info: MCPInfo = mcp_server.mcp_info or {}
# Use helper to deserialize environment dictionary
# Safely access env field which may not exist on Prisma model objects
env_data = getattr(mcp_server, "env", None)
env_dict = _deserialize_env_dict(env_data)
# Use alias for name if present, else server_name
name_for_prefix = (
mcp_server.alias or mcp_server.server_name or mcp_server.server_id
)
# Preserve all custom fields from database while setting defaults for core fields
mcp_info: MCPInfo = _mcp_info.copy()
# Set default values for core fields if not present
if "server_name" not in mcp_info:
mcp_info["server_name"] = (
mcp_server.server_name or mcp_server.server_id
)
if "description" not in mcp_info and mcp_server.description:
mcp_info["description"] = mcp_server.description

new_server = MCPServer(
server_id=mcp_server.server_id,
name=name_for_prefix,
alias=getattr(mcp_server, "alias", None),
server_name=getattr(mcp_server, "server_name", None),
url=mcp_server.url,
transport=cast(MCPTransportType, mcp_server.transport),
auth_type=cast(MCPAuthType, mcp_server.auth_type),
mcp_info=mcp_info,
extra_headers=getattr(mcp_server, "extra_headers", None),
# oauth specific fields
client_id=getattr(mcp_server, "client_id", None),
client_secret=getattr(mcp_server, "client_secret", None),
scopes=getattr(mcp_server, "scopes", None),
authorization_url=getattr(mcp_server, "authorization_url", None),
token_url=getattr(mcp_server, "token_url", None),
# Stdio-specific fields
command=getattr(mcp_server, "command", None),
args=getattr(mcp_server, "args", None) or [],
env=env_dict,
access_groups=getattr(mcp_server, "mcp_access_groups", None),
allowed_tools=getattr(mcp_server, "allowed_tools", None),
disallowed_tools=getattr(mcp_server, "disallowed_tools", None),
)
self.registry[mcp_server.server_id] = new_server
verbose_logger.debug(f"Added MCP Server: {name_for_prefix}")

new_server = MCPServer(
server_id=mcp_server.server_id,
name=name_for_prefix,
alias=getattr(mcp_server, "alias", None),
server_name=getattr(mcp_server, "server_name", None),
url=mcp_server.url,
transport=cast(MCPTransportType, mcp_server.transport),
auth_type=cast(MCPAuthType, mcp_server.auth_type),
mcp_info=mcp_info,
extra_headers=getattr(mcp_server, "extra_headers", None),
# oauth specific fields
client_id=getattr(mcp_server, "client_id", None),
client_secret=getattr(mcp_server, "client_secret", None),
scopes=getattr(mcp_server, "scopes", None),
authorization_url=getattr(mcp_server, "authorization_url", None),
token_url=getattr(mcp_server, "token_url", None),
# Stdio-specific fields
command=getattr(mcp_server, "command", None),
args=getattr(mcp_server, "args", None) or [],
env=env_dict,
access_groups=getattr(mcp_server, "mcp_access_groups", None),
allowed_tools=getattr(mcp_server, "allowed_tools", None),
disallowed_tools=getattr(mcp_server, "disallowed_tools", None),
)
self.registry[mcp_server.server_id] = new_server
verbose_logger.debug(f"Added MCP Server: {name_for_prefix}")
except Exception as e:
verbose_logger.debug(f"Failed to add MCP server: {str(e)}")
raise e

def get_all_mcp_server_ids(self) -> Set[str]:
"""
Get all MCP server IDs
"""
all_servers = list(self.get_registry().values())
return {server.server_id for server in all_servers}

async def get_allowed_mcp_servers(
self, user_api_key_auth: Optional[UserAPIKeyAuth] = None
Expand Down Expand Up @@ -1118,25 +1132,23 @@ async def get_all_mcp_servers_with_health_and_teams(
if _server_id in allowed_server_ids:
list_mcp_servers.append(
LiteLLM_MCPServerTable(
server_id=_server_id,
server_name=_server_config.name,
alias=_server_config.alias,
url=_server_config.url,
transport=_server_config.transport,
auth_type=_server_config.auth_type,
created_at=datetime.datetime.now(),
updated_at=datetime.datetime.now(),
description=(
_server_config.mcp_info.get("description")
if _server_config.mcp_info
else None
),
mcp_info=_server_config.mcp_info,
mcp_access_groups=_server_config.access_groups or [],
# Stdio-specific fields
command=getattr(_server_config, "command", None),
args=getattr(_server_config, "args", None) or [],
env=getattr(_server_config, "env", None) or {},
**{
**_server_config.model_dump(),
"created_at": datetime.datetime.now(),
"updated_at": datetime.datetime.now(),
"description": (
_server_config.mcp_info.get("description")
if _server_config.mcp_info
else None
),
"allowed_tools": _server_config.allowed_tools or [],
"mcp_info": _server_config.mcp_info,
"mcp_access_groups": _server_config.access_groups or [],
"extra_headers": _server_config.extra_headers or [],
"command": getattr(_server_config, "command", None),
"args": getattr(_server_config, "args", None) or [],
"env": getattr(_server_config, "env", None) or {},
}
)
)

Expand Down Expand Up @@ -1176,44 +1188,19 @@ async def get_all_mcp_servers_with_health_and_teams(
}
)

# Map servers to their teams and return with health data
from typing import cast

return [
LiteLLM_MCPServerTable(
server_id=server.server_id,
server_name=server.server_name,
alias=server.alias,
description=server.description,
url=server.url,
transport=server.transport,
auth_type=server.auth_type,
created_at=server.created_at,
created_by=server.created_by,
updated_at=server.updated_at,
updated_by=server.updated_by,
mcp_access_groups=(
server.mcp_access_groups
if server.mcp_access_groups is not None
else []
),
allowed_tools=(
server.allowed_tools
if server.allowed_tools is not None
else []
),
mcp_info=server.mcp_info,
teams=cast(
List[Dict[str, str | None]],
server_to_teams_map.get(server.server_id, []),
),
# Stdio-specific fields
command=getattr(server, "command", None),
args=getattr(server, "args", None) or [],
env=getattr(server, "env", None) or {},
)
for server in list_mcp_servers
]
## mark invalid servers w/ reason for being invalid
valid_server_ids = self.get_all_mcp_server_ids()
for server in list_mcp_servers:
if server.server_id not in valid_server_ids:
server.status = "unhealthy"
## try adding server to registry to get error
try:
self.add_update_server(server)
except Exception as e:
server.health_check_error = str(e)
server.health_check_error = "Server is not in in memory registry yet. This could be a temporary sync issue."

return list_mcp_servers

async def reload_servers_from_database(self):
"""
Expand Down
4 changes: 3 additions & 1 deletion litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase):
mcp_info: Optional[MCPInfo] = None
mcp_access_groups: List[str] = Field(default_factory=list)
allowed_tools: Optional[List[str]] = None
extra_headers: Optional[List[str]] = None
# Stdio-specific fields
command: Optional[str] = None
args: List[str] = Field(default_factory=list)
Expand Down Expand Up @@ -994,9 +995,10 @@ class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase):
teams: List[Dict[str, Optional[str]]] = Field(default_factory=list)
mcp_access_groups: List[str] = Field(default_factory=list)
allowed_tools: List[str] = Field(default_factory=list)
extra_headers: List[str] = Field(default_factory=list)
mcp_info: Optional[MCPInfo] = None
# Health check status
status: Optional[str] = Field(
status: Optional[Literal["healthy", "unhealthy", "unknown"]] = Field(
default="unknown",
description="Health status: 'healthy', 'unhealthy', 'unknown'",
)
Expand Down
1 change: 1 addition & 0 deletions litellm/proxy/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ model LiteLLM_MCPServerTable {
mcp_info Json? @default("{}")
mcp_access_groups String[]
allowed_tools String[] @default([])
extra_headers String[] @default([])
// Health check status
status String? @default("unknown")
last_health_check DateTime?
Expand Down
1 change: 1 addition & 0 deletions schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ model LiteLLM_MCPServerTable {
mcp_info Json? @default("{}")
mcp_access_groups String[]
allowed_tools String[] @default([])
extra_headers String[] @default([])
// Health check status
status String? @default("unknown")
last_health_check DateTime?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ async def test_pre_call_tool_check_allowed_tools_takes_precedence(self):
"Tool tool3 is not allowed for server test-server"
in exc_info.value.detail["error"]
)

async def test_get_tools_from_server_add_prefix(self):
"""Verify _get_tools_from_server respects add_prefix True/False."""
manager = MCPServerManager()
Expand Down Expand Up @@ -909,6 +910,39 @@ async def test_rest_endpoint_shows_all_when_allowed_tools_is_empty_list(self):
assert "tool_1" in tool_names
assert "tool_2" in tool_names

def test_add_db_mcp_server_to_registry(self):
"""Test that add_db_mcp_server_to_registry adds a MCP server to the registry"""
manager = MCPServerManager()
server = LiteLLM_MCPServerTable(
**{
"server_id": "4c679a81-acd9-4954-9f84-30b739362498",
"server_name": "edc_mcp_server",
"alias": "edc_mcp_server",
"description": None,
"url": "fake_mcp_url",
"transport": "http",
"auth_type": "none",
"created_at": "2025-09-30T08:28:31.353000Z",
"created_by": "a1248959",
"updated_at": "2025-09-30T08:28:31.353000Z",
"updated_by": "a1248959",
"teams": [],
"mcp_access_groups": [],
"mcp_info": {
"server_name": "edc_mcp_server",
"mcp_server_cost_info": None,
},
"status": "unknown",
"last_health_check": None,
"health_check_error": None,
"command": None,
"args": [],
"env": {},
},
)
manager.add_update_server(server)
assert server.server_id in manager.get_registry()


if __name__ == "__main__":
pytest.main([__file__])
Loading
Loading