diff --git a/src/Azure.DataApiBuilder.Mcp/Core/McpProtocolDefaults.cs b/src/Azure.DataApiBuilder.Mcp/Core/McpProtocolDefaults.cs index 48b235f480..289f6aa72d 100644 --- a/src/Azure.DataApiBuilder.Mcp/Core/McpProtocolDefaults.cs +++ b/src/Azure.DataApiBuilder.Mcp/Core/McpProtocolDefaults.cs @@ -1,3 +1,4 @@ +using System.Globalization; using Azure.DataApiBuilder.Product; using Microsoft.Extensions.Configuration; @@ -19,13 +20,18 @@ public static class McpProtocolDefaults /// /// Default MCP protocol version advertised when no configuration override is provided. /// - public const string DEFAULT_PROTOCOL_VERSION = "2025-06-18"; + public const string DEFAULT_PROTOCOL_VERSION = "2025-11-25"; /// /// Configuration key used to override the MCP protocol version. /// public const string PROTOCOL_VERSION_CONFIG_KEY = "MCP:ProtocolVersion"; + /// + /// Protocol version where MCP initialize server description is expected under serverInfo.description. + /// + public const string SERVER_INFO_DESCRIPTION_PROTOCOL_VERSION = "2025-11-25"; + /// /// Helper to resolve the effective protocol version from configuration. /// Falls back to when the key is not set. @@ -34,6 +40,44 @@ public static string ResolveProtocolVersion(IConfiguration? configuration) { return configuration?.GetValue(PROTOCOL_VERSION_CONFIG_KEY) ?? DEFAULT_PROTOCOL_VERSION; } + + /// + /// Resolves the protocol version to send in initialize response as the + /// greatest version that does not exceed the client requested version. + /// + /// The server's effective supported protocol version. + /// The protocol version requested by the client. + /// The protocol version to return to the client. + public static string ResolveInitializeResponseProtocolVersion(string supportedProtocolVersion, string? clientRequestedProtocolVersion) + { + if (string.IsNullOrWhiteSpace(clientRequestedProtocolVersion)) + { + return supportedProtocolVersion; + } + + return CompareProtocolVersions(supportedProtocolVersion, clientRequestedProtocolVersion) <= 0 + ? supportedProtocolVersion + : clientRequestedProtocolVersion; + } + + /// + /// Indicates whether initialize response metadata should use serverInfo.description instead of top-level instructions. + /// + public static bool ShouldUseServerInfoDescription(string protocolVersion) + { + return CompareProtocolVersions(protocolVersion, SERVER_INFO_DESCRIPTION_PROTOCOL_VERSION) >= 0; + } + + private static int CompareProtocolVersions(string leftVersion, string rightVersion) + { + const string PROTOCOL_VERSION_DATE_FORMAT = "yyyy-MM-dd"; + if (DateOnly.TryParseExact(leftVersion, PROTOCOL_VERSION_DATE_FORMAT, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateOnly leftDate) && + DateOnly.TryParseExact(rightVersion, PROTOCOL_VERSION_DATE_FORMAT, CultureInfo.InvariantCulture, DateTimeStyles.None, out DateOnly rightDate)) + { + return leftDate.CompareTo(rightDate); + } + + return string.Compare(leftVersion, rightVersion, StringComparison.Ordinal); + } } } - diff --git a/src/Azure.DataApiBuilder.Mcp/Core/McpStdioServer.cs b/src/Azure.DataApiBuilder.Mcp/Core/McpStdioServer.cs index 36b018549f..343b11f15d 100644 --- a/src/Azure.DataApiBuilder.Mcp/Core/McpStdioServer.cs +++ b/src/Azure.DataApiBuilder.Mcp/Core/McpStdioServer.cs @@ -122,7 +122,7 @@ public async Task RunAsync(CancellationToken cancellationToken) switch (method) { case "initialize": - HandleInitialize(id); + HandleInitialize(id, root); break; case "notifications/initialized": @@ -167,22 +167,27 @@ public async Task RunAsync(CancellationToken cancellationToken) /// /// The request identifier extracted from the incoming JSON-RPC request. Used to correlate the response with the request. /// + /// The incoming initialize request payload. /// - /// This method constructs and writes the MCP "initialize" response to STDOUT. It uses the protocol version defined by PROTOCOL_VERSION - /// and includes supported capabilities and server information. No notifications are sent here; the server waits for the client to send - /// "notifications/initialized" before sending any notifications. + /// This method constructs and writes the MCP "initialize" response to STDOUT. It negotiates the response protocol version from the + /// server-supported version and client-requested version, and includes supported capabilities and server information. No notifications + /// are sent here; the server waits for the client to send "notifications/initialized" before sending any notifications. /// - private void HandleInitialize(JsonElement? id) + private void HandleInitialize(JsonElement? id, JsonElement root) { + string? clientRequestedProtocolVersion = GetClientProtocolVersion(root); + string negotiatedProtocolVersion = + McpProtocolDefaults.ResolveInitializeResponseProtocolVersion(_protocolVersion, clientRequestedProtocolVersion); + // Get the description from runtime config if available - string? instructions = null; + string? description = null; RuntimeConfigProvider? runtimeConfigProvider = _serviceProvider.GetService(); if (runtimeConfigProvider != null) { try { RuntimeConfig runtimeConfig = runtimeConfigProvider.GetConfig(); - instructions = runtimeConfig.Runtime?.Mcp?.Description; + description = runtimeConfig.Runtime?.Mcp?.Description; } catch (Exception) { @@ -191,13 +196,33 @@ private void HandleInitialize(JsonElement? id) } } - // Create the initialize response - only include instructions if non-empty + bool shouldUseServerInfoDescription = McpProtocolDefaults.ShouldUseServerInfoDescription(negotiatedProtocolVersion); + + // Create the initialize response - only include description/instructions if non-empty object result; - if (!string.IsNullOrWhiteSpace(instructions)) + if (!string.IsNullOrWhiteSpace(description) && shouldUseServerInfoDescription) { result = new { - protocolVersion = _protocolVersion, + protocolVersion = negotiatedProtocolVersion, + capabilities = new + { + tools = new { listChanged = true }, + logging = new { } + }, + serverInfo = new + { + name = McpProtocolDefaults.MCP_SERVER_NAME, + version = McpProtocolDefaults.MCP_SERVER_VERSION, + description = description + } + }; + } + else if (!string.IsNullOrWhiteSpace(description)) + { + result = new + { + protocolVersion = negotiatedProtocolVersion, capabilities = new { tools = new { listChanged = true }, @@ -208,14 +233,14 @@ private void HandleInitialize(JsonElement? id) name = McpProtocolDefaults.MCP_SERVER_NAME, version = McpProtocolDefaults.MCP_SERVER_VERSION }, - instructions = instructions + instructions = description }; } else { result = new { - protocolVersion = _protocolVersion, + protocolVersion = negotiatedProtocolVersion, capabilities = new { tools = new { listChanged = true }, @@ -232,6 +257,22 @@ private void HandleInitialize(JsonElement? id) WriteResult(id, result); } + private static string? GetClientProtocolVersion(JsonElement root) + { + if (!root.TryGetProperty("params", out JsonElement paramsElement) || paramsElement.ValueKind != JsonValueKind.Object) + { + return null; + } + + if (!paramsElement.TryGetProperty("protocolVersion", out JsonElement protocolVersionElement) || + protocolVersionElement.ValueKind != JsonValueKind.String) + { + return null; + } + + return protocolVersionElement.GetString(); + } + /// /// Handles the "tools/list" JSON-RPC method by sending the list of available tools to the client. /// diff --git a/src/Service.Tests/UnitTests/McpProtocolDefaultsTests.cs b/src/Service.Tests/UnitTests/McpProtocolDefaultsTests.cs new file mode 100644 index 0000000000..341f17986a --- /dev/null +++ b/src/Service.Tests/UnitTests/McpProtocolDefaultsTests.cs @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Azure.DataApiBuilder.Mcp.Core; +using Microsoft.Extensions.Configuration; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests +{ + [TestClass] + public class McpProtocolDefaultsTests + { + [TestMethod] + public void ResolveProtocolVersion_WithoutOverride_UsesLatestDefault() + { + IConfiguration config = new ConfigurationBuilder().Build(); + + string resolved = McpProtocolDefaults.ResolveProtocolVersion(config); + + Assert.AreEqual("2025-11-25", resolved); + } + + [DataTestMethod] + [DataRow("2025-11-25", "2026-01-01", "2025-11-25")] + [DataRow("2025-11-25", "2025-06-18", "2025-06-18")] + [DataRow("2025-11-25", "2025-11-25", "2025-11-25")] + [DataRow("2025-11-25", null, "2025-11-25")] + [DataRow("a-version", "z-version", "a-version")] + public void ResolveInitializeResponseProtocolVersion_ReturnsExpectedNegotiatedVersion( + string supportedProtocolVersion, + string clientRequestedProtocolVersion, + string expectedVersion) + { + string resolved = McpProtocolDefaults.ResolveInitializeResponseProtocolVersion( + supportedProtocolVersion, + clientRequestedProtocolVersion); + + Assert.AreEqual(expectedVersion, resolved); + } + + [TestMethod] + public void ShouldUseServerInfoDescription_AtOrAboveThreshold_ReturnsTrue() + { + Assert.IsTrue(McpProtocolDefaults.ShouldUseServerInfoDescription("2025-11-25")); + Assert.IsTrue(McpProtocolDefaults.ShouldUseServerInfoDescription("2025-12-01")); + } + + [TestMethod] + public void ShouldUseServerInfoDescription_BelowThreshold_ReturnsFalse() + { + Assert.IsFalse(McpProtocolDefaults.ShouldUseServerInfoDescription("2025-06-18")); + } + } +} diff --git a/src/Service.Tests/UnitTests/McpStdioServerInitializeTests.cs b/src/Service.Tests/UnitTests/McpStdioServerInitializeTests.cs new file mode 100644 index 0000000000..53cff20188 --- /dev/null +++ b/src/Service.Tests/UnitTests/McpStdioServerInitializeTests.cs @@ -0,0 +1,200 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +#nullable enable + +using System; +using System.Collections.Generic; +using System.Diagnostics.CodeAnalysis; +using System.IO; +using System.Reflection; +using System.Text.Json; +using Azure.DataApiBuilder.Config; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Mcp.Core; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.VisualStudio.TestTools.UnitTesting; + +namespace Azure.DataApiBuilder.Service.Tests.UnitTests +{ + [TestClass] + public class McpStdioServerInitializeTests + { + [TestMethod] + public void HandleInitialize_ClientRequests2025_11_25_WithDescription_UsesServerInfoDescription() + { + const string DESCRIPTION = "mcp description"; + McpStdioServer server = CreateServer(description: DESCRIPTION, out StringWriter stdoutCapture); + + JsonElement responseRoot = InvokeHandleInitialize( + server, + stdoutCapture, + """ + {"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"client","version":"1.0.0"}}} + """); + + AssertInitializeEnvelopeAndCapabilities(responseRoot, expectedId: 1, expectedProtocolVersion: "2025-11-25"); + JsonElement result = responseRoot.GetProperty("result"); + + Assert.IsTrue(result.TryGetProperty("serverInfo", out JsonElement serverInfo), "Expected result.serverInfo."); + Assert.AreEqual(DESCRIPTION, serverInfo.GetProperty("description").GetString()); + Assert.IsFalse(result.TryGetProperty("instructions", out _), "Did not expect top-level instructions for 2025-11-25."); + Assert.AreEqual(1, CountOutputLines(stdoutCapture)); + } + + [TestMethod] + public void HandleInitialize_ClientRequests2025_06_18_WithDescription_UsesTopLevelInstructions() + { + const string DESCRIPTION = "legacy instruction text"; + McpStdioServer server = CreateServer(description: DESCRIPTION, out StringWriter stdoutCapture); + + JsonElement responseRoot = InvokeHandleInitialize( + server, + stdoutCapture, + """ + {"jsonrpc":"2.0","id":"abc","method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"client","version":"1.0.0"}}} + """); + + AssertInitializeEnvelopeAndCapabilities(responseRoot, expectedId: "abc", expectedProtocolVersion: "2025-06-18"); + JsonElement result = responseRoot.GetProperty("result"); + + Assert.AreEqual(DESCRIPTION, result.GetProperty("instructions").GetString()); + Assert.IsFalse(result.GetProperty("serverInfo").TryGetProperty("description", out _), "Did not expect serverInfo.description for 2025-06-18."); + Assert.AreEqual(1, CountOutputLines(stdoutCapture)); + } + + [TestMethod] + public void HandleInitialize_ClientRequests2025_11_25_WithoutDescription_EmitsNeitherField() + { + McpStdioServer server = CreateServer(description: null, out StringWriter stdoutCapture); + + JsonElement responseRoot = InvokeHandleInitialize( + server, + stdoutCapture, + """ + {"jsonrpc":"2.0","id":2,"method":"initialize","params":{"protocolVersion":"2025-11-25","capabilities":{},"clientInfo":{"name":"client","version":"1.0.0"}}} + """); + + AssertInitializeEnvelopeAndCapabilities(responseRoot, expectedId: 2, expectedProtocolVersion: "2025-11-25"); + JsonElement result = responseRoot.GetProperty("result"); + + Assert.IsFalse(result.TryGetProperty("instructions", out _), "Did not expect top-level instructions when description is not configured."); + Assert.IsFalse(result.GetProperty("serverInfo").TryGetProperty("description", out _), "Did not expect serverInfo.description when description is not configured."); + Assert.AreEqual(1, CountOutputLines(stdoutCapture)); + } + + private static McpStdioServer CreateServer(string? description, out StringWriter stdoutCapture) + { + stdoutCapture = new StringWriter(); + McpStdoutWriter stdoutWriter = new(stdoutCapture); + + RuntimeConfig runtimeConfig = new( + Schema: RuntimeConfig.DEFAULT_CONFIG_SCHEMA_LINK, + DataSource: null, + Entities: new RuntimeEntities(new Dictionary()), + Runtime: new RuntimeOptions( + Rest: null, + GraphQL: null, + Mcp: new McpRuntimeOptions(Description: description), + Host: null)); + RuntimeConfigProvider runtimeConfigProvider = new StubRuntimeConfigProvider(runtimeConfig); + + IConfiguration configuration = new ConfigurationBuilder().Build(); + ServiceProvider serviceProvider = new ServiceCollection() + .AddSingleton(configuration) + .AddSingleton(stdoutWriter) + .AddSingleton(runtimeConfigProvider) + .BuildServiceProvider(); + + return new McpStdioServer(new McpToolRegistry(), serviceProvider); + } + + private static JsonElement InvokeHandleInitialize(McpStdioServer server, StringWriter stdoutCapture, string initializeRequestJson) + { + MethodInfo? handleInitialize = typeof(McpStdioServer).GetMethod("HandleInitialize", BindingFlags.NonPublic | BindingFlags.Instance); + Assert.IsNotNull(handleInitialize, "Expected private HandleInitialize method to exist."); + + using JsonDocument request = JsonDocument.Parse(initializeRequestJson); + JsonElement requestRoot = request.RootElement; + JsonElement? id = requestRoot.TryGetProperty("id", out JsonElement idElement) ? idElement : null; + + handleInitialize.Invoke(server, new object?[] { id, requestRoot }); + + string output = ExtractSingleOutputLine(stdoutCapture); + using JsonDocument response = JsonDocument.Parse(output); + return response.RootElement.Clone(); + } + + private static void AssertInitializeEnvelopeAndCapabilities(JsonElement responseRoot, object expectedId, string expectedProtocolVersion) + { + Assert.AreEqual("2.0", responseRoot.GetProperty("jsonrpc").GetString()); + if (expectedId is int expectedNumericId) + { + Assert.AreEqual(expectedNumericId, responseRoot.GetProperty("id").GetInt32()); + } + else + { + Assert.AreEqual(expectedId, responseRoot.GetProperty("id").GetString()); + } + + JsonElement result = responseRoot.GetProperty("result"); + Assert.AreEqual(expectedProtocolVersion, result.GetProperty("protocolVersion").GetString()); + + JsonElement capabilities = result.GetProperty("capabilities"); + Assert.IsTrue(capabilities.GetProperty("tools").GetProperty("listChanged").GetBoolean()); + Assert.AreEqual(JsonValueKind.Object, capabilities.GetProperty("logging").ValueKind); + + JsonElement serverInfo = result.GetProperty("serverInfo"); + Assert.AreEqual(McpProtocolDefaults.MCP_SERVER_NAME, serverInfo.GetProperty("name").GetString()); + Assert.AreEqual(McpProtocolDefaults.MCP_SERVER_VERSION, serverInfo.GetProperty("version").GetString()); + } + + private static int CountOutputLines(StringWriter stdoutCapture) + { + return stdoutCapture + .ToString() + .Split(Environment.NewLine, StringSplitOptions.RemoveEmptyEntries) + .Length; + } + + private static string ExtractSingleOutputLine(StringWriter stdoutCapture) + { + string[] lines = stdoutCapture + .ToString() + .Split(Environment.NewLine, StringSplitOptions.RemoveEmptyEntries); + Assert.AreEqual(1, lines.Length, "Expected a single JSON-RPC response line."); + return lines[0]; + } + + private sealed class StubRuntimeConfigProvider : RuntimeConfigProvider + { + private readonly RuntimeConfig _runtimeConfig; + + public StubRuntimeConfigProvider(RuntimeConfig runtimeConfig) : base(new StubRuntimeConfigLoader()) + { + _runtimeConfig = runtimeConfig; + } + + public override RuntimeConfig GetConfig() + { + return _runtimeConfig; + } + } + + private sealed class StubRuntimeConfigLoader : RuntimeConfigLoader + { + public override bool TryLoadKnownConfig([NotNullWhen(true)] out RuntimeConfig? config, bool replaceEnvVar = false) + { + config = null; + return false; + } + + public override string GetPublishedDraftSchemaLink() + { + return RuntimeConfig.DEFAULT_CONFIG_SCHEMA_LINK; + } + } + } +}