diff --git a/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs new file mode 100644 index 0000000000..db1c761d2f --- /dev/null +++ b/src/Azure.DataApiBuilder.Mcp/BuiltInTools/ReadRecordsTool.cs @@ -0,0 +1,422 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Data.Common; +using System.Text.Json; +using Azure.DataApiBuilder.Auth; +using Azure.DataApiBuilder.Config.DatabasePrimitives; +using Azure.DataApiBuilder.Config.ObjectModel; +using Azure.DataApiBuilder.Core.Authorization; +using Azure.DataApiBuilder.Core.Configurations; +using Azure.DataApiBuilder.Core.Models; +using Azure.DataApiBuilder.Core.Parsers; +using Azure.DataApiBuilder.Core.Resolvers; +using Azure.DataApiBuilder.Core.Resolvers.Factories; +using Azure.DataApiBuilder.Core.Services; +using Azure.DataApiBuilder.Core.Services.MetadataProviders; +using Azure.DataApiBuilder.Mcp.Model; +using Azure.DataApiBuilder.Service.Exceptions; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using ModelContextProtocol.Protocol; +using static Azure.DataApiBuilder.Mcp.Model.McpEnums; + +namespace Azure.DataApiBuilder.Mcp.BuiltInTools +{ + public class ReadRecordsTool : IMcpTool + { + public ToolType ToolType { get; } = ToolType.BuiltIn; + + public Tool GetToolMetadata() + { + return new Tool + { + Name = "read_records", + Description = "Retrieves records from a given entity.", + InputSchema = JsonSerializer.Deserialize( + @"{ + ""type"": ""object"", + ""properties"": { + ""entity"": { + ""type"": ""string"", + ""description"": ""The name of the entity to read, as provided by the describe_entities tool. Required."" + }, + ""select"": { + ""type"": ""string"", + ""description"": ""A comma-separated list of field names to include in the response. If omitted, all fields are returned. Optional."" + }, + ""filter"": { + ""type"": ""string"", + ""description"": ""A case-insensitive OData-like expression that defines a query predicate. Supports logical grouping with parentheses and the operators eq, ne, gt, ge, lt, le, and, or, not. Examples: year ge 1990, date lt 2025-01-01T00:00:00Z, (title eq 'Foundation') and (available ne false). Optional."" + }, + ""first"": { + ""type"": ""integer"", + ""description"": ""The maximum number of records to return in the current page. Optional."" + }, + ""orderby"": { + ""type"": ""array"", + ""items"": { ""type"": ""string"" }, + ""description"": ""A list of field names and directions for sorting, for example 'name asc' or 'year desc'. Optional."" + }, + ""after"": { + ""type"": ""string"", + ""description"": ""A cursor token for retrieving the next page of results. Returned as 'after' in the previous response. Optional."" + } + } + }" + ) + }; + } + + public async Task ExecuteAsync( + JsonDocument? arguments, + IServiceProvider serviceProvider, + CancellationToken cancellationToken = default) + { + ILogger? logger = serviceProvider.GetService>(); + + // Get runtime config + RuntimeConfigProvider runtimeConfigProvider = serviceProvider.GetRequiredService(); + RuntimeConfig runtimeConfig = runtimeConfigProvider.GetConfig(); + + if (runtimeConfig.McpDmlTools?.ReadRecords is not true) + { + return BuildErrorResult( + "ToolDisabled", + "The read_records tool is disabled in the configuration.", + logger); + } + + try + { + cancellationToken.ThrowIfCancellationRequested(); + + string entityName; + string? select = null; + string? filter = null; + int? first = null; + IEnumerable? orderby = null; + string? after = null; + + // Extract arguments + if (arguments == null) + { + return BuildErrorResult("InvalidArguments", "No arguments provided.", logger); + } + + JsonElement root = arguments.RootElement; + + if (!root.TryGetProperty("entity", out JsonElement entityElement) || string.IsNullOrWhiteSpace(entityElement.GetString())) + { + return BuildErrorResult("InvalidArguments", "Missing required argument 'entity'.", logger); + } + + entityName = entityElement.GetString()!; + + if (root.TryGetProperty("select", out JsonElement selectElement)) + { + select = selectElement.GetString(); + } + + if (root.TryGetProperty("filter", out JsonElement filterElement)) + { + filter = filterElement.GetString(); + } + + if (root.TryGetProperty("first", out JsonElement firstElement)) + { + first = firstElement.GetInt32(); + } + + if (root.TryGetProperty("orderby", out JsonElement orderbyElement)) + { + orderby = (IEnumerable?)orderbyElement.EnumerateArray().Select(e => e.GetString()); + } + + if (root.TryGetProperty("after", out JsonElement afterElement)) + { + after = afterElement.GetString(); + } + + // Get required services & configuration + IQueryEngineFactory queryEngineFactory = serviceProvider.GetRequiredService(); + IMetadataProviderFactory metadataProviderFactory = serviceProvider.GetRequiredService(); + + // Check metadata for entity exists + string dataSourceName; + ISqlMetadataProvider sqlMetadataProvider; + + try + { + dataSourceName = runtimeConfig.GetDataSourceNameFromEntityName(entityName); + sqlMetadataProvider = metadataProviderFactory.GetMetadataProvider(dataSourceName); + } + catch (Exception) + { + return BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + } + + if (!sqlMetadataProvider.EntityToDatabaseObject.TryGetValue(entityName, out DatabaseObject? dbObject) || dbObject is null) + { + return BuildErrorResult("EntityNotFound", $"Entity '{entityName}' is not defined in the configuration.", logger); + } + + // Authorization check in the existing entity + IAuthorizationResolver authResolver = serviceProvider.GetRequiredService(); + IAuthorizationService authorizationService = serviceProvider.GetRequiredService(); + IHttpContextAccessor httpContextAccessor = serviceProvider.GetRequiredService(); + HttpContext? httpContext = httpContextAccessor.HttpContext; + + if (httpContext is null || !authResolver.IsValidRoleContext(httpContext)) + { + return BuildErrorResult("PermissionDenied", $"You do not have permission to read records for entity '{entityName}'.", logger); + } + + if (!TryResolveAuthorizedRole(httpContext, authResolver, entityName, out string? effectiveRole, out string authError)) + { + return BuildErrorResult("PermissionDenied", authError, logger); + } + + // Build and validate Find context + RequestValidator requestValidator = new(metadataProviderFactory, runtimeConfigProvider); + FindRequestContext context = new(entityName, dbObject, true); + httpContext.Request.Method = "GET"; + + requestValidator.ValidateEntity(entityName); + + if (!string.IsNullOrWhiteSpace(select)) + { + // Update the context to specify which fields will be returned from the entity. + IEnumerable fieldsReturnedForFind = select.Split(",").ToList(); + context.UpdateReturnFields(fieldsReturnedForFind); + } + + if (!string.IsNullOrWhiteSpace(filter)) + { + string filterQueryString = $"?{RequestParser.FILTER_URL}={filter}"; + context.FilterClauseInUrl = sqlMetadataProvider.GetODataParser().GetFilterClause(filterQueryString, $"{context.EntityName}.{context.DatabaseObject.FullName}"); + } + + if (orderby is not null && orderby.Count() != 0) + { + string sortQueryString = $"?{RequestParser.SORT_URL}="; + foreach (string param in orderby) + { + if (string.IsNullOrWhiteSpace(param)) + { + return BuildErrorResult("InvalidArguments", "Parameters inside 'orderby' argument cannot be empty or null.", logger); + } + + sortQueryString += $"{param}, "; + } + + sortQueryString = sortQueryString.Substring(0, sortQueryString.Length - 2); + (context.OrderByClauseInUrl, context.OrderByClauseOfBackingColumns) = RequestParser.GenerateOrderByLists(context, sqlMetadataProvider, sortQueryString); + } + + context.First = first; + context.After = after; + + // The final authorization check on columns occurs after the request is fully parsed and validated. + requestValidator.ValidateRequestContext(context); + + AuthorizationResult authorizationResult = await authorizationService.AuthorizeAsync( + user: httpContext.User, + resource: context, + requirements: new[] { new ColumnsPermissionsRequirement() }); + if (!authorizationResult.Succeeded) + { + return BuildErrorResult("PermissionDenied", DataApiBuilderException.AUTHORIZATION_FAILURE, logger); + } + + // Execute + IQueryEngine queryEngine = queryEngineFactory.GetQueryEngine(sqlMetadataProvider.GetDatabaseType()); + JsonDocument? queryResult = await queryEngine.ExecuteAsync(context); + IActionResult actionResult = queryResult is null ? SqlResponseHelpers.FormatFindResult(JsonDocument.Parse("[]").RootElement.Clone(), context, metadataProviderFactory.GetMetadataProvider(dataSourceName), runtimeConfigProvider.GetConfig(), httpContext, true) + : SqlResponseHelpers.FormatFindResult(queryResult.RootElement.Clone(), context, metadataProviderFactory.GetMetadataProvider(dataSourceName), runtimeConfigProvider.GetConfig(), httpContext, true); + + // Normalize response + string rawPayloadJson = ExtractResultJson(actionResult); + JsonDocument result = JsonDocument.Parse(rawPayloadJson); + JsonElement queryRoot = result.RootElement; + + return BuildSuccessResult( + entityName, + queryRoot.Clone(), + logger); + } + catch (OperationCanceledException) + { + return BuildErrorResult("OperationCanceled", "The read operation was canceled.", logger); + } + catch (DbException argEx) + { + return BuildErrorResult("DatabaseOperationFailed", argEx.Message, logger); + } + catch (ArgumentException argEx) + { + return BuildErrorResult("InvalidArguments", argEx.Message, logger); + } + catch (DataApiBuilderException argEx) + { + return BuildErrorResult(argEx.StatusCode.ToString(), argEx.Message, logger); + } + catch (Exception) + { + return BuildErrorResult("UnexpectedError", "Unexpected error occurred in ReadRecordsTool.", logger); + } + } + + /// + /// Ensures that the role used on the request has the necessary authorizations. + /// + /// Contains request headers and metadata of the user. + /// Resolver used to check if role has necessary authorizations. + /// Name of the entity used in the request. + /// Role defined in client role header. + /// Error message given to the user. + /// True if the user role is authorized, along with the role. + private static bool TryResolveAuthorizedRole( + HttpContext httpContext, + IAuthorizationResolver authorizationResolver, + string entityName, + out string? effectiveRole, + out string error) + { + effectiveRole = null; + error = string.Empty; + + string roleHeader = httpContext.Request.Headers[AuthorizationResolver.CLIENT_ROLE_HEADER].ToString(); + + if (string.IsNullOrWhiteSpace(roleHeader)) + { + error = $"Client role header '{AuthorizationResolver.CLIENT_ROLE_HEADER}' is missing or empty."; + return false; + } + + string[] roles = roleHeader + .Split(',', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries) + .Distinct(StringComparer.OrdinalIgnoreCase) + .ToArray(); + + if (roles.Length == 0) + { + error = $"Client role header '{AuthorizationResolver.CLIENT_ROLE_HEADER}' is missing or empty."; + return false; + } + + foreach (string role in roles) + { + bool allowed = authorizationResolver.AreRoleAndOperationDefinedForEntity( + entityName, role, EntityActionOperation.Read); + + if (allowed) + { + effectiveRole = role; + return true; + } + } + + error = $"You do not have permission to read records for entity '{entityName}'."; + return false; + } + + /// + /// Returns a result from the query in the case that it was successfully ran. + /// + /// Name of the entity used in the request. + /// Query result from engine. + /// MCP logger that returns all logged events. + private static CallToolResult BuildSuccessResult( + string entityName, + JsonElement engineRootElement, + ILogger? logger) + { + // Build normalized response + Dictionary normalized = new() + { + ["status"] = "success", + ["result"] = engineRootElement // only requested values + }; + + string output = JsonSerializer.Serialize(normalized, new JsonSerializerOptions { WriteIndented = true }); + + logger?.LogInformation("ReadRecordsTool success for entity {Entity}.", entityName); + + return new CallToolResult + { + Content = new List + { + new TextContentBlock { Type = "text", Text = output } + } + }; + } + + /// + /// Returns an error if the query failed to run at any point. + /// + /// Type of error that is encountered. + /// Error message given to the user. + /// MCP logger that returns all logged events. + private static CallToolResult BuildErrorResult( + string errorType, + string message, + ILogger? logger) + { + Dictionary errorObj = new() + { + ["status"] = "error", + ["error"] = new Dictionary + { + ["type"] = errorType, + ["message"] = message + } + }; + + string output = JsonSerializer.Serialize(errorObj); + + logger?.LogError("ReadRecordsTool error {ErrorType}: {Message}", errorType, message); + + return new CallToolResult + { + Content = + [ + new TextContentBlock { Type = "text", Text = output } + ], + IsError = true + }; + } + + /// + /// Extracts a JSON string from a typical IActionResult. + /// Falls back to "{}" for unsupported/empty cases to avoid leaking internals. + /// + private static string ExtractResultJson(IActionResult? result) + { + switch (result) + { + case ObjectResult obj: + if (obj.Value is JsonElement je) + { + return je.GetRawText(); + } + + if (obj.Value is JsonDocument jd) + { + return jd.RootElement.GetRawText(); + } + + return JsonSerializer.Serialize(obj.Value ?? new object()); + + case ContentResult content: + return string.IsNullOrWhiteSpace(content.Content) ? "{}" : content.Content; + + default: + return "{}"; + } + } + } +} diff --git a/src/Core/Parsers/FilterParser.cs b/src/Core/Parsers/FilterParser.cs index ec765e26a6..c9cfc1eb53 100644 --- a/src/Core/Parsers/FilterParser.cs +++ b/src/Core/Parsers/FilterParser.cs @@ -44,7 +44,6 @@ public FilterClause GetFilterClause(string filterQueryString, string resourcePat { if (_model is null) { - throw new DataApiBuilderException( message: "The runtime has not been initialized with an Edm model.", statusCode: HttpStatusCode.InternalServerError, diff --git a/src/Core/Parsers/RequestParser.cs b/src/Core/Parsers/RequestParser.cs index bb4dd8d51e..6402ce4ecb 100644 --- a/src/Core/Parsers/RequestParser.cs +++ b/src/Core/Parsers/RequestParser.cs @@ -147,7 +147,7 @@ public static void ParseQueryString(RestRequestContext context, ISqlMetadataProv /// associated with the sort param. /// A List /// - private static (List?, List?) GenerateOrderByLists(RestRequestContext context, + public static (List?, List?) GenerateOrderByLists(RestRequestContext context, ISqlMetadataProvider sqlMetadataProvider, string sortQueryString) { diff --git a/src/Core/Resolvers/SqlResponseHelpers.cs b/src/Core/Resolvers/SqlResponseHelpers.cs index 8b0a0edb67..d0bf768281 100644 --- a/src/Core/Resolvers/SqlResponseHelpers.cs +++ b/src/Core/Resolvers/SqlResponseHelpers.cs @@ -23,21 +23,23 @@ public class SqlResponseHelpers /// /// Format the results from a Find operation. Check if there is a requirement - /// for a nextLink, and if so, add this value to the array of JsonElements to + /// for a nextLink/after, and if so, add this value to the array of JsonElements to /// be used as part of the response. /// /// The JsonDocument from the query. /// The RequestContext. - /// the metadataprovider. + /// The metadataprovider. /// Runtimeconfig object /// HTTP context associated with the API request + /// True if request is done through MCP endpoint /// An OkObjectResult from a Find operation that has been correctly formatted. public static OkObjectResult FormatFindResult( JsonElement findOperationResponse, FindRequestContext context, ISqlMetadataProvider sqlMetadataProvider, RuntimeConfig runtimeConfig, - HttpContext httpContext) + HttpContext httpContext, + bool? isMcpRequest = null) { // When there are no rows returned from the database, the jsonElement will be an empty array. @@ -55,7 +57,7 @@ public static OkObjectResult FormatFindResult( uint maxPageSize = runtimeConfig.MaxPageSize(); // If the results are not a collection or if the query does not have a next page - // no nextLink is needed. So, the response is returned after removing the extra fields. + // no nextLink/after is needed. So, the response is returned after removing the extra fields. if (findOperationResponse.ValueKind is not JsonValueKind.Array || !SqlPaginationUtil.HasNext(findOperationResponse, context.First, defaultPageSize, maxPageSize)) { // If there are no additional fields present, the response is returned directly. When there @@ -89,27 +91,43 @@ public static OkObjectResult FormatFindResult( tableName: context.DatabaseObject.Name, sqlMetadataProvider: sqlMetadataProvider); - string basePaginationUri = SqlPaginationUtil.ConstructBaseUriForPagination(httpContext, runtimeConfig.Runtime?.BaseRoute); - - // Build the query string with the $after token. - string queryString = SqlPaginationUtil.BuildQueryStringWithAfterToken( - queryStringParameters: context!.ParsedQueryString, - newAfterPayload: after); - - // Get the final consolidated nextLink for the pagination. - JsonElement nextLink = SqlPaginationUtil.GetConsolidatedNextLinkForPagination( - baseUri: basePaginationUri, - queryString: queryString, - isNextLinkRelative: runtimeConfig.NextLinkRelative()); - // When there are extra fields present, they are removed before returning the response. if (extraFieldsInResponse.Count > 0) { rootEnumerated = RemoveExtraFieldsInResponseWithMultipleItems(rootEnumerated, extraFieldsInResponse); } - rootEnumerated.Add(nextLink); - return OkResponse(JsonSerializer.SerializeToElement(rootEnumerated)); + // Create an 'after' object if the request comes from MCP endpoint. + if (isMcpRequest is true) + { + string jsonString = JsonSerializer.Serialize(new[] + { + new { after = after } + }); + JsonElement afterElement = JsonSerializer.Deserialize(jsonString); + + rootEnumerated.Add(afterElement); + } + // Create a 'nextLink' object if the request comes from REST endpoint. + else + { + string basePaginationUri = SqlPaginationUtil.ConstructBaseUriForPagination(httpContext, runtimeConfig.Runtime?.BaseRoute); + + // Build the query string with the $after token. + string queryString = SqlPaginationUtil.BuildQueryStringWithAfterToken( + queryStringParameters: context!.ParsedQueryString, + newAfterPayload: after); + + // Get the final consolidated nextLink for the pagination. + JsonElement nextLink = SqlPaginationUtil.GetConsolidatedNextLinkForPagination( + baseUri: basePaginationUri, + queryString: queryString, + isNextLinkRelative: runtimeConfig.NextLinkRelative()); + + rootEnumerated.Add(nextLink); + } + + return OkResponse(JsonSerializer.SerializeToElement(rootEnumerated), isMcpRequest); } /// @@ -186,8 +204,9 @@ private static JsonElement RemoveExtraFieldsInResponseWithSingleItem(JsonElement /// form that complies with vNext Api guidelines. /// /// Value representing the Json results of the client's request. + /// True if request is done through MCP endpoint. /// Correctly formatted OkObjectResult. - public static OkObjectResult OkResponse(JsonElement jsonResult) + public static OkObjectResult OkResponse(JsonElement jsonResult, bool? isMcpRequest = null) { // For consistency we return all values as type Array if (jsonResult.ValueKind != JsonValueKind.Array) @@ -200,20 +219,34 @@ public static OkObjectResult OkResponse(JsonElement jsonResult) // More than 0 records, and the last element is of type array, then we have pagination if (resultEnumerated.Count > 0 && resultEnumerated[resultEnumerated.Count - 1].ValueKind == JsonValueKind.Array) { - // Get the nextLink + // Get the 'nextLink' or 'after' // resultEnumerated will be an array of the form - // [{object1}, {object2},...{objectlimit}, [{nextLinkObject}]] - // if the last element is of type array, we know it is nextLink - // we strip the "[" and "]" and then save the nextLink element - // into a dictionary with a key of "nextLink" and a value that - // represents the nextLink data we require. - string nextLinkJsonString = JsonSerializer.Serialize(resultEnumerated[resultEnumerated.Count - 1]); - Dictionary nextLink = JsonSerializer.Deserialize>(nextLinkJsonString[1..^1])!; + // [{object1}, {object2},...{objectlimit}, [{nextLinkObject/afterObject}]] + // if the last element is of type array, we know it is 'nextLink' + // if the request is done through the REST endpoint and it is + // 'after' if the request is done through the MCP endpoint, + // we strip the "[" and "]" and then save the element + // into a dictionary with a key of "nextLinkAfter" and a value that + // represents the nextLink/after data we require. + string nextLinkAfterJsonString = JsonSerializer.Serialize(resultEnumerated[resultEnumerated.Count - 1]); + Dictionary nextLinkAfter = JsonSerializer.Deserialize>(nextLinkAfterJsonString[1..^1])!; IEnumerable value = resultEnumerated.Take(resultEnumerated.Count - 1); + + // Check 'after' object if request is done through MCP endpoint. + if (isMcpRequest is true) + { + return new OkObjectResult(new + { + value = value, + after = nextLinkAfter["after"] + }); + } + + // Check 'nextLink' object if request is done through REST endpoint. return new OkObjectResult(new { value = value, - @nextLink = nextLink["nextLink"] + @nextLink = nextLinkAfter["nextLink"] }); }