Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/code_reference/mcp.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ The `mcp` module defines configuration and execution classes for tool use via MC

## Configuration Classes

[MCPProvider](#data_designer.config.mcp.MCPProvider) configures remote MCP servers via SSE transport. [LocalStdioMCPProvider](#data_designer.config.mcp.LocalStdioMCPProvider) configures local MCP servers as subprocesses via stdio transport. [ToolConfig](#data_designer.config.mcp.ToolConfig) defines which tools are available for LLM columns and how they are constrained.
[MCPProvider](#data_designer.config.mcp.MCPProvider) configures remote MCP servers via SSE or Streamable HTTP transport. [LocalStdioMCPProvider](#data_designer.config.mcp.LocalStdioMCPProvider) configures local MCP servers as subprocesses via stdio transport. [ToolConfig](#data_designer.config.mcp.ToolConfig) defines which tools are available for LLM columns and how they are constrained.

For user-facing guides, see:

Expand Down
9 changes: 5 additions & 4 deletions docs/concepts/mcp/configure-mcp-cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,15 @@ data-designer config mcp
The wizard first asks you to choose a provider type:

1. **Remote SSE**: Connect to a pre-existing MCP server via HTTP Server-Sent Events
2. **Local stdio subprocess**: Launch an MCP server as a subprocess
2. **Remote Streamable HTTP**: Connect to a pre-existing MCP server via Streamable HTTP
3. **Local stdio subprocess**: Launch an MCP server as a subprocess

### Remote SSE Configuration
### Remote SSE / Streamable HTTP Configuration

When configuring a Remote SSE provider, you'll be prompted for:
When configuring a remote provider (SSE or Streamable HTTP), you'll be prompted for:

- **Name**: Unique identifier (e.g., `"doc-search"`)
- **Endpoint**: SSE endpoint URL (e.g., `"http://localhost:8080/sse"`)
- **Endpoint**: Server endpoint URL (e.g., `"http://localhost:8080/sse"` or `"https://mcp.example.com/mcp"`)
- **API Key**: Optional API key or environment variable name

### Local Stdio Configuration
Expand Down
29 changes: 22 additions & 7 deletions docs/concepts/mcp/mcp-providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,45 @@ An MCP provider defines how Data Designer connects to a tool server. Data Design

| Provider Class | Connection Method | Use Case |
|---------------|-------------------|----------|
| `MCPProvider` | HTTP Server-Sent Events | Connect to a pre-existing MCP server |
| `MCPProvider` | SSE or Streamable HTTP | Connect to a pre-existing MCP server |
| `LocalStdioMCPProvider` | Subprocess via stdin/stdout | Launch an MCP server as a subprocess |

When you create a `ToolConfig`, you reference providers by name, and Data Designer uses those provider settings to communicate with the appropriate MCP servers.

## MCPProvider (Remote SSE)
## MCPProvider (Remote)

Use `MCPProvider` to connect to a pre-existing MCP server via Server-Sent Events:
Use `MCPProvider` to connect to a pre-existing MCP server. Both SSE (Server-Sent Events) and Streamable HTTP transports are supported:

```python
import data_designer.config as dd
from data_designer.interface import DataDesigner

mcp_provider = dd.MCPProvider(
# SSE transport (default)
sse_provider = dd.MCPProvider(
name="remote-mcp",
endpoint="http://localhost:8080/sse",
api_key="MCP_API_KEY", # Environment variable name
)

data_designer = DataDesigner(mcp_providers=[mcp_provider])
# Streamable HTTP transport
http_provider = dd.MCPProvider(
name="remote-tools",
endpoint="https://mcp.example.com/mcp",
api_key="MCP_API_KEY",
provider_type="streamable_http",
)

data_designer = DataDesigner(mcp_providers=[sse_provider, http_provider])
```

### MCPProvider Fields

| Field | Type | Required | Description |
|-------|------|----------|-------------|
| `name` | `str` | Yes | Unique identifier for the provider |
| `endpoint` | `str` | Yes | SSE endpoint URL (e.g., `"http://localhost:8080/sse"`) |
| `endpoint` | `str` | Yes | Endpoint URL for the remote MCP server |
| `api_key` | `str` | No | API key or environment variable name |
| `provider_type` | `str` | No | Always `"sse"` (set automatically) |
| `provider_type` | `str` | No | Transport type: `"sse"` (default) or `"streamable_http"` |

## LocalStdioMCPProvider (Subprocess)

Expand Down Expand Up @@ -103,6 +112,12 @@ providers:
endpoint: http://localhost:8080/sse
api_key: ${MCP_API_KEY}

# Remote Streamable HTTP provider
- name: remote-tools
provider_type: streamable_http
endpoint: https://mcp.example.com/mcp
api_key: ${MCP_API_KEY}

# Local stdio provider
- name: local-tools
provider_type: stdio
Expand Down
21 changes: 16 additions & 5 deletions packages/data-designer-config/src/data_designer/config/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,25 +15,36 @@ class MCPProvider(ConfigBase):
"""Configuration for a remote MCP server connection.

MCPProvider is used to connect to pre-existing MCP servers via SSE (Server-Sent Events)
transport. For local subprocess-based MCP servers, use LocalStdioMCPProvider instead.
or Streamable HTTP transport. For local subprocess-based MCP servers, use
LocalStdioMCPProvider instead.

Attributes:
name (str): Unique name used to reference this MCP provider.
endpoint (str): SSE endpoint URL for connecting to the remote MCP server.
endpoint (str): Endpoint URL for connecting to the remote MCP server.
api_key (str | None): Optional API key for authentication. Defaults to None.
provider_type (Literal["sse"]): Transport type discriminator, always "sse".
provider_type (Literal["sse", "streamable_http"]): Transport type discriminator.
Defaults to ``"sse"``.

Examples:
Remote SSE transport:
Remote SSE transport (default):

>>> MCPProvider(
... name="remote-mcp",
... endpoint="http://localhost:8080/sse",
... api_key="your-api-key",
... )

Remote Streamable HTTP transport:

>>> MCPProvider(
... name="remote-mcp",
... endpoint="https://api.example.com/mcp",
... api_key="your-api-key",
... provider_type="streamable_http",
... )
"""

provider_type: Literal["sse"] = "sse"
provider_type: Literal["sse", "streamable_http"] = "sse"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Do we opt for StrEnum?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Do we opt for StrEnum?

We use a mix of string literals and string enums in this code base. We can definitely come back and change this to a StrEnum later.

name: str
endpoint: str
api_key: str | None = None
Expand Down
21 changes: 21 additions & 0 deletions packages/data-designer-config/tests/config/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,32 @@ def test_mcp_provider_requires_endpoint() -> None:
provider = MCPProvider(name="sse", endpoint="http://localhost:8080")
assert provider.endpoint == "http://localhost:8080"
assert provider.api_key is None
assert provider.provider_type == "sse"

provider_with_key = MCPProvider(name="sse-auth", endpoint="http://localhost:8080", api_key="secret")
assert provider_with_key.api_key == "secret"


def test_mcp_provider_streamable_http() -> None:
provider = MCPProvider(
name="streamable",
endpoint="https://api.example.com/mcp",
provider_type="streamable_http",
)
assert provider.provider_type == "streamable_http"
assert provider.endpoint == "https://api.example.com/mcp"
assert provider.api_key is None

provider_with_key = MCPProvider(
name="streamable-auth",
endpoint="https://api.example.com/mcp",
provider_type="streamable_http",
api_key="secret",
)
assert provider_with_key.api_key == "secret"
assert provider_with_key.provider_type == "streamable_http"


def test_local_stdio_mcp_provider_requires_command() -> None:
with pytest.raises(ValidationError):
LocalStdioMCPProvider(name="missing-command")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client

from data_designer.config.mcp import LocalStdioMCPProvider, MCPProviderT
from data_designer.config.mcp import LocalStdioMCPProvider, MCPProvider, MCPProviderT
from data_designer.engine.mcp.errors import MCPToolError
from data_designer.engine.mcp.registry import MCPToolDefinition, MCPToolResult

Expand Down Expand Up @@ -211,11 +212,15 @@ async def create_session() -> ClientSession:
env=provider.env,
)
ctx = stdio_client(params)
elif isinstance(provider, MCPProvider) and provider.provider_type == "streamable_http":
headers = _build_auth_headers(provider.api_key)
ctx = streamablehttp_client(provider.endpoint, headers=headers)
else:
headers = _build_auth_headers(provider.api_key)
ctx = sse_client(provider.endpoint, headers=headers)

read, write = await ctx.__aenter__()
ctx_result = await ctx.__aenter__()
read, write = ctx_result[0], ctx_result[1]
new_session = ClientSession(read, write)
await new_session.__aenter__()
await new_session.initialize()
Expand Down Expand Up @@ -399,6 +404,11 @@ def list_tools(provider: MCPProviderT, timeout_sec: float | None = None) -> tupl
return _MCP_IO_SERVICE.list_tools(provider, timeout_sec=timeout_sec)


def list_tool_names(provider: MCPProviderT, timeout_sec: float) -> list[str]:
"""Return the names of all tools available on an MCP provider."""
return [t.name for t in _MCP_IO_SERVICE.list_tools(provider, timeout_sec=timeout_sec)]


def call_tools(
calls: list[tuple[MCPProviderT, str, dict[str, Any]]],
*,
Expand Down Expand Up @@ -434,7 +444,7 @@ def get_session_pool_info() -> dict[str, Any]:


def _build_auth_headers(api_key: str | None) -> dict[str, Any] | None:
"""Build authentication headers for SSE client."""
"""Build authentication headers for remote MCP clients."""
if not api_key:
return None
return {"Authorization": f"Bearer {api_key}"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ def stub_sse_provider() -> MCPProvider:
)


@pytest.fixture
def stub_streamable_http_provider() -> MCPProvider:
"""Create a stub Streamable HTTP MCP provider for testing."""
return MCPProvider(
name="test-streamable-http",
endpoint="https://api.example.com/mcp",
api_key="test-key",
provider_type="streamable_http",
)


# =============================================================================
# Tool config fixtures
# =============================================================================
Expand Down
53 changes: 53 additions & 0 deletions packages/data-designer-engine/tests/engine/mcp/test_mcp_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,59 @@ def mock_stdio_client(params: Any) -> MockContextManager:
mcp_io.clear_session_pool()


# =============================================================================
# Streamable HTTP session creation tests
# =============================================================================


@pytest.mark.asyncio
async def test_get_or_create_session_for_streamable_http_provider(
monkeypatch: pytest.MonkeyPatch, stub_streamable_http_provider: MCPProvider
) -> None:
"""Test that _get_or_create_session uses streamablehttp_client for streamable_http providers."""
mcp_io.clear_session_pool()

class MockContextManager:
async def __aenter__(self) -> tuple[Any, Any, Any]:
return ("mock_read", "mock_write", "mock_get_session_id")

async def __aexit__(self, *args: Any) -> None:
pass

class MockSession:
async def __aenter__(self) -> "MockSession":
return self

async def __aexit__(self, *args: Any) -> None:
pass

async def initialize(self) -> None:
pass

streamable_http_client_called = False
received_endpoint: str | None = None
received_headers: dict[str, Any] | None = None

def mock_streamablehttp_client(endpoint: str, headers: dict[str, Any] | None = None) -> MockContextManager:
nonlocal streamable_http_client_called, received_endpoint, received_headers
streamable_http_client_called = True
received_endpoint = endpoint
received_headers = headers
return MockContextManager()

monkeypatch.setattr(mcp_io, "streamablehttp_client", mock_streamablehttp_client)
monkeypatch.setattr(mcp_io, "ClientSession", lambda r, w: MockSession())

session = await mcp_io._MCP_IO_SERVICE._get_or_create_session(stub_streamable_http_provider)

assert streamable_http_client_called
assert received_endpoint == stub_streamable_http_provider.endpoint
assert received_headers == {"Authorization": "Bearer test-key"}
assert session is not None

mcp_io.clear_session_pool()


# =============================================================================
# Session cleanup exception handling tests
# =============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ def _select_provider(self, providers: list[MCPProviderT], prompt: str, default:
options = {}
for p in providers:
if isinstance(p, MCPProvider):
options[p.name] = f"{p.name} (SSE: {p.endpoint})"
transport_label = "Streamable HTTP" if p.provider_type == "streamable_http" else "SSE"
options[p.name] = f"{p.name} ({transport_label}: {p.endpoint})"
elif isinstance(p, LocalStdioMCPProvider):
options[p.name] = f"{p.name} (stdio: {p.command})"
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from __future__ import annotations

from typing import Any
from typing import Any, Literal

from data_designer.cli.forms.field import TextField
from data_designer.cli.forms.form import Form
Expand Down Expand Up @@ -42,8 +42,8 @@ def run(self, initial_data: dict[str, Any] | None = None) -> MCPProviderT | None
return None

# Run appropriate form based on provider type
if provider_type == "sse":
result = self._run_sse_form(initial_data)
if provider_type in ("sse", "streamable_http"):
result = self._run_remote_form(provider_type, initial_data)
else: # stdio
result = self._run_stdio_form(initial_data)

Expand All @@ -59,6 +59,7 @@ def _select_provider_type(self) -> str | None:
"""Prompt user to select provider type."""
options = {
"sse": "Remote SSE server (connect to existing server)",
"streamable_http": "Remote Streamable HTTP server (connect to existing server)",
"stdio": "Local stdio subprocess (launch server as subprocess)",
}

Expand All @@ -70,8 +71,11 @@ def _select_provider_type(self) -> str | None:
allow_back=True,
)

def _run_sse_form(self, initial_data: dict[str, Any] | None = None) -> MCPProvider | None:
"""Run form for remote SSE provider."""
def _run_remote_form(
self, provider_type: Literal["sse", "streamable_http"], initial_data: dict[str, Any] | None = None
) -> MCPProvider | None:
"""Run form for a remote MCP provider (SSE or Streamable HTTP)."""
transport_label = "SSE" if provider_type == "sse" else "Streamable HTTP"
fields = [
TextField(
"name",
Expand All @@ -82,7 +86,7 @@ def _run_sse_form(self, initial_data: dict[str, Any] | None = None) -> MCPProvid
),
TextField(
"endpoint",
"SSE endpoint URL",
f"{transport_label} endpoint URL",
default=initial_data.get("endpoint") if initial_data else None,
required=True,
validator=self._validate_endpoint,
Expand All @@ -95,7 +99,7 @@ def _run_sse_form(self, initial_data: dict[str, Any] | None = None) -> MCPProvid
),
]

form = Form("Remote SSE Provider", fields)
form = Form(f"Remote {transport_label} Provider", fields)
if initial_data:
form.set_values(initial_data)

Expand All @@ -108,6 +112,7 @@ def _run_sse_form(self, initial_data: dict[str, Any] | None = None) -> MCPProvid
name=result["name"],
endpoint=result["endpoint"],
api_key=result.get("api_key") or None,
provider_type=provider_type,
)
except Exception as e:
print_error(f"Configuration error: {e}")
Expand Down
Loading