From 07fe55b14cdc770508ef9f6c3218283905ee08ed Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 29 Sep 2025 18:03:44 +0800 Subject: [PATCH 1/3] Add multi-transport support to MCP plugin Enhanced MCP plugin to support stdio, SSE, and WebSocket transports for server connections and tool execution. Refactored ServerConfig to include transport-specific fields and improved environment/header handling. Added comprehensive test suite for MCP plugin covering configuration, server connection, tool execution, and plugin structure. Updated plugin loader tests to include MCP plugin. --- optillm/plugins/mcp_plugin.py | 456 ++++++++++++++++++++-------- tests/test_mcp_plugin.py | 538 ++++++++++++++++++++++++++++++++++ tests/test_plugins.py | 32 +- 3 files changed, 898 insertions(+), 128 deletions(-) create mode 100644 tests/test_mcp_plugin.py diff --git a/optillm/plugins/mcp_plugin.py b/optillm/plugins/mcp_plugin.py index 03359f7e..49983eeb 100644 --- a/optillm/plugins/mcp_plugin.py +++ b/optillm/plugins/mcp_plugin.py @@ -21,6 +21,8 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client +from mcp.client.sse import sse_client +from mcp.client.websocket import websocket_client import mcp.types as types from mcp.shared.exceptions import McpError @@ -116,19 +118,47 @@ def find_executable(cmd: str) -> Optional[str]: @dataclass class ServerConfig: """Configuration for a single MCP server""" - command: str - args: List[str] - env: Dict[str, str] + # Transport type: "stdio" (default), "sse", or "websocket" + transport: str = "stdio" + + # For stdio transport + command: Optional[str] = None + args: List[str] = None + + # For remote transports (SSE/WebSocket) + url: Optional[str] = None + headers: Dict[str, str] = None + + # Common fields + env: Dict[str, str] = None description: Optional[str] = None - + + # Timeout settings + timeout: float = 5.0 + sse_read_timeout: float = 300.0 + + def __post_init__(self): + """Initialize default values for mutable fields""" + if self.args is None: + self.args = [] + if self.headers is None: + self.headers = {} + if self.env is None: + self.env = {} + @classmethod def from_dict(cls, config: Dict[str, Any]) -> 'ServerConfig': """Create ServerConfig from a dictionary""" return cls( - command=config.get("command", ""), + transport=config.get("transport", "stdio"), + command=config.get("command"), args=config.get("args", []), + url=config.get("url"), + headers=config.get("headers", {}), env=config.get("env", {}), - description=config.get("description") + description=config.get("description"), + timeout=config.get("timeout", 5.0), + sse_read_timeout=config.get("sse_read_timeout", 300.0) ) class MCPConfigManager: @@ -230,7 +260,7 @@ async def send_notification(self, *args, **kwargs): class MCPServer: """Represents a connection to an MCP server""" - + def __init__(self, server_name: str, config: ServerConfig): self.server_name = server_name self.config = config @@ -241,34 +271,155 @@ def __init__(self, server_name: str, config: ServerConfig): self.has_tools_capability = False self.has_resources_capability = False self.has_prompts_capability = False - - async def connect_and_discover(self) -> bool: - """Connect to the server and discover capabilities using proper context management""" - logger.info(f"Connecting to MCP server: {self.server_name}") + + async def connect_stdio(self, session: LoggingClientSession) -> bool: + """Connect to server using stdio transport and discover capabilities""" + try: + logger.info(f"Connected to server: {self.server_name}") + + # Initialize session + logger.debug(f"Initializing MCP session for {self.server_name}") + result = await session.initialize() + logger.info(f"Server {self.server_name} initialized with capabilities: {result.capabilities}") + logger.debug(f"Full initialization result: {result}") + + # Check which capabilities the server supports + server_capabilities = result.capabilities + + # Discover tools if supported + if hasattr(server_capabilities, "tools"): + self.has_tools_capability = True + logger.info(f"Discovering tools for {self.server_name}") + try: + tools_result = await session.list_tools() + self.tools = tools_result.tools + logger.info(f"Found {len(self.tools)} tools") + logger.debug(f"Tools details: {[t.name for t in self.tools]}") + except McpError as e: + logger.warning(f"Failed to list tools: {e}") + + # Discover resources if supported + if hasattr(server_capabilities, "resources"): + self.has_resources_capability = True + logger.info(f"Discovering resources for {self.server_name}") + try: + resources_result = await session.list_resources() + self.resources = resources_result.resources + logger.info(f"Found {len(self.resources)} resources") + logger.debug(f"Resources details: {[r.uri for r in self.resources]}") + except McpError as e: + logger.warning(f"Failed to list resources: {e}") + + # Discover prompts if supported + if hasattr(server_capabilities, "prompts"): + self.has_prompts_capability = True + logger.info(f"Discovering prompts for {self.server_name}") + try: + prompts_result = await session.list_prompts() + self.prompts = prompts_result.prompts + logger.info(f"Found {len(self.prompts)} prompts") + logger.debug(f"Prompts details: {[p.name for p in self.prompts]}") + except McpError as e: + logger.warning(f"Failed to list prompts: {e}") + + logger.info(f"Server {self.server_name} capabilities: " + f"{len(self.tools)} tools, {len(self.resources)} resources, " + f"{len(self.prompts)} prompts") + return True + + except Exception as e: + logger.error(f"Error during stdio session: {e}") + logger.error(traceback.format_exc()) + return False + + async def connect_sse(self) -> bool: + """Connect to server using SSE transport and discover capabilities""" + logger.info(f"Connecting to SSE server: {self.server_name}") + logger.debug(f"SSE URL: {self.config.url}") + logger.debug(f"Headers: {self.config.headers}") + + if not self.config.url: + logger.error(f"SSE transport requires URL for server {self.server_name}") + return False + + try: + # Expand environment variables in headers + expanded_headers = {} + for key, value in self.config.headers.items(): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + env_var = value[2:-1] + expanded_value = os.environ.get(env_var) + if expanded_value: + expanded_headers[key] = expanded_value + else: + logger.warning(f"Environment variable {env_var} not found for header {key}") + else: + expanded_headers[key] = value + + async with sse_client( + url=self.config.url, + headers=expanded_headers, + timeout=self.config.timeout, + sse_read_timeout=self.config.sse_read_timeout + ) as (read_stream, write_stream): + async with LoggingClientSession(read_stream, write_stream) as session: + return await self.connect_stdio(session) + + except Exception as e: + logger.error(f"Error connecting to SSE server {self.server_name}: {e}") + logger.error(traceback.format_exc()) + return False + + async def connect_websocket(self) -> bool: + """Connect to server using WebSocket transport and discover capabilities""" + logger.info(f"Connecting to WebSocket server: {self.server_name}") + logger.debug(f"WebSocket URL: {self.config.url}") + + if not self.config.url: + logger.error(f"WebSocket transport requires URL for server {self.server_name}") + return False + + try: + async with websocket_client(self.config.url) as (read_stream, write_stream): + async with LoggingClientSession(read_stream, write_stream) as session: + return await self.connect_stdio(session) + + except Exception as e: + logger.error(f"Error connecting to WebSocket server {self.server_name}: {e}") + logger.error(traceback.format_exc()) + return False + + async def connect_stdio_native(self) -> bool: + """Connect using stdio transport with local executable""" logger.debug(f"Server configuration: {vars(self.config)}") - + + # Validate stdio configuration + if not self.config.command: + logger.error(f"stdio transport requires command for server {self.server_name}") + return False + # Find the full path to the command full_command = find_executable(self.config.command) if not full_command: logger.error(f"Failed to find executable for command: {self.config.command}") return False - + # Create environment with PATH included merged_env = os.environ.copy() if self.config.env: merged_env.update(self.config.env) - + logger.debug(f"Using command: {full_command}") logger.debug(f"Arguments: {self.config.args}") logger.debug(f"Environment: {self.config.env}") - + # Create server parameters server_params = StdioServerParameters( command=full_command, args=self.config.args, env=merged_env ) - + try: # Start the server separately to see its output process = await asyncio.create_subprocess_exec( @@ -278,7 +429,7 @@ async def connect_and_discover(self) -> bool: stderr=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE ) - + # Log startup message from stderr async def log_stderr(): while True: @@ -287,7 +438,7 @@ async def log_stderr(): break stderr_text = line.decode().strip() logger.info(f"Server {self.server_name} stderr: {stderr_text}") - + # Log stdout too for debugging async def log_stdout(): while True: @@ -296,75 +447,52 @@ async def log_stdout(): break stdout_text = line.decode().strip() logger.debug(f"Server {self.server_name} stdout: {stdout_text}") - + # Start logging tasks asyncio.create_task(log_stderr()) asyncio.create_task(log_stdout()) - + # Wait a bit for the server to start up logger.debug(f"Waiting for server to start up...") await asyncio.sleep(2) - + # Use the MCP client with proper context management logger.debug(f"Establishing MCP client connection to {self.server_name}") async with stdio_client(server_params) as (read_stream, write_stream): logger.debug(f"Connection established, creating session") # Use our logging session instead of the regular one async with LoggingClientSession(read_stream, write_stream) as session: - logger.info(f"Connected to server: {self.server_name}") - - # Initialize session - logger.debug(f"Initializing MCP session for {self.server_name}") - result = await session.initialize() - logger.info(f"Server {self.server_name} initialized with capabilities: {result.capabilities}") - logger.debug(f"Full initialization result: {result}") - - # Check which capabilities the server supports - server_capabilities = result.capabilities - - # Discover tools if supported - if hasattr(server_capabilities, "tools"): - self.has_tools_capability = True - logger.info(f"Discovering tools for {self.server_name}") - try: - tools_result = await session.list_tools() - self.tools = tools_result.tools - logger.info(f"Found {len(self.tools)} tools") - logger.debug(f"Tools details: {[t.name for t in self.tools]}") - except McpError as e: - logger.warning(f"Failed to list tools: {e}") - - # Discover resources if supported - if hasattr(server_capabilities, "resources"): - self.has_resources_capability = True - logger.info(f"Discovering resources for {self.server_name}") - try: - resources_result = await session.list_resources() - self.resources = resources_result.resources - logger.info(f"Found {len(self.resources)} resources") - logger.debug(f"Resources details: {[r.uri for r in self.resources]}") - except McpError as e: - logger.warning(f"Failed to list resources: {e}") - - # Discover prompts if supported - if hasattr(server_capabilities, "prompts"): - self.has_prompts_capability = True - logger.info(f"Discovering prompts for {self.server_name}") - try: - prompts_result = await session.list_prompts() - self.prompts = prompts_result.prompts - logger.info(f"Found {len(self.prompts)} prompts") - logger.debug(f"Prompts details: {[p.name for p in self.prompts]}") - except McpError as e: - logger.warning(f"Failed to list prompts: {e}") - - logger.info(f"Server {self.server_name} capabilities: " - f"{len(self.tools)} tools, {len(self.resources)} resources, " - f"{len(self.prompts)} prompts") - - self.connected = True - return True - + return await self.connect_stdio(session) + + except Exception as e: + logger.error(f"Error connecting to MCP server {self.server_name}: {e}") + logger.error(traceback.format_exc()) + return False + + async def connect_and_discover(self) -> bool: + """Connect to the server and discover capabilities using appropriate transport""" + logger.info(f"Connecting to MCP server: {self.server_name} using {self.config.transport} transport") + + # Route to appropriate transport method + try: + if self.config.transport == "stdio": + success = await self.connect_stdio_native() + elif self.config.transport == "sse": + success = await self.connect_sse() + elif self.config.transport == "websocket": + success = await self.connect_websocket() + else: + logger.error(f"Unsupported transport type: {self.config.transport}") + return False + + if success: + self.connected = True + logger.info(f"Successfully connected to {self.server_name} via {self.config.transport}") + else: + logger.error(f"Failed to connect to {self.server_name} via {self.config.transport}") + + return success + except Exception as e: logger.error(f"Error connecting to MCP server {self.server_name}: {e}") logger.error(traceback.format_exc()) @@ -496,87 +624,173 @@ def get_capabilities_description(self) -> str: return "\n".join(description_parts) +async def execute_tool_with_session(session: LoggingClientSession, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Execute a tool using an existing session""" + try: + # Initialize the session + await session.initialize() + + # Call the tool and get the result + logger.info(f"Calling tool {tool_name} with arguments: {arguments}") + result = await session.call_tool(tool_name, arguments) + + # Process the result + content_results = [] + for content in result.content: + if content.type == "text": + content_results.append({ + "type": "text", + "text": content.text + }) + logger.debug(f"Tool result (text): {content.text[:100]}...") + elif content.type == "image": + content_results.append({ + "type": "image", + "data": content.data, + "mimeType": content.mimeType + }) + logger.debug(f"Tool result (image): {content.mimeType}") + + return { + "result": content_results, + "is_error": result.isError + } + + except Exception as e: + logger.error(f"Error executing tool {tool_name}: {e}") + logger.error(traceback.format_exc()) + return {"error": f"Error executing tool: {str(e)}"} + async def execute_tool(server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: """ Execute a tool on an MCP server - + This function creates a fresh connection for each tool execution to ensure reliability. """ logger.info(f"Executing tool {tool_name} on server {server_name} with arguments: {arguments}") - + # Load configuration config_manager = MCPConfigManager() if not config_manager.load_config(): return {"error": "Failed to load MCP configuration"} - + # Get server configuration server_config = config_manager.servers.get(server_name) if not server_config: return {"error": f"Server {server_name} not found in configuration"} - + + # Log the tool call in detail + logger.debug(f"Tool call details:") + logger.debug(f" Server: {server_name}") + logger.debug(f" Tool: {tool_name}") + logger.debug(f" Arguments: {json.dumps(arguments, indent=2)}") + logger.debug(f" Transport: {server_config.transport}") + + try: + # Route to appropriate transport + if server_config.transport == "stdio": + return await execute_tool_stdio(server_config, tool_name, arguments) + elif server_config.transport == "sse": + return await execute_tool_sse(server_config, tool_name, arguments) + elif server_config.transport == "websocket": + return await execute_tool_websocket(server_config, tool_name, arguments) + else: + return {"error": f"Unsupported transport type: {server_config.transport}"} + + except Exception as e: + logger.error(f"Error executing tool {tool_name} on server {server_name}: {e}") + logger.error(traceback.format_exc()) + return {"error": f"Error executing tool: {str(e)}"} + +async def execute_tool_stdio(server_config: ServerConfig, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Execute tool using stdio transport""" + if not server_config.command: + return {"error": "stdio transport requires command"} + # Find executable full_command = find_executable(server_config.command) if not full_command: return {"error": f"Failed to find executable for command: {server_config.command}"} - + # Create environment with PATH included merged_env = os.environ.copy() if server_config.env: merged_env.update(server_config.env) - + # Create server parameters server_params = StdioServerParameters( command=full_command, args=server_config.args, env=merged_env ) - + + logger.debug(f" Command: {full_command}") + logger.debug(f" Args: {server_config.args}") + try: - # Log the tool call in detail - logger.debug(f"Tool call details:") - logger.debug(f" Server: {server_name}") - logger.debug(f" Tool: {tool_name}") - logger.debug(f" Arguments: {json.dumps(arguments, indent=2)}") - logger.debug(f" Command: {full_command}") - logger.debug(f" Args: {server_config.args}") - # Use the MCP client with proper context management async with stdio_client(server_params) as (read_stream, write_stream): # Use our logging session async with LoggingClientSession(read_stream, write_stream) as session: - # Initialize the session - await session.initialize() - - # Call the tool and get the result - logger.info(f"Calling tool {tool_name} with arguments: {arguments}") - result = await session.call_tool(tool_name, arguments) - - # Process the result - content_results = [] - for content in result.content: - if content.type == "text": - content_results.append({ - "type": "text", - "text": content.text - }) - logger.debug(f"Tool result (text): {content.text[:100]}...") - elif content.type == "image": - content_results.append({ - "type": "image", - "data": content.data, - "mimeType": content.mimeType - }) - logger.debug(f"Tool result (image): {content.mimeType}") - - return { - "result": content_results, - "is_error": result.isError - } - + return await execute_tool_with_session(session, tool_name, arguments) + except Exception as e: - logger.error(f"Error executing tool {tool_name} on server {server_name}: {e}") + logger.error(f"Error with stdio tool execution: {e}") logger.error(traceback.format_exc()) - return {"error": f"Error executing tool: {str(e)}"} + return {"error": f"Error executing tool via stdio: {str(e)}"} + +async def execute_tool_sse(server_config: ServerConfig, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Execute tool using SSE transport""" + if not server_config.url: + return {"error": "SSE transport requires URL"} + + try: + # Expand environment variables in headers + expanded_headers = {} + for key, value in server_config.headers.items(): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + env_var = value[2:-1] + expanded_value = os.environ.get(env_var) + if expanded_value: + expanded_headers[key] = expanded_value + else: + logger.warning(f"Environment variable {env_var} not found for header {key}") + else: + expanded_headers[key] = value + + logger.debug(f" URL: {server_config.url}") + logger.debug(f" Headers: {list(expanded_headers.keys())}") + + async with sse_client( + url=server_config.url, + headers=expanded_headers, + timeout=server_config.timeout, + sse_read_timeout=server_config.sse_read_timeout + ) as (read_stream, write_stream): + async with LoggingClientSession(read_stream, write_stream) as session: + return await execute_tool_with_session(session, tool_name, arguments) + + except Exception as e: + logger.error(f"Error with SSE tool execution: {e}") + logger.error(traceback.format_exc()) + return {"error": f"Error executing tool via SSE: {str(e)}"} + +async def execute_tool_websocket(server_config: ServerConfig, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Execute tool using WebSocket transport""" + if not server_config.url: + return {"error": "WebSocket transport requires URL"} + + try: + logger.debug(f" URL: {server_config.url}") + + async with websocket_client(server_config.url) as (read_stream, write_stream): + async with LoggingClientSession(read_stream, write_stream) as session: + return await execute_tool_with_session(session, tool_name, arguments) + + except Exception as e: + logger.error(f"Error with WebSocket tool execution: {e}") + logger.error(traceback.format_exc()) + return {"error": f"Error executing tool via WebSocket: {str(e)}"} async def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str, int]: """ diff --git a/tests/test_mcp_plugin.py b/tests/test_mcp_plugin.py new file mode 100644 index 00000000..6bd764ec --- /dev/null +++ b/tests/test_mcp_plugin.py @@ -0,0 +1,538 @@ +#!/usr/bin/env python3 +""" +Comprehensive test suite for MCP plugin functionality +""" + +import sys +import os +import asyncio +import json +import pytest +from unittest.mock import Mock, AsyncMock, patch, MagicMock +from pathlib import Path + +# Try to import pytest, but don't fail if it's not available +try: + import pytest +except ImportError: + pytest = None + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from optillm.plugins.mcp_plugin import ( + ServerConfig, MCPServer, MCPConfigManager, MCPServerManager, + execute_tool, execute_tool_stdio, execute_tool_sse, execute_tool_websocket, + LoggingClientSession, SLUG +) + + +class TestServerConfig: + """Test ServerConfig dataclass functionality""" + + def test_default_stdio_config(self): + """Test default configuration for stdio transport""" + config = ServerConfig() + assert config.transport == "stdio" + assert config.command is None + assert config.args == [] + assert config.url is None + assert config.headers == {} + assert config.env == {} + assert config.timeout == 5.0 + assert config.sse_read_timeout == 300.0 + + def test_stdio_config_from_dict(self): + """Test creating stdio config from dictionary""" + config_dict = { + "transport": "stdio", + "command": "npx", + "args": ["@modelcontextprotocol/server-filesystem", "/tmp"], + "env": {"PATH": "/usr/local/bin"}, + "description": "Filesystem server" + } + + config = ServerConfig.from_dict(config_dict) + assert config.transport == "stdio" + assert config.command == "npx" + assert config.args == ["@modelcontextprotocol/server-filesystem", "/tmp"] + assert config.env == {"PATH": "/usr/local/bin"} + assert config.description == "Filesystem server" + + def test_sse_config_from_dict(self): + """Test creating SSE config from dictionary""" + config_dict = { + "transport": "sse", + "url": "https://api.example.com/mcp", + "headers": {"Authorization": "Bearer token123"}, + "timeout": 10.0, + "sse_read_timeout": 600.0, + "description": "Remote SSE server" + } + + config = ServerConfig.from_dict(config_dict) + assert config.transport == "sse" + assert config.url == "https://api.example.com/mcp" + assert config.headers == {"Authorization": "Bearer token123"} + assert config.timeout == 10.0 + assert config.sse_read_timeout == 600.0 + assert config.description == "Remote SSE server" + + def test_websocket_config_from_dict(self): + """Test creating WebSocket config from dictionary""" + config_dict = { + "transport": "websocket", + "url": "wss://api.example.com/mcp", + "description": "WebSocket server" + } + + config = ServerConfig.from_dict(config_dict) + assert config.transport == "websocket" + assert config.url == "wss://api.example.com/mcp" + assert config.description == "WebSocket server" + + +class TestMCPConfigManager: + """Test MCP configuration management""" + + def test_init_default_path(self): + """Test default configuration path""" + manager = MCPConfigManager() + expected_path = Path.home() / ".optillm" / "mcp_config.json" + assert manager.config_path == expected_path + + def test_init_custom_path(self): + """Test custom configuration path""" + custom_path = "/tmp/custom_mcp_config.json" + manager = MCPConfigManager(custom_path) + assert manager.config_path == Path(custom_path) + + def test_create_default_config(self): + """Test creating default configuration file""" + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + config_path = Path(temp_dir) / "test_config.json" + manager = MCPConfigManager(str(config_path)) + + success = manager.create_default_config() + assert success + assert config_path.exists() + + # Verify default content + with open(config_path) as f: + config = json.load(f) + + assert "mcpServers" in config + assert "log_level" in config + assert config["mcpServers"] == {} + assert config["log_level"] == "INFO" + + def test_load_valid_config(self): + """Test loading valid configuration""" + import tempfile + + config_data = { + "mcpServers": { + "test_server": { + "transport": "stdio", + "command": "test-command", + "args": ["arg1", "arg2"] + } + }, + "log_level": "DEBUG" + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(config_data, f) + config_path = f.name + + try: + manager = MCPConfigManager(config_path) + success = manager.load_config() + assert success + assert len(manager.servers) == 1 + assert "test_server" in manager.servers + assert manager.servers["test_server"].command == "test-command" + assert manager.log_level == "DEBUG" + finally: + os.unlink(config_path) + + def test_load_nonexistent_config(self): + """Test loading non-existent configuration""" + manager = MCPConfigManager("/nonexistent/path.json") + success = manager.load_config() + assert not success + assert len(manager.servers) == 0 + + +@pytest.mark.asyncio +class TestMCPServer: + """Test MCP server connection and capability discovery""" + + def test_init(self): + """Test MCPServer initialization""" + config = ServerConfig() + server = MCPServer("test_server", config) + + assert server.server_name == "test_server" + assert server.config == config + assert server.tools == [] + assert server.resources == [] + assert server.prompts == [] + assert not server.connected + assert not server.has_tools_capability + assert not server.has_resources_capability + assert not server.has_prompts_capability + + async def test_connect_stdio_validation(self): + """Test stdio connection validation""" + config = ServerConfig(transport="stdio") # No command + server = MCPServer("test_server", config) + + result = await server.connect_stdio_native() + assert not result + + async def test_connect_sse_validation(self): + """Test SSE connection validation""" + config = ServerConfig(transport="sse") # No URL + server = MCPServer("test_server", config) + + result = await server.connect_sse() + assert not result + + async def test_connect_websocket_validation(self): + """Test WebSocket connection validation""" + config = ServerConfig(transport="websocket") # No URL + server = MCPServer("test_server", config) + + result = await server.connect_websocket() + assert not result + + async def test_connect_and_discover_unsupported_transport(self): + """Test unsupported transport type""" + config = ServerConfig(transport="invalid") + server = MCPServer("test_server", config) + + result = await server.connect_and_discover() + assert not result + + @patch('optillm.plugins.mcp_plugin.sse_client') + async def test_connect_sse_success(self, mock_sse_client): + """Test successful SSE connection""" + # Mock the SSE client context manager + mock_streams = (AsyncMock(), AsyncMock()) + mock_sse_client.return_value.__aenter__ = AsyncMock(return_value=mock_streams) + mock_sse_client.return_value.__aexit__ = AsyncMock(return_value=None) + + # Mock session + mock_session = AsyncMock() + mock_result = Mock() + mock_result.capabilities = Mock() + mock_session.initialize.return_value = mock_result + + config = ServerConfig( + transport="sse", + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer token"} + ) + server = MCPServer("test_server", config) + + with patch.object(server, 'connect_stdio', return_value=True): + with patch('optillm.plugins.mcp_plugin.LoggingClientSession') as mock_session_class: + mock_session_class.return_value.__aenter__ = AsyncMock(return_value=mock_session) + mock_session_class.return_value.__aexit__ = AsyncMock(return_value=None) + + result = await server.connect_sse() + assert result + + +@pytest.mark.asyncio +class TestToolExecution: + """Test tool execution functionality""" + + async def test_execute_tool_server_not_found(self): + """Test tool execution with non-existent server""" + with patch('optillm.plugins.mcp_plugin.MCPConfigManager') as mock_manager_class: + mock_manager = Mock() + mock_manager.load_config.return_value = True + mock_manager.servers = {} + mock_manager_class.return_value = mock_manager + + result = await execute_tool("nonexistent", "test_tool", {}) + assert "error" in result + assert "not found" in result["error"] + + async def test_execute_tool_config_load_failure(self): + """Test tool execution with config load failure""" + with patch('optillm.plugins.mcp_plugin.MCPConfigManager') as mock_manager_class: + mock_manager = Mock() + mock_manager.load_config.return_value = False + mock_manager_class.return_value = mock_manager + + result = await execute_tool("test_server", "test_tool", {}) + assert "error" in result + assert "Failed to load MCP configuration" == result["error"] + + async def test_execute_tool_unsupported_transport(self): + """Test tool execution with unsupported transport""" + config = ServerConfig(transport="invalid") + + with patch('optillm.plugins.mcp_plugin.MCPConfigManager') as mock_manager_class: + mock_manager = Mock() + mock_manager.load_config.return_value = True + mock_manager.servers = {"test_server": config} + mock_manager_class.return_value = mock_manager + + result = await execute_tool("test_server", "test_tool", {}) + assert "error" in result + assert "Unsupported transport type" in result["error"] + + async def test_execute_tool_stdio_no_command(self): + """Test stdio tool execution without command""" + config = ServerConfig(transport="stdio") # No command + result = await execute_tool_stdio(config, "test_tool", {}) + assert "error" in result + assert "requires command" in result["error"] + + async def test_execute_tool_sse_no_url(self): + """Test SSE tool execution without URL""" + config = ServerConfig(transport="sse") # No URL + result = await execute_tool_sse(config, "test_tool", {}) + assert "error" in result + assert "requires URL" in result["error"] + + async def test_execute_tool_websocket_no_url(self): + """Test WebSocket tool execution without URL""" + config = ServerConfig(transport="websocket") # No URL + result = await execute_tool_websocket(config, "test_tool", {}) + assert "error" in result + assert "requires URL" in result["error"] + + +class TestMCPServerManager: + """Test MCP server manager functionality""" + + def test_init(self): + """Test MCPServerManager initialization""" + config_manager = MCPConfigManager() + manager = MCPServerManager(config_manager) + + assert manager.config_manager == config_manager + assert manager.servers == {} + assert not manager.initialized + assert manager.all_tools == [] + assert manager.all_resources == [] + assert manager.all_prompts == [] + + def test_get_tools_for_model_empty(self): + """Test getting tools when no tools are available""" + config_manager = MCPConfigManager() + manager = MCPServerManager(config_manager) + + tools = manager.get_tools_for_model() + assert tools == [] + + def test_get_capabilities_description_no_servers(self): + """Test getting capabilities description with no servers""" + config_manager = MCPConfigManager() + manager = MCPServerManager(config_manager) + + description = manager.get_capabilities_description() + assert "No MCP servers available" in description + + +@pytest.mark.asyncio +@pytest.mark.skipif(not os.getenv("GITHUB_TOKEN"), reason="GITHUB_TOKEN not set") +class TestGitHubMCPServer: + """Integration tests with GitHub MCP server (requires GITHUB_TOKEN)""" + + async def test_github_mcp_server_connection(self): + """Test real connection to GitHub MCP server""" + config = ServerConfig( + transport="sse", + url="https://api.githubcopilot.com/mcp", + headers={ + "Authorization": f"Bearer {os.getenv('GITHUB_TOKEN')}", + "Accept": "text/event-stream" + }, + description="GitHub MCP Server" + ) + + server = MCPServer("github", config) + + try: + connected = await server.connect_and_discover() + + if connected: + assert server.connected + assert len(server.tools) > 0 or len(server.resources) > 0 or len(server.prompts) > 0 + print(f"GitHub MCP server connected successfully!") + print(f"Found: {len(server.tools)} tools, {len(server.resources)} resources, {len(server.prompts)} prompts") + + # Test a simple tool if available + if server.tools: + tool_name = server.tools[0].name + print(f"Testing tool: {tool_name}") + + # Create minimal arguments - this might fail but tests the connection + result = await execute_tool_sse(config, tool_name, {}) + print(f"Tool execution result: {result}") + else: + pytest.skip("Could not connect to GitHub MCP server") + + except Exception as e: + pytest.skip(f"GitHub MCP server test failed: {e}") + + +class TestPluginStructure: + """Test plugin structure and exports""" + + def test_slug_exists(self): + """Test that plugin has SLUG defined""" + assert hasattr(sys.modules['optillm.plugins.mcp_plugin'], 'SLUG') + assert SLUG == "mcp" + + def test_run_function_exists(self): + """Test that plugin has run function defined""" + import optillm.plugins.mcp_plugin as plugin + assert hasattr(plugin, 'run') + assert callable(plugin.run) + + def test_required_imports(self): + """Test that required modules can be imported""" + try: + from mcp.client.sse import sse_client + from mcp.client.websocket import websocket_client + assert sse_client is not None + assert websocket_client is not None + except ImportError as e: + pytest.fail(f"Required MCP imports failed: {e}") + + +# Mock tests for various scenarios +class TestMockScenarios: + """Test various scenarios with mocked dependencies""" + + @patch('optillm.plugins.mcp_plugin.find_executable') + def test_stdio_command_not_found(self, mock_find_executable): + """Test stdio transport when command is not found""" + mock_find_executable.return_value = None + + config = ServerConfig(transport="stdio", command="nonexistent-command") + + async def test_async(): + result = await execute_tool_stdio(config, "test_tool", {}) + assert "error" in result + assert "Failed to find executable" in result["error"] + + asyncio.run(test_async()) + + def test_environment_variable_expansion(self): + """Test environment variable expansion in SSE headers""" + os.environ["TEST_TOKEN"] = "test-token-value" + + try: + config = ServerConfig( + transport="sse", + url="https://api.example.com/mcp", + headers={"Authorization": "Bearer ${TEST_TOKEN}"} + ) + + server = MCPServer("test", config) + + # Test the header expansion logic from connect_sse method + expanded_headers = {} + for key, value in config.headers.items(): + if isinstance(value, str) and value.startswith("${") and value.endswith("}"): + env_var = value[2:-1] + expanded_value = os.environ.get(env_var) + if expanded_value: + expanded_headers[key] = expanded_value + else: + expanded_headers[key] = value + + assert expanded_headers["Authorization"] == "Bearer test-token-value" + + finally: + del os.environ["TEST_TOKEN"] + + +if __name__ == "__main__": + print("Running MCP plugin tests...") + + # Run basic tests + test_classes = [ + TestServerConfig, + TestMCPConfigManager, + TestPluginStructure, + TestMockScenarios + ] + + for test_class in test_classes: + instance = test_class() + methods = [method for method in dir(instance) if method.startswith('test_')] + + for method_name in methods: + try: + method = getattr(instance, method_name) + if asyncio.iscoroutinefunction(method): + asyncio.run(method()) + else: + method() + print(f"✅ {test_class.__name__}.{method_name} passed") + except Exception as e: + print(f"❌ {test_class.__name__}.{method_name} failed: {e}") + + # Run async tests + async def run_async_tests(): + test_instance = TestMCPServer() + async_methods = [ + 'test_connect_stdio_validation', + 'test_connect_sse_validation', + 'test_connect_websocket_validation', + 'test_connect_and_discover_unsupported_transport' + ] + + for method_name in async_methods: + try: + method = getattr(test_instance, method_name) + await method() + print(f"✅ TestMCPServer.{method_name} passed") + except Exception as e: + print(f"❌ TestMCPServer.{method_name} failed: {e}") + + # Tool execution tests + tool_test_instance = TestToolExecution() + tool_methods = [ + 'test_execute_tool_server_not_found', + 'test_execute_tool_config_load_failure', + 'test_execute_tool_unsupported_transport', + 'test_execute_tool_stdio_no_command', + 'test_execute_tool_sse_no_url', + 'test_execute_tool_websocket_no_url' + ] + + for method_name in tool_methods: + try: + method = getattr(tool_test_instance, method_name) + await method() + print(f"✅ TestToolExecution.{method_name} passed") + except Exception as e: + print(f"❌ TestToolExecution.{method_name} failed: {e}") + + asyncio.run(run_async_tests()) + + print("\n🎯 MCP Plugin tests completed!") + print("💡 To run GitHub MCP server integration test, set GITHUB_TOKEN environment variable") + + if os.getenv("GITHUB_TOKEN"): + print("🔍 Running GitHub MCP server integration test...") + async def run_github_test(): + test_instance = TestGitHubMCPServer() + try: + await test_instance.test_github_mcp_server_connection() + print("✅ GitHub MCP server integration test passed") + except Exception as e: + print(f"❌ GitHub MCP server integration test failed: {e}") + + asyncio.run(run_github_test()) \ No newline at end of file diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 5f91fe1d..34e073a1 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -22,7 +22,7 @@ def test_plugin_module_imports(): """Test that plugin modules can be imported""" plugin_modules = [ 'optillm.plugins.memory_plugin', - 'optillm.plugins.readurls_plugin', + 'optillm.plugins.readurls_plugin', 'optillm.plugins.privacy_plugin', 'optillm.plugins.genselect_plugin', 'optillm.plugins.majority_voting_plugin', @@ -31,7 +31,8 @@ def test_plugin_module_imports(): 'optillm.plugins.deepthink_plugin', 'optillm.plugins.longcepo_plugin', 'optillm.plugins.spl_plugin', - 'optillm.plugins.proxy_plugin' + 'optillm.plugins.proxy_plugin', + 'optillm.plugins.mcp_plugin' ] for module_name in plugin_modules: @@ -52,7 +53,7 @@ def test_plugin_approach_detection(): load_plugins() # Check if known plugins are loaded - expected_plugins = ["memory", "readurls", "privacy", "web_search", "deep_research", "deepthink", "longcepo", "spl", "proxy"] + expected_plugins = ["memory", "readurls", "privacy", "web_search", "deep_research", "deepthink", "longcepo", "spl", "proxy", "mcp"] for plugin_name in expected_plugins: assert plugin_name in plugin_approaches, f"Plugin {plugin_name} not loaded" @@ -251,7 +252,7 @@ def test_proxy_plugin_timeout_handling(): from optillm.plugins.proxy.client import ProxyClient from unittest.mock import Mock, patch import concurrent.futures - + # Create config with short timeout config = { "providers": [ @@ -261,7 +262,7 @@ def test_proxy_plugin_timeout_handling(): "api_key": "test-key-1" }, { - "name": "fast_provider", + "name": "fast_provider", "base_url": "http://localhost:8002/v1", "api_key": "test-key-2" } @@ -279,10 +280,10 @@ def test_proxy_plugin_timeout_handling(): "timeout": 5 } } - + # Create proxy client proxy_client = ProxyClient(config) - + # Verify timeout settings are loaded assert proxy_client.request_timeout == 2, "Request timeout should be 2" assert proxy_client.connect_timeout == 1, "Connect timeout should be 1" @@ -290,6 +291,17 @@ def test_proxy_plugin_timeout_handling(): assert proxy_client.queue_timeout == 5, "Queue timeout should be 5" +def test_mcp_plugin(): + """Test MCP plugin module""" + import optillm.plugins.mcp_plugin as plugin + assert hasattr(plugin, 'run') + assert hasattr(plugin, 'SLUG') + assert hasattr(plugin, 'ServerConfig') + assert hasattr(plugin, 'MCPServer') + assert hasattr(plugin, 'execute_tool') + assert plugin.SLUG == "mcp" + + def test_plugin_subdirectory_imports(): """Test all plugins with subdirectories can import their submodules""" # Test deep_research @@ -435,6 +447,12 @@ def test_no_relative_import_errors(): print("✅ Proxy plugin timeout handling test passed") except Exception as e: print(f"❌ Proxy plugin timeout handling test failed: {e}") + + try: + test_mcp_plugin() + print("✅ MCP plugin test passed") + except Exception as e: + print(f"❌ MCP plugin test failed: {e}") try: test_plugin_subdirectory_imports() From 06a6eb013edbfe794c48ad7dec4cbad0c0bf5875 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 29 Sep 2025 18:04:50 +0800 Subject: [PATCH 2/3] Update README.md --- README.md | 181 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 171 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index f17049f5..7b580b06 100644 --- a/README.md +++ b/README.md @@ -339,6 +339,11 @@ response = client.chat.completions.create( The Model Context Protocol (MCP) plugin enables OptiLLM to connect with MCP servers, bringing external tools, resources, and prompts into the context of language models. This allows for powerful integrations with filesystem access, database queries, API connections, and more. +OptiLLM supports both **local** and **remote** MCP servers through multiple transport methods: +- **stdio**: Local servers (traditional) +- **SSE**: Remote servers via Server-Sent Events +- **WebSocket**: Remote servers via WebSocket connections + #### What is MCP? The [Model Context Protocol](https://modelcontextprotocol.io/) (MCP) is an open protocol standard that allows LLMs to securely access tools and data sources through a standardized interface. MCP servers can provide: @@ -351,12 +356,16 @@ The [Model Context Protocol](https://modelcontextprotocol.io/) (MCP) is an open ##### Setting up MCP Config +> **Note on Backwards Compatibility**: Existing MCP configurations will continue to work unchanged. The `transport` field defaults to "stdio" when not specified, maintaining full backwards compatibility with existing setups. + 1. Create a configuration file at `~/.optillm/mcp_config.json` with the following structure: +**Local Server (stdio) - Traditional Method:** ```json { "mcpServers": { "filesystem": { + "transport": "stdio", "command": "npx", "args": [ "-y", @@ -364,46 +373,169 @@ The [Model Context Protocol](https://modelcontextprotocol.io/) (MCP) is an open "/path/to/allowed/directory1", "/path/to/allowed/directory2" ], + "env": {}, + "description": "Local filesystem access" + } + }, + "log_level": "INFO" +} +``` + +**Legacy Format (still works):** +```json +{ + "mcpServers": { + "filesystem": { + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/path/to/directory"], "env": {} } + } +} +``` + +**Remote Server (SSE) - New Feature:** +```json +{ + "mcpServers": { + "github": { + "transport": "sse", + "url": "https://api.githubcopilot.com/mcp", + "headers": { + "Authorization": "Bearer ${GITHUB_TOKEN}", + "Accept": "text/event-stream" + }, + "timeout": 30.0, + "sse_read_timeout": 300.0, + "description": "GitHub MCP server for repository access" + } + }, + "log_level": "INFO" +} +``` + +**Remote Server (WebSocket) - New Feature:** +```json +{ + "mcpServers": { + "remote-ws": { + "transport": "websocket", + "url": "wss://api.example.com/mcp", + "description": "Remote WebSocket MCP server" + } + }, + "log_level": "INFO" +} +``` + +**Mixed Configuration (Local + Remote):** +```json +{ + "mcpServers": { + "filesystem": { + "transport": "stdio", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-filesystem", "/home/user/docs"], + "description": "Local filesystem access" + }, + "github": { + "transport": "sse", + "url": "https://api.githubcopilot.com/mcp", + "headers": { + "Authorization": "Bearer ${GITHUB_TOKEN}" + }, + "description": "GitHub MCP server" + }, + "remote-api": { + "transport": "websocket", + "url": "wss://api.company.com/mcp", + "description": "Company internal MCP server" + } }, "log_level": "INFO" } ``` -Each server entry in `mcpServers` consists of: +##### Configuration Parameters + +**Common Parameters:** +- **Server name**: A unique identifier for the server (e.g., "filesystem", "github") +- **transport**: Transport method - "stdio" (default), "sse", or "websocket" +- **description** (optional): Description of the server's functionality +- **timeout** (optional): Connection timeout in seconds (default: 5.0) -- **Server name**: A unique identifier for the server (e.g., "filesystem") +**stdio Transport (Local Servers):** - **command**: The executable to run the server - **args**: Command-line arguments for the server - **env**: Environment variables for the server process -- **description** (optional): Description of the server's functionality + +**sse Transport (Server-Sent Events):** +- **url**: The SSE endpoint URL +- **headers** (optional): HTTP headers for authentication +- **sse_read_timeout** (optional): SSE read timeout in seconds (default: 300.0) + +**websocket Transport (WebSocket):** +- **url**: The WebSocket endpoint URL + +**Environment Variable Expansion:** +Headers and other string values support environment variable expansion using `${VARIABLE_NAME}` syntax. This is especially useful for API keys: +```json +{ + "headers": { + "Authorization": "Bearer ${GITHUB_TOKEN}", + "X-API-Key": "${MY_API_KEY}" + } +} +``` #### Available MCP Servers -You can use any of the [official MCP servers](https://modelcontextprotocol.io/examples) or third-party servers. Some popular options include: +OptiLLM supports both local and remote MCP servers: + +##### Local MCP Servers (stdio transport) + +You can use any of the [official MCP servers](https://modelcontextprotocol.io/examples) or third-party servers that run as local processes: - **Filesystem**: `@modelcontextprotocol/server-filesystem` - File operations - **Git**: `mcp-server-git` - Git repository operations - **SQLite**: `@modelcontextprotocol/server-sqlite` - SQLite database access - **Brave Search**: `@modelcontextprotocol/server-brave-search` - Web search capabilities -Example configuration for multiple servers: +##### Remote MCP Servers (SSE/WebSocket transport) + +Remote servers provide centralized access without requiring local installation: + +- **GitHub MCP Server**: `https://api.githubcopilot.com/mcp` - Repository management, issue tracking, and code analysis +- **Third-party servers**: Any MCP server that supports SSE or WebSocket protocols + +##### Example: Comprehensive Configuration ```json { "mcpServers": { "filesystem": { + "transport": "stdio", "command": "npx", "args": ["-y", "@modelcontextprotocol/server-filesystem", "/home/user/documents"], - "env": {} + "description": "Local file system access" }, "search": { + "transport": "stdio", "command": "npx", "args": ["-y", "@modelcontextprotocol/server-brave-search"], "env": { "BRAVE_API_KEY": "your-api-key-here" - } + }, + "description": "Web search capabilities" + }, + "github": { + "transport": "sse", + "url": "https://api.githubcopilot.com/mcp", + "headers": { + "Authorization": "Bearer ${GITHUB_TOKEN}", + "Accept": "text/event-stream" + }, + "description": "GitHub repository and issue management" } }, "log_level": "INFO" @@ -429,11 +561,18 @@ The plugin enhances the system prompt with MCP capabilities so the model knows w Here are some examples of queries that will engage MCP tools: +**Local Server Examples:** - "List all the Python files in my documents directory" (Filesystem) - "What are the recent commits in my Git repository?" (Git) - "Search for the latest information about renewable energy" (Search) - "Query my database for all users who registered this month" (Database) +**Remote Server Examples:** +- "Show me the open issues in my GitHub repository" (GitHub MCP) +- "Create a new branch for the feature I'm working on" (GitHub MCP) +- "What are the most recent pull requests that need review?" (GitHub MCP) +- "Get the file contents from my remote repository" (GitHub MCP) + #### Troubleshooting ##### Logs @@ -447,13 +586,35 @@ Check this log file for connection issues, tool execution errors, and other diag ##### Common Issues +**Local Server Issues (stdio transport):** + 1. **Command not found**: Make sure the server executable is available in your PATH, or use an absolute path in the configuration. -2. **Connection failed**: Verify the server is properly configured and any required API keys are provided. +2. **Access denied**: For filesystem operations, ensure the paths specified in the configuration are accessible to the process. + +**Remote Server Issues (SSE/WebSocket transport):** + +3. **Connection timeout**: Remote servers may take longer to connect. Increase the `timeout` value in your configuration. + +4. **Authentication failed**: Verify your API keys and tokens are correct. For GitHub MCP server, ensure your `GITHUB_TOKEN` environment variable is set with appropriate permissions. + +5. **Network errors**: Check your internet connection and verify the server URL is accessible. + +6. **Environment variable not found**: If using `${VARIABLE_NAME}` syntax, ensure the environment variables are set before starting OptILLM. + +**General Issues:** + +7. **Method not found**: Some servers don't implement all MCP capabilities (tools, resources, prompts). Verify which capabilities the server supports. + +8. **Transport not supported**: Ensure you're using a supported transport: "stdio", "sse", or "websocket". + +**Example: Testing GitHub MCP Connection** -3. **Method not found**: Some servers don't implement all MCP capabilities (tools, resources, prompts). Verify which capabilities the server supports. +To test if your GitHub MCP server configuration is working: -4. **Access denied**: For filesystem operations, ensure the paths specified in the configuration are accessible to the process. +1. Set your GitHub token: `export GITHUB_TOKEN="your-github-token"` +2. Start OptILLM and check the logs at `~/.optillm/logs/mcp_plugin.log` +3. Look for connection success messages and discovered capabilities ## Available parameters From 413577b125262a44b69337974af13a56d7606758 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 29 Sep 2025 18:07:46 +0800 Subject: [PATCH 3/3] Bump version to 0.3.0 Updated __version__ in optillm/__init__.py and project version in pyproject.toml to 0.3.0 for a new release. --- optillm/__init__.py | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optillm/__init__.py b/optillm/__init__.py index 8c142f98..ef07b022 100644 --- a/optillm/__init__.py +++ b/optillm/__init__.py @@ -1,5 +1,5 @@ # Version information -__version__ = "0.2.10" +__version__ = "0.3.0" # Import from server module from .server import ( diff --git a/pyproject.toml b/pyproject.toml index 81f03ffd..77f9895f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "optillm" -version = "0.2.10" +version = "0.3.0" description = "An optimizing inference proxy for LLMs." readme = "README.md" license = "Apache-2.0"