From d283ee66d2f2addfdb7d4fc9d15121470c41d754 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Sat, 8 Mar 2025 13:32:43 +0800 Subject: [PATCH 1/8] fix gradio gui --- requirements.txt | 5 +++-- setup.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/requirements.txt b/requirements.txt index 51dae49d..651f8d3e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,9 +21,10 @@ ipython ipykernel peft bitsandbytes -gradio +gradio<5.16.0 # Constrain spacy version to avoid blis build issues on ARM64 spacy<3.8.0 cerebras_cloud_sdk outlines[transformers] -sentencepiece \ No newline at end of file +sentencepiece +mcp \ No newline at end of file diff --git a/setup.py b/setup.py index 2ebd1f13..6fc6d188 100644 --- a/setup.py +++ b/setup.py @@ -38,12 +38,13 @@ "ipykernel", "peft", "bitsandbytes", - "gradio", + "gradio<5.16.0", # Constrain spacy version to avoid blis build issues on ARM64 "spacy<3.8.0", "cerebras_cloud_sdk", "outlines[transformers]", "sentencepiece", + "mcp", ], entry_points={ 'console_scripts': [ From 0b045a34ad8e0bf1b965e4b1737134a8036a725a Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Sat, 8 Mar 2025 13:32:54 +0800 Subject: [PATCH 2/8] Create mcp_plugin.py --- optillm/plugins/mcp_plugin.py | 639 ++++++++++++++++++++++++++++++++++ 1 file changed, 639 insertions(+) create mode 100644 optillm/plugins/mcp_plugin.py diff --git a/optillm/plugins/mcp_plugin.py b/optillm/plugins/mcp_plugin.py new file mode 100644 index 00000000..9a50bdb0 --- /dev/null +++ b/optillm/plugins/mcp_plugin.py @@ -0,0 +1,639 @@ +import os +import json +import logging +import asyncio +from typing import Dict, Any, Tuple, Optional, List +from dataclasses import dataclass +from enum import Enum +import re +import pydantic +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client +import mcp.types as types + +logger = logging.getLogger(__name__) + +# Plugin identifier +SLUG = "mcp" + +class ServerType(str, Enum): + """Supported server types""" + STDIO = "stdio" + HTTP = "http" + +class ResourceAccess(str, Enum): + """Resource access modes""" + READ_ONLY = "read_only" + READ_WRITE = "read_write" + NONE = "none" + +class ServerConfig(pydantic.BaseModel): + """Configuration for a single MCP server""" + type: ServerType = ServerType.STDIO + command: str + args: List[str] = [] + env: Dict[str, str] = {} + url: Optional[str] = None # For HTTP servers + resource_access: ResourceAccess = ResourceAccess.READ_ONLY + allowed_tools: Optional[List[str]] = None # If None, all tools allowed + description: Optional[str] = None + +class MCPConfig(pydantic.BaseModel): + """Root configuration model""" + mcpServers: Dict[str, ServerConfig] + default_model: Optional[str] = None + log_level: str = "INFO" + +@dataclass +class ToolMatch: + """Represents a matched tool with extracted arguments""" + server_name: str + tool_name: str + arguments: Dict[str, Any] + confidence: float + +@dataclass +class PromptMatch: + """Represents a matched prompt template""" + server_name: str + prompt_name: str + arguments: Dict[str, Any] + +class ArgumentExtractor: + """Extracts arguments from text using LLM""" + + def __init__(self, client, model: str): + self.client = client + self.model = model + + async def extract_arguments(self, text: str, tool: types.Tool) -> Dict[str, Any]: + """Use LLM to extract arguments from text""" + prompt = f""" + Extract arguments for the tool '{tool.name}' from this text: "{text}" + + The tool accepts these arguments: + {json.dumps(tool.inputSchema, indent=2)} + + Return only a JSON object with the extracted arguments. If an argument can't be found, omit it. + """ + + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + temperature=0.1 + ) + + try: + return json.loads(response.choices[0].message.content) + except json.JSONDecodeError: + logger.error("Failed to parse argument extraction response") + return {} + +class ToolMatcher: + """Matches tools to text using semantic and pattern matching""" + + def __init__(self, client, model: str): + self.client = client + self.model = model + self.argument_extractor = ArgumentExtractor(client, model) + + async def find_matching_tools( + self, + text: str, + available_tools: Dict[str, List[types.Tool]] + ) -> List[ToolMatch]: + """Find tools that match the given text""" + matches = [] + + # First pass: Look for explicit tool mentions + for server_name, tools in available_tools.items(): + for tool in tools: + # Check for direct name matches + if tool.name.lower() in text.lower(): + args = await self.argument_extractor.extract_arguments(text, tool) + matches.append(ToolMatch( + server_name=server_name, + tool_name=tool.name, + arguments=args, + confidence=0.9 + )) + continue + + # Check for semantic matches using the tool description + if tool.description and self._semantic_match(text, tool.description): + args = await self.argument_extractor.extract_arguments(text, tool) + matches.append(ToolMatch( + server_name=server_name, + tool_name=tool.name, + arguments=args, + confidence=0.7 + )) + + # Use LLM for additional tool matching if needed + if not matches: + llm_matches = await self._llm_tool_matching(text, available_tools) + matches.extend(llm_matches) + + # Sort by confidence + matches.sort(key=lambda x: x.confidence, reverse=True) + return matches + + def _semantic_match(self, text: str, description: str) -> bool: + """Simple semantic matching using keywords""" + keywords = description.lower().split() + text_words = text.lower().split() + matches = sum(1 for word in keywords if any(w.startswith(word) for w in text_words)) + return matches / len(keywords) > 0.5 + + async def _llm_tool_matching( + self, + text: str, + available_tools: Dict[str, List[types.Tool]] + ) -> List[ToolMatch]: + """Use LLM to find matching tools""" + # Create tool descriptions + tool_descriptions = [] + for server_name, tools in available_tools.items(): + for tool in tools: + desc = f"Server: {server_name}, Tool: {tool.name}" + if tool.description: + desc += f", Description: {tool.description}" + tool_descriptions.append(desc) + + prompt = f""" + Given this user request: "{text}" + + And these available tools: + {json.dumps(tool_descriptions, indent=2)} + + Which tools would be most appropriate to use? Return a JSON array of objects with: + - server_name: The server name + - tool_name: The tool name + - confidence: A number between 0 and 1 indicating confidence + + Only include tools with confidence > 0.5. + """ + + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + temperature=0.1 + ) + + try: + llm_matches = json.loads(response.choices[0].message.content) + matches = [] + + for match in llm_matches: + server_name = match["server_name"] + tool_name = match["tool_name"] + + # Find the tool + tool = next( + (t for t in available_tools[server_name] if t.name == tool_name), + None + ) + + if tool: + args = await self.argument_extractor.extract_arguments(text, tool) + matches.append(ToolMatch( + server_name=server_name, + tool_name=tool_name, + arguments=args, + confidence=match["confidence"] + )) + + return matches + except (json.JSONDecodeError, KeyError): + logger.error("Failed to parse LLM tool matching response") + return [] + +class ResourceManager: + """Manages MCP resources""" + + def __init__(self, client, model: str): + self.client = client + self.model = model + self.resource_cache: Dict[str, List[types.Resource]] = {} + + async def find_relevant_resources( + self, + text: str, + available_resources: Dict[str, List[types.Resource]] + ) -> List[str]: + """Find resources relevant to the given text""" + prompt = f""" + Given this user request: "{text}" + + And these available resources: + {json.dumps([ + { + "server": server, + "resources": [ + {"uri": r.uri, "name": r.name, "description": r.description} + for r in resources + ] + } + for server, resources in available_resources.items() + ], indent=2)} + + Which resources would be most relevant? Return only a JSON array of resource URIs. + """ + + response = self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + temperature=0.1 + ) + + try: + return json.loads(response.choices[0].message.content) + except json.JSONDecodeError: + logger.error("Failed to parse resource relevance response") + return [] + +class MCPClientManager: + """Manages multiple MCP client connections""" + + def __init__(self, config_path: Optional[str] = None): + self.config_path = config_path or os.path.expanduser("~/mcp_config.json") + self.sessions: Dict[str, ClientSession] = {} + self.tools_cache: Dict[str, List[types.Tool]] = {} + self.resources_cache: Dict[str, List[types.Resource]] = {} + self.prompts_cache: Dict[str, List[types.Prompt]] = {} + self.config: Optional[MCPConfig] = None + + def validate_config(self, config_data: Dict[str, Any]) -> MCPConfig: + """Validate configuration using pydantic""" + try: + return MCPConfig(**config_data) + except pydantic.ValidationError as e: + logger.error(f"Invalid configuration: {e}") + raise + + async def initialize_servers(self): + """Initialize connections to all configured servers""" + try: + with open(self.config_path, 'r') as f: + config_data = json.load(f) + + self.config = self.validate_config(config_data) + + for server_name, server_config in self.config.mcpServers.items(): + try: + await self.connect_server(server_name, server_config) + except Exception as e: + logger.error(f"Failed to connect to server {server_name}: {e}") + + except FileNotFoundError: + logger.warning(f"MCP config file not found at {self.config_path}") + except Exception as e: + logger.error(f"Error initializing MCP servers: {e}") + + async def connect_server(self, server_name: str, server_config: ServerConfig): + """Connect to a single MCP server""" + if server_config.type == ServerType.STDIO: + server_params = StdioServerParameters( + command=server_config.command, + args=server_config.args, + env=server_config.env + ) + + transport = await stdio_client(server_params) + read_stream, write_stream = transport + + elif server_config.type == ServerType.HTTP: + # HTTP transport implementation would go here + raise NotImplementedError("HTTP transport not yet implemented") + + try: + session = ClientSession(read_stream, write_stream) + await session.initialize() + + # Cache available capabilities + tools_result = await session.list_tools() + self.tools_cache[server_name] = tools_result.tools + + resources_result = await session.list_resources() + self.resources_cache[server_name] = resources_result.resources + + prompts_result = await session.list_prompts() + self.prompts_cache[server_name] = prompts_result.prompts + + self.sessions[server_name] = session + logger.info(f"Connected to MCP server: {server_name}") + + except Exception as e: + logger.error(f"Error connecting to server {server_name}: {e}") + raise + + async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> str: + """Call a tool on a specific server""" + if server_name not in self.sessions: + raise ValueError(f"Server {server_name} not connected") + + # Validate against allowed tools if configured + server_config = self.config.mcpServers[server_name] + if server_config.allowed_tools is not None: + if tool_name not in server_config.allowed_tools: + raise ValueError(f"Tool {tool_name} not allowed on server {server_name}") + + session = self.sessions[server_name] + result = await session.call_tool(tool_name, arguments) + + # Extract text content from result + text_contents = [] + for content in result.content: + if content.type == "text": + text_contents.append(content.text) + + return "\n".join(text_contents) + + async def read_resource(self, server_name: str, uri: str) -> Tuple[str, Optional[str]]: + """Read a resource from a server""" + if server_name not in self.sessions: + raise ValueError(f"Server {server_name} not connected") + + # Check resource access permissions + server_config = self.config.mcpServers[server_name] + if server_config.resource_access == ResourceAccess.NONE: + raise ValueError(f"Resource access not allowed on server {server_name}") + + session = self.sessions[server_name] + result = await session.read_resource(uri) + + # Return first content and its MIME type + if result.contents: + content = result.contents[0] + return content.text or content.blob or "", content.mimeType + return "", None + + async def get_prompt(self, server_name: str, prompt_name: str, arguments: Dict[str, Any]) -> str: + """Get a prompt from a server""" + if server_name not in self.sessions: + raise ValueError(f"Server {server_name} not connected") + + session = self.sessions[server_name] + result = await session.get_prompt(prompt_name, arguments) + + # Convert prompt messages to text + messages = [] + for msg in result.messages: + if msg.content.type == "text": + messages.append(f"{msg.role}: {msg.content.text}") + + return "\n".join(messages) + + async def cleanup(self): + """Clean up all server connections""" + for session in self.sessions.values(): + await session.aclose() + self.sessions.clear() + self.tools_cache.clear() + self.resources_cache.clear() + self.prompts_cache.clear() + +class MCPPlugin: + """optillm plugin for MCP integration""" + + def __init__(self): + self.client_manager = MCPClientManager() + self.initialized = False + self.tool_matcher: Optional[ToolMatcher] = None + self.resource_manager: Optional[ResourceManager] = None + + async def ensure_initialized(self, client, model: str): + """Initialize if not already done""" + if not self.initialized: + await self.client_manager.initialize_servers() + self.tool_matcher = ToolMatcher(client, model) + self.resource_manager = ResourceManager(client, model) + self.initialized = True + + async def process_request( + self, + messages: List[Dict[str, Any]], + model: str + ) -> str: + """Process the request and handle MCP interactions""" + # Last message contains the current request + current_message = messages[-1]["content"] + + # Find matching tools + tool_matches = await self.tool_matcher.find_matching_tools( + current_message, + self.client_manager.tools_cache + ) + + # Find relevant resources + relevant_resources = await self.resource_manager.find_relevant_resources( + current_message, + self.client_manager.resources_cache + ) + + # Collect context and results + context_parts = [] + + # Add resource content + for uri in relevant_resources: + server_name = uri.split("://")[0] # Simple server extraction from URI + try: + content, mime_type = await self.client_manager.read_resource( + server_name, + uri + ) + if content: + context_parts.append(f"Resource {uri}:\n{content}") + except Exception as e: + logger.error(f"Error reading resource {uri}: {e}") + + # Execute tool calls + for match in tool_matches: + try: + result = await self.client_manager.call_tool( + match.server_name, + match.tool_name, + match.arguments + ) + context_parts.append(f"Tool {match.tool_name} result:\n{result}") + except Exception as e: + logger.error(f"Error calling tool {match.tool_name}: {e}") + context_parts.append(f"Error calling tool {match.tool_name}: {str(e)}") + + # Build final context + context = "\n\n".join(context_parts) + if context: + return f"{current_message}\n\nContext:\n{context}" + return current_message + + async def handle_tool_error( + self, + error: Exception, + tool_match: ToolMatch, + client, + model: str + ) -> str: + """Handle tool execution errors intelligently""" + prompt = f""" + An error occurred while executing tool '{tool_match.tool_name}': + Error: {str(error)} + + The tool was called with these arguments: + {json.dumps(tool_match.arguments, indent=2)} + + Analyze the error and provide a brief explanation of what went wrong. + Focus on possible solutions or alternatives. + """ + + response = client.chat.completions.create( + model=model, + messages=[{"role": "user", "content": prompt}], + temperature=0.3 + ) + + return response.choices[0].message.content + +async def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str, int]: + """Main plugin execution function""" + plugin = MCPPlugin() + + try: + await plugin.ensure_initialized(client, model) + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": initial_query} + ] + + processed_query = await plugin.process_request(messages, model) + + # Create a system prompt that includes MCP capabilities + enhanced_system_prompt = f""" + {system_prompt} + + You have access to the following MCP capabilities: + + Tools: + {json.dumps([ + { + "server": server, + "tools": [ + {"name": t.name, "description": t.description} + for t in tools + ] + } + for server, tools in plugin.client_manager.tools_cache.items() + ], indent=2)} + + Resources: + {json.dumps([ + { + "server": server, + "resources": [ + {"uri": r.uri, "name": r.name, "description": r.description} + for r in resources + ] + } + for server, resources in plugin.client_manager.resources_cache.items() + ], indent=2)} + + Prompts: + {json.dumps([ + { + "server": server, + "prompts": [ + {"name": p.name, "description": p.description} + for p in prompts + ] + } + for server, prompts in plugin.client_manager.prompts_cache.items() + ], indent=2)} + """ + + # Pass the processed query and enhanced system prompt to the model + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": enhanced_system_prompt}, + {"role": "user", "content": processed_query} + ], + temperature=0.7, + ) + + return response.choices[0].message.content, response.usage.completion_tokens + + except Exception as e: + logger.error(f"Error in MCP plugin: {str(e)}") + # In case of error, pass through the original query + return initial_query, 0 + finally: + await plugin.client_manager.cleanup() + +def validate_config_file(config_path: str) -> None: + """Validate MCP configuration file""" + try: + with open(config_path, 'r') as f: + config_data = json.load(f) + + MCPConfig(**config_data) + except FileNotFoundError: + raise ValueError(f"Configuration file not found: {config_path}") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON in configuration file: {e}") + except pydantic.ValidationError as e: + raise ValueError(f"Invalid configuration format: {e}") + except Exception as e: + raise ValueError(f"Error validating configuration: {e}") + +def create_default_config(config_path: str) -> None: + """Create a default MCP configuration file""" + default_config = { + "mcpServers": { + "example": { + "type": "stdio", + "command": "python", + "args": ["example_server.py"], + "env": {}, + "resource_access": "read_only", + "description": "Example MCP server" + } + }, + "log_level": "INFO" + } + + os.makedirs(os.path.dirname(config_path), exist_ok=True) + with open(config_path, 'w') as f: + json.dump(default_config, f, indent=2) + +async def test_server_connection( + server_name: str, + server_config: ServerConfig +) -> Tuple[bool, str]: + """Test connection to a single MCP server""" + try: + if server_config.type == ServerType.STDIO: + server_params = StdioServerParameters( + command=server_config.command, + args=server_config.args, + env=server_config.env + ) + + transport = await stdio_client(server_params) + read_stream, write_stream = transport + + session = ClientSession(read_stream, write_stream) + await session.initialize() + + # Test basic operations + await session.list_tools() + await session.list_resources() + await session.list_prompts() + + await session.aclose() + return True, "Connection successful" + + elif server_config.type == ServerType.HTTP: + return False, "HTTP transport not yet implemented" + + except Exception as e: + return False, f"Connection failed: {str(e)}" \ No newline at end of file From 84383766aca4dcc55afbc1494b0e77b56535f647 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Sat, 8 Mar 2025 13:44:40 +0800 Subject: [PATCH 3/8] Update optillm.py --- optillm.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/optillm.py b/optillm.py index bb509198..0e2e8939 100644 --- a/optillm.py +++ b/optillm.py @@ -118,6 +118,29 @@ def get_config(): plugin_approaches = {} +def normalize_message_content(messages): + """ + Ensure all message content fields are strings, not lists. + Some models don't handle list-format content correctly. + """ + normalized_messages = [] + for message in messages: + normalized_message = message.copy() + content = message.get('content', '') + + # Convert list content to string if needed + if isinstance(content, list): + # Extract text content from the list + text_content = ' '.join( + item.get('text', '') for item in content + if isinstance(item, dict) and item.get('type') == 'text' + ) + normalized_message['content'] = text_content + + normalized_messages.append(normalized_message) + + return normalized_messages + def none_approach( client: Any, model: str, @@ -143,10 +166,13 @@ def none_approach( model = model[5:] try: - # Make the direct completion call with original messages and parameters + # Normalize message content to ensure it's always string + normalized_messages = normalize_message_content(original_messages) + + # Make the direct completion call with normalized messages and parameters response = client.chat.completions.create( model=model, - messages=original_messages, + messages=normalized_messages, **kwargs ) From 7f822d97e19d795db8467c9765a3cfa7c587e3ce Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Sat, 8 Mar 2025 19:37:53 +0800 Subject: [PATCH 4/8] Update optillm.py --- optillm.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/optillm.py b/optillm.py index 0e2e8939..498ba2ec 100644 --- a/optillm.py +++ b/optillm.py @@ -346,12 +346,32 @@ def execute_single_approach(approach, system_prompt, initial_query, client, mode import inspect sig = inspect.signature(plugin_func) - if 'request_config' in sig.parameters: - # Plugin supports request_config - return plugin_func(system_prompt, initial_query, client, model, request_config=request_config) + # Check if the plugin function is async + is_async = inspect.iscoroutinefunction(plugin_func) + + if is_async: + # For async functions, we need to run them in an event loop + import asyncio + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + if 'request_config' in sig.parameters: + # Plugin supports request_config + result = loop.run_until_complete(plugin_func(system_prompt, initial_query, client, model, request_config=request_config)) + else: + # Legacy plugin without request_config support + result = loop.run_until_complete(plugin_func(system_prompt, initial_query, client, model)) + return result + finally: + loop.close() else: - # Legacy plugin without request_config support - return plugin_func(system_prompt, initial_query, client, model) + # For synchronous functions, call directly + if 'request_config' in sig.parameters: + # Plugin supports request_config + return plugin_func(system_prompt, initial_query, client, model, request_config=request_config) + else: + # Legacy plugin without request_config support + return plugin_func(system_prompt, initial_query, client, model) else: raise ValueError(f"Unknown approach: {approach}") From 164bd74189bc1602bd3adc17ded8cc913114c839 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 10 Mar 2025 12:20:29 +0800 Subject: [PATCH 5/8] Update mcp_plugin.py --- optillm/plugins/mcp_plugin.py | 1110 ++++++++++++++++----------------- 1 file changed, 548 insertions(+), 562 deletions(-) diff --git a/optillm/plugins/mcp_plugin.py b/optillm/plugins/mcp_plugin.py index 9a50bdb0..b93c249e 100644 --- a/optillm/plugins/mcp_plugin.py +++ b/optillm/plugins/mcp_plugin.py @@ -1,639 +1,625 @@ +""" +MCP Plugin for OptILLM + +This plugin integrates the Model Context Protocol (MCP) with OptILLM, +allowing access to external tools, resources, and prompts through MCP servers. +""" + import os import json import logging import asyncio -from typing import Dict, Any, Tuple, Optional, List -from dataclasses import dataclass -from enum import Enum +import sys +import time import re -import pydantic +from typing import Dict, List, Any, Optional, Tuple, Set, Union, Callable +from dataclasses import dataclass +from pathlib import Path +import traceback + from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client import mcp.types as types -logger = logging.getLogger(__name__) +# Configure logging +LOG_DIR = Path.home() / ".optillm" / "logs" +LOG_DIR.mkdir(parents=True, exist_ok=True) +LOG_FILE = LOG_DIR / "mcp_plugin.log" -# Plugin identifier -SLUG = "mcp" +# Configure root logger +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + handlers=[ + logging.FileHandler(LOG_FILE), + logging.StreamHandler() + ] +) -class ServerType(str, Enum): - """Supported server types""" - STDIO = "stdio" - HTTP = "http" +logger = logging.getLogger("optillm.mcp_plugin") -class ResourceAccess(str, Enum): - """Resource access modes""" - READ_ONLY = "read_only" - READ_WRITE = "read_write" - NONE = "none" +# Plugin identifier +SLUG = "mcp" -class ServerConfig(pydantic.BaseModel): +@dataclass +class ServerConfig: """Configuration for a single MCP server""" - type: ServerType = ServerType.STDIO command: str - args: List[str] = [] - env: Dict[str, str] = {} - url: Optional[str] = None # For HTTP servers - resource_access: ResourceAccess = ResourceAccess.READ_ONLY - allowed_tools: Optional[List[str]] = None # If None, all tools allowed + args: List[str] + env: Dict[str, str] description: Optional[str] = None - -class MCPConfig(pydantic.BaseModel): - """Root configuration model""" - mcpServers: Dict[str, ServerConfig] - default_model: Optional[str] = None - log_level: str = "INFO" - -@dataclass -class ToolMatch: - """Represents a matched tool with extracted arguments""" - server_name: str - tool_name: str - arguments: Dict[str, Any] - confidence: float -@dataclass -class PromptMatch: - """Represents a matched prompt template""" - server_name: str - prompt_name: str - arguments: Dict[str, Any] + @classmethod + def from_dict(cls, config: Dict[str, Any]) -> 'ServerConfig': + """Create ServerConfig from a dictionary""" + return cls( + command=config.get("command", ""), + args=config.get("args", []), + env=config.get("env", {}), + description=config.get("description") + ) -class ArgumentExtractor: - """Extracts arguments from text using LLM""" +class MCPConfigManager: + """Manages MCP configuration loading and validation""" - def __init__(self, client, model: str): - self.client = client - self.model = model - - async def extract_arguments(self, text: str, tool: types.Tool) -> Dict[str, Any]: - """Use LLM to extract arguments from text""" - prompt = f""" - Extract arguments for the tool '{tool.name}' from this text: "{text}" - - The tool accepts these arguments: - {json.dumps(tool.inputSchema, indent=2)} - - Return only a JSON object with the extracted arguments. If an argument can't be found, omit it. - """ - - response = self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": prompt}], - temperature=0.1 - ) - + def __init__(self, config_path: Optional[str] = None): + """Initialize with optional custom config path""" + if config_path: + self.config_path = Path(config_path) + else: + self.config_path = Path.home() / ".optillm" / "mcp_config.json" + + # Default configuration + self.servers: Dict[str, ServerConfig] = {} + self.log_level: str = "INFO" + + def load_config(self) -> bool: + """Load configuration from file""" try: - return json.loads(response.choices[0].message.content) - except json.JSONDecodeError: - logger.error("Failed to parse argument extraction response") - return {} - -class ToolMatcher: - """Matches tools to text using semantic and pattern matching""" + if not self.config_path.exists(): + logger.warning(f"MCP config file not found at {self.config_path}") + return False + + with open(self.config_path, 'r') as f: + config = json.load(f) + + # Set log level + self.log_level = config.get("log_level", "INFO") + log_level = getattr(logging, self.log_level.upper(), logging.INFO) + logger.setLevel(log_level) + + # Load server configurations + servers_config = config.get("mcpServers", {}) + for server_name, server_config in servers_config.items(): + self.servers[server_name] = ServerConfig.from_dict(server_config) + + logger.info(f"Loaded configuration with {len(self.servers)} servers") + return True + + except Exception as e: + logger.error(f"Error loading MCP configuration: {e}") + logger.error(traceback.format_exc()) + return False - def __init__(self, client, model: str): - self.client = client - self.model = model - self.argument_extractor = ArgumentExtractor(client, model) - - async def find_matching_tools( - self, - text: str, - available_tools: Dict[str, List[types.Tool]] - ) -> List[ToolMatch]: - """Find tools that match the given text""" - matches = [] - - # First pass: Look for explicit tool mentions - for server_name, tools in available_tools.items(): - for tool in tools: - # Check for direct name matches - if tool.name.lower() in text.lower(): - args = await self.argument_extractor.extract_arguments(text, tool) - matches.append(ToolMatch( - server_name=server_name, - tool_name=tool.name, - arguments=args, - confidence=0.9 - )) - continue + def create_default_config(self) -> bool: + """Create a default configuration file if none exists""" + try: + if self.config_path.exists(): + return True - # Check for semantic matches using the tool description - if tool.description and self._semantic_match(text, tool.description): - args = await self.argument_extractor.extract_arguments(text, tool) - matches.append(ToolMatch( - server_name=server_name, - tool_name=tool.name, - arguments=args, - confidence=0.7 - )) - - # Use LLM for additional tool matching if needed - if not matches: - llm_matches = await self._llm_tool_matching(text, available_tools) - matches.extend(llm_matches) - - # Sort by confidence - matches.sort(key=lambda x: x.confidence, reverse=True) - return matches - - def _semantic_match(self, text: str, description: str) -> bool: - """Simple semantic matching using keywords""" - keywords = description.lower().split() - text_words = text.lower().split() - matches = sum(1 for word in keywords if any(w.startswith(word) for w in text_words)) - return matches / len(keywords) > 0.5 + default_config = { + "mcpServers": {}, + "log_level": "INFO" + } + + self.config_path.parent.mkdir(parents=True, exist_ok=True) + with open(self.config_path, 'w') as f: + json.dump(default_config, f, indent=2) + + logger.info(f"Created default configuration at {self.config_path}") + return True + + except Exception as e: + logger.error(f"Error creating default configuration: {e}") + return False - async def _llm_tool_matching( - self, - text: str, - available_tools: Dict[str, List[types.Tool]] - ) -> List[ToolMatch]: - """Use LLM to find matching tools""" - # Create tool descriptions - tool_descriptions = [] - for server_name, tools in available_tools.items(): - for tool in tools: - desc = f"Server: {server_name}, Tool: {tool.name}" - if tool.description: - desc += f", Description: {tool.description}" - tool_descriptions.append(desc) - - prompt = f""" - Given this user request: "{text}" - - And these available tools: - {json.dumps(tool_descriptions, indent=2)} - - Which tools would be most appropriate to use? Return a JSON array of objects with: - - server_name: The server name - - tool_name: The tool name - - confidence: A number between 0 and 1 indicating confidence - - Only include tools with confidence > 0.5. - """ - - response = self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": prompt}], - temperature=0.1 - ) - +class MCPServer: + """Represents a connection to an MCP server""" + + def __init__(self, server_name: str, config: ServerConfig): + self.server_name = server_name + self.config = config + self.session: Optional[ClientSession] = None + self.transport: Optional[Tuple] = None + self.connected = False + self.tools: List[types.Tool] = [] + self.resources: List[types.Resource] = [] + self.prompts: List[types.Prompt] = [] + + async def connect(self) -> bool: + """Connect to the MCP server""" try: - llm_matches = json.loads(response.choices[0].message.content) - matches = [] + logger.info(f"Connecting to MCP server: {self.server_name}") + + # Create server parameters + server_params = StdioServerParameters( + command=self.config.command, + args=self.config.args, + env=self.config.env + ) - for match in llm_matches: - server_name = match["server_name"] - tool_name = match["tool_name"] + # Create transport using async with + transport = None + try: + # Using context manager in a way that's compatible with asyncio + ctx = stdio_client(server_params) + transport = await ctx.__aenter__() + self.transport = transport - # Find the tool - tool = next( - (t for t in available_tools[server_name] if t.name == tool_name), - None - ) + read_stream, write_stream = transport - if tool: - args = await self.argument_extractor.extract_arguments(text, tool) - matches.append(ToolMatch( - server_name=server_name, - tool_name=tool_name, - arguments=args, - confidence=match["confidence"] - )) - - return matches - except (json.JSONDecodeError, KeyError): - logger.error("Failed to parse LLM tool matching response") - return [] - -class ResourceManager: - """Manages MCP resources""" + # Create session + self.session = ClientSession(read_stream, write_stream) + + # Initialize session + await self.session.initialize() + + # Discover capabilities + await self.discover_capabilities() + + self.connected = True + logger.info(f"Successfully connected to MCP server: {self.server_name}") + return True + + except Exception as e: + # Make sure to clean up resources in case of an error + if transport: + try: + await ctx.__aexit__(type(e), e, e.__traceback__) + except: + pass + raise + + except Exception as e: + logger.error(f"Error connecting to MCP server {self.server_name}: {e}") + logger.error(traceback.format_exc()) + + if self.session: + try: + await self.session.aclose() + except: + pass + + self.session = None + self.connected = False + return False - def __init__(self, client, model: str): - self.client = client - self.model = model - self.resource_cache: Dict[str, List[types.Resource]] = {} - - async def find_relevant_resources( - self, - text: str, - available_resources: Dict[str, List[types.Resource]] - ) -> List[str]: - """Find resources relevant to the given text""" - prompt = f""" - Given this user request: "{text}" - - And these available resources: - {json.dumps([ - { - "server": server, - "resources": [ - {"uri": r.uri, "name": r.name, "description": r.description} - for r in resources - ] - } - for server, resources in available_resources.items() - ], indent=2)} - - Which resources would be most relevant? Return only a JSON array of resource URIs. - """ - - response = self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": prompt}], - temperature=0.1 - ) - + async def discover_capabilities(self) -> bool: + """Discover the server's capabilities""" + if not self.session: + logger.error(f"Cannot discover capabilities for {self.server_name}: Not connected") + return False + try: - return json.loads(response.choices[0].message.content) - except json.JSONDecodeError: - logger.error("Failed to parse resource relevance response") - return [] - -class MCPClientManager: - """Manages multiple MCP client connections""" + # List tools + tools_result = await self.session.list_tools() + self.tools = tools_result.tools + + # List resources + resources_result = await self.session.list_resources() + self.resources = resources_result.resources + + # List prompts + prompts_result = await self.session.list_prompts() + self.prompts = prompts_result.prompts + + logger.info(f"Server {self.server_name} capabilities: " + f"{len(self.tools)} tools, {len(self.resources)} resources, " + f"{len(self.prompts)} prompts") + return True + + except Exception as e: + logger.error(f"Error discovering capabilities for {self.server_name}: {e}") + logger.error(traceback.format_exc()) + return False - def __init__(self, config_path: Optional[str] = None): - self.config_path = config_path or os.path.expanduser("~/mcp_config.json") - self.sessions: Dict[str, ClientSession] = {} - self.tools_cache: Dict[str, List[types.Tool]] = {} - self.resources_cache: Dict[str, List[types.Resource]] = {} - self.prompts_cache: Dict[str, List[types.Prompt]] = {} - self.config: Optional[MCPConfig] = None - - def validate_config(self, config_data: Dict[str, Any]) -> MCPConfig: - """Validate configuration using pydantic""" - try: - return MCPConfig(**config_data) - except pydantic.ValidationError as e: - logger.error(f"Invalid configuration: {e}") - raise + async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Call a tool on this server""" + if not self.session or not self.connected: + logger.error(f"Cannot call tool for {self.server_name}: Not connected") + return {"error": f"Server {self.server_name} is not connected"} - async def initialize_servers(self): - """Initialize connections to all configured servers""" try: - with open(self.config_path, 'r') as f: - config_data = json.load(f) + # Find the tool + tool = next((t for t in self.tools if t.name == tool_name), None) + if not tool: + return {"error": f"Tool {tool_name} not found on server {self.server_name}"} + + # Call the tool + logger.info(f"Calling tool {tool_name} on server {self.server_name} with arguments: {arguments}") + result = await self.session.call_tool(tool_name, arguments) - self.config = self.validate_config(config_data) + # Process the result + content_results = [] + for content in result.content: + if content.type == "text": + content_results.append({ + "type": "text", + "text": content.text + }) + elif content.type == "image": + content_results.append({ + "type": "image", + "data": content.data, + "mimeType": content.mimeType + }) + + return { + "result": content_results, + "is_error": result.isError + } - for server_name, server_config in self.config.mcpServers.items(): - try: - await self.connect_server(server_name, server_config) - except Exception as e: - logger.error(f"Failed to connect to server {server_name}: {e}") - - except FileNotFoundError: - logger.warning(f"MCP config file not found at {self.config_path}") except Exception as e: - logger.error(f"Error initializing MCP servers: {e}") - - async def connect_server(self, server_name: str, server_config: ServerConfig): - """Connect to a single MCP server""" - if server_config.type == ServerType.STDIO: - server_params = StdioServerParameters( - command=server_config.command, - args=server_config.args, - env=server_config.env - ) + logger.error(f"Error calling tool {tool_name} on server {self.server_name}: {e}") + logger.error(traceback.format_exc()) + return {"error": f"Error calling tool: {str(e)}"} + + async def read_resource(self, uri: str) -> Dict[str, Any]: + """Read a resource from this server""" + if not self.session or not self.connected: + logger.error(f"Cannot read resource for {self.server_name}: Not connected") + return {"error": f"Server {self.server_name} is not connected"} - transport = await stdio_client(server_params) - read_stream, write_stream = transport + try: + # Find the resource + resource = next((r for r in self.resources if r.uri == uri), None) + if not resource: + return {"error": f"Resource {uri} not found on server {self.server_name}"} + + # Read the resource + logger.info(f"Reading resource {uri} from server {self.server_name}") + result = await self.session.read_resource(uri) - elif server_config.type == ServerType.HTTP: - # HTTP transport implementation would go here - raise NotImplementedError("HTTP transport not yet implemented") + # Process the result + content_results = [] + for content in result.contents: + if content.text: + content_results.append({ + "type": "text", + "uri": content.uri or uri, + "text": content.text, + "mimeType": content.mimeType + }) + elif content.blob: + content_results.append({ + "type": "binary", + "uri": content.uri or uri, + "data": content.blob, + "mimeType": content.mimeType + }) - try: - session = ClientSession(read_stream, write_stream) - await session.initialize() + return { + "result": content_results + } - # Cache available capabilities - tools_result = await session.list_tools() - self.tools_cache[server_name] = tools_result.tools + except Exception as e: + logger.error(f"Error reading resource {uri} from server {self.server_name}: {e}") + logger.error(traceback.format_exc()) + return {"error": f"Error reading resource: {str(e)}"} + + async def get_prompt(self, prompt_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Get a prompt from this server""" + if not self.session or not self.connected: + logger.error(f"Cannot get prompt for {self.server_name}: Not connected") + return {"error": f"Server {self.server_name} is not connected"} - resources_result = await session.list_resources() - self.resources_cache[server_name] = resources_result.resources + try: + # Find the prompt + prompt = next((p for p in self.prompts if p.name == prompt_name), None) + if not prompt: + return {"error": f"Prompt {prompt_name} not found on server {self.server_name}"} + + # Get the prompt + logger.info(f"Getting prompt {prompt_name} from server {self.server_name} with arguments: {arguments}") + result = await self.session.get_prompt(prompt_name, arguments) - prompts_result = await session.list_prompts() - self.prompts_cache[server_name] = prompts_result.prompts + # Process the result + messages = [] + for msg in result.messages: + if msg.content.type == "text": + messages.append({ + "role": msg.role, + "content": msg.content.text + }) + elif msg.content.type == "image": + messages.append({ + "role": msg.role, + "content": { + "type": "image", + "data": msg.content.data, + "mimeType": msg.content.mimeType + } + }) - self.sessions[server_name] = session - logger.info(f"Connected to MCP server: {server_name}") + return { + "result": messages + } except Exception as e: - logger.error(f"Error connecting to server {server_name}: {e}") - raise - - async def call_tool(self, server_name: str, tool_name: str, arguments: Dict[str, Any]) -> str: - """Call a tool on a specific server""" - if server_name not in self.sessions: - raise ValueError(f"Server {server_name} not connected") - - # Validate against allowed tools if configured - server_config = self.config.mcpServers[server_name] - if server_config.allowed_tools is not None: - if tool_name not in server_config.allowed_tools: - raise ValueError(f"Tool {tool_name} not allowed on server {server_name}") - - session = self.sessions[server_name] - result = await session.call_tool(tool_name, arguments) - - # Extract text content from result - text_contents = [] - for content in result.content: - if content.type == "text": - text_contents.append(content.text) - - return "\n".join(text_contents) - - async def read_resource(self, server_name: str, uri: str) -> Tuple[str, Optional[str]]: - """Read a resource from a server""" - if server_name not in self.sessions: - raise ValueError(f"Server {server_name} not connected") - - # Check resource access permissions - server_config = self.config.mcpServers[server_name] - if server_config.resource_access == ResourceAccess.NONE: - raise ValueError(f"Resource access not allowed on server {server_name}") - - session = self.sessions[server_name] - result = await session.read_resource(uri) - - # Return first content and its MIME type - if result.contents: - content = result.contents[0] - return content.text or content.blob or "", content.mimeType - return "", None - - async def get_prompt(self, server_name: str, prompt_name: str, arguments: Dict[str, Any]) -> str: - """Get a prompt from a server""" - if server_name not in self.sessions: - raise ValueError(f"Server {server_name} not connected") - - session = self.sessions[server_name] - result = await session.get_prompt(prompt_name, arguments) - - # Convert prompt messages to text - messages = [] - for msg in result.messages: - if msg.content.type == "text": - messages.append(f"{msg.role}: {msg.content.text}") - - return "\n".join(messages) - - async def cleanup(self): - """Clean up all server connections""" - for session in self.sessions.values(): - await session.aclose() - self.sessions.clear() - self.tools_cache.clear() - self.resources_cache.clear() - self.prompts_cache.clear() + logger.error(f"Error getting prompt {prompt_name} from server {self.server_name}: {e}") + logger.error(traceback.format_exc()) + return {"error": f"Error getting prompt: {str(e)}"} + + async def close(self): + """Close the connection to the server""" + if self.session: + try: + await self.session.aclose() + logger.info(f"Closed connection to MCP server: {self.server_name}") + except Exception as e: + logger.error(f"Error closing connection to {self.server_name}: {e}") + finally: + self.session = None + self.connected = False -class MCPPlugin: - """optillm plugin for MCP integration""" +class MCPServerManager: + """Manages MCP server connections and capabilities""" - def __init__(self): - self.client_manager = MCPClientManager() + def __init__(self, config_manager: MCPConfigManager): + self.config_manager = config_manager + self.servers: Dict[str, MCPServer] = {} self.initialized = False - self.tool_matcher: Optional[ToolMatcher] = None - self.resource_manager: Optional[ResourceManager] = None - - async def ensure_initialized(self, client, model: str): - """Initialize if not already done""" - if not self.initialized: - await self.client_manager.initialize_servers() - self.tool_matcher = ToolMatcher(client, model) - self.resource_manager = ResourceManager(client, model) - self.initialized = True - - async def process_request( - self, - messages: List[Dict[str, Any]], - model: str - ) -> str: - """Process the request and handle MCP interactions""" - # Last message contains the current request - current_message = messages[-1]["content"] - - # Find matching tools - tool_matches = await self.tool_matcher.find_matching_tools( - current_message, - self.client_manager.tools_cache - ) - - # Find relevant resources - relevant_resources = await self.resource_manager.find_relevant_resources( - current_message, - self.client_manager.resources_cache - ) + + async def initialize(self) -> bool: + """Initialize connections to all configured servers""" + if self.initialized: + return True - # Collect context and results - context_parts = [] + # Create servers + for server_name, server_config in self.config_manager.servers.items(): + self.servers[server_name] = MCPServer(server_name, server_config) + + # Connect to all servers asynchronously + if self.servers: + connect_tasks = [server.connect() for server in self.servers.values()] + results = await asyncio.gather(*connect_tasks, return_exceptions=True) - # Add resource content - for uri in relevant_resources: - server_name = uri.split("://")[0] # Simple server extraction from URI - try: - content, mime_type = await self.client_manager.read_resource( - server_name, - uri - ) - if content: - context_parts.append(f"Resource {uri}:\n{content}") - except Exception as e: - logger.error(f"Error reading resource {uri}: {e}") + # Check how many servers connected successfully + success_count = sum(1 for r in results if r is True) + logger.info(f"Connected to {success_count}/{len(self.servers)} MCP servers") - # Execute tool calls - for match in tool_matches: - try: - result = await self.client_manager.call_tool( - match.server_name, - match.tool_name, - match.arguments - ) - context_parts.append(f"Tool {match.tool_name} result:\n{result}") - except Exception as e: - logger.error(f"Error calling tool {match.tool_name}: {e}") - context_parts.append(f"Error calling tool {match.tool_name}: {str(e)}") - - # Build final context - context = "\n\n".join(context_parts) - if context: - return f"{current_message}\n\nContext:\n{context}" - return current_message - - async def handle_tool_error( - self, - error: Exception, - tool_match: ToolMatch, - client, - model: str - ) -> str: - """Handle tool execution errors intelligently""" - prompt = f""" - An error occurred while executing tool '{tool_match.tool_name}': - Error: {str(error)} + if success_count > 0: + self.initialized = True + return True + else: + logger.error("Failed to connect to any MCP servers") + return False + else: + logger.warning("No MCP servers configured") + self.initialized = True + return True + + def get_tools_for_model(self) -> List[Dict[str, Any]]: + """Get tools from all servers in a format suitable for the model's tool-calling API""" + tools = [] - The tool was called with these arguments: - {json.dumps(tool_match.arguments, indent=2)} + for server_name, server in self.servers.items(): + if not server.connected or not server.tools: + continue + + for tool in server.tools: + # Convert MCP tool to model tool format + tool_entry = { + "type": "function", + "function": { + "name": f"{server_name}.{tool.name}", + "description": tool.description or f"Tool {tool.name} from server {server_name}", + "parameters": tool.inputSchema + } + } + tools.append(tool_entry) + + return tools + + def get_capabilities_description(self) -> str: + """Get a formatted description of all server capabilities""" + if not self.servers: + return "No MCP servers available." + + description_parts = [] - Analyze the error and provide a brief explanation of what went wrong. - Focus on possible solutions or alternatives. - """ + for server_name, server in self.servers.items(): + if not server.connected: + description_parts.append(f"## {server_name}\nServer connection failed or not established.\n") + continue + + server_description = f"## {server_name}\n" + + if server.config.description: + server_description += f"{server.config.description}\n\n" + + if server.tools: + server_description += "### Tools\n" + for tool in server.tools: + server_description += f"- {server_name}.{tool.name}: {tool.description or 'No description'}\n" + server_description += "\n" + + if server.resources: + server_description += "### Resources\n" + for resource in server.resources: + server_description += f"- {resource.uri}: {resource.name or 'No name'} - {resource.description or 'No description'}\n" + server_description += "\n" + + if server.prompts: + server_description += "### Prompts\n" + for prompt in server.prompts: + server_description += f"- {prompt.name}: {prompt.description or 'No description'}\n" + server_description += "\n" + + description_parts.append(server_description) + + return "\n".join(description_parts) + + async def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """Execute a tool on the appropriate server""" + if "." not in tool_name: + return {"error": f"Invalid tool name format: {tool_name}. Expected format: server_name.tool_name"} + + server_name, function_name = tool_name.split(".", 1) - response = client.chat.completions.create( - model=model, - messages=[{"role": "user", "content": prompt}], - temperature=0.3 - ) + if server_name not in self.servers: + return {"error": f"Server not found: {server_name}"} + + server = self.servers[server_name] + if not server.connected: + return {"error": f"Server {server_name} is not connected"} + + # Execute the tool + return await server.call_tool(function_name, arguments) + + async def close(self): + """Close all server connections""" + if not self.servers: + return + + # Close all server connections in parallel + close_tasks = [server.close() for server in self.servers.values()] + await asyncio.gather(*close_tasks, return_exceptions=True) - return response.choices[0].message.content + self.servers = {} + self.initialized = False async def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str, int]: - """Main plugin execution function""" - plugin = MCPPlugin() + """ + Main plugin execution function called by OptILLM + + Args: + system_prompt: System prompt + initial_query: User query + client: OptILLM client + model: Model identifier + + Returns: + Tuple of (response text, token usage) + """ + logger.info(f"MCP Plugin run called with model: {model}") + + # Create server manager + config_manager = MCPConfigManager() + server_manager = MCPServerManager(config_manager) try: - await plugin.ensure_initialized(client, model) - - messages = [ - {"role": "system", "content": system_prompt}, - {"role": "user", "content": initial_query} - ] - - processed_query = await plugin.process_request(messages, model) - - # Create a system prompt that includes MCP capabilities - enhanced_system_prompt = f""" - {system_prompt} + # Load configuration + if not config_manager.load_config(): + # Try to create default config + config_manager.create_default_config() + # Try loading again + if not config_manager.load_config(): + logger.error("Failed to load or create MCP configuration") + # In case of no configuration, pass through the original query + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": initial_query} + ], + temperature=0.7, + ) + return response.choices[0].message.content, response.usage.completion_tokens - You have access to the following MCP capabilities: + # Initialize server manager + await server_manager.initialize() - Tools: - {json.dumps([ - { - "server": server, - "tools": [ - {"name": t.name, "description": t.description} - for t in tools - ] - } - for server, tools in plugin.client_manager.tools_cache.items() - ], indent=2)} + # Get tools formatted for the model + tools = server_manager.get_tools_for_model() - Resources: - {json.dumps([ - { - "server": server, - "resources": [ - {"uri": r.uri, "name": r.name, "description": r.description} - for r in resources - ] - } - for server, resources in plugin.client_manager.resources_cache.items() - ], indent=2)} + # Get capabilities description + capabilities_description = server_manager.get_capabilities_description() - Prompts: - {json.dumps([ - { - "server": server, - "prompts": [ - {"name": p.name, "description": p.description} - for p in prompts - ] - } - for server, prompts in plugin.client_manager.prompts_cache.items() - ], indent=2)} - """ + # Enhance system prompt with MCP capabilities + enhanced_system_prompt = f"{system_prompt}\n\nYou have access to the following MCP capabilities:\n\n{capabilities_description}" - # Pass the processed query and enhanced system prompt to the model + # First request - ask the model what it wants to do + logger.info("Sending initial request to model") response = client.chat.completions.create( model=model, messages=[ {"role": "system", "content": enhanced_system_prompt}, - {"role": "user", "content": processed_query} + {"role": "user", "content": initial_query} ], + tools=tools if tools else None, # Only include tools if available temperature=0.7, ) - return response.choices[0].message.content, response.usage.completion_tokens - - except Exception as e: - logger.error(f"Error in MCP plugin: {str(e)}") - # In case of error, pass through the original query - return initial_query, 0 - finally: - await plugin.client_manager.cleanup() - -def validate_config_file(config_path: str) -> None: - """Validate MCP configuration file""" - try: - with open(config_path, 'r') as f: - config_data = json.load(f) + # Check if the model wants to use any tools + response_message = response.choices[0].message + response_content = response_message.content or "" - MCPConfig(**config_data) - except FileNotFoundError: - raise ValueError(f"Configuration file not found: {config_path}") - except json.JSONDecodeError as e: - raise ValueError(f"Invalid JSON in configuration file: {e}") - except pydantic.ValidationError as e: - raise ValueError(f"Invalid configuration format: {e}") - except Exception as e: - raise ValueError(f"Error validating configuration: {e}") - -def create_default_config(config_path: str) -> None: - """Create a default MCP configuration file""" - default_config = { - "mcpServers": { - "example": { - "type": "stdio", - "command": "python", - "args": ["example_server.py"], - "env": {}, - "resource_access": "read_only", - "description": "Example MCP server" - } - }, - "log_level": "INFO" - } - - os.makedirs(os.path.dirname(config_path), exist_ok=True) - with open(config_path, 'w') as f: - json.dump(default_config, f, indent=2) - -async def test_server_connection( - server_name: str, - server_config: ServerConfig -) -> Tuple[bool, str]: - """Test connection to a single MCP server""" - try: - if server_config.type == ServerType.STDIO: - server_params = StdioServerParameters( - command=server_config.command, - args=server_config.args, - env=server_config.env - ) - - transport = await stdio_client(server_params) - read_stream, write_stream = transport - - session = ClientSession(read_stream, write_stream) - await session.initialize() + # Check for tool calls + if hasattr(response_message, "tool_calls") and response_message.tool_calls: + logger.info(f"Model requested tool calls: {len(response_message.tool_calls)}") - # Test basic operations - await session.list_tools() - await session.list_resources() - await session.list_prompts() + # Create new messages with the original system and user message + messages = [ + {"role": "system", "content": enhanced_system_prompt}, + {"role": "user", "content": initial_query}, + {"role": "assistant", "content": response_content, "tool_calls": response_message.tool_calls} + ] - await session.aclose() - return True, "Connection successful" + # Process each tool call + for tool_call in response_message.tool_calls: + tool_call_id = tool_call.id + tool_name = tool_call.function.name + try: + # Parse arguments + arguments = json.loads(tool_call.function.arguments) + + # Execute tool + logger.info(f"Executing tool: {tool_name} with arguments: {arguments}") + result = await server_manager.execute_tool(tool_name, arguments) + + # Add tool result to messages + messages.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "content": json.dumps(result) + }) + except Exception as e: + logger.error(f"Error processing tool call {tool_name}: {e}") + messages.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "content": json.dumps({"error": f"Error: {str(e)}"}) + }) - elif server_config.type == ServerType.HTTP: - return False, "HTTP transport not yet implemented" + # Send follow-up request with tool results + logger.info("Sending follow-up request to model with tool results") + final_response = client.chat.completions.create( + model=model, + messages=messages, + tools=tools if tools else None, # Keep tools available in case the model wants to make additional calls + temperature=0.7, + ) + final_message = final_response.choices[0].message + response_text = final_message.content or "" + token_usage = final_response.usage.completion_tokens + else: + # Model didn't call any tools, use its initial response + response_text = response_content + token_usage = response.usage.completion_tokens + + return response_text, token_usage + except Exception as e: - return False, f"Connection failed: {str(e)}" \ No newline at end of file + logger.error(f"Error in MCP plugin run: {e}") + logger.error(traceback.format_exc()) + # In case of error, pass through the original query + return initial_query, 0 + + finally: + # Always clean up server connections + try: + await server_manager.close() + except Exception as e: + logger.error(f"Error cleaning up server connections: {e}") From 8f1a3be88097500e85bfdabbc14b07b7f41c65ac Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 10 Mar 2025 12:38:18 +0800 Subject: [PATCH 6/8] Update mcp_plugin.py --- optillm/plugins/mcp_plugin.py | 145 ++++++++++++++++++++++++++++++---- 1 file changed, 131 insertions(+), 14 deletions(-) diff --git a/optillm/plugins/mcp_plugin.py b/optillm/plugins/mcp_plugin.py index b93c249e..23490cd5 100644 --- a/optillm/plugins/mcp_plugin.py +++ b/optillm/plugins/mcp_plugin.py @@ -12,6 +12,8 @@ import sys import time import re +import shutil +import subprocess from typing import Dict, List, Any, Optional, Tuple, Set, Union, Callable from dataclasses import dataclass from pathlib import Path @@ -41,6 +43,115 @@ # Plugin identifier SLUG = "mcp" +def find_executable(cmd: str) -> Optional[str]: + """ + Find the full path to an executable command. + + This function will: + 1. Check if the command exists in PATH + 2. Check common install locations + 3. Try npm prefix paths + + Args: + cmd: The command to find + + Returns: + Full path to the executable if found, None otherwise + """ + # First check if it's already a full path + if os.path.isfile(cmd) and os.access(cmd, os.X_OK): + return cmd + + # Next check if it's in PATH + cmd_path = shutil.which(cmd) + if cmd_path: + logger.info(f"Found {cmd} in PATH at {cmd_path}") + return cmd_path + + # Try common locations for Node.js tools + common_paths = [ + "/usr/local/bin", + "/usr/bin", + "/bin", + "/opt/homebrew/bin", # macOS with Homebrew + "/opt/homebrew/opt/node/bin", # Specific Homebrew Node.js location + "/usr/local/opt/node/bin", + os.path.expanduser("~/.npm-global/bin"), + os.path.expanduser("~/.nvm/current/bin"), + os.path.expanduser("~/npm/bin"), + os.path.expanduser("~/.npm/bin"), + ] + + for path in common_paths: + full_path = os.path.join(path, cmd) + if os.path.isfile(full_path) and os.access(full_path, os.X_OK): + logger.info(f"Found {cmd} at {full_path}") + return full_path + + # Try using npm to find global bin path + try: + npm_bin_path = subprocess.run( + ["npm", "bin", "-g"], + capture_output=True, + text=True, + check=True + ).stdout.strip() + + full_path = os.path.join(npm_bin_path, cmd) + if os.path.isfile(full_path) and os.access(full_path, os.X_OK): + logger.info(f"Found {cmd} in npm global bin at {full_path}") + return full_path + except: + pass + + # If all else fails, create a wrapper script that sources profile files + wrapper_path = create_command_wrapper(cmd) + if wrapper_path: + logger.info(f"Created wrapper script for {cmd} at {wrapper_path}") + return wrapper_path + + logger.error(f"Could not find executable: {cmd}") + return None + +def create_command_wrapper(cmd: str) -> Optional[str]: + """ + Create a shell wrapper script for a command that might be in PATH after shell initialization. + + Args: + cmd: The command to wrap + + Returns: + Path to the wrapper script if successful, None otherwise + """ + try: + wrapper_dir = Path.home() / ".optillm" / "wrappers" + wrapper_dir.mkdir(parents=True, exist_ok=True) + + wrapper_path = wrapper_dir / f"{cmd}_wrapper.sh" + + with open(wrapper_path, 'w') as f: + f.write(f"""#!/bin/bash +# Source profile to get correct PATH +if [ -f ~/.bash_profile ]; then + source ~/.bash_profile +elif [ -f ~/.profile ]; then + source ~/.profile +elif [ -f ~/.zshrc ]; then + source ~/.zshrc +fi + +# Run command with all arguments +exec {cmd} "$@" +""") + + # Make the script executable + os.chmod(wrapper_path, 0o755) + + return str(wrapper_path) + except Exception as e: + logger.error(f"Error creating wrapper script: {e}") + return None + @dataclass class ServerConfig: """Configuration for a single MCP server""" @@ -141,27 +252,36 @@ async def connect(self) -> bool: try: logger.info(f"Connecting to MCP server: {self.server_name}") - # Create server parameters + # Find the full path to the command + full_command = find_executable(self.config.command) + if not full_command: + logger.error(f"Failed to find executable for command: {self.config.command}") + logger.error("Please make sure the command is in your PATH or use an absolute path") + return False + + # Create environment with PATH included + merged_env = os.environ.copy() + if self.config.env: + merged_env.update(self.config.env) + + # Create server parameters with the full command path server_params = StdioServerParameters( - command=self.config.command, + command=full_command, args=self.config.args, - env=self.config.env + env=merged_env ) # Create transport using async with - transport = None try: - # Using context manager in a way that's compatible with asyncio + # Using context manager directly with try/except/finally for cleanup ctx = stdio_client(server_params) transport = await ctx.__aenter__() self.transport = transport read_stream, write_stream = transport - # Create session + # Create and initialize session self.session = ClientSession(read_stream, write_stream) - - # Initialize session await self.session.initialize() # Discover capabilities @@ -172,12 +292,9 @@ async def connect(self) -> bool: return True except Exception as e: - # Make sure to clean up resources in case of an error - if transport: - try: - await ctx.__aexit__(type(e), e, e.__traceback__) - except: - pass + # Clean up resources in case of error + if 'ctx' in locals() and 'transport' in locals(): + await ctx.__aexit__(type(e), e, e.__traceback__) raise except Exception as e: From e7846b8a6228dae8d064ffc7bc5468af778c6ba1 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 10 Mar 2025 13:48:26 +0800 Subject: [PATCH 7/8] Update mcp_plugin.py init version --- optillm/plugins/mcp_plugin.py | 676 ++++++++++++++-------------------- 1 file changed, 285 insertions(+), 391 deletions(-) diff --git a/optillm/plugins/mcp_plugin.py b/optillm/plugins/mcp_plugin.py index 23490cd5..10a39ad0 100644 --- a/optillm/plugins/mcp_plugin.py +++ b/optillm/plugins/mcp_plugin.py @@ -22,6 +22,7 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client import mcp.types as types +from mcp.shared.exceptions import McpError # Configure logging LOG_DIR = Path.home() / ".optillm" / "logs" @@ -47,11 +48,6 @@ def find_executable(cmd: str) -> Optional[str]: """ Find the full path to an executable command. - This function will: - 1. Check if the command exists in PATH - 2. Check common install locations - 3. Try npm prefix paths - Args: cmd: The command to find @@ -68,18 +64,14 @@ def find_executable(cmd: str) -> Optional[str]: logger.info(f"Found {cmd} in PATH at {cmd_path}") return cmd_path - # Try common locations for Node.js tools + # Try common locations common_paths = [ "/usr/local/bin", "/usr/bin", "/bin", - "/opt/homebrew/bin", # macOS with Homebrew - "/opt/homebrew/opt/node/bin", # Specific Homebrew Node.js location - "/usr/local/opt/node/bin", + "/opt/homebrew/bin", os.path.expanduser("~/.npm-global/bin"), os.path.expanduser("~/.nvm/current/bin"), - os.path.expanduser("~/npm/bin"), - os.path.expanduser("~/.npm/bin"), ] for path in common_paths: @@ -87,71 +79,10 @@ def find_executable(cmd: str) -> Optional[str]: if os.path.isfile(full_path) and os.access(full_path, os.X_OK): logger.info(f"Found {cmd} at {full_path}") return full_path - - # Try using npm to find global bin path - try: - npm_bin_path = subprocess.run( - ["npm", "bin", "-g"], - capture_output=True, - text=True, - check=True - ).stdout.strip() - - full_path = os.path.join(npm_bin_path, cmd) - if os.path.isfile(full_path) and os.access(full_path, os.X_OK): - logger.info(f"Found {cmd} in npm global bin at {full_path}") - return full_path - except: - pass - - # If all else fails, create a wrapper script that sources profile files - wrapper_path = create_command_wrapper(cmd) - if wrapper_path: - logger.info(f"Created wrapper script for {cmd} at {wrapper_path}") - return wrapper_path - + logger.error(f"Could not find executable: {cmd}") return None -def create_command_wrapper(cmd: str) -> Optional[str]: - """ - Create a shell wrapper script for a command that might be in PATH after shell initialization. - - Args: - cmd: The command to wrap - - Returns: - Path to the wrapper script if successful, None otherwise - """ - try: - wrapper_dir = Path.home() / ".optillm" / "wrappers" - wrapper_dir.mkdir(parents=True, exist_ok=True) - - wrapper_path = wrapper_dir / f"{cmd}_wrapper.sh" - - with open(wrapper_path, 'w') as f: - f.write(f"""#!/bin/bash -# Source profile to get correct PATH -if [ -f ~/.bash_profile ]; then - source ~/.bash_profile -elif [ -f ~/.profile ]; then - source ~/.profile -elif [ -f ~/.zshrc ]; then - source ~/.zshrc -fi - -# Run command with all arguments -exec {cmd} "$@" -""") - - # Make the script executable - os.chmod(wrapper_path, 0o755) - - return str(wrapper_path) - except Exception as e: - logger.error(f"Error creating wrapper script: {e}") - return None - @dataclass class ServerConfig: """Configuration for a single MCP server""" @@ -240,255 +171,131 @@ class MCPServer: def __init__(self, server_name: str, config: ServerConfig): self.server_name = server_name self.config = config - self.session: Optional[ClientSession] = None - self.transport: Optional[Tuple] = None + self.tools = [] + self.resources = [] + self.prompts = [] self.connected = False - self.tools: List[types.Tool] = [] - self.resources: List[types.Resource] = [] - self.prompts: List[types.Prompt] = [] + self.has_tools_capability = False + self.has_resources_capability = False + self.has_prompts_capability = False - async def connect(self) -> bool: - """Connect to the MCP server""" - try: - logger.info(f"Connecting to MCP server: {self.server_name}") - - # Find the full path to the command - full_command = find_executable(self.config.command) - if not full_command: - logger.error(f"Failed to find executable for command: {self.config.command}") - logger.error("Please make sure the command is in your PATH or use an absolute path") - return False - - # Create environment with PATH included - merged_env = os.environ.copy() - if self.config.env: - merged_env.update(self.config.env) + async def connect_and_discover(self) -> bool: + """Connect to the server and discover capabilities using proper context management""" + logger.info(f"Connecting to MCP server: {self.server_name}") + + # Find the full path to the command + full_command = find_executable(self.config.command) + if not full_command: + logger.error(f"Failed to find executable for command: {self.config.command}") + return False - # Create server parameters with the full command path - server_params = StdioServerParameters( - command=full_command, - args=self.config.args, - env=merged_env + # Create environment with PATH included + merged_env = os.environ.copy() + if self.config.env: + merged_env.update(self.config.env) + + # Create server parameters + server_params = StdioServerParameters( + command=full_command, + args=self.config.args, + env=merged_env + ) + + try: + # Start the server separately to see its output + process = await asyncio.create_subprocess_exec( + full_command, + *self.config.args, + env=merged_env, + stderr=asyncio.subprocess.PIPE ) - # Create transport using async with - try: - # Using context manager directly with try/except/finally for cleanup - ctx = stdio_client(server_params) - transport = await ctx.__aenter__() - self.transport = transport - - read_stream, write_stream = transport - - # Create and initialize session - self.session = ClientSession(read_stream, write_stream) - await self.session.initialize() - - # Discover capabilities - await self.discover_capabilities() - - self.connected = True - logger.info(f"Successfully connected to MCP server: {self.server_name}") - return True - - except Exception as e: - # Clean up resources in case of error - if 'ctx' in locals() and 'transport' in locals(): - await ctx.__aexit__(type(e), e, e.__traceback__) - raise - - except Exception as e: - logger.error(f"Error connecting to MCP server {self.server_name}: {e}") - logger.error(traceback.format_exc()) - - if self.session: - try: - await self.session.aclose() - except: - pass + # Log startup message from stderr + async def log_stderr(): + while True: + line = await process.stderr.readline() + if not line: + break + logger.info(f"Server {self.server_name} stderr: {line.decode().strip()}") - self.session = None - self.connected = False - return False - - async def discover_capabilities(self) -> bool: - """Discover the server's capabilities""" - if not self.session: - logger.error(f"Cannot discover capabilities for {self.server_name}: Not connected") - return False + # Start stderr logging task + asyncio.create_task(log_stderr()) - try: - # List tools - tools_result = await self.session.list_tools() - self.tools = tools_result.tools + # Wait a bit for the server to start up + await asyncio.sleep(2) - # List resources - resources_result = await self.session.list_resources() - self.resources = resources_result.resources - - # List prompts - prompts_result = await self.session.list_prompts() - self.prompts = prompts_result.prompts + # Use the MCP client with proper context management + async with stdio_client(server_params) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + logger.info(f"Connected to server: {self.server_name}") + + # Initialize session + result = await session.initialize() + logger.info(f"Server {self.server_name} initialized with capabilities: {result.capabilities}") + + # Check which capabilities the server supports + server_capabilities = result.capabilities + + # Discover tools if supported + if hasattr(server_capabilities, "tools"): + self.has_tools_capability = True + logger.info(f"Discovering tools for {self.server_name}") + try: + tools_result = await session.list_tools() + self.tools = tools_result.tools + logger.info(f"Found {len(self.tools)} tools") + except McpError as e: + logger.warning(f"Failed to list tools: {e}") + + # Discover resources if supported + if hasattr(server_capabilities, "resources"): + self.has_resources_capability = True + logger.info(f"Discovering resources for {self.server_name}") + try: + resources_result = await session.list_resources() + self.resources = resources_result.resources + logger.info(f"Found {len(self.resources)} resources") + except McpError as e: + logger.warning(f"Failed to list resources: {e}") + + # Discover prompts if supported + if hasattr(server_capabilities, "prompts"): + self.has_prompts_capability = True + logger.info(f"Discovering prompts for {self.server_name}") + try: + prompts_result = await session.list_prompts() + self.prompts = prompts_result.prompts + logger.info(f"Found {len(self.prompts)} prompts") + except McpError as e: + logger.warning(f"Failed to list prompts: {e}") + + logger.info(f"Server {self.server_name} capabilities: " + f"{len(self.tools)} tools, {len(self.resources)} resources, " + f"{len(self.prompts)} prompts") - logger.info(f"Server {self.server_name} capabilities: " - f"{len(self.tools)} tools, {len(self.resources)} resources, " - f"{len(self.prompts)} prompts") + self.connected = True return True except Exception as e: - logger.error(f"Error discovering capabilities for {self.server_name}: {e}") + logger.error(f"Error connecting to MCP server {self.server_name}: {e}") logger.error(traceback.format_exc()) return False - - async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: - """Call a tool on this server""" - if not self.session or not self.connected: - logger.error(f"Cannot call tool for {self.server_name}: Not connected") - return {"error": f"Server {self.server_name} is not connected"} - - try: - # Find the tool - tool = next((t for t in self.tools if t.name == tool_name), None) - if not tool: - return {"error": f"Tool {tool_name} not found on server {self.server_name}"} - - # Call the tool - logger.info(f"Calling tool {tool_name} on server {self.server_name} with arguments: {arguments}") - result = await self.session.call_tool(tool_name, arguments) - - # Process the result - content_results = [] - for content in result.content: - if content.type == "text": - content_results.append({ - "type": "text", - "text": content.text - }) - elif content.type == "image": - content_results.append({ - "type": "image", - "data": content.data, - "mimeType": content.mimeType - }) - - return { - "result": content_results, - "is_error": result.isError - } - - except Exception as e: - logger.error(f"Error calling tool {tool_name} on server {self.server_name}: {e}") - logger.error(traceback.format_exc()) - return {"error": f"Error calling tool: {str(e)}"} - - async def read_resource(self, uri: str) -> Dict[str, Any]: - """Read a resource from this server""" - if not self.session or not self.connected: - logger.error(f"Cannot read resource for {self.server_name}: Not connected") - return {"error": f"Server {self.server_name} is not connected"} - - try: - # Find the resource - resource = next((r for r in self.resources if r.uri == uri), None) - if not resource: - return {"error": f"Resource {uri} not found on server {self.server_name}"} - - # Read the resource - logger.info(f"Reading resource {uri} from server {self.server_name}") - result = await self.session.read_resource(uri) - - # Process the result - content_results = [] - for content in result.contents: - if content.text: - content_results.append({ - "type": "text", - "uri": content.uri or uri, - "text": content.text, - "mimeType": content.mimeType - }) - elif content.blob: - content_results.append({ - "type": "binary", - "uri": content.uri or uri, - "data": content.blob, - "mimeType": content.mimeType - }) - - return { - "result": content_results - } - - except Exception as e: - logger.error(f"Error reading resource {uri} from server {self.server_name}: {e}") - logger.error(traceback.format_exc()) - return {"error": f"Error reading resource: {str(e)}"} - - async def get_prompt(self, prompt_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: - """Get a prompt from this server""" - if not self.session or not self.connected: - logger.error(f"Cannot get prompt for {self.server_name}: Not connected") - return {"error": f"Server {self.server_name} is not connected"} - - try: - # Find the prompt - prompt = next((p for p in self.prompts if p.name == prompt_name), None) - if not prompt: - return {"error": f"Prompt {prompt_name} not found on server {self.server_name}"} - - # Get the prompt - logger.info(f"Getting prompt {prompt_name} from server {self.server_name} with arguments: {arguments}") - result = await self.session.get_prompt(prompt_name, arguments) - - # Process the result - messages = [] - for msg in result.messages: - if msg.content.type == "text": - messages.append({ - "role": msg.role, - "content": msg.content.text - }) - elif msg.content.type == "image": - messages.append({ - "role": msg.role, - "content": { - "type": "image", - "data": msg.content.data, - "mimeType": msg.content.mimeType - } - }) - - return { - "result": messages - } - - except Exception as e: - logger.error(f"Error getting prompt {prompt_name} from server {self.server_name}: {e}") - logger.error(traceback.format_exc()) - return {"error": f"Error getting prompt: {str(e)}"} - - async def close(self): - """Close the connection to the server""" - if self.session: - try: - await self.session.aclose() - logger.info(f"Closed connection to MCP server: {self.server_name}") - except Exception as e: - logger.error(f"Error closing connection to {self.server_name}: {e}") - finally: - self.session = None - self.connected = False class MCPServerManager: - """Manages MCP server connections and capabilities""" + """Manages MCP servers and capabilities""" def __init__(self, config_manager: MCPConfigManager): self.config_manager = config_manager self.servers: Dict[str, MCPServer] = {} self.initialized = False + + # Cache of capabilities + self.all_tools = [] + self.all_resources = [] + self.all_prompts = [] async def initialize(self) -> bool: - """Initialize connections to all configured servers""" + """Initialize and cache all server capabilities""" if self.initialized: return True @@ -496,50 +303,65 @@ async def initialize(self) -> bool: for server_name, server_config in self.config_manager.servers.items(): self.servers[server_name] = MCPServer(server_name, server_config) - # Connect to all servers asynchronously - if self.servers: - connect_tasks = [server.connect() for server in self.servers.values()] - results = await asyncio.gather(*connect_tasks, return_exceptions=True) - - # Check how many servers connected successfully - success_count = sum(1 for r in results if r is True) - logger.info(f"Connected to {success_count}/{len(self.servers)} MCP servers") - - if success_count > 0: - self.initialized = True - return True - else: - logger.error("Failed to connect to any MCP servers") - return False - else: - logger.warning("No MCP servers configured") - self.initialized = True - return True + # Connect to all servers and discover capabilities + for server_name, server in self.servers.items(): + success = await server.connect_and_discover() + if success: + # Cache server capabilities + for tool in server.tools: + self.all_tools.append({ + "server": server_name, + "name": tool.name, + "description": tool.description, + "input_schema": tool.inputSchema + }) + + for resource in server.resources: + self.all_resources.append({ + "server": server_name, + "uri": resource.uri, + "name": resource.name, + "description": resource.description + }) + + for prompt in server.prompts: + self.all_prompts.append({ + "server": server_name, + "name": prompt.name, + "description": prompt.description, + "arguments": prompt.arguments + }) + + self.initialized = True + + # Check if we successfully connected to any servers + connected_servers = sum(1 for server in self.servers.values() if server.connected) + logger.info(f"Connected to {connected_servers}/{len(self.servers)} MCP servers") + return connected_servers > 0 def get_tools_for_model(self) -> List[Dict[str, Any]]: - """Get tools from all servers in a format suitable for the model's tool-calling API""" + """Get tools in a format suitable for the model's tool-calling API""" tools = [] - for server_name, server in self.servers.items(): - if not server.connected or not server.tools: - continue - - for tool in server.tools: - # Convert MCP tool to model tool format - tool_entry = { - "type": "function", - "function": { - "name": f"{server_name}.{tool.name}", - "description": tool.description or f"Tool {tool.name} from server {server_name}", - "parameters": tool.inputSchema - } + for tool_info in self.all_tools: + server_name = tool_info["server"] + tool_name = tool_info["name"] + + # Format for model tools API + tool_entry = { + "type": "function", + "function": { + "name": f"{server_name}.{tool_name}", + "description": tool_info["description"] or f"Tool {tool_name} from server {server_name}", + "parameters": tool_info["input_schema"] } - tools.append(tool_entry) + } + tools.append(tool_entry) return tools def get_capabilities_description(self) -> str: - """Get a formatted description of all server capabilities""" + """Get a description of all capabilities""" if not self.servers: return "No MCP servers available." @@ -576,35 +398,74 @@ def get_capabilities_description(self) -> str: description_parts.append(server_description) return "\n".join(description_parts) + +async def execute_tool(server_name: str, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: + """ + Execute a tool on an MCP server - async def execute_tool(self, tool_name: str, arguments: Dict[str, Any]) -> Dict[str, Any]: - """Execute a tool on the appropriate server""" - if "." not in tool_name: - return {"error": f"Invalid tool name format: {tool_name}. Expected format: server_name.tool_name"} - - server_name, function_name = tool_name.split(".", 1) - - if server_name not in self.servers: - return {"error": f"Server not found: {server_name}"} - - server = self.servers[server_name] - if not server.connected: - return {"error": f"Server {server_name} is not connected"} - - # Execute the tool - return await server.call_tool(function_name, arguments) + This function creates a fresh connection for each tool execution to ensure reliability. + """ + logger.info(f"Executing tool {tool_name} on server {server_name} with arguments: {arguments}") - async def close(self): - """Close all server connections""" - if not self.servers: - return - - # Close all server connections in parallel - close_tasks = [server.close() for server in self.servers.values()] - await asyncio.gather(*close_tasks, return_exceptions=True) - - self.servers = {} - self.initialized = False + # Load configuration + config_manager = MCPConfigManager() + if not config_manager.load_config(): + return {"error": "Failed to load MCP configuration"} + + # Get server configuration + server_config = config_manager.servers.get(server_name) + if not server_config: + return {"error": f"Server {server_name} not found in configuration"} + + # Find executable + full_command = find_executable(server_config.command) + if not full_command: + return {"error": f"Failed to find executable for command: {server_config.command}"} + + # Create environment with PATH included + merged_env = os.environ.copy() + if server_config.env: + merged_env.update(server_config.env) + + # Create server parameters + server_params = StdioServerParameters( + command=full_command, + args=server_config.args, + env=merged_env + ) + + try: + # Use the MCP client with proper context management + async with stdio_client(server_params) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: + # Call the tool and get the result + logger.info(f"Calling tool {tool_name} with arguments: {arguments}") + result = await session.call_tool(tool_name, arguments) + + # Process the result + content_results = [] + for content in result.content: + if content.type == "text": + content_results.append({ + "type": "text", + "text": content.text + }) + elif content.type == "image": + content_results.append({ + "type": "image", + "data": content.data, + "mimeType": content.mimeType + }) + + return { + "result": content_results, + "is_error": result.isError + } + + except Exception as e: + logger.error(f"Error executing tool {tool_name} on server {server_name}: {e}") + logger.error(traceback.format_exc()) + return {"error": f"Error executing tool: {str(e)}"} async def run(system_prompt: str, initial_query: str, client, model: str) -> Tuple[str, int]: """ @@ -621,12 +482,9 @@ async def run(system_prompt: str, initial_query: str, client, model: str) -> Tup """ logger.info(f"MCP Plugin run called with model: {model}") - # Create server manager - config_manager = MCPConfigManager() - server_manager = MCPServerManager(config_manager) - try: # Load configuration + config_manager = MCPConfigManager() if not config_manager.load_config(): # Try to create default config config_manager.create_default_config() @@ -645,10 +503,34 @@ async def run(system_prompt: str, initial_query: str, client, model: str) -> Tup return response.choices[0].message.content, response.usage.completion_tokens # Initialize server manager - await server_manager.initialize() + server_manager = MCPServerManager(config_manager) + success = await server_manager.initialize() + + if not success: + logger.warning("Failed to connect to any MCP servers, falling back to default behavior") + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": initial_query} + ], + temperature=0.7, + ) + return response.choices[0].message.content, response.usage.completion_tokens # Get tools formatted for the model tools = server_manager.get_tools_for_model() + if not tools: + logger.warning("No tools available from MCP servers") + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": initial_query} + ], + temperature=0.7, + ) + return response.choices[0].message.content, response.usage.completion_tokens # Get capabilities description capabilities_description = server_manager.get_capabilities_description() @@ -686,27 +568,38 @@ async def run(system_prompt: str, initial_query: str, client, model: str) -> Tup # Process each tool call for tool_call in response_message.tool_calls: tool_call_id = tool_call.id - tool_name = tool_call.function.name - try: - # Parse arguments - arguments = json.loads(tool_call.function.arguments) - - # Execute tool - logger.info(f"Executing tool: {tool_name} with arguments: {arguments}") - result = await server_manager.execute_tool(tool_name, arguments) + full_tool_name = tool_call.function.name + + # Split into server and tool name + if "." in full_tool_name: + server_name, tool_name = full_tool_name.split(".", 1) - # Add tool result to messages - messages.append({ - "role": "tool", - "tool_call_id": tool_call_id, - "content": json.dumps(result) - }) - except Exception as e: - logger.error(f"Error processing tool call {tool_name}: {e}") + try: + # Parse arguments + arguments = json.loads(tool_call.function.arguments) + + # Execute tool (creates a fresh connection for reliability) + result = await execute_tool(server_name, tool_name, arguments) + + # Add tool result to messages + messages.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "content": json.dumps(result) + }) + except Exception as e: + logger.error(f"Error processing tool call {full_tool_name}: {e}") + messages.append({ + "role": "tool", + "tool_call_id": tool_call_id, + "content": json.dumps({"error": f"Error: {str(e)}"}) + }) + else: + # Invalid tool name format messages.append({ "role": "tool", "tool_call_id": tool_call_id, - "content": json.dumps({"error": f"Error: {str(e)}"}) + "content": json.dumps({"error": f"Invalid tool name format: {full_tool_name}. Expected format: server_name.tool_name"}) }) # Send follow-up request with tool results @@ -732,11 +625,12 @@ async def run(system_prompt: str, initial_query: str, client, model: str) -> Tup logger.error(f"Error in MCP plugin run: {e}") logger.error(traceback.format_exc()) # In case of error, pass through the original query - return initial_query, 0 - - finally: - # Always clean up server connections - try: - await server_manager.close() - except Exception as e: - logger.error(f"Error cleaning up server connections: {e}") + response = client.chat.completions.create( + model=model, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": initial_query} + ], + temperature=0.7, + ) + return response.choices[0].message.content, response.usage.completion_tokens From 8dbf4a1cd069ef211ddbefa66664150c35743a6e Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 10 Mar 2025 15:35:33 +0800 Subject: [PATCH 8/8] bump version --- optillm/__init__.py | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/optillm/__init__.py b/optillm/__init__.py index 9ff9e0a4..81258dfe 100644 --- a/optillm/__init__.py +++ b/optillm/__init__.py @@ -2,7 +2,7 @@ import os # Version information -__version__ = "0.1.6" +__version__ = "0.1.7" # Get the path to the root optillm.py spec = util.spec_from_file_location( diff --git a/setup.py b/setup.py index 6fc6d188..8e5d7315 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name="optillm", - version="0.1.6", + version="0.1.7", packages=find_packages(include=['optillm', 'optillm.*']), # This ensures all subpackages are included py_modules=['optillm'], package_data={