Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"watchfiles>=1.1.0",
"sqlite-vec>=0.1.6",
"tiktoken>=0.12.0",
"semver>=3.0.4",
]

[dependency-groups]
Expand Down
189 changes: 91 additions & 98 deletions src/mcp_optimizer/toolhive/toolhive_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
"""

import asyncio
import time
from functools import wraps
from typing import Any, Awaitable, Callable, Self, TypeVar
from urllib.parse import urlparse

import httpx
import structlog
from semver import Version

from mcp_optimizer.toolhive.api_models.core import Workload
from mcp_optimizer.toolhive.api_models.registry import ImageMetadata, Registry, RemoteServerMetadata
Expand All @@ -31,6 +31,12 @@ class ToolhiveConnectionError(Exception):
pass


class ToolhiveScanError(Exception):
"""Exception raised when unable to find ToolHive in the specified port range."""

pass


class ToolhiveClient:
"""Client for interacting with the Toolhive API."""

Expand Down Expand Up @@ -86,31 +92,33 @@ def __init__(

self._discover_port(port)

def _discover_port(self, port: int | None = None) -> None:
async def _discover_port_async(self, port: int | None = None) -> None:
"""
Discover the ToolHive port.
Async version: Discover the ToolHive port.

Args:
port: Optional specific port to try first
"""
if port is not None:
for attempt in range(3):
if self._is_toolhive_available(self.thv_host, port):
try:
_, port = await self._is_toolhive_available(self.thv_host, port)
self.thv_port = port
break
logger.warning(
"ToolHive not available at specified host/port, retrying...",
host=self.thv_host,
port=port,
attempt=attempt + 1,
)
time.sleep(1)
except ToolhiveScanError:
logger.warning(
"ToolHive not available at specified host/port, retrying...",
host=self.thv_host,
port=port,
attempt=attempt + 1,
)
await asyncio.sleep(1)
# If port is not found yet (either not specified or retries failed),
# try scanning the port range
if self.thv_port is None:
try:
# Scan for ToolHive in the port range
self.thv_port = self._scan_for_toolhive(
self.thv_port = await self._scan_for_toolhive(
self.thv_host, self.scan_port_start, self.scan_port_end
)
except Exception as e:
Expand All @@ -120,8 +128,42 @@ def _discover_port(self, port: int | None = None) -> None:
self.base_url = f"http://{self.thv_host}:{self.thv_port}"
logger.info("ToolhiveClient initialized", host=self.thv_host, port=self.thv_port)

async def _is_toolhive_available_async(self, host: str, port: int) -> bool:
"""Async version: Check if ToolHive is available at the given host and port."""
def _discover_port(self, port: int | None = None) -> None:
"""
Discover the ToolHive port.
Detects if there's a running event loop and executes appropriately.

Args:
port: Optional specific port to try first
"""
try:
# Try to get the running event loop
asyncio.get_running_loop()
# If we get here, there's a running loop - we need to create a task
# This shouldn't happen in __init__, but we'll handle it gracefully
raise RuntimeError(
"_discover_port called from async context. Use _discover_port_async instead."
)
except RuntimeError:
# No running loop - safe to use asyncio.run()
asyncio.run(self._discover_port_async(port))

def _parse_toolhive_version(self, version_str: str) -> Version:
"""Parse ToolHive version string into a Version object."""
try:
version = Version.parse(version_str.replace("v", ""))
return version
except (ValueError, TypeError) as e:
logger.warning("Invalid semver version string", version=version_str, error=str(e))
raise ToolhiveScanError(f"Invalid ToolHive version string: {version_str}") from e

async def _is_toolhive_available(self, host: str, port: int) -> tuple[Version, int]:
"""
Check if ToolHive is available at the given host and port.

