From 7826ade12b50ce377d0856a3e3448ea36269a4c7 Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:25:11 +0000 Subject: [PATCH 1/4] test: convert test_integration.py to in-memory transport (fix flaky) (#2277) --- src/mcp/server/session.py | 2 +- tests/server/mcpserver/test_integration.py | 744 +++++++-------------- 2 files changed, 231 insertions(+), 515 deletions(-) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 759d2131a..ce467e6c9 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -474,7 +474,7 @@ async def send_progress_notification( related_request_id, ) - async def send_resource_list_changed(self) -> None: # pragma: no cover + async def send_resource_list_changed(self) -> None: """Send a resource list changed notification.""" await self.send_notification(types.ResourceListChangedNotification()) diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index c4ea2dad6..90e333b77 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -1,7 +1,7 @@ """Integration tests for MCPServer server functionality. These tests validate the proper functioning of MCPServer features using focused, -single-feature servers across different transports (SSE and StreamableHTTP). +single-feature example servers over an in-memory transport. """ # TODO(Marcelo): The `examples` package is not being imported as package. We need to solve this. # pyright: reportUnknownMemberType=false @@ -10,12 +10,8 @@ # pyright: reportUnknownArgumentType=false import json -import multiprocessing -import socket -from collections.abc import Generator import pytest -import uvicorn from inline_snapshot import snapshot from examples.snippets.servers import ( @@ -30,9 +26,8 @@ structured_output, tool_progress, ) +from mcp.client import Client from mcp.client.session import ClientSession -from mcp.client.sse import sse_client -from mcp.client.streamable_http import streamable_http_client from mcp.shared._context import RequestContext from mcp.shared.session import RequestResponder from mcp.types import ( @@ -42,7 +37,6 @@ ElicitRequestParams, ElicitResult, GetPromptResult, - InitializeResult, LoggingMessageNotification, LoggingMessageNotificationParams, NotificationParams, @@ -58,7 +52,8 @@ TextResourceContents, ToolListChangedNotification, ) -from tests.test_helpers import wait_for_server + +pytestmark = pytest.mark.anyio class NotificationCollector: @@ -85,105 +80,6 @@ async def handle_generic_notification( self.tool_notifications.append(message.params) -# Common fixtures -@pytest.fixture -def server_port() -> int: - """Get a free port for testing.""" - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - """Get the server URL for testing.""" - return f"http://127.0.0.1:{server_port}" - - -def run_server_with_transport(module_name: str, port: int, transport: str) -> None: # pragma: no cover - """Run server with specified transport.""" - # Get the MCP instance based on module name - if module_name == "basic_tool": - mcp = basic_tool.mcp - elif module_name == "basic_resource": - mcp = basic_resource.mcp - elif module_name == "basic_prompt": - mcp = basic_prompt.mcp - elif module_name == "tool_progress": - mcp = tool_progress.mcp - elif module_name == "sampling": - mcp = sampling.mcp - elif module_name == "elicitation": - mcp = elicitation.mcp - elif module_name == "completion": - mcp = completion.mcp - elif module_name == "notifications": - mcp = notifications.mcp - elif module_name == "mcpserver_quickstart": - mcp = mcpserver_quickstart.mcp - elif module_name == "structured_output": - mcp = structured_output.mcp - else: - raise ImportError(f"Unknown module: {module_name}") - - # Create app based on transport type - if transport == "sse": - app = mcp.sse_app() - elif transport == "streamable-http": - app = mcp.streamable_http_app() - else: - raise ValueError(f"Invalid transport for test server: {transport}") - - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=port, log_level="error")) - print(f"Starting {transport} server on port {port}") - server.run() - - -@pytest.fixture -def server_transport(request: pytest.FixtureRequest, server_port: int) -> Generator[str, None, None]: - """Start server in a separate process with specified MCP instance and transport. - - Args: - request: pytest request with param tuple of (module_name, transport) - server_port: Port to run the server on - - Yields: - str: The transport type ('sse' or 'streamable_http') - """ - module_name, transport = request.param - - proc = multiprocessing.Process( - target=run_server_with_transport, - args=(module_name, server_port, transport), - daemon=True, - ) - proc.start() - - # Wait for server to be ready - wait_for_server(server_port) - - yield transport - - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("Server process failed to terminate") - - -# Helper function to create client based on transport -def create_client_for_transport(transport: str, server_url: str): - """Create the appropriate client context manager based on transport type.""" - if transport == "sse": - endpoint = f"{server_url}/sse" - return sse_client(endpoint) - elif transport == "streamable-http": - endpoint = f"{server_url}/mcp" - return streamable_http_client(endpoint) - else: # pragma: no cover - raise ValueError(f"Invalid transport: {transport}") - - -# Callback functions for testing async def sampling_callback( context: RequestContext[ClientSession], params: CreateMessageRequestParams ) -> CreateMessageResult: @@ -210,147 +106,85 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E return ElicitResult(action="decline") -# Test basic tools -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("basic_tool", "sse"), - ("basic_tool", "streamable-http"), - ], - indirect=True, -) -async def test_basic_tools(server_transport: str, server_url: str) -> None: +async def test_basic_tools() -> None: """Test basic tool functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Tool Example" - assert result.capabilities.tools is not None - - # Test sum tool - tool_result = await session.call_tool("sum", {"a": 5, "b": 3}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "8" - - # Test weather tool - weather_result = await session.call_tool("get_weather", {"city": "London"}) - assert len(weather_result.content) == 1 - assert isinstance(weather_result.content[0], TextContent) - assert "Weather in London: 22degreesC" in weather_result.content[0].text - - -# Test resources -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("basic_resource", "sse"), - ("basic_resource", "streamable-http"), - ], - indirect=True, -) -async def test_basic_resources(server_transport: str, server_url: str) -> None: + async with Client(basic_tool.mcp) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.tools is not None + + # Test sum tool + tool_result = await client.call_tool("sum", {"a": 5, "b": 3}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "8" + + # Test weather tool + weather_result = await client.call_tool("get_weather", {"city": "London"}) + assert len(weather_result.content) == 1 + assert isinstance(weather_result.content[0], TextContent) + assert "Weather in London: 22degreesC" in weather_result.content[0].text + + +async def test_basic_resources() -> None: """Test basic resource functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Resource Example" - assert result.capabilities.resources is not None - - # Test document resource - doc_content = await session.read_resource("file://documents/readme") - assert isinstance(doc_content, ReadResourceResult) - assert len(doc_content.contents) == 1 - assert isinstance(doc_content.contents[0], TextResourceContents) - assert "Content of readme" in doc_content.contents[0].text - - # Test settings resource - settings_content = await session.read_resource("config://settings") - assert isinstance(settings_content, ReadResourceResult) - assert len(settings_content.contents) == 1 - assert isinstance(settings_content.contents[0], TextResourceContents) - settings_json = json.loads(settings_content.contents[0].text) - assert settings_json["theme"] == "dark" - assert settings_json["language"] == "en" - - -# Test prompts -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("basic_prompt", "sse"), - ("basic_prompt", "streamable-http"), - ], - indirect=True, -) -async def test_basic_prompts(server_transport: str, server_url: str) -> None: + async with Client(basic_resource.mcp) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.resources is not None + + # Test document resource + doc_content = await client.read_resource("file://documents/readme") + assert isinstance(doc_content, ReadResourceResult) + assert len(doc_content.contents) == 1 + assert isinstance(doc_content.contents[0], TextResourceContents) + assert "Content of readme" in doc_content.contents[0].text + + # Test settings resource + settings_content = await client.read_resource("config://settings") + assert isinstance(settings_content, ReadResourceResult) + assert len(settings_content.contents) == 1 + assert isinstance(settings_content.contents[0], TextResourceContents) + settings_json = json.loads(settings_content.contents[0].text) + assert settings_json["theme"] == "dark" + assert settings_json["language"] == "en" + + +async def test_basic_prompts() -> None: """Test basic prompt functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Prompt Example" - assert result.capabilities.prompts is not None - - # Test review_code prompt - prompts = await session.list_prompts() - review_prompt = next((p for p in prompts.prompts if p.name == "review_code"), None) - assert review_prompt is not None - - prompt_result = await session.get_prompt("review_code", {"code": "def hello():\n print('Hello')"}) - assert isinstance(prompt_result, GetPromptResult) - assert len(prompt_result.messages) == 1 - assert isinstance(prompt_result.messages[0].content, TextContent) - assert "Please review this code:" in prompt_result.messages[0].content.text - assert "def hello():" in prompt_result.messages[0].content.text - - # Test debug_error prompt - debug_result = await session.get_prompt( - "debug_error", {"error": "TypeError: 'NoneType' object is not subscriptable"} - ) - assert isinstance(debug_result, GetPromptResult) - assert len(debug_result.messages) == 3 - assert debug_result.messages[0].role == "user" - assert isinstance(debug_result.messages[0].content, TextContent) - assert "I'm seeing this error:" in debug_result.messages[0].content.text - assert debug_result.messages[1].role == "user" - assert isinstance(debug_result.messages[1].content, TextContent) - assert "TypeError" in debug_result.messages[1].content.text - assert debug_result.messages[2].role == "assistant" - assert isinstance(debug_result.messages[2].content, TextContent) - assert "I'll help debug that" in debug_result.messages[2].content.text - - -# Test progress reporting -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("tool_progress", "sse"), - ("tool_progress", "streamable-http"), - ], - indirect=True, -) -async def test_tool_progress(server_transport: str, server_url: str) -> None: + async with Client(basic_prompt.mcp) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.prompts is not None + + # Test review_code prompt + prompts = await client.list_prompts() + review_prompt = next((p for p in prompts.prompts if p.name == "review_code"), None) + assert review_prompt is not None + + prompt_result = await client.get_prompt("review_code", {"code": "def hello():\n print('Hello')"}) + assert isinstance(prompt_result, GetPromptResult) + assert len(prompt_result.messages) == 1 + assert isinstance(prompt_result.messages[0].content, TextContent) + assert "Please review this code:" in prompt_result.messages[0].content.text + assert "def hello():" in prompt_result.messages[0].content.text + + # Test debug_error prompt + debug_result = await client.get_prompt( + "debug_error", {"error": "TypeError: 'NoneType' object is not subscriptable"} + ) + assert isinstance(debug_result, GetPromptResult) + assert len(debug_result.messages) == 3 + assert debug_result.messages[0].role == "user" + assert isinstance(debug_result.messages[0].content, TextContent) + assert "I'm seeing this error:" in debug_result.messages[0].content.text + assert debug_result.messages[1].role == "user" + assert isinstance(debug_result.messages[1].content, TextContent) + assert "TypeError" in debug_result.messages[1].content.text + assert debug_result.messages[2].role == "assistant" + assert isinstance(debug_result.messages[2].content, TextContent) + assert "I'll help debug that" in debug_result.messages[2].content.text + + +async def test_tool_progress() -> None: """Test tool progress reporting.""" - transport = server_transport collector = NotificationCollector() async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): @@ -358,134 +192,79 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult] if isinstance(message, Exception): # pragma: no cover raise message - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Progress Example" - - # Test progress callback - progress_updates = [] - - async def progress_callback(progress: float, total: float | None, message: str | None) -> None: - progress_updates.append((progress, total, message)) - - # Call tool with progress - steps = 3 - tool_result = await session.call_tool( - "long_running_task", - {"task_name": "Test Task", "steps": steps}, - progress_callback=progress_callback, - ) - assert tool_result.content == snapshot([TextContent(text="Task 'Test Task' completed")]) - - # Verify progress updates - assert len(progress_updates) == steps - for i, (progress, total, message) in enumerate(progress_updates): - expected_progress = (i + 1) / steps - assert abs(progress - expected_progress) < 0.01 - assert total == 1.0 - assert f"Step {i + 1}/{steps}" in message - - # Verify log messages - assert len(collector.log_messages) > 0 - - -# Test sampling -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("sampling", "sse"), - ("sampling", "streamable-http"), - ], - indirect=True, -) -async def test_sampling(server_transport: str, server_url: str) -> None: + async with Client(tool_progress.mcp, message_handler=message_handler) as client: + # Test progress callback + progress_updates = [] + + async def progress_callback(progress: float, total: float | None, message: str | None) -> None: + progress_updates.append((progress, total, message)) + + # Call tool with progress + steps = 3 + tool_result = await client.call_tool( + "long_running_task", + {"task_name": "Test Task", "steps": steps}, + progress_callback=progress_callback, + ) + assert tool_result.content == snapshot([TextContent(text="Task 'Test Task' completed")]) + + # Verify progress updates + assert len(progress_updates) == steps + for i, (progress, total, message) in enumerate(progress_updates): + expected_progress = (i + 1) / steps + assert abs(progress - expected_progress) < 0.01 + assert total == 1.0 + assert f"Step {i + 1}/{steps}" in message + + # Verify log messages + assert len(collector.log_messages) > 0 + + +async def test_sampling() -> None: """Test sampling (LLM interaction) functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, sampling_callback=sampling_callback) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Sampling Example" - assert result.capabilities.tools is not None - - # Test sampling tool - sampling_result = await session.call_tool("generate_poem", {"topic": "nature"}) - assert len(sampling_result.content) == 1 - assert isinstance(sampling_result.content[0], TextContent) - assert "This is a simulated LLM response" in sampling_result.content[0].text - - -# Test elicitation -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("elicitation", "sse"), - ("elicitation", "streamable-http"), - ], - indirect=True, -) -async def test_elicitation(server_transport: str, server_url: str) -> None: + async with Client(sampling.mcp, sampling_callback=sampling_callback) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.tools is not None + + # Test sampling tool + sampling_result = await client.call_tool("generate_poem", {"topic": "nature"}) + assert len(sampling_result.content) == 1 + assert isinstance(sampling_result.content[0], TextContent) + assert "This is a simulated LLM response" in sampling_result.content[0].text + + +async def test_elicitation() -> None: """Test elicitation (user interaction) functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, elicitation_callback=elicitation_callback) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Elicitation Example" - - # Test booking with unavailable date (triggers elicitation) - booking_result = await session.call_tool( - "book_table", - { - "date": "2024-12-25", # Unavailable date - "time": "19:00", - "party_size": 4, - }, - ) - assert len(booking_result.content) == 1 - assert isinstance(booking_result.content[0], TextContent) - assert "[SUCCESS] Booked for 2024-12-26" in booking_result.content[0].text - - # Test booking with available date (no elicitation) - booking_result = await session.call_tool( - "book_table", - { - "date": "2024-12-20", # Available date - "time": "20:00", - "party_size": 2, - }, - ) - assert len(booking_result.content) == 1 - assert isinstance(booking_result.content[0], TextContent) - assert "[SUCCESS] Booked for 2024-12-20 at 20:00" in booking_result.content[0].text - - -# Test notifications -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("notifications", "sse"), - ("notifications", "streamable-http"), - ], - indirect=True, -) -async def test_notifications(server_transport: str, server_url: str) -> None: + async with Client(elicitation.mcp, elicitation_callback=elicitation_callback) as client: + # Test booking with unavailable date (triggers elicitation) + booking_result = await client.call_tool( + "book_table", + { + "date": "2024-12-25", # Unavailable date + "time": "19:00", + "party_size": 4, + }, + ) + assert len(booking_result.content) == 1 + assert isinstance(booking_result.content[0], TextContent) + assert "[SUCCESS] Booked for 2024-12-26" in booking_result.content[0].text + + # Test booking with available date (no elicitation) + booking_result = await client.call_tool( + "book_table", + { + "date": "2024-12-20", # Available date + "time": "20:00", + "party_size": 2, + }, + ) + assert len(booking_result.content) == 1 + assert isinstance(booking_result.content[0], TextContent) + assert "[SUCCESS] Booked for 2024-12-20 at 20:00" in booking_result.content[0].text + + +async def test_notifications() -> None: """Test notifications and logging functionality.""" - transport = server_transport collector = NotificationCollector() async def message_handler(message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception): @@ -493,150 +272,87 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult] if isinstance(message, Exception): # pragma: no cover raise message - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream, message_handler=message_handler) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Notifications Example" - - # Call tool that generates notifications - tool_result = await session.call_tool("process_data", {"data": "test_data"}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert "Processed: test_data" in tool_result.content[0].text - - # Verify log messages at different levels - assert len(collector.log_messages) >= 4 - log_levels = {msg.level for msg in collector.log_messages} - assert "debug" in log_levels - assert "info" in log_levels - assert "warning" in log_levels - assert "error" in log_levels - - # Verify resource list changed notification - assert len(collector.resource_notifications) > 0 - - -# Test completion -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("completion", "sse"), - ("completion", "streamable-http"), - ], - indirect=True, -) -async def test_completion(server_transport: str, server_url: str) -> None: + async with Client(notifications.mcp, message_handler=message_handler) as client: + # Call tool that generates notifications + tool_result = await client.call_tool("process_data", {"data": "test_data"}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert "Processed: test_data" in tool_result.content[0].text + + # Verify log messages at different levels + assert len(collector.log_messages) >= 4 + log_levels = {msg.level for msg in collector.log_messages} + assert "debug" in log_levels + assert "info" in log_levels + assert "warning" in log_levels + assert "error" in log_levels + + # Verify resource list changed notification + assert len(collector.resource_notifications) > 0 + + +async def test_completion() -> None: """Test completion (autocomplete) functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Example" - assert result.capabilities.resources is not None - assert result.capabilities.prompts is not None - - # Test resource completion - completion_result = await session.complete( - ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), - argument={"name": "repo", "value": ""}, - context_arguments={"owner": "modelcontextprotocol"}, - ) - - assert completion_result is not None - assert hasattr(completion_result, "completion") - assert completion_result.completion is not None - assert len(completion_result.completion.values) == 3 - assert "python-sdk" in completion_result.completion.values - assert "typescript-sdk" in completion_result.completion.values - assert "specification" in completion_result.completion.values - - # Test prompt completion - completion_result = await session.complete( - ref=PromptReference(type="ref/prompt", name="review_code"), - argument={"name": "language", "value": "py"}, - ) - - assert completion_result is not None - assert hasattr(completion_result, "completion") - assert completion_result.completion is not None - assert "python" in completion_result.completion.values - assert all(lang.startswith("py") for lang in completion_result.completion.values) - - -# Test MCPServer quickstart example -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("mcpserver_quickstart", "sse"), - ("mcpserver_quickstart", "streamable-http"), - ], - indirect=True, -) -async def test_mcpserver_quickstart(server_transport: str, server_url: str) -> None: + async with Client(completion.mcp) as client: + assert client.server_capabilities is not None + assert client.server_capabilities.resources is not None + assert client.server_capabilities.prompts is not None + + # Test resource completion + completion_result = await client.complete( + ref=ResourceTemplateReference(type="ref/resource", uri="github://repos/{owner}/{repo}"), + argument={"name": "repo", "value": ""}, + context_arguments={"owner": "modelcontextprotocol"}, + ) + + assert completion_result is not None + assert hasattr(completion_result, "completion") + assert completion_result.completion is not None + assert len(completion_result.completion.values) == 3 + assert "python-sdk" in completion_result.completion.values + assert "typescript-sdk" in completion_result.completion.values + assert "specification" in completion_result.completion.values + + # Test prompt completion + completion_result = await client.complete( + ref=PromptReference(type="ref/prompt", name="review_code"), + argument={"name": "language", "value": "py"}, + ) + + assert completion_result is not None + assert hasattr(completion_result, "completion") + assert completion_result.completion is not None + assert "python" in completion_result.completion.values + assert all(lang.startswith("py") for lang in completion_result.completion.values) + + +async def test_mcpserver_quickstart() -> None: """Test MCPServer quickstart example.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Demo" - - # Test add tool - tool_result = await session.call_tool("add", {"a": 10, "b": 20}) - assert len(tool_result.content) == 1 - assert isinstance(tool_result.content[0], TextContent) - assert tool_result.content[0].text == "30" - - # Test greeting resource directly - resource_result = await session.read_resource("greeting://Alice") - assert len(resource_result.contents) == 1 - assert isinstance(resource_result.contents[0], TextResourceContents) - assert resource_result.contents[0].text == "Hello, Alice!" - - -# Test structured output example -@pytest.mark.anyio -@pytest.mark.parametrize( - "server_transport", - [ - ("structured_output", "sse"), - ("structured_output", "streamable-http"), - ], - indirect=True, -) -async def test_structured_output(server_transport: str, server_url: str) -> None: + async with Client(mcpserver_quickstart.mcp) as client: + # Test add tool + tool_result = await client.call_tool("add", {"a": 10, "b": 20}) + assert len(tool_result.content) == 1 + assert isinstance(tool_result.content[0], TextContent) + assert tool_result.content[0].text == "30" + + # Test greeting resource directly + resource_result = await client.read_resource("greeting://Alice") + assert len(resource_result.contents) == 1 + assert isinstance(resource_result.contents[0], TextResourceContents) + assert resource_result.contents[0].text == "Hello, Alice!" + + +async def test_structured_output() -> None: """Test structured output functionality.""" - transport = server_transport - client_cm = create_client_for_transport(transport, server_url) - - async with client_cm as (read_stream, write_stream): - async with ClientSession(read_stream, write_stream) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == "Structured Output Example" - - # Test get_weather tool - weather_result = await session.call_tool("get_weather", {"city": "New York"}) - assert len(weather_result.content) == 1 - assert isinstance(weather_result.content[0], TextContent) - - # Check that the result contains expected weather data - result_text = weather_result.content[0].text - assert "22.5" in result_text # temperature - assert "sunny" in result_text # condition - assert "45" in result_text # humidity - assert "5.2" in result_text # wind_speed + async with Client(structured_output.mcp) as client: + # Test get_weather tool + weather_result = await client.call_tool("get_weather", {"city": "New York"}) + assert len(weather_result.content) == 1 + assert isinstance(weather_result.content[0], TextContent) + + # Check that the result contains expected weather data + result_text = weather_result.content[0].text + assert "22.5" in result_text # temperature + assert "sunny" in result_text # condition + assert "45" in result_text # humidity + assert "5.2" in result_text # wind_speed From 67201a9bbd9c31419716886d9c1036e6370f83dc Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Mar 2026 15:48:30 +0000 Subject: [PATCH 2/4] test: fix WS test port race; narrow to single smoke test covering both transport ends (#2267) --- src/mcp/server/websocket.py | 8 +- tests/client/test_client.py | 17 ++- tests/shared/test_ws.py | 205 ++++-------------------------------- tests/test_helpers.py | 54 ++++++++++ 4 files changed, 97 insertions(+), 187 deletions(-) diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 3e675da5f..32b50560c 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -10,7 +10,7 @@ from mcp.shared.message import SessionMessage -@asynccontextmanager # pragma: no cover +@asynccontextmanager async def websocket_server(scope: Scope, receive: Receive, send: Send): """WebSocket server transport for MCP. This is an ASGI application, suitable for use with a framework like Starlette and a server like Hypercorn. @@ -34,13 +34,13 @@ async def ws_reader(): async for msg in websocket.iter_text(): try: client_message = types.jsonrpc_message_adapter.validate_json(msg, by_name=False) - except ValidationError as exc: + except ValidationError as exc: # pragma: no cover await read_stream_writer.send(exc) continue session_message = SessionMessage(client_message) await read_stream_writer.send(session_message) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: no cover await websocket.close() async def ws_writer(): @@ -49,7 +49,7 @@ async def ws_writer(): async for session_message in write_stream_reader: obj = session_message.message.model_dump_json(by_alias=True, exclude_unset=True) await websocket.send_text(obj) - except anyio.ClosedResourceError: + except anyio.ClosedResourceError: # pragma: no cover await websocket.close() async with anyio.create_task_group() as tg: diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 45300063a..3bdd30570 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -8,7 +8,7 @@ import pytest from inline_snapshot import snapshot -from mcp import types +from mcp import MCPError, types from mcp.client._memory import InMemoryTransport from mcp.client.client import Client from mcp.server import Server, ServerRequestContext @@ -175,6 +175,21 @@ async def test_read_resource(app: MCPServer): ) +async def test_read_resource_error_propagates(): + """MCPError raised by a server handler propagates to the client with its code intact.""" + + async def handle_read_resource( + ctx: ServerRequestContext, params: types.ReadResourceRequestParams + ) -> ReadResourceResult: + raise MCPError(code=404, message="no resource with that URI was found") + + server = Server("test", on_read_resource=handle_read_resource) + async with Client(server) as client: + with pytest.raises(MCPError) as exc_info: + await client.read_resource("unknown://example") + assert exc_info.value.error.code == 404 + + async def test_get_prompt(app: MCPServer): """Test getting a prompt.""" async with Client(app) as client: diff --git a/tests/shared/test_ws.py b/tests/shared/test_ws.py index 9addb661d..482dcdcf3 100644 --- a/tests/shared/test_ws.py +++ b/tests/shared/test_ws.py @@ -1,210 +1,51 @@ -import multiprocessing -import socket -from collections.abc import AsyncGenerator, Generator -from urllib.parse import urlparse +"""Smoke test for the WebSocket transport. + +Runs the full WS stack end-to-end over a real TCP connection, covering both +``src/mcp/client/websocket.py`` and ``src/mcp/server/websocket.py``. MCP +semantics (error propagation, timeouts, etc.) are transport-agnostic and are +covered in ``tests/client/test_client.py`` and ``tests/issues/test_88_random_error.py``. +""" + +from collections.abc import Generator -import anyio import pytest -import uvicorn from starlette.applications import Starlette from starlette.routing import WebSocketRoute from starlette.websockets import WebSocket -from mcp import MCPError from mcp.client.session import ClientSession from mcp.client.websocket import websocket_client -from mcp.server import Server, ServerRequestContext +from mcp.server import Server from mcp.server.websocket import websocket_server -from mcp.types import ( - CallToolRequestParams, - CallToolResult, - EmptyResult, - InitializeResult, - ListToolsResult, - PaginatedRequestParams, - ReadResourceRequestParams, - ReadResourceResult, - TextContent, - TextResourceContents, - Tool, -) -from tests.test_helpers import wait_for_server +from mcp.types import EmptyResult, InitializeResult +from tests.test_helpers import run_uvicorn_in_thread SERVER_NAME = "test_server_for_WS" -@pytest.fixture -def server_port() -> int: - with socket.socket() as s: - s.bind(("127.0.0.1", 0)) - return s.getsockname()[1] - - -@pytest.fixture -def server_url(server_port: int) -> str: - return f"ws://127.0.0.1:{server_port}" - - -async def handle_read_resource( # pragma: no cover - ctx: ServerRequestContext, params: ReadResourceRequestParams -) -> ReadResourceResult: - parsed = urlparse(str(params.uri)) - if parsed.scheme == "foobar": - return ReadResourceResult( - contents=[TextResourceContents(uri=str(params.uri), text=f"Read {parsed.netloc}", mime_type="text/plain")] - ) - elif parsed.scheme == "slow": - await anyio.sleep(2.0) - return ReadResourceResult( - contents=[ - TextResourceContents( - uri=str(params.uri), text=f"Slow response from {parsed.netloc}", mime_type="text/plain" - ) - ] - ) - raise MCPError(code=404, message="OOPS! no resource with that URI was found") - - -async def handle_list_tools( # pragma: no cover - ctx: ServerRequestContext, params: PaginatedRequestParams | None -) -> ListToolsResult: - return ListToolsResult( - tools=[ - Tool( - name="test_tool", - description="A test tool", - input_schema={"type": "object", "properties": {}}, - ) - ] - ) - +def make_server_app() -> Starlette: + srv = Server(SERVER_NAME) -async def handle_call_tool( # pragma: no cover - ctx: ServerRequestContext, params: CallToolRequestParams -) -> CallToolResult: - return CallToolResult(content=[TextContent(type="text", text=f"Called {params.name}")]) - - -def _create_server() -> Server: # pragma: no cover - return Server( - SERVER_NAME, - on_read_resource=handle_read_resource, - on_list_tools=handle_list_tools, - on_call_tool=handle_call_tool, - ) - - -# Test fixtures -def make_server_app() -> Starlette: # pragma: no cover - """Create test Starlette app with WebSocket transport""" - server = _create_server() - - async def handle_ws(websocket: WebSocket): + async def handle_ws(websocket: WebSocket) -> None: async with websocket_server(websocket.scope, websocket.receive, websocket.send) as streams: - await server.run(streams[0], streams[1], server.create_initialization_options()) - - app = Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) - return app - - -def run_server(server_port: int) -> None: # pragma: no cover - app = make_server_app() - server = uvicorn.Server(config=uvicorn.Config(app=app, host="127.0.0.1", port=server_port, log_level="error")) - print(f"starting server on {server_port}") - server.run() - - -@pytest.fixture() -def server(server_port: int) -> Generator[None, None, None]: - proc = multiprocessing.Process(target=run_server, kwargs={"server_port": server_port}, daemon=True) - print("starting process") - proc.start() + await srv.run(streams[0], streams[1], srv.create_initialization_options()) - # Wait for server to be running - print("waiting for server to start") - wait_for_server(server_port) + return Starlette(routes=[WebSocketRoute("/ws", endpoint=handle_ws)]) - yield - print("killing server") - # Signal the server to stop - proc.kill() - proc.join(timeout=2) - if proc.is_alive(): # pragma: no cover - print("server process failed to terminate") - - -@pytest.fixture() -async def initialized_ws_client_session(server: None, server_url: str) -> AsyncGenerator[ClientSession, None]: - """Create and initialize a WebSocket client session""" - async with websocket_client(server_url + "/ws") as streams: - async with ClientSession(*streams) as session: - # Test initialization - result = await session.initialize() - assert isinstance(result, InitializeResult) - assert result.server_info.name == SERVER_NAME - - # Test ping - ping_result = await session.send_ping() - assert isinstance(ping_result, EmptyResult) - - yield session +@pytest.fixture +def ws_server_url() -> Generator[str, None, None]: + with run_uvicorn_in_thread(make_server_app()) as base_url: + yield base_url.replace("http://", "ws://") + "/ws" -# Tests @pytest.mark.anyio -async def test_ws_client_basic_connection(server: None, server_url: str) -> None: - """Test the WebSocket connection establishment""" - async with websocket_client(server_url + "/ws") as streams: +async def test_ws_client_basic_connection(ws_server_url: str) -> None: + async with websocket_client(ws_server_url) as streams: async with ClientSession(*streams) as session: - # Test initialization result = await session.initialize() assert isinstance(result, InitializeResult) assert result.server_info.name == SERVER_NAME - # Test ping ping_result = await session.send_ping() assert isinstance(ping_result, EmptyResult) - - -@pytest.mark.anyio -async def test_ws_client_happy_request_and_response( - initialized_ws_client_session: ClientSession, -) -> None: - """Test a successful request and response via WebSocket""" - result = await initialized_ws_client_session.read_resource("foobar://example") - assert isinstance(result, ReadResourceResult) - assert isinstance(result.contents, list) - assert len(result.contents) > 0 - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "Read example" - - -@pytest.mark.anyio -async def test_ws_client_exception_handling( - initialized_ws_client_session: ClientSession, -) -> None: - """Test exception handling in WebSocket communication""" - with pytest.raises(MCPError) as exc_info: - await initialized_ws_client_session.read_resource("unknown://example") - assert exc_info.value.error.code == 404 - - -@pytest.mark.anyio -async def test_ws_client_timeout( - initialized_ws_client_session: ClientSession, -) -> None: - """Test timeout handling in WebSocket communication""" - # Set a very short timeout to trigger a timeout exception - with pytest.raises(TimeoutError): - with anyio.fail_after(0.1): # 100ms timeout - await initialized_ws_client_session.read_resource("slow://example") - - # Now test that we can still use the session after a timeout - with anyio.fail_after(5): # Longer timeout to allow completion - result = await initialized_ws_client_session.read_resource("foobar://example") - assert isinstance(result, ReadResourceResult) - assert isinstance(result.contents, list) - assert len(result.contents) > 0 - assert isinstance(result.contents[0], TextResourceContents) - assert result.contents[0].text == "Read example" diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 5c04c269f..810c72820 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,7 +1,61 @@ """Common test utilities for MCP server tests.""" import socket +import threading import time +from collections.abc import Generator +from contextlib import contextmanager +from typing import Any + +import uvicorn + +_SERVER_SHUTDOWN_TIMEOUT_S = 5.0 + + +@contextmanager +def run_uvicorn_in_thread(app: Any, **config_kwargs: Any) -> Generator[str, None, None]: + """Run a uvicorn server in a background thread on an ephemeral port. + + The socket is bound and put into listening state *before* the thread + starts, so the port is known immediately with no wait. The kernel's + listen queue buffers any connections that arrive before uvicorn's event + loop reaches ``accept()``, so callers can connect as soon as this + function yields — no polling, no sleeps, no startup race. + + This also avoids the TOCTOU race of the old pick-a-port-then-rebind + pattern: the socket passed here is the one uvicorn serves on, with no + gap where another pytest-xdist worker could claim it. + + Args: + app: ASGI application to serve. + **config_kwargs: Additional keyword arguments for :class:`uvicorn.Config` + (e.g. ``log_level``). ``host``/``port`` are ignored since the + socket is pre-bound. + + Yields: + The base URL of the running server, e.g. ``http://127.0.0.1:54321``. + """ + host = "127.0.0.1" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind((host, 0)) + sock.listen() + port = sock.getsockname()[1] + + config_kwargs.setdefault("log_level", "error") + # Uvicorn's interface autodetection calls asyncio.iscoroutinefunction, + # which Python 3.14 deprecates. Under filterwarnings=error this crashes + # the server thread silently. Starlette is asgi3; skip the autodetect. + config_kwargs.setdefault("interface", "asgi3") + server = uvicorn.Server(config=uvicorn.Config(app=app, **config_kwargs)) + + thread = threading.Thread(target=server.run, kwargs={"sockets": [sock]}, daemon=True) + thread.start() + try: + yield f"http://{host}:{port}" + finally: + server.should_exit = True + thread.join(timeout=_SERVER_SHUTDOWN_TIMEOUT_S) def wait_for_server(port: int, timeout: float = 20.0) -> None: From 20dd94632ef1b8a8e2db5ca9c0c38dab845fa5bb Mon Sep 17 00:00:00 2001 From: Max Isbey <224885523+maxisbey@users.noreply.github.com> Date: Wed, 18 Mar 2026 17:31:26 +0000 Subject: [PATCH 3/4] feat(client): store InitializeResult as initialize_result (#2300) --- docs/migration.md | 24 +++++++++++++++++++ src/mcp/client/client.py | 15 ++++++++---- src/mcp/client/session.py | 13 ++++++----- tests/client/test_client.py | 3 ++- tests/client/test_session.py | 27 +++++++++++----------- tests/client/transports/test_memory.py | 2 +- tests/server/mcpserver/test_integration.py | 17 +++++--------- 7 files changed, 64 insertions(+), 37 deletions(-) diff --git a/docs/migration.md b/docs/migration.md index 7cf032553..dd6a7a18f 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -169,6 +169,30 @@ result = await session.list_resources(params=PaginatedRequestParams(cursor="next result = await session.list_tools(params=PaginatedRequestParams(cursor="next_page_token")) ``` +### `ClientSession.get_server_capabilities()` replaced by `initialize_result` property + +`ClientSession` now stores the full `InitializeResult` via an `initialize_result` property. This provides access to `server_info`, `capabilities`, `instructions`, and the negotiated `protocol_version` through a single property. The `get_server_capabilities()` method has been removed. + +**Before (v1):** + +```python +capabilities = session.get_server_capabilities() +# server_info, instructions, protocol_version were not stored — had to capture initialize() return value +``` + +**After (v2):** + +```python +result = session.initialize_result +if result is not None: + capabilities = result.capabilities + server_info = result.server_info + instructions = result.instructions + version = result.protocol_version +``` + +The high-level `Client.initialize_result` returns the same `InitializeResult` but is non-nullable — initialization is guaranteed inside the context manager, so no `None` check is needed. This replaces v1's `Client.server_capabilities`; use `client.initialize_result.capabilities` instead. + ### `McpError` renamed to `MCPError` The `McpError` exception class has been renamed to `MCPError` for consistent naming with the MCP acronym style used throughout the SDK. diff --git a/src/mcp/client/client.py b/src/mcp/client/client.py index 7dc67c584..34d6a360f 100644 --- a/src/mcp/client/client.py +++ b/src/mcp/client/client.py @@ -19,6 +19,7 @@ EmptyResult, GetPromptResult, Implementation, + InitializeResult, ListPromptsResult, ListResourcesResult, ListResourceTemplatesResult, @@ -29,7 +30,6 @@ ReadResourceResult, RequestParamsMeta, ResourceTemplateReference, - ServerCapabilities, ) @@ -155,9 +155,16 @@ def session(self) -> ClientSession: return self._session @property - def server_capabilities(self) -> ServerCapabilities | None: - """The server capabilities received during initialization, or None if not yet initialized.""" - return self.session.get_server_capabilities() + def initialize_result(self) -> InitializeResult: + """The server's InitializeResult. + + Contains server_info, capabilities, instructions, and the negotiated protocol_version. + Raises RuntimeError if accessed outside the context manager. + """ + result = self.session.initialize_result + if result is None: # pragma: no cover + raise RuntimeError("Client must be used within an async context manager") + return result async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> EmptyResult: """Send a ping request to the server.""" diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index a0ca751bd..7c964a334 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -131,7 +131,7 @@ def __init__( self._logging_callback = logging_callback or _default_logging_callback self._message_handler = message_handler or _default_message_handler self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} - self._server_capabilities: types.ServerCapabilities | None = None + self._initialize_result: types.InitializeResult | None = None self._experimental_features: ExperimentalClientFeatures | None = None # Experimental: Task handlers (use defaults if not provided) @@ -185,18 +185,19 @@ async def initialize(self) -> types.InitializeResult: if result.protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: raise RuntimeError(f"Unsupported protocol version from the server: {result.protocol_version}") - self._server_capabilities = result.capabilities + self._initialize_result = result await self.send_notification(types.InitializedNotification()) return result - def get_server_capabilities(self) -> types.ServerCapabilities | None: - """Return the server capabilities received during initialization. + @property + def initialize_result(self) -> types.InitializeResult | None: + """The server's InitializeResult. None until initialize() has been called. - Returns None if the session has not been initialized yet. + Contains server_info, capabilities, instructions, and the negotiated protocol_version. """ - return self._server_capabilities + return self._initialize_result @property def experimental(self) -> ExperimentalClientFeatures: diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 3bdd30570..18368e6bb 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -99,7 +99,7 @@ def greeting_prompt(name: str) -> str: async def test_client_is_initialized(app: MCPServer): """Test that the client is initialized after entering context.""" async with Client(app) as client: - assert client.server_capabilities == snapshot( + assert client.initialize_result.capabilities == snapshot( ServerCapabilities( experimental={}, prompts=PromptsCapability(list_changed=False), @@ -107,6 +107,7 @@ async def test_client_is_initialized(app: MCPServer): tools=ToolsCapability(list_changed=False), ) ) + assert client.initialize_result.server_info.name == "test" async def test_client_with_simple_server(simple_server: Server): diff --git a/tests/client/test_session.py b/tests/client/test_session.py index d6d13e273..f25c964f0 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -540,8 +540,8 @@ async def mock_server(): @pytest.mark.anyio -async def test_get_server_capabilities(): - """Test that get_server_capabilities returns None before init and capabilities after""" +async def test_initialize_result(): + """Test that initialize_result is None before init and contains the full result after.""" client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[SessionMessage](1) server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[SessionMessage](1) @@ -551,6 +551,8 @@ async def test_get_server_capabilities(): resources=types.ResourcesCapability(subscribe=True, list_changed=True), tools=types.ToolsCapability(list_changed=False), ) + expected_server_info = Implementation(name="mock-server", version="0.1.0") + expected_instructions = "Use the tools wisely." async def mock_server(): session_message = await client_to_server_receive.receive() @@ -564,7 +566,8 @@ async def mock_server(): result = InitializeResult( protocol_version=LATEST_PROTOCOL_VERSION, capabilities=expected_capabilities, - server_info=Implementation(name="mock-server", version="0.1.0"), + server_info=expected_server_info, + instructions=expected_instructions, ) async with server_to_client_send: @@ -590,21 +593,17 @@ async def mock_server(): server_to_client_send, server_to_client_receive, ): - assert session.get_server_capabilities() is None + assert session.initialize_result is None tg.start_soon(mock_server) await session.initialize() - capabilities = session.get_server_capabilities() - assert capabilities is not None - assert capabilities == expected_capabilities - assert capabilities.logging is not None - assert capabilities.prompts is not None - assert capabilities.prompts.list_changed is True - assert capabilities.resources is not None - assert capabilities.resources.subscribe is True - assert capabilities.tools is not None - assert capabilities.tools.list_changed is False + result = session.initialize_result + assert result is not None + assert result.server_info == expected_server_info + assert result.capabilities == expected_capabilities + assert result.instructions == expected_instructions + assert result.protocol_version == LATEST_PROTOCOL_VERSION @pytest.mark.anyio diff --git a/tests/client/transports/test_memory.py b/tests/client/transports/test_memory.py index 47be3e208..c8fc41fd5 100644 --- a/tests/client/transports/test_memory.py +++ b/tests/client/transports/test_memory.py @@ -69,7 +69,7 @@ async def test_with_mcpserver(mcpserver_server: MCPServer): async def test_server_is_running(mcpserver_server: MCPServer): """Test that the server is running and responding to requests.""" async with Client(mcpserver_server) as client: - assert client.server_capabilities is not None + assert client.initialize_result.capabilities.tools is not None async def test_list_tools(mcpserver_server: MCPServer): diff --git a/tests/server/mcpserver/test_integration.py b/tests/server/mcpserver/test_integration.py index 90e333b77..f71c0574c 100644 --- a/tests/server/mcpserver/test_integration.py +++ b/tests/server/mcpserver/test_integration.py @@ -109,8 +109,7 @@ async def elicitation_callback(context: RequestContext[ClientSession], params: E async def test_basic_tools() -> None: """Test basic tool functionality.""" async with Client(basic_tool.mcp) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.tools is not None + assert client.initialize_result.capabilities.tools is not None # Test sum tool tool_result = await client.call_tool("sum", {"a": 5, "b": 3}) @@ -128,8 +127,7 @@ async def test_basic_tools() -> None: async def test_basic_resources() -> None: """Test basic resource functionality.""" async with Client(basic_resource.mcp) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.resources is not None + assert client.initialize_result.capabilities.resources is not None # Test document resource doc_content = await client.read_resource("file://documents/readme") @@ -151,8 +149,7 @@ async def test_basic_resources() -> None: async def test_basic_prompts() -> None: """Test basic prompt functionality.""" async with Client(basic_prompt.mcp) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.prompts is not None + assert client.initialize_result.capabilities.prompts is not None # Test review_code prompt prompts = await client.list_prompts() @@ -223,8 +220,7 @@ async def progress_callback(progress: float, total: float | None, message: str | async def test_sampling() -> None: """Test sampling (LLM interaction) functionality.""" async with Client(sampling.mcp, sampling_callback=sampling_callback) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.tools is not None + assert client.initialize_result.capabilities.tools is not None # Test sampling tool sampling_result = await client.call_tool("generate_poem", {"topic": "nature"}) @@ -294,9 +290,8 @@ async def message_handler(message: RequestResponder[ServerRequest, ClientResult] async def test_completion() -> None: """Test completion (autocomplete) functionality.""" async with Client(completion.mcp) as client: - assert client.server_capabilities is not None - assert client.server_capabilities.resources is not None - assert client.server_capabilities.prompts is not None + assert client.initialize_result.capabilities.resources is not None + assert client.initialize_result.capabilities.prompts is not None # Test resource completion completion_result = await client.complete( From 5388bea53ad7b13db7d031cdfd27392474c89007 Mon Sep 17 00:00:00 2001 From: Jonathan Hefner Date: Wed, 18 Mar 2026 13:15:17 -0500 Subject: [PATCH 4/4] docs: generate hierarchical per-module API reference pages (#2103) --- docs/api.md | 1 - docs/hooks/gen_ref_pages.py | 35 +++++++++++++++++++++++++++++++++++ docs/index.md | 2 +- docs/migration.md | 2 +- mkdocs.yml | 11 +++++++---- pyproject.toml | 2 ++ uv.lock | 28 ++++++++++++++++++++++++++++ 7 files changed, 74 insertions(+), 7 deletions(-) delete mode 100644 docs/api.md create mode 100644 docs/hooks/gen_ref_pages.py diff --git a/docs/api.md b/docs/api.md deleted file mode 100644 index 3f696af54..000000000 --- a/docs/api.md +++ /dev/null @@ -1 +0,0 @@ -::: mcp diff --git a/docs/hooks/gen_ref_pages.py b/docs/hooks/gen_ref_pages.py new file mode 100644 index 000000000..ad8c19b45 --- /dev/null +++ b/docs/hooks/gen_ref_pages.py @@ -0,0 +1,35 @@ +"""Generate the code reference pages and navigation.""" + +from pathlib import Path + +import mkdocs_gen_files + +nav = mkdocs_gen_files.Nav() + +root = Path(__file__).parent.parent.parent +src = root / "src" + +for path in sorted(src.rglob("*.py")): + module_path = path.relative_to(src).with_suffix("") + doc_path = path.relative_to(src).with_suffix(".md") + full_doc_path = Path("api", doc_path) + + parts = tuple(module_path.parts) + + if parts[-1] == "__init__": + parts = parts[:-1] + doc_path = doc_path.with_name("index.md") + full_doc_path = full_doc_path.with_name("index.md") + elif parts[-1].startswith("_"): + continue + + nav[parts] = doc_path.as_posix() + + with mkdocs_gen_files.open(full_doc_path, "w") as fd: + ident = ".".join(parts) + fd.write(f"::: {ident}") + + mkdocs_gen_files.set_edit_path(full_doc_path, path.relative_to(root)) + +with mkdocs_gen_files.open("api/SUMMARY.md", "w") as nav_file: + nav_file.writelines(nav.build_literate_nav()) diff --git a/docs/index.md b/docs/index.md index e096d910b..436d1c8fc 100644 --- a/docs/index.md +++ b/docs/index.md @@ -64,4 +64,4 @@ npx -y @modelcontextprotocol/inspector ## API Reference -Full API documentation is available in the [API Reference](api.md). +Full API documentation is available in the [API Reference](api/mcp/index.md). diff --git a/docs/migration.md b/docs/migration.md index dd6a7a18f..3b47f9aad 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -883,6 +883,6 @@ The lowlevel `Server` also now exposes a `session_manager` property to access th If you encounter issues during migration: -1. Check the [API Reference](api.md) for updated method signatures +1. Check the [API Reference](api/mcp/index.md) for updated method signatures 2. Review the [examples](https://github.com/modelcontextprotocol/python-sdk/tree/main/examples) for updated usage patterns 3. Open an issue on [GitHub](https://github.com/modelcontextprotocol/python-sdk/issues) if you find a bug or need further assistance diff --git a/mkdocs.yml b/mkdocs.yml index 070c533e3..3a555785a 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -25,7 +25,7 @@ nav: - Introduction: experimental/tasks.md - Server Implementation: experimental/tasks-server.md - Client Usage: experimental/tasks-client.md - - API Reference: api.md + - API Reference: api/ theme: name: "material" @@ -115,10 +115,15 @@ plugins: - social: enabled: !ENV [ENABLE_SOCIAL_CARDS, false] - glightbox + - gen-files: + scripts: + - docs/hooks/gen_ref_pages.py + - literate-nav: + nav_file: SUMMARY.md - mkdocstrings: handlers: python: - paths: [src/mcp] + paths: [src] options: relative_crossrefs: true members_order: source @@ -126,8 +131,6 @@ plugins: show_signature_annotations: true signature_crossrefs: true group_by_category: false - # 3 because docs are in pages with an H2 just above them - heading_level: 3 inventories: - url: https://docs.python.org/3/objects.inv - url: https://docs.pydantic.dev/latest/objects.inv diff --git a/pyproject.toml b/pyproject.toml index f275b90cf..624ade170 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,9 @@ dev = [ ] docs = [ "mkdocs>=1.6.1", + "mkdocs-gen-files>=0.5.0", "mkdocs-glightbox>=0.4.0", + "mkdocs-literate-nav>=0.6.1", "mkdocs-material[imaging]>=9.5.45", "mkdocstrings-python>=2.0.1", ] diff --git a/uv.lock b/uv.lock index c25047e48..4af3532ea 100644 --- a/uv.lock +++ b/uv.lock @@ -840,7 +840,9 @@ dev = [ ] docs = [ { name = "mkdocs" }, + { name = "mkdocs-gen-files" }, { name = "mkdocs-glightbox" }, + { name = "mkdocs-literate-nav" }, { name = "mkdocs-material", extra = ["imaging"] }, { name = "mkdocstrings-python" }, ] @@ -888,7 +890,9 @@ dev = [ ] docs = [ { name = "mkdocs", specifier = ">=1.6.1" }, + { name = "mkdocs-gen-files", specifier = ">=0.5.0" }, { name = "mkdocs-glightbox", specifier = ">=0.4.0" }, + { name = "mkdocs-literate-nav", specifier = ">=0.6.1" }, { name = "mkdocs-material", extras = ["imaging"], specifier = ">=9.5.45" }, { name = "mkdocstrings-python", specifier = ">=2.0.1" }, ] @@ -1501,6 +1505,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/4d/7123b6fa2278000688ebd338e2a06d16870aaf9eceae6ba047ea05f92df1/mkdocs_autorefs-1.4.3-py3-none-any.whl", hash = "sha256:469d85eb3114801d08e9cc55d102b3ba65917a869b893403b8987b601cf55dc9", size = 25034, upload-time = "2025-08-26T14:23:15.906Z" }, ] +[[package]] +name = "mkdocs-gen-files" +version = "0.6.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mkdocs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/61/35/f26349f7fa18414eb2e25d75a6fa9c7e3186c36e1d227c0b2d785a7bd5c4/mkdocs_gen_files-0.6.0.tar.gz", hash = "sha256:52022dc14dcc0451e05e54a8f5d5e7760351b6701eff816d1e9739577ec5635e", size = 8642, upload-time = "2025-11-23T12:13:22.124Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8d/ec/72417415563c60ae01b36f0d497f1f4c803972f447ef4fb7f7746d6e07db/mkdocs_gen_files-0.6.0-py3-none-any.whl", hash = "sha256:815af15f3e2dbfda379629c1b95c02c8e6f232edf2a901186ea3b204ab1135b2", size = 8182, upload-time = "2025-11-23T12:13:20.756Z" }, +] + [[package]] name = "mkdocs-get-deps" version = "0.2.0" @@ -1527,6 +1543,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/30/cf/e9a0ce9da269746906fdc595c030f6df66793dad1487abd1699af2ba44f1/mkdocs_glightbox-0.5.1-py3-none-any.whl", hash = "sha256:f47af0daff164edf8d36e553338425be3aab6e34b987d9cbbc2ae7819a98cb01", size = 26431, upload-time = "2025-09-04T13:10:27.933Z" }, ] +[[package]] +name = "mkdocs-literate-nav" +version = "0.6.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mkdocs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f6/5f/99aa379b305cd1c2084d42db3d26f6de0ea9bf2cc1d10ed17f61aff35b9a/mkdocs_literate_nav-0.6.2.tar.gz", hash = "sha256:760e1708aa4be86af81a2b56e82c739d5a8388a0eab1517ecfd8e5aa40810a75", size = 17419, upload-time = "2025-03-18T21:53:09.711Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/84/b5b14d2745e4dd1a90115186284e9ee1b4d0863104011ab46abb7355a1c3/mkdocs_literate_nav-0.6.2-py3-none-any.whl", hash = "sha256:0a6489a26ec7598477b56fa112056a5e3a6c15729f0214bea8a4dbc55bd5f630", size = 13261, upload-time = "2025-03-18T21:53:08.1Z" }, +] + [[package]] name = "mkdocs-material" version = "9.7.2"