From 073cb8f7b797b5408615cf6c955c8d8a7f263baf Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Wed, 2 Apr 2025 20:25:15 -0700 Subject: [PATCH 01/21] migrate to azure open ai --- .../Assistants/ChatMessage.cs | 7 -- .../AzureAISearchProvider.cs | 6 +- .../CosmosDBSearchProvider.cs | 7 +- .../KustoSearchProvider.cs | 6 +- .../Assistants/AssistantRuntimeState.cs | 4 +- .../Assistants/AssistantService.cs | 98 +++++++++-------- .../AssistantSkillTriggerBindingProvider.cs | 4 +- .../ChatCompletionsJsonConverter.cs | 10 +- .../Assistants/IAssistantSkillInvoker.cs | 42 ++++---- .../Embeddings/EmbeddingsContext.cs | 13 ++- .../Embeddings/EmbeddingsContextConverter.cs | 8 +- .../Embeddings/EmbeddingsConverter.cs | 22 ++-- .../Embeddings/EmbeddingsHelper.cs | 6 +- .../EmbeddingsOptionsJsonConverter.cs | 10 +- .../Embeddings/EmbeddingsStoreConverter.cs | 27 ++--- .../{ChatMessage.cs => AssistantMessage.cs} | 13 +-- .../Models/AssistantState.cs | 4 +- .../Models/ChatMessageTableEntity.cs | 100 +++++++++++++++++- .../OpenAIClientFactory.cs | 61 +++++++++++ .../OpenAIExtension.cs | 30 +++--- .../OpenAIWebJobsBuilderExtensions.cs | 30 ++---- .../Search/SearchableDocumentJsonConverter.cs | 20 ++-- .../Search/SemanticSearchContext.cs | 8 +- .../Search/SemanticSearchConverter.cs | 41 +++---- .../TextCompletionAttribute.cs | 18 +--- .../TextCompletionConverter.cs | 28 ++--- .../WebJobs.Extensions.OpenAI.csproj | 4 +- 27 files changed, 382 insertions(+), 245 deletions(-) rename src/WebJobs.Extensions.OpenAI/Models/{ChatMessage.cs => AssistantMessage.cs} (71%) create mode 100644 src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs index 585fc205..bbd14017 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs @@ -19,7 +19,6 @@ public ChatMessage(string content, string role, string? name) { this.Content = content; this.Role = role; - this.Name = name; } /// @@ -33,10 +32,4 @@ public ChatMessage(string content, string role, string? name) /// [JsonPropertyName("role")] public string Role { get; set; } - - /// - /// Gets or sets the name of the calling function if applicable. - /// - [JsonPropertyName("name")] - public string? Name { get; set; } } diff --git a/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs b/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs index 4ad8280e..14abae5b 100644 --- a/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs +++ b/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs @@ -246,16 +246,16 @@ async Task IndexSectionsAsync(SearchClient searchClient, SearchableDocument docu { int iteration = 0; IndexDocumentsBatch batch = new(); - for (int i = 0; i < document.Embeddings?.Response?.Data.Count; i++) + for (int i = 0; i < document.Embeddings?.Response?.Count; i++) { batch.Actions.Add(new IndexDocumentsAction( IndexActionType.MergeOrUpload, new SearchDocument { ["id"] = Guid.NewGuid().ToString("N"), - ["text"] = document.Embeddings.Request.Input![i], + ["text"] = document.Embeddings.Input![i], ["title"] = Path.GetFileNameWithoutExtension(document.Title), - ["embeddings"] = document.Embeddings.Response.Data[i].Embedding.ToArray() ?? Array.Empty(), + ["embeddings"] = document.Embeddings.Response[i].ToFloats().ToArray() ?? Array.Empty(), ["timestamp"] = DateTime.UtcNow })); iteration++; diff --git a/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs b/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs index 8dbfc5f8..cd5cc65e 100644 --- a/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs +++ b/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs @@ -211,7 +211,7 @@ public void CreateVectorIndexIfNotExists(MongoClient cosmosClient) async Task UpsertVectorAsync(MongoClient cosmosClient, SearchableDocument document) { List list = new(); - for (int i = 0; i < document.Embeddings?.Response?.Data.Count; i++) + for (int i = 0; i < document.Embeddings?.Response?.Count; i++) { BsonDocument vectorDocument = new() @@ -219,15 +219,14 @@ async Task UpsertVectorAsync(MongoClient cosmosClient, SearchableDocument docume { "id", Guid.NewGuid().ToString("N") }, { this.cosmosDBSearchConfigOptions.Value.TextKey, - document.Embeddings.Request.Input![i] + document.Embeddings.Input![i] }, { "title", Path.GetFileNameWithoutExtension(document.Title) }, { this.cosmosDBSearchConfigOptions.Value.EmbeddingKey, new BsonArray( document - .Embeddings.Response.Data[i] - .Embedding.ToArray() + .Embeddings.Response[i].ToFloats().ToArray() .Select(e => new BsonDouble(Convert.ToDouble(e))) ) }, diff --git a/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs b/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs index 542454a3..156868bb 100644 --- a/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs +++ b/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs @@ -78,13 +78,13 @@ public async Task AddDocumentAsync(SearchableDocument document, CancellationToke table.AppendColumn("Embeddings", typeof(object)); table.AppendColumn("Timestamp", typeof(DateTime)); - for (int i = 0; i < document.Embeddings?.Response?.Data.Count; i++) + for (int i = 0; i < document.Embeddings?.Response?.Count; i++) { table.Rows.Add( Guid.NewGuid().ToString("N"), Path.GetFileNameWithoutExtension(document.Title), - document.Embeddings.Request.Input![i], - GetEmbeddingsString(document.Embeddings.Response.Data[i].Embedding, true), + document.Embeddings.Input![i], + GetEmbeddingsString(document.Embeddings.Response[i].ToFloats().ToArray(), true), DateTime.UtcNow); } diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantRuntimeState.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantRuntimeState.cs index 55dbd2f5..03b10f6d 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantRuntimeState.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantRuntimeState.cs @@ -6,13 +6,13 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; -record struct MessageRecord(DateTime Timestamp, ChatMessage ChatMessageEntity); +record struct MessageRecord(DateTime Timestamp, AssistantMessage ChatMessageEntity); [JsonObject(MemberSerialization.OptIn)] class AssistantRuntimeState { [JsonProperty("messages")] - public List? ChatMessages { get; set; } + public List? ChatMessages { get; set; } [JsonProperty("totalTokens")] public int TotalTokens { get; set; } = 0; diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs index e31e8a3b..56f4abdb 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.ClientModel; using Azure; using Azure.AI.OpenAI; using Azure.Data.Tables; @@ -8,6 +9,7 @@ using Microsoft.Extensions.Azure; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; @@ -28,7 +30,7 @@ record InternalChatState(string Id, AssistantStateEntity Metadata, List const int FunctionCallBatchLimit = 50; const string DefaultChatStorage = "AzureWebJobsStorage"; - readonly OpenAIClient openAIClient; + readonly ChatClient chatClient; readonly IAssistantSkillInvoker skillInvoker; readonly ILogger logger; readonly AzureComponentFactory azureComponentFactory; @@ -37,7 +39,7 @@ record InternalChatState(string Id, AssistantStateEntity Metadata, List(); this.azureComponentFactory = azureComponentFactory ?? throw new ArgumentNullException(nameof(azureComponentFactory)); @@ -117,7 +121,8 @@ async Task DeleteBatch() partitionKey: request.Id, messageIndex: 1, // 1-based index content: request.Instructions, - role: ChatRole.System); + role: ChatMessageRole.System, + toolCalls: null); batch.Add(new TableTransactionAction(TableTransactionActionType.Add, chatMessageEntity)); } @@ -152,7 +157,7 @@ public async Task GetStateAsync(AssistantQueryAttribute assistan if (chatState is null) { this.logger.LogWarning("No assistant exists with ID = '{Id}'", id); - return new AssistantState(id, false, default, default, 0, 0, Array.Empty()); + return new AssistantState(id, false, default, default, 0, 0, Array.Empty()); } List filteredChatMessages = chatState.Messages @@ -171,7 +176,7 @@ public async Task GetStateAsync(AssistantQueryAttribute assistan chatState.Metadata.LastUpdatedAt, chatState.Metadata.TotalMessages, chatState.Metadata.TotalTokens, - filteredChatMessages.Select(msg => new ChatMessage(msg.Content, msg.Role, msg.Name)).ToList()); + filteredChatMessages.Select(msg => new AssistantMessage(msg.Content, msg.Role)).ToList()); return state; } @@ -198,7 +203,7 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib if (chatState is null || !chatState.Metadata.Exists) { this.logger.LogWarning("[{Id}] Ignoring request sent to nonexistent assistant.", attribute.Id); - return new AssistantState(attribute.Id, false, default, default, 0, 0, Array.Empty()); + return new AssistantState(attribute.Id, false, default, default, 0, 0, Array.Empty()); } this.logger.LogInformation("[{Id}] Received message: {Text}", attribute.Id, attribute.UserMessage); @@ -211,51 +216,55 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib partitionKey: attribute.Id, messageIndex: ++chatState.Metadata.TotalMessages, content: attribute.UserMessage, - role: ChatRole.User); + role: ChatMessageRole.User, + toolCalls: null); chatState.Messages.Add(chatMessageEntity); // Add the chat message to the batch batch.Add(new TableTransactionAction(TableTransactionActionType.Add, chatMessageEntity)); string deploymentName = attribute.Model ?? OpenAIModels.DefaultChatModel; - IList? functions = this.skillInvoker.GetFunctionsDefinitions(); + IList? functions = this.skillInvoker.GetFunctionsDefinitions(); // We loop if the model returns function calls. Otherwise, we break after receiving a response. while (true) { // Get the next response from the LLM - ChatCompletionsOptions chatRequest = new(deploymentName, ToOpenAIChatRequestMessages(chatState.Messages)); + ChatCompletionOptions chatRequest = new ChatCompletionOptions(); if (functions is not null) { - foreach (ChatCompletionsFunctionToolDefinition fn in functions) + foreach (ChatTool fn in functions) { chatRequest.Tools.Add(fn); } } + chatRequest.ToolChoice = ChatToolChoice.CreateAutoChoice(); + IEnumerable chatMessages = ToOpenAIChatRequestMessages(chatState.Messages); - Response response = await this.openAIClient.GetChatCompletionsAsync( - chatRequest, - cancellationToken); + // ToDo: Pass more ChatCompletionOptions like TextCompletion + ClientResult response = await this.chatClient.CompleteChatAsync(chatMessages, chatRequest); // We don't normally expect more than one message, but just in case we get multiple messages, // return all of them separated by two newlines. string replyMessage = string.Join( Environment.NewLine + Environment.NewLine, - response.Value.Choices.Select(choice => choice.Message.Content)); - if (!string.IsNullOrWhiteSpace(replyMessage)) + response.Value.Content.Select(message => message.Text)); + if (!string.IsNullOrWhiteSpace(replyMessage) || response.Value.ToolCalls.Any()) { this.logger.LogInformation( - "[{Id}] Got LLM response consisting of {Count} tokens: {Text}", + "[{Id}] Got LLM response consisting of {Count} tokens: [{Text}] && {Count} ToolCalls", attribute.Id, - response.Value.Usage.CompletionTokens, - replyMessage); + response.Value.Usage.OutputTokenCount, + replyMessage, + response.Value.ToolCalls.Count); // Add the user message as a new Chat message entity ChatMessageTableEntity replyFromAssistantEntity = new( partitionKey: attribute.Id, messageIndex: ++chatState.Metadata.TotalMessages, content: replyMessage, - role: ChatRole.Assistant); + role: ChatMessageRole.Assistant, + toolCalls: response.Value.ToolCalls); chatState.Messages.Add(replyFromAssistantEntity); // Add the reply from assistant chat message to the batch @@ -268,12 +277,11 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib } // Set the total tokens that have been consumed. - chatState.Metadata.TotalTokens = response.Value.Usage.TotalTokens; + chatState.Metadata.TotalTokens = response.Value.Usage.TotalTokenCount; // Check for function calls (which are described in the API as tools) - List functionCalls = response.Value.Choices - .SelectMany(c => c.Message.ToolCalls) - .OfType() + List functionCalls = response.Value.ToolCalls + .OfType() .ToList(); if (functionCalls.Count == 0) { @@ -299,16 +307,17 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib attribute.Id, functionCalls.Count); + // Invoke the function calls and add the responses to the chat history. List> tasks = new(capacity: functionCalls.Count); - foreach (ChatCompletionsFunctionToolCall call in functionCalls) + foreach (ChatToolCall call in functionCalls) { // CONSIDER: Call these in parallel this.logger.LogInformation( "[{Id}] Calling function '{Name}' with arguments: {Args}", attribute.Id, - call.Name, - call.Arguments); + call.FunctionName, + call.FunctionArguments); string? functionResult; try @@ -320,7 +329,7 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib this.logger.LogInformation( "[{id}] Function '{Name}' returned the following content: {Content}", attribute.Id, - call.Name, + call.FunctionName, functionResult); } catch (Exception ex) @@ -329,7 +338,7 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib ex, "[{id}] Function '{Name}' failed with an unhandled exception", attribute.Id, - call.Name); + call.FunctionName); // CONSIDER: Automatic retries? functionResult = "The function call failed. Let the user know and ask if they'd like you to try again"; @@ -346,9 +355,10 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib ChatMessageTableEntity functionResultEntity = new( partitionKey: attribute.Id, messageIndex: ++chatState.Metadata.TotalMessages, - content: functionResult, - role: ChatRole.Function, - name: call.Name); + content: $"Function Name: '{call.FunctionName}' and Function Result: '{functionResult}'", + role: ChatMessageRole.Tool, + name: call.Id, + toolCalls: null); chatState.Messages.Add(functionResultEntity); batch.Add(new TableTransactionAction(TableTransactionActionType.Add, functionResultEntity)); @@ -365,7 +375,7 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib // return the latest assistant message in the chat state List filteredChatMessages = chatState.Messages - .Where(msg => msg.CreatedAt > timeFilter && msg.Role == ChatRole.Assistant) + .Where(msg => msg.CreatedAt > timeFilter && msg.Role == ChatMessageRole.Assistant.ToString()) .ToList(); this.logger.LogInformation( @@ -381,7 +391,7 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib chatState.Metadata.LastUpdatedAt, chatState.Metadata.TotalMessages, chatState.Metadata.TotalTokens, - filteredChatMessages.Select(msg => new ChatMessage(msg.Content, msg.Role, msg.Name)).ToList()); + filteredChatMessages.Select(msg => new AssistantMessage(msg.Content, msg.Role)).ToList()); return state; } @@ -420,26 +430,30 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib return new InternalChatState(id, assistantStateEntity, chatMessageList); } - static IEnumerable ToOpenAIChatRequestMessages(IEnumerable entities) + static IEnumerable ToOpenAIChatRequestMessages(IEnumerable entities) { foreach (ChatMessageTableEntity entity in entities) { switch (entity.Role.ToLowerInvariant()) { case "user": - yield return new ChatRequestUserMessage(entity.Content); + yield return new UserChatMessage(entity.Content); break; case "assistant": - yield return new ChatRequestAssistantMessage(entity.Content); + if (entity.ToolCalls != null && entity.ToolCalls.Any()) + { + yield return new AssistantChatMessage(entity.ToolCalls); + } + else + { + yield return new AssistantChatMessage(entity.Content); + } break; case "system": - yield return new ChatRequestSystemMessage(entity.Content); - break; - case "function": - yield return new ChatRequestFunctionMessage(entity.Name, entity.Content); + yield return new SystemChatMessage(entity.Content); break; case "tool": - yield return new ChatRequestToolMessage(entity.Content, toolCallId: entity.Name); + yield return new ToolChatMessage(toolCallId: entity.Name, entity.Content); break; default: throw new InvalidOperationException($"Unknown chat role '{entity.Role}'"); diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs index 55eed508..c4dd70f8 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs @@ -113,10 +113,10 @@ public Task BindAsync(object value, ValueBindingContext context) SkillInvocationContext skillInvocationContext = (SkillInvocationContext)value; object? convertedValue; - if (!string.IsNullOrEmpty(skillInvocationContext.Arguments)) + if (!string.IsNullOrEmpty(skillInvocationContext.Arguments.ToString())) { // We expect that input to always be a string value in the form {"paramName":paramValue} - JObject argsJson = JObject.Parse(skillInvocationContext.Arguments); + JObject argsJson = JObject.Parse(skillInvocationContext.Arguments.ToString()); JToken? paramValue = argsJson[this.parameterInfo.Name]; convertedValue = paramValue?.ToObject(destinationType); } diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/ChatCompletionsJsonConverter.cs b/src/WebJobs.Extensions.OpenAI/Assistants/ChatCompletionsJsonConverter.cs index f8ad8c8c..6d0b7e49 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/ChatCompletionsJsonConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/ChatCompletionsJsonConverter.cs @@ -4,19 +4,19 @@ using System.ClientModel.Primitives; using System.Text.Json; using System.Text.Json.Serialization; -using Azure.AI.OpenAI; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; -class ChatCompletionsJsonConverter : JsonConverter +class ChatCompletionsJsonConverter : JsonConverter { static readonly ModelReaderWriterOptions modelReaderWriterOptions = new("J"); - public override ChatCompletions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override ChatCompletion Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { throw new NotImplementedException(); } - public override void Write(Utf8JsonWriter writer, ChatCompletions value, JsonSerializerOptions options) + public override void Write(Utf8JsonWriter writer, ChatCompletion value, JsonSerializerOptions options) { - ((IJsonModel)value).Write(writer, modelReaderWriterOptions); + ((IJsonModel)value).Write(writer, modelReaderWriterOptions); } } diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs b/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs index d3e539ba..c11f55bd 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs @@ -4,28 +4,28 @@ using System.Reflection; using System.Runtime.ExceptionServices; using System.Text; -using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Host.Executors; using Microsoft.Extensions.Logging; using Newtonsoft.Json; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; public interface IAssistantSkillInvoker { - IList? GetFunctionsDefinitions(); - Task InvokeAsync(ChatCompletionsFunctionToolCall call, CancellationToken cancellationToken); + IList? GetFunctionsDefinitions(); + Task InvokeAsync(ChatToolCall call, CancellationToken cancellationToken); } class SkillInvocationContext { - public SkillInvocationContext(string arguments) + public SkillInvocationContext(BinaryData arguments) { this.Arguments = arguments; } // The arguments are passed as a JSON object in the form of {"paramName":paramValue} - public string Arguments { get; } + public BinaryData Arguments { get; } // The result of the function invocation, if any public object? Result { get; set; } @@ -70,14 +70,14 @@ internal void UnregisterSkill(string name) this.skills.Remove(name); } - IList? IAssistantSkillInvoker.GetFunctionsDefinitions() + IList? IAssistantSkillInvoker.GetFunctionsDefinitions() { if (this.skills.Count == 0) { return null; } - List functions = new(capacity: this.skills.Count); + List functions = new(capacity: this.skills.Count); foreach (Skill skill in this.skills.Values) { // The parameters can be defined in the attribute JSON or can be inferred from @@ -85,12 +85,12 @@ internal void UnregisterSkill(string name) string parametersJson = skill.Attribute.ParameterDescriptionJson ?? JsonConvert.SerializeObject(GetParameterDefinition(skill)); - functions.Add(new ChatCompletionsFunctionToolDefinition - { - Name = skill.Name, - Description = skill.Attribute.FunctionDescription, - Parameters = BinaryData.FromBytes(Encoding.UTF8.GetBytes(parametersJson)), - }); + ChatTool chatTool = ChatTool.CreateFunctionTool( + skill.Name, + skill.Attribute.FunctionDescription, + BinaryData.FromBytes(Encoding.UTF8.GetBytes(parametersJson)) + ); + functions.Add(chatTool); } return functions; @@ -140,7 +140,7 @@ static Dictionary GetParameterDefinition(Skill skill) } async Task IAssistantSkillInvoker.InvokeAsync( - ChatCompletionsFunctionToolCall call, + ChatToolCall call, CancellationToken cancellationToken) { if (call is null) @@ -148,17 +148,17 @@ static Dictionary GetParameterDefinition(Skill skill) throw new ArgumentNullException(nameof(call)); } - if (call.Name is null) + if (call.FunctionName is null) { throw new ArgumentException("The function call must have a name", nameof(call)); } - if (!this.skills.TryGetValue(call.Name, out Skill? skill)) + if (!this.skills.TryGetValue(call.FunctionName, out Skill? skill)) { - throw new InvalidOperationException($"No skill registered with name '{call.Name}'"); + throw new InvalidOperationException($"No skill registered with name '{call.FunctionName}'"); } - SkillInvocationContext skillInvocationContext = new(call.Arguments); + SkillInvocationContext skillInvocationContext = new(call.FunctionArguments); // This call may throw if the Functions host is shutting down or if there is an internal error // in the Functions runtime. We don't currently try to handle these exceptions. @@ -170,7 +170,7 @@ static Dictionary GetParameterDefinition(Skill skill) InvokeHandler = async userCodeInvoker => { // Invoke the function and attempt to get the result. - this.logger.LogInformation("Invoking user-code function '{Name}'", call.Name); + this.logger.LogInformation("Invoking user-code function '{Name}'", call.FunctionName); Task invokeTask = userCodeInvoker.Invoke(); if (invokeTask is Task invokeTaskWithResult) { @@ -182,7 +182,7 @@ static Dictionary GetParameterDefinition(Skill skill) this.logger.LogWarning( "Unable to discover the return value (if any) for user-code function '{Name}'. " + "This is an internal error in the extension that may result in model hallucination.", - call.Name); + call.FunctionName); await invokeTask; } } @@ -205,7 +205,7 @@ static Dictionary GetParameterDefinition(Skill skill) // Convert the output to JSON string jsonResult = JsonConvert.SerializeObject(skillInvocationContext.Result); this.logger.LogInformation( - "Returning output of user-code function '{Name}' as JSON: {Json}", call.Name, jsonResult); + "Returning output of user-code function '{Name}' as JSON: {Json}", call.FunctionName, jsonResult); return jsonResult; } } \ No newline at end of file diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs index c6b90c1a..ab59b3d0 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs @@ -1,8 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using OpenAISDK = Azure.AI.OpenAI; - +using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; /// @@ -12,24 +11,24 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; /// The embeddings response that was received from OpenAI. public class EmbeddingsContext { - public EmbeddingsContext(OpenAISDK.EmbeddingsOptions Request, OpenAISDK.Embeddings? Response) + public EmbeddingsContext(IList Input, OpenAIEmbeddingCollection? Response) { - this.Request = Request; + this.Input = Input; this.Response = Response; } /// /// Embeddings request sent to OpenAI. /// - public OpenAISDK.EmbeddingsOptions Request { get; set; } + public IList Input { get; set; } /// /// Embeddings response from OpenAI. /// - public OpenAISDK.Embeddings? Response { get; set; } + public OpenAIEmbeddingCollection? Response { get; set; } /// /// Gets the number of embeddings that were returned in the response. /// - public int Count => this.Response?.Data?.Count ?? 0; + public int Count => this.Response?.Count ?? 0; } diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs index ebe439c7..6634064a 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs @@ -4,7 +4,7 @@ using System.ClientModel.Primitives; using System.Text.Json; using System.Text.Json.Serialization; -using OpenAISDK = Azure.AI.OpenAI; +using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; @@ -22,10 +22,10 @@ public override EmbeddingsContext Read(ref Utf8JsonReader reader, Type typeToCon public override void Write(Utf8JsonWriter writer, EmbeddingsContext value, JsonSerializerOptions options) { writer.WriteStartObject(); - writer.WritePropertyName("request"u8); - ((IJsonModel)value.Request).Write(writer, modelReaderWriterOptions); + writer.WritePropertyName("input"u8); + ((IJsonModel)value.Input).Write(writer, modelReaderWriterOptions); - if (value.Response is IJsonModel response) + if (value.Response is IJsonModel response) { writer.WritePropertyName("response"u8); response.Write(writer, modelReaderWriterOptions); diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs index 28fc990b..96f6fca3 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs @@ -1,11 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.ClientModel; using System.Text.Json; -using Azure; +using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; using Microsoft.Extensions.Logging; -using OpenAISDK = Azure.AI.OpenAI; +using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; @@ -13,7 +14,7 @@ class EmbeddingsConverter : IAsyncConverter, IAsyncConverter { - readonly OpenAISDK.OpenAIClient openAIClient; + readonly EmbeddingClient embeddingClient; readonly ILogger logger; // Note: we need this converter as Azure.AI.OpenAI does not support System.Text.Json serialization since their constructors are internal @@ -22,9 +23,10 @@ class EmbeddingsConverter : Converters = { new EmbeddingsContextConverter(), new SearchableDocumentJsonConverter() } }; - public EmbeddingsConverter(OpenAISDK.OpenAIClient openAIClient, ILoggerFactory loggerFactory) + public EmbeddingsConverter(AzureOpenAIClient azureOpenAIClient, ILoggerFactory loggerFactory) { - this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient)); + // ToDo: Handle the deployment name retrieval better + this.embeddingClient = azureOpenAIClient.GetEmbeddingClient("embedding") ?? throw new ArgumentNullException(nameof(azureOpenAIClient)); this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); } @@ -47,11 +49,11 @@ async Task ConvertCoreAsync( EmbeddingsAttribute attribute, CancellationToken cancellationToken) { - OpenAISDK.EmbeddingsOptions request = await EmbeddingsHelper.BuildRequest(attribute.MaxOverlap, attribute.MaxChunkLength, attribute.Model, attribute.InputType, attribute.Input); - this.logger.LogInformation("Sending OpenAI embeddings request: {request}", request.Input); - Response response = await this.openAIClient.GetEmbeddingsAsync(request, cancellationToken); - this.logger.LogInformation("Received OpenAI embeddings count: {response}", response.Value.Data.Count); + List input = await EmbeddingsHelper.BuildRequest(attribute.MaxOverlap, attribute.MaxChunkLength, attribute.Model, attribute.InputType, attribute.Input); + this.logger.LogInformation("Sending OpenAI embeddings request"); + ClientResult response = await this.embeddingClient.GenerateEmbeddingsAsync(input); + this.logger.LogInformation("Received OpenAI embeddings count: {response}", response.Value.Count); - return new EmbeddingsContext(request, response); + return new EmbeddingsContext(input, response); } } diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs index 43125c3e..319c2820 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs @@ -2,7 +2,7 @@ // Licensed under the MIT License. using System.Diagnostics; -using Azure.AI.OpenAI; +using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; static class EmbeddingsHelper @@ -17,7 +17,7 @@ static EmbeddingsHelper() httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(UserAgent); } - public static async Task BuildRequest(int maxOverlap, int maxChunkLength, string model, InputType inputType, string input) + public static async Task> BuildRequest(int maxOverlap, int maxChunkLength, string model, InputType inputType, string input) { using TextReader reader = await GetTextReader(inputType, input); if (maxOverlap >= maxChunkLength) @@ -26,7 +26,7 @@ public static async Task BuildRequest(int maxOverlap, int max } List chunks = GetTextChunks(reader, 0, maxChunkLength, maxOverlap).ToList(); - return new EmbeddingsOptions(model, chunks); + return chunks; } static async Task GetTextReader(InputType inputType, string input) diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs index e1705ac5..2b5fdacb 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs @@ -4,19 +4,19 @@ using System.ClientModel.Primitives; using System.Text.Json; using System.Text.Json.Serialization; -using Azure.AI.OpenAI; +using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; -class EmbeddingsOptionsJsonConverter : JsonConverter +class EmbeddingsOptionsJsonConverter : JsonConverter { static readonly ModelReaderWriterOptions modelReaderWriterOptions = new("J"); - public override EmbeddingsOptions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override EmbeddingGenerationOptions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { throw new NotImplementedException(); } - public override void Write(Utf8JsonWriter writer, EmbeddingsOptions value, JsonSerializerOptions options) + public override void Write(Utf8JsonWriter writer, EmbeddingGenerationOptions value, JsonSerializerOptions options) { - ((IJsonModel)value).Write(writer, modelReaderWriterOptions); + ((IJsonModel)value).Write(writer, modelReaderWriterOptions); } } diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs index f51285d5..b995152b 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs @@ -1,19 +1,21 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.ClientModel; using System.Text.Json; using Azure; +using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using OpenAI.Embeddings; using OpenAISDK = Azure.AI.OpenAI; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; class EmbeddingsStoreConverter : IAsyncConverter> { - readonly OpenAISDK.OpenAIClient openAIClient; - readonly ILogger logger; + readonly EmbeddingClient embeddingClient; readonly ILogger logger; readonly ISearchProvider? searchProvider; // Note: we need this converter as Azure.AI.OpenAI does not support System.Text.Json serialization since their constructors are internal @@ -22,12 +24,13 @@ class EmbeddingsStoreConverter : Converters = { new EmbeddingsContextConverter(), new SearchableDocumentJsonConverter() } }; - public EmbeddingsStoreConverter(OpenAISDK.OpenAIClient openAIClient, + public EmbeddingsStoreConverter(AzureOpenAIClient azureOpenAIClient, ILoggerFactory loggerFactory, IEnumerable searchProviders, IOptions openAiConfigOptions) { - this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient)); + // ToDo: Handle the deployment name retrieval better + this.embeddingClient = azureOpenAIClient.GetEmbeddingClient("embedding") ?? throw new ArgumentNullException(nameof(azureOpenAIClient)); this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); openAiConfigOptions.Value.SearchProvider.TryGetValue("type", out object value); this.searchProvider = searchProviders? @@ -41,7 +44,7 @@ public Task> ConvertAsync(EmbeddingsStoreAtt throw new InvalidOperationException( "No search provider is configured. Search providers are configured in the host.json file. For .NET apps, the appropriate nuget package must also be added to the app's project file."); } - IAsyncCollector collector = new SemanticDocumentCollector(input, this.searchProvider, this.openAIClient, this.logger); + IAsyncCollector collector = new SemanticDocumentCollector(input, this.searchProvider, this.embeddingClient, this.logger); return Task.FromResult(collector); } @@ -62,14 +65,14 @@ sealed class SemanticDocumentCollector : IAsyncCollector { readonly EmbeddingsStoreAttribute attribute; readonly ISearchProvider searchProvider; - readonly OpenAISDK.OpenAIClient openAIClient; + readonly EmbeddingClient embeddingClient; readonly ILogger logger; - public SemanticDocumentCollector(EmbeddingsStoreAttribute attribute, ISearchProvider searchProvider, OpenAISDK.OpenAIClient openAIClient, ILogger logger) + public SemanticDocumentCollector(EmbeddingsStoreAttribute attribute, ISearchProvider searchProvider, EmbeddingClient embeddingClient, ILogger logger) { this.attribute = attribute; this.searchProvider = searchProvider; - this.openAIClient = openAIClient; + this.embeddingClient = embeddingClient; this.logger = logger; } @@ -85,10 +88,10 @@ public async Task AddAsync(SearchableDocument item, CancellationToken cancellati } // Get embeddings from OpenAI - OpenAISDK.EmbeddingsOptions request = await EmbeddingsHelper.BuildRequest(this.attribute.MaxOverlap, this.attribute.MaxChunkLength, this.attribute.Model, this.attribute.InputType, this.attribute.Input); - this.logger.LogInformation("Sending OpenAI embeddings request to deployment: {deploymentName}", request.DeploymentName); - Response response = await this.openAIClient.GetEmbeddingsAsync(request, cancellationToken); - EmbeddingsContext embeddingsContext = new(request, response); + List input = await EmbeddingsHelper.BuildRequest(this.attribute.MaxOverlap, this.attribute.MaxChunkLength, this.attribute.Model, this.attribute.InputType, this.attribute.Input); + this.logger.LogInformation("Sending OpenAI embeddings request"); + ClientResult response = await this.embeddingClient.GenerateEmbeddingsAsync(input); + EmbeddingsContext embeddingsContext = new (input, response); this.logger.LogInformation("Received OpenAI embeddings of count: {count}", embeddingsContext.Count); // Add document to the embed store diff --git a/src/WebJobs.Extensions.OpenAI/Models/ChatMessage.cs b/src/WebJobs.Extensions.OpenAI/Models/AssistantMessage.cs similarity index 71% rename from src/WebJobs.Extensions.OpenAI/Models/ChatMessage.cs rename to src/WebJobs.Extensions.OpenAI/Models/AssistantMessage.cs index b52e0976..6ae82c58 100644 --- a/src/WebJobs.Extensions.OpenAI/Models/ChatMessage.cs +++ b/src/WebJobs.Extensions.OpenAI/Models/AssistantMessage.cs @@ -9,18 +9,17 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; /// Chat Message Entity which contains the content of the message, the role of the chat agent, and the name of the calling function if applicable. /// [JsonObject(MemberSerialization.OptIn)] -public class ChatMessage +public class AssistantMessage { /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The content of the message. /// The role of the chat agent. - public ChatMessage(string content, string role, string? name) + public AssistantMessage(string content, string role) { this.Content = content; this.Role = role; - this.Name = name; } /// @@ -34,10 +33,4 @@ public ChatMessage(string content, string role, string? name) /// [JsonProperty("role")] public string Role { get; set; } - - /// - /// Gets or sets the name of the calling function if applicable. - /// - [JsonProperty("name")] - public string? Name { get; set; } } diff --git a/src/WebJobs.Extensions.OpenAI/Models/AssistantState.cs b/src/WebJobs.Extensions.OpenAI/Models/AssistantState.cs index bb1775d7..6dcd68d1 100644 --- a/src/WebJobs.Extensions.OpenAI/Models/AssistantState.cs +++ b/src/WebJobs.Extensions.OpenAI/Models/AssistantState.cs @@ -18,7 +18,7 @@ public AssistantState( DateTime LastUpdatedAt, int TotalMessages, int TotalTokens, - IReadOnlyList RecentMessages) + IReadOnlyList RecentMessages) { this.Id = Id; this.Exists = Exists; @@ -69,5 +69,5 @@ public AssistantState( /// Gets a list of the recent messages from the assistant. /// [JsonProperty("recentMessages")] - public IReadOnlyList RecentMessages { get; set; } = Array.Empty(); + public IReadOnlyList RecentMessages { get; set; } = Array.Empty(); } diff --git a/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs b/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs index 022350d9..e38e5f12 100644 --- a/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs +++ b/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs @@ -1,9 +1,11 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Runtime.Serialization; +using System.Text.Json; using Azure; -using Azure.AI.OpenAI; using Azure.Data.Tables; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; @@ -19,8 +21,9 @@ public ChatMessageTableEntity( string partitionKey, int messageIndex, string content, - ChatRole role, - string? name = null) + ChatMessageRole role, + string? name = null, + IEnumerable? toolCalls = null) { this.PartitionKey = partitionKey; this.RowKey = GetRowKey(messageIndex); @@ -28,6 +31,7 @@ public ChatMessageTableEntity( this.Role = role.ToString(); this.Name = name; this.CreatedAt = DateTime.UtcNow; + this.ToolCalls = toolCalls?.ToList(); } public ChatMessageTableEntity(TableEntity entity) @@ -40,6 +44,7 @@ public ChatMessageTableEntity(TableEntity entity) this.Role = entity.GetString(nameof(this.Role)); this.Name = entity.GetString(nameof(this.Name)); this.CreatedAt = DateTime.SpecifyKind(entity.GetDateTime(nameof(this.CreatedAt)).GetValueOrDefault(), DateTimeKind.Utc); + this.ToolCallsString = entity.GetString(nameof(this.ToolCalls)); } /// @@ -82,10 +87,99 @@ public ChatMessageTableEntity(TableEntity entity) /// public DateTime CreatedAt { get; set; } + /// + /// Gets or sets the ToolCalls for Assistant + /// + [IgnoreDataMember] + public IList? ToolCalls { get; set; } + // WARNING: Changing this is a breaking change! static string GetRowKey(int messageNumber) { // Example msg-001B return string.Concat(RowKeyPrefix, messageNumber.ToString("X4")); } + + /// + /// Converts the ToolCalls to a Json string for table storage + /// + [DataMember(Name = "ToolCalls")] + public string ToolCallsString + { + get + { + if (this.ToolCalls == null || this.ToolCalls.Count == 0) + { + return string.Empty; + } + + IList cloneList = this.SerializeChatTool(this.ToolCalls); + var options = new JsonSerializerOptions { WriteIndented = false, PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; + return JsonSerializer.Serialize(cloneList, options); + } + set + { + if (!string.IsNullOrEmpty(value)) + { + var options = new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; + var cloneList = JsonSerializer.Deserialize>(value, options); + this.ToolCalls = cloneList != null ? this.DeserializeChatTool(cloneList) : new List(); + } + else + { + this.ToolCalls = new List(); + } + } + } + + IList SerializeChatTool(IList toolCalls) + { + IList chatToolCloneList = new List(); + foreach (ChatToolCall toolCall in toolCalls) + { + ChatToolCallClone chatToolClone = new ChatToolCallClone(toolCall.Id, toolCall.FunctionName, toolCall.FunctionArguments.ToString(), toolCall.Kind.ToString()); + chatToolCloneList.Add(chatToolClone); + } + return chatToolCloneList; + } + + IList DeserializeChatTool(IList clones) + { + IList result = new List(); + foreach (var clone in clones) + { + var kind = Enum.Parse(clone.Kind); + var functionArgs = JsonDocument.Parse(clone.FunctionArguments).RootElement; + var toolCall = ChatToolCall.CreateFunctionToolCall(clone.Id, clone.FunctionName, BinaryData.FromString(functionArgs.GetRawText())); + result.Add(toolCall); + } + return result; + } } + +class ChatToolCallClone +{ + public ChatToolCallClone() + { + this.Id = string.Empty; + this.FunctionName = string.Empty; + this.FunctionArguments = string.Empty; + this.Kind = string.Empty; + } + + internal ChatToolCallClone(string id, string functionName, string functionArguments, string kind) + { + this.Id = id; + this.FunctionName = functionName; + this.Kind = kind; + this.FunctionArguments = functionArguments; + } + + public string Id { get; set; } + + public string FunctionName { get; set; } + + public string FunctionArguments { get; set; } + + public string Kind { get; set; } +} \ No newline at end of file diff --git a/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs b/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs new file mode 100644 index 00000000..059ec8fe --- /dev/null +++ b/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs @@ -0,0 +1,61 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Concurrent; +using Azure; +using Azure.AI.OpenAI; +using Azure.Identity; +using Microsoft.Extensions.DependencyInjection; +using OpenAI; + +namespace Microsoft.Azure.WebJobs.Extensions.OpenAI; + +public class OpenAIClientFactory +{ + static readonly ConcurrentDictionary _azureOpenAIclients = new(); + static readonly ConcurrentDictionary _openAIClients = new(); + + public static AzureOpenAIClient CreateAzureOpenAIClient(string endpoint, string apiKey) + { + string key = $"{endpoint}-{apiKey}"; + return _azureOpenAIclients.GetOrAdd(key, _ => new AzureOpenAIClient(new Uri(endpoint), new AzureKeyCredential(apiKey))); + } + + public static AzureOpenAIClient CreateAzureOpenAIClientWithDefaultAzureCredential(string endpoint) + { + return _azureOpenAIclients.GetOrAdd(endpoint, _ => new AzureOpenAIClient(new Uri(endpoint), new DefaultAzureCredential())); + } + + public static OpenAIClient CreateOpenAIClient(string apiKey) + { + return _openAIClients.GetOrAdd(apiKey, _ => new OpenAIClient(apiKey)); + } +} + +public static class ServiceCollectionExtensions +{ + public static IServiceCollection AddAzureOpenAIClient( + this IServiceCollection services, + string endpoint, + string apiKey) + { + services.AddSingleton(sp => OpenAIClientFactory.CreateAzureOpenAIClient(endpoint, apiKey)); + return services; + } + + public static IServiceCollection AddAzureOpenAIClientWithDefaultAzureCredential( + this IServiceCollection services, + string endpoint) + { + services.AddSingleton(sp => OpenAIClientFactory.CreateAzureOpenAIClientWithDefaultAzureCredential(endpoint)); + return services; + } + + public static IServiceCollection AddOpenAIClient( + this IServiceCollection services, + string apiKey) + { + services.AddSingleton(sp => OpenAIClientFactory.CreateOpenAIClient(apiKey)); + return services; + } +} \ No newline at end of file diff --git a/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs b/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs index 4204b525..16e82ed0 100644 --- a/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs +++ b/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs @@ -15,21 +15,21 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI; [Extension("OpenAI")] partial class OpenAIExtension : IExtensionConfigProvider { - readonly OpenAIClient openAIClient; + readonly AzureOpenAIClient openAIClient; readonly TextCompletionConverter textCompletionConverter; readonly EmbeddingsConverter embeddingsConverter; readonly EmbeddingsStoreConverter embeddingsStoreConverter; readonly SemanticSearchConverter semanticSearchConverter; - readonly AssistantBindingConverter chatBotConverter; + readonly AssistantBindingConverter assistantConverter; readonly AssistantSkillTriggerBindingProvider assistantskillTriggerBindingProvider; public OpenAIExtension( - OpenAIClient openAIClient, + AzureOpenAIClient openAIClient, TextCompletionConverter textCompletionConverter, EmbeddingsConverter embeddingsConverter, EmbeddingsStoreConverter embeddingsStoreConverter, SemanticSearchConverter semanticSearchConverter, - AssistantBindingConverter chatBotConverter, + AssistantBindingConverter assistantConverter, AssistantSkillTriggerBindingProvider assistantTriggerBindingProvider) { this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient)); @@ -37,7 +37,7 @@ public OpenAIExtension( this.embeddingsConverter = embeddingsConverter ?? throw new ArgumentNullException(nameof(embeddingsConverter)); this.embeddingsStoreConverter = embeddingsStoreConverter ?? throw new ArgumentNullException(nameof(embeddingsStoreConverter)); this.semanticSearchConverter = semanticSearchConverter ?? throw new ArgumentNullException(nameof(semanticSearchConverter)); - this.chatBotConverter = chatBotConverter ?? throw new ArgumentNullException(nameof(chatBotConverter)); + this.assistantConverter = assistantConverter ?? throw new ArgumentNullException(nameof(assistantConverter)); this.assistantskillTriggerBindingProvider = assistantTriggerBindingProvider ?? throw new ArgumentNullException(nameof(assistantTriggerBindingProvider)); } @@ -64,18 +64,18 @@ void IExtensionConfigProvider.Initialize(ExtensionConfigContext context) semanticSearchRule.BindToInput(this.semanticSearchConverter); // Assistant support - var chatBotCreateRule = context.AddBindingRule(); - chatBotCreateRule.BindToCollector(this.chatBotConverter); - context.AddConverter(this.chatBotConverter.ToAssistantCreateRequest); - context.AddConverter(this.chatBotConverter.ToAssistantCreateRequest); + var assistantCreateRule = context.AddBindingRule(); + assistantCreateRule.BindToCollector(this.assistantConverter); + context.AddConverter(this.assistantConverter.ToAssistantCreateRequest); + context.AddConverter(this.assistantConverter.ToAssistantCreateRequest); - var chatBotPostRule = context.AddBindingRule(); - chatBotPostRule.BindToInput(this.chatBotConverter); - chatBotPostRule.BindToInput(this.chatBotConverter); + var assistantPostRule = context.AddBindingRule(); + assistantPostRule.BindToInput(this.assistantConverter); + assistantPostRule.BindToInput(this.assistantConverter); - var chatBotQueryRule = context.AddBindingRule(); - chatBotQueryRule.BindToInput(this.chatBotConverter); - chatBotQueryRule.BindToInput(this.chatBotConverter); + var assistantQueryRule = context.AddBindingRule(); + assistantQueryRule.BindToInput(this.assistantConverter); + assistantQueryRule.BindToInput(this.assistantConverter); // Assistant skill trigger support context.AddBindingRule() diff --git a/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs b/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs index 3b49603b..a1fd498d 100644 --- a/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs +++ b/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs @@ -1,9 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using Azure; -using Azure.AI.OpenAI; -using Azure.Identity; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; @@ -32,7 +29,7 @@ public static IWebJobsBuilder AddOpenAIBindings(this IWebJobsBuilder builder) } // Register the client for Azure Open AI - Uri? azureOpenAIEndpoint = GetAzureOpenAIEndpoint(); + string? azureOpenAIEndpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); string? openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); string? azureOpenAIKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_KEY"); @@ -80,33 +77,18 @@ public static IWebJobsBuilder AddOpenAIBindings(this IWebJobsBuilder builder) return builder; } - static Uri? GetAzureOpenAIEndpoint() + static void RegisterAzureOpenAIClient(IServiceCollection services, string azureOpenAIEndpoint, string azureOpenAIKey) { - if (Uri.TryCreate(Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"), UriKind.Absolute, out var uri)) - { - return uri; - } - - return null; - } - - static void RegisterAzureOpenAIClient(IServiceCollection services, Uri azureOpenAIEndpoint, string azureOpenAIKey) - { - services.AddAzureClients(clientBuilder => - { - clientBuilder.AddOpenAIClient(azureOpenAIEndpoint, new AzureKeyCredential(azureOpenAIKey)); - }); + services.AddAzureOpenAIClient(azureOpenAIEndpoint, azureOpenAIKey); } - static void RegisterAzureOpenAIADAuthClient(IServiceCollection services, Uri azureOpenAIEndpoint) + static void RegisterAzureOpenAIADAuthClient(IServiceCollection services, string azureOpenAIEndpoint) { - var managedIdentityClient = new OpenAIClient(azureOpenAIEndpoint, new DefaultAzureCredential()); - services.AddSingleton(managedIdentityClient); + services.AddAzureOpenAIClientWithDefaultAzureCredential(azureOpenAIEndpoint); } static void RegisterOpenAIClient(IServiceCollection services, string openAIKey) { - var openAIClient = new OpenAIClient(openAIKey); - services.AddSingleton(openAIClient); + services.AddOpenAIClient(openAIKey); } } diff --git a/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs b/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs index 0bbf9b55..5bbda29c 100644 --- a/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs @@ -5,7 +5,7 @@ using System.Text.Json; using System.Text.Json.Serialization; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; -using OpenAISDK = Azure.AI.OpenAI; +using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; class SearchableDocumentJsonConverter : JsonConverter @@ -16,8 +16,8 @@ public override SearchableDocument Read(ref Utf8JsonReader reader, Type typeToCo using JsonDocument jsonDocument = JsonDocument.ParseValue(ref reader); // Properties for SearchableDocument - OpenAISDK.EmbeddingsOptions embeddingsOptions = new(); - OpenAISDK.Embeddings? embeddings = null; + IList input = new List(); + OpenAIEmbeddingCollection? embeddings = null; int count; string title = string.Empty; string connectionName = string.Empty; @@ -29,13 +29,13 @@ public override SearchableDocument Read(ref Utf8JsonReader reader, Type typeToCo { foreach (JsonProperty embeddingContextItem in item.Value.EnumerateObject()) { - if (embeddingContextItem.NameEquals("request"u8)) + if (embeddingContextItem.NameEquals("input"u8)) { - embeddingsOptions = ModelReaderWriter.Read(BinaryData.FromString(embeddingContextItem.Value.GetRawText()))!; + input = new List(); // ToDo: revisit } if (embeddingContextItem.NameEquals("response"u8)) { - embeddings = ModelReaderWriter.Read(BinaryData.FromString(embeddingContextItem.Value.GetRawText()))!; + embeddings = ModelReaderWriter.Read(BinaryData.FromString(embeddingContextItem.Value.GetRawText()))!; } if (embeddingContextItem.NameEquals("count"u8)) { @@ -65,7 +65,7 @@ public override SearchableDocument Read(ref Utf8JsonReader reader, Type typeToCo } SearchableDocument searchableDocument = new SearchableDocument(title) { - Embeddings = new EmbeddingsContext(embeddingsOptions, embeddings), + Embeddings = new EmbeddingsContext(input, embeddings), ConnectionInfo = new ConnectionInfo(connectionName, collectionName), }; return searchableDocument; @@ -78,13 +78,13 @@ public override void Write(Utf8JsonWriter writer, SearchableDocument value, Json writer.WritePropertyName("embeddingsContext"u8); writer.WriteStartObject(); - if (value.Embeddings?.Request is IJsonModel request) + if (value.Embeddings?.Input is IJsonModel request) { - writer.WritePropertyName("request"u8); + writer.WritePropertyName("input"u8); request.Write(writer, modelReaderWriterOptions); } - if (value.Embeddings?.Response is IJsonModel response) + if (value.Embeddings?.Response is IJsonModel response) { writer.WritePropertyName("response"u8); response.Write(writer, modelReaderWriterOptions); diff --git a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchContext.cs b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchContext.cs index adb000f7..a51b8fc9 100644 --- a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchContext.cs +++ b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchContext.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; using Newtonsoft.Json; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; @@ -18,7 +18,7 @@ public class SemanticSearchContext /// /// The embeddings context associated with the semantic search. /// The chat response from the large language model. - public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletions Chat) + public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletion Chat) { this.Embeddings = Embeddings; this.Chat = Chat; @@ -34,11 +34,11 @@ public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletions Chat) /// Chat response from the chat completions request. /// [JsonProperty("chat")] - public ChatCompletions Chat { get; } + public ChatCompletion Chat { get; } /// /// Gets the latest response message from the OpenAI Chat API. /// [JsonProperty("response")] - public string Response => this.Chat.Choices.Last().Message.Content; + public string Response => this.Chat.Content.LastOrDefault().Text; } diff --git a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs index b853036b..b3799884 100644 --- a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs @@ -1,13 +1,18 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.ClientModel; using System.Text; using System.Text.Json; using Azure; +using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using OpenAI; +using OpenAI.Chat; +using OpenAI.Embeddings; using OpenAISDK = Azure.AI.OpenAI; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; @@ -16,7 +21,8 @@ class SemanticSearchConverter : IAsyncConverter, IAsyncConverter { - readonly OpenAISDK.OpenAIClient openAIClient; + readonly ChatClient chatClient; + readonly EmbeddingClient embeddingClient; readonly ILogger logger; readonly ISearchProvider? searchProvider; @@ -30,12 +36,14 @@ class SemanticSearchConverter : }; public SemanticSearchConverter( - OpenAISDK.OpenAIClient openAIClient, + AzureOpenAIClient azureOpenAIClient, ILoggerFactory loggerFactory, IEnumerable searchProviders, IOptions openAiConfigOptions) { - this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient)); + //ToDo: Handle retrieval of the model better + this.chatClient = azureOpenAIClient.GetChatClient(deploymentName: Environment.GetEnvironmentVariable("CHAT_MODEL_DEPLOYMENT_NAME")) ?? throw new ArgumentNullException(nameof(azureOpenAIClient)); + this.embeddingClient = azureOpenAIClient.GetEmbeddingClient(deploymentName: Environment.GetEnvironmentVariable("EMBEDDING_MODEL_DEPLOYMENT_NAME")) ?? throw new ArgumentNullException(nameof(azureOpenAIClient)); this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); openAiConfigOptions.Value.SearchProvider.TryGetValue("type", out object value); @@ -61,12 +69,9 @@ async Task ConvertHelperAsync( } // Get the embeddings for the query, which will be used for doing a semantic search - OpenAISDK.EmbeddingsOptions embeddingsRequest = new(attribute.EmbeddingsModel, new List { attribute.Query }); - - this.logger.LogInformation("Sending OpenAI embeddings request: {request}", embeddingsRequest.Input); - Response embeddingsResponse = await this.openAIClient.GetEmbeddingsAsync(embeddingsRequest, cancellationToken); - this.logger.LogInformation("Received OpenAI embeddings count: {response}", embeddingsResponse.Value.Data.Count); - + this.logger.LogInformation("Sending OpenAI embeddings request: {request}", attribute.Query); + ClientResult embedding = await this.embeddingClient.GenerateEmbeddingAsync(attribute.Query, cancellationToken: cancellationToken); + this.logger.LogInformation("Received OpenAI embeddings"); ConnectionInfo connectionInfo = new(attribute.ConnectionName, attribute.Collection); if (string.IsNullOrEmpty(connectionInfo.ConnectionName)) @@ -81,7 +86,7 @@ async Task ConvertHelperAsync( // Search for relevant document snippets using the original query and the embeddings SearchRequest searchRequest = new( attribute.Query, - embeddingsResponse.Value.Data[0].Embedding, + embedding.Value.ToFloats(), attribute.MaxKnowledgeCount, connectionInfo); SearchResponse searchResponse = await this.searchProvider.SearchAsync(searchRequest); @@ -95,20 +100,16 @@ async Task ConvertHelperAsync( } // Call the chat API with the new combined prompt to get a response back - OpenAISDK.ChatCompletionsOptions chatCompletionsOptions = new() - { - DeploymentName = attribute.ChatModel, - Messages = + IList messages = new List() { - new OpenAISDK.ChatRequestSystemMessage(promptBuilder.ToString()), - new OpenAISDK.ChatRequestUserMessage(attribute.Query), - } - }; + new SystemChatMessage(promptBuilder.ToString()), + new UserChatMessage(attribute.Query), + }; - Response chatResponse = await this.openAIClient.GetChatCompletionsAsync(chatCompletionsOptions); + ClientResult chatResponse = await this.chatClient.CompleteChatAsync(messages); // Give the user the full context, including the embeddings information as well as the chat info - return new SemanticSearchContext(new EmbeddingsContext(embeddingsRequest, embeddingsResponse), chatResponse); + return new SemanticSearchContext(new EmbeddingsContext(new List { attribute.Query }, null), chatResponse); } async Task IAsyncConverter.ConvertAsync( diff --git a/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs b/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs index 2ad00ac8..7e6bf10d 100644 --- a/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Description; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI; @@ -66,20 +66,12 @@ public TextCompletionAttribute(string prompt) [AutoResolve] public string? MaxTokens { get; set; } = "100"; - internal ChatCompletionsOptions BuildRequest() + internal ChatCompletionOptions BuildRequest() { - ChatCompletionsOptions request = new() - { - DeploymentName = this.Model, - Messages = - { - new ChatRequestUserMessage(this.Prompt), - } - }; - + ChatCompletionOptions request = new(); if (int.TryParse(this.MaxTokens, out int maxTokens)) { - request.MaxTokens = maxTokens; + request.MaxOutputTokenCount = maxTokens; } if (float.TryParse(this.Temperature, out float temperature)) @@ -89,7 +81,7 @@ internal ChatCompletionsOptions BuildRequest() if (float.TryParse(this.TopP, out float topP)) { - request.NucleusSamplingFactor = topP; + request.TopP = topP; } return request; diff --git a/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs b/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs index b936ea44..8f183be5 100644 --- a/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs @@ -1,11 +1,12 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using Azure; +using System.ClientModel; using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; using Microsoft.Extensions.Logging; using Newtonsoft.Json; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI; @@ -13,12 +14,13 @@ class TextCompletionConverter : IAsyncConverter, IAsyncConverter { - readonly OpenAIClient openAIClient; + readonly ChatClient chatClient; readonly ILogger logger; - public TextCompletionConverter(OpenAIClient openAIClient, ILoggerFactory loggerFactory) + public TextCompletionConverter(AzureOpenAIClient openAIClient, ILoggerFactory loggerFactory) { - this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient)); + // ToDo: Handle the model retrieval better + this.chatClient = openAIClient.GetChatClient(deploymentName: Environment.GetEnvironmentVariable("CHAT_MODEL_DEPLOYMENT_NAME")) ?? throw new ArgumentNullException(nameof(openAIClient)); this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); } @@ -43,18 +45,20 @@ async Task ConvertCoreAsync( TextCompletionAttribute attribute, CancellationToken cancellationToken) { - ChatCompletionsOptions options = attribute.BuildRequest(); - this.logger.LogInformation("Sending OpenAI completion request: {request}", options); + ChatCompletionOptions options = attribute.BuildRequest(); + this.logger.LogInformation("Sending OpenAI completion request with prompt: {request}", attribute.Prompt); - Response response = await this.openAIClient.GetChatCompletionsAsync( - options, - cancellationToken); + IList chatMessages = new List() + { + new UserChatMessage(attribute.Prompt) + }; + + ClientResult response = await this.chatClient.CompleteChatAsync(chatMessages, options); string text = string.Join( Environment.NewLine + Environment.NewLine, - response.Value.Choices.Select(choice => choice.Message.Content)); - TextCompletionResponse textCompletionResponse = new(text, response.Value.Usage.TotalTokens); - + response.Value.Content[0].Text); + TextCompletionResponse textCompletionResponse = new(text, response.Value.Usage.TotalTokenCount); return textCompletionResponse; } } diff --git a/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj b/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj index 323285c0..e802f7d3 100644 --- a/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj +++ b/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj @@ -5,7 +5,7 @@ - + @@ -14,4 +14,4 @@ - + \ No newline at end of file From 2424e728b28b6b947a1f4bd6b623430c7b2fee3a Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Fri, 11 Apr 2025 15:42:33 -0700 Subject: [PATCH 02/21] managed identity update for azure openai --- README.md | 75 +++++++-- .../assistant/csharp-legacy/AssistantApis.cs | 2 +- .../assistant/csharp-ooproc/AssistantApis.cs | 6 +- samples/chat/README.md | 2 +- samples/chat/csharp-legacy/ChatBot.cs | 4 +- samples/chat/csharp-ooproc/ChatBot.cs | 12 +- .../csharp-legacy/EmbeddingsLegacy.cs | 13 +- .../Embeddings/EmbeddingsGenerator.cs | 6 +- samples/rag-aisearch/README.md | 4 +- .../rag-aisearch/csharp-legacy/FilePrompt.cs | 2 +- .../rag-aisearch/csharp-ooproc/FilePrompt.cs | 2 +- .../rag-cosmosdb/csharp-legacy/FilePrompt.cs | 2 +- .../rag-cosmosdb/csharp-ooproc/FilePrompt.cs | 2 +- samples/rag-kusto/csharp-legacy/FilePrompt.cs | 2 +- .../csharp-ooproc/EmailPromptDemo.cs | 2 +- .../csharp-legacy/TextCompletionLegacy.cs | 4 +- .../csharp-ooproc/TextCompletions.cs | 4 +- .../Assistants/AssistantPostInputAttribute.cs | 58 ++++++- .../Assistants/ChatMessage.cs | 10 +- .../Embeddings/EmbeddingsInputAttribute.cs | 25 ++- .../EmbeddingsStoreOutputAttribute.cs | 33 +++- .../Search/SemanticSearchInputAttribute.cs | 57 ++++++- .../TextCompletionInputAttribute.cs | 23 ++- .../Assistants/AssistantBaseAttribute.cs | 105 +++++++++++++ .../Assistants/AssistantPostAttribute.cs | 19 +-- .../Assistants/AssistantService.cs | 24 ++- .../Embeddings/EmbeddingsAttribute.cs | 3 +- .../Embeddings/EmbeddingsBaseAttribute.cs | 27 +++- .../Embeddings/EmbeddingsConverter.cs | 19 ++- .../Embeddings/EmbeddingsHelper.cs | 3 +- .../Embeddings/EmbeddingsStoreAttribute.cs | 11 +- .../Embeddings/EmbeddingsStoreConverter.cs | 38 +++-- .../Models/AssistantMessage.cs | 10 +- .../OpenAIClientFactory.cs | 144 ++++++++++++++---- .../OpenAIExtension.cs | 8 +- .../OpenAIWebJobsBuilderExtensions.cs | 41 +---- .../Search/SemanticSearchAttribute.cs | 26 ++-- .../Search/SemanticSearchConverter.cs | 26 ++-- .../TextCompletionAttribute.cs | 66 +------- .../TextCompletionConverter.cs | 12 +- tests/SampleValidation/AssistantTests.cs | 18 ++- tests/SampleValidation/Chat.cs | 4 +- tests/SampleValidation/EmbeddingsTests.cs | 83 ++++++++++ 43 files changed, 734 insertions(+), 303 deletions(-) create mode 100644 src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs create mode 100644 tests/SampleValidation/EmbeddingsTests.cs diff --git a/README.md b/README.md index 537adecc..55d0440c 100644 --- a/README.md +++ b/README.md @@ -29,18 +29,75 @@ Add following section to `host.json` of the function app for non dotnet language ## Requirements -* [.NET 6 SDK](https://dotnet.microsoft.com/download/dotnet/6.0) or greater (Visual Studio 2022 recommended) +* [.NET 8 SDK](https://dotnet.microsoft.com/download/dotnet/8.0) or greater (Visual Studio 2022 recommended) * [Azure Functions Core Tools v4.x](https://learn.microsoft.com/azure/azure-functions/functions-run-local?tabs=v4%2Cwindows%2Cnode%2Cportal%2Cbash) +* Azure Storage emulator such as [Azurite](https://learn.microsoft.com/azure/storage/common/storage-use-azurite) running in the background +* The target language runtime (e.g. dotnet, nodejs, powershell, python, java) installed on your machine. Refer the official supported versions. * Update settings in Azure Function or the `local.settings.json` file for local development with the following keys: - 1. For Azure, `AZURE_OPENAI_ENDPOINT` - [Azure OpenAI resource](https://learn.microsoft.com/azure/ai-services/openai/how-to/create-resource?pivots=web-portal) (e.g. `https://***.openai.azure.com/`) set. - 1. For Azure, assign the user or function app managed identity `Cognitive Services OpenAI User` role on the Azure OpenAI resource. It is strongly recommended to use managed identity to avoid overhead of secrets maintenance, however if there is a need for key based authentication add the setting `AZURE_OPENAI_KEY` and its value in the settings. - 1. For non- Azure, `OPENAI_API_KEY` - An OpenAI account and an [API key](https://platform.openai.com/account/api-keys) saved into a setting. - If using environment variables, Learn more in [.env readme](./env/README.md). 1. Update `CHAT_MODEL_DEPLOYMENT_NAME` and `EMBEDDING_MODEL_DEPLOYMENT_NAME` keys to Azure Deployment names or override default OpenAI model names. - 1. If using user assigned managed identity, add `AZURE_CLIENT_ID` to environment variable settings with value as client id of the managed identity. 1. Visit binding specific samples README for additional settings that might be required for each binding. -* Azure Storage emulator such as [Azurite](https://learn.microsoft.com/azure/storage/common/storage-use-azurite) running in the background -* The target language runtime (e.g. dotnet, nodejs, powershell, python, java) installed on your machine. Refer the official supported versions. + 1. Refer [Configuring AI Service Connections](#configuring-ai-service-connections) + +## Configuring AI Service Connections + +The Azure Functions OpenAI Extension provides flexible options for configuring connections to AI services through the `AIConnectionName` property in the AssistantPost, TextCompletion, SemanticSearch, EmbeddingsStore, Embeddings bindings + +### Managed Identity Role + +Strongly recommended to use managed identity and ensure the user or function app's managed identity has the role - `Cognitive Services OpenAI User` + +### AIConnectionName Property + +The optional `AIConnectionName` property specifies the name of a configuration section that contains connection details for the AI service: + +#### For Azure OpenAI Service + +* If specified, the extension looks for `Endpoint` and `Key` values in the named configuration section +* If not specified or the configuration section doesn't exist, the extension falls back to environment variables: + * `AZURE_OPENAI_ENDPOINT` and/or + * `AZURE_OPENAI_KEY` +* For user-assigned managed identity authentication, a configuration section is required + + ```json + "__endpoint": "Placeholder for the Azure OpenAI endpoint value", + "__credential": "managedidentity", + "__managedIdentityResourceId": "Resource Id of managed identity", + "__managedIdentityClientId": "Client Id of managed identity" + ``` + + * Only one of managedIdentityResourceId or managedIdentityClientId should be specified, not both. + * If no Resource Id or Client Id is specified, the system-assigned managed identity will be used by default. + * Pass the configured `ConnectionNamePrefix` value, example `AzureOpenAI` to the `AIConnectionName` property. + +#### For OpenAI Service (non-Azure) + +* Set the `OPENAI_API_KEY` environment variable + +### Configuration Examples + +#### Example: Using a configuration section + +In `local.settings.json` or app environment variables: + +```json +"AzureOpenAI__endpoint": "Placeholder for the Azure OpenAI endpoint value", +"AzureOpenAI__credential": "managedidentity", +``` + +Specifying credential is optional for system assigned managed identity + +Function usage example: + +```csharp +[Function(nameof(PostUserResponse))] +public static IActionResult PostUserResponse( + [HttpTrigger(AuthorizationLevel.Function, "post", Route = "chats/{chatId}")] HttpRequestData req, + string chatId, + [AssistantPostInput("{chatId}", "{Query.message}", "AzureOpenAI", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state) +{ + return new OkObjectResult(state.RecentMessages.LastOrDefault()?.Content ?? "No response returned."); +} +``` ## Features @@ -270,7 +327,7 @@ public static async Task IngestFile( public class EmbeddingsStoreOutputResponse { - [EmbeddingsStoreOutput("{url}", InputType.Url, "AISearchEndpoint", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] + [EmbeddingsStoreOutput("{url}", InputType.Url, "AISearchEndpoint", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] public required SearchableDocument SearchableDocument { get; init; } public IActionResult? HttpResponse { get; set; } diff --git a/samples/assistant/csharp-legacy/AssistantApis.cs b/samples/assistant/csharp-legacy/AssistantApis.cs index 16f9d5bb..4bdacd5e 100644 --- a/samples/assistant/csharp-legacy/AssistantApis.cs +++ b/samples/assistant/csharp-legacy/AssistantApis.cs @@ -51,7 +51,7 @@ public static async Task CreateAssistant( public static IActionResult PostUserQuery( [HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = "assistants/{assistantId}")] HttpRequest req, string assistantId, - [AssistantPost("{assistantId}", "{Query.message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState updatedState) + [AssistantPost("{assistantId}", "{Query.message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState updatedState) { return new OkObjectResult(updatedState.RecentMessages.Any() ? updatedState.RecentMessages[updatedState.RecentMessages.Count - 1].Content : "No response returned."); } diff --git a/samples/assistant/csharp-ooproc/AssistantApis.cs b/samples/assistant/csharp-ooproc/AssistantApis.cs index 5d425c8f..a98b6fcd 100644 --- a/samples/assistant/csharp-ooproc/AssistantApis.cs +++ b/samples/assistant/csharp-ooproc/AssistantApis.cs @@ -59,10 +59,10 @@ public class CreateChatBotOutput /// HTTP POST function that sends user prompts to the assistant chat bot. /// [Function(nameof(PostUserQuery))] - public static async Task PostUserQuery( + public static IActionResult PostUserQuery( [HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = "assistants/{assistantId}")] HttpRequestData req, string assistantId, - [AssistantPostInput("{assistantId}", "{Query.message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state) + [AssistantPostInput("{assistantId}", "{Query.message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state) { return new OkObjectResult(state.RecentMessages.Any() ? state.RecentMessages[state.RecentMessages.Count - 1].Content : "No response returned."); } @@ -71,7 +71,7 @@ public static async Task PostUserQuery( /// HTTP GET function that queries the conversation history of the assistant chat bot. /// [Function(nameof(GetChatState))] - public static async Task GetChatState( + public static IActionResult GetChatState( [HttpTrigger(AuthorizationLevel.Anonymous, "get", Route = "assistants/{assistantId}")] HttpRequestData req, string assistantId, [AssistantQueryInput("{assistantId}", TimestampUtc = "{Query.timestampUTC}", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state) diff --git a/samples/chat/README.md b/samples/chat/README.md index d554a49b..1c030710 100644 --- a/samples/chat/README.md +++ b/samples/chat/README.md @@ -71,7 +71,7 @@ For additional details on using identity-based connections, refer to the [Azure public static async Task PostUserResponse( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "chats/{chatId}")] HttpRequestData req, string chatId, - [AssistantPostInput("{chatId}", "{message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] AssistantState state) + [AssistantPostInput("{chatId}", "{message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] AssistantState state) ``` diff --git a/samples/chat/csharp-legacy/ChatBot.cs b/samples/chat/csharp-legacy/ChatBot.cs index e7403df7..358d61d8 100644 --- a/samples/chat/csharp-legacy/ChatBot.cs +++ b/samples/chat/csharp-legacy/ChatBot.cs @@ -48,10 +48,10 @@ public static AssistantState GetChatState( } [FunctionName(nameof(PostUserResponse))] - public static async Task PostUserResponse( + public static IActionResult PostUserResponse( [HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = "chats/{chatId}")] HttpRequest req, string chatId, - [AssistantPost("{chatId}", "{Query.message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState updatedState) + [AssistantPost("{chatId}", "{Query.message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState updatedState) { return new OkObjectResult(updatedState.RecentMessages.Any() ? updatedState.RecentMessages[updatedState.RecentMessages.Count - 1].Content : "No response returned."); } diff --git a/samples/chat/csharp-ooproc/ChatBot.cs b/samples/chat/csharp-ooproc/ChatBot.cs index 2a94b4b4..9db059ce 100644 --- a/samples/chat/csharp-ooproc/ChatBot.cs +++ b/samples/chat/csharp-ooproc/ChatBot.cs @@ -23,8 +23,8 @@ public class CreateRequest [Function(nameof(CreateChatBot))] public static async Task CreateChatBot( - [HttpTrigger(AuthorizationLevel.Function, "put", Route = "chats/{chatId}")] HttpRequestData req, - string chatId) + [HttpTrigger(AuthorizationLevel.Function, "put", Route = "chats/{chatId}")] HttpRequestData req, + string chatId) { var responseJson = new { chatId }; CreateRequest? createRequestBody; @@ -39,7 +39,7 @@ public static async Task CreateChatBot( } catch (Exception ex) { - throw new ArgumentException("Invalid request body. Make sure that you pass in {\"instructions\": value } as the request body.", ex.Message); + throw new ArgumentException("Invalid request body. Make sure that you pass in {\"instructions\": value } as the request body.", ex); } return new CreateChatBotOutput @@ -63,16 +63,16 @@ public class CreateChatBotOutput } [Function(nameof(PostUserResponse))] - public static async Task PostUserResponse( + public static IActionResult PostUserResponse( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "chats/{chatId}")] HttpRequestData req, string chatId, - [AssistantPostInput("{chatId}", "{Query.message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state) + [AssistantPostInput("{chatId}", "{Query.message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state) { return new OkObjectResult(state.RecentMessages.LastOrDefault()?.Content ?? "No response returned."); } [Function(nameof(GetChatState))] - public static async Task GetChatState( + public static IActionResult GetChatState( [HttpTrigger(AuthorizationLevel.Function, "get", Route = "chats/{chatId}")] HttpRequestData req, string chatId, [AssistantQueryInput("{chatId}", TimestampUtc = "{Query.timestampUTC}", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state, diff --git a/samples/embeddings/csharp-legacy/EmbeddingsLegacy.cs b/samples/embeddings/csharp-legacy/EmbeddingsLegacy.cs index 58688120..12baeeca 100644 --- a/samples/embeddings/csharp-legacy/EmbeddingsLegacy.cs +++ b/samples/embeddings/csharp-legacy/EmbeddingsLegacy.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using System.Linq; using Microsoft.Azure.WebJobs; using Microsoft.Azure.WebJobs.Extensions.Http; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; @@ -23,13 +22,13 @@ public record EmbeddingsRequest(string RawText, string FilePath, string Url); [FunctionName(nameof(GenerateEmbeddings_Http_Request))] public static void GenerateEmbeddings_Http_Request( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings")] EmbeddingsRequest req, - [Embeddings("{RawText}", InputType.RawText)] EmbeddingsContext embeddings, + [Embeddings("{RawText}", InputType.RawText, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings, ILogger logger) { logger.LogInformation( - "Received {count} embedding(s) for input text containing {length} characters.", - embeddings.Count, - req.RawText.Length); + "Received {count} embedding(s) for input text containing {length} characters.", + embeddings.Count, + req.RawText.Length); // TODO: Store the embeddings into a database or other storage. } @@ -41,7 +40,7 @@ public static void GenerateEmbeddings_Http_Request( [FunctionName(nameof(GetEmbeddings_Http_FilePath))] public static void GetEmbeddings_Http_FilePath( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings-from-file")] EmbeddingsRequest req, - [Embeddings("{FilePath}", InputType.FilePath, MaxChunkLength = 512)] EmbeddingsContext embeddings, + [Embeddings("{FilePath}", InputType.FilePath, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%", MaxChunkLength = 512)] EmbeddingsContext embeddings, ILogger logger) { logger.LogInformation( @@ -59,7 +58,7 @@ public static void GetEmbeddings_Http_FilePath( [FunctionName(nameof(GetEmbeddings_Http_Url))] public static void GetEmbeddings_Http_Url( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings-from-url")] EmbeddingsRequest req, - [Embeddings("{Url}", InputType.Url, MaxChunkLength = 512)] EmbeddingsContext embeddings, + [Embeddings("{Url}", InputType.Url, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%", MaxChunkLength = 512)] EmbeddingsContext embeddings, ILogger logger) { logger.LogInformation( diff --git a/samples/embeddings/csharp-ooproc/Embeddings/EmbeddingsGenerator.cs b/samples/embeddings/csharp-ooproc/Embeddings/EmbeddingsGenerator.cs index 4b42dcca..84c5fa41 100644 --- a/samples/embeddings/csharp-ooproc/Embeddings/EmbeddingsGenerator.cs +++ b/samples/embeddings/csharp-ooproc/Embeddings/EmbeddingsGenerator.cs @@ -41,7 +41,7 @@ internal class EmbeddingsRequest [Function(nameof(GenerateEmbeddings_Http_RequestAsync))] public async Task GenerateEmbeddings_Http_RequestAsync( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings")] HttpRequestData req, - [EmbeddingsInput("{rawText}", InputType.RawText, Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings) + [EmbeddingsInput("{rawText}", InputType.RawText, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings) { using StreamReader reader = new(req.Body); string request = await reader.ReadToEndAsync(); @@ -63,7 +63,7 @@ public async Task GenerateEmbeddings_Http_RequestAsync( [Function(nameof(GetEmbeddings_Http_FilePath))] public async Task GetEmbeddings_Http_FilePath( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings-from-file")] HttpRequestData req, - [EmbeddingsInput("{filePath}", InputType.FilePath, MaxChunkLength = 512, Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings) + [EmbeddingsInput("{filePath}", InputType.FilePath, MaxChunkLength = 512, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings) { using StreamReader reader = new(req.Body); string request = await reader.ReadToEndAsync(); @@ -84,7 +84,7 @@ public async Task GetEmbeddings_Http_FilePath( [Function(nameof(GetEmbeddings_Http_URL))] public async Task GetEmbeddings_Http_URL( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings-from-url")] HttpRequestData req, - [EmbeddingsInput("{url}", InputType.Url, MaxChunkLength = 512, Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings) + [EmbeddingsInput("{url}", InputType.Url, MaxChunkLength = 512, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings) { using StreamReader reader = new(req.Body); string request = await reader.ReadToEndAsync(); diff --git a/samples/rag-aisearch/README.md b/samples/rag-aisearch/README.md index c56c2952..84e6609d 100644 --- a/samples/rag-aisearch/README.md +++ b/samples/rag-aisearch/README.md @@ -35,8 +35,6 @@ and optionally [enable semantic ranking](https://learn.microsoft.com/en-us/azure 1. User-Assigned Managed Identity: - // Note and TODO: Use of user-assigned managed identity for AI Search is breaking at the moment, since Azure OpenAI authentication also needs update to support a separate identity. - ```json "__endpoint": "https://.search.windows.net", "__credential": "managedidentity", @@ -56,7 +54,7 @@ and optionally [enable semantic ranking](https://learn.microsoft.com/en-us/azure Specifying credential is optional for system assigned managed identity 4. Binding Configuration - - Pass the configured `ConnectionNamePrefix` value, example `AISearch` to the `connectionName` property in the `SemanticSearchInput` or `EmbeddingsStoreOutput` bindings. Default is `AISearchEndpoint` if just the endpoint is being configured in local.settings.json or environment variables to use DefaultAzureCredential. + Pass the configured `ConnectionNamePrefix` value, example `AISearch` to the `searchConnectionName` property in the `SemanticSearchInput` or `EmbeddingsStoreOutput` bindings. Default is `AISearchEndpoint` if just the endpoint is being configured in local.settings.json or environment variables to use DefaultAzureCredential. ## Running the sample diff --git a/samples/rag-aisearch/csharp-legacy/FilePrompt.cs b/samples/rag-aisearch/csharp-legacy/FilePrompt.cs index df97a125..4c6b7f2f 100644 --- a/samples/rag-aisearch/csharp-legacy/FilePrompt.cs +++ b/samples/rag-aisearch/csharp-legacy/FilePrompt.cs @@ -20,7 +20,7 @@ public record SemanticSearchRequest(string Prompt); [FunctionName("IngestFile")] public static async Task IngestFile( [HttpTrigger(AuthorizationLevel.Function, "post")] EmbeddingsRequest req, - [EmbeddingsStore("{url}", InputType.Url, "AISearchEndpoint", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] + [EmbeddingsStore("{url}", InputType.Url, "AISearchEndpoint", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] IAsyncCollector output) { if (string.IsNullOrWhiteSpace(req.Url)) diff --git a/samples/rag-aisearch/csharp-ooproc/FilePrompt.cs b/samples/rag-aisearch/csharp-ooproc/FilePrompt.cs index d96b1d93..c9d85c6a 100644 --- a/samples/rag-aisearch/csharp-ooproc/FilePrompt.cs +++ b/samples/rag-aisearch/csharp-ooproc/FilePrompt.cs @@ -67,7 +67,7 @@ public static async Task IngestFile( public class EmbeddingsStoreOutputResponse { - [EmbeddingsStoreOutput("{url}", InputType.Url, "AISearchEndpoint", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] + [EmbeddingsStoreOutput("{url}", InputType.Url, "AISearchEndpoint", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] public required SearchableDocument SearchableDocument { get; init; } public IActionResult? HttpResponse { get; set; } diff --git a/samples/rag-cosmosdb/csharp-legacy/FilePrompt.cs b/samples/rag-cosmosdb/csharp-legacy/FilePrompt.cs index 3411dcb0..ba4d823c 100644 --- a/samples/rag-cosmosdb/csharp-legacy/FilePrompt.cs +++ b/samples/rag-cosmosdb/csharp-legacy/FilePrompt.cs @@ -20,7 +20,7 @@ public record SemanticSearchRequest(string Prompt); [FunctionName("IngestFile")] public static async Task IngestFile( [HttpTrigger(AuthorizationLevel.Function, "post")] EmbeddingsRequest req, - [EmbeddingsStore("{url}", InputType.Url, "CosmosDBMongoVCoreConnectionString", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] + [EmbeddingsStore("{url}", InputType.Url, "CosmosDBMongoVCoreConnectionString", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] IAsyncCollector output) { if (string.IsNullOrWhiteSpace(req.Url)) diff --git a/samples/rag-cosmosdb/csharp-ooproc/FilePrompt.cs b/samples/rag-cosmosdb/csharp-ooproc/FilePrompt.cs index 09be8588..5e1b964d 100644 --- a/samples/rag-cosmosdb/csharp-ooproc/FilePrompt.cs +++ b/samples/rag-cosmosdb/csharp-ooproc/FilePrompt.cs @@ -66,7 +66,7 @@ public static async Task IngestFile( public class EmbeddingsStoreOutputResponse { - [EmbeddingsStoreOutput("{url}", InputType.Url, "CosmosDBMongoVCoreConnectionString", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] + [EmbeddingsStoreOutput("{url}", InputType.Url, "CosmosDBMongoVCoreConnectionString", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] public required SearchableDocument SearchableDocument { get; init; } public IActionResult? HttpResponse { get; set; } diff --git a/samples/rag-kusto/csharp-legacy/FilePrompt.cs b/samples/rag-kusto/csharp-legacy/FilePrompt.cs index 32600947..e63563e7 100644 --- a/samples/rag-kusto/csharp-legacy/FilePrompt.cs +++ b/samples/rag-kusto/csharp-legacy/FilePrompt.cs @@ -20,7 +20,7 @@ public record SemanticSearchRequest(string Prompt); [FunctionName("IngestFile")] public static async Task IngestFile( [HttpTrigger(AuthorizationLevel.Function, "post")] EmbeddingsRequest req, - [EmbeddingsStore("{url}", InputType.Url, "KustoConnectionString", "Documents", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] + [EmbeddingsStore("{url}", InputType.Url, "KustoConnectionString", "Documents", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] IAsyncCollector output) { if (string.IsNullOrWhiteSpace(req.Url)) diff --git a/samples/rag-kusto/csharp-ooproc/EmailPromptDemo.cs b/samples/rag-kusto/csharp-ooproc/EmailPromptDemo.cs index 7b78e9b8..747ef524 100644 --- a/samples/rag-kusto/csharp-ooproc/EmailPromptDemo.cs +++ b/samples/rag-kusto/csharp-ooproc/EmailPromptDemo.cs @@ -66,7 +66,7 @@ public async Task IngestEmail( } public class EmbeddingsStoreOutputResponse { - [EmbeddingsStoreOutput("{url}", InputType.Url, "KustoConnectionString", "Documents", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] + [EmbeddingsStoreOutput("{url}", InputType.Url, "KustoConnectionString", "Documents", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] public required SearchableDocument SearchableDocument { get; init; } [HttpResult] diff --git a/samples/textcompletion/csharp-legacy/TextCompletionLegacy.cs b/samples/textcompletion/csharp-legacy/TextCompletionLegacy.cs index 12d96f83..95fa8567 100644 --- a/samples/textcompletion/csharp-legacy/TextCompletionLegacy.cs +++ b/samples/textcompletion/csharp-legacy/TextCompletionLegacy.cs @@ -24,7 +24,7 @@ public static class TextCompletionLegacy [FunctionName(nameof(WhoIs))] public static IActionResult WhoIs( [HttpTrigger(AuthorizationLevel.Function, Route = "whois/{name}")] HttpRequest req, - [TextCompletion("Who is {name}?", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response) + [TextCompletion("Who is {name}?", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response) { return new OkObjectResult(response.Content); } @@ -36,7 +36,7 @@ public static IActionResult WhoIs( [FunctionName(nameof(GenericCompletion))] public static IActionResult GenericCompletion( [HttpTrigger(AuthorizationLevel.Function, "post")] PromptPayload payload, - [TextCompletion("{Prompt}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response, + [TextCompletion("{Prompt}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response, ILogger log) { string text = response.Content; diff --git a/samples/textcompletion/csharp-ooproc/TextCompletions.cs b/samples/textcompletion/csharp-ooproc/TextCompletions.cs index a1a1ef7e..4bdbae7a 100644 --- a/samples/textcompletion/csharp-ooproc/TextCompletions.cs +++ b/samples/textcompletion/csharp-ooproc/TextCompletions.cs @@ -23,7 +23,7 @@ public static class TextCompletions [Function(nameof(WhoIs))] public static IActionResult WhoIs( [HttpTrigger(AuthorizationLevel.Function, Route = "whois/{name}")] HttpRequestData req, - [TextCompletionInput("Who is {name}?", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response) + [TextCompletionInput("Who is {name}?", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response) { return new OkObjectResult(response.Content); } @@ -35,7 +35,7 @@ public static IActionResult WhoIs( [Function(nameof(GenericCompletion))] public static IActionResult GenericCompletion( [HttpTrigger(AuthorizationLevel.Function, "post")] HttpRequestData req, - [TextCompletionInput("{Prompt}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response, + [TextCompletionInput("{Prompt}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response, ILogger log) { string text = response.Content; diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs index fda594e7..cf6f8179 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs @@ -10,12 +10,36 @@ namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Assistants; /// public sealed class AssistantPostInputAttribute : InputBindingAttribute { - public AssistantPostInputAttribute(string id, string UserMessage) + /// + /// Initializes a new instance of the class. + /// + /// The assistant identifier. + /// The user message. + /// The name of the configuration section for AI service connectivity settings. + public AssistantPostInputAttribute(string id, string userMessage, string aiConnectionName = "") { this.Id = id; - this.UserMessage = UserMessage; + this.UserMessage = userMessage; + this.AIConnectionName = aiConnectionName; } + /// + /// Gets or sets the name of the configuration section for AI service connectivity settings. + /// + /// + /// This property specifies the name of the configuration section that contains connection details for the AI service. + /// + /// For Azure OpenAI: + /// - If specified, looks for "Endpoint" and "Key" values in this configuration section + /// - If not specified or the section doesn't exist, falls back to environment variables: + /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY + /// - For user-assigned managed identity authentication, this property is required + /// + /// For OpenAI: + /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. + /// + public string AIConnectionName { get; set; } + /// /// Gets the ID of the assistant to update. /// @@ -27,7 +51,7 @@ public AssistantPostInputAttribute(string id, string UserMessage) /// /// When using Azure OpenAI, then should be the name of the model deployment. /// - public string? Model { get; set; } + public string? ChatModel { get; set; } /// /// Gets user message that user has entered for assistant to respond to. @@ -43,4 +67,32 @@ public AssistantPostInputAttribute(string id, string UserMessage) /// Table collection name for chat storage. /// public string CollectionName { get; set; } = "ChatState"; + + /// + /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output + /// more random, while lower values like 0.2 will make it more focused and deterministic. + /// + /// + /// It's generally recommend to use this or but not both. + /// + public string? Temperature { get; set; } = "0.5"; + + /// + /// Gets or sets an alternative to sampling with temperature, called nucleus sampling, where the model considers + /// the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + /// probability mass are considered. + /// + /// + /// It's generally recommend to use this or but not both. + /// + public string? TopP { get; set; } + + /// + /// Gets or sets the maximum number of tokens to output in the completion. Default is 100. + /// + /// + /// The token count of your prompt plus max_tokens cannot exceed the model's context length. + /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). + /// + public string? MaxTokens { get; set; } = "100"; } diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs index bbd14017..73595e69 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs @@ -15,10 +15,12 @@ public class ChatMessage /// /// The content of the message. /// The role of the chat agent. - public ChatMessage(string content, string role, string? name) + /// The tool calls. + public ChatMessage(string content, string role, string toolCalls) { this.Content = content; this.Role = role; + this.ToolCalls = toolCalls; } /// @@ -32,4 +34,10 @@ public ChatMessage(string content, string role, string? name) /// [JsonPropertyName("role")] public string Role { get; set; } + + /// + /// Gets or sets the tool calls. + /// + [JsonPropertyName("toolCalls")] + public string ToolCalls { get; set; } } diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs index f9a2d073..5c2465f1 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs @@ -8,24 +8,43 @@ namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; public class EmbeddingsInputAttribute : InputBindingAttribute { /// - /// Initializes a new instance of the class with the specified input. + /// Initializes a new instance of the class with the specified input. /// /// The input source containing the data to generate embeddings for. /// The type of the input. + /// The name of the configuration section for AI service connectivity settings. /// Thrown if is null. - public EmbeddingsInputAttribute(string input, InputType inputType) + public EmbeddingsInputAttribute(string input, InputType inputType, string aiConnectionName = "") { this.Input = input ?? throw new ArgumentNullException(nameof(input)); this.InputType = inputType; + this.AIConnectionName = aiConnectionName; } + /// + /// Gets or sets the name of the configuration section for AI service connectivity settings. + /// + /// + /// This property specifies the name of the configuration section that contains connection details for the AI service. + /// + /// For Azure OpenAI: + /// - If specified, looks for "Endpoint" and "Key" values in this configuration section + /// - If not specified or the section doesn't exist, falls back to environment variables: + /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY + /// - For user-assigned managed identity authentication, this property is required + /// + /// For OpenAI: + /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. + /// + public string AIConnectionName { get; set; } + /// /// Gets or sets the ID of the model to use. /// /// /// Changing the default embeddings model is a breaking change, since any changes will be stored in a vector database for lookup. Changing the default model can cause the lookups to start misbehaving if they don't match the data that was previously ingested into the vector database. /// - public string Model { get; set; } = OpenAIModels.DefaultEmbeddingsModel; + public string EmbeddingsModel { get; set; } = OpenAIModels.DefaultEmbeddingsModel; /// /// Gets or sets the maximum number of characters to chunk the input into. diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs index 4be0d841..bbcaf326 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs @@ -14,28 +14,47 @@ public class EmbeddingsStoreOutputAttribute : OutputBindingAttribute /// The input source containing the data to generate embeddings for /// and is interpreted based on the value for . /// The type of the input. - /// - /// The name of an app setting or environment variable which contains a connection string value. + /// + /// The name of an app setting or environment variable which contains a connection string value for embedding store. /// /// The name of the collection or table to search or store. + /// The name of the configuration section for AI service connectivity settings. /// - /// Thrown if or or are null. + /// Thrown if or or are null. /// - public EmbeddingsStoreOutputAttribute(string input, InputType inputType, string connectionName, string collection) + public EmbeddingsStoreOutputAttribute(string input, InputType inputType, string storeConnectionName, string collection, string aiConnectionName = "") { this.Input = input ?? throw new ArgumentNullException(nameof(input)); this.InputType = inputType; - this.ConnectionName = connectionName ?? throw new ArgumentNullException(nameof(connectionName)); + this.StoreConnectionName = storeConnectionName ?? throw new ArgumentNullException(nameof(storeConnectionName)); this.Collection = collection ?? throw new ArgumentNullException(nameof(collection)); + this.AIConnectionName = aiConnectionName; } + /// + /// Gets or sets the name of the configuration section for AI service connectivity settings. + /// + /// + /// This property specifies the name of the configuration section that contains connection details for the AI service. + /// + /// For Azure OpenAI: + /// - If specified, looks for "Endpoint" and "Key" values in this configuration section + /// - If not specified or the section doesn't exist, falls back to environment variables: + /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY + /// - For user-assigned managed identity authentication, configuration section is required + /// + /// For OpenAI: + /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. + /// + public string AIConnectionName { get; set; } + /// /// Gets or sets the ID of the model to use. /// /// /// Changing the default embeddings model is a breaking change, since any changes will be stored in a vector database for lookup. Changing the default model can cause the lookups to start misbehaving if they don't match the data that was previously ingested into the vector database. /// - public string Model { get; set; } = OpenAIModels.DefaultEmbeddingsModel; + public string EmbeddingsModel { get; set; } = OpenAIModels.DefaultEmbeddingsModel; /// /// Gets or sets the maximum number of characters to chunk the input into. @@ -70,7 +89,7 @@ public EmbeddingsStoreOutputAttribute(string input, InputType inputType, string /// /// This property supports binding expressions. /// - public string ConnectionName { get; set; } + public string StoreConnectionName { get; set; } /// /// The name of the collection or table to search. diff --git a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs index e660d709..a4a50f61 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs @@ -14,26 +14,45 @@ public sealed class SemanticSearchInputAttribute : InputBindingAttribute /// Initializes a new instance of the class with the specified connection /// and collection names. /// - /// + /// /// The name of an app setting or environment variable which contains a connection string value. /// /// The name of the collection or table to search or store. + /// The name of the configuration section for AI service connectivity settings. /// - /// Thrown if either or are null. + /// Thrown if either or are null. /// - public SemanticSearchInputAttribute(string connectionName, string collection) + public SemanticSearchInputAttribute(string searchConnectionName, string collection, string aiConnectionName = "") { - this.ConnectionName = connectionName ?? throw new ArgumentNullException(nameof(connectionName)); + this.SearchConnectionName = searchConnectionName ?? throw new ArgumentNullException(nameof(searchConnectionName)); this.Collection = collection ?? throw new ArgumentNullException(nameof(collection)); + this.AIConnectionName = aiConnectionName; } + /// + /// Gets or sets the name of the configuration section for AI service connectivity settings. + /// + /// + /// This property specifies the name of the configuration section that contains connection details for the AI service. + /// + /// For Azure OpenAI: + /// - If specified, looks for "Endpoint" and "Key" values in this configuration section + /// - If not specified or the section doesn't exist, falls back to environment variables: + /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY + /// - For user-assigned managed identity authentication, this property is required + /// + /// For OpenAI: + /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. + /// + public string AIConnectionName { get; set; } + /// /// Gets or sets the name of an app setting or environment variable which contains a connection string value. /// /// /// This property supports binding expressions. /// - public string ConnectionName { get; set; } + public string SearchConnectionName { get; set; } /// /// The name of the collection or table to search. @@ -98,4 +117,32 @@ The following is a list of documents that you can refer to when answering questi /// Gets or sets the number of knowledge items to inject into the . /// public int MaxKnowledgeCount { get; set; } = 1; + + /// + /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output + /// more random, while lower values like 0.2 will make it more focused and deterministic. + /// + /// + /// It's generally recommend to use this or but not both. + /// + public string? Temperature { get; set; } = "0.5"; + + /// + /// Gets or sets an alternative to sampling with temperature, called nucleus sampling, where the model considers + /// the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + /// probability mass are considered. + /// + /// + /// It's generally recommend to use this or but not both. + /// + public string? TopP { get; set; } + + /// + /// Gets or sets the maximum number of tokens to output in the completion. Default is 2048. + /// + /// + /// The token count of your prompt plus max_tokens cannot exceed the model's context length. + /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). + /// + public string? MaxTokens { get; set; } = "2048"; } diff --git a/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs index 463d1e96..d7c49268 100644 --- a/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs @@ -14,11 +14,30 @@ public sealed class TextCompletionInputAttribute : InputBindingAttribute /// Initializes a new instance of the class with the specified text prompt. /// /// The prompt to generate completions for, encoded as a string. - public TextCompletionInputAttribute(string prompt) + /// The name of the configuration section for AI service connectivity settings. + public TextCompletionInputAttribute(string prompt, string aiConnectionName = "") { this.Prompt = prompt ?? throw new ArgumentNullException(nameof(prompt)); + this.AIConnectionName = aiConnectionName; } + /// + /// Gets or sets the name of the configuration section for AI service connectivity settings. + /// + /// + /// This property specifies the name of the configuration section that contains connection details for the AI service. + /// + /// For Azure OpenAI: + /// - If specified, looks for "Endpoint" and "Key" values in this configuration section + /// - If not specified or the section doesn't exist, falls back to environment variables: + /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY + /// - For user-assigned managed identity authentication, this property is required + /// + /// For OpenAI: + /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. + /// + public string AIConnectionName { get; set; } + /// /// Gets or sets the prompt to generate completions for, encoded as a string. /// @@ -27,7 +46,7 @@ public TextCompletionInputAttribute(string prompt) /// /// Gets or sets the ID of the model to use. /// - public string Model { get; set; } = "gpt-3.5-turbo"; + public string ChatModel { get; set; } = "gpt-3.5-turbo"; /// /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs new file mode 100644 index 00000000..4c08ea38 --- /dev/null +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Azure.WebJobs.Description; +using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; +using OpenAI.Chat; + +namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; + +/// +/// Binding attribute for the OpenAI Assistant extension. This attribute is used to specify the configuration +/// settings for the OpenAI Assistant when used in a function parameter. It allows to set various chat completion options. +/// +[Binding] +[AttributeUsage(AttributeTargets.Parameter)] +public class AssistantBaseAttribute : Attribute +{ + /// + /// Initializes a new instance of the class. + /// + /// The name of the configuration section for AI service connectivity settings. + public AssistantBaseAttribute(string aiConnectionName = "") + { + this.AIConnectionName = aiConnectionName; + } + + /// + /// Gets or sets the name of the Large Language Model to invoke for chat responses. + /// The default value is "gpt-3.5-turbo". + /// + /// + /// This property supports binding expressions. + /// + [AutoResolve] + public string ChatModel { get; set; } = OpenAIModels.DefaultChatModel; + + /// + /// Gets or sets the name of the configuration section for AI service connectivity settings. + /// + /// + /// This property specifies the name of the configuration section that contains connection details for the AI service. + /// + /// For Azure OpenAI: + /// - If specified, looks for "Endpoint" and "Key" values in this configuration section + /// - If not specified or the section doesn't exist, falls back to environment variables: + /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY + /// - For user-assigned managed identity authentication, configuration section is required + /// + /// For OpenAI: + /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. + /// + public string AIConnectionName { get; set; } + + /// + /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output + /// more random, while lower values like 0.2 will make it more focused and deterministic. + /// + /// + /// It's generally recommend to use this or but not both. + /// + [AutoResolve] + public string? Temperature { get; set; } = "0.5"; + + /// + /// Gets or sets an alternative to sampling with temperature, called nucleus sampling, where the model considers + /// the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + /// probability mass are considered. + /// + /// + /// It's generally recommend to use this or but not both. + /// + [AutoResolve] + public string? TopP { get; set; } + + /// + /// Gets or sets the maximum number of tokens to output in the completion. Default value = 100. + /// + /// + /// The token count of your prompt plus max_tokens cannot exceed the model's context length. + /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). + /// + [AutoResolve] + public string? MaxTokens { get; set; } = "100"; + + internal ChatCompletionOptions BuildRequest() + { + ChatCompletionOptions request = new(); + if (int.TryParse(this.MaxTokens, out int maxTokens)) + { + request.MaxOutputTokenCount = maxTokens; + } + + if (float.TryParse(this.Temperature, out float temperature)) + { + request.Temperature = temperature; + } + + if (float.TryParse(this.TopP, out float topP)) + { + request.TopP = topP; + } + + return request; + } +} diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs index 56ed35fa..c2b42734 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs @@ -7,9 +7,15 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; [Binding] [AttributeUsage(AttributeTargets.Parameter)] -public sealed class AssistantPostAttribute : Attribute +public sealed class AssistantPostAttribute : AssistantBaseAttribute { - public AssistantPostAttribute(string id, string userMessage) + /// + /// Initializes a new instance of the class. + /// + /// The assistant identifier. + /// The user message. + /// The name of the configuration section for AI service connectivity settings. + public AssistantPostAttribute(string id, string userMessage, string aiConnectionName = "") : base(aiConnectionName) { this.Id = id; this.UserMessage = userMessage; @@ -21,15 +27,6 @@ public AssistantPostAttribute(string id, string userMessage) [AutoResolve] public string Id { get; } - /// - /// Gets or sets the OpenAI chat model to use. - /// - /// - /// When using Azure OpenAI, then should be the name of the model deployment. - /// - [AutoResolve] - public string? Model { get; set; } - /// /// Gets or sets the user message to OpenAI. /// diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs index 56f4abdb..8f8baf22 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs @@ -3,7 +3,6 @@ using System.ClientModel; using Azure; -using Azure.AI.OpenAI; using Azure.Data.Tables; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; using Microsoft.Extensions.Azure; @@ -30,7 +29,7 @@ record InternalChatState(string Id, AssistantStateEntity Metadata, List const int FunctionCallBatchLimit = 50; const string DefaultChatStorage = "AzureWebJobsStorage"; - readonly ChatClient chatClient; + readonly OpenAIClientFactory openAIClientFactory; readonly IAssistantSkillInvoker skillInvoker; readonly ILogger logger; readonly AzureComponentFactory azureComponentFactory; @@ -39,7 +38,7 @@ record InternalChatState(string Id, AssistantStateEntity Metadata, List(); + this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory)); this.azureComponentFactory = azureComponentFactory ?? throw new ArgumentNullException(nameof(azureComponentFactory)); this.configuration = configuration ?? throw new ArgumentNullException(nameof(configuration)); } @@ -176,7 +172,7 @@ public async Task GetStateAsync(AssistantQueryAttribute assistan chatState.Metadata.LastUpdatedAt, chatState.Metadata.TotalMessages, chatState.Metadata.TotalTokens, - filteredChatMessages.Select(msg => new AssistantMessage(msg.Content, msg.Role)).ToList()); + filteredChatMessages.Select(msg => new AssistantMessage(msg.Content, msg.Role, msg.ToolCallsString)).ToList()); return state; } @@ -223,14 +219,13 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib // Add the chat message to the batch batch.Add(new TableTransactionAction(TableTransactionActionType.Add, chatMessageEntity)); - string deploymentName = attribute.Model ?? OpenAIModels.DefaultChatModel; IList? functions = this.skillInvoker.GetFunctionsDefinitions(); // We loop if the model returns function calls. Otherwise, we break after receiving a response. while (true) { // Get the next response from the LLM - ChatCompletionOptions chatRequest = new ChatCompletionOptions(); + ChatCompletionOptions chatRequest = attribute.BuildRequest(); if (functions is not null) { foreach (ChatTool fn in functions) @@ -238,11 +233,12 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib chatRequest.Tools.Add(fn); } } - chatRequest.ToolChoice = ChatToolChoice.CreateAutoChoice(); + IEnumerable chatMessages = ToOpenAIChatRequestMessages(chatState.Messages); - // ToDo: Pass more ChatCompletionOptions like TextCompletion - ClientResult response = await this.chatClient.CompleteChatAsync(chatMessages, chatRequest); + ClientResult response = await this.openAIClientFactory.GetChatClient( + attribute.AIConnectionName, + attribute.ChatModel).CompleteChatAsync(chatMessages, chatRequest, cancellationToken: cancellationToken); // We don't normally expect more than one message, but just in case we get multiple messages, // return all of them separated by two newlines. @@ -391,7 +387,7 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib chatState.Metadata.LastUpdatedAt, chatState.Metadata.TotalMessages, chatState.Metadata.TotalTokens, - filteredChatMessages.Select(msg => new AssistantMessage(msg.Content, msg.Role)).ToList()); + filteredChatMessages.Select(msg => new AssistantMessage(msg.Content, msg.Role, msg.ToolCallsString)).ToList()); return state; } diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsAttribute.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsAttribute.cs index bd15b6ae..8c8c733b 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsAttribute.cs @@ -21,8 +21,9 @@ public sealed class EmbeddingsAttribute : EmbeddingsBaseAttribute /// /// The input source containing the data to generate embeddings for. /// The type of the input. + /// The name of the configuration section for AI service connectivity settings. /// Thrown if is null. - public EmbeddingsAttribute(string input, InputType inputType) : base(input, inputType) + public EmbeddingsAttribute(string input, InputType inputType, string aiConnectionName = "") : base(input, inputType, aiConnectionName) { } } diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsBaseAttribute.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsBaseAttribute.cs index 7873474f..25f516b1 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsBaseAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsBaseAttribute.cs @@ -22,13 +22,34 @@ public class EmbeddingsBaseAttribute : Attribute /// /// The input source containing the data to generate embeddings for. /// The type of the input. + /// The name of the configuration section for AI service connectivity settings. /// Thrown if is null. - public EmbeddingsBaseAttribute(string input, InputType inputType) + public EmbeddingsBaseAttribute(string input, InputType inputType, string aiConnectionName = "") { - this.Input = input ?? throw new ArgumentNullException(nameof(input)); + this.Input = string.IsNullOrEmpty(input) + ? throw new ArgumentException("Input cannot be null or empty.", nameof(input)) + : input; this.InputType = inputType; + this.AIConnectionName = aiConnectionName; } + /// + /// Gets or sets the name of the configuration section for AI service connectivity settings. + /// + /// + /// This property specifies the name of the configuration section that contains connection details for the AI service. + /// + /// For Azure OpenAI: + /// - If specified, looks for "Endpoint" and "Key" values in this configuration section + /// - If not specified or the section doesn't exist, falls back to environment variables: + /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY + /// - For user-assigned managed identity authentication, configuration section is required + /// + /// For OpenAI: + /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. + /// + public string AIConnectionName { get; set; } + /// /// Gets or sets the ID of the model to use. /// @@ -36,7 +57,7 @@ public EmbeddingsBaseAttribute(string input, InputType inputType) /// Changing the default embeddings model is a breaking change, since any changes will be stored in a vector database for lookup. Changing the default model can cause the lookups to start misbehaving if they don't match the data that was previously ingested into the vector database. /// [AutoResolve] - public string Model { get; set; } = OpenAIModels.DefaultEmbeddingsModel; + public string EmbeddingsModel { get; set; } = OpenAIModels.DefaultEmbeddingsModel; /// /// Gets or sets the maximum number of characters to chunk the input into. diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs index 96f6fca3..df7d1d87 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs @@ -3,7 +3,6 @@ using System.ClientModel; using System.Text.Json; -using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; using Microsoft.Extensions.Logging; using OpenAI.Embeddings; @@ -14,7 +13,7 @@ class EmbeddingsConverter : IAsyncConverter, IAsyncConverter { - readonly EmbeddingClient embeddingClient; + readonly OpenAIClientFactory openAIClientFactory; readonly ILogger logger; // Note: we need this converter as Azure.AI.OpenAI does not support System.Text.Json serialization since their constructors are internal @@ -23,10 +22,11 @@ class EmbeddingsConverter : Converters = { new EmbeddingsContextConverter(), new SearchableDocumentJsonConverter() } }; - public EmbeddingsConverter(AzureOpenAIClient azureOpenAIClient, ILoggerFactory loggerFactory) + public EmbeddingsConverter( + OpenAIClientFactory openAIClientFactory, + ILoggerFactory loggerFactory) { - // ToDo: Handle the deployment name retrieval better - this.embeddingClient = azureOpenAIClient.GetEmbeddingClient("embedding") ?? throw new ArgumentNullException(nameof(azureOpenAIClient)); + this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory)); this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); } @@ -49,9 +49,14 @@ async Task ConvertCoreAsync( EmbeddingsAttribute attribute, CancellationToken cancellationToken) { - List input = await EmbeddingsHelper.BuildRequest(attribute.MaxOverlap, attribute.MaxChunkLength, attribute.Model, attribute.InputType, attribute.Input); + List input = await EmbeddingsHelper.BuildRequest(attribute.MaxOverlap, + attribute.MaxChunkLength, + attribute.InputType, + attribute.Input); this.logger.LogInformation("Sending OpenAI embeddings request"); - ClientResult response = await this.embeddingClient.GenerateEmbeddingsAsync(input); + ClientResult response = await this.openAIClientFactory.GetEmbeddingClient( + attribute.AIConnectionName, + attribute.EmbeddingsModel).GenerateEmbeddingsAsync(input, cancellationToken: cancellationToken); this.logger.LogInformation("Received OpenAI embeddings count: {response}", response.Value.Count); return new EmbeddingsContext(input, response); diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs index 319c2820..5ce78d9e 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System.Diagnostics; -using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; static class EmbeddingsHelper @@ -17,7 +16,7 @@ static EmbeddingsHelper() httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(UserAgent); } - public static async Task> BuildRequest(int maxOverlap, int maxChunkLength, string model, InputType inputType, string input) + public static async Task> BuildRequest(int maxOverlap, int maxChunkLength, InputType inputType, string input) { using TextReader reader = await GetTextReader(inputType, input); if (maxOverlap >= maxChunkLength) diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreAttribute.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreAttribute.cs index 1355404b..def0ddb2 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreAttribute.cs @@ -16,26 +16,27 @@ public sealed class EmbeddingsStoreAttribute : EmbeddingsBaseAttribute /// The input source containing the data to generate embeddings for /// and is interpreted based on the value for . /// The type of the input. - /// + /// /// The name of an app setting or environment variable which contains a connection string value. /// /// The name of the collection or table to search or store. + /// The name of the configuration section for AI service connectivity settings. /// /// Thrown if or or are null. /// - public EmbeddingsStoreAttribute(string input, InputType inputType, string connectionName, string collection) : base(input, inputType) + public EmbeddingsStoreAttribute(string input, InputType inputType, string storeConnectionName, string collection, string aiConnectionName = "") : base(input, inputType, aiConnectionName) { - this.ConnectionName = connectionName ?? throw new ArgumentNullException(nameof(connectionName)); + this.StoreConnectionName = storeConnectionName ?? throw new ArgumentNullException(nameof(storeConnectionName)); this.Collection = collection ?? throw new ArgumentNullException(nameof(collection)); } /// - /// Gets or sets the name of an app setting or environment variable which contains a connection string value. + /// Gets or sets the name of an app setting or environment variable which contains a connection string value for embedding store. /// /// /// This property supports binding expressions. /// - public string ConnectionName { get; set; } + public string StoreConnectionName { get; set; } /// /// The name of the collection or table to search. diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs index b995152b..7cd02216 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs @@ -3,19 +3,17 @@ using System.ClientModel; using System.Text.Json; -using Azure; -using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using OpenAI.Embeddings; -using OpenAISDK = Azure.AI.OpenAI; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; class EmbeddingsStoreConverter : IAsyncConverter> { - readonly EmbeddingClient embeddingClient; readonly ILogger logger; + readonly OpenAIClientFactory openAIClientFactory; + readonly ILogger logger; readonly ISearchProvider? searchProvider; // Note: we need this converter as Azure.AI.OpenAI does not support System.Text.Json serialization since their constructors are internal @@ -24,13 +22,13 @@ class EmbeddingsStoreConverter : Converters = { new EmbeddingsContextConverter(), new SearchableDocumentJsonConverter() } }; - public EmbeddingsStoreConverter(AzureOpenAIClient azureOpenAIClient, + public EmbeddingsStoreConverter( + OpenAIClientFactory openAIClientFactory, ILoggerFactory loggerFactory, IEnumerable searchProviders, IOptions openAiConfigOptions) { - // ToDo: Handle the deployment name retrieval better - this.embeddingClient = azureOpenAIClient.GetEmbeddingClient("embedding") ?? throw new ArgumentNullException(nameof(azureOpenAIClient)); + this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory)); this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); openAiConfigOptions.Value.SearchProvider.TryGetValue("type", out object value); this.searchProvider = searchProviders? @@ -44,7 +42,7 @@ public Task> ConvertAsync(EmbeddingsStoreAtt throw new InvalidOperationException( "No search provider is configured. Search providers are configured in the host.json file. For .NET apps, the appropriate nuget package must also be added to the app's project file."); } - IAsyncCollector collector = new SemanticDocumentCollector(input, this.searchProvider, this.embeddingClient, this.logger); + IAsyncCollector collector = new SemanticDocumentCollector(input, this.searchProvider, this.openAIClientFactory, this.logger); return Task.FromResult(collector); } @@ -65,20 +63,23 @@ sealed class SemanticDocumentCollector : IAsyncCollector { readonly EmbeddingsStoreAttribute attribute; readonly ISearchProvider searchProvider; - readonly EmbeddingClient embeddingClient; + readonly OpenAIClientFactory openAIClientFactory; readonly ILogger logger; - public SemanticDocumentCollector(EmbeddingsStoreAttribute attribute, ISearchProvider searchProvider, EmbeddingClient embeddingClient, ILogger logger) + public SemanticDocumentCollector(EmbeddingsStoreAttribute attribute, + ISearchProvider searchProvider, + OpenAIClientFactory openAIClientFactory, + ILogger logger) { this.attribute = attribute; this.searchProvider = searchProvider; - this.embeddingClient = embeddingClient; + this.openAIClientFactory = openAIClientFactory; this.logger = logger; } public async Task AddAsync(SearchableDocument item, CancellationToken cancellationToken = default) { - if (string.IsNullOrEmpty(this.attribute.ConnectionName)) + if (string.IsNullOrEmpty(this.attribute.StoreConnectionName)) { throw new InvalidOperationException("No connection string information was provided."); } @@ -88,15 +89,20 @@ public async Task AddAsync(SearchableDocument item, CancellationToken cancellati } // Get embeddings from OpenAI - List input = await EmbeddingsHelper.BuildRequest(this.attribute.MaxOverlap, this.attribute.MaxChunkLength, this.attribute.Model, this.attribute.InputType, this.attribute.Input); + List input = await EmbeddingsHelper.BuildRequest(this.attribute.MaxOverlap, + this.attribute.MaxChunkLength, + this.attribute.InputType, + this.attribute.Input); this.logger.LogInformation("Sending OpenAI embeddings request"); - ClientResult response = await this.embeddingClient.GenerateEmbeddingsAsync(input); - EmbeddingsContext embeddingsContext = new (input, response); + ClientResult response = await this.openAIClientFactory.GetEmbeddingClient( + this.attribute.AIConnectionName, + this.attribute.EmbeddingsModel).GenerateEmbeddingsAsync(input, cancellationToken: cancellationToken); + EmbeddingsContext embeddingsContext = new(input, response); this.logger.LogInformation("Received OpenAI embeddings of count: {count}", embeddingsContext.Count); // Add document to the embed store item.Embeddings = embeddingsContext; - item.ConnectionInfo = new ConnectionInfo(this.attribute.ConnectionName, this.attribute.Collection); + item.ConnectionInfo = new ConnectionInfo(this.attribute.StoreConnectionName, this.attribute.Collection); this.logger.LogInformation("Adding document to the embed store."); await this.searchProvider.AddDocumentAsync(item, cancellationToken); this.logger.LogInformation("Finished adding document to the embed store."); diff --git a/src/WebJobs.Extensions.OpenAI/Models/AssistantMessage.cs b/src/WebJobs.Extensions.OpenAI/Models/AssistantMessage.cs index 6ae82c58..a6166d1f 100644 --- a/src/WebJobs.Extensions.OpenAI/Models/AssistantMessage.cs +++ b/src/WebJobs.Extensions.OpenAI/Models/AssistantMessage.cs @@ -16,10 +16,12 @@ public class AssistantMessage /// /// The content of the message. /// The role of the chat agent. - public AssistantMessage(string content, string role) + /// The tool calls. + public AssistantMessage(string content, string role, string toolCalls) { this.Content = content; this.Role = role; + this.ToolCalls = toolCalls; } /// @@ -33,4 +35,10 @@ public AssistantMessage(string content, string role) /// [JsonProperty("role")] public string Role { get; set; } + + /// + /// Gets or sets the tool calls. + /// + [JsonProperty("toolCalls")] + public string ToolCalls { get; } } diff --git a/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs b/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs index 059ec8fe..180b508b 100644 --- a/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs +++ b/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs @@ -4,58 +4,144 @@ using System.Collections.Concurrent; using Azure; using Azure.AI.OpenAI; +using Azure.Core; using Azure.Identity; -using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Azure; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; using OpenAI; +using OpenAI.Chat; +using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI; public class OpenAIClientFactory { - static readonly ConcurrentDictionary _azureOpenAIclients = new(); - static readonly ConcurrentDictionary _openAIClients = new(); + readonly IConfiguration configuration; + readonly AzureComponentFactory azureComponentFactory; + readonly ILogger logger; + readonly ConcurrentDictionary azureOpenAIclients = new(); + readonly ConcurrentDictionary openAIClients = new(); + readonly ConcurrentDictionary chatClients = new(); // key is ai endpoint, model + readonly ConcurrentDictionary embeddingClients = new(); // key is ai endpoint, model + string aiEndpoint = string.Empty; - public static AzureOpenAIClient CreateAzureOpenAIClient(string endpoint, string apiKey) + public OpenAIClientFactory( + IConfiguration configuration, + AzureComponentFactory azureComponentFactory, + ILoggerFactory loggerFactory) { - string key = $"{endpoint}-{apiKey}"; - return _azureOpenAIclients.GetOrAdd(key, _ => new AzureOpenAIClient(new Uri(endpoint), new AzureKeyCredential(apiKey))); + this.configuration = configuration ?? throw new ArgumentNullException(nameof(configuration)); + this.azureComponentFactory = azureComponentFactory ?? throw new ArgumentNullException(nameof(azureComponentFactory)); + this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); } - public static AzureOpenAIClient CreateAzureOpenAIClientWithDefaultAzureCredential(string endpoint) + public ChatClient GetChatClient(string aiConnectionName, string model) { - return _azureOpenAIclients.GetOrAdd(endpoint, _ => new AzureOpenAIClient(new Uri(endpoint), new DefaultAzureCredential())); + HasOpenAIKey(out bool hasOpenAIKey, out string openAIKey); + ChatClient chatClient; + (chatClient, string endpoint, string chatModel) = this.chatClients.GetOrAdd( + hasOpenAIKey ? "OpenAI" : aiConnectionName, + name => + { + if (!hasOpenAIKey) + { + AzureOpenAIClient azureOpenAIClient = this.CreateClientFromConfigSection(aiConnectionName); + return (azureOpenAIClient.GetChatClient(model), this.aiEndpoint, model); + } + else + { + OpenAIClient openAIClient = this.CreateOpenAIClient(openAIKey); + return (openAIClient.GetChatClient(model), this.aiEndpoint, model); + } + }); + + return chatClient; } - public static OpenAIClient CreateOpenAIClient(string apiKey) + public EmbeddingClient GetEmbeddingClient(string aiConnectionName, string model) { - return _openAIClients.GetOrAdd(apiKey, _ => new OpenAIClient(apiKey)); + HasOpenAIKey(out bool hasOpenAIKey, out string openAIKey); + EmbeddingClient embeddingClient; + (embeddingClient, string endpoint, string embeddingModel) = this.embeddingClients.GetOrAdd( + hasOpenAIKey ? "OpenAI" : aiConnectionName, + name => + { + if (!hasOpenAIKey) + { + AzureOpenAIClient azureOpenAIClient = this.CreateClientFromConfigSection(aiConnectionName); + return (azureOpenAIClient.GetEmbeddingClient(model), this.aiEndpoint, model); + } + else + { + OpenAIClient openAIClient = this.CreateOpenAIClient(openAIKey); + return (openAIClient.GetEmbeddingClient(model), this.aiEndpoint, model); + } + }); + + return embeddingClient; } -} -public static class ServiceCollectionExtensions -{ - public static IServiceCollection AddAzureOpenAIClient( - this IServiceCollection services, - string endpoint, - string apiKey) + static void HasOpenAIKey(out bool hasOpenAIKey, out string openAIKey) + { + hasOpenAIKey = false; + openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + if (!string.IsNullOrEmpty(openAIKey)) + { + hasOpenAIKey = true; + } + } + + AzureOpenAIClient CreateClientFromConfigSection(string aiConnectionName) + { + IConfigurationSection section = this.configuration.GetSection(aiConnectionName); + + if (!section.Exists()) + { + this.logger.LogInformation($"Configuration section for Azure OpenAI not found, trying fallback to environment variables - AZURE_OPENAI_ENDPOINT and/or AZURE_OPENAI_KEY"); + } + + this.aiEndpoint = section?.GetValue("Endpoint") ?? Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); + string? azureOpenAIKey = section?.GetValue("Key") ?? Environment.GetEnvironmentVariable("AZURE_OPENAI_KEY"); + + if (!string.IsNullOrEmpty(this.aiEndpoint)) + { + this.logger.LogInformation($"Using Azure OpenAI endpoint: {this.aiEndpoint}"); + if (!string.IsNullOrEmpty(azureOpenAIKey)) + { + this.logger.LogInformation($"Authenticating using Azure OpenAI Key."); + return this.CreateAzureOpenAIClient(this.aiEndpoint, azureOpenAIKey); + } + else + { + this.logger.LogInformation($"Authenticating using Azure OpenAI TokenCredential."); + + TokenCredential tokenCredential = section.Exists() ? + this.azureComponentFactory.CreateTokenCredential(section) : + new DefaultAzureCredential(); + return this.CreateAzureOpenAIClientWithTokenCredential(this.aiEndpoint, tokenCredential); + } + } + + string errorMessage = $"Configuration section '{aiConnectionName}' is missing required 'Endpoint' or 'Key' values."; + this.logger.LogError(errorMessage); + throw new InvalidOperationException(errorMessage); + } + + AzureOpenAIClient CreateAzureOpenAIClient(string endpoint, string apiKey) { - services.AddSingleton(sp => OpenAIClientFactory.CreateAzureOpenAIClient(endpoint, apiKey)); - return services; + string key = $"{endpoint}-{apiKey}"; + return this.azureOpenAIclients.GetOrAdd(key, _ => new AzureOpenAIClient(new Uri(endpoint), new AzureKeyCredential(apiKey))); } - public static IServiceCollection AddAzureOpenAIClientWithDefaultAzureCredential( - this IServiceCollection services, - string endpoint) + AzureOpenAIClient CreateAzureOpenAIClientWithTokenCredential(string endpoint, TokenCredential tokenCredential) { - services.AddSingleton(sp => OpenAIClientFactory.CreateAzureOpenAIClientWithDefaultAzureCredential(endpoint)); - return services; + return this.azureOpenAIclients.GetOrAdd(endpoint, _ => new AzureOpenAIClient(new Uri(endpoint), tokenCredential)); } - public static IServiceCollection AddOpenAIClient( - this IServiceCollection services, - string apiKey) + OpenAIClient CreateOpenAIClient(string openAIKey) { - services.AddSingleton(sp => OpenAIClientFactory.CreateOpenAIClient(apiKey)); - return services; + this.logger.LogInformation($"Authenticating using OpenAI Key."); + return this.openAIClients.GetOrAdd(openAIKey, _ => new OpenAIClient(openAIKey)); } } \ No newline at end of file diff --git a/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs b/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs index 16e82ed0..d7fa61ca 100644 --- a/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs +++ b/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs @@ -15,7 +15,7 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI; [Extension("OpenAI")] partial class OpenAIExtension : IExtensionConfigProvider { - readonly AzureOpenAIClient openAIClient; + readonly OpenAIClientFactory openAIClientFactory; readonly TextCompletionConverter textCompletionConverter; readonly EmbeddingsConverter embeddingsConverter; readonly EmbeddingsStoreConverter embeddingsStoreConverter; @@ -24,7 +24,7 @@ partial class OpenAIExtension : IExtensionConfigProvider readonly AssistantSkillTriggerBindingProvider assistantskillTriggerBindingProvider; public OpenAIExtension( - AzureOpenAIClient openAIClient, + OpenAIClientFactory openAIClientFactory, TextCompletionConverter textCompletionConverter, EmbeddingsConverter embeddingsConverter, EmbeddingsStoreConverter embeddingsStoreConverter, @@ -32,7 +32,7 @@ public OpenAIExtension( AssistantBindingConverter assistantConverter, AssistantSkillTriggerBindingProvider assistantTriggerBindingProvider) { - this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient)); + this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory)); this.textCompletionConverter = textCompletionConverter ?? throw new ArgumentNullException(nameof(textCompletionConverter)); this.embeddingsConverter = embeddingsConverter ?? throw new ArgumentNullException(nameof(embeddingsConverter)); this.embeddingsStoreConverter = embeddingsStoreConverter ?? throw new ArgumentNullException(nameof(embeddingsStoreConverter)); @@ -82,6 +82,6 @@ void IExtensionConfigProvider.Initialize(ExtensionConfigContext context) .BindToTrigger(this.assistantskillTriggerBindingProvider); // OpenAI service input binding support (NOTE: This may be removed in a future version.) - context.AddBindingRule().BindToInput(_ => this.openAIClient); + context.AddBindingRule().BindToInput(_ => this.openAIClientFactory); } } diff --git a/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs b/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs index a1fd498d..733d7956 100644 --- a/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs +++ b/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs @@ -28,31 +28,9 @@ public static IWebJobsBuilder AddOpenAIBindings(this IWebJobsBuilder builder) throw new ArgumentNullException(nameof(builder)); } - // Register the client for Azure Open AI - string? azureOpenAIEndpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); - string? openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); - string? azureOpenAIKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_KEY"); - - if (azureOpenAIEndpoint != null && !string.IsNullOrEmpty(azureOpenAIKey)) - { - RegisterAzureOpenAIClient(builder.Services, azureOpenAIEndpoint, azureOpenAIKey); - } - else if (azureOpenAIEndpoint != null) - { - RegisterAzureOpenAIADAuthClient(builder.Services, azureOpenAIEndpoint); - } - else if (!string.IsNullOrEmpty(openAIKey)) - { - RegisterOpenAIClient(builder.Services, openAIKey); - } - else - { - throw new InvalidOperationException("Must set AZURE_OPENAI_ENDPOINT or OPENAI_API_KEY environment variables."); - } - // Register the WebJobs extension, which enables the bindings. builder.AddExtension(); - + // Service objects that will be used by the extension builder.Services.AddSingleton(); builder.Services.AddSingleton(); @@ -74,21 +52,8 @@ public static IWebJobsBuilder AddOpenAIBindings(this IWebJobsBuilder builder) builder.Services.AddAzureClientsCore(); // Adds AzureComponentFactory - return builder; - } - - static void RegisterAzureOpenAIClient(IServiceCollection services, string azureOpenAIEndpoint, string azureOpenAIKey) - { - services.AddAzureOpenAIClient(azureOpenAIEndpoint, azureOpenAIKey); - } + builder.Services.AddSingleton(); - static void RegisterAzureOpenAIADAuthClient(IServiceCollection services, string azureOpenAIEndpoint) - { - services.AddAzureOpenAIClientWithDefaultAzureCredential(azureOpenAIEndpoint); - } - - static void RegisterOpenAIClient(IServiceCollection services, string openAIKey) - { - services.AddOpenAIClient(openAIKey); + return builder; } } diff --git a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchAttribute.cs b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchAttribute.cs index 70b3e742..3edbc5f1 100644 --- a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchAttribute.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using Microsoft.Azure.WebJobs.Description; +using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; @@ -11,22 +12,23 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; /// [Binding] [AttributeUsage(AttributeTargets.Parameter)] -public sealed class SemanticSearchAttribute : Attribute +public sealed class SemanticSearchAttribute : AssistantBaseAttribute { /// /// Initializes a new instance of the class with the specified connection /// and collection names. /// - /// - /// The name of an app setting or environment variable which contains a connection string value. + /// + /// The name of an app setting or environment variable which contains a connection string value of search provider. /// /// The name of the collection or table to search or store. + /// The name of the configuration section for AI service connectivity settings. /// - /// Thrown if either or are null. + /// Thrown if either or are null. /// - public SemanticSearchAttribute(string connectionName, string collection) + public SemanticSearchAttribute(string searchConnectionName, string collection, string aiConnectionName = "") : base(aiConnectionName) { - this.ConnectionName = connectionName ?? throw new ArgumentNullException(nameof(connectionName)); + this.SearchConnectionName = searchConnectionName ?? throw new ArgumentNullException(nameof(searchConnectionName)); this.Collection = collection ?? throw new ArgumentNullException(nameof(collection)); } @@ -37,7 +39,7 @@ public SemanticSearchAttribute(string connectionName, string collection) /// This property supports binding expressions. /// [AutoResolve] - public string ConnectionName { get; set; } + public string SearchConnectionName { get; set; } /// /// The name of the collection or table or index to search. @@ -68,16 +70,6 @@ public SemanticSearchAttribute(string connectionName, string collection) [AutoResolve] public string EmbeddingsModel { get; set; } = OpenAIModels.DefaultEmbeddingsModel; - /// - /// Gets or sets the name of the Large Language Model to invoke for chat responses. - /// The default value is "gpt-3.5-turbo". - /// - /// - /// This property supports binding expressions. - /// - [AutoResolve] - public string ChatModel { get; set; } = OpenAIModels.DefaultChatModel; - /// /// Gets or sets the system prompt to use for prompting the large language model. /// diff --git a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs index b3799884..3385ca3d 100644 --- a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs @@ -4,16 +4,12 @@ using System.ClientModel; using System.Text; using System.Text.Json; -using Azure; -using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using OpenAI; using OpenAI.Chat; using OpenAI.Embeddings; -using OpenAISDK = Azure.AI.OpenAI; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; @@ -21,30 +17,27 @@ class SemanticSearchConverter : IAsyncConverter, IAsyncConverter { - readonly ChatClient chatClient; - readonly EmbeddingClient embeddingClient; + readonly OpenAIClientFactory openAIClientFactory; readonly ILogger logger; readonly ISearchProvider? searchProvider; static readonly JsonSerializerOptions options = new() { - Converters = { + Converters = { new SearchableDocumentJsonConverter(), new EmbeddingsContextConverter(), new EmbeddingsOptionsJsonConverter(), new ChatCompletionsJsonConverter()} - }; + }; public SemanticSearchConverter( - AzureOpenAIClient azureOpenAIClient, + OpenAIClientFactory openAIClientFactory, ILoggerFactory loggerFactory, IEnumerable searchProviders, IOptions openAiConfigOptions) { - //ToDo: Handle retrieval of the model better - this.chatClient = azureOpenAIClient.GetChatClient(deploymentName: Environment.GetEnvironmentVariable("CHAT_MODEL_DEPLOYMENT_NAME")) ?? throw new ArgumentNullException(nameof(azureOpenAIClient)); - this.embeddingClient = azureOpenAIClient.GetEmbeddingClient(deploymentName: Environment.GetEnvironmentVariable("EMBEDDING_MODEL_DEPLOYMENT_NAME")) ?? throw new ArgumentNullException(nameof(azureOpenAIClient)); this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); + this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory)); openAiConfigOptions.Value.SearchProvider.TryGetValue("type", out object value); this.logger.LogInformation("Type of the searchProvider configured in host file: {type}", value); @@ -70,10 +63,12 @@ async Task ConvertHelperAsync( // Get the embeddings for the query, which will be used for doing a semantic search this.logger.LogInformation("Sending OpenAI embeddings request: {request}", attribute.Query); - ClientResult embedding = await this.embeddingClient.GenerateEmbeddingAsync(attribute.Query, cancellationToken: cancellationToken); + ClientResult embedding = await this.openAIClientFactory.GetEmbeddingClient( + attribute.AIConnectionName, + attribute.EmbeddingsModel).GenerateEmbeddingAsync(attribute.Query, cancellationToken: cancellationToken); this.logger.LogInformation("Received OpenAI embeddings"); - ConnectionInfo connectionInfo = new(attribute.ConnectionName, attribute.Collection); + ConnectionInfo connectionInfo = new(attribute.SearchConnectionName, attribute.Collection); if (string.IsNullOrEmpty(connectionInfo.ConnectionName)) { throw new InvalidOperationException("No connection string information was provided."); @@ -106,7 +101,8 @@ async Task ConvertHelperAsync( new UserChatMessage(attribute.Query), }; - ClientResult chatResponse = await this.chatClient.CompleteChatAsync(messages); + ChatCompletionOptions completionOptions = attribute.BuildRequest(); + ClientResult chatResponse = await this.openAIClientFactory.GetChatClient(attribute.AIConnectionName, attribute.ChatModel).CompleteChatAsync(messages, completionOptions); // Give the user the full context, including the embeddings information as well as the chat info return new SemanticSearchContext(new EmbeddingsContext(new List { attribute.Query }, null), chatResponse); diff --git a/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs b/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs index 7e6bf10d..0d319171 100644 --- a/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs @@ -2,8 +2,7 @@ // Licensed under the MIT License. using Microsoft.Azure.WebJobs.Description; -using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; -using OpenAI.Chat; +using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI; @@ -12,13 +11,14 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI; /// [Binding] [AttributeUsage(AttributeTargets.Parameter)] -public sealed class TextCompletionAttribute : Attribute +public sealed class TextCompletionAttribute : AssistantBaseAttribute { /// /// Initializes a new instance of the class with the specified text prompt. /// /// The prompt to generate completions for, encoded as a string. - public TextCompletionAttribute(string prompt) + /// The name of the configuration section for AI service connectivity settings. + public TextCompletionAttribute(string prompt, string aiConnectionName = "") : base(aiConnectionName) { this.Prompt = prompt ?? throw new ArgumentNullException(nameof(prompt)); } @@ -28,62 +28,4 @@ public TextCompletionAttribute(string prompt) /// [AutoResolve] public string Prompt { get; } - - /// - /// Gets or sets the ID of the model to use. - /// - [AutoResolve] - public string Model { get; set; } = OpenAIModels.DefaultChatModel; - - /// - /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output - /// more random, while lower values like 0.2 will make it more focused and deterministic. - /// - /// - /// It's generally recommend to use this or but not both. - /// - [AutoResolve] - public string? Temperature { get; set; } = "0.5"; - - /// - /// Gets or sets an alternative to sampling with temperature, called nucleus sampling, where the model considers - /// the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% - /// probability mass are considered. - /// - /// - /// It's generally recommend to use this or but not both. - /// - [AutoResolve] - public string? TopP { get; set; } - - /// - /// Gets or sets the maximum number of tokens to generate in the completion. - /// - /// - /// The token count of your prompt plus max_tokens cannot exceed the model's context length. - /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). - /// - [AutoResolve] - public string? MaxTokens { get; set; } = "100"; - - internal ChatCompletionOptions BuildRequest() - { - ChatCompletionOptions request = new(); - if (int.TryParse(this.MaxTokens, out int maxTokens)) - { - request.MaxOutputTokenCount = maxTokens; - } - - if (float.TryParse(this.Temperature, out float temperature)) - { - request.Temperature = temperature; - } - - if (float.TryParse(this.TopP, out float topP)) - { - request.TopP = topP; - } - - return request; - } } diff --git a/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs b/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs index 8f183be5..d489ad65 100644 --- a/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs @@ -2,7 +2,6 @@ // Licensed under the MIT License. using System.ClientModel; -using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; using Microsoft.Extensions.Logging; using Newtonsoft.Json; @@ -14,13 +13,12 @@ class TextCompletionConverter : IAsyncConverter, IAsyncConverter { - readonly ChatClient chatClient; + readonly OpenAIClientFactory openAIClientFactory; readonly ILogger logger; - public TextCompletionConverter(AzureOpenAIClient openAIClient, ILoggerFactory loggerFactory) + public TextCompletionConverter(OpenAIClientFactory openAIClientFactory, ILoggerFactory loggerFactory) { - // ToDo: Handle the model retrieval better - this.chatClient = openAIClient.GetChatClient(deploymentName: Environment.GetEnvironmentVariable("CHAT_MODEL_DEPLOYMENT_NAME")) ?? throw new ArgumentNullException(nameof(openAIClient)); + this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory)); this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); } @@ -53,7 +51,9 @@ async Task ConvertCoreAsync( new UserChatMessage(attribute.Prompt) }; - ClientResult response = await this.chatClient.CompleteChatAsync(chatMessages, options); + ClientResult response = await this.openAIClientFactory.GetChatClient( + attribute.AIConnectionName, + attribute.ChatModel).CompleteChatAsync(chatMessages, options); string text = string.Join( Environment.NewLine + Environment.NewLine, diff --git a/tests/SampleValidation/AssistantTests.cs b/tests/SampleValidation/AssistantTests.cs index e6eb44f6..fc53f512 100644 --- a/tests/SampleValidation/AssistantTests.cs +++ b/tests/SampleValidation/AssistantTests.cs @@ -71,7 +71,7 @@ public async Task AddTodoTest() Assert.StartsWith("text/plain", questionResponse.Content.Headers.ContentType?.MediaType); // Ensure that the model responded and mentioned the new todo item. - await ValidateAssistantResponseAsync(expectedMessageCount: 4, expectedContent: "Buy milk", hasTotalTokens: true); + await ValidateAssistantResponseAsync(expectedMessageCount: 5, expectedContent: "Buy milk", hasTotalTokens: true); // Local function to validate each chat bot response async Task ValidateAssistantResponseAsync(int expectedMessageCount, string expectedContent, bool hasTotalTokens = false) @@ -107,7 +107,7 @@ async Task ValidateAssistantResponseAsync(int expectedMessageCount, string expec // The timestamp filter should ensure we only ever look at the most recent messages Assert.True(messageArray!.Count <= totalMessages); - Assert.True(messageArray!.Count <= 4); + Assert.True(messageArray!.Count <= 5); if (totalMessages >= expectedMessageCount) { @@ -118,13 +118,23 @@ async Task ValidateAssistantResponseAsync(int expectedMessageCount, string expec { // Make sure the first message is the system message JsonNode systemMessage = messageArray!.First()!; - Assert.Equal("system", systemMessage["role"]?.GetValue()); + Assert.Equal("system", systemMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase); } else { + // Validate the third message contains the toolcalls string with task description + if (messageArray.Count >= 2) + { + JsonNode thirdMessage = messageArray![1]!; + Assert.Equal("assistant", thirdMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase); + string? thirdMessageToolCalls = thirdMessage["toolCalls"]?.GetValue(); + Assert.NotNull(thirdMessageToolCalls); + Assert.Contains("AddTodo", thirdMessageToolCalls, StringComparison.OrdinalIgnoreCase); + } + // Make sure that the last message is from the chat bot (assistant) JsonNode lastMessage = messageArray![messageArray.Count - 1]!; - Assert.Equal("assistant", lastMessage["role"]?.GetValue()); + Assert.Equal("assistant", lastMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase); Assert.Contains("Buy milk", lastMessage!["content"]?.GetValue()); } diff --git a/tests/SampleValidation/Chat.cs b/tests/SampleValidation/Chat.cs index e6719e72..386b1f85 100644 --- a/tests/SampleValidation/Chat.cs +++ b/tests/SampleValidation/Chat.cs @@ -125,13 +125,13 @@ async Task ValidateAssistantResponseAsync(int expectedMessageCount, string expec { // Make sure the first message is the system message JsonNode systemMessage = messageArray!.First()!; - Assert.Equal("system", systemMessage["role"]?.GetValue()); + Assert.Equal("system", systemMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase); } else { // Make sure that the last message is from the chat bot (assistant) JsonNode lastMessage = messageArray![messageArray.Count - 1]!; - Assert.Equal("assistant", lastMessage["role"]?.GetValue()); + Assert.Equal("assistant", lastMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase); Assert.StartsWith("Yo!", lastMessage!["content"]?.GetValue()); } diff --git a/tests/SampleValidation/EmbeddingsTests.cs b/tests/SampleValidation/EmbeddingsTests.cs new file mode 100644 index 00000000..4e45fe69 --- /dev/null +++ b/tests/SampleValidation/EmbeddingsTests.cs @@ -0,0 +1,83 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Net; +using System.Net.Http.Json; +using Xunit; +using Xunit.Abstractions; + +namespace SampleValidation; + +public class EmbeddingsTests +{ + readonly ITestOutputHelper output; + readonly string baseAddress; + readonly HttpClient client; + readonly CancellationTokenSource cts; + + public EmbeddingsTests(ITestOutputHelper output) + { + this.output = output; + this.client = new HttpClient(new LoggingHandler(this.output)); + this.cts = new CancellationTokenSource(delay: TimeSpan.FromMinutes(Debugger.IsAttached ? 5 : 1)); + + this.baseAddress = Environment.GetEnvironmentVariable("FUNC_BASE_ADDRESS") ?? "http://localhost:7071"; + +#if RELEASE + // Use the default key for the Azure Functions app in RELEASE mode; for local development, DEBUG mode can be used. + string functionKey = Environment.GetEnvironmentVariable("FUNC_DEFAULT_KEY") ?? throw new InvalidOperationException("Missing environment variable 'FUNC_DEFAULT_KEY'"); + client.DefaultRequestHeaders.Add("x-functions-key", functionKey); +#endif + } + + [Fact] + public async Task GenerateEmbeddings_Http_Request_Test() + { + // Arrange + var request = new + { + rawText = "This is a test for generating embeddings from raw text." + }; + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/embeddings", + request, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.NoContent, response.StatusCode); + } + + [Fact] + public async Task EmbeddingsLegacy_Performance_Test() + { + // Create a large text for testing performance + string largeText = new string('A', 10000); // 10KB text + + var request = new + { + rawText = largeText + }; + + // Measure time + var stopwatch = Stopwatch.StartNew(); + + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/embeddings", + request, + cancellationToken: this.cts.Token); + + stopwatch.Stop(); + + // Assert + Assert.Equal(HttpStatusCode.NoContent, response.StatusCode); + + // Log performance metrics + this.output.WriteLine($"Embeddings generation for 10KB text took {stopwatch.ElapsedMilliseconds}ms"); + + // If specific performance SLAs exist, add assertions like: + // Assert.True(stopwatch.ElapsedMilliseconds < 5000, "Embeddings generation took too long"); + } +} \ No newline at end of file From f46d77f996aade635b0bcf5184bcca0be601c6d7 Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Fri, 11 Apr 2025 16:04:26 -0700 Subject: [PATCH 03/21] add text completion tests --- .../TextCompletionAttribute.cs | 4 +- tests/SampleValidation/TextCompletionTests.cs | 131 ++++++++++++++++++ 2 files changed, 134 insertions(+), 1 deletion(-) create mode 100644 tests/SampleValidation/TextCompletionTests.cs diff --git a/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs b/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs index 0d319171..aac4242e 100644 --- a/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs @@ -20,7 +20,9 @@ public sealed class TextCompletionAttribute : AssistantBaseAttribute /// The name of the configuration section for AI service connectivity settings. public TextCompletionAttribute(string prompt, string aiConnectionName = "") : base(aiConnectionName) { - this.Prompt = prompt ?? throw new ArgumentNullException(nameof(prompt)); + this.Prompt = string.IsNullOrEmpty(prompt) + ? throw new ArgumentException("Input cannot be null or empty.", nameof(prompt)) + : prompt; } /// diff --git a/tests/SampleValidation/TextCompletionTests.cs b/tests/SampleValidation/TextCompletionTests.cs new file mode 100644 index 00000000..2cb18ac4 --- /dev/null +++ b/tests/SampleValidation/TextCompletionTests.cs @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Net; +using System.Net.Http.Json; +using Xunit; +using Xunit.Abstractions; + +namespace SampleValidation; + +public class TextCompletionLegacyTests +{ + readonly ITestOutputHelper output; + readonly HttpClient client; + readonly string baseAddress; + readonly CancellationTokenSource cts; + + public TextCompletionLegacyTests(ITestOutputHelper output) + { + this.output = output; + this.client = new HttpClient(new LoggingHandler(this.output)); + this.baseAddress = Environment.GetEnvironmentVariable("FUNC_BASE_ADDRESS") ?? "http://localhost:7071"; + this.cts = new CancellationTokenSource(delay: TimeSpan.FromMinutes(Debugger.IsAttached ? 5 : 1)); + +#if RELEASE + // Use the default key for the Azure Functions app in RELEASE mode; for local development, DEBUG mode can be used. + string functionKey = Environment.GetEnvironmentVariable("FUNC_DEFAULT_KEY") ?? throw new InvalidOperationException("Missing environment variable 'FUNC_DEFAULT_KEY'"); + this.client.DefaultRequestHeaders.Add("x-functions-key", functionKey); +#endif + } + + [Fact] + public async Task WhoIs_ValidName_ReturnsContent() + { + // Arrange + string name = "Albert Einstein"; + + // Act + using HttpResponseMessage response = await this.client.GetAsync( + requestUri: $"{this.baseAddress}/api/whois/{name}", + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + string content = await response.Content.ReadAsStringAsync(this.cts.Token); + this.output.WriteLine($"Response content: {content}"); + + // Validate that content is not empty and contains some relevant information about Albert Einstein + Assert.NotEmpty(content); + Assert.Contains("physicist", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("relativity", content, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task GenericCompletion_ValidPrompt_ReturnsContent() + { + // Arrange + var promptPayload = new { Prompt = "Write a haiku about programming" }; + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/GenericCompletion", + promptPayload, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + string content = await response.Content.ReadAsStringAsync(this.cts.Token); + this.output.WriteLine($"Response content: {content}"); + + // Validate that content is not empty and resembles a haiku structure + Assert.NotEmpty(content); + + // Verify we received some form of haiku (might contain line breaks) + // We can't verify exact content since AI responses vary, but we can check for reasonable length + // and common haiku-related formatting + Assert.True(content.Length >= 10, "Response should contain a meaningful haiku"); + Assert.True(content.Split('\n').Length >= 1, "Haiku should have at least one line"); + } + + [Fact] + public async Task GenericCompletion_EmptyPrompt_ReturnsError() + { + // Arrange + var emptyPrompt = new { Prompt = string.Empty }; + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/GenericCompletion", + emptyPrompt, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + } + + [Fact] + public async Task GenericCompletion_ComplexPrompt_ReturnsValidContent() + { + // Arrange + var complexPrompt = new + { + Prompt = "Explain the difference between synchronous and asynchronous programming in 3 bullet points" + }; + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/GenericCompletion", + complexPrompt, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + string content = await response.Content.ReadAsStringAsync(this.cts.Token); + this.output.WriteLine($"Response content: {content}"); + + // Validate that content is not empty and contains relevant terms + Assert.NotEmpty(content); + Assert.Contains("synchronous", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("asynchronous", content, StringComparison.OrdinalIgnoreCase); + + // Check for bullet point formatting (could be -, *, or numbered) + bool hasBulletFormat = content.Contains('-') || + content.Contains('*') || + content.Contains("1.") || + content.Contains('•'); + + Assert.True(hasBulletFormat, "Response should contain bullet points"); + } +} From a0c58cf26515bd0ea85bc261500049c58cb6e383 Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Fri, 11 Apr 2025 16:43:55 -0700 Subject: [PATCH 04/21] add file prompts tests --- tests/SampleValidation/FilePromptsTests.cs | 76 ++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 tests/SampleValidation/FilePromptsTests.cs diff --git a/tests/SampleValidation/FilePromptsTests.cs b/tests/SampleValidation/FilePromptsTests.cs new file mode 100644 index 00000000..efc3ef68 --- /dev/null +++ b/tests/SampleValidation/FilePromptsTests.cs @@ -0,0 +1,76 @@ +using System.Diagnostics; +using System.Net; +using System.Net.Http.Json; +using System.Text.Json.Nodes; +using Xunit; +using Xunit.Abstractions; + +namespace SampleValidation; + +public class FilePromptTests +{ + readonly ITestOutputHelper output; + readonly string baseAddress; + readonly HttpClient client; + readonly CancellationTokenSource cts; + + public FilePromptTests(ITestOutputHelper output) + { + this.output = output; + this.client = new HttpClient(new LoggingHandler(this.output)); + this.cts = new CancellationTokenSource(delay: TimeSpan.FromMinutes(Debugger.IsAttached ? 5 : 1)); + + this.baseAddress = Environment.GetEnvironmentVariable("FUNC_BASE_ADDRESS") ?? "http://localhost:7071"; + +#if RELEASE + // Use the default key for the Azure Functions app in RELEASE mode; for local development, DEBUG mode can be used. + string functionKey = Environment.GetEnvironmentVariable("FUNC_DEFAULT_KEY") ?? throw new InvalidOperationException("Missing environment variable 'FUNC_DEFAULT_KEY'"); + client.DefaultRequestHeaders.Add("x-functions-key", functionKey); +#endif + } + + [Fact] + public async Task IngestFile_ValidUrl_ReturnsSuccess() + { + // Prepare the request + var request = new { url = "https://github.com/Azure/azure-functions-openai-extension/blob/main/README.md" }; + + // Send the POST request to IngestFile + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/IngestFile", + request, + cancellationToken: this.cts.Token); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.StartsWith("application/json", response.Content.Headers.ContentType?.MediaType); + + // Validate the response content + string responseContent = await response.Content.ReadAsStringAsync(this.cts.Token); + JsonNode? jsonResponse = JsonNode.Parse(responseContent); + Assert.NotNull(jsonResponse); + Assert.Equal("success", jsonResponse!["status"]?.GetValue()); + Assert.Equal("README.md", jsonResponse!["title"]?.GetValue()); + } + + [Fact] + public async Task PromptFile_ValidPrompt_ReturnsResponse() + { + // Prepare the request + var request = new { prompt = "How can the textCompletion input binding be used from Azure Functions OpenAI extension?" }; + + // Send the POST request to PromptFile + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/PromptFile", + request, + cancellationToken: this.cts.Token); + + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.StartsWith("text/plain", response.Content.Headers.ContentType?.MediaType); + + // Validate the response content + string responseContent = await response.Content.ReadAsStringAsync(this.cts.Token); + Assert.False(string.IsNullOrWhiteSpace(responseContent)); + Assert.Contains("OpenAI Chat Completions API", responseContent, StringComparison.OrdinalIgnoreCase); + Assert.Contains("README", responseContent, StringComparison.OrdinalIgnoreCase); + } +} \ No newline at end of file From 05c61a746d2a95c81dcb045b02b58033d755907e Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Sun, 13 Apr 2025 01:05:05 -0700 Subject: [PATCH 05/21] refactor and add unit tests --- OpenAI-Extension.sln | 7 + eng/ci/templates/build-local.yml | 4 + .../csharp-legacy/AssistantSample.csproj | 2 +- .../csharp-ooproc/AssistantSample.csproj | 4 +- samples/chat/csharp-ooproc/ChatBot.csproj | 2 +- .../Embeddings/Embeddings.csproj | 2 +- .../SemanticAISearchEmbeddings.csproj | 2 +- .../SemanticCosmosDBSearchEmbeddings.csproj | 2 +- .../SemanticSearchEmbeddings.csproj | 2 +- .../csharp-ooproc/TextCompletion.csproj | 2 +- .../Functions.Worker.Extensions.OpenAI.csproj | 2 +- .../Assistants/AssistantService.cs | 357 +++++++----- .../Embeddings/EmbeddingsHelper.cs | 6 + .../Models/ChatMessageTableEntity.cs | 13 +- .../OpenAIClientFactory.cs | 2 +- .../OpenAIExtension.cs | 1 - .../Properties/AssemblyInfo.cs | 7 + .../TextCompletionConverter.cs | 2 +- .../WebJobs.Extensions.OpenAI.csproj | 2 +- tests/SampleValidation/EmbeddingsTests.cs | 66 ++- tests/SampleValidation/FilePromptsTests.cs | 56 +- tests/UnitTests/AssistantServiceTests.cs | 520 ++++++++++++++++++ tests/UnitTests/OpenAIClientFactoryTests.cs | 383 +++++++++++++ tests/UnitTests/WebJobsOpenAIUnitTests.csproj | 29 + 24 files changed, 1278 insertions(+), 197 deletions(-) create mode 100644 src/WebJobs.Extensions.OpenAI/Properties/AssemblyInfo.cs create mode 100644 tests/UnitTests/AssistantServiceTests.cs create mode 100644 tests/UnitTests/OpenAIClientFactoryTests.cs create mode 100644 tests/UnitTests/WebJobsOpenAIUnitTests.csproj diff --git a/OpenAI-Extension.sln b/OpenAI-Extension.sln index 6c86f9cc..2fee6c2d 100644 --- a/OpenAI-Extension.sln +++ b/OpenAI-Extension.sln @@ -118,6 +118,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "KustoSearchLegacy", "sample EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TextCompletionLegacy", "samples\textcompletion\csharp-legacy\TextCompletionLegacy.csproj", "{73C26271-7B8D-4B38-B454-D9902B92270F}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "WebJobsOpenAIUnitTests", "tests\UnitTests\WebJobsOpenAIUnitTests.csproj", "{52337999-5676-27FD-AE3D-CAAE406AE337}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -216,6 +218,10 @@ Global {73C26271-7B8D-4B38-B454-D9902B92270F}.Debug|Any CPU.Build.0 = Debug|Any CPU {73C26271-7B8D-4B38-B454-D9902B92270F}.Release|Any CPU.ActiveCfg = Release|Any CPU {73C26271-7B8D-4B38-B454-D9902B92270F}.Release|Any CPU.Build.0 = Release|Any CPU + {52337999-5676-27FD-AE3D-CAAE406AE337}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {52337999-5676-27FD-AE3D-CAAE406AE337}.Debug|Any CPU.Build.0 = Debug|Any CPU + {52337999-5676-27FD-AE3D-CAAE406AE337}.Release|Any CPU.ActiveCfg = Release|Any CPU + {52337999-5676-27FD-AE3D-CAAE406AE337}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -258,6 +264,7 @@ Global {BB20B2C0-35AE-CC75-61CE-C2713C83E4F6} = {4F208B03-C74C-458B-8300-9FF82F3C6325} {C5ACB61F-CDAC-E93F-895E-4D46F086DE61} = {B87CBFFB-8221-45E8-8631-4BB1685D50F0} {73C26271-7B8D-4B38-B454-D9902B92270F} = {5315AE81-BC33-49F2-935B-69287BB44CBD} + {52337999-5676-27FD-AE3D-CAAE406AE337} = {B99948E2-9853-4D9C-B784-19C2503CDA8B} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {05A7A679-3A53-45B5-AE93-4313655E127D} diff --git a/eng/ci/templates/build-local.yml b/eng/ci/templates/build-local.yml index a12251f8..80337141 100644 --- a/eng/ci/templates/build-local.yml +++ b/eng/ci/templates/build-local.yml @@ -22,6 +22,10 @@ jobs: dotnet build $(System.DefaultWorkingDirectory)/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/WebJobs.Extensions.OpenAI.CosmosDBSearch.csproj --configuration $(config) -p:Version=$(fakeWebJobsPackageVersion) -p:CosmosDBSearchVersion=$(fakeWebJobsPackageVersion) displayName: Dotnet Build WebJobs.Extensions.OpenAI + - script: | + dotnet test $(System.DefaultWorkingDirectory)/tests/UnitTests/WebJobsOpenAIUnitTests.csproj --configuration $(config) --collect "Code Coverage" + displayName: Dotnet Test WebJobsOpenAIUnitTests + - task: CopyFiles@2 displayName: 'Copy NuGet WebJobs.Extensions.OpenAI to local directory' inputs: diff --git a/samples/assistant/csharp-legacy/AssistantSample.csproj b/samples/assistant/csharp-legacy/AssistantSample.csproj index d60abd4a..711f9628 100644 --- a/samples/assistant/csharp-legacy/AssistantSample.csproj +++ b/samples/assistant/csharp-legacy/AssistantSample.csproj @@ -7,7 +7,7 @@ - + diff --git a/samples/assistant/csharp-ooproc/AssistantSample.csproj b/samples/assistant/csharp-ooproc/AssistantSample.csproj index d6e9140a..2087a8de 100644 --- a/samples/assistant/csharp-ooproc/AssistantSample.csproj +++ b/samples/assistant/csharp-ooproc/AssistantSample.csproj @@ -9,11 +9,11 @@ - + - + diff --git a/samples/chat/csharp-ooproc/ChatBot.csproj b/samples/chat/csharp-ooproc/ChatBot.csproj index 12dd29a8..e45a1e8c 100644 --- a/samples/chat/csharp-ooproc/ChatBot.csproj +++ b/samples/chat/csharp-ooproc/ChatBot.csproj @@ -11,7 +11,7 @@ - + diff --git a/samples/embeddings/csharp-ooproc/Embeddings/Embeddings.csproj b/samples/embeddings/csharp-ooproc/Embeddings/Embeddings.csproj index 61b2ebf6..464519f2 100644 --- a/samples/embeddings/csharp-ooproc/Embeddings/Embeddings.csproj +++ b/samples/embeddings/csharp-ooproc/Embeddings/Embeddings.csproj @@ -9,7 +9,7 @@ - + diff --git a/samples/rag-aisearch/csharp-ooproc/SemanticAISearchEmbeddings.csproj b/samples/rag-aisearch/csharp-ooproc/SemanticAISearchEmbeddings.csproj index feef3710..2740dd2b 100644 --- a/samples/rag-aisearch/csharp-ooproc/SemanticAISearchEmbeddings.csproj +++ b/samples/rag-aisearch/csharp-ooproc/SemanticAISearchEmbeddings.csproj @@ -11,7 +11,7 @@ - + diff --git a/samples/rag-cosmosdb/csharp-ooproc/SemanticCosmosDBSearchEmbeddings.csproj b/samples/rag-cosmosdb/csharp-ooproc/SemanticCosmosDBSearchEmbeddings.csproj index 6c3e0147..9b114e87 100644 --- a/samples/rag-cosmosdb/csharp-ooproc/SemanticCosmosDBSearchEmbeddings.csproj +++ b/samples/rag-cosmosdb/csharp-ooproc/SemanticCosmosDBSearchEmbeddings.csproj @@ -11,7 +11,7 @@ - + diff --git a/samples/rag-kusto/csharp-ooproc/SemanticSearchEmbeddings.csproj b/samples/rag-kusto/csharp-ooproc/SemanticSearchEmbeddings.csproj index cb5db855..71a9aac5 100644 --- a/samples/rag-kusto/csharp-ooproc/SemanticSearchEmbeddings.csproj +++ b/samples/rag-kusto/csharp-ooproc/SemanticSearchEmbeddings.csproj @@ -10,7 +10,7 @@ - + diff --git a/samples/textcompletion/csharp-ooproc/TextCompletion.csproj b/samples/textcompletion/csharp-ooproc/TextCompletion.csproj index 74c6a35e..2b9b00ec 100644 --- a/samples/textcompletion/csharp-ooproc/TextCompletion.csproj +++ b/samples/textcompletion/csharp-ooproc/TextCompletion.csproj @@ -11,7 +11,7 @@ - + diff --git a/src/Functions.Worker.Extensions.OpenAI/Functions.Worker.Extensions.OpenAI.csproj b/src/Functions.Worker.Extensions.OpenAI/Functions.Worker.Extensions.OpenAI.csproj index a01f9810..aac240ec 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Functions.Worker.Extensions.OpenAI.csproj +++ b/src/Functions.Worker.Extensions.OpenAI/Functions.Worker.Extensions.OpenAI.csproj @@ -7,7 +7,7 @@ - + diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs index 8f8baf22..cd5726b3 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs @@ -178,36 +178,58 @@ public async Task GetStateAsync(AssistantQueryAttribute assistan public async Task PostMessageAsync(AssistantPostAttribute attribute, CancellationToken cancellationToken) { + // Validate inputs and prepare for processing DateTime timeFilter = DateTime.UtcNow; - if (string.IsNullOrEmpty(attribute.Id)) - { - throw new ArgumentException("The assistant ID must be specified.", nameof(attribute)); - } - - if (string.IsNullOrEmpty(attribute.UserMessage)) - { - throw new ArgumentException("The assistant must have a user message", nameof(attribute)); - } + this.ValidateAttributes(attribute); this.logger.LogInformation("Posting message to assistant entity '{Id}'", attribute.Id); - TableClient tableClient = this.GetOrCreateTableClient(attribute.ChatStorageConnectionSetting, attribute.CollectionName); + // Load and validate chat state InternalChatState? chatState = await this.LoadChatStateAsync(attribute.Id, tableClient, cancellationToken); - - // Check if assistant has been deactivated if (chatState is null || !chatState.Metadata.Exists) { - this.logger.LogWarning("[{Id}] Ignoring request sent to nonexistent assistant.", attribute.Id); - return new AssistantState(attribute.Id, false, default, default, 0, 0, Array.Empty()); + return this.CreateNonExistentAssistantState(attribute.Id); } this.logger.LogInformation("[{Id}] Received message: {Text}", attribute.Id, attribute.UserMessage); - // Create a batch of table transaction actions + // Process the conversation List batch = new(); + this.AddUserMessageToChat(attribute, chatState, batch); + + await this.ProcessConversationWithLLM(attribute, chatState, batch, cancellationToken); - // Add the user message as a new Chat message entity + // Update state and persist changes + this.UpdateAssistantState(chatState, batch); + await tableClient.SubmitTransactionAsync(batch, cancellationToken); + + // Return results + return this.CreateAssistantStateResponse(attribute.Id, chatState, timeFilter); + } + + // Helper methods + void ValidateAttributes(AssistantPostAttribute attribute) + { + if (string.IsNullOrEmpty(attribute.Id)) + { + throw new ArgumentException("The assistant ID must be specified.", nameof(attribute)); + } + + if (string.IsNullOrEmpty(attribute.UserMessage)) + { + throw new ArgumentException("The assistant must have a user message", nameof(attribute)); + } + } + + AssistantState CreateNonExistentAssistantState(string id) + { + this.logger.LogWarning("[{Id}] Ignoring request sent to nonexistent assistant.", id); + return new AssistantState(id, false, default, default, 0, 0, Array.Empty()); + } + + void AddUserMessageToChat(AssistantPostAttribute attribute, InternalChatState chatState, List batch) + { ChatMessageTableEntity chatMessageEntity = new( partitionKey: attribute.Id, messageIndex: ++chatState.Metadata.TotalMessages, @@ -215,70 +237,35 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib role: ChatMessageRole.User, toolCalls: null); chatState.Messages.Add(chatMessageEntity); - - // Add the chat message to the batch batch.Add(new TableTransactionAction(TableTransactionActionType.Add, chatMessageEntity)); + } + async Task ProcessConversationWithLLM( + AssistantPostAttribute attribute, + InternalChatState chatState, + List batch, + CancellationToken cancellationToken) + { IList? functions = this.skillInvoker.GetFunctionsDefinitions(); // We loop if the model returns function calls. Otherwise, we break after receiving a response. while (true) { - // Get the next response from the LLM - ChatCompletionOptions chatRequest = attribute.BuildRequest(); - if (functions is not null) - { - foreach (ChatTool fn in functions) - { - chatRequest.Tools.Add(fn); - } - } - - IEnumerable chatMessages = ToOpenAIChatRequestMessages(chatState.Messages); - - ClientResult response = await this.openAIClientFactory.GetChatClient( - attribute.AIConnectionName, - attribute.ChatModel).CompleteChatAsync(chatMessages, chatRequest, cancellationToken: cancellationToken); + // Get the LLM response + ClientResult response = await this.GetLLMResponse(attribute, chatState, functions, cancellationToken); - // We don't normally expect more than one message, but just in case we get multiple messages, - // return all of them separated by two newlines. - string replyMessage = string.Join( - Environment.NewLine + Environment.NewLine, - response.Value.Content.Select(message => message.Text)); + // Process text response if available + string replyMessage = this.FormatReplyMessage(response); if (!string.IsNullOrWhiteSpace(replyMessage) || response.Value.ToolCalls.Any()) { - this.logger.LogInformation( - "[{Id}] Got LLM response consisting of {Count} tokens: [{Text}] && {Count} ToolCalls", - attribute.Id, - response.Value.Usage.OutputTokenCount, - replyMessage, - response.Value.ToolCalls.Count); - - // Add the user message as a new Chat message entity - ChatMessageTableEntity replyFromAssistantEntity = new( - partitionKey: attribute.Id, - messageIndex: ++chatState.Metadata.TotalMessages, - content: replyMessage, - role: ChatMessageRole.Assistant, - toolCalls: response.Value.ToolCalls); - chatState.Messages.Add(replyFromAssistantEntity); - - // Add the reply from assistant chat message to the batch - batch.Add(new TableTransactionAction(TableTransactionActionType.Add, replyFromAssistantEntity)); - - this.logger.LogInformation( - "[{Id}] Chat length is now {Count} messages", - attribute.Id, - chatState.Metadata.TotalMessages); + this.LogAndAddAssistantReply(attribute.Id, replyMessage, response, chatState, batch); } - // Set the total tokens that have been consumed. + // Update token count chatState.Metadata.TotalTokens = response.Value.Usage.TotalTokenCount; - // Check for function calls (which are described in the API as tools) - List functionCalls = response.Value.ToolCalls - .OfType() - .ToList(); + // Handle function calls + List functionCalls = response.Value.ToolCalls.OfType().ToList(); if (functionCalls.Count == 0) { // No function calls, so we're done @@ -287,89 +274,175 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib if (batch.Count > FunctionCallBatchLimit) { - // Too many function calls, something might be wrong. Break out of the loop - // to avoid infinite loops and to avoid exceeding the batch size limit of 100. - this.logger.LogWarning( - "[{Id}] Ignoring {Count} function call(s) in response due to exceeding the limit of {Limit}.", - attribute.Id, - functionCalls.Count, - FunctionCallBatchLimit); + // Too many function calls, something might be wrong + this.LogBatchLimitExceeded(attribute.Id, functionCalls.Count); break; } - // Loop case: found some functions to execute - this.logger.LogInformation( - "[{Id}] Found {Count} function call(s) in response", - attribute.Id, - functionCalls.Count); - + // Process function calls + await this.ProcessFunctionCalls(attribute.Id, functionCalls, chatState, batch, cancellationToken); + } + } - // Invoke the function calls and add the responses to the chat history. - List> tasks = new(capacity: functionCalls.Count); - foreach (ChatToolCall call in functionCalls) + async Task> GetLLMResponse( + AssistantPostAttribute attribute, + InternalChatState chatState, + IList? functions, + CancellationToken cancellationToken) + { + ChatCompletionOptions chatRequest = attribute.BuildRequest(); + if (functions is not null) + { + foreach (ChatTool fn in functions) { - // CONSIDER: Call these in parallel - this.logger.LogInformation( - "[{Id}] Calling function '{Name}' with arguments: {Args}", - attribute.Id, - call.FunctionName, - call.FunctionArguments); - - string? functionResult; - try - { - // NOTE: In Consumption plans, calling a function from another function results in double-billing. - // CONSIDER: Use a background thread to invoke the action to avoid double-billing. - functionResult = await this.skillInvoker.InvokeAsync(call, cancellationToken); - - this.logger.LogInformation( - "[{id}] Function '{Name}' returned the following content: {Content}", - attribute.Id, - call.FunctionName, - functionResult); - } - catch (Exception ex) - { - this.logger.LogError( - ex, - "[{id}] Function '{Name}' failed with an unhandled exception", - attribute.Id, - call.FunctionName); - - // CONSIDER: Automatic retries? - functionResult = "The function call failed. Let the user know and ask if they'd like you to try again"; - } - - if (string.IsNullOrWhiteSpace(functionResult)) - { - // When experimenting with gpt-4-0613, an empty result would cause the model to go into a - // function calling loop. By instead providing a result with some instructions, we were able - // to get the model to response to the user in a natural way. - functionResult = "The function call succeeded. Let the user know that you completed the action."; - } - - ChatMessageTableEntity functionResultEntity = new( - partitionKey: attribute.Id, - messageIndex: ++chatState.Metadata.TotalMessages, - content: $"Function Name: '{call.FunctionName}' and Function Result: '{functionResult}'", - role: ChatMessageRole.Tool, - name: call.Id, - toolCalls: null); - chatState.Messages.Add(functionResultEntity); - - batch.Add(new TableTransactionAction(TableTransactionActionType.Add, functionResultEntity)); + chatRequest.Tools.Add(fn); } } - // Update the assistant state entity + IEnumerable chatMessages = ToOpenAIChatRequestMessages(chatState.Messages); + + return await this.openAIClientFactory.GetChatClient( + attribute.AIConnectionName, + attribute.ChatModel).CompleteChatAsync(chatMessages, chatRequest, cancellationToken: cancellationToken); + } + + string FormatReplyMessage(ClientResult response) + { + return string.Join( + Environment.NewLine + Environment.NewLine, + response.Value.Content.Select(message => message.Text)); + } + + void LogAndAddAssistantReply( + string assistantId, + string replyMessage, + ClientResult response, + InternalChatState chatState, + List batch) + { + this.logger.LogInformation( + "[{Id}] Got LLM response consisting of {Count} tokens: [{Text}] && {Count} ToolCalls", + assistantId, + response.Value.Usage.OutputTokenCount, + replyMessage, + response.Value.ToolCalls.Count); + + ChatMessageTableEntity replyFromAssistantEntity = new( + partitionKey: assistantId, + messageIndex: ++chatState.Metadata.TotalMessages, + content: replyMessage, + role: ChatMessageRole.Assistant, + toolCalls: response.Value.ToolCalls); + + chatState.Messages.Add(replyFromAssistantEntity); + batch.Add(new TableTransactionAction(TableTransactionActionType.Add, replyFromAssistantEntity)); + + this.logger.LogInformation( + "[{Id}] Chat length is now {Count} messages", + assistantId, + chatState.Metadata.TotalMessages); + } + + void LogBatchLimitExceeded(string assistantId, int functionCallCount) + { + this.logger.LogWarning( + "[{Id}] Ignoring {Count} function call(s) in response due to exceeding the limit of {Limit}.", + assistantId, + functionCallCount, + FunctionCallBatchLimit); + } + + async Task ProcessFunctionCalls( + string assistantId, + List functionCalls, + InternalChatState chatState, + List batch, + CancellationToken cancellationToken) + { + this.logger.LogInformation( + "[{Id}] Found {Count} function call(s) in response", + assistantId, + functionCalls.Count); + + foreach (ChatToolCall call in functionCalls) + { + await this.ProcessSingleFunctionCall(assistantId, call, chatState, batch, cancellationToken); + } + } + + async Task ProcessSingleFunctionCall( + string assistantId, + ChatToolCall call, + InternalChatState chatState, + List batch, + CancellationToken cancellationToken) + { + this.logger.LogInformation( + "[{Id}] Calling function '{Name}' with arguments: {Args}", + assistantId, + call.FunctionName, + call.FunctionArguments); + + string? functionResult = await this.InvokeFunctionWithErrorHandling(assistantId, call, cancellationToken); + + if (string.IsNullOrWhiteSpace(functionResult)) + { + functionResult = "The function call succeeded. Let the user know that you completed the action."; + } + + ChatMessageTableEntity functionResultEntity = new( + partitionKey: assistantId, + messageIndex: ++chatState.Metadata.TotalMessages, + content: $"Function Name: '{call.FunctionName}' and Function Result: '{functionResult}'", + role: ChatMessageRole.Tool, + name: call.Id, + toolCalls: null); + + chatState.Messages.Add(functionResultEntity); + batch.Add(new TableTransactionAction(TableTransactionActionType.Add, functionResultEntity)); + } + + async Task InvokeFunctionWithErrorHandling( + string assistantId, + ChatToolCall call, + CancellationToken cancellationToken) + { + try + { + // NOTE: In Consumption plans, calling a function from another function results in double-billing. + // CONSIDER: Use a background thread to invoke the action to avoid double-billing. + string? result = await this.skillInvoker.InvokeAsync(call, cancellationToken); + + this.logger.LogInformation( + "[{id}] Function '{Name}' returned the following content: {Content}", + assistantId, + call.FunctionName, + result); + + return result; + } + catch (Exception ex) + { + this.logger.LogError( + ex, + "[{id}] Function '{Name}' failed with an unhandled exception", + assistantId, + call.FunctionName); + + // CONSIDER: Automatic retries? + return "The function call failed. Let the user know and ask if they'd like you to try again"; + } + } + + void UpdateAssistantState(InternalChatState chatState, List batch) + { chatState.Metadata.TotalMessages = chatState.Messages.Count; chatState.Metadata.LastUpdatedAt = DateTime.UtcNow; batch.Add(new TableTransactionAction(TableTransactionActionType.UpdateMerge, chatState.Metadata)); + } - // Add the batch of table transaction actions to the table - await tableClient.SubmitTransactionAsync(batch, cancellationToken); - - // return the latest assistant message in the chat state + AssistantState CreateAssistantStateResponse(string assistantId, InternalChatState chatState, DateTime timeFilter) + { List filteredChatMessages = chatState.Messages .Where(msg => msg.CreatedAt > timeFilter && msg.Role == ChatMessageRole.Assistant.ToString()) .ToList(); @@ -378,18 +451,16 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib "Returning {Count}/{Total} chat messages from entity '{Id}'", filteredChatMessages.Count, chatState.Metadata.TotalMessages, - attribute.Id); + assistantId); - AssistantState state = new( - attribute.Id, + return new AssistantState( + assistantId, true, chatState.Metadata.CreatedAt, chatState.Metadata.LastUpdatedAt, chatState.Metadata.TotalMessages, chatState.Metadata.TotalTokens, filteredChatMessages.Select(msg => new AssistantMessage(msg.Content, msg.Role, msg.ToolCallsString)).ToList()); - - return state; } async Task LoadChatStateAsync(string id, TableClient tableClient, CancellationToken cancellationToken) @@ -493,7 +564,7 @@ TableClient GetOrCreateTableClient(string? chatStorageConnectionSetting, string? // Else, will use the connection string connectionStringName = chatStorageConnectionSetting ?? DefaultChatStorage; string connectionString = this.configuration.GetValue(connectionStringName); - + this.logger.LogInformation("using connection string for table service client"); this.tableServiceClient = new TableServiceClient(connectionString); diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs index 5ce78d9e..a8cb6166 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs @@ -40,6 +40,12 @@ static async Task GetTextReader(InputType inputType, string input) } else if (inputType == InputType.Url) { + if (!Uri.TryCreate(input, UriKind.Absolute, out Uri? uriResult) || + uriResult.Scheme != Uri.UriSchemeHttps) + { + throw new ArgumentException($"Invalid Url: {input}. Ensure it is a valid https Url."); + } + Stream stream = await httpClient.GetStreamAsync(input); return new StreamReader(stream); } diff --git a/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs b/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs index e38e5f12..00b72bb7 100644 --- a/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs +++ b/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs @@ -121,8 +121,8 @@ public string ToolCallsString { if (!string.IsNullOrEmpty(value)) { - var options = new JsonSerializerOptions { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; - var cloneList = JsonSerializer.Deserialize>(value, options); + JsonSerializerOptions options = new() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; + List? cloneList = JsonSerializer.Deserialize>(value, options); this.ToolCalls = cloneList != null ? this.DeserializeChatTool(cloneList) : new List(); } else @@ -137,7 +137,7 @@ IList SerializeChatTool(IList toolCalls) IList chatToolCloneList = new List(); foreach (ChatToolCall toolCall in toolCalls) { - ChatToolCallClone chatToolClone = new ChatToolCallClone(toolCall.Id, toolCall.FunctionName, toolCall.FunctionArguments.ToString(), toolCall.Kind.ToString()); + ChatToolCallClone chatToolClone = new(toolCall.Id, toolCall.FunctionName, toolCall.FunctionArguments.ToString(), toolCall.Kind.ToString()); chatToolCloneList.Add(chatToolClone); } return chatToolCloneList; @@ -146,11 +146,10 @@ IList SerializeChatTool(IList toolCalls) IList DeserializeChatTool(IList clones) { IList result = new List(); - foreach (var clone in clones) + foreach (ChatToolCallClone clone in clones) { - var kind = Enum.Parse(clone.Kind); - var functionArgs = JsonDocument.Parse(clone.FunctionArguments).RootElement; - var toolCall = ChatToolCall.CreateFunctionToolCall(clone.Id, clone.FunctionName, BinaryData.FromString(functionArgs.GetRawText())); + JsonElement functionArgs = JsonDocument.Parse(clone.FunctionArguments).RootElement; + ChatToolCall toolCall = ChatToolCall.CreateFunctionToolCall(clone.Id, clone.FunctionName, BinaryData.FromString(functionArgs.GetRawText())); result.Add(toolCall); } return result; diff --git a/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs b/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs index 180b508b..51adb201 100644 --- a/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs +++ b/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs @@ -123,7 +123,7 @@ AzureOpenAIClient CreateClientFromConfigSection(string aiConnectionName) } } - string errorMessage = $"Configuration section '{aiConnectionName}' is missing required 'Endpoint' or 'Key' values."; + string errorMessage = $"Configuration section '{aiConnectionName}' is missing required 'Endpoint' and/or 'Key' values."; this.logger.LogError(errorMessage); throw new InvalidOperationException(errorMessage); } diff --git a/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs b/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs index d7fa61ca..c9c31bc8 100644 --- a/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs +++ b/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Description; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; diff --git a/src/WebJobs.Extensions.OpenAI/Properties/AssemblyInfo.cs b/src/WebJobs.Extensions.OpenAI/Properties/AssemblyInfo.cs new file mode 100644 index 00000000..5b5af38e --- /dev/null +++ b/src/WebJobs.Extensions.OpenAI/Properties/AssemblyInfo.cs @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Runtime.CompilerServices; + +// Make internals visible to the test project +[assembly: InternalsVisibleTo("WebJobsOpenAIUnitTests")] \ No newline at end of file diff --git a/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs b/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs index d489ad65..6043dec5 100644 --- a/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs @@ -53,7 +53,7 @@ async Task ConvertCoreAsync( ClientResult response = await this.openAIClientFactory.GetChatClient( attribute.AIConnectionName, - attribute.ChatModel).CompleteChatAsync(chatMessages, options); + attribute.ChatModel).CompleteChatAsync(chatMessages, options, cancellationToken: cancellationToken); string text = string.Join( Environment.NewLine + Environment.NewLine, diff --git a/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj b/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj index e802f7d3..c7a17757 100644 --- a/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj +++ b/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj @@ -11,7 +11,7 @@ - + \ No newline at end of file diff --git a/tests/SampleValidation/EmbeddingsTests.cs b/tests/SampleValidation/EmbeddingsTests.cs index 4e45fe69..d9f215bd 100644 --- a/tests/SampleValidation/EmbeddingsTests.cs +++ b/tests/SampleValidation/EmbeddingsTests.cs @@ -54,7 +54,7 @@ public async Task GenerateEmbeddings_Http_Request_Test() public async Task EmbeddingsLegacy_Performance_Test() { // Create a large text for testing performance - string largeText = new string('A', 10000); // 10KB text + string largeText = new('A', 10000); // 10KB text var request = new { @@ -80,4 +80,68 @@ public async Task EmbeddingsLegacy_Performance_Test() // If specific performance SLAs exist, add assertions like: // Assert.True(stopwatch.ElapsedMilliseconds < 5000, "Embeddings generation took too long"); } + + [Fact] + public async Task GetEmbeddings_Url_ReturnsNoContent() + { + // Arrange + var request = new { url = "https://github.com/Azure/azure-functions-openai-extension/blob/main/README.md" }; + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/embeddings-from-url", + request, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.NoContent, response.StatusCode); + } + + [Fact] + public async Task GenerateEmbeddings_InvalidText_ReturnsBadRequest() + { + // Arrange + var request = new { url = "" }; // Invalid: Empty text + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/embeddings", + request, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + } + + [Fact] + public async Task GetEmbeddings_InvalidFilePath_ReturnsBadRequest() + { + // Arrange + var request = new { filePath = "invalid/file/path.txt" }; // Invalid: Non-existent file path + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/embeddings-from-file", + request, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + } + + [Fact] + public async Task GetEmbeddings_InvalidUrl_ReturnsBadRequest() + { + // Arrange + var request = new { url = "invalid-url" }; // Invalid: Malformed URL + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/embeddings-from-url", + request, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + } } \ No newline at end of file diff --git a/tests/SampleValidation/FilePromptsTests.cs b/tests/SampleValidation/FilePromptsTests.cs index efc3ef68..8eb01abd 100644 --- a/tests/SampleValidation/FilePromptsTests.cs +++ b/tests/SampleValidation/FilePromptsTests.cs @@ -18,7 +18,7 @@ public FilePromptTests(ITestOutputHelper output) { this.output = output; this.client = new HttpClient(new LoggingHandler(this.output)); - this.cts = new CancellationTokenSource(delay: TimeSpan.FromMinutes(Debugger.IsAttached ? 5 : 1)); + this.cts = new CancellationTokenSource(delay: TimeSpan.FromMinutes(Debugger.IsAttached ? 5 : 2)); this.baseAddress = Environment.GetEnvironmentVariable("FUNC_BASE_ADDRESS") ?? "http://localhost:7071"; @@ -30,47 +30,39 @@ public FilePromptTests(ITestOutputHelper output) } [Fact] - public async Task IngestFile_ValidUrl_ReturnsSuccess() - { - // Prepare the request - var request = new { url = "https://github.com/Azure/azure-functions-openai-extension/blob/main/README.md" }; + public async Task Ingest_Prompt_File_Test() + { + // Step 1: Test IngestFile + var ingestRequest = new { url = "https://github.com/Azure/azure-functions-openai-extension/blob/main/README.md" }; - // Send the POST request to IngestFile - using HttpResponseMessage response = await this.client.PostAsJsonAsync( + using HttpResponseMessage ingestResponse = await this.client.PostAsJsonAsync( requestUri: $"{this.baseAddress}/api/IngestFile", - request, + ingestRequest, cancellationToken: this.cts.Token); - Assert.Equal(HttpStatusCode.OK, response.StatusCode); - Assert.StartsWith("application/json", response.Content.Headers.ContentType?.MediaType); + Assert.Equal(HttpStatusCode.OK, ingestResponse.StatusCode); + Assert.StartsWith("application/json", ingestResponse.Content.Headers.ContentType?.MediaType); - // Validate the response content - string responseContent = await response.Content.ReadAsStringAsync(this.cts.Token); - JsonNode? jsonResponse = JsonNode.Parse(responseContent); - Assert.NotNull(jsonResponse); - Assert.Equal("success", jsonResponse!["status"]?.GetValue()); - Assert.Equal("README.md", jsonResponse!["title"]?.GetValue()); - } + string ingestResponseContent = await ingestResponse.Content.ReadAsStringAsync(this.cts.Token); + JsonNode? ingestJsonResponse = JsonNode.Parse(ingestResponseContent); + Assert.NotNull(ingestJsonResponse); + Assert.Equal("success", ingestJsonResponse!["status"]?.GetValue()); + Assert.Equal("README.md", ingestJsonResponse!["title"]?.GetValue()); - [Fact] - public async Task PromptFile_ValidPrompt_ReturnsResponse() - { - // Prepare the request - var request = new { prompt = "How can the textCompletion input binding be used from Azure Functions OpenAI extension?" }; + // Step 2: Test PromptFile + var promptRequest = new { prompt = "How can the textCompletion input binding be used from Azure Functions OpenAI extension?" }; - // Send the POST request to PromptFile - using HttpResponseMessage response = await this.client.PostAsJsonAsync( + using HttpResponseMessage promptResponse = await this.client.PostAsJsonAsync( requestUri: $"{this.baseAddress}/api/PromptFile", - request, + promptRequest, cancellationToken: this.cts.Token); - Assert.Equal(HttpStatusCode.OK, response.StatusCode); - Assert.StartsWith("text/plain", response.Content.Headers.ContentType?.MediaType); + Assert.Equal(HttpStatusCode.OK, promptResponse.StatusCode); + Assert.StartsWith("text/plain", promptResponse.Content.Headers.ContentType?.MediaType); - // Validate the response content - string responseContent = await response.Content.ReadAsStringAsync(this.cts.Token); - Assert.False(string.IsNullOrWhiteSpace(responseContent)); - Assert.Contains("OpenAI Chat Completions API", responseContent, StringComparison.OrdinalIgnoreCase); - Assert.Contains("README", responseContent, StringComparison.OrdinalIgnoreCase); + string promptResponseContent = await promptResponse.Content.ReadAsStringAsync(this.cts.Token); + Assert.False(string.IsNullOrWhiteSpace(promptResponseContent)); + Assert.Contains("OpenAI Chat Completions API", promptResponseContent, StringComparison.OrdinalIgnoreCase); + Assert.Contains("README", promptResponseContent, StringComparison.OrdinalIgnoreCase); } } \ No newline at end of file diff --git a/tests/UnitTests/AssistantServiceTests.cs b/tests/UnitTests/AssistantServiceTests.cs new file mode 100644 index 00000000..0377e21a --- /dev/null +++ b/tests/UnitTests/AssistantServiceTests.cs @@ -0,0 +1,520 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + + +using System.ClientModel; +using System.Threading; +using Azure; +using Azure.Core; +using Azure.Data.Tables; +using Azure.Data.Tables.Models; +using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; +using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; +using Microsoft.Extensions.Azure; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Moq; +using OpenAI.Chat; +using Xunit; + +namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Tests.Assistants; + +public class DefaultAssistantServiceTests +{ + readonly Mock mockOpenAIClientFactory; + readonly Mock mockAzureComponentFactory; + readonly Mock mockConfiguration; + readonly Mock mockSkillInvoker; + readonly Mock mockLoggerFactory; + readonly Mock> mockLogger; + readonly Mock mockTableServiceClient; + readonly Mock mockTableClient; + + public DefaultAssistantServiceTests() + { + this.mockAzureComponentFactory = new Mock(); + this.mockConfiguration = new Mock(); + this.mockSkillInvoker = new Mock(); + this.mockLoggerFactory = new Mock(); + this.mockLogger = new Mock>(); + this.mockTableServiceClient = new Mock(); + this.mockTableClient = new Mock(); + this.mockOpenAIClientFactory = new Mock( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + this.mockLoggerFactory.Setup(x => x.CreateLogger(It.IsAny())) + .Returns(this.mockLogger.Object); + + // Setup table client + this.mockTableServiceClient.Setup(x => x.GetTableClient(It.IsAny())) + .Returns(this.mockTableClient.Object); + } + + [Fact] + public async Task CreateAssistantAsync_WithValidRequest_CreatesAssistantAndMessages() + { + // Arrange + var request = new AssistantCreateRequest("testId", "Test instructions") + { + CollectionName = "ChatState", + ChatStorageConnectionSetting = "AzureWebJobsStorage" + }; + + var mockQueryResult = new List(); + AsyncPageable mockQueryable = MockAsyncPageable.Create(mockQueryResult); + + this.mockTableClient.Setup(x => x.CreateIfNotExistsAsync(It.IsAny())) + .ReturnsAsync(Response.FromValue(new TableItem(request.CollectionName), new Mock().Object)); + + this.mockTableClient.Setup(x => x.QueryAsync( + It.Is(s => s == $"PartitionKey eq '{request.Id}'"), + null, + null, + It.IsAny())) + .Returns(mockQueryable); + + this.mockTableClient.Setup(x => x.SubmitTransactionAsync( + It.IsAny>(), + It.IsAny())) + .ReturnsAsync(Response.FromValue(new List() as IReadOnlyList, new Mock().Object)); + + //// Arrange + //Mock mockSection = CreateMockSection( + // exists: false, + // tableServiceUri: null); + //mockSection.Setup(s => s.Value).Returns("UseDevelopmentStorage=true"); + + //// Setup AzureWebJobsStorage directly + //this.mockConfiguration.Setup(c => c["AzureWebJobsStorage"]).Returns("UseDevelopmentStorage=true"); + //this.mockConfiguration.Setup(c => c.GetSection("AzureWebJobsStorage")).Returns(mockSection.Object); + + // Create the service under test + var assistantService = new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object); + + // Use reflection to set the tableClient field + System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + tableClientField?.SetValue(assistantService, this.mockTableClient.Object); + + // Act + await assistantService.CreateAssistantAsync(request, CancellationToken.None); + + // Assert + this.mockTableClient.Verify(x => x.CreateIfNotExistsAsync(It.IsAny()), Times.Once); + + this.mockTableClient.Verify(x => x.QueryAsync( + It.Is(s => s.Contains(request.Id)), + null, + null, + It.IsAny()), Times.Once); + + this.mockTableClient.Verify(x => x.SubmitTransactionAsync( + It.IsAny>(), + It.IsAny()), Times.Once); + } + + [Fact] + public async Task CreateAssistantAsync_WithLocalDevStorage_CreatesAssistantAndMessages() + { + // Arrange + var request = new AssistantCreateRequest("testId", "Test instructions") + { + CollectionName = "ChatState", + ChatStorageConnectionSetting = "AzureWebJobsStorage" + }; + + var mockQueryResult = new List(); + AsyncPageable mockQueryable = MockAsyncPageable.Create(mockQueryResult); + + // Arrange + Mock mockSection = CreateMockSection( + exists: false, + tableServiceUri: null); + mockSection.Setup(s => s.Value).Returns("UseDevelopmentStorage=true"); + + // Setup AzureWebJobsStorage directly + this.mockConfiguration.Setup(c => c["AzureWebJobsStorage"]).Returns("UseDevelopmentStorage=true"); + this.mockConfiguration.Setup(c => c.GetSection("AzureWebJobsStorage")).Returns(mockSection.Object); + + // Create the service under test + var assistantService = new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object); + + // Act + await assistantService.CreateAssistantAsync(request, CancellationToken.None); + + // Assert + this.mockTableClient.Verify(x => x.CreateIfNotExistsAsync(It.IsAny()), Times.Never); + + this.mockTableClient.Verify(x => x.QueryAsync( + It.Is(s => s.Contains(request.Id)), + null, + null, + It.IsAny()), Times.Never); + + this.mockTableClient.Verify(x => x.SubmitTransactionAsync( + It.IsAny>(), + It.IsAny()), Times.Never); + } + + [Fact] + public async Task CreateAssistantAsync_WithExistingAssistant_DeletesOldEntitiesFirst() + { + // Arrange + var request = new AssistantCreateRequest("testId", "Test instructions") + { + CollectionName = "ChatState", + ChatStorageConnectionSetting = "AzureWebJobsStorage" + }; + + // Create existing entities + var existingEntities = new List + { + new("testId", "state"), + new("testId", "msg-001") + }; + + AsyncPageable mockQueryable = MockAsyncPageable.Create(existingEntities); + + this.mockTableClient.Setup(x => x.CreateIfNotExistsAsync(It.IsAny())) + .ReturnsAsync(Response.FromValue(new TableItem(request.CollectionName), new Mock().Object)); + + this.mockTableClient.Setup(x => x.QueryAsync( + It.Is(s => s == $"PartitionKey eq '{request.Id}'"), + null, + null, + It.IsAny())) + .Returns(mockQueryable); + + this.mockTableClient.Setup(x => x.SubmitTransactionAsync( + It.IsAny>(), + It.IsAny())) + .ReturnsAsync(Response.FromValue(new List() as IReadOnlyList, new Mock().Object)); + + // Create the service under test + var assistantService = new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object); + + // Use reflection to set the tableClient field + System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + tableClientField?.SetValue(assistantService, this.mockTableClient.Object); + + + // Act + await assistantService.CreateAssistantAsync(request, CancellationToken.None); + + // Assert + this.mockTableClient.Verify(x => x.SubmitTransactionAsync( + It.IsAny>(), + It.IsAny()), Times.AtLeastOnce); + + this.mockTableClient.Verify(x => x.SubmitTransactionAsync( + It.Is>( + actions => actions.Any(a => a.ActionType == TableTransactionActionType.Add)), + It.IsAny()), Times.AtLeastOnce); + } + + [Fact] + public async Task GetStateAsync_WithValidId_ReturnsCorrectState() + { + // Arrange + string id = "testId"; + string timestamp = DateTime.UtcNow.AddHours(-1).ToString("o"); + var attribute = new AssistantQueryAttribute(id) + { + TimestampUtc = timestamp, + CollectionName = "testCollection" + }; + + // Create mock entities + var stateEntity = new TableEntity(id, AssistantStateEntity.FixedRowKeyValue) + { + ["Exists"] = true, + ["CreatedAt"] = DateTime.UtcNow.AddDays(-1), + ["LastUpdatedAt"] = DateTime.UtcNow, + ["TotalMessages"] = 2, + ["TotalTokens"] = 100 + }; + + var message1 = new TableEntity(id, "msg-001") + { + ["Content"] = "Test message 1", + ["Role"] = ChatMessageRole.System.ToString(), + ["CreatedAt"] = DateTime.UtcNow.AddMinutes(-30), + ["ToolCalls"] = "" + }; + + var message2 = new TableEntity(id, "msg-002") + { + ["Content"] = "Test message 2", + ["Role"] = ChatMessageRole.Assistant.ToString(), + ["CreatedAt"] = DateTime.UtcNow.AddMinutes(-15), + ["ToolCalls"] = "" + }; + + var mockQueryResult = new List { stateEntity, message1, message2 }; + AsyncPageable mockQueryable = MockAsyncPageable.Create(mockQueryResult); + + this.mockTableClient.Setup(x => x.QueryAsync( + It.Is(s => s == $"PartitionKey eq '{id}'"), + null, + null, + It.IsAny())) + .Returns(mockQueryable); + + // Create the service under test + var assistantService = new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object); + + // Use reflection to set the tableClient field + System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + tableClientField?.SetValue(assistantService, this.mockTableClient.Object); + + // Act + AssistantState result = await assistantService.GetStateAsync(attribute, CancellationToken.None); + + // Assert + Assert.Equal(id, result.Id); + Assert.True(result.Exists); + Assert.Equal(2, result.TotalMessages); + Assert.Equal(100, result.TotalTokens); + + // Verify that we filter messages based on timestamp + DateTime parsedTimestamp = DateTime.Parse(Uri.UnescapeDataString(timestamp)).ToUniversalTime(); + Assert.All(result.RecentMessages, msg => + Assert.True(DateTime.UtcNow.AddMinutes(-30) > parsedTimestamp)); + } + + [Fact] + public async Task GetStateAsync_WithNonExistentId_ReturnsEmptyState() + { + // Arrange + string id = "nonExistentId"; + string timestamp = DateTime.UtcNow.AddHours(-1).ToString("o"); + var attribute = new AssistantQueryAttribute(id) + { + TimestampUtc = timestamp, + CollectionName = "testCollection" + }; + + var mockQueryResult = new List { }; + AsyncPageable mockQueryable = MockAsyncPageable.Create(mockQueryResult); + + this.mockTableClient.Setup(x => x.QueryAsync( + It.Is(s => s == $"PartitionKey eq '{id}'"), + null, + null, + It.IsAny())) + .Returns(mockQueryable); + + // Create the service under test + var assistantService = new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object); + + // Use reflection to set the tableClient field + System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + tableClientField?.SetValue(assistantService, this.mockTableClient.Object); + + // Act + AssistantState result = await assistantService.GetStateAsync(attribute, CancellationToken.None); + + // Assert + Assert.Equal(id, result.Id); + Assert.False(result.Exists); + Assert.Equal(0, result.TotalMessages); + Assert.Empty(result.RecentMessages); + } + + [Fact] + public void Constructor_WithNullArguments_ThrowsArgumentNullException() + { + // Act & Assert +#nullable disable + Assert.Throws(() => new DefaultAssistantService( + null, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object)); + + Assert.Throws(() => new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + null, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object)); + + Assert.Throws(() => new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + null, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object)); + + Assert.Throws(() => new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + null, + this.mockLoggerFactory.Object)); + + Assert.Throws(() => new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + null)); +#nullable restore + } + + [Fact] + public async Task PostMessageAsync_WithNonExistentAssistant_ReturnsEmptyState() + { + // Arrange + string assistantId = "nonExistentId"; + var attribute = new AssistantPostAttribute(assistantId, "Hello") + { + CollectionName = "ChatState", + ChatStorageConnectionSetting = "AzureWebJobsStorage" + }; + + var mockQueryResult = new List(); + AsyncPageable mockQueryable = MockAsyncPageable.Create(mockQueryResult); + + this.mockTableClient.Setup(x => x.QueryAsync( + It.Is(s => s == $"PartitionKey eq '{assistantId}'"), + null, + null, + It.IsAny())) + .Returns(mockQueryable); + + // Create the service under test + var assistantService = new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object); + + // Use reflection to set the tableClient field + System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + tableClientField?.SetValue(assistantService, this.mockTableClient.Object); + + // Act + AssistantState result = await assistantService.PostMessageAsync(attribute, CancellationToken.None); + + // Assert + Assert.Equal(assistantId, result.Id); + Assert.False(result.Exists); + Assert.Equal(0, result.TotalMessages); + Assert.Empty(result.RecentMessages); + } + + [Fact] + public async Task PostMessageAsync_WithInvalidAttributes_ThrowsArgumentException() + { + // Arrange - Missing user message + var attributeNoMessage = new AssistantPostAttribute("testId", "") + { + CollectionName = "ChatState", + ChatStorageConnectionSetting = "AzureWebJobsStorage" + }; + + // Create the service under test + var assistantService = new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object); + + // Act & Assert - Empty message + await Assert.ThrowsAsync(() => + assistantService.PostMessageAsync(attributeNoMessage, CancellationToken.None)); + + // Arrange - Missing ID + var attributeNoId = new AssistantPostAttribute("", "Hello") + { + CollectionName = "ChatState", + ChatStorageConnectionSetting = "AzureWebJobsStorage" + }; + + // Act & Assert - Empty ID + await Assert.ThrowsAsync(() => + assistantService.PostMessageAsync(attributeNoId, CancellationToken.None)); + } + + static Mock CreateMockSection(bool exists, string? tableServiceUri = null) + { + var mockSection = new Mock(); + + if (!exists) + { + mockSection.Setup(s => s.Value).Returns(""); + mockSection.Setup(s => s.GetChildren()).Returns([]); + } + else + { + mockSection.Setup(s => s.Value).Returns("some-value"); + } + + var endpointSection = new Mock(); +#nullable disable + endpointSection.Setup(s => s.Value).Returns(tableServiceUri); +#nullable restore + mockSection.Setup(s => s.GetSection("tableServiceUri")).Returns(endpointSection.Object); + + return mockSection; + } +} + +// Update the MockAsyncPageable class to ensure TableEntity satisfies the 'notnull' constraint +public class MockAsyncPageable : AsyncPageable where T : notnull +{ + readonly IEnumerable _items; + + MockAsyncPageable(IEnumerable items) + { + this._items = items; + } + + public static AsyncPageable Create(IEnumerable items) + { + return new MockAsyncPageable(items); + } + + public override async IAsyncEnumerable> AsPages( + string? continuationToken = null, + int? pageSizeHint = null) + { + await Task.Yield(); + yield return Page.FromValues([.. this._items], null, new Mock().Object); + } +} diff --git a/tests/UnitTests/OpenAIClientFactoryTests.cs b/tests/UnitTests/OpenAIClientFactoryTests.cs new file mode 100644 index 00000000..37c5ba72 --- /dev/null +++ b/tests/UnitTests/OpenAIClientFactoryTests.cs @@ -0,0 +1,383 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + + +using System.Reflection; +using Azure.Core; +using Microsoft.Azure.WebJobs.Extensions.OpenAI; +using Microsoft.Extensions.Azure; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Moq; +using OpenAI.Chat; +using OpenAI.Embeddings; +using Xunit; + +namespace WebJobsOpenAIUnitTests; + +public class OpenAIClientFactoryTests +{ + readonly Mock mockConfiguration; + readonly Mock mockAzureComponentFactory; + readonly Mock mockLoggerFactory; + readonly Mock> mockLogger; + + public OpenAIClientFactoryTests() + { + this.mockConfiguration = new Mock(); + this.mockAzureComponentFactory = new Mock(); + this.mockLoggerFactory = new Mock(); + this.mockLogger = new Mock>(); + this.mockLoggerFactory + .Setup(f => f.CreateLogger(It.Is(s => s == typeof(OpenAIClientFactory).FullName))) + .Returns(this.mockLogger.Object); + } + + [Fact] + public void Constructor_NullConfiguration_ThrowsArgumentNullException() + { + // Arrange + IConfiguration? configuration = null; + +#nullable disable + // Act & Assert + ArgumentNullException exception = Assert.Throws(() => + new OpenAIClientFactory(configuration, this.mockAzureComponentFactory.Object, this.mockLoggerFactory.Object)); +#nullable restore + Assert.Equal("configuration", exception.ParamName); + } + + [Fact] + public void Constructor_NullAzureComponentFactory_ThrowsArgumentNullException() + { + // Arrange + AzureComponentFactory? azureComponentFactory = null; + +#nullable disable + // Act & Assert + ArgumentNullException exception = Assert.Throws(() => + new OpenAIClientFactory(this.mockConfiguration.Object, azureComponentFactory, this.mockLoggerFactory.Object)); +#nullable restore + Assert.Equal("azureComponentFactory", exception.ParamName); + } + + [Fact] + public void Constructor_NullLoggerFactory_ThrowsArgumentNullException() + { + // Arrange + ILoggerFactory? loggerFactory = null; + + // Act & Assert +#nullable disable + ArgumentNullException exception = Assert.Throws(() => + new OpenAIClientFactory(this.mockConfiguration.Object, this.mockAzureComponentFactory.Object, loggerFactory)); +#nullable restore + Assert.Equal("loggerFactory", exception.ParamName); + } + + [Fact] + public void GetChatClient_WithAzureOpenAI_ReturnsClient() + { + // Arrange + Mock mockSection = CreateMockSection( + exists: true, + endpoint: "https://test-endpoint.openai.azure.com/", + key: "test-key"); + + this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object); + + // Clear environment variable if set + string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + try + { + Environment.SetEnvironmentVariable("OPENAI_API_KEY", null); + + OpenAIClientFactory factory = new( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + // Act + ChatClient chatClient = factory.GetChatClient("TestConnection", "gpt-35-turbo"); + + // Assert + Assert.NotNull(chatClient); + } + finally + { + // Restore original environment value + Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue); + } + } + + [Fact] + public void GetChatClient_WithOpenAIKey_ReturnsClient() + { + // Arrange + string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + try + { + Environment.SetEnvironmentVariable("OPENAI_API_KEY", "test-openai-key"); + + OpenAIClientFactory factory = new( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + // Act + ChatClient chatClient = factory.GetChatClient("TestConnection", "gpt-4"); + + // Assert + Assert.NotNull(chatClient); + } + finally + { + // Restore original environment value + Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue); + } + } + + [Fact] + public void GetEmbeddingClient_WithAzureOpenAI_ReturnsClient() + { + // Arrange + Mock mockSection = CreateMockSection( + exists: true, + endpoint: "https://test-endpoint.openai.azure.com/", + key: "test-key"); + + this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object); + + // Clear environment variable if set + string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + try + { + Environment.SetEnvironmentVariable("OPENAI_API_KEY", null); + + OpenAIClientFactory factory = new( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + // Act + EmbeddingClient embeddingClient = factory.GetEmbeddingClient("TestConnection", "text-embedding-ada-002"); + + // Assert + Assert.NotNull(embeddingClient); + } + finally + { + // Restore original environment value + Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue); + } + } + + [Fact] + public void GetEmbeddingClient_WithOpenAIKey_ReturnsClient() + { + // Arrange + string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + try + { + Environment.SetEnvironmentVariable("OPENAI_API_KEY", "test-openai-key"); + + OpenAIClientFactory factory = new( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + // Act + EmbeddingClient embeddingClient = factory.GetEmbeddingClient("TestConnection", "text-embedding-ada-002"); + + // Assert + Assert.NotNull(embeddingClient); + } + finally + { + // Restore original environment value + Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue); + } + } + + [Fact] + public void HasOpenAIKey_KeyExists_SetsHasKeyToTrue() + { + // Arrange + string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + try + { + Environment.SetEnvironmentVariable("OPENAI_API_KEY", "test-key"); + + // Use reflection to access private static method + MethodInfo? methodInfo = typeof(OpenAIClientFactory).GetMethod("HasOpenAIKey", + BindingFlags.NonPublic | BindingFlags.Static); + + // Act +#nullable disable + object[] parameters = [false, null]; +#nullable restore + methodInfo?.Invoke(null, parameters); + + // Assert + Assert.True((bool)parameters[0]); + Assert.Equal("test-key", parameters[1]); + } + finally + { + // Restore original environment value + Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue); + } + } + + [Fact] + public void HasOpenAIKey_NoKeyExists_SetsHasKeyToFalse() + { + // Arrange + string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + try + { + Environment.SetEnvironmentVariable("OPENAI_API_KEY", null); + + // Use reflection to access private static method + MethodInfo? methodInfo = typeof(OpenAIClientFactory).GetMethod("HasOpenAIKey", + BindingFlags.NonPublic | BindingFlags.Static); + + // Act + object[] parameters = [true, "some-value"]; + methodInfo?.Invoke(null, parameters); + + // Assert + Assert.False((bool)parameters[0]); + Assert.Null((string)parameters[1]); + } + finally + { + // Restore original environment value + Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue); + } + } + + [Fact] + public void CreateClientFromConfigSection_MissingEndpoint_ThrowsInvalidOperationException() + { + // Arrange + Mock mockSection = CreateMockSection( + exists: true, + endpoint: null, + key: "test-key"); + + this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object); + + // Clear environment variable if set + string? originalEndpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); + try + { + Environment.SetEnvironmentVariable("AZURE_OPENAI_ENDPOINT", null); + + OpenAIClientFactory factory = new( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + // Act & Assert + InvalidOperationException exception = Assert.Throws(() => + factory.GetChatClient("TestConnection", "gpt-35-turbo")); + Assert.Contains("missing required 'Endpoint'", exception.Message); + } + finally + { + // Restore original environment value + Environment.SetEnvironmentVariable("AZURE_OPENAI_ENDPOINT", originalEndpoint); + } + } + + [Fact] + public void CreateClientFromConfigSection_WithTokenCredential_ReturnsClient() + { + // Arrange + Mock mockSection = CreateMockSection( + exists: true, + endpoint: "https://test-endpoint.openai.azure.com/", + key: null); + + this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object); + + TokenCredential tokenCredential = new Mock().Object; + this.mockAzureComponentFactory.Setup(f => f.CreateTokenCredential(It.IsAny())) + .Returns(tokenCredential); + + OpenAIClientFactory factory = new( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + // Act + ChatClient chatClient = factory.GetChatClient("TestConnection", "gpt-35-turbo"); + + // Assert + Assert.NotNull(chatClient); + this.mockAzureComponentFactory.Verify(f => f.CreateTokenCredential(It.IsAny()), Times.Once); + } + + [Fact] + public void CreateClientFromConfigSection_WithEnvironmentVariables_UsesEnvironmentValues() + { + // Arrange + Mock mockSection = CreateMockSection( + exists: false, + endpoint: null, + key: null); + this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object); + + string? originalEndpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); + string? originalKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_KEY"); + try + { + Environment.SetEnvironmentVariable("AZURE_OPENAI_ENDPOINT", "https://env-endpoint.openai.azure.com/"); + Environment.SetEnvironmentVariable("AZURE_OPENAI_KEY", "env-key"); + + OpenAIClientFactory factory = new( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + // Act + ChatClient chatClient = factory.GetChatClient("TestConnection", "gpt-35-turbo"); + + // Assert + Assert.NotNull(chatClient); + } + finally + { + // Restore original environment values + Environment.SetEnvironmentVariable("AZURE_OPENAI_ENDPOINT", originalEndpoint); + Environment.SetEnvironmentVariable("AZURE_OPENAI_KEY", originalKey); + } + } + + static Mock CreateMockSection(bool exists, string? endpoint = null, string? key = null) + { + var mockSection = new Mock(); + + if (!exists) + { + mockSection.Setup(s => s.Value).Returns(""); + mockSection.Setup(s => s.GetChildren()).Returns([]); + } + else + { + mockSection.Setup(s => s.Value).Returns("some-value"); + } + +#nullable disable + var endpointSection = new Mock(); + endpointSection.Setup(s => s.Value).Returns(endpoint); + mockSection.Setup(s => s.GetSection("Endpoint")).Returns(endpointSection.Object); + var keySection = new Mock(); + keySection.Setup(s => s.Value).Returns(key); +#nullable restore + mockSection.Setup(s => s.GetSection("Key")).Returns(keySection.Object); + + return mockSection; + } +} \ No newline at end of file diff --git a/tests/UnitTests/WebJobsOpenAIUnitTests.csproj b/tests/UnitTests/WebJobsOpenAIUnitTests.csproj new file mode 100644 index 00000000..d3c1b04b --- /dev/null +++ b/tests/UnitTests/WebJobsOpenAIUnitTests.csproj @@ -0,0 +1,29 @@ + + + + net8.0 + enable + enable + + false + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + + From c55bec794fd5fda7d9add3a171afd649c4173ad4 Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Sun, 13 Apr 2025 01:24:15 -0700 Subject: [PATCH 06/21] remove local dev test --- tests/UnitTests/AssistantServiceTests.cs | 48 ------------------------ 1 file changed, 48 deletions(-) diff --git a/tests/UnitTests/AssistantServiceTests.cs b/tests/UnitTests/AssistantServiceTests.cs index 0377e21a..eb40bb59 100644 --- a/tests/UnitTests/AssistantServiceTests.cs +++ b/tests/UnitTests/AssistantServiceTests.cs @@ -120,54 +120,6 @@ public async Task CreateAssistantAsync_WithValidRequest_CreatesAssistantAndMessa It.IsAny()), Times.Once); } - [Fact] - public async Task CreateAssistantAsync_WithLocalDevStorage_CreatesAssistantAndMessages() - { - // Arrange - var request = new AssistantCreateRequest("testId", "Test instructions") - { - CollectionName = "ChatState", - ChatStorageConnectionSetting = "AzureWebJobsStorage" - }; - - var mockQueryResult = new List(); - AsyncPageable mockQueryable = MockAsyncPageable.Create(mockQueryResult); - - // Arrange - Mock mockSection = CreateMockSection( - exists: false, - tableServiceUri: null); - mockSection.Setup(s => s.Value).Returns("UseDevelopmentStorage=true"); - - // Setup AzureWebJobsStorage directly - this.mockConfiguration.Setup(c => c["AzureWebJobsStorage"]).Returns("UseDevelopmentStorage=true"); - this.mockConfiguration.Setup(c => c.GetSection("AzureWebJobsStorage")).Returns(mockSection.Object); - - // Create the service under test - var assistantService = new DefaultAssistantService( - this.mockOpenAIClientFactory.Object, - this.mockAzureComponentFactory.Object, - this.mockConfiguration.Object, - this.mockSkillInvoker.Object, - this.mockLoggerFactory.Object); - - // Act - await assistantService.CreateAssistantAsync(request, CancellationToken.None); - - // Assert - this.mockTableClient.Verify(x => x.CreateIfNotExistsAsync(It.IsAny()), Times.Never); - - this.mockTableClient.Verify(x => x.QueryAsync( - It.Is(s => s.Contains(request.Id)), - null, - null, - It.IsAny()), Times.Never); - - this.mockTableClient.Verify(x => x.SubmitTransactionAsync( - It.IsAny>(), - It.IsAny()), Times.Never); - } - [Fact] public async Task CreateAssistantAsync_WithExistingAssistant_DeletesOldEntitiesFirst() { From 8a080de7797ed1802f16a83cd0406eae0a7f6f9b Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Sun, 13 Apr 2025 01:32:06 -0700 Subject: [PATCH 07/21] add no build to test task --- eng/ci/templates/build-local.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/eng/ci/templates/build-local.yml b/eng/ci/templates/build-local.yml index 80337141..b11122e1 100644 --- a/eng/ci/templates/build-local.yml +++ b/eng/ci/templates/build-local.yml @@ -20,10 +20,11 @@ jobs: dotnet build $(System.DefaultWorkingDirectory)/src/WebJobs.Extensions.OpenAI.Kusto/WebJobs.Extensions.OpenAI.Kusto.csproj --configuration $(config) -p:Version=$(fakeWebJobsPackageVersion) -p:AzureAISearchVersion=$(fakeWebJobsPackageVersion) -p:KustoVersion=$(fakeWebJobsPackageVersion) dotnet build $(System.DefaultWorkingDirectory)/src/WebJobs.Extensions.OpenAI.AzureAISearch/WebJobs.Extensions.OpenAI.AzureAISearch.csproj --configuration $(config) -p:Version=$(fakeWebJobsPackageVersion) -p:AzureAISearchVersion=$(fakeWebJobsPackageVersion) dotnet build $(System.DefaultWorkingDirectory)/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/WebJobs.Extensions.OpenAI.CosmosDBSearch.csproj --configuration $(config) -p:Version=$(fakeWebJobsPackageVersion) -p:CosmosDBSearchVersion=$(fakeWebJobsPackageVersion) + dotnet build $(System.DefaultWorkingDirectory)/tests/UnitTests/WebJobsOpenAIUnitTests.csproj --configuration $(config) -p:WebJobsVersion=$(fakeWebJobsPackageVersion) -p:Version=$(fakeWebJobsPackageVersion) displayName: Dotnet Build WebJobs.Extensions.OpenAI - script: | - dotnet test $(System.DefaultWorkingDirectory)/tests/UnitTests/WebJobsOpenAIUnitTests.csproj --configuration $(config) --collect "Code Coverage" + dotnet test $(System.DefaultWorkingDirectory)/tests/UnitTests/WebJobsOpenAIUnitTests.csproj --configuration $(config) --collect "Code Coverage" --no-build displayName: Dotnet Test WebJobsOpenAIUnitTests - task: CopyFiles@2 From f6bad6ee967c958d8efe9117bf8bcadf1c8389b3 Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Sun, 13 Apr 2025 21:02:50 -0700 Subject: [PATCH 08/21] update out of proc worker and changelog --- CHANGELOG.md | 6 +++ .../Assistants/ChatCompletionJsonConverter.cs | 12 ++--- .../Embeddings/EmbeddingsContext.cs | 20 +++----- .../Embeddings/EmbeddingsJsonConverter.cs | 14 +++--- .../EmbeddingsOptionsJsonConverter.cs | 12 ++--- .../Embeddings/JsonModelListWrapper.cs | 49 +++++++++++++++++++ .../Search/SearchableDocumentJsonConverter.cs | 12 +++-- .../Search/SemanticSearchContext.cs | 13 +++-- .../Startup.cs | 2 +- .../Embeddings/EmbeddingsContext.cs | 3 +- .../Embeddings/EmbeddingsContextConverter.cs | 7 ++- .../Embeddings/JsonModelListWrapper.cs | 49 +++++++++++++++++++ .../Search/SearchableDocumentJsonConverter.cs | 6 +-- .../Search/SemanticSearchContext.cs | 2 +- 14 files changed, 157 insertions(+), 50 deletions(-) create mode 100644 src/Functions.Worker.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs create mode 100644 src/WebJobs.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b54278d..7139b41f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,8 +11,14 @@ Starting v0.1.0 for Microsoft.Azure.WebJobs.Extensions.OpenAI.AzureAISearch, it ## v0.19.0 - Unreleased +### Breaking + +- Model properties named to ChatModel and EmbeddingsModel in related bindings +- Managed identity support through config section and binding parameter AIConnectionName. + ### Changed +- Updated Azure.AI.OpenAI from 1.0.0-beta.15 to 2.1.0 - Updated Azure.Data.Tables from 12.9.1 to 12.10.0, Azure.Identity from 1.12.1 to 1.13.2, Microsoft.Extensions.Azure from 1.7.5 to 1.8.0 ## v0.18.0 - 2024/10/08 diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs index a73974fa..cf471922 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs @@ -1,23 +1,23 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using Azure.AI.OpenAI; using System.ClientModel.Primitives; using System.Text.Json; using System.Text.Json.Serialization; +using OpenAI.Chat; namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Assistants; -public class ChatCompletionsJsonConverter : JsonConverter +public class ChatCompletionJsonConverter : JsonConverter { static readonly ModelReaderWriterOptions modelReaderWriterOptions = new("J"); - public override ChatCompletions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override ChatCompletion Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { using JsonDocument jsonDocument = JsonDocument.ParseValue(ref reader); - return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!; + return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!; } - public override void Write(Utf8JsonWriter writer, ChatCompletions value, JsonSerializerOptions options) + public override void Write(Utf8JsonWriter writer, ChatCompletion value, JsonSerializerOptions options) { - ((IJsonModel)value).Write(writer, modelReaderWriterOptions); + ((IJsonModel)value).Write(writer, modelReaderWriterOptions); } } \ No newline at end of file diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs index 21259707..4218199d 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs @@ -1,34 +1,30 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using OpenAISDK = Azure.AI.OpenAI; +using OpenAI.Embeddings; + namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; public class EmbeddingsContext { - /// - /// Binding target for the . - /// - /// The embeddings request that was sent to OpenAI. - /// The embeddings response that was received from OpenAI. - public EmbeddingsContext(OpenAISDK.EmbeddingsOptions Request, OpenAISDK.Embeddings Response) + public EmbeddingsContext(IList Input, OpenAIEmbeddingCollection? Response) { - this.Request = Request; + this.Input = Input; this.Response = Response; } /// /// Embeddings request sent to OpenAI. /// - public OpenAISDK.EmbeddingsOptions Request { get; set; } + public IList Input { get; set; } /// /// Embeddings response from OpenAI. /// - public OpenAISDK.Embeddings Response { get; set; } - + public OpenAIEmbeddingCollection? Response { get; set; } + /// /// Gets the number of embeddings that were returned in the response. /// - public int Count => this.Response.Data?.Count ?? 0; + public int Count => this.Response?.Count ?? 0; } diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsJsonConverter.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsJsonConverter.cs index 30a3679a..cfcc5d23 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsJsonConverter.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsJsonConverter.cs @@ -4,25 +4,25 @@ using System.ClientModel.Primitives; using System.Text.Json; using System.Text.Json.Serialization; -using OpenAISDK = Azure.AI.OpenAI; +using OpenAI.Embeddings; namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; /// -/// Embeddings JSON converter needed to serialize and deserialize the Embeddings object with the dotnet worker. +/// OpenAIEmbeddingCollection JSON converter needed to serialize and deserialize the OpenAIEmbeddingCollection object with the dotnet worker. /// -class EmbeddingsJsonConverter : JsonConverter +class EmbeddingsJsonConverter : JsonConverter { static readonly ModelReaderWriterOptions JsonOptions = new("J"); - public override OpenAISDK.Embeddings Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override OpenAIEmbeddingCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { using JsonDocument jsonDocument = JsonDocument.ParseValue(ref reader); - return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!; + return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!; } - public override void Write(Utf8JsonWriter writer, OpenAISDK.Embeddings value, JsonSerializerOptions options) + public override void Write(Utf8JsonWriter writer, OpenAIEmbeddingCollection value, JsonSerializerOptions options) { - ((IJsonModel)value).Write(writer, JsonOptions); + ((IJsonModel)value).Write(writer, JsonOptions); } } diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs index 0a8949dc..9ca6bdf8 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs @@ -4,25 +4,25 @@ using System.ClientModel.Primitives; using System.Text.Json; using System.Text.Json.Serialization; -using Azure.AI.OpenAI; +using OpenAI.Embeddings; namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; /// /// EmbeddingsOptions JSON converter needed to serialize and deserialize the EmbeddingsOptions object with the dotnet worker. /// -class EmbeddingsOptionsJsonConverter : JsonConverter +class EmbeddingsOptionsJsonConverter : JsonConverter { static readonly ModelReaderWriterOptions JsonOptions = new("J"); - public override EmbeddingsOptions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override EmbeddingGenerationOptions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { using JsonDocument jsonDocument = JsonDocument.ParseValue(ref reader); - return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!; + return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!; } - public override void Write(Utf8JsonWriter writer, EmbeddingsOptions value, JsonSerializerOptions options) + public override void Write(Utf8JsonWriter writer, EmbeddingGenerationOptions value, JsonSerializerOptions options) { - ((IJsonModel)value).Write(writer, JsonOptions); + ((IJsonModel)value).Write(writer, JsonOptions); } } diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs new file mode 100644 index 00000000..242a8fe0 --- /dev/null +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs @@ -0,0 +1,49 @@ +using System.ClientModel.Primitives; +using System.Text.Json; + +namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; + +class JsonModelListWrapper : IJsonModel> +{ + readonly List list; + + public JsonModelListWrapper(List list) + { + this.list = list; + } + + public void Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + writer.WriteStartArray(); + foreach (string item in this.list) + { + writer.WriteStringValue(item); + } + writer.WriteEndArray(); + } + + public static JsonModelListWrapper FromList(List list) + { + return new JsonModelListWrapper(list); + } + + public List Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public BinaryData Write(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public List Create(BinaryData data, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public string GetFormatFromOptions(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } +} \ No newline at end of file diff --git a/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs b/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs index 959912a5..ac8e25fa 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs @@ -4,7 +4,8 @@ using System.ClientModel.Primitives; using System.Text.Json; using System.Text.Json.Serialization; -using OpenAISDK = Azure.AI.OpenAI; +using Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; +using OpenAI.Embeddings; namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Search; @@ -23,12 +24,13 @@ public override void Write(Utf8JsonWriter writer, SearchableDocument value, Json writer.WritePropertyName("embeddingsContext"u8); writer.WriteStartObject(); - if (value.EmbeddingsContext?.Request is IJsonModel request) + if (value.EmbeddingsContext?.Input is List inputList) { - writer.WritePropertyName("request"u8); - request.Write(writer, modelReaderWriterOptions); + writer.WritePropertyName("input"u8); + var inputWrapper = JsonModelListWrapper.FromList(inputList); + inputWrapper.Write(writer, modelReaderWriterOptions); } - if (value.EmbeddingsContext?.Response is IJsonModel response) + if (value.EmbeddingsContext?.Response is IJsonModel response) { writer.WritePropertyName("response"u8); response.Write(writer, modelReaderWriterOptions); diff --git a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchContext.cs b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchContext.cs index b40a0dc8..8fb1788d 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchContext.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchContext.cs @@ -2,8 +2,8 @@ // Licensed under the MIT License. using System.Text.Json.Serialization; -using Azure.AI.OpenAI; using Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; +using OpenAI.Chat; namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Search; @@ -17,11 +17,11 @@ public class SemanticSearchContext /// /// The embeddings context associated with the semantic search. /// The chat response from the large language model. - public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletions Chat) + public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletion Chat) { this.Embeddings = Embeddings; this.Chat = Chat; - + } /// @@ -34,12 +34,11 @@ public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletions Chat) /// Gets the chat response from the large language model. /// [JsonPropertyName("chat")] - public ChatCompletions Chat { get; } - + public ChatCompletion Chat { get; } /// /// Gets the latest response message from the OpenAI Chat API. /// [JsonPropertyName("response")] - public string Response => this.Chat.Choices.Last().Message.Content; - } + public string Response => this.Chat.Content.Last().Text; +} diff --git a/src/Functions.Worker.Extensions.OpenAI/Startup.cs b/src/Functions.Worker.Extensions.OpenAI/Startup.cs index 8bbd524b..e4ad38a9 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Startup.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Startup.cs @@ -25,7 +25,7 @@ public override void Configure(IFunctionsWorkerApplicationBuilder applicationBui jsonSerializerOptions.Converters.Add(new EmbeddingsJsonConverter()); jsonSerializerOptions.Converters.Add(new EmbeddingsOptionsJsonConverter()); jsonSerializerOptions.Converters.Add(new SearchableDocumentJsonConverter()); - jsonSerializerOptions.Converters.Add(new ChatCompletionsJsonConverter()); + jsonSerializerOptions.Converters.Add(new ChatCompletionJsonConverter()); }); } } diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs index ab59b3d0..db182d4f 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs @@ -2,12 +2,13 @@ // Licensed under the MIT License. using OpenAI.Embeddings; + namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; /// /// Binding target for the . /// -/// The embeddings request that was sent to OpenAI. +/// The embeddings input that was sent to OpenAI. /// The embeddings response that was received from OpenAI. public class EmbeddingsContext { diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs index 6634064a..d848d22f 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs @@ -23,7 +23,12 @@ public override void Write(Utf8JsonWriter writer, EmbeddingsContext value, JsonS { writer.WriteStartObject(); writer.WritePropertyName("input"u8); - ((IJsonModel)value.Input).Write(writer, modelReaderWriterOptions); + + if (value.Input is List inputList) + { + var inputWrapper = JsonModelListWrapper.FromList(inputList); + inputWrapper.Write(writer, modelReaderWriterOptions); + } if (value.Response is IJsonModel response) { diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs new file mode 100644 index 00000000..0a4277ee --- /dev/null +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs @@ -0,0 +1,49 @@ +using System.ClientModel.Primitives; +using System.Text.Json; + +namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; + +class JsonModelListWrapper : IJsonModel> +{ + readonly List list; + + public JsonModelListWrapper(List list) + { + this.list = list; + } + + public void Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + writer.WriteStartArray(); + foreach (string item in this.list) + { + writer.WriteStringValue(item); + } + writer.WriteEndArray(); + } + + public static JsonModelListWrapper FromList(List list) + { + return new JsonModelListWrapper(list); + } + + public List Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public BinaryData Write(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public List Create(BinaryData data, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public string GetFormatFromOptions(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } +} \ No newline at end of file diff --git a/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs b/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs index 5bbda29c..3dd4910b 100644 --- a/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs @@ -78,10 +78,10 @@ public override void Write(Utf8JsonWriter writer, SearchableDocument value, Json writer.WritePropertyName("embeddingsContext"u8); writer.WriteStartObject(); - if (value.Embeddings?.Input is IJsonModel request) + if (value.Embeddings?.Input is List inputList) { - writer.WritePropertyName("input"u8); - request.Write(writer, modelReaderWriterOptions); + var inputWrapper = JsonModelListWrapper.FromList(inputList); + inputWrapper.Write(writer, modelReaderWriterOptions); } if (value.Embeddings?.Response is IJsonModel response) diff --git a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchContext.cs b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchContext.cs index a51b8fc9..a0bf8073 100644 --- a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchContext.cs +++ b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchContext.cs @@ -40,5 +40,5 @@ public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletion Chat) /// Gets the latest response message from the OpenAI Chat API. /// [JsonProperty("response")] - public string Response => this.Chat.Content.LastOrDefault().Text; + public string Response => this.Chat.Content.Last().Text; } From 7ec128ca751ef7a92220f5cf84e99cbf9fff996c Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Sun, 13 Apr 2025 21:15:13 -0700 Subject: [PATCH 09/21] update powershell and node samples --- samples/assistant/javascript/src/functions/assistantApis.js | 2 +- samples/assistant/powershell/PostUserQuery/function.json | 2 +- samples/assistant/typescript/src/functions/assistantApis.ts | 2 +- samples/chat/javascript/src/app.js | 2 +- samples/chat/typescript/src/functions/app.ts | 2 +- samples/embeddings/javascript/src/app.js | 6 +++--- .../embeddings/powershell/GenerateEmbeddings/function.json | 2 +- .../powershell/GetEmbeddingsFilePath/function.json | 2 +- .../embeddings/powershell/GetEmbeddingsURL/function.json | 2 +- samples/embeddings/typescript/src/app.ts | 6 +++--- samples/rag-aisearch/javascript/src/app.js | 2 +- samples/rag-aisearch/powershell/IngestFile/function.json | 2 +- samples/rag-aisearch/typescript/src/app.ts | 2 +- samples/rag-cosmosdb/javascript/src/app.js | 2 +- samples/rag-cosmosdb/powershell/IngestFile/function.json | 2 +- samples/rag-cosmosdb/typescript/src/app.ts | 2 +- samples/rag-kusto/javascript/src/app.js | 2 +- samples/rag-kusto/powershell/IngestEmail/function.json | 2 +- samples/rag-kusto/typescript/src/app.ts | 2 +- samples/textcompletion/javascript/src/functions/whois.js | 2 +- samples/textcompletion/powershell/WhoIs/function.json | 2 +- samples/textcompletion/typescript/src/functions/whois.ts | 2 +- 22 files changed, 26 insertions(+), 26 deletions(-) diff --git a/samples/assistant/javascript/src/functions/assistantApis.js b/samples/assistant/javascript/src/functions/assistantApis.js index a8eaced4..3d167a10 100644 --- a/samples/assistant/javascript/src/functions/assistantApis.js +++ b/samples/assistant/javascript/src/functions/assistantApis.js @@ -36,7 +36,7 @@ app.http('CreateAssistant', { const assistantPostInput = input.generic({ type: 'assistantPost', id: '{assistantId}', - model: '%CHAT_MODEL_DEPLOYMENT_NAME%', + chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%', userMessage: '{Query.message}', chatStorageConnectionSetting: CHAT_STORAGE_CONNECTION_SETTING, collectionName: COLLECTION_NAME diff --git a/samples/assistant/powershell/PostUserQuery/function.json b/samples/assistant/powershell/PostUserQuery/function.json index 024cd24f..05b7702c 100644 --- a/samples/assistant/powershell/PostUserQuery/function.json +++ b/samples/assistant/powershell/PostUserQuery/function.json @@ -22,7 +22,7 @@ "dataType": "string", "id": "{assistantId}", "userMessage": "{Query.message}", - "model": "%CHAT_MODEL_DEPLOYMENT_NAME%", + "chatModel": "%CHAT_MODEL_DEPLOYMENT_NAME%", "chatStorageConnectionSetting": "AzureWebJobsStorage", "collectionName": "ChatState" } diff --git a/samples/assistant/typescript/src/functions/assistantApis.ts b/samples/assistant/typescript/src/functions/assistantApis.ts index 35bdd59c..c3bd70bf 100644 --- a/samples/assistant/typescript/src/functions/assistantApis.ts +++ b/samples/assistant/typescript/src/functions/assistantApis.ts @@ -36,7 +36,7 @@ app.http('CreateAssistant', { const assistantPostInput = input.generic({ type: 'assistantPost', id: '{assistantId}', - model: '%CHAT_MODEL_DEPLOYMENT_NAME%', + chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%', userMessage: '{Query.message}', chatStorageConnectionSetting: CHAT_STORAGE_CONNECTION_SETTING, collectionName: COLLECTION_NAME diff --git a/samples/chat/javascript/src/app.js b/samples/chat/javascript/src/app.js index f0a86ddd..4b3b3409 100644 --- a/samples/chat/javascript/src/app.js +++ b/samples/chat/javascript/src/app.js @@ -52,7 +52,7 @@ app.http('GetChatState', { const assistantPostInput = input.generic({ type: 'assistantPost', id: '{chatID}', - model: '%CHAT_MODEL_DEPLOYMENT_NAME%', + chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%', userMessage: '{Query.message}', chatStorageConnectionSetting: CHAT_STORAGE_CONNECTION_SETTING, collectionName: COLLECTION_NAME diff --git a/samples/chat/typescript/src/functions/app.ts b/samples/chat/typescript/src/functions/app.ts index 84a12075..a5cb025c 100644 --- a/samples/chat/typescript/src/functions/app.ts +++ b/samples/chat/typescript/src/functions/app.ts @@ -52,7 +52,7 @@ app.http('GetChatState', { const assistantPostInput = input.generic({ type: 'assistantPost', id: '{chatID}', - model: '%CHAT_MODEL_DEPLOYMENT_NAME%', + chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%', userMessage: '{Query.message}', chatStorageConnectionSetting: CHAT_STORAGE_CONNECTION_SETTING, collectionName: COLLECTION_NAME diff --git a/samples/embeddings/javascript/src/app.js b/samples/embeddings/javascript/src/app.js index 7c53e652..6a52be79 100644 --- a/samples/embeddings/javascript/src/app.js +++ b/samples/embeddings/javascript/src/app.js @@ -4,7 +4,7 @@ const embeddingsHttpInput = input.generic({ input: '{rawText}', inputType: 'RawText', type: 'embeddings', - model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' + embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' }) app.http('generateEmbeddings', { @@ -31,7 +31,7 @@ const embeddingsFilePathInput = input.generic({ inputType: 'FilePath', type: 'embeddings', maxChunkLength: 512, - model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' + embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' }) app.http('getEmbeddingsFilePath', { @@ -58,7 +58,7 @@ const embeddingsUrlInput = input.generic({ inputType: 'Url', type: 'embeddings', maxChunkLength: 512, - model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' + embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' }) app.http('getEmbeddingsUrl', { diff --git a/samples/embeddings/powershell/GenerateEmbeddings/function.json b/samples/embeddings/powershell/GenerateEmbeddings/function.json index 57621818..0236fcdf 100644 --- a/samples/embeddings/powershell/GenerateEmbeddings/function.json +++ b/samples/embeddings/powershell/GenerateEmbeddings/function.json @@ -21,7 +21,7 @@ "direction": "in", "inputType": "RawText", "input": "{rawText}", - "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" } ] } \ No newline at end of file diff --git a/samples/embeddings/powershell/GetEmbeddingsFilePath/function.json b/samples/embeddings/powershell/GetEmbeddingsFilePath/function.json index 3209aed5..1d8d44ed 100644 --- a/samples/embeddings/powershell/GetEmbeddingsFilePath/function.json +++ b/samples/embeddings/powershell/GetEmbeddingsFilePath/function.json @@ -22,7 +22,7 @@ "inputType": "FilePath", "input": "{filePath}", "maxChunkLength": 512, - "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" } ] } \ No newline at end of file diff --git a/samples/embeddings/powershell/GetEmbeddingsURL/function.json b/samples/embeddings/powershell/GetEmbeddingsURL/function.json index ebf96844..de11ff4b 100644 --- a/samples/embeddings/powershell/GetEmbeddingsURL/function.json +++ b/samples/embeddings/powershell/GetEmbeddingsURL/function.json @@ -22,7 +22,7 @@ "inputType": "Url", "input": "{url}", "maxChunkLength": 512, - "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" } ] } \ No newline at end of file diff --git a/samples/embeddings/typescript/src/app.ts b/samples/embeddings/typescript/src/app.ts index 9af76b96..f749cd0c 100644 --- a/samples/embeddings/typescript/src/app.ts +++ b/samples/embeddings/typescript/src/app.ts @@ -8,7 +8,7 @@ const embeddingsHttpInput = input.generic({ input: '{rawText}', inputType: 'RawText', type: 'embeddings', - model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' + embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' }) app.http('generateEmbeddings', { @@ -39,7 +39,7 @@ const embeddingsFilePathInput = input.generic({ inputType: 'FilePath', type: 'embeddings', maxChunkLength: 512, - model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' + embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' }) app.http('getEmbeddingsFilePath', { @@ -70,7 +70,7 @@ const embeddingsUrlInput = input.generic({ inputType: 'Url', type: 'embeddings', maxChunkLength: 512, - model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' + embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' }) app.http('getEmbeddingsUrl', { diff --git a/samples/rag-aisearch/javascript/src/app.js b/samples/rag-aisearch/javascript/src/app.js index 6bc7cc8f..c68b58eb 100644 --- a/samples/rag-aisearch/javascript/src/app.js +++ b/samples/rag-aisearch/javascript/src/app.js @@ -8,7 +8,7 @@ const embeddingsStoreOutput = output.generic({ inputType: "url", connectionName: "AISearchEndpoint", collection: "openai-index", - model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" }); app.http('IngestFile', { diff --git a/samples/rag-aisearch/powershell/IngestFile/function.json b/samples/rag-aisearch/powershell/IngestFile/function.json index 96a6a06e..be786a8b 100644 --- a/samples/rag-aisearch/powershell/IngestFile/function.json +++ b/samples/rag-aisearch/powershell/IngestFile/function.json @@ -22,7 +22,7 @@ "inputType": "Url", "connectionName": "AISearchEndpoint", "collection": "openai-index", - "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" } ] } \ No newline at end of file diff --git a/samples/rag-aisearch/typescript/src/app.ts b/samples/rag-aisearch/typescript/src/app.ts index ec430e24..176686d1 100644 --- a/samples/rag-aisearch/typescript/src/app.ts +++ b/samples/rag-aisearch/typescript/src/app.ts @@ -11,7 +11,7 @@ const embeddingsStoreOutput = output.generic({ inputType: "url", connectionName: "AISearchEndpoint", collection: "openai-index", - model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" }); app.http('IngestFile', { diff --git a/samples/rag-cosmosdb/javascript/src/app.js b/samples/rag-cosmosdb/javascript/src/app.js index 274917c4..0116e652 100644 --- a/samples/rag-cosmosdb/javascript/src/app.js +++ b/samples/rag-cosmosdb/javascript/src/app.js @@ -8,7 +8,7 @@ const embeddingsStoreOutput = output.generic({ inputType: "url", connectionName: "CosmosDBMongoVCoreConnectionString", collection: "openai-index", - model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" }); app.http('IngestFile', { diff --git a/samples/rag-cosmosdb/powershell/IngestFile/function.json b/samples/rag-cosmosdb/powershell/IngestFile/function.json index 55b986e0..cdb86803 100644 --- a/samples/rag-cosmosdb/powershell/IngestFile/function.json +++ b/samples/rag-cosmosdb/powershell/IngestFile/function.json @@ -22,7 +22,7 @@ "inputType": "Url", "connectionName": "CosmosDBMongoVCoreConnectionString", "collection": "openai-index", - "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" } ] } \ No newline at end of file diff --git a/samples/rag-cosmosdb/typescript/src/app.ts b/samples/rag-cosmosdb/typescript/src/app.ts index c1808411..3ef9257e 100644 --- a/samples/rag-cosmosdb/typescript/src/app.ts +++ b/samples/rag-cosmosdb/typescript/src/app.ts @@ -11,7 +11,7 @@ const embeddingsStoreOutput = output.generic({ inputType: "url", connectionName: "CosmosDBMongoVCoreConnectionString", collection: "openai-index", - model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" }); app.http('IngestFile', { diff --git a/samples/rag-kusto/javascript/src/app.js b/samples/rag-kusto/javascript/src/app.js index 655e64e6..87b948d1 100644 --- a/samples/rag-kusto/javascript/src/app.js +++ b/samples/rag-kusto/javascript/src/app.js @@ -8,7 +8,7 @@ const embeddingsStoreOutput = output.generic({ inputType: "url", connectionName: "KustoConnectionString", collection: "Documents", - model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" }); app.http('IngestEmail', { diff --git a/samples/rag-kusto/powershell/IngestEmail/function.json b/samples/rag-kusto/powershell/IngestEmail/function.json index 62ad48c4..31bd2eaa 100644 --- a/samples/rag-kusto/powershell/IngestEmail/function.json +++ b/samples/rag-kusto/powershell/IngestEmail/function.json @@ -22,7 +22,7 @@ "inputType": "Url", "connectionName": "KustoConnectionString", "collection": "Documents", - "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" } ] } \ No newline at end of file diff --git a/samples/rag-kusto/typescript/src/app.ts b/samples/rag-kusto/typescript/src/app.ts index bd0df1d8..bfc17d70 100644 --- a/samples/rag-kusto/typescript/src/app.ts +++ b/samples/rag-kusto/typescript/src/app.ts @@ -11,7 +11,7 @@ const embeddingsStoreOutput = output.generic({ inputType: "url", connectionName: "KustoConnectionString", collection: "Documents", - model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" }); app.http('IngestEmail', { diff --git a/samples/textcompletion/javascript/src/functions/whois.js b/samples/textcompletion/javascript/src/functions/whois.js index 665af065..96725512 100644 --- a/samples/textcompletion/javascript/src/functions/whois.js +++ b/samples/textcompletion/javascript/src/functions/whois.js @@ -5,7 +5,7 @@ const openAICompletionInput = input.generic({ prompt: 'Who is {name}?', maxTokens: '100', type: 'textCompletion', - model: '%CHAT_MODEL_DEPLOYMENT_NAME%' + chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%' }) app.http('whois', { diff --git a/samples/textcompletion/powershell/WhoIs/function.json b/samples/textcompletion/powershell/WhoIs/function.json index bdc381c7..6abf6c9b 100644 --- a/samples/textcompletion/powershell/WhoIs/function.json +++ b/samples/textcompletion/powershell/WhoIs/function.json @@ -21,7 +21,7 @@ "name": "TextCompletionResponse", "prompt": "Who is {name}?", "maxTokens": "100", - "model": "%CHAT_MODEL_DEPLOYMENT_NAME%" + "chatModel": "%CHAT_MODEL_DEPLOYMENT_NAME%" } ] } \ No newline at end of file diff --git a/samples/textcompletion/typescript/src/functions/whois.ts b/samples/textcompletion/typescript/src/functions/whois.ts index 22f8c2ee..b910f51c 100644 --- a/samples/textcompletion/typescript/src/functions/whois.ts +++ b/samples/textcompletion/typescript/src/functions/whois.ts @@ -5,7 +5,7 @@ const openAICompletionInput = input.generic({ prompt: 'Who is {name}?', maxTokens: '100', type: 'textCompletion', - model: '%CHAT_MODEL_DEPLOYMENT_NAME%' + chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%' }) app.http('whois', { From 004de6ede4b855a005cee3a66fe6d49c6e94b34f Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Sun, 13 Apr 2025 21:29:58 -0700 Subject: [PATCH 10/21] update java library --- .../annotation/assistant/AssistantPost.java | 45 ++++++++- .../assistant/AssistantSkillTrigger.java | 95 ++++++++++--------- .../annotation/assistant/ChatMessage.java | 33 ++++--- .../embeddings/EmbeddingsInput.java | 28 ++++-- .../embeddings/EmbeddingsStoreOutput.java | 31 ++++-- .../annotation/search/SemanticSearch.java | 61 ++++++++++-- .../textcompletion/TextCompletion.java | 40 +++++--- 7 files changed, 230 insertions(+), 103 deletions(-) diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java index 67a3cc3f..c6065206 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java @@ -61,7 +61,7 @@ * * @return The OpenAI chat model to use. */ - String model(); + String chatModel(); /** * The user message that user has entered for assistant to respond to. @@ -70,6 +70,14 @@ */ String userMessage(); + /** + * The name of the configuration section for AI service connectivity settings. + * + * @return The name of the configuration section for AI service connectivity + * settings. + */ + String aiConnectionName() default ""; + /** * The configuration section name for the table settings for assistant chat * storage. @@ -86,4 +94,39 @@ * returns {@code DEFAULT_COLLECTION}. */ String collectionName() default DEFAULT_COLLECTION; + + /** + * The sampling temperature to use, between 0 and 2. Higher values like 0.8 will + * make the output + * more random, while lower values like 0.2 will make it more focused and + * deterministic. + * It's generally recommended to use this or {@link #topP()} but not both. + * + * @return The sampling temperature value. + */ + String temperature() default "0.5"; + + /** + * An alternative to sampling with temperature, called nucleus sampling, where + * the model considers + * the results of the tokens with top_p probability mass. So 0.1 means only the + * tokens comprising the top 10% + * probability mass are considered. + * It's generally recommended to use this or {@link #temperature()} but not + * both. + * + * @return The topP value. + */ + String topP() default ""; + + /** + * The maximum number of tokens to generate in the completion. + * The token count of your prompt plus max_tokens cannot exceed the model's + * context length. + * Most models have a context length of 2048 tokens (except for the newest + * models, which support 4096). + * + * @return The maxTokens value. + */ + String maxTokens() default "100"; } diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java index 610a1fcd..ec6b29d8 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java @@ -14,54 +14,55 @@ import java.lang.annotation.Target; /** - *