Returns:
Tuple of (Version object, port) if available, raises ToolhiveScanError otherwise
"""
try:
async with httpx.AsyncClient(timeout=1.0) as client:
response = await client.get(f"http://{host}:{port}/api/v1beta/version")
Expand All @@ -137,60 +179,52 @@ async def _is_toolhive_available_async(self, host: str, port: int) -> bool:
port=port,
response=data,
)
return False
return True
except (ValueError, KeyError):
raise ToolhiveScanError(
f"Port {port} on host {host} did not respond with ToolHive format"
)
parsed_version = self._parse_toolhive_version(data["version"])
logger.info(
"Found ToolHive instance", host=host, port=port, version=str(parsed_version)
)
return parsed_version, port
except (ValueError, KeyError, ToolhiveScanError) as e:
logger.debug("Port responded but could not parse JSON", host=host, port=port)
return False
except (httpx.HTTPError, OSError):
return False

async def _scan_for_toolhive_async(
self, host: str, scan_port_start: int, scan_port_end: int
) -> int:
raise ToolhiveScanError(
f"Port {port} on host {host} did not respond with valid JSON"
) from e
except (httpx.HTTPError, OSError) as e:
logger.debug("Error checking ToolHive availability", host=host, port=port, error=str(e))
raise ToolhiveScanError(f"Error checking ToolHive availability on {host}:{port}") from e

async def _scan_for_toolhive(self, host: str, scan_port_start: int, scan_port_end: int) -> int:
"""Async version: Scan for ToolHive in the specified port range."""
logger.info(
"Scanning for ToolHive", host=host, port_range=f"{scan_port_start}-{scan_port_end}"
)

for port in range(scan_port_start, scan_port_end + 1):
if await self._is_toolhive_available_async(host, port):
logger.info("Found ToolHive", host=host, port=port)
return port
task_outcomes = await asyncio.gather(
*[
self._is_toolhive_available(host, port)
for port in range(scan_port_start, scan_port_end + 1)
],
return_exceptions=True,
)
thv_version_port = [
version_port
for version_port in task_outcomes
if not isinstance(version_port, Exception)
]

thv_version_port.sort(key=lambda x: x[0], reverse=True)

if thv_version_port:
return thv_version_port[0][1]

# If no port found, raise an error
raise ConnectionError(
f"ToolHive not found on {host} in port range {scan_port_start}-{scan_port_end}"
)

async def _discover_port_async(self, port: int | None = None) -> None:
"""Async version: Discover the ToolHive port."""
if port is not None:
for attempt in range(3):
if await self._is_toolhive_available_async(self.thv_host, port):
self.thv_port = port
break
logger.warning(
"ToolHive not available at specified host/port, retrying...",
host=self.thv_host,
port=port,
attempt=attempt + 1,
)
await asyncio.sleep(1)

# If port is not found yet, try scanning the port range
if self.thv_port is None:
try:
self.thv_port = await self._scan_for_toolhive_async(
self.thv_host, self.scan_port_start, self.scan_port_end
)
except Exception as e:
logger.error("Error scanning for ToolHive", error=str(e))
raise

self.base_url = f"http://{self.thv_host}:{self.thv_port}"
logger.info("ToolhiveClient port discovered", host=self.thv_host, port=self.thv_port)

async def _rediscover_port(self) -> bool:
"""
Attempt to rediscover the ToolHive port after a connection failure.
Expand All @@ -214,7 +248,7 @@ async def _rediscover_port(self) -> bool:
self.thv_port = None

try:
# Try the initial port first
# Try the initial port first using async version
await self._discover_port_async(self._initial_port)

if self.thv_port and self.thv_port != old_port:
Expand Down Expand Up @@ -327,47 +361,6 @@ async def wrapper(*args: Any, **kwargs: Any) -> T:

return wrapper

def _is_toolhive_available(self, host: str, port: int) -> bool:
"""Check if ToolHive is available at the given host and port."""
try:
response = httpx.get(f"http://{host}:{port}/api/v1beta/version", timeout=1.0)
response.raise_for_status()

# Validate that the response is actually from ToolHive
# by checking for the expected version field
try:
data = response.json()
if not isinstance(data, dict) or "version" not in data:
logger.debug(
"Port responded but not with ToolHive format",
host=host,
port=port,
response=data,
)
return False
return True
except (ValueError, KeyError):
logger.debug("Port responded but could not parse JSON", host=host, port=port)
return False
except (httpx.HTTPError, OSError):
return False

def _scan_for_toolhive(self, host: str, scan_port_start: int, scan_port_end: int) -> int:
"""Scan for ToolHive in the specified port range."""
logger.info(
"Scanning for ToolHive", host=host, port_range=f"{scan_port_start}-{scan_port_end}"
)

for port in range(scan_port_start, scan_port_end + 1):
if self._is_toolhive_available(host, port):
logger.info("Found ToolHive", host=host, port=port)
return port

# If no port found, raise an error
raise ConnectionError(
f"ToolHive not found on {host} in port range {scan_port_start}-{scan_port_end}"
)

async def __aenter__(self) -> Self:
"""Async context manager entry."""
self._client = httpx.AsyncClient(timeout=self.timeout)
Expand Down
12 changes: 8 additions & 4 deletions tests/test_polling_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from unittest.mock import AsyncMock, patch

import pytest
from semver import Version

from mcp_optimizer.db.config import DatabaseConfig
from mcp_optimizer.db.models import McpStatus, RegistryServer, WorkloadServer
from mcp_optimizer.embeddings import EmbeddingManager
from mcp_optimizer.polling_manager import PollingManager
from mcp_optimizer.toolhive.toolhive_client import ToolhiveClient
from mcp_optimizer.toolhive.toolhive_client import ToolhiveClient, ToolhiveScanError


@pytest.fixture
Expand All @@ -29,11 +30,14 @@ def embedding_manager():
def toolhive_client(monkeypatch):
"""Create a mock ToolhiveClient for testing."""

def mock_scan_for_toolhive(self, host, start_port, end_port):
async def mock_scan_for_toolhive(self, host, start_port, end_port):
return 8080 # Force return of 8080 for testing

def mock_is_toolhive_available(self, host, port):
return port == 8080 # Only consider 8080 as available
async def mock_is_toolhive_available(self, host, port):
# Return (Version, port) tuple as per new signature
if port == 8080:
return (Version.parse("1.0.0"), 8080)
raise ToolhiveScanError(f"Port {port} not available")

# Mock the methods before creating the client
monkeypatch.setattr(
Expand Down
Loading