diff --git a/docs/code_reference/mcp.md b/docs/code_reference/mcp.md index b27bbb60f..cbabce846 100644 --- a/docs/code_reference/mcp.md +++ b/docs/code_reference/mcp.md @@ -4,7 +4,7 @@ The `mcp` module defines configuration and execution classes for tool use via MC ## Configuration Classes -[MCPProvider](#data_designer.config.mcp.MCPProvider) configures remote MCP servers via SSE transport. [LocalStdioMCPProvider](#data_designer.config.mcp.LocalStdioMCPProvider) configures local MCP servers as subprocesses via stdio transport. [ToolConfig](#data_designer.config.mcp.ToolConfig) defines which tools are available for LLM columns and how they are constrained. +[MCPProvider](#data_designer.config.mcp.MCPProvider) configures remote MCP servers via SSE or Streamable HTTP transport. [LocalStdioMCPProvider](#data_designer.config.mcp.LocalStdioMCPProvider) configures local MCP servers as subprocesses via stdio transport. [ToolConfig](#data_designer.config.mcp.ToolConfig) defines which tools are available for LLM columns and how they are constrained. For user-facing guides, see: diff --git a/docs/concepts/mcp/configure-mcp-cli.md b/docs/concepts/mcp/configure-mcp-cli.md index ba1757f60..72b357e99 100644 --- a/docs/concepts/mcp/configure-mcp-cli.md +++ b/docs/concepts/mcp/configure-mcp-cli.md @@ -49,14 +49,15 @@ data-designer config mcp The wizard first asks you to choose a provider type: 1. **Remote SSE**: Connect to a pre-existing MCP server via HTTP Server-Sent Events -2. **Local stdio subprocess**: Launch an MCP server as a subprocess +2. **Remote Streamable HTTP**: Connect to a pre-existing MCP server via Streamable HTTP +3. **Local stdio subprocess**: Launch an MCP server as a subprocess -### Remote SSE Configuration +### Remote SSE / Streamable HTTP Configuration -When configuring a Remote SSE provider, you'll be prompted for: +When configuring a remote provider (SSE or Streamable HTTP), you'll be prompted for: - **Name**: Unique identifier (e.g., `"doc-search"`) -- **Endpoint**: SSE endpoint URL (e.g., `"http://localhost:8080/sse"`) +- **Endpoint**: Server endpoint URL (e.g., `"http://localhost:8080/sse"` or `"https://mcp.example.com/mcp"`) - **API Key**: Optional API key or environment variable name ### Local Stdio Configuration diff --git a/docs/concepts/mcp/mcp-providers.md b/docs/concepts/mcp/mcp-providers.md index 3e6247edb..0c03ded91 100644 --- a/docs/concepts/mcp/mcp-providers.md +++ b/docs/concepts/mcp/mcp-providers.md @@ -8,26 +8,35 @@ An MCP provider defines how Data Designer connects to a tool server. Data Design | Provider Class | Connection Method | Use Case | |---------------|-------------------|----------| -| `MCPProvider` | HTTP Server-Sent Events | Connect to a pre-existing MCP server | +| `MCPProvider` | SSE or Streamable HTTP | Connect to a pre-existing MCP server | | `LocalStdioMCPProvider` | Subprocess via stdin/stdout | Launch an MCP server as a subprocess | When you create a `ToolConfig`, you reference providers by name, and Data Designer uses those provider settings to communicate with the appropriate MCP servers. -## MCPProvider (Remote SSE) +## MCPProvider (Remote) -Use `MCPProvider` to connect to a pre-existing MCP server via Server-Sent Events: +Use `MCPProvider` to connect to a pre-existing MCP server. Both SSE (Server-Sent Events) and Streamable HTTP transports are supported: ```python import data_designer.config as dd from data_designer.interface import DataDesigner -mcp_provider = dd.MCPProvider( +# SSE transport (default) +sse_provider = dd.MCPProvider( name="remote-mcp", endpoint="http://localhost:8080/sse", api_key="MCP_API_KEY", # Environment variable name ) -data_designer = DataDesigner(mcp_providers=[mcp_provider]) +# Streamable HTTP transport +http_provider = dd.MCPProvider( + name="remote-tools", + endpoint="https://mcp.example.com/mcp", + api_key="MCP_API_KEY", + provider_type="streamable_http", +) + +data_designer = DataDesigner(mcp_providers=[sse_provider, http_provider]) ``` ### MCPProvider Fields @@ -35,9 +44,9 @@ data_designer = DataDesigner(mcp_providers=[mcp_provider]) | Field | Type | Required | Description | |-------|------|----------|-------------| | `name` | `str` | Yes | Unique identifier for the provider | -| `endpoint` | `str` | Yes | SSE endpoint URL (e.g., `"http://localhost:8080/sse"`) | +| `endpoint` | `str` | Yes | Endpoint URL for the remote MCP server | | `api_key` | `str` | No | API key or environment variable name | -| `provider_type` | `str` | No | Always `"sse"` (set automatically) | +| `provider_type` | `str` | No | Transport type: `"sse"` (default) or `"streamable_http"` | ## LocalStdioMCPProvider (Subprocess) @@ -103,6 +112,12 @@ providers: endpoint: http://localhost:8080/sse api_key: ${MCP_API_KEY} + # Remote Streamable HTTP provider + - name: remote-tools + provider_type: streamable_http + endpoint: https://mcp.example.com/mcp + api_key: ${MCP_API_KEY} + # Local stdio provider - name: local-tools provider_type: stdio diff --git a/packages/data-designer-config/src/data_designer/config/mcp.py b/packages/data-designer-config/src/data_designer/config/mcp.py index fe870fa86..f98f75929 100644 --- a/packages/data-designer-config/src/data_designer/config/mcp.py +++ b/packages/data-designer-config/src/data_designer/config/mcp.py @@ -15,25 +15,36 @@ class MCPProvider(ConfigBase): """Configuration for a remote MCP server connection. MCPProvider is used to connect to pre-existing MCP servers via SSE (Server-Sent Events) - transport. For local subprocess-based MCP servers, use LocalStdioMCPProvider instead. + or Streamable HTTP transport. For local subprocess-based MCP servers, use + LocalStdioMCPProvider instead. Attributes: name (str): Unique name used to reference this MCP provider. - endpoint (str): SSE endpoint URL for connecting to the remote MCP server. + endpoint (str): Endpoint URL for connecting to the remote MCP server. api_key (str | None): Optional API key for authentication. Defaults to None. - provider_type (Literal["sse"]): Transport type discriminator, always "sse". + provider_type (Literal["sse", "streamable_http"]): Transport type discriminator. + Defaults to ``"sse"``. Examples: - Remote SSE transport: + Remote SSE transport (default): >>> MCPProvider( ... name="remote-mcp", ... endpoint="http://localhost:8080/sse", ... api_key="your-api-key", ... ) + + Remote Streamable HTTP transport: + + >>> MCPProvider( + ... name="remote-mcp", + ... endpoint="https://api.example.com/mcp", + ... api_key="your-api-key", + ... provider_type="streamable_http", + ... ) """ - provider_type: Literal["sse"] = "sse" + provider_type: Literal["sse", "streamable_http"] = "sse" name: str endpoint: str api_key: str | None = None diff --git a/packages/data-designer-config/tests/config/test_mcp.py b/packages/data-designer-config/tests/config/test_mcp.py index 7aefd2d6b..55b386709 100644 --- a/packages/data-designer-config/tests/config/test_mcp.py +++ b/packages/data-designer-config/tests/config/test_mcp.py @@ -14,11 +14,32 @@ def test_mcp_provider_requires_endpoint() -> None: provider = MCPProvider(name="sse", endpoint="http://localhost:8080") assert provider.endpoint == "http://localhost:8080" assert provider.api_key is None + assert provider.provider_type == "sse" provider_with_key = MCPProvider(name="sse-auth", endpoint="http://localhost:8080", api_key="secret") assert provider_with_key.api_key == "secret" +def test_mcp_provider_streamable_http() -> None: + provider = MCPProvider( + name="streamable", + endpoint="https://api.example.com/mcp", + provider_type="streamable_http", + ) + assert provider.provider_type == "streamable_http" + assert provider.endpoint == "https://api.example.com/mcp" + assert provider.api_key is None + + provider_with_key = MCPProvider( + name="streamable-auth", + endpoint="https://api.example.com/mcp", + provider_type="streamable_http", + api_key="secret", + ) + assert provider_with_key.api_key == "secret" + assert provider_with_key.provider_type == "streamable_http" + + def test_local_stdio_mcp_provider_requires_command() -> None: with pytest.raises(ValidationError): LocalStdioMCPProvider(name="missing-command") diff --git a/packages/data-designer-engine/src/data_designer/engine/mcp/io.py b/packages/data-designer-engine/src/data_designer/engine/mcp/io.py index 5ef4702ce..60a46b257 100644 --- a/packages/data-designer-engine/src/data_designer/engine/mcp/io.py +++ b/packages/data-designer-engine/src/data_designer/engine/mcp/io.py @@ -39,8 +39,9 @@ from mcp import ClientSession, StdioServerParameters from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import streamablehttp_client -from data_designer.config.mcp import LocalStdioMCPProvider, MCPProviderT +from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, MCPProviderT from data_designer.engine.mcp.errors import MCPToolError from data_designer.engine.mcp.registry import MCPToolDefinition, MCPToolResult @@ -211,11 +212,15 @@ async def create_session() -> ClientSession: env=provider.env, ) ctx = stdio_client(params) + elif isinstance(provider, MCPProvider) and provider.provider_type == "streamable_http": + headers = _build_auth_headers(provider.api_key) + ctx = streamablehttp_client(provider.endpoint, headers=headers) else: headers = _build_auth_headers(provider.api_key) ctx = sse_client(provider.endpoint, headers=headers) - read, write = await ctx.__aenter__() + ctx_result = await ctx.__aenter__() + read, write = ctx_result[0], ctx_result[1] new_session = ClientSession(read, write) await new_session.__aenter__() await new_session.initialize() @@ -399,6 +404,11 @@ def list_tools(provider: MCPProviderT, timeout_sec: float | None = None) -> tupl return _MCP_IO_SERVICE.list_tools(provider, timeout_sec=timeout_sec) +def list_tool_names(provider: MCPProviderT, timeout_sec: float) -> list[str]: + """Return the names of all tools available on an MCP provider.""" + return [t.name for t in _MCP_IO_SERVICE.list_tools(provider, timeout_sec=timeout_sec)] + + def call_tools( calls: list[tuple[MCPProviderT, str, dict[str, Any]]], *, @@ -434,7 +444,7 @@ def get_session_pool_info() -> dict[str, Any]: def _build_auth_headers(api_key: str | None) -> dict[str, Any] | None: - """Build authentication headers for SSE client.""" + """Build authentication headers for remote MCP clients.""" if not api_key: return None return {"Authorization": f"Bearer {api_key}"} diff --git a/packages/data-designer-engine/src/data_designer/engine/testing/fixtures.py b/packages/data-designer-engine/src/data_designer/engine/testing/fixtures.py index 7e694f575..af7d3ebfc 100644 --- a/packages/data-designer-engine/src/data_designer/engine/testing/fixtures.py +++ b/packages/data-designer-engine/src/data_designer/engine/testing/fixtures.py @@ -81,6 +81,17 @@ def stub_sse_provider() -> MCPProvider: ) +@pytest.fixture +def stub_streamable_http_provider() -> MCPProvider: + """Create a stub Streamable HTTP MCP provider for testing.""" + return MCPProvider( + name="test-streamable-http", + endpoint="https://api.example.com/mcp", + api_key="test-key", + provider_type="streamable_http", + ) + + # ============================================================================= # Tool config fixtures # ============================================================================= diff --git a/packages/data-designer-engine/tests/engine/mcp/test_mcp_io.py b/packages/data-designer-engine/tests/engine/mcp/test_mcp_io.py index 411dbd35b..21a357090 100644 --- a/packages/data-designer-engine/tests/engine/mcp/test_mcp_io.py +++ b/packages/data-designer-engine/tests/engine/mcp/test_mcp_io.py @@ -693,6 +693,59 @@ def mock_stdio_client(params: Any) -> MockContextManager: mcp_io.clear_session_pool() +# ============================================================================= +# Streamable HTTP session creation tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_get_or_create_session_for_streamable_http_provider( + monkeypatch: pytest.MonkeyPatch, stub_streamable_http_provider: MCPProvider +) -> None: + """Test that _get_or_create_session uses streamablehttp_client for streamable_http providers.""" + mcp_io.clear_session_pool() + + class MockContextManager: + async def __aenter__(self) -> tuple[Any, Any, Any]: + return ("mock_read", "mock_write", "mock_get_session_id") + + async def __aexit__(self, *args: Any) -> None: + pass + + class MockSession: + async def __aenter__(self) -> "MockSession": + return self + + async def __aexit__(self, *args: Any) -> None: + pass + + async def initialize(self) -> None: + pass + + streamable_http_client_called = False + received_endpoint: str | None = None + received_headers: dict[str, Any] | None = None + + def mock_streamablehttp_client(endpoint: str, headers: dict[str, Any] | None = None) -> MockContextManager: + nonlocal streamable_http_client_called, received_endpoint, received_headers + streamable_http_client_called = True + received_endpoint = endpoint + received_headers = headers + return MockContextManager() + + monkeypatch.setattr(mcp_io, "streamablehttp_client", mock_streamablehttp_client) + monkeypatch.setattr(mcp_io, "ClientSession", lambda r, w: MockSession()) + + session = await mcp_io._MCP_IO_SERVICE._get_or_create_session(stub_streamable_http_provider) + + assert streamable_http_client_called + assert received_endpoint == stub_streamable_http_provider.endpoint + assert received_headers == {"Authorization": "Bearer test-key"} + assert session is not None + + mcp_io.clear_session_pool() + + # ============================================================================= # Session cleanup exception handling tests # ============================================================================= diff --git a/packages/data-designer/src/data_designer/cli/controllers/mcp_provider_controller.py b/packages/data-designer/src/data_designer/cli/controllers/mcp_provider_controller.py index 0f16fb9cb..3d95b5630 100644 --- a/packages/data-designer/src/data_designer/cli/controllers/mcp_provider_controller.py +++ b/packages/data-designer/src/data_designer/cli/controllers/mcp_provider_controller.py @@ -218,7 +218,8 @@ def _select_provider(self, providers: list[MCPProviderT], prompt: str, default: options = {} for p in providers: if isinstance(p, MCPProvider): - options[p.name] = f"{p.name} (SSE: {p.endpoint})" + transport_label = "Streamable HTTP" if p.provider_type == "streamable_http" else "SSE" + options[p.name] = f"{p.name} ({transport_label}: {p.endpoint})" elif isinstance(p, LocalStdioMCPProvider): options[p.name] = f"{p.name} (stdio: {p.command})" else: diff --git a/packages/data-designer/src/data_designer/cli/forms/mcp_provider_builder.py b/packages/data-designer/src/data_designer/cli/forms/mcp_provider_builder.py index 27ad95d21..040f95914 100644 --- a/packages/data-designer/src/data_designer/cli/forms/mcp_provider_builder.py +++ b/packages/data-designer/src/data_designer/cli/forms/mcp_provider_builder.py @@ -3,7 +3,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, Literal from data_designer.cli.forms.field import TextField from data_designer.cli.forms.form import Form @@ -42,8 +42,8 @@ def run(self, initial_data: dict[str, Any] | None = None) -> MCPProviderT | None return None # Run appropriate form based on provider type - if provider_type == "sse": - result = self._run_sse_form(initial_data) + if provider_type in ("sse", "streamable_http"): + result = self._run_remote_form(provider_type, initial_data) else: # stdio result = self._run_stdio_form(initial_data) @@ -59,6 +59,7 @@ def _select_provider_type(self) -> str | None: """Prompt user to select provider type.""" options = { "sse": "Remote SSE server (connect to existing server)", + "streamable_http": "Remote Streamable HTTP server (connect to existing server)", "stdio": "Local stdio subprocess (launch server as subprocess)", } @@ -70,8 +71,11 @@ def _select_provider_type(self) -> str | None: allow_back=True, ) - def _run_sse_form(self, initial_data: dict[str, Any] | None = None) -> MCPProvider | None: - """Run form for remote SSE provider.""" + def _run_remote_form( + self, provider_type: Literal["sse", "streamable_http"], initial_data: dict[str, Any] | None = None + ) -> MCPProvider | None: + """Run form for a remote MCP provider (SSE or Streamable HTTP).""" + transport_label = "SSE" if provider_type == "sse" else "Streamable HTTP" fields = [ TextField( "name", @@ -82,7 +86,7 @@ def _run_sse_form(self, initial_data: dict[str, Any] | None = None) -> MCPProvid ), TextField( "endpoint", - "SSE endpoint URL", + f"{transport_label} endpoint URL", default=initial_data.get("endpoint") if initial_data else None, required=True, validator=self._validate_endpoint, @@ -95,7 +99,7 @@ def _run_sse_form(self, initial_data: dict[str, Any] | None = None) -> MCPProvid ), ] - form = Form("Remote SSE Provider", fields) + form = Form(f"Remote {transport_label} Provider", fields) if initial_data: form.set_values(initial_data) @@ -108,6 +112,7 @@ def _run_sse_form(self, initial_data: dict[str, Any] | None = None) -> MCPProvid name=result["name"], endpoint=result["endpoint"], api_key=result.get("api_key") or None, + provider_type=provider_type, ) except Exception as e: print_error(f"Configuration error: {e}") diff --git a/packages/data-designer/src/data_designer/interface/data_designer.py b/packages/data-designer/src/data_designer/interface/data_designer.py index 62970e3f8..a2b7f74d3 100644 --- a/packages/data-designer/src/data_designer/interface/data_designer.py +++ b/packages/data-designer/src/data_designer/interface/data_designer.py @@ -36,6 +36,7 @@ from data_designer.engine.analysis.dataset_profiler import DataDesignerDatasetProfiler, DatasetProfilerConfig from data_designer.engine.compiler import compile_data_designer_config from data_designer.engine.dataset_builders.column_wise_builder import ColumnWiseDatasetBuilder +from data_designer.engine.mcp.io import list_tool_names from data_designer.engine.model_provider import resolve_model_provider_registry from data_designer.engine.resources.managed_storage import init_managed_blob_storage from data_designer.engine.resources.resource_provider import ResourceProvider, create_resource_provider @@ -149,6 +150,25 @@ def info(self) -> InterfaceInfo: """ return self._get_interface_info(self._model_providers) + def list_mcp_tool_names(self, mcp_provider_name: str, *, timeout_sec: float = 10.0) -> list[str]: + """Connect to a configured MCP provider and return the names of its available tools. + + Args: + mcp_provider_name: The ``name`` field of an MCP provider passed to the constructor. + timeout_sec: Timeout in seconds for the MCP handshake. Defaults to 10. + + Returns: + A list of tool name strings exposed by the MCP server. + + Raises: + ValueError: If no provider with the given name was configured. + """ + for provider in self._mcp_providers: + if provider.name == mcp_provider_name: + return list_tool_names(provider, timeout_sec=timeout_sec) + configured = [p.name for p in self._mcp_providers] + raise ValueError(f"No MCP provider named {mcp_provider_name!r}. Configured providers: {configured}") + def create( self, config_builder: DataDesignerConfigBuilder, diff --git a/packages/data-designer/tests/cli/forms/test_mcp_provider_builder.py b/packages/data-designer/tests/cli/forms/test_mcp_provider_builder.py index 3e5ec8f62..182e6ddde 100644 --- a/packages/data-designer/tests/cli/forms/test_mcp_provider_builder.py +++ b/packages/data-designer/tests/cli/forms/test_mcp_provider_builder.py @@ -179,12 +179,12 @@ def test_select_provider_type_returns_stdio(mock_select: MagicMock) -> None: # ============================================================================= -# SSE form tests +# Remote form tests (SSE and Streamable HTTP) # ============================================================================= -def test_run_sse_form_creates_mcp_provider() -> None: - """Test _run_sse_form creates valid MCPProvider.""" +def test_run_remote_form_creates_sse_provider() -> None: + """Test _run_remote_form creates valid MCPProvider with SSE transport.""" builder = MCPProviderFormBuilder() form_result = { "name": "my-server", @@ -196,29 +196,52 @@ def test_run_sse_form_creates_mcp_provider() -> None: mock_form.prompt_all.return_value = form_result with patch("data_designer.cli.forms.mcp_provider_builder.Form", return_value=mock_form): - result = builder._run_sse_form() + result = builder._run_remote_form("sse") assert isinstance(result, MCPProvider) assert result.name == "my-server" assert result.endpoint == "http://localhost:8080/sse" assert result.api_key == "secret-key" + assert result.provider_type == "sse" -def test_run_sse_form_returns_none_on_cancel() -> None: - """Test _run_sse_form returns None when user cancels.""" +def test_run_remote_form_creates_streamable_http_provider() -> None: + """Test _run_remote_form creates valid MCPProvider with Streamable HTTP transport.""" + builder = MCPProviderFormBuilder() + form_result = { + "name": "my-server", + "endpoint": "https://api.example.com/mcp", + "api_key": "secret-key", + } + + mock_form = MagicMock() + mock_form.prompt_all.return_value = form_result + + with patch("data_designer.cli.forms.mcp_provider_builder.Form", return_value=mock_form): + result = builder._run_remote_form("streamable_http") + + assert isinstance(result, MCPProvider) + assert result.name == "my-server" + assert result.endpoint == "https://api.example.com/mcp" + assert result.api_key == "secret-key" + assert result.provider_type == "streamable_http" + + +def test_run_remote_form_returns_none_on_cancel() -> None: + """Test _run_remote_form returns None when user cancels.""" builder = MCPProviderFormBuilder() mock_form = MagicMock() mock_form.prompt_all.return_value = None with patch("data_designer.cli.forms.mcp_provider_builder.Form", return_value=mock_form): - result = builder._run_sse_form() + result = builder._run_remote_form("sse") assert result is None -def test_run_sse_form_handles_optional_api_key() -> None: - """Test _run_sse_form handles missing/empty api_key.""" +def test_run_remote_form_handles_optional_api_key() -> None: + """Test _run_remote_form handles missing/empty api_key.""" builder = MCPProviderFormBuilder() form_result = { "name": "my-server", @@ -230,14 +253,14 @@ def test_run_sse_form_handles_optional_api_key() -> None: mock_form.prompt_all.return_value = form_result with patch("data_designer.cli.forms.mcp_provider_builder.Form", return_value=mock_form): - result = builder._run_sse_form() + result = builder._run_remote_form("sse") assert isinstance(result, MCPProvider) assert result.api_key is None -def test_run_sse_form_uses_initial_data() -> None: - """Test _run_sse_form populates form with initial data.""" +def test_run_remote_form_uses_initial_data() -> None: + """Test _run_remote_form populates form with initial data.""" builder = MCPProviderFormBuilder() initial_data = { "name": "existing-server", @@ -249,14 +272,14 @@ def test_run_sse_form_uses_initial_data() -> None: mock_form.prompt_all.return_value = initial_data with patch("data_designer.cli.forms.mcp_provider_builder.Form", return_value=mock_form): - builder._run_sse_form(initial_data) + builder._run_remote_form("sse", initial_data) mock_form.set_values.assert_called_once_with(initial_data) @patch("data_designer.cli.forms.mcp_provider_builder.print_error") -def test_run_sse_form_handles_exception(mock_print_error: MagicMock) -> None: - """Test _run_sse_form handles validation exceptions gracefully.""" +def test_run_remote_form_handles_exception(mock_print_error: MagicMock) -> None: + """Test _run_remote_form handles validation exceptions gracefully.""" builder = MCPProviderFormBuilder() form_result = { "name": "", # Invalid: empty name will cause exception @@ -274,7 +297,7 @@ def test_run_sse_form_handles_exception(mock_print_error: MagicMock) -> None: side_effect=Exception("Validation error"), ), ): - result = builder._run_sse_form() + result = builder._run_remote_form("sse") assert result is None mock_print_error.assert_called()