- * Assistant skill trigger attribute. - *

- * - * @since 1.0.0 - */ - @Retention(RetentionPolicy.RUNTIME) - @Target(ElementType.PARAMETER) - @CustomBinding(direction = "in", name = "", type = "assistantSkillTrigger") - public @interface AssistantSkillTrigger { - - /** - * The variable name used in function.json. - * - * @return The variable name used in function.json. - */ - String name(); + *

+ * Assistant skill trigger attribute. + *

+ * + * @since 1.0.0 + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.PARAMETER) +@CustomBinding(direction = "in", name = "", type = "assistantSkillTrigger") +public @interface AssistantSkillTrigger { + + /** + * The variable name used in function.json. + * + * @return The variable name used in function.json. + */ + String name(); + + /** + * The name of the function to be invoked by the assistant. + * + * @return The name of the function to be invoked by the assistant. + */ + String functionName() default ""; - /** - * The name of the function to be invoked by the assistant. - * - * @return The name of the function to be invoked by the assistant. - */ - String functionName() default ""; + /** + * The description of the assistant function, which is provided to the LLM. + * + * @return The description of the assistant function. + */ + String functionDescription(); - /** - * The description of the assistant function, which is provided to the LLM. - * - * @return The description of the assistant function. - */ - String functionDescription(); + /** + * The JSON description of the function parameter, which is provided to the LLM. + * If no description is provided, the description will be autogenerated. + * For more information on the syntax of the parameter description JSON, see the + * OpenAI API documentation: + * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools. + * + * @return The JSON description of the function parameter. + */ + String parameterDescriptionJson() default ""; - /** - * The JSON description of the function parameter, which is provided to the LLM. - * If no description is provided, the description will be autogenerated. - * For more information on the syntax of the parameter description JSON, see the OpenAI API documentation: - * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools. - * - * @return The JSON description of the function parameter. - */ - String parameterDescriptionJson() default ""; + /** + * The OpenAI chat model to use. + * When using Azure OpenAI, this should be the name of the model deployment. + * + * @return The OpenAI chat model to use. + */ + String chatModel() default "gpt-3.5-turbo"; - /** - * The OpenAI chat model to use. - * When using Azure OpenAI, this should be the name of the model deployment. - * - * @return The OpenAI chat model to use. - */ - String model() default "gpt-3.5-turbo"; - - } +} diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/ChatMessage.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/ChatMessage.java index 1d246201..9e1dfbcb 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/ChatMessage.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/ChatMessage.java @@ -8,14 +8,15 @@ /** *

- * Chat Message Entity which contains the content of the message, the role of the chat agent, and the name of the calling function if applicable. + * Chat Message Entity which contains the content of the message, the role of + * the chat agent, and the name of the calling function if applicable. *

*/ public class ChatMessage { private String content; private String role; - private String name; + private String toolCalls; /** * Initializes a new instance of the ChatMessage class. @@ -28,18 +29,17 @@ public ChatMessage(String content, String role) { this.role = role; } - /** * Initializes a new instance of the ChatMessage class. * - * @param content The content of the message. - * @param role The role of the chat agent. - * @param name The name of the calling function if applicable. + * @param content The content of the message. + * @param role The role of the chat agent. + * @param toolCalls The toolCalls of the calling function if applicable. */ - public ChatMessage(String content, String role, String name) { + public ChatMessage(String content, String role, String toolCalls) { this.content = content; this.role = role; - this.name = name; + this.toolCalls = toolCalls; } /** @@ -79,21 +79,20 @@ public void setRole(String role) { } /** - * Gets the name of the calling function if applicable. + * Gets the toolCalls of the calling function if applicable. * - * @return The name of the calling function if applicable. + * @return The toolCalls of the calling function if applicable. */ - public String getName() { - return name; + public String getToolCalls() { + return toolCalls; } /** - * Sets the name of the calling function if applicable. + * Sets the toolCalls of the calling function if applicable. * - * @param name The name of the calling function if applicable. + * @param name The toolCalls of the calling function if applicable. */ - public void setName(String name) { - this.name = name; + public void setToolCalls(String toolCalls) { + this.toolCalls = toolCalls; } } - diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java index 39d46691..c3a65ac5 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java @@ -12,7 +12,7 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; - + /** *

* Embeddings input attribute which is used for embedding generation. @@ -34,17 +34,21 @@ /** * The ID of the model to use. - * Changing the default embeddings model is a breaking change, since any changes will be stored in a vector database for lookup. - * Changing the default model can cause the lookups to start misbehaving if they don't match the data that was previously ingested into the vector database. + * Changing the default embeddings model is a breaking change, since any changes + * will be stored in a vector database for lookup. + * Changing the default model can cause the lookups to start misbehaving if they + * don't match the data that was previously ingested into the vector database. * * @return The model ID. */ - String model() default "text-embedding-ada-002"; + String embeddingsModel() default "text-embedding-ada-002"; /** * The maximum number of characters to chunk the input into. - * At the time of writing, the maximum input tokens allowed for second-generation input embedding models - * like text-embedding-ada-002 is 8191. 1 token is ~4 chars in English, which translates to roughly 32K + * At the time of writing, the maximum input tokens allowed for + * second-generation input embedding models + * like text-embedding-ada-002 is 8191. 1 token is ~4 chars in English, which + * translates to roughly 32K * characters of English input that can fit into a single chunk. * * @return The maximum number of characters to chunk the input into. @@ -65,11 +69,19 @@ */ String input(); + /** + * The name of the configuration section for AI service connectivity settings. + * + * @return The name of the configuration section for AI service connectivity + * settings. + */ + String aiConnectionName() default ""; + /** * The input type. * * @return The input type. */ InputType inputType(); - - } + +} diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java index 87402766..7c01a724 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java @@ -27,17 +27,21 @@ /** * The ID of the model to use. - * Changing the default embeddings model is a breaking change, since any changes will be stored in a vector database for lookup. - * Changing the default model can cause the lookups to start misbehaving if they don't match the data that was previously ingested into the vector database. + * Changing the default embeddings model is a breaking change, since any changes + * will be stored in a vector database for lookup. + * Changing the default model can cause the lookups to start misbehaving if they + * don't match the data that was previously ingested into the vector database. * * @return The model ID. */ - String model() default "text-embedding-ada-002"; + String embeddingsModel() default "text-embedding-ada-002"; /** * The maximum number of characters to chunk the input into. - * At the time of writing, the maximum input tokens allowed for second-generation input embedding models - * like text-embedding-ada-002 is 8191. 1 token is ~4 chars in English, which translates to roughly 32K + * At the time of writing, the maximum input tokens allowed for + * second-generation input embedding models + * like text-embedding-ada-002 is 8191. 1 token is ~4 chars in English, which + * translates to roughly 32K * characters of English input that can fit into a single chunk. * * @return The maximum number of characters to chunk the input into. @@ -66,13 +70,22 @@ InputType inputType(); /** - * The name of an app setting or environment variable which contains a connection string value. + * The name of the configuration section for AI service connectivity settings. + * + * @return The name of the configuration section for AI service connectivity + * settings. + */ + String aiConnectionName() default ""; + + /** + * The name of an app setting or environment variable which contains a + * connection string value. * This property supports binding expressions. * - * @return The connection name. + * @return The store connection name. */ - String connectionName(); - + String storeConnectionName(); + /** * The name of the collection or table to search. * This property supports binding expressions. diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java index 87941b85..a6669175 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java @@ -8,10 +8,10 @@ import com.microsoft.azure.functions.annotation.CustomBinding; - import java.lang.annotation.ElementType; - import java.lang.annotation.Retention; - import java.lang.annotation.RetentionPolicy; - import java.lang.annotation.Target; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.PARAMETER) @@ -26,12 +26,21 @@ String name(); /** - * The name of an app setting or environment variable which contains a connection string value. + * The name of the configuration section for AI service connectivity settings. + * + * @return The name of the configuration section for AI service connectivity + * settings. + */ + String aiConnectionName() default ""; + + /** + * The name of an app setting or environment variable which contains a + * connection string value. * This property supports binding expressions. * * @return The connection name. */ - String connectionName(); + String searchConnectionName(); /** * The name of the collection or table to search. @@ -50,7 +59,6 @@ */ String query() default ""; - /** * The model to use for embeddings. * The default value is "text-embedding-ada-002". @@ -69,10 +77,10 @@ */ String chatModel() default "gpt-3.5-turbo"; - /** * The system prompt to use for prompting the large language model. - * The system prompt will be appended with knowledge that is fetched as a result of the Query. + * The system prompt will be appended with knowledge that is fetched as a result + * of the Query. * The combined prompt will then be sent to the OpenAI Chat API. * This property supports binding expressions. * @@ -94,4 +102,39 @@ String systemPrompt() default "You are a helpful assistant. You are responding t * @return The number of knowledge items to inject into the SystemPrompt. */ int maxKnowledgeLength() default 1; + + /** + * The sampling temperature to use, between 0 and 2. Higher values like 0.8 will + * make the output + * more random, while lower values like 0.2 will make it more focused and + * deterministic. + * It's generally recommended to use this or {@link #topP()} but not both. + * + * @return The sampling temperature value. + */ + String temperature() default "0.5"; + + /** + * An alternative to sampling with temperature, called nucleus sampling, where + * the model considers + * the results of the tokens with top_p probability mass. So 0.1 means only the + * tokens comprising the top 10% + * probability mass are considered. + * It's generally recommended to use this or {@link #temperature()} but not + * both. + * + * @return The topP value. + */ + String topP() default ""; + + /** + * The maximum number of tokens to generate in the completion. + * The token count of your prompt plus max_tokens cannot exceed the model's + * context length. + * Most models have a context length of 2048 tokens (except for the newest + * models, which support 4096). + * + * @return The maxTokens value. + */ + String maxTokens() default "2048"; } diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java index 1d024ce4..7f67659b 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java @@ -12,10 +12,11 @@ import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; - + /** *

- * Assistant query input attribute which is used query the Assistant to get current state. + * Assistant query input attribute which is used query the Assistant to get + * current state. *

* * @since 1.0.0 @@ -39,16 +40,26 @@ */ String prompt(); + /** + * The name of the configuration section for AI service connectivity settings. + * + * @return The name of the configuration section for AI service connectivity + * settings. + */ + String aiConnectionName() default ""; + /** * The ID of the model to use. * * @return The model ID. */ - String model() default "gpt-3.5-turbo"; + String chatModel() default "gpt-3.5-turbo"; /** - * The sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output - * more random, while lower values like 0.2 will make it more focused and deterministic. + * The sampling temperature to use, between 0 and 2. Higher values like 0.8 will + * make the output + * more random, while lower values like 0.2 will make it more focused and + * deterministic. * It's generally recommended to use this or {@link #topP()} but not both. * * @return The sampling temperature value. @@ -56,10 +67,13 @@ String temperature() default "0.5"; /** - * An alternative to sampling with temperature, called nucleus sampling, where the model considers - * the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + * An alternative to sampling with temperature, called nucleus sampling, where + * the model considers + * the results of the tokens with top_p probability mass. So 0.1 means only the + * tokens comprising the top 10% * probability mass are considered. - * It's generally recommended to use this or {@link #temperature()} but not both. + * It's generally recommended to use this or {@link #temperature()} but not + * both. * * @return The topP value. */ @@ -67,11 +81,13 @@ /** * The maximum number of tokens to generate in the completion. - * The token count of your prompt plus max_tokens cannot exceed the model's context length. - * Most models have a context length of 2048 tokens (except for the newest models, which support 4096). + * The token count of your prompt plus max_tokens cannot exceed the model's + * context length. + * Most models have a context length of 2048 tokens (except for the newest + * models, which support 4096). * * @return The maxTokens value. */ String maxTokens() default "100"; - - } + +} From 95f10dc21d3a1b79718c0caa4deddd8d5a46ef6d Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Sun, 13 Apr 2025 23:49:02 -0700 Subject: [PATCH 11/21] update assistant message --- CHANGELOG.md | 9 ++++++--- java-library/CHANGELOG.md | 10 +++++++++- java-library/pom.xml | 2 +- ...ChatMessage.java => AssistantMessage.java} | 19 ++++--------------- .../annotation/assistant/AssistantState.java | 14 +++++++------- .../embeddings/EmbeddingsContext.java | 12 ++++++------ .../{ChatMessage.cs => AssistantMessage.cs} | 6 +++--- .../Assistants/AssistantState.cs | 2 +- 8 files changed, 37 insertions(+), 37 deletions(-) rename java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/{ChatMessage.java => AssistantMessage.java} (78%) rename src/Functions.Worker.Extensions.OpenAI/Assistants/{ChatMessage.cs => AssistantMessage.cs} (86%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7139b41f..2049a173 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,13 +13,16 @@ Starting v0.1.0 for Microsoft.Azure.WebJobs.Extensions.OpenAI.AzureAISearch, it ### Breaking -- Model properties named to ChatModel and EmbeddingsModel in related bindings -- Managed identity support through config section and binding parameter AIConnectionName. +- Renamed model properites to `chatModel` and `embeddingsModel` in AssistantPost, Embeddings and TextCompletion bindings. +- Renamed connectionName to `searchConnectionName` in SemanticSearch binding. +- Renamed connectionName to `storeConnectionName` in EmbeddingsStore binding. +- Renamed ChatMessage entity to AssistantMessage. +- Managed identity support through config section and binding parameter `aiConnectionName` in AssistantPost, Embeddings, EmbeddingsStore, SemanticSearch and TextCompletion bindings. ### Changed - Updated Azure.AI.OpenAI from 1.0.0-beta.15 to 2.1.0 -- Updated Azure.Data.Tables from 12.9.1 to 12.10.0, Azure.Identity from 1.12.1 to 1.13.2, Microsoft.Extensions.Azure from 1.7.5 to 1.8.0 +- Updated Azure.Data.Tables from 12.9.1 to 12.10.0, Azure.Identity from 1.12.1 to 1.13.2, Microsoft.Extensions.Azure from 1.7.5 to 1.10.0 ## v0.18.0 - 2024/10/08 diff --git a/java-library/CHANGELOG.md b/java-library/CHANGELOG.md index 8e928b10..77a7e203 100644 --- a/java-library/CHANGELOG.md +++ b/java-library/CHANGELOG.md @@ -7,9 +7,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## v0.5.0 - Unreleased +### Breaking + +- model properties renamed to `chatModel` and `embeddingsModel` in assistantPost, embeddings and textCompletion bindings. +- renamed connectionName to `searchConnectionName` in semanticSearch binding. +- renamed connectionName to `storeConnectionName` in embeddingsStore binding. +- renamed ChatMessage to `AssistantMessage`. +- managed identity support through config section and binding parameter `aiConnectionName` in assistantPost, embeddings, embeddingsStore, semanticSearch and textCompletion bindings. + ### Changed -- Update azure-ai-openai from 1.0.0-beta.11 to 1.0.0-beta.14 +- Update azure-ai-openai from 1.0.0-beta.11 to 1.0.0-beta.16 ## v0.4.0 - 2024/10/08 diff --git a/java-library/pom.xml b/java-library/pom.xml index bf26b3d7..221bf968 100644 --- a/java-library/pom.xml +++ b/java-library/pom.xml @@ -88,7 +88,7 @@ com.azure azure-ai-openai - 1.0.0-beta.14 + 1.0.0-beta.16 compile diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/ChatMessage.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantMessage.java similarity index 78% rename from java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/ChatMessage.java rename to java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantMessage.java index 9e1dfbcb..32b8281f 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/ChatMessage.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantMessage.java @@ -12,31 +12,20 @@ * the chat agent, and the name of the calling function if applicable. *

*/ -public class ChatMessage { +public class AssistantMessage { private String content; private String role; private String toolCalls; /** - * Initializes a new instance of the ChatMessage class. - * - * @param content The content of the message. - * @param role The role of the chat agent. - */ - public ChatMessage(String content, String role) { - this.content = content; - this.role = role; - } - - /** - * Initializes a new instance of the ChatMessage class. + * Initializes a new instance of the AssistantMessage class. * * @param content The content of the message. * @param role The role of the chat agent. * @param toolCalls The toolCalls of the calling function if applicable. */ - public ChatMessage(String content, String role, String toolCalls) { + public AssistantMessage(String content, String role, String toolCalls) { this.content = content; this.role = role; this.toolCalls = toolCalls; @@ -90,7 +79,7 @@ public String getToolCalls() { /** * Sets the toolCalls of the calling function if applicable. * - * @param name The toolCalls of the calling function if applicable. + * @param toolCalls The toolCalls of the calling function if applicable. */ public void setToolCalls(String toolCalls) { this.toolCalls = toolCalls; diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantState.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantState.java index 4ebcb0ea..0f88cfc1 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantState.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantState.java @@ -20,11 +20,11 @@ public class AssistantState { private String lastUpdatedAt; private int totalMessages; private int totalTokens; - private List recentMessages; + private List recentMessages; public AssistantState(String id, boolean exists, - String createdAt, String lastUpdatedAt, - int totalMessages, int totalTokens, List recentMessages) { + String createdAt, String lastUpdatedAt, + int totalMessages, int totalTokens, List recentMessages) { this.id = id; this.exists = exists; this.createdAt = createdAt; @@ -56,7 +56,7 @@ public boolean isExists() { * Gets timestamp of when assistant is created. * * @return The timestamp of when assistant is created. - */ + */ public String getCreatedAt() { return createdAt; } @@ -93,7 +93,7 @@ public int getTotalTokens() { * * @return A list of the recent messages from the assistant. */ - public List getRecentMessages() { + public List getRecentMessages() { return recentMessages; } @@ -156,8 +156,8 @@ public void setTotalTokens(int totalTokens) { * * @param recentMessages A list of the recent messages from the assistant. */ - public void setRecentMessages(List recentMessages) { + public void setRecentMessages(List recentMessages) { this.recentMessages = recentMessages; } - + } diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java index 1627b380..20a348e9 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java @@ -7,20 +7,20 @@ package com.microsoft.azure.functions.openai.annotation.embeddings; import com.azure.ai.openai.models.Embeddings; -import com.azure.ai.openai.models.EmbeddingsOptions; +import java.util.List; public class EmbeddingsContext { - private EmbeddingsOptions request; + private List input; private Embeddings response; private int count = 0; - public EmbeddingsOptions getRequest() { - return request; + public List getInput() { + return input; } - public void setRequest(EmbeddingsOptions request) { - this.request = request; + public void setInput(List input) { + this.input = input; } public Embeddings getResponse() { diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantMessage.cs similarity index 86% rename from src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs rename to src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantMessage.cs index 73595e69..cdad6ef4 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantMessage.cs @@ -8,15 +8,15 @@ namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Assistants; /// /// Chat Message Entity which contains the content of the message, the role of the chat agent, and the name of the calling function if applicable. /// -public class ChatMessage +public class AssistantMessage { /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The content of the message. /// The role of the chat agent. /// The tool calls. - public ChatMessage(string content, string role, string toolCalls) + public AssistantMessage(string content, string role, string toolCalls) { this.Content = content; this.Role = role; diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantState.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantState.cs index 1ef18d72..7570ba87 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantState.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantState.cs @@ -50,5 +50,5 @@ public class AssistantState /// Gets a list of the recent messages from the assistant. /// [JsonPropertyName("recentMessages")] - public IReadOnlyList RecentMessages { get; set; } = Array.Empty(); + public IReadOnlyList RecentMessages { get; set; } = Array.Empty(); } From e94e84f224c4438490c9e9175f49d3e94e6fece8 Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Mon, 14 Apr 2025 00:47:26 -0700 Subject: [PATCH 12/21] remove constructor --- README.md | 2 +- samples/assistant/java/extensions.csproj | 2 +- samples/chat/java/extensions.csproj | 2 +- .../chat/powershell/PostUserResponse/function.json | 2 +- samples/chat/powershell/extensions.csproj | 2 +- samples/embeddings/java/extensions.csproj | 2 +- samples/embeddings/powershell/extensions.csproj | 2 +- samples/rag-aisearch/java/extensions.csproj | 2 +- samples/rag-aisearch/powershell/extensions.csproj | 2 +- samples/rag-cosmosdb/java/extensions.csproj | 2 +- samples/rag-cosmosdb/powershell/extensions.csproj | 2 +- samples/rag-cosmosdb/python/extensions.csproj | 2 +- samples/rag-kusto/java/extensions.csproj | 2 +- samples/rag-kusto/powershell/extensions.csproj | 2 +- samples/rag-kusto/python/extensions.csproj | 2 +- samples/textcompletion/java/extensions.csproj | 2 +- samples/textcompletion/powershell/extensions.csproj | 2 +- .../Assistants/AssistantPostInputAttribute.cs | 6 ++---- .../Assistants/AssistantSkillTriggerAttribute.cs | 8 -------- .../Embeddings/EmbeddingsInputAttribute.cs | 6 ++---- .../Embeddings/EmbeddingsStoreOutputAttribute.cs | 6 ++---- .../Search/SemanticSearchInputAttribute.cs | 6 ++---- .../TextCompletion/TextCompletionInputAttribute.cs | 6 ++---- .../Assistants/AssistantBaseAttribute.cs | 11 +---------- .../Assistants/AssistantPostAttribute.cs | 3 +-- .../Assistants/AssistantSkillTriggerAttribute.cs | 8 -------- .../Embeddings/EmbeddingsAttribute.cs | 3 +-- .../Embeddings/EmbeddingsBaseAttribute.cs | 6 ++---- .../Embeddings/EmbeddingsStoreAttribute.cs | 3 +-- .../Models/AssistantMessage.cs | 2 +- .../Search/SemanticSearchAttribute.cs | 3 +-- .../TextCompletionAttribute.cs | 3 +-- 32 files changed, 36 insertions(+), 78 deletions(-) diff --git a/README.md b/README.md index 55d0440c..aaa3fcdf 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,7 @@ Function usage example: public static IActionResult PostUserResponse( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "chats/{chatId}")] HttpRequestData req, string chatId, - [AssistantPostInput("{chatId}", "{Query.message}", "AzureOpenAI", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state) + [AssistantPostInput("{chatId}", "{Query.message}", AIConnectionName = "AzureOpenAI", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state) { return new OkObjectResult(state.RecentMessages.LastOrDefault()?.Content ?? "No response returned."); } diff --git a/samples/assistant/java/extensions.csproj b/samples/assistant/java/extensions.csproj index eb243f13..ad5a93c6 100644 --- a/samples/assistant/java/extensions.csproj +++ b/samples/assistant/java/extensions.csproj @@ -1,6 +1,6 @@  - net60 + net80 ** target/azure-functions/azfs-java-openai-sample/bin diff --git a/samples/chat/java/extensions.csproj b/samples/chat/java/extensions.csproj index eb243f13..ad5a93c6 100644 --- a/samples/chat/java/extensions.csproj +++ b/samples/chat/java/extensions.csproj @@ -1,6 +1,6 @@  - net60 + net80 ** target/azure-functions/azfs-java-openai-sample/bin diff --git a/samples/chat/powershell/PostUserResponse/function.json b/samples/chat/powershell/PostUserResponse/function.json index 16142e20..46a0ba09 100644 --- a/samples/chat/powershell/PostUserResponse/function.json +++ b/samples/chat/powershell/PostUserResponse/function.json @@ -20,7 +20,7 @@ "direction": "in", "name": "ChatBotState", "id": "{chatId}", - "model": "%CHAT_MODEL_DEPLOYMENT_NAME%", + "chatModel": "%CHAT_MODEL_DEPLOYMENT_NAME%", "userMessage": "{Query.message}", "chatStorageConnectionSetting": "AzureWebJobsStorage", "collectionName": "ChatState" diff --git a/samples/chat/powershell/extensions.csproj b/samples/chat/powershell/extensions.csproj index 8f73abd3..82ae1923 100644 --- a/samples/chat/powershell/extensions.csproj +++ b/samples/chat/powershell/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/embeddings/java/extensions.csproj b/samples/embeddings/java/extensions.csproj index eb243f13..ad5a93c6 100644 --- a/samples/embeddings/java/extensions.csproj +++ b/samples/embeddings/java/extensions.csproj @@ -1,6 +1,6 @@  - net60 + net80 ** target/azure-functions/azfs-java-openai-sample/bin diff --git a/samples/embeddings/powershell/extensions.csproj b/samples/embeddings/powershell/extensions.csproj index 8f73abd3..82ae1923 100644 --- a/samples/embeddings/powershell/extensions.csproj +++ b/samples/embeddings/powershell/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/rag-aisearch/java/extensions.csproj b/samples/rag-aisearch/java/extensions.csproj index fcbb0cac..29683eae 100644 --- a/samples/rag-aisearch/java/extensions.csproj +++ b/samples/rag-aisearch/java/extensions.csproj @@ -1,6 +1,6 @@  - net60 + net80 ** target/azure-functions/azfs-java-openai-sample/bin diff --git a/samples/rag-aisearch/powershell/extensions.csproj b/samples/rag-aisearch/powershell/extensions.csproj index 68c66ba4..1c321726 100644 --- a/samples/rag-aisearch/powershell/extensions.csproj +++ b/samples/rag-aisearch/powershell/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/rag-cosmosdb/java/extensions.csproj b/samples/rag-cosmosdb/java/extensions.csproj index 010fdc9a..2dbe5f2b 100644 --- a/samples/rag-cosmosdb/java/extensions.csproj +++ b/samples/rag-cosmosdb/java/extensions.csproj @@ -1,6 +1,6 @@  - net60 + net80 ** target/azure-functions/azfs-java-openai-sample/bin diff --git a/samples/rag-cosmosdb/powershell/extensions.csproj b/samples/rag-cosmosdb/powershell/extensions.csproj index bf9f1260..02592700 100644 --- a/samples/rag-cosmosdb/powershell/extensions.csproj +++ b/samples/rag-cosmosdb/powershell/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/rag-cosmosdb/python/extensions.csproj b/samples/rag-cosmosdb/python/extensions.csproj index 61aaed1f..e67d7e72 100644 --- a/samples/rag-cosmosdb/python/extensions.csproj +++ b/samples/rag-cosmosdb/python/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/rag-kusto/java/extensions.csproj b/samples/rag-kusto/java/extensions.csproj index e2323ac5..3fded17b 100644 --- a/samples/rag-kusto/java/extensions.csproj +++ b/samples/rag-kusto/java/extensions.csproj @@ -1,6 +1,6 @@  - net60 + net80 ** target/azure-functions/azfs-java-openai-sample/bin diff --git a/samples/rag-kusto/powershell/extensions.csproj b/samples/rag-kusto/powershell/extensions.csproj index c8f3ee70..08c2647b 100644 --- a/samples/rag-kusto/powershell/extensions.csproj +++ b/samples/rag-kusto/powershell/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/rag-kusto/python/extensions.csproj b/samples/rag-kusto/python/extensions.csproj index 1346f6ff..05d7cdf4 100644 --- a/samples/rag-kusto/python/extensions.csproj +++ b/samples/rag-kusto/python/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/textcompletion/java/extensions.csproj b/samples/textcompletion/java/extensions.csproj index eb243f13..ad5a93c6 100644 --- a/samples/textcompletion/java/extensions.csproj +++ b/samples/textcompletion/java/extensions.csproj @@ -1,6 +1,6 @@  - net60 + net80 ** target/azure-functions/azfs-java-openai-sample/bin diff --git a/samples/textcompletion/powershell/extensions.csproj b/samples/textcompletion/powershell/extensions.csproj index 8f73abd3..82ae1923 100644 --- a/samples/textcompletion/powershell/extensions.csproj +++ b/samples/textcompletion/powershell/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs index cf6f8179..e102a0f5 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs @@ -15,12 +15,10 @@ public sealed class AssistantPostInputAttribute : InputBindingAttribute /// /// The assistant identifier. /// The user message. - /// The name of the configuration section for AI service connectivity settings. - public AssistantPostInputAttribute(string id, string userMessage, string aiConnectionName = "") + public AssistantPostInputAttribute(string id, string userMessage) { this.Id = id; this.UserMessage = userMessage; - this.AIConnectionName = aiConnectionName; } /// @@ -38,7 +36,7 @@ public AssistantPostInputAttribute(string id, string userMessage, string aiConne /// For OpenAI: /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. /// - public string AIConnectionName { get; set; } + public string AIConnectionName { get; set; } = ""; /// /// Gets the ID of the assistant to update. diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs index 6c6787ec..410eb73b 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs @@ -37,12 +37,4 @@ public AssistantSkillTriggerAttribute(string functionDescription) /// https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools. /// public string? ParameterDescriptionJson { get; set; } - - /// - /// Gets or sets the OpenAI chat model to use. - /// - /// - /// When using Azure OpenAI, then should be the name of the model deployment. - /// - public string Model { get; set; } = OpenAIModels.DefaultChatModel; } \ No newline at end of file diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs index 5c2465f1..a5943c6e 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs @@ -12,13 +12,11 @@ public class EmbeddingsInputAttribute : InputBindingAttribute /// /// The input source containing the data to generate embeddings for. /// The type of the input. - /// The name of the configuration section for AI service connectivity settings. /// Thrown if is null. - public EmbeddingsInputAttribute(string input, InputType inputType, string aiConnectionName = "") + public EmbeddingsInputAttribute(string input, InputType inputType) { this.Input = input ?? throw new ArgumentNullException(nameof(input)); this.InputType = inputType; - this.AIConnectionName = aiConnectionName; } /// @@ -36,7 +34,7 @@ public EmbeddingsInputAttribute(string input, InputType inputType, string aiConn /// For OpenAI: /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. /// - public string AIConnectionName { get; set; } + public string AIConnectionName { get; set; } = ""; /// /// Gets or sets the ID of the model to use. diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs index bbcaf326..9adde53b 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs @@ -18,17 +18,15 @@ public class EmbeddingsStoreOutputAttribute : OutputBindingAttribute /// The name of an app setting or environment variable which contains a connection string value for embedding store. /// /// The name of the collection or table to search or store. - /// The name of the configuration section for AI service connectivity settings. /// /// Thrown if or or are null. /// - public EmbeddingsStoreOutputAttribute(string input, InputType inputType, string storeConnectionName, string collection, string aiConnectionName = "") + public EmbeddingsStoreOutputAttribute(string input, InputType inputType, string storeConnectionName, string collection) { this.Input = input ?? throw new ArgumentNullException(nameof(input)); this.InputType = inputType; this.StoreConnectionName = storeConnectionName ?? throw new ArgumentNullException(nameof(storeConnectionName)); this.Collection = collection ?? throw new ArgumentNullException(nameof(collection)); - this.AIConnectionName = aiConnectionName; } /// @@ -46,7 +44,7 @@ public EmbeddingsStoreOutputAttribute(string input, InputType inputType, string /// For OpenAI: /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. /// - public string AIConnectionName { get; set; } + public string AIConnectionName { get; set; } = ""; /// /// Gets or sets the ID of the model to use. diff --git a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs index a4a50f61..560e420a 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs @@ -18,15 +18,13 @@ public sealed class SemanticSearchInputAttribute : InputBindingAttribute /// The name of an app setting or environment variable which contains a connection string value. /// /// The name of the collection or table to search or store. - /// The name of the configuration section for AI service connectivity settings. /// /// Thrown if either or are null. /// - public SemanticSearchInputAttribute(string searchConnectionName, string collection, string aiConnectionName = "") + public SemanticSearchInputAttribute(string searchConnectionName, string collection) { this.SearchConnectionName = searchConnectionName ?? throw new ArgumentNullException(nameof(searchConnectionName)); this.Collection = collection ?? throw new ArgumentNullException(nameof(collection)); - this.AIConnectionName = aiConnectionName; } /// @@ -44,7 +42,7 @@ public SemanticSearchInputAttribute(string searchConnectionName, string collecti /// For OpenAI: /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. /// - public string AIConnectionName { get; set; } + public string AIConnectionName { get; set; } = ""; /// /// Gets or sets the name of an app setting or environment variable which contains a connection string value. diff --git a/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs index d7c49268..a9538709 100644 --- a/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs @@ -14,11 +14,9 @@ public sealed class TextCompletionInputAttribute : InputBindingAttribute /// Initializes a new instance of the class with the specified text prompt. /// /// The prompt to generate completions for, encoded as a string. - /// The name of the configuration section for AI service connectivity settings. - public TextCompletionInputAttribute(string prompt, string aiConnectionName = "") + public TextCompletionInputAttribute(string prompt) { this.Prompt = prompt ?? throw new ArgumentNullException(nameof(prompt)); - this.AIConnectionName = aiConnectionName; } /// @@ -36,7 +34,7 @@ public TextCompletionInputAttribute(string prompt, string aiConnectionName = "") /// For OpenAI: /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. /// - public string AIConnectionName { get; set; } + public string AIConnectionName { get; set; } = ""; /// /// Gets or sets the prompt to generate completions for, encoded as a string. diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs index 4c08ea38..9dea44de 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs @@ -15,15 +15,6 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; [AttributeUsage(AttributeTargets.Parameter)] public class AssistantBaseAttribute : Attribute { - /// - /// Initializes a new instance of the class. - /// - /// The name of the configuration section for AI service connectivity settings. - public AssistantBaseAttribute(string aiConnectionName = "") - { - this.AIConnectionName = aiConnectionName; - } - /// /// Gets or sets the name of the Large Language Model to invoke for chat responses. /// The default value is "gpt-3.5-turbo". @@ -49,7 +40,7 @@ public AssistantBaseAttribute(string aiConnectionName = "") /// For OpenAI: /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. /// - public string AIConnectionName { get; set; } + public string AIConnectionName { get; set; } = ""; /// /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs index c2b42734..9e3c5faf 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs @@ -14,8 +14,7 @@ public sealed class AssistantPostAttribute : AssistantBaseAttribute /// /// The assistant identifier. /// The user message. - /// The name of the configuration section for AI service connectivity settings. - public AssistantPostAttribute(string id, string userMessage, string aiConnectionName = "") : base(aiConnectionName) + public AssistantPostAttribute(string id, string userMessage) { this.Id = id; this.UserMessage = userMessage; diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs index 749e868b..25209b4c 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs @@ -42,12 +42,4 @@ public AssistantSkillTriggerAttribute(string functionDescription) /// https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools. /// public string? ParameterDescriptionJson { get; set; } - - /// - /// Gets or sets the OpenAI chat model to use. - /// - /// - /// When using Azure OpenAI, then should be the name of the model deployment. - /// - public string Model { get; set; } = "gpt-3.5-turbo"; } \ No newline at end of file diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsAttribute.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsAttribute.cs index 8c8c733b..bd15b6ae 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsAttribute.cs @@ -21,9 +21,8 @@ public sealed class EmbeddingsAttribute : EmbeddingsBaseAttribute /// /// The input source containing the data to generate embeddings for. /// The type of the input. - /// The name of the configuration section for AI service connectivity settings. /// Thrown if is null. - public EmbeddingsAttribute(string input, InputType inputType, string aiConnectionName = "") : base(input, inputType, aiConnectionName) + public EmbeddingsAttribute(string input, InputType inputType) : base(input, inputType) { } } diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsBaseAttribute.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsBaseAttribute.cs index 25f516b1..5ecac174 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsBaseAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsBaseAttribute.cs @@ -22,15 +22,13 @@ public class EmbeddingsBaseAttribute : Attribute /// /// The input source containing the data to generate embeddings for. /// The type of the input. - /// The name of the configuration section for AI service connectivity settings. /// Thrown if is null. - public EmbeddingsBaseAttribute(string input, InputType inputType, string aiConnectionName = "") + public EmbeddingsBaseAttribute(string input, InputType inputType) { this.Input = string.IsNullOrEmpty(input) ? throw new ArgumentException("Input cannot be null or empty.", nameof(input)) : input; this.InputType = inputType; - this.AIConnectionName = aiConnectionName; } /// @@ -48,7 +46,7 @@ public EmbeddingsBaseAttribute(string input, InputType inputType, string aiConne /// For OpenAI: /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. /// - public string AIConnectionName { get; set; } + public string AIConnectionName { get; set; } = ""; /// /// Gets or sets the ID of the model to use. diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreAttribute.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreAttribute.cs index def0ddb2..29ec64d3 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreAttribute.cs @@ -20,11 +20,10 @@ public sealed class EmbeddingsStoreAttribute : EmbeddingsBaseAttribute /// The name of an app setting or environment variable which contains a connection string value. /// /// The name of the collection or table to search or store. - /// The name of the configuration section for AI service connectivity settings. /// /// Thrown if or or are null. /// - public EmbeddingsStoreAttribute(string input, InputType inputType, string storeConnectionName, string collection, string aiConnectionName = "") : base(input, inputType, aiConnectionName) + public EmbeddingsStoreAttribute(string input, InputType inputType, string storeConnectionName, string collection) : base(input, inputType) { this.StoreConnectionName = storeConnectionName ?? throw new ArgumentNullException(nameof(storeConnectionName)); this.Collection = collection ?? throw new ArgumentNullException(nameof(collection)); diff --git a/src/WebJobs.Extensions.OpenAI/Models/AssistantMessage.cs b/src/WebJobs.Extensions.OpenAI/Models/AssistantMessage.cs index a6166d1f..68bdc20b 100644 --- a/src/WebJobs.Extensions.OpenAI/Models/AssistantMessage.cs +++ b/src/WebJobs.Extensions.OpenAI/Models/AssistantMessage.cs @@ -40,5 +40,5 @@ public AssistantMessage(string content, string role, string toolCalls) /// Gets or sets the tool calls. /// [JsonProperty("toolCalls")] - public string ToolCalls { get; } + public string ToolCalls { get; set; } } diff --git a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchAttribute.cs b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchAttribute.cs index 3edbc5f1..dce23d5d 100644 --- a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchAttribute.cs @@ -22,11 +22,10 @@ public sealed class SemanticSearchAttribute : AssistantBaseAttribute /// The name of an app setting or environment variable which contains a connection string value of search provider. /// /// The name of the collection or table to search or store. - /// The name of the configuration section for AI service connectivity settings. /// /// Thrown if either or are null. /// - public SemanticSearchAttribute(string searchConnectionName, string collection, string aiConnectionName = "") : base(aiConnectionName) + public SemanticSearchAttribute(string searchConnectionName, string collection) { this.SearchConnectionName = searchConnectionName ?? throw new ArgumentNullException(nameof(searchConnectionName)); this.Collection = collection ?? throw new ArgumentNullException(nameof(collection)); diff --git a/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs b/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs index aac4242e..ab3f2944 100644 --- a/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs @@ -17,8 +17,7 @@ public sealed class TextCompletionAttribute : AssistantBaseAttribute /// Initializes a new instance of the class with the specified text prompt. /// /// The prompt to generate completions for, encoded as a string. - /// The name of the configuration section for AI service connectivity settings. - public TextCompletionAttribute(string prompt, string aiConnectionName = "") : base(aiConnectionName) + public TextCompletionAttribute(string prompt) { this.Prompt = string.IsNullOrEmpty(prompt) ? throw new ArgumentException("Input cannot be null or empty.", nameof(prompt)) From 9166d72e326d7de45c991eaf5849f1586d3c862a Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Mon, 14 Apr 2025 14:10:52 -0700 Subject: [PATCH 13/21] update changelog files --- .../annotation/assistant/AssistantSkillTrigger.java | 9 --------- src/Directory.Build.props | 2 +- src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md | 4 ++++ .../CHANGELOG.md | 5 +++++ src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md | 6 +++++- 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java index ec6b29d8..36f67dbb 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java @@ -56,13 +56,4 @@ * @return The JSON description of the function parameter. */ String parameterDescriptionJson() default ""; - - /** - * The OpenAI chat model to use. - * When using Azure OpenAI, this should be the name of the model deployment. - * - * @return The OpenAI chat model to use. - */ - String chatModel() default "gpt-3.5-turbo"; - } diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 91466cb7..9f4ab3ef 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -33,7 +33,7 @@ $(VersionPrefix).$(FileVersionRevision) - 0.16.0-alpha + 0.17.0-alpha 0.4.0-alpha diff --git a/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md b/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md index f53671d3..5e1f2299 100644 --- a/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md +++ b/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Managed identity support and consistency established with other Azure Functions extensions +### Changed + +- Updated Microsoft.Azure.WebJobs.Extensions.OpenAI to 0.19.0 + ## v0.3.0 - 2024/10/08 ### Changed diff --git a/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CHANGELOG.md b/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CHANGELOG.md index bdef6010..187ff9f1 100644 --- a/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CHANGELOG.md +++ b/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added DiskANN support for CosmosDB (MongoDB) Search Provider. Refer [README](../../samples/rag-cosmosdb/README.md) for more information on usage. +### Changed + +- Updated Microsoft.Azure.WebJobs.Extensions.OpenAI to 0.19.0 + ## v0.3.0 - 2024/10/08 ### Added @@ -18,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added HNSW support for CosmosDB (MongoDB) Search Provider. Refer [README](../../samples/rag-cosmosdb/README.md) for more information on usage. ### Changed + - Updated nuget dependencies ## v0.2.0 - 2024/05/06 diff --git a/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md b/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md index 51430935..7bf7f1be 100644 --- a/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md +++ b/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md @@ -3,7 +3,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).\ + +## v0.17.0 - Unreleased + +- Updated Microsoft.Azure.WebJobs.Extensions.OpenAI to 0.19.0 ## v0.16.0 - 2024/10/08 From e1b0689d50735c7a77b56fae0a813d319b4d644b Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Wed, 16 Apr 2025 15:02:21 -0700 Subject: [PATCH 14/21] remove reference to Microsoft.Azure.WebJobs.Script.Abstractions --- .../Assistants/BuiltInFunctionsProvider.cs | 71 ------------------- .../WebJobs.Extensions.OpenAI.csproj | 1 - 2 files changed, 72 deletions(-) delete mode 100644 src/WebJobs.Extensions.OpenAI/Assistants/BuiltInFunctionsProvider.cs diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/BuiltInFunctionsProvider.cs b/src/WebJobs.Extensions.OpenAI/Assistants/BuiltInFunctionsProvider.cs deleted file mode 100644 index db996693..00000000 --- a/src/WebJobs.Extensions.OpenAI/Assistants/BuiltInFunctionsProvider.cs +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Collections.Immutable; -using System.Reflection; -using Microsoft.Azure.WebJobs.Script.Description; -using Newtonsoft.Json.Linq; - -namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; - -/// -/// Class that defines all the built-in functions for executing CNCF Serverless Workflows. -/// IMPORTANT: Renaming methods in this class is a breaking change! -/// -class BuiltInFunctionsProvider : IFunctionProvider -{ - /// - Task> IFunctionProvider.GetFunctionMetadataAsync() => - Task.FromResult(this.GetFunctionMetadata().ToImmutableArray()); - - // TODO: Not sure what this is for... - /// - ImmutableDictionary> IFunctionProvider.FunctionErrors => - new Dictionary>().ToImmutableDictionary(); - - - internal static string GetBuiltInFunctionName(string functionName) - { - return $"OpenAI::{functionName}"; - } - - /// - /// Returns an enumeration of all the function triggers defined in this class. - /// - IEnumerable GetFunctionMetadata() - { - foreach (MethodInfo method in this.GetType().GetMethods()) - { - if (method.GetCustomAttribute() is not FunctionNameAttribute) - { - // Not an Azure Function definition - continue; - } - - FunctionMetadata metadata = new() - { - // NOTE: We always use the method name and ignore the function name - Name = GetBuiltInFunctionName(method.Name), - ScriptFile = $"assembly:{method.ReflectedType.Assembly.FullName}", - EntryPoint = $"{method.ReflectedType.FullName}.{method.Name}", - Language = "DotNetAssembly", - }; - - // Scan the parameters for binding attributes and add them to the bindings collection - // so that we can register them with the Functions runtime. - foreach (ParameterInfo parameter in method.GetParameters()) - { - if (parameter.GetCustomAttribute() is not null) - { - // NOTE: We assume each OpenAI service function in this file defines the parameter name as "service". - metadata.Bindings.Add(BindingMetadata.Create(new JObject( - new JProperty("type", "openAIService"), - new JProperty("name", "service"), - new JProperty("direction", "in")))); - } - } - - yield return metadata; - } - } -} diff --git a/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj b/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj index c7a17757..efa2f556 100644 --- a/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj +++ b/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj @@ -10,7 +10,6 @@ - From 33ab08cc845a2429a0a1039c81bec3a74bee136b Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Thu, 24 Apr 2025 15:38:02 -0700 Subject: [PATCH 15/21] address review comments --- .../assistant/AssistantMessage.java | 3 +- .../Assistants/ChatCompletionJsonConverter.cs | 8 +++-- .../Embeddings/EmbeddingsConverter.cs | 18 +++------- .../Embeddings/EmbeddingsHelper.cs | 34 +++++++++++++++---- .../Embeddings/EmbeddingsStoreConverter.cs | 14 ++------ 5 files changed, 42 insertions(+), 35 deletions(-) diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantMessage.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantMessage.java index 32b8281f..11627eda 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantMessage.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantMessage.java @@ -8,8 +8,7 @@ /** *

- * Chat Message Entity which contains the content of the message, the role of - * the chat agent, and the name of the calling function if applicable. + * Chat Message Entity which contains the content of the message, the role of the chat agent, and the name of the calling function if applicable. *

*/ public class AssistantMessage { diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs index cf471922..f8e8e676 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs @@ -10,10 +10,14 @@ namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Assistants; public class ChatCompletionJsonConverter : JsonConverter { static readonly ModelReaderWriterOptions modelReaderWriterOptions = new("J"); - public override ChatCompletion Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override ChatCompletion Read( + ref Utf8JsonReader reader, + Type typeToConvert, + JsonSerializerOptions options) { using JsonDocument jsonDocument = JsonDocument.ParseValue(ref reader); - return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!; + return ModelReaderWriter.Read( + BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!; } public override void Write(Utf8JsonWriter writer, ChatCompletion value, JsonSerializerOptions options) diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs index df7d1d87..364f552b 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs @@ -1,11 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using System.ClientModel; using System.Text.Json; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; using Microsoft.Extensions.Logging; -using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; @@ -49,16 +47,10 @@ async Task ConvertCoreAsync( EmbeddingsAttribute attribute, CancellationToken cancellationToken) { - List input = await EmbeddingsHelper.BuildRequest(attribute.MaxOverlap, - attribute.MaxChunkLength, - attribute.InputType, - attribute.Input); - this.logger.LogInformation("Sending OpenAI embeddings request"); - ClientResult response = await this.openAIClientFactory.GetEmbeddingClient( - attribute.AIConnectionName, - attribute.EmbeddingsModel).GenerateEmbeddingsAsync(input, cancellationToken: cancellationToken); - this.logger.LogInformation("Received OpenAI embeddings count: {response}", response.Value.Count); - - return new EmbeddingsContext(input, response); + return await EmbeddingsHelper.GenerateEmbeddingsAsync( + attribute, + this.openAIClientFactory, + this.logger, + cancellationToken); } } diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs index a8cb6166..0b26aac6 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs @@ -1,7 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.ClientModel; using System.Diagnostics; +using Microsoft.Extensions.Logging; +using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; static class EmbeddingsHelper @@ -16,15 +19,34 @@ static EmbeddingsHelper() httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(UserAgent); } - public static async Task> BuildRequest(int maxOverlap, int maxChunkLength, InputType inputType, string input) + internal static async Task GenerateEmbeddingsAsync( + EmbeddingsBaseAttribute attribute, + OpenAIClientFactory openAIClientFactory, + ILogger logger, + CancellationToken cancellationToken = default) { - using TextReader reader = await GetTextReader(inputType, input); - if (maxOverlap >= maxChunkLength) + List chunks = await BuildRequest(attribute); + + logger.LogInformation("Sending OpenAI embeddings request"); + + ClientResult response = await openAIClientFactory.GetEmbeddingClient( + attribute.AIConnectionName, + attribute.EmbeddingsModel).GenerateEmbeddingsAsync(chunks, cancellationToken: cancellationToken); + + logger.LogInformation("Received OpenAI embeddings count: {count}", response.Value.Count); + + return new EmbeddingsContext(chunks, response); + } + + static async Task> BuildRequest(EmbeddingsBaseAttribute attribute) + { + using TextReader reader = await GetTextReader(attribute.InputType, attribute.Input); + if (attribute.MaxOverlap >= attribute.MaxChunkLength) { - throw new ArgumentOutOfRangeException($"MaxOverlap ({maxOverlap}) must be less than MaxChunkLength ({maxChunkLength})."); + throw new ArgumentOutOfRangeException($"MaxOverlap ({attribute.MaxOverlap}) must be less than MaxChunkLength ({attribute.MaxChunkLength})."); } - List chunks = GetTextChunks(reader, 0, maxChunkLength, maxOverlap).ToList(); + List chunks = GetTextChunks(reader, 0, attribute.MaxChunkLength, attribute.MaxOverlap).ToList(); return chunks; } @@ -40,7 +62,7 @@ static async Task GetTextReader(InputType inputType, string input) } else if (inputType == InputType.Url) { - if (!Uri.TryCreate(input, UriKind.Absolute, out Uri? uriResult) || + if (!Uri.TryCreate(input, UriKind.Absolute, out Uri? uriResult) || uriResult.Scheme != Uri.UriSchemeHttps) { throw new ArgumentException($"Invalid Url: {input}. Ensure it is a valid https Url."); diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs index 7cd02216..73be52a4 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs @@ -1,12 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using System.ClientModel; using System.Text.Json; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; class EmbeddingsStoreConverter : @@ -89,16 +87,8 @@ public async Task AddAsync(SearchableDocument item, CancellationToken cancellati } // Get embeddings from OpenAI - List input = await EmbeddingsHelper.BuildRequest(this.attribute.MaxOverlap, - this.attribute.MaxChunkLength, - this.attribute.InputType, - this.attribute.Input); - this.logger.LogInformation("Sending OpenAI embeddings request"); - ClientResult response = await this.openAIClientFactory.GetEmbeddingClient( - this.attribute.AIConnectionName, - this.attribute.EmbeddingsModel).GenerateEmbeddingsAsync(input, cancellationToken: cancellationToken); - EmbeddingsContext embeddingsContext = new(input, response); - this.logger.LogInformation("Received OpenAI embeddings of count: {count}", embeddingsContext.Count); + EmbeddingsContext embeddingsContext = await EmbeddingsHelper. + GenerateEmbeddingsAsync(this.attribute, this.openAIClientFactory, this.logger, cancellationToken); // Add document to the embed store item.Embeddings = embeddingsContext; From 6e85da4ecb446864c8ada3a7383def10dd401792 Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Sun, 27 Apr 2025 14:37:35 -0700 Subject: [PATCH 16/21] Input to Request, internalsvisible in csproj --- .../annotation/embeddings/EmbeddingsContext.java | 13 +++++++------ .../Embeddings/EmbeddingsContext.cs | 6 +++--- .../Search/SearchableDocumentJsonConverter.cs | 4 ++-- .../AzureAISearchProvider.cs | 2 +- .../CosmosDBSearchProvider.cs | 2 +- src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md | 2 +- .../KustoSearchProvider.cs | 2 +- .../Embeddings/EmbeddingsContext.cs | 8 ++++---- .../Embeddings/EmbeddingsContextConverter.cs | 4 ++-- .../Properties/AssemblyInfo.cs | 7 ------- .../Search/SearchableDocumentJsonConverter.cs | 12 +++++++++--- .../WebJobs.Extensions.OpenAI.csproj | 4 ++++ 12 files changed, 35 insertions(+), 31 deletions(-) delete mode 100644 src/WebJobs.Extensions.OpenAI/Properties/AssemblyInfo.cs diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java index 20a348e9..3e04c8cf 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java @@ -11,16 +11,16 @@ public class EmbeddingsContext { - private List input; + private List request; private Embeddings response; private int count = 0; - public List getInput() { - return input; + public List getRequest() { + return request; } - public void setInput(List input) { - this.input = input; + public void setRequest(List request) { + this.request = request; } public Embeddings getResponse() { @@ -38,7 +38,8 @@ public void setResponse(Embeddings response) { */ public int getCount() { return this.response != null && this.response.getData() != null - ? this.count : 0; + ? this.count + : 0; } } diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs index 4218199d..2f7dfa45 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs @@ -7,16 +7,16 @@ namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; public class EmbeddingsContext { - public EmbeddingsContext(IList Input, OpenAIEmbeddingCollection? Response) + public EmbeddingsContext(IList Request, OpenAIEmbeddingCollection? Response) { - this.Input = Input; + this.Request = Request; this.Response = Response; } /// /// Embeddings request sent to OpenAI. /// - public IList Input { get; set; } + public IList Request { get; set; } /// /// Embeddings response from OpenAI. diff --git a/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs b/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs index ac8e25fa..5b6804e2 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs @@ -24,9 +24,9 @@ public override void Write(Utf8JsonWriter writer, SearchableDocument value, Json writer.WritePropertyName("embeddingsContext"u8); writer.WriteStartObject(); - if (value.EmbeddingsContext?.Input is List inputList) + if (value.EmbeddingsContext?.Request is List inputList) { - writer.WritePropertyName("input"u8); + writer.WritePropertyName("request"u8); var inputWrapper = JsonModelListWrapper.FromList(inputList); inputWrapper.Write(writer, modelReaderWriterOptions); } diff --git a/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs b/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs index 14abae5b..4c4d0134 100644 --- a/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs +++ b/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs @@ -253,7 +253,7 @@ async Task IndexSectionsAsync(SearchClient searchClient, SearchableDocument docu new SearchDocument { ["id"] = Guid.NewGuid().ToString("N"), - ["text"] = document.Embeddings.Input![i], + ["text"] = document.Embeddings.Request![i], ["title"] = Path.GetFileNameWithoutExtension(document.Title), ["embeddings"] = document.Embeddings.Response[i].ToFloats().ToArray() ?? Array.Empty(), ["timestamp"] = DateTime.UtcNow diff --git a/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs b/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs index 76f4d1ca..82ad69dc 100644 --- a/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs +++ b/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs @@ -220,7 +220,7 @@ async Task UpsertVectorAsync(MongoClient cosmosClient, SearchableDocument docume { "id", Guid.NewGuid().ToString("N") }, { this.cosmosDBSearchConfigOptions.Value.TextKey, - document.Embeddings.Input![i] + document.Embeddings.Request![i] }, { "title", Path.GetFileNameWithoutExtension(document.Title) }, { diff --git a/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md b/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md index 7bf7f1be..a7c67706 100644 --- a/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md +++ b/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md @@ -3,7 +3,7 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), -and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).\ +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## v0.17.0 - Unreleased diff --git a/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs b/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs index 156868bb..fe46f501 100644 --- a/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs +++ b/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs @@ -83,7 +83,7 @@ public async Task AddDocumentAsync(SearchableDocument document, CancellationToke table.Rows.Add( Guid.NewGuid().ToString("N"), Path.GetFileNameWithoutExtension(document.Title), - document.Embeddings.Input![i], + document.Embeddings.Request![i], GetEmbeddingsString(document.Embeddings.Response[i].ToFloats().ToArray(), true), DateTime.UtcNow); } diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs index db182d4f..a890fb57 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs @@ -8,20 +8,20 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; /// /// Binding target for the . /// -/// The embeddings input that was sent to OpenAI. +/// The embeddings request that was sent to OpenAI. /// The embeddings response that was received from OpenAI. public class EmbeddingsContext { - public EmbeddingsContext(IList Input, OpenAIEmbeddingCollection? Response) + public EmbeddingsContext(IList Request, OpenAIEmbeddingCollection? Response) { - this.Input = Input; + this.Request = Request; this.Response = Response; } /// /// Embeddings request sent to OpenAI. /// - public IList Input { get; set; } + public IList Request { get; set; } /// /// Embeddings response from OpenAI. diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs index d848d22f..47c8807e 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs @@ -22,9 +22,9 @@ public override EmbeddingsContext Read(ref Utf8JsonReader reader, Type typeToCon public override void Write(Utf8JsonWriter writer, EmbeddingsContext value, JsonSerializerOptions options) { writer.WriteStartObject(); - writer.WritePropertyName("input"u8); + writer.WritePropertyName("request"u8); - if (value.Input is List inputList) + if (value.Request is List inputList) { var inputWrapper = JsonModelListWrapper.FromList(inputList); inputWrapper.Write(writer, modelReaderWriterOptions); diff --git a/src/WebJobs.Extensions.OpenAI/Properties/AssemblyInfo.cs b/src/WebJobs.Extensions.OpenAI/Properties/AssemblyInfo.cs deleted file mode 100644 index 5b5af38e..00000000 --- a/src/WebJobs.Extensions.OpenAI/Properties/AssemblyInfo.cs +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Runtime.CompilerServices; - -// Make internals visible to the test project -[assembly: InternalsVisibleTo("WebJobsOpenAIUnitTests")] \ No newline at end of file diff --git a/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs b/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs index 3dd4910b..d81919b5 100644 --- a/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs @@ -29,9 +29,14 @@ public override SearchableDocument Read(ref Utf8JsonReader reader, Type typeToCo { foreach (JsonProperty embeddingContextItem in item.Value.EnumerateObject()) { - if (embeddingContextItem.NameEquals("input"u8)) + if (embeddingContextItem.NameEquals("request"u8)) { - input = new List(); // ToDo: revisit + // Parse the array of string inputs + input = new List(); + foreach (JsonElement element in embeddingContextItem.Value.EnumerateArray()) + { + input.Add(element.GetString() ?? string.Empty); + } } if (embeddingContextItem.NameEquals("response"u8)) { @@ -78,8 +83,9 @@ public override void Write(Utf8JsonWriter writer, SearchableDocument value, Json writer.WritePropertyName("embeddingsContext"u8); writer.WriteStartObject(); - if (value.Embeddings?.Input is List inputList) + if (value.Embeddings?.Request is List inputList) { + writer.WritePropertyName("request"u8); var inputWrapper = JsonModelListWrapper.FromList(inputList); inputWrapper.Write(writer, modelReaderWriterOptions); } diff --git a/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj b/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj index efa2f556..0f4fbb48 100644 --- a/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj +++ b/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj @@ -13,4 +13,8 @@ + + + + \ No newline at end of file From 8bf1d4f58b691aa668371ede61df4a2d36d7cdeb Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Sun, 27 Apr 2025 15:15:39 -0700 Subject: [PATCH 17/21] use model default constants --- .../annotation/assistant/AssistantPost.java | 3 +- .../embeddings/EmbeddingsInput.java | 3 +- .../embeddings/EmbeddingsStoreOutput.java | 3 +- .../annotation/search/SemanticSearch.java | 7 ++-- .../textcompletion/TextCompletion.java | 3 +- .../openai/constants/ModelDefaults.java | 33 +++++++++++++++++++ 6 files changed, 45 insertions(+), 7 deletions(-) create mode 100644 java-library/src/main/java/com/microsoft/azure/functions/openai/constants/ModelDefaults.java diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java index c6065206..3a5eaf95 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java @@ -7,6 +7,7 @@ package com.microsoft.azure.functions.openai.annotation.assistant; import com.microsoft.azure.functions.annotation.CustomBinding; +import com.microsoft.azure.functions.openai.constants.ModelDefaults; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; @@ -61,7 +62,7 @@ * * @return The OpenAI chat model to use. */ - String chatModel(); + String chatModel() default ModelDefaults.DEFAULT_CHAT_MODEL; /** * The user message that user has entered for assistant to respond to. diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java index c3a65ac5..5d1d911b 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java @@ -7,6 +7,7 @@ package com.microsoft.azure.functions.openai.annotation.embeddings; import com.microsoft.azure.functions.annotation.CustomBinding; +import com.microsoft.azure.functions.openai.constants.ModelDefaults; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; @@ -41,7 +42,7 @@ * * @return The model ID. */ - String embeddingsModel() default "text-embedding-ada-002"; + String embeddingsModel() default ModelDefaults.DEFAULT_EMBEDDINGS_MODEL; /** * The maximum number of characters to chunk the input into. diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java index 7c01a724..ee333138 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java @@ -7,6 +7,7 @@ package com.microsoft.azure.functions.openai.annotation.embeddings; import com.microsoft.azure.functions.annotation.CustomBinding; +import com.microsoft.azure.functions.openai.constants.ModelDefaults; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; @@ -34,7 +35,7 @@ * * @return The model ID. */ - String embeddingsModel() default "text-embedding-ada-002"; + String embeddingsModel() default ModelDefaults.DEFAULT_EMBEDDINGS_MODEL; /** * The maximum number of characters to chunk the input into. diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java index a6669175..cd8e09ad 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java @@ -7,6 +7,7 @@ package com.microsoft.azure.functions.openai.annotation.search; import com.microsoft.azure.functions.annotation.CustomBinding; +import com.microsoft.azure.functions.openai.constants.ModelDefaults; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; @@ -66,16 +67,16 @@ * * @return The model to use for embeddings. */ - String embeddingsModel() default "text-embedding-ada-002"; + String embeddingsModel() default ModelDefaults.DEFAULT_EMBEDDINGS_MODEL; /** * the name of the Large Language Model to invoke for chat responses. * The default value is "gpt-3.5-turbo". * This property supports binding expressions. - * + * * @return The name of the Large Language Model to invoke for chat responses. */ - String chatModel() default "gpt-3.5-turbo"; + String chatModel() default ModelDefaults.DEFAULT_CHAT_MODEL; /** * The system prompt to use for prompting the large language model. diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java index 7f67659b..d59e24a8 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java @@ -7,6 +7,7 @@ package com.microsoft.azure.functions.openai.annotation.textcompletion; import com.microsoft.azure.functions.annotation.CustomBinding; +import com.microsoft.azure.functions.openai.constants.ModelDefaults; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; @@ -53,7 +54,7 @@ * * @return The model ID. */ - String chatModel() default "gpt-3.5-turbo"; + String chatModel() default ModelDefaults.DEFAULT_CHAT_MODEL; /** * The sampling temperature to use, between 0 and 2. Higher values like 0.8 will diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/constants/ModelDefaults.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/constants/ModelDefaults.java new file mode 100644 index 00000000..0d420625 --- /dev/null +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/constants/ModelDefaults.java @@ -0,0 +1,33 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for + * license information. + */ + +package com.microsoft.azure.functions.openai.constants; + +/** + * Constants for default model values used throughout the library. + * Centralizing these values makes it easier to update them across the codebase. + * + * @since 1.0.0 + */ +public final class ModelDefaults { + + /** + * Private constructor to prevent instantiation. + */ + private ModelDefaults() { + // Utility class should not be instantiated + } + + /** + * The default chat completion model used for text generation. + */ + public static final String DEFAULT_CHAT_MODEL = "gpt-3.5-turbo"; + + /** + * The default embeddings model used for vector embeddings. + */ + public static final String DEFAULT_EMBEDDINGS_MODEL = "text-embedding-ada-002"; +} From 675eb3f11fbff0c72b6c418c691cb7c0ada19b31 Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Sun, 27 Apr 2025 15:16:42 -0700 Subject: [PATCH 18/21] null check --- .../Assistants/AssistantSkillTriggerBindingProvider.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs index c4dd70f8..8b2ee6f9 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs @@ -113,7 +113,7 @@ public Task BindAsync(object value, ValueBindingContext context) SkillInvocationContext skillInvocationContext = (SkillInvocationContext)value; object? convertedValue; - if (!string.IsNullOrEmpty(skillInvocationContext.Arguments.ToString())) + if (!string.IsNullOrEmpty(skillInvocationContext.Arguments?.ToString())) { // We expect that input to always be a string value in the form {"paramName":paramValue} JObject argsJson = JObject.Parse(skillInvocationContext.Arguments.ToString()); From 8b253680861c37530ddc73234a6c42dc6c345b79 Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Sun, 27 Apr 2025 23:41:01 -0700 Subject: [PATCH 19/21] support reasoning models --- CHANGELOG.md | 4 ++ java-library/CHANGELOG.md | 4 ++ .../annotation/assistant/AssistantPost.java | 7 ++++ .../annotation/search/SemanticSearch.java | 7 ++++ .../textcompletion/TextCompletion.java | 6 +++ .../Assistants/AssistantPostInputAttribute.cs | 8 ++++ .../Embeddings/EmbeddingsRequestContext.cs | 0 .../Search/SemanticSearchInputAttribute.cs | 8 ++++ .../TextCompletionInputAttribute.cs | 8 ++++ .../Assistants/AssistantBaseAttribute.cs | 40 +++++++++++++++---- 10 files changed, 85 insertions(+), 7 deletions(-) create mode 100644 src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsRequestContext.cs diff --git a/CHANGELOG.md b/CHANGELOG.md index 2049a173..5b9c0d31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,10 @@ Starting v0.1.0 for Microsoft.Azure.WebJobs.Extensions.OpenAI.AzureAISearch, it - Updated Azure.AI.OpenAI from 1.0.0-beta.15 to 2.1.0 - Updated Azure.Data.Tables from 12.9.1 to 12.10.0, Azure.Identity from 1.12.1 to 1.13.2, Microsoft.Extensions.Azure from 1.7.5 to 1.10.0 +### Added + +- Introduced experimental `isReasoningModel` property to support reasoning models. Setting of max_completion_tokens and reasoning_effort is not supported with current underlying Azure.AI.OpenAI + ## v0.18.0 - 2024/10/08 ### Changed diff --git a/java-library/CHANGELOG.md b/java-library/CHANGELOG.md index 77a7e203..ba912f4d 100644 --- a/java-library/CHANGELOG.md +++ b/java-library/CHANGELOG.md @@ -19,6 +19,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Update azure-ai-openai from 1.0.0-beta.11 to 1.0.0-beta.16 +### Added + +- Introduced experimental `isReasoningModel` property to support reasoning models. Setting of max_completion_tokens and reasoning_effort is not supported with current underlying Azure.AI.OpenAI + ## v0.4.0 - 2024/10/08 ### Changed diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java index 3a5eaf95..16be1347 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java @@ -130,4 +130,11 @@ * @return The maxTokens value. */ String maxTokens() default "100"; + + /** + * Indicates whether the assistant uses a reasoning model. + * + * @return {@code true} if the assistant is based on a reasoning model; {@code false} otherwise. + */ + boolean isReasoningModel() default false; } diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java index cd8e09ad..d3d0c206 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java @@ -138,4 +138,11 @@ String systemPrompt() default "You are a helpful assistant. You are responding t * @return The maxTokens value. */ String maxTokens() default "2048"; + + /** + * Indicates whether the assistant uses a reasoning model. + * + * @return {@code true} if the assistant is based on a reasoning model; {@code false} otherwise. + */ + boolean isReasoningModel(); } diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java index d59e24a8..879b0c06 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java @@ -91,4 +91,10 @@ */ String maxTokens() default "100"; + /** + * Indicates whether the assistant uses a reasoning model. + * + * @return {@code true} if the assistant is based on a reasoning model; {@code false} otherwise. + */ + boolean isReasoningModel(); } diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs index e102a0f5..c29b462f 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs @@ -93,4 +93,12 @@ public AssistantPostInputAttribute(string id, string userMessage) /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). /// public string? MaxTokens { get; set; } = "100"; + + /// + /// Gets or sets a value indicating whether the model is a reasoning model. + /// + /// + /// Warning: This is experimental and associated with the reasoning model until all models have parity in the expected properties. + /// + public bool IsReasoningModel { get; set; } } diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsRequestContext.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsRequestContext.cs new file mode 100644 index 00000000..e69de29b diff --git a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs index 560e420a..1b679607 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs @@ -143,4 +143,12 @@ The following is a list of documents that you can refer to when answering questi /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). /// public string? MaxTokens { get; set; } = "2048"; + + /// + /// Gets or sets a value indicating whether the model is a reasoning model. + /// + /// + /// Warning: This is experimental and associated with the reasoning model until all models have parity in the expected properties. + /// + public bool IsReasoningModel { get; set; } } diff --git a/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs index a9538709..d9eaa985 100644 --- a/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs @@ -73,4 +73,12 @@ public TextCompletionInputAttribute(string prompt) /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). /// public string? MaxTokens { get; set; } = "100"; + + /// + /// Gets or sets a value indicating whether the model is a reasoning model. + /// + /// + /// Warning: This is experimental and associated with the reasoning model until all models have parity in the expected properties. + /// + public bool IsReasoningModel { get; set; } } diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs index 9dea44de..9c8a4c21 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs @@ -63,6 +63,14 @@ public class AssistantBaseAttribute : Attribute [AutoResolve] public string? TopP { get; set; } + /// + /// Gets or sets a value indicating whether the model is a reasoning model. + /// + /// + /// Warning: This is experimental and associated with the reasoning model until all models have parity in the expected properties. + /// + public bool IsReasoningModel { get; set; } + /// /// Gets or sets the maximum number of tokens to output in the completion. Default value = 100. /// @@ -76,21 +84,39 @@ public class AssistantBaseAttribute : Attribute internal ChatCompletionOptions BuildRequest() { ChatCompletionOptions request = new(); - if (int.TryParse(this.MaxTokens, out int maxTokens)) + if (float.TryParse(this.TopP, out float topP)) { - request.MaxOutputTokenCount = maxTokens; + request.TopP = topP; } - if (float.TryParse(this.Temperature, out float temperature)) + if (this.IsReasoningModel) { - request.Temperature = temperature; + this.MaxTokens = null; + this.Temperature = null; // property not supported for reasoning model. } - - if (float.TryParse(this.TopP, out float topP)) + else { - request.TopP = topP; + if (int.TryParse(this.MaxTokens, out int maxTokens)) + { + request.MaxOutputTokenCount = maxTokens; + } + + if (float.TryParse(this.Temperature, out float temperature)) + { + request.Temperature = temperature; + } } + // ToDo: SetNewMaxCompletionTokensPropertyEnabled() has a bug in the current version + // of the Azure.AI.OpenAI SDK but is fixed in the next preview. + // + // This method doesn't swap max_tokens with max_completion_tokens and throws errors + // if max_tokens is set. max_completion_tokens is not configurable due to this bug. + // + // Hence, setting max_tokens to null for reasoning models until the fixed SDK version + // can be adopted. + // request.SetNewMaxCompletionTokensPropertyEnabled(this.IsReasoningModel); + return request; } } From 24f80cb9043e5ba8ed6d5fa44b7120302a5a5308 Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Wed, 30 Apr 2025 11:56:49 -0700 Subject: [PATCH 20/21] update readme files --- README.md | 10 +++++++--- samples/rag-aisearch/README.md | 4 ++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index aaa3fcdf..2294b0cb 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![Build](https://dev.azure.com/azfunc/Azure%20Functions/_apis/build/status%2FExtension-OpenAI%2FAzure%20Functions%20OpenAI%20Extension%20PR%20CI?branchName=main)](https://dev.azure.com/azfunc/Azure%20Functions/_build/latest?definitionId=303&branchName=main) -This project adds support for [OpenAI](https://platform.openai.com/) LLM (GPT-3.5-turbo, GPT-4) bindings in [Azure Functions](https://azure.microsoft.com/products/functions/). +This project adds support for [OpenAI](https://platform.openai.com/) LLM (GPT-3.5-turbo, GPT-4, o-series) bindings in [Azure Functions](https://azure.microsoft.com/products/functions/). This extension depends on the [Azure AI OpenAI SDK](https://github.com/Azure/azure-sdk-for-net/tree/main/sdk/openai/Azure.AI.OpenAI). @@ -62,10 +62,10 @@ The optional `AIConnectionName` property specifies the name of a configuration s "__endpoint": "Placeholder for the Azure OpenAI endpoint value", "__credential": "managedidentity", "__managedIdentityResourceId": "Resource Id of managed identity", - "__managedIdentityClientId": "Client Id of managed identity" + "__clientId": "Client Id of managed identity" ``` - * Only one of managedIdentityResourceId or managedIdentityClientId should be specified, not both. + * Only one of managedIdentityResourceId or clientId should be specified, not both. * If no Resource Id or Client Id is specified, the system-assigned managed identity will be used by default. * Pass the configured `ConnectionNamePrefix` value, example `AzureOpenAI` to the `AIConnectionName` property. @@ -99,6 +99,10 @@ public static IActionResult PostUserResponse( } ``` +## Using Reasoning Models + +If using reasoning models, set the `IsReasoningModel` property to true in `AssistantPost`, `SemanticSearch` and `TextCompletion` bindings. This is required due to difference in expected properties for reasoning models. + ## Features The following features are currently available. More features will be slowly added over time. The language stack specific samples are also available in this repo for dotnet-isolated, java, nodejs, powershell and python. Visit the feature specific folder for utilising those. diff --git a/samples/rag-aisearch/README.md b/samples/rag-aisearch/README.md index 84e6609d..e9b0ce81 100644 --- a/samples/rag-aisearch/README.md +++ b/samples/rag-aisearch/README.md @@ -39,10 +39,10 @@ and optionally [enable semantic ranking](https://learn.microsoft.com/en-us/azure "__endpoint": "https://.search.windows.net", "__credential": "managedidentity", "__managedIdentityResourceId": "Resource Id of managed identity", - "__managedIdentityClientId": "Client Id of managed identity" + "__clientId": "Client Id of managed identity" ``` - Only one of managedIdentityResourceId or managedIdentityClientId should be specified, not both. + Only one of managedIdentityResourceId or clientId should be specified, not both. 2. System-Assigned Managed Identity or local development: From 9795b4e67c5f33f474410d01a3322f09eba17fed Mon Sep 17 00:00:00 2001 From: manvkaur <67894494+manvkaur@users.noreply.github.com> Date: Wed, 30 Apr 2025 13:00:53 -0700 Subject: [PATCH 21/21] remove unintended file --- .../Embeddings/EmbeddingsRequestContext.cs | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsRequestContext.cs diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsRequestContext.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsRequestContext.cs deleted file mode 100644 index e69de29b..00000000