diff --git a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs index bfe33dbda368..3c828ca4b27b 100644 --- a/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs +++ b/extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs @@ -18,8 +18,11 @@ using Amazon.Runtime.Internal.Util; using Microsoft.Extensions.AI; using System; +using System.Buffers; using System.Collections.Generic; using System.Diagnostics; +using System.Globalization; +using System.IO; using System.Linq; using System.Runtime.CompilerServices; using System.Text; @@ -35,6 +38,11 @@ internal sealed partial class BedrockChatClient : IChatClient /// A default logger to use. private static readonly ILogger DefaultLogger = Logger.GetLogger(typeof(BedrockChatClient)); + /// The name used for the synthetic tool that enforces response format. + private const string ResponseFormatToolName = "generate_response"; + /// The description used for the synthetic tool that enforces response format. + private const string ResponseFormatToolDescription = "Generate response in specified format"; + /// The wrapped instance. private readonly IAmazonBedrockRuntime _runtime; /// Default model ID to use when no model is specified in the request. @@ -63,6 +71,12 @@ public void Dispose() } /// + /// + /// When is specified, the model must support + /// the ToolChoice feature. Models without this support will throw . + /// If the model fails to return the expected structured output, + /// is thrown. + /// public async Task GetResponseAsync( IEnumerable messages, ChatOptions? options = null, CancellationToken cancellationToken = default) { @@ -79,7 +93,29 @@ public async Task GetResponseAsync( request.InferenceConfig = CreateInferenceConfiguration(request.InferenceConfig, options); request.AdditionalModelRequestFields = CreateAdditionalModelRequestFields(request.AdditionalModelRequestFields, options); - var response = await _runtime.ConverseAsync(request, cancellationToken).ConfigureAwait(false); + ConverseResponse response; + try + { + response = await _runtime.ConverseAsync(request, cancellationToken).ConfigureAwait(false); + } + // Transforms ValidationException to NotSupportedException when error message indicates model lacks tool use support (required for ResponseFormat). + // This detection relies on error message text which may change in future Bedrock API versions. + catch (AmazonBedrockRuntimeException ex) when (options?.ResponseFormat is ChatResponseFormatJson) + { + // Detect unsupported model: ValidationException with specific tool support error messages + if (ex.ErrorCode == "ValidationException" && + (ex.Message.IndexOf("toolChoice is not supported by this model", StringComparison.OrdinalIgnoreCase) >= 0 || + ex.Message.IndexOf("This model doesn't support tool use", StringComparison.OrdinalIgnoreCase) >= 0)) + { + throw new NotSupportedException( + $"The model '{request.ModelId}' does not support ResponseFormat. " + + $"ResponseFormat requires ToolChoice support, which is only available in Claude 3+ and Mistral Large models. " + + $"See: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference-supported-models-features.html", + ex); + } + + throw; + } ChatMessage result = new() { @@ -89,6 +125,42 @@ public async Task GetResponseAsync( MessageId = Guid.NewGuid().ToString("N"), }; + // Check if ResponseFormat was used and extract structured content + bool usingResponseFormat = options?.ResponseFormat is ChatResponseFormatJson; + if (usingResponseFormat) + { + var structuredContent = ExtractResponseFormatContent(response.Output?.Message); + if (structuredContent is not null) + { + // Replace the content with the extracted JSON as a TextContent + result.Contents.Add(new TextContent(structuredContent) { RawRepresentation = response.Output?.Message }); + + // Skip normal content processing since we've extracted the structured response + if (DocumentToDictionary(response.AdditionalModelResponseFields) is { } responseFieldsDict) + { + result.AdditionalProperties = new(responseFieldsDict); + } + + return new(result) + { + CreatedAt = result.CreatedAt, + FinishReason = response.StopReason is not null ? GetChatFinishReason(response.StopReason) : null, + Usage = response.Usage is TokenUsage tokenUsage ? CreateUsageDetails(tokenUsage) : null, + RawRepresentation = response, + }; + } + else + { + // Model succeeded but did not return expected structured output + throw new InvalidOperationException( + $"Model '{request.ModelId}' did not return structured output as requested. " + + $"This may indicate the model refused to follow the tool use instruction, " + + $"the schema was too complex, or the prompt conflicted with the requirement. " + + $"StopReason: {response.StopReason?.Value ?? "unknown"}."); + } + } + + // Normal content processing when not using ResponseFormat or extraction failed if (response.Output?.Message?.Content is { } contents) { foreach (var content in contents) @@ -182,6 +254,14 @@ public async IAsyncEnumerable GetStreamingResponseAsync( throw new ArgumentNullException(nameof(messages)); } + // Check if ResponseFormat is set - not supported for streaming yet + if (options?.ResponseFormat is ChatResponseFormatJson) + { + throw new NotSupportedException( + "ResponseFormat is not yet supported for streaming responses with Amazon Bedrock. " + + "Please use GetResponseAsync for structured output."); + } + ConverseStreamRequest request = options?.RawRepresentationFactory?.Invoke(this) as ConverseStreamRequest ?? new(); request.ModelId ??= options?.ModelId ?? _modelId; request.Messages = CreateMessages(request.Messages, messages); @@ -794,7 +874,11 @@ private static Document ToDocument(JsonElement json) } } - /// Creates an from the specified options. + /// Creates a from the specified options. + /// + /// When ResponseFormat is specified, creates a synthetic tool to enforce structured output. + /// This conflicts with user-provided tools as Bedrock only supports a single ToolChoice value. + /// private static ToolConfiguration? CreateToolConfig(ToolConfiguration? toolConfig, ChatOptions? options) { if (options?.Tools is { Count: > 0 } tools) @@ -857,6 +941,56 @@ private static Document ToDocument(JsonElement json) } } + // Handle ResponseFormat by creating a synthetic tool + if (options?.ResponseFormat is ChatResponseFormatJson jsonFormat) + { + // Check for conflict with user-provided tools + if (toolConfig?.Tools?.Count > 0) + { + throw new ArgumentException( + "ResponseFormat cannot be used with Tools in Amazon Bedrock. " + + "ResponseFormat uses Bedrock's tool mechanism for structured output, " + + "which conflicts with user-provided tools."); + } + + // Create the synthetic tool with the schema from ResponseFormat + toolConfig ??= new(); + toolConfig.Tools ??= []; + + // Parse the schema if provided, otherwise create an empty object schema + Document schemaDoc; + if (jsonFormat.Schema.HasValue) + { + // Schema is already a JsonElement (parsed JSON), convert directly to Document + schemaDoc = ToDocument(jsonFormat.Schema.Value); + } + else + { + // For JSON mode without schema, create a generic object schema + schemaDoc = new Document(new Dictionary + { + ["type"] = new Document("object"), + ["additionalProperties"] = new Document(true) + }); + } + + toolConfig.Tools.Add(new Tool + { + ToolSpec = new ToolSpecification + { + Name = ResponseFormatToolName, + Description = jsonFormat.SchemaDescription ?? ResponseFormatToolDescription, + InputSchema = new ToolInputSchema + { + Json = schemaDoc + } + } + }); + + // Force the model to use the synthetic tool + toolConfig.ToolChoice = new ToolChoice { Tool = new() { Name = ResponseFormatToolName } }; + } + if (toolConfig?.Tools is { Count: > 0 } && toolConfig.ToolChoice is null) { switch (options!.ToolMode) @@ -872,6 +1006,43 @@ private static Document ToDocument(JsonElement json) return toolConfig; } + /// Extracts JSON content from the synthetic ResponseFormat tool use, if present. + private static string? ExtractResponseFormatContent(Message? message) + { + if (message?.Content is null) + { + return null; + } + + foreach (var content in message.Content) + { + if (content.ToolUse is ToolUseBlock toolUse && + toolUse.Name == ResponseFormatToolName && + toolUse.Input != default) + { + // Convert the Document back to JSON string + return DocumentToJsonString(toolUse.Input); + } + } + + return null; + } + + /// + /// Converts a to a JSON string using the SDK's standard DocumentMarshaller. + /// Note: Document is a struct (value type), so circular references are structurally impossible. + /// + private static string DocumentToJsonString(Document document) + { + using var stream = new MemoryStream(); + using (var writer = new Utf8JsonWriter(stream, new JsonWriterOptions { Indented = false })) + { + Amazon.Runtime.Documents.Internal.Transform.DocumentMarshaller.Instance.Write(writer, document); + } + return Encoding.UTF8.GetString(stream.ToArray()); + } + + /// Creates an from the specified options. private static InferenceConfiguration CreateInferenceConfiguration(InferenceConfiguration config, ChatOptions? options) { diff --git a/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs b/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs index 8f5099c973d8..4b180da27c7c 100644 --- a/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs +++ b/extensions/test/BedrockMEAITests/BedrockChatClientTests.cs @@ -1,11 +1,44 @@ -using Microsoft.Extensions.AI; +using Amazon.BedrockRuntime.Model; +using Amazon.Runtime; +using Amazon.Runtime.Documents; +using Amazon.Runtime.Internal; +using Amazon.Runtime.Internal.Transform; +using Microsoft.Extensions.AI; +using Moq; using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Reflection; +using System.Text; +using System.Text.Json; +using System.Threading; +using System.Threading.Tasks; using Xunit; namespace Amazon.BedrockRuntime; +// Simple test implementation of AIFunctionDeclaration +internal sealed class TestAIFunction : AIFunctionDeclaration +{ + public TestAIFunction(string name, string description, JsonElement jsonSchema) + { + Name = name; + Description = description; + JsonSchema = jsonSchema; + } + + public override string Name { get; } + public override string Description { get; } + public override JsonElement JsonSchema { get; } +} + public class BedrockChatClientTests { + #region Basic Client Tests + [Fact] [Trait("UnitTest", "BedrockRuntime")] public void AsIChatClient_InvalidArguments_Throws() @@ -19,8 +52,8 @@ public void AsIChatClient_InvalidArguments_Throws() [InlineData("claude")] public void AsIChatClient_ReturnsInstance(string modelId) { - IAmazonBedrockRuntime runtime = new AmazonBedrockRuntimeClient("awsAccessKeyId", "awsSecretAccessKey", RegionEndpoint.USEast1); - IChatClient client = runtime.AsIChatClient(modelId); + var mockRuntime = new Mock(); + IChatClient client = mockRuntime.Object.AsIChatClient(modelId); Assert.NotNull(client); Assert.Equal("aws.bedrock", client.GetService()?.ProviderName); @@ -31,17 +64,1112 @@ public void AsIChatClient_ReturnsInstance(string modelId) [Trait("UnitTest", "BedrockRuntime")] public void AsIChatClient_GetService() { - IAmazonBedrockRuntime runtime = new AmazonBedrockRuntimeClient("awsAccessKeyId", "awsSecretAccessKey", RegionEndpoint.USEast1); - IChatClient client = runtime.AsIChatClient(); + var mockRuntime = new Mock(); + IChatClient client = mockRuntime.Object.AsIChatClient(); - Assert.Same(runtime, client.GetService()); - Assert.Same(runtime, client.GetService()); + Assert.Same(mockRuntime.Object, client.GetService()); Assert.Same(client, client.GetService()); - Assert.Null(client.GetService()); - - Assert.Null(client.GetService("key")); Assert.Null(client.GetService("key")); - Assert.Null(client.GetService("key")); } + + #endregion + + #region ResponseFormat Tests + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithSchema_CreatesSyntheticToolWithCorrectSchema() + { + // Arrange + var mockRuntime = new Mock(); + ConverseRequest capturedRequest = null; + + mockRuntime + .Setup(x => x.ConverseAsync(It.IsAny(), It.IsAny())) + .Callback((req, ct) => capturedRequest = req) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List + { + new ContentBlock + { + ToolUse = new ToolUseBlock + { + ToolUseId = "test-id", + Name = "generate_response", + Input = new Document(new Dictionary + { + ["name"] = new Document("John Doe"), + ["age"] = new Document(30) + }) + } + } + } + } + }, + StopReason = new StopReason("tool_use") + }); + + var client = mockRuntime.Object.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Test") }; + + var schemaJson = """ + { + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "number" } + }, + "required": ["name"] + } + """; + var schemaElement = JsonDocument.Parse(schemaJson).RootElement; + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.ForJsonSchema(schemaElement, + schemaName: "PersonSchema", + schemaDescription: "A person object") + }; + + // Act + await client.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(capturedRequest); + var tool = capturedRequest.ToolConfig.Tools[0]; + Assert.Equal("generate_response", tool.ToolSpec.Name); + Assert.Equal("A person object", tool.ToolSpec.Description); + + // Verify schema structure matches input + var schema = tool.ToolSpec.InputSchema.Json; + Assert.True(schema.IsDictionary()); + var schemaDict = schema.AsDictionary(); + + Assert.Equal("object", schemaDict["type"].AsString()); + Assert.True(schemaDict.ContainsKey("properties")); + + var properties = schemaDict["properties"].AsDictionary(); + Assert.True(properties.ContainsKey("name")); + Assert.True(properties.ContainsKey("age")); + Assert.Equal("string", properties["name"].AsDictionary()["type"].AsString()); + Assert.Equal("number", properties["age"].AsDictionary()["type"].AsString()); + + Assert.True(schemaDict.ContainsKey("required")); + var required = schemaDict["required"].AsList(); + Assert.Single(required); + Assert.Equal("name", required[0].AsString()); + + // Verify the mock was called + mockRuntime.Verify(x => x.ConverseAsync(It.IsAny(), It.IsAny()), Times.Once); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_ModelReturnsToolUse_ExtractsJsonCorrectly() + { + // Arrange + var mockRuntime = new Mock(); + + // Setup mock to return tool use with structured data + mockRuntime + .Setup(x => x.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List + { + new ContentBlock + { + ToolUse = new ToolUseBlock + { + ToolUseId = "test-id", + Name = "generate_response", + Input = new Document(new Dictionary + { + ["city"] = new Document("Seattle"), + ["temperature"] = new Document(72), + ["conditions"] = new Document("sunny") + }) + } + } + } + } + }, + StopReason = new StopReason("tool_use"), + Usage = new TokenUsage { InputTokens = 10, OutputTokens = 20, TotalTokens = 30 } + }); + + var client = mockRuntime.Object.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Get weather") }; + var options = new ChatOptions { ResponseFormat = ChatResponseFormat.Json }; + + // Act + var response = await client.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Text); + + // Parse the JSON to verify structure + var json = JsonDocument.Parse(response.Text); + Assert.Equal("Seattle", json.RootElement.GetProperty("city").GetString()); + Assert.Equal(72, json.RootElement.GetProperty("temperature").GetInt32()); + Assert.Equal("sunny", json.RootElement.GetProperty("conditions").GetString()); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithTools_ThrowsArgumentException() + { + // Arrange + var mockRuntime = new Mock(); + var client = mockRuntime.Object.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Test") }; + + // Create test tool + var tool = new TestAIFunction("test", "Test tool", JsonDocument.Parse("{}").RootElement); + + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.Json, + Tools = new[] { tool } + }; + + // Act & Assert + await Assert.ThrowsAsync(async () => + await client.GetResponseAsync(messages, options)); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_UnsupportedModel_ThrowsNotSupportedException() + { + // Arrange + var mockRuntime = new Mock(); + + // Setup mock to throw BedrockRuntimeException with toolChoice error + mockRuntime + .Setup(x => x.ConverseAsync(It.IsAny(), It.IsAny())) + .ThrowsAsync(new AmazonBedrockRuntimeException("ValidationException: toolChoice is not supported by this model") + { + ErrorCode = "ValidationException" + }); + + var client = mockRuntime.Object.AsIChatClient("titan"); + var messages = new[] { new ChatMessage(ChatRole.User, "Test") }; + var options = new ChatOptions { ResponseFormat = ChatResponseFormat.Json }; + + // Act & Assert + var ex = await Assert.ThrowsAsync(async () => + await client.GetResponseAsync(messages, options)); + + Assert.Contains("does not support ResponseFormat", ex.Message); + Assert.Contains("ToolChoice", ex.Message); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_ForStreaming_ThrowsNotSupportedException() + { + // Arrange + var mockRuntime = new Mock(); + var client = mockRuntime.Object.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Test") }; + var options = new ChatOptions { ResponseFormat = ChatResponseFormat.Json }; + + // Act & Assert + await Assert.ThrowsAsync(async () => + { + await foreach (var update in client.GetStreamingResponseAsync(messages, options)) + { + // Should not reach here + } + }); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_ModelReturnsText_ThrowsInvalidOperationException() + { + // Arrange - Model returns text instead of tool_use + var mockRuntime = new Mock(); + + mockRuntime + .Setup(x => x.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List + { + new ContentBlock { Text = "Here is some text" } + } + } + }, + StopReason = new StopReason("end_turn") + }); + + var client = mockRuntime.Object.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Generate data") }; + var options = new ChatOptions { ResponseFormat = ChatResponseFormat.Json }; + + // Act & Assert + var ex = await Assert.ThrowsAsync(async () => + await client.GetResponseAsync(messages, options)); + + Assert.Contains("did not return structured output", ex.Message); + Assert.Contains("end_turn", ex.Message); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WrongToolName_ThrowsInvalidOperationException() + { + // Arrange - Model uses wrong tool name + var mockRuntime = new Mock(); + + mockRuntime + .Setup(x => x.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List + { + new ContentBlock + { + ToolUse = new ToolUseBlock + { + ToolUseId = "wrong-id", + Name = "wrong_tool_name", + Input = new Document(new Dictionary + { + ["data"] = new Document("value") + }) + } + } + } + } + }, + StopReason = new StopReason("tool_use") + }); + + var client = mockRuntime.Object.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Generate data") }; + var options = new ChatOptions { ResponseFormat = ChatResponseFormat.Json }; + + // Act & Assert + var ex = await Assert.ThrowsAsync(async () => + await client.GetResponseAsync(messages, options)); + + Assert.Contains("did not return structured output", ex.Message); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_EmptyToolInput_ReturnsEmptyJson() + { + // Arrange - Tool with empty input + var mockRuntime = new Mock(); + + mockRuntime + .Setup(x => x.ConverseAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(new ConverseResponse + { + Output = new ConverseOutput + { + Message = new Message + { + Role = ConversationRole.Assistant, + Content = new List + { + new ContentBlock + { + ToolUse = new ToolUseBlock + { + ToolUseId = "empty-id", + Name = "generate_response", + Input = new Document(new Dictionary()) + } + } + } + } + }, + StopReason = new StopReason("tool_use") + }); + + var client = mockRuntime.Object.AsIChatClient("claude-3"); + var messages = new[] { new ChatMessage(ChatRole.User, "Generate data") }; + var options = new ChatOptions { ResponseFormat = ChatResponseFormat.Json }; + + // Act + var response = await client.GetResponseAsync(messages, options); + + // Assert - Empty object is valid JSON + Assert.NotNull(response.Text); + var json = JsonDocument.Parse(response.Text); + Assert.Equal(JsonValueKind.Object, json.RootElement.ValueKind); + } + + #endregion } + +/// +/// Tests using HTTP-layer mocking to test actual Converse API response scenarios. +/// This allows testing beyond the happy path with realistic service responses. +/// Based on Peter's (peterrsongg) suggestion to test different response structures. +/// +public class BedrockChatClientHttpMockedTests : IClassFixture +{ + private readonly HttpMockFixture _fixture; + + public BedrockChatClientHttpMockedTests(HttpMockFixture fixture) + { + _fixture = fixture; + } + + /// + /// Helper method to inject stubbed web response data into a request's state + /// + private static void InjectMockedResponse(ConverseRequest request, StubWebResponseData webResponseData) + { + var interfaceType = typeof(IAmazonWebServiceRequest); + var requestStatePropertyInfo = interfaceType.GetProperty("RequestState"); + var requestState = (Dictionary)requestStatePropertyInfo.GetValue(request); + requestState["response"] = webResponseData; + } + + #region HTTP Mocking Infrastructure (Based on Peter's Working Code) + + /// + /// Pipeline customizer that replaces the HTTP handler with a mock implementation + /// + private class MockPipelineCustomizer : IRuntimePipelineCustomizer + { + public string UniqueName => "BedrockMEAIMockPipeline"; + + public void Customize(Type type, RuntimePipeline pipeline) + { +#if BCL + // On .NET Framework, use Stream + pipeline.ReplaceHandler>( + new HttpHandler(new MockHttpRequestFactory(), new object())); +#else + // On .NET Core/.NET 5+, use HttpContent + pipeline.ReplaceHandler>( + new HttpHandler(new MockHttpRequestFactory(), new object())); +#endif + } + } + + /// + /// Factory for creating mock HTTP requests + /// +#if BCL + private class MockHttpRequestFactory : IHttpRequestFactory + { + public IHttpRequest CreateHttpRequest(Uri requestUri) + { + return new MockHttpRequest(requestUri); + } +#else + private class MockHttpRequestFactory : IHttpRequestFactory + { + public IHttpRequest CreateHttpRequest(Uri requestUri) + { + return new MockHttpRequest(requestUri); + } +#endif + + public void Dispose() + { + // No resources to dispose + } + } + + /// + /// Mock HTTP request that retrieves stubbed response data from request state + /// +#if BCL + private class MockHttpRequest : IHttpRequest +#else + private class MockHttpRequest : IHttpRequest +#endif + { + private IWebResponseData _webResponseData; + + public MockHttpRequest(Uri requestUri) + { + RequestUri = requestUri; + } + + public string Method { get; set; } + public Uri RequestUri { get; set; } + public Version HttpProtocolVersion { get; set; } + + public void ConfigureRequest(IRequestContext requestContext) + { + // Retrieve the stubbed response from request state + // This is the critical line that Peter identified (line 60 in his comment) + var request = requestContext.OriginalRequest as IAmazonWebServiceRequest; + if (request != null && request.RequestState.ContainsKey("response")) + { + _webResponseData = request.RequestState["response"] as IWebResponseData; + } + } + + public void SetRequestHeaders(IDictionary headers) + { + // Not needed for mock + } + +#if BCL + public Stream GetRequestContent() + { + return new MemoryStream(); + } +#else + public HttpContent GetRequestContent() + { + return null; + } +#endif + + public IWebResponseData GetResponse() + { + return GetResponseAsync(CancellationToken.None).Result; + } + + public Task GetResponseAsync(CancellationToken cancellationToken) + { + return Task.FromResult(_webResponseData); + } + +#if BCL + public void WriteToRequestBody(Stream requestContent, Stream contentStream, + IDictionary contentHeaders, IRequestContext requestContext) + { + // Not needed for mock + } + + public void WriteToRequestBody(Stream requestContent, byte[] content, + IDictionary contentHeaders) + { + // Not needed for mock + } + + public Task WriteToRequestBodyAsync(Stream requestContent, Stream contentStream, + IDictionary contentHeaders, IRequestContext requestContext) + { + return Task.CompletedTask; + } + + public Task WriteToRequestBodyAsync(Stream requestContent, byte[] content, + IDictionary contentHeaders, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } +#else + public void WriteToRequestBody(HttpContent requestContent, Stream contentStream, + IDictionary contentHeaders, IRequestContext requestContext) + { + // Not needed for mock + } + + public void WriteToRequestBody(HttpContent requestContent, byte[] content, + IDictionary contentHeaders) + { + // Not needed for mock + } + + public Task WriteToRequestBodyAsync(HttpContent requestContent, Stream contentStream, + IDictionary contentHeaders, IRequestContext requestContext) + { + return Task.CompletedTask; + } + + public Task WriteToRequestBodyAsync(HttpContent requestContent, byte[] content, + IDictionary contentHeaders, CancellationToken cancellationToken = default) + { + return Task.CompletedTask; + } +#endif + + public IHttpRequestStreamHandle SetupHttpRequestStreamPublisher( + IDictionary contentHeaders, IHttpRequestStreamPublisher publisher) + { + throw new NotImplementedException(); + } + + public void Abort() + { + // Not needed for mock + } + +#if BCL + public Task GetRequestContentAsync() + { + return Task.FromResult(new MemoryStream()); + } + + public Task GetRequestContentAsync(CancellationToken cancellationToken) + { + return Task.FromResult(new MemoryStream()); + } +#else + public Task GetRequestContentAsync() + { + return Task.FromResult(null); + } + + public Task GetRequestContentAsync(CancellationToken cancellationToken) + { + return Task.FromResult(null); + } +#endif + + public Stream SetupProgressListeners(Stream originalStream, long progressUpdateInterval, + object sender, EventHandler callback) + { + return originalStream; + } + + public void Dispose() + { + // Nothing to dispose + } + } + + /// + /// Stubbed web response data for testing different response scenarios + /// + private class StubWebResponseData : IWebResponseData + { + private readonly IHttpResponseBody _httpResponseBody; + + public StubWebResponseData(string jsonResponse, Dictionary headers = null, + HttpStatusCode statusCode = HttpStatusCode.OK) + { + StatusCode = statusCode; + IsSuccessStatusCode = (int)statusCode >= 200 && (int)statusCode < 300; + JsonResponse = jsonResponse; + Headers = headers ?? new Dictionary(StringComparer.OrdinalIgnoreCase); + ContentType = "application/json"; + ContentLength = jsonResponse?.Length ?? 0; + + _httpResponseBody = new HttpResponseBody(jsonResponse); + } + + public Dictionary Headers { get; set; } + public string JsonResponse { get; } + public long ContentLength { get; set; } + public string ContentType { get; set; } + public HttpStatusCode StatusCode { get; set; } + public bool IsSuccessStatusCode { get; set; } + + public IHttpResponseBody ResponseBody => _httpResponseBody; + + public string[] GetHeaderNames() + { + return Headers.Keys.ToArray(); + } + + public bool IsHeaderPresent(string headerName) + { + return Headers.ContainsKey(headerName); + } + + public string GetHeaderValue(string headerName) + { + return Headers.ContainsKey(headerName) ? Headers[headerName] : null; + } + } + + /// + /// HTTP response body implementation for stubbed responses + /// + private class HttpResponseBody : IHttpResponseBody + { + private readonly string _jsonResponse; + private Stream _stream; + + public HttpResponseBody(string jsonResponse) + { + _jsonResponse = jsonResponse ?? string.Empty; + } + + public void Dispose() + { + _stream?.Dispose(); + } + + public Stream OpenResponse() + { + _stream = new MemoryStream(Encoding.UTF8.GetBytes(_jsonResponse)); + return _stream; + } + + public Task OpenResponseAsync() + { + return Task.FromResult(OpenResponse()); + } + } + + #endregion + + #region ResponseFormat with HTTP Mocking Tests + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithActualConverseResponse_ParsesCorrectly() + { + // Arrange - This is a real Converse API response with tool_use + var converseResponse = """ + { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_12345", + "name": "generate_response", + "input": { + "name": "Alice Johnson", + "age": 28, + "city": "Seattle" + } + } + } + ] + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 125, + "outputTokens": 45, + "totalTokens": 170 + } + } + """; + + var chatClient = _fixture.BedrockRuntimeClient.AsIChatClient("anthropic.claude-3-sonnet-20240229-v1:0"); + var messages = new[] { new ChatMessage(ChatRole.User, "Generate a person") }; + + var schemaJson = """ + { + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "number" }, + "city": { "type": "string" } + }, + "required": ["name", "age"] + } + """; + var schemaElement = JsonDocument.Parse(schemaJson).RootElement; + + var request = new ConverseRequest(); + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.ForJsonSchema(schemaElement, + schemaName: "PersonSchema", + schemaDescription: "A person with demographic information"), + RawRepresentationFactory = _ => request + }; + + // Inject the stubbed response + var webResponseData = new StubWebResponseData(converseResponse); + InjectMockedResponse(request, webResponseData); + + // Act + var response = await chatClient.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(response); + Assert.NotNull(response.Text); + + // Verify the JSON structure + var json = JsonDocument.Parse(response.Text); + Assert.Equal("Alice Johnson", json.RootElement.GetProperty("name").GetString()); + Assert.Equal(28, json.RootElement.GetProperty("age").GetInt32()); + Assert.Equal("Seattle", json.RootElement.GetProperty("city").GetString()); + + // Verify usage metadata + var usage = response.Usage; + Assert.NotNull(usage); + Assert.Equal(125, usage.InputTokenCount); + Assert.Equal(45, usage.OutputTokenCount); + Assert.Equal(170, usage.TotalTokenCount); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithNestedObjects_ParsesCorrectly() + { + // Arrange - Test with nested JSON structure + var converseResponse = """ + { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_nested", + "name": "generate_response", + "input": { + "user": { + "name": "Bob Smith", + "contact": { + "email": "bob@example.com", + "phone": "555-0123" + } + }, + "metadata": { + "timestamp": "2024-01-15T10:30:00Z", + "version": 1 + } + } + } + } + ] + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 200, + "outputTokens": 80, + "totalTokens": 280 + } + } + """; + + var chatClient = _fixture.BedrockRuntimeClient.AsIChatClient("anthropic.claude-3-sonnet-20240229-v1:0"); + var messages = new[] { new ChatMessage(ChatRole.User, "Generate user data") }; + + var request = new ConverseRequest(); + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.Json, + RawRepresentationFactory = _ => request + }; + + var webResponseData = new StubWebResponseData(converseResponse); + InjectMockedResponse(request, webResponseData); + + // Act + var response = await chatClient.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(response.Text); + var json = JsonDocument.Parse(response.Text); + + var user = json.RootElement.GetProperty("user"); + Assert.Equal("Bob Smith", user.GetProperty("name").GetString()); + + var contact = user.GetProperty("contact"); + Assert.Equal("bob@example.com", contact.GetProperty("email").GetString()); + Assert.Equal("555-0123", contact.GetProperty("phone").GetString()); + + var metadata = json.RootElement.GetProperty("metadata"); + Assert.Equal("2024-01-15T10:30:00Z", metadata.GetProperty("timestamp").GetString()); + Assert.Equal(1, metadata.GetProperty("version").GetInt32()); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithArrayData_ParsesCorrectly() + { + // Arrange - Test with arrays in JSON response + var converseResponse = """ + { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_array", + "name": "generate_response", + "input": { + "items": ["apple", "banana", "orange"], + "prices": [1.99, 0.99, 2.49], + "inventory": { + "warehouse": "A", + "quantities": [100, 250, 75] + } + } + } + } + ] + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 50, + "outputTokens": 30, + "totalTokens": 80 + } + } + """; + + var chatClient = _fixture.BedrockRuntimeClient.AsIChatClient("anthropic.claude-3-sonnet-20240229-v1:0"); + var messages = new[] { new ChatMessage(ChatRole.User, "List items") }; + + var request = new ConverseRequest(); + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.Json, + RawRepresentationFactory = _ => request + }; + + var webResponseData = new StubWebResponseData(converseResponse); + InjectMockedResponse(request, webResponseData); + + // Act + var response = await chatClient.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(response.Text); + var json = JsonDocument.Parse(response.Text); + + var items = json.RootElement.GetProperty("items"); + Assert.Equal(JsonValueKind.Array, items.ValueKind); + Assert.Equal(3, items.GetArrayLength()); + Assert.Equal("apple", items[0].GetString()); + Assert.Equal("banana", items[1].GetString()); + Assert.Equal("orange", items[2].GetString()); + + var prices = json.RootElement.GetProperty("prices"); + Assert.Equal(3, prices.GetArrayLength()); + Assert.Equal(1.99, prices[0].GetDouble(), precision: 2); + + var inventory = json.RootElement.GetProperty("inventory"); + var quantities = inventory.GetProperty("quantities"); + Assert.Equal(3, quantities.GetArrayLength()); + Assert.Equal(100, quantities[0].GetInt32()); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithMinimalSchema_ParsesCorrectly() + { + // Arrange - Test simple JSON response + var converseResponse = """ + { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_simple", + "name": "generate_response", + "input": { + "message": "Hello, World!", + "status": "success" + } + } + } + ] + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 10, + "outputTokens": 5, + "totalTokens": 15 + } + } + """; + + var chatClient = _fixture.BedrockRuntimeClient.AsIChatClient("anthropic.claude-3-haiku-20240307-v1:0"); + var messages = new[] { new ChatMessage(ChatRole.User, "Say hello") }; + + var request = new ConverseRequest(); + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.Json, + RawRepresentationFactory = _ => request + }; + + var webResponseData = new StubWebResponseData(converseResponse); + InjectMockedResponse(request, webResponseData); + + // Act + var response = await chatClient.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(response.Text); + var json = JsonDocument.Parse(response.Text); + Assert.Equal("Hello, World!", json.RootElement.GetProperty("message").GetString()); + Assert.Equal("success", json.RootElement.GetProperty("status").GetString()); + } + + [Fact] + [Trait("UnitTest", "BedrockRuntime")] + public async Task ResponseFormat_Json_WithComplexSchema_ValidatesStructure() + { + // Arrange - Test with detailed schema validation + var converseResponse = """ + { + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_complex", + "name": "generate_response", + "input": { + "id": "usr_123", + "username": "testuser", + "email": "test@example.com", + "profile": { + "firstName": "Test", + "lastName": "User", + "age": 25, + "preferences": { + "theme": "dark", + "notifications": true + } + }, + "roles": ["admin", "user"], + "active": true + } + } + } + ] + } + }, + "stopReason": "tool_use", + "usage": { + "inputTokens": 300, + "outputTokens": 150, + "totalTokens": 450 + } + } + """; + + var chatClient = _fixture.BedrockRuntimeClient.AsIChatClient("anthropic.claude-3-sonnet-20240229-v1:0"); + var messages = new[] { new ChatMessage(ChatRole.User, "Generate user profile") }; + + var schemaJson = """ + { + "type": "object", + "properties": { + "id": { "type": "string" }, + "username": { "type": "string" }, + "email": { "type": "string", "format": "email" }, + "profile": { + "type": "object", + "properties": { + "firstName": { "type": "string" }, + "lastName": { "type": "string" }, + "age": { "type": "number" }, + "preferences": { "type": "object" } + }, + "required": ["firstName", "lastName"] + }, + "roles": { + "type": "array", + "items": { "type": "string" } + }, + "active": { "type": "boolean" } + }, + "required": ["id", "username", "email"] + } + """; + var schemaElement = JsonDocument.Parse(schemaJson).RootElement; + + var request = new ConverseRequest(); + var options = new ChatOptions + { + ResponseFormat = ChatResponseFormat.ForJsonSchema(schemaElement, + schemaName: "UserProfile", + schemaDescription: "Complete user profile with preferences"), + RawRepresentationFactory = _ => request + }; + + var webResponseData = new StubWebResponseData(converseResponse); + InjectMockedResponse(request, webResponseData); + + // Act + var response = await chatClient.GetResponseAsync(messages, options); + + // Assert + Assert.NotNull(response.Text); + var json = JsonDocument.Parse(response.Text); + + // Verify required fields + Assert.Equal("usr_123", json.RootElement.GetProperty("id").GetString()); + Assert.Equal("testuser", json.RootElement.GetProperty("username").GetString()); + Assert.Equal("test@example.com", json.RootElement.GetProperty("email").GetString()); + + // Verify nested profile + var profile = json.RootElement.GetProperty("profile"); + Assert.Equal("Test", profile.GetProperty("firstName").GetString()); + Assert.Equal("User", profile.GetProperty("lastName").GetString()); + Assert.Equal(25, profile.GetProperty("age").GetInt32()); + + // Verify nested preferences + var preferences = profile.GetProperty("preferences"); + Assert.Equal("dark", preferences.GetProperty("theme").GetString()); + Assert.True(preferences.GetProperty("notifications").GetBoolean()); + + // Verify array + var roles = json.RootElement.GetProperty("roles"); + Assert.Equal(2, roles.GetArrayLength()); + Assert.Equal("admin", roles[0].GetString()); + Assert.Equal("user", roles[1].GetString()); + + // Verify boolean + Assert.True(json.RootElement.GetProperty("active").GetBoolean()); + } + + #endregion + + /// + /// Test fixture that registers the HTTP mocking pipeline customizer + /// + public class HttpMockFixture : IDisposable + { + private readonly MockPipelineCustomizer _customizer; + + public HttpMockFixture() + { + // Register the mock pipeline customizer globally + _customizer = new MockPipelineCustomizer(); + Runtime.Internal.RuntimePipelineCustomizerRegistry.Instance.Register(_customizer); + + // Create the Bedrock Runtime client - it will use the mocked pipeline + BedrockRuntimeClient = new AmazonBedrockRuntimeClient(); + } + + public IAmazonBedrockRuntime BedrockRuntimeClient { get; private set; } + + public void Dispose() + { + // Clean up + Runtime.Internal.RuntimePipelineCustomizerRegistry.Instance.Deregister(_customizer); + BedrockRuntimeClient?.Dispose(); + } + } +} \ No newline at end of file diff --git a/extensions/test/BedrockMEAITests/BedrockMEAITests.NetFramework.csproj b/extensions/test/BedrockMEAITests/BedrockMEAITests.NetFramework.csproj index dd9de35ce4a5..1289d63cea0d 100644 --- a/extensions/test/BedrockMEAITests/BedrockMEAITests.NetFramework.csproj +++ b/extensions/test/BedrockMEAITests/BedrockMEAITests.NetFramework.csproj @@ -1,6 +1,7 @@  net472 + $(DefineConstants);BCL BedrockMEAITests BedrockMEAITests @@ -19,6 +20,7 @@ + diff --git a/generator/.DevConfigs/12b83a1f-1d6b-4e96-bd62-f0e0b7e4df6d.json b/generator/.DevConfigs/12b83a1f-1d6b-4e96-bd62-f0e0b7e4df6d.json new file mode 100644 index 000000000000..297c2483daf0 --- /dev/null +++ b/generator/.DevConfigs/12b83a1f-1d6b-4e96-bd62-f0e0b7e4df6d.json @@ -0,0 +1,11 @@ +{ + "extensions": [ + { + "extensionName": "Extensions.Bedrock.MEAI", + "type": "minor", + "changeLogMessages": [ + "Add support for ChatOptions.ResponseFormat to enable structured JSON responses using JSON Schema." + ] + } + ] +}