diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b54278d..5b9c0d31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,9 +11,22 @@ Starting v0.1.0 for Microsoft.Azure.WebJobs.Extensions.OpenAI.AzureAISearch, it ## v0.19.0 - Unreleased +### Breaking + +- Renamed model properites to `chatModel` and `embeddingsModel` in AssistantPost, Embeddings and TextCompletion bindings. +- Renamed connectionName to `searchConnectionName` in SemanticSearch binding. +- Renamed connectionName to `storeConnectionName` in EmbeddingsStore binding. +- Renamed ChatMessage entity to AssistantMessage. +- Managed identity support through config section and binding parameter `aiConnectionName` in AssistantPost, Embeddings, EmbeddingsStore, SemanticSearch and TextCompletion bindings. + ### Changed -- Updated Azure.Data.Tables from 12.9.1 to 12.10.0, Azure.Identity from 1.12.1 to 1.13.2, Microsoft.Extensions.Azure from 1.7.5 to 1.8.0 +- Updated Azure.AI.OpenAI from 1.0.0-beta.15 to 2.1.0 +- Updated Azure.Data.Tables from 12.9.1 to 12.10.0, Azure.Identity from 1.12.1 to 1.13.2, Microsoft.Extensions.Azure from 1.7.5 to 1.10.0 + +### Added + +- Introduced experimental `isReasoningModel` property to support reasoning models. Setting of max_completion_tokens and reasoning_effort is not supported with current underlying Azure.AI.OpenAI ## v0.18.0 - 2024/10/08 diff --git a/OpenAI-Extension.sln b/OpenAI-Extension.sln index 6c86f9cc..2fee6c2d 100644 --- a/OpenAI-Extension.sln +++ b/OpenAI-Extension.sln @@ -118,6 +118,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "KustoSearchLegacy", "sample EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TextCompletionLegacy", "samples\textcompletion\csharp-legacy\TextCompletionLegacy.csproj", "{73C26271-7B8D-4B38-B454-D9902B92270F}" EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "WebJobsOpenAIUnitTests", "tests\UnitTests\WebJobsOpenAIUnitTests.csproj", "{52337999-5676-27FD-AE3D-CAAE406AE337}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -216,6 +218,10 @@ Global {73C26271-7B8D-4B38-B454-D9902B92270F}.Debug|Any CPU.Build.0 = Debug|Any CPU {73C26271-7B8D-4B38-B454-D9902B92270F}.Release|Any CPU.ActiveCfg = Release|Any CPU {73C26271-7B8D-4B38-B454-D9902B92270F}.Release|Any CPU.Build.0 = Release|Any CPU + {52337999-5676-27FD-AE3D-CAAE406AE337}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {52337999-5676-27FD-AE3D-CAAE406AE337}.Debug|Any CPU.Build.0 = Debug|Any CPU + {52337999-5676-27FD-AE3D-CAAE406AE337}.Release|Any CPU.ActiveCfg = Release|Any CPU + {52337999-5676-27FD-AE3D-CAAE406AE337}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -258,6 +264,7 @@ Global {BB20B2C0-35AE-CC75-61CE-C2713C83E4F6} = {4F208B03-C74C-458B-8300-9FF82F3C6325} {C5ACB61F-CDAC-E93F-895E-4D46F086DE61} = {B87CBFFB-8221-45E8-8631-4BB1685D50F0} {73C26271-7B8D-4B38-B454-D9902B92270F} = {5315AE81-BC33-49F2-935B-69287BB44CBD} + {52337999-5676-27FD-AE3D-CAAE406AE337} = {B99948E2-9853-4D9C-B784-19C2503CDA8B} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {05A7A679-3A53-45B5-AE93-4313655E127D} diff --git a/README.md b/README.md index 537adecc..2294b0cb 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) [![Build](https://dev.azure.com/azfunc/Azure%20Functions/_apis/build/status%2FExtension-OpenAI%2FAzure%20Functions%20OpenAI%20Extension%20PR%20CI?branchName=main)](https://dev.azure.com/azfunc/Azure%20Functions/_build/latest?definitionId=303&branchName=main) -This project adds support for [OpenAI](https://platform.openai.com/) LLM (GPT-3.5-turbo, GPT-4) bindings in [Azure Functions](https://azure.microsoft.com/products/functions/). +This project adds support for [OpenAI](https://platform.openai.com/) LLM (GPT-3.5-turbo, GPT-4, o-series) bindings in [Azure Functions](https://azure.microsoft.com/products/functions/). This extension depends on the [Azure AI OpenAI SDK](https://github.com/Azure/azure-sdk-for-net/tree/main/sdk/openai/Azure.AI.OpenAI). @@ -29,18 +29,79 @@ Add following section to `host.json` of the function app for non dotnet language ## Requirements -* [.NET 6 SDK](https://dotnet.microsoft.com/download/dotnet/6.0) or greater (Visual Studio 2022 recommended) +* [.NET 8 SDK](https://dotnet.microsoft.com/download/dotnet/8.0) or greater (Visual Studio 2022 recommended) * [Azure Functions Core Tools v4.x](https://learn.microsoft.com/azure/azure-functions/functions-run-local?tabs=v4%2Cwindows%2Cnode%2Cportal%2Cbash) +* Azure Storage emulator such as [Azurite](https://learn.microsoft.com/azure/storage/common/storage-use-azurite) running in the background +* The target language runtime (e.g. dotnet, nodejs, powershell, python, java) installed on your machine. Refer the official supported versions. * Update settings in Azure Function or the `local.settings.json` file for local development with the following keys: - 1. For Azure, `AZURE_OPENAI_ENDPOINT` - [Azure OpenAI resource](https://learn.microsoft.com/azure/ai-services/openai/how-to/create-resource?pivots=web-portal) (e.g. `https://***.openai.azure.com/`) set. - 1. For Azure, assign the user or function app managed identity `Cognitive Services OpenAI User` role on the Azure OpenAI resource. It is strongly recommended to use managed identity to avoid overhead of secrets maintenance, however if there is a need for key based authentication add the setting `AZURE_OPENAI_KEY` and its value in the settings. - 1. For non- Azure, `OPENAI_API_KEY` - An OpenAI account and an [API key](https://platform.openai.com/account/api-keys) saved into a setting. - If using environment variables, Learn more in [.env readme](./env/README.md). 1. Update `CHAT_MODEL_DEPLOYMENT_NAME` and `EMBEDDING_MODEL_DEPLOYMENT_NAME` keys to Azure Deployment names or override default OpenAI model names. - 1. If using user assigned managed identity, add `AZURE_CLIENT_ID` to environment variable settings with value as client id of the managed identity. 1. Visit binding specific samples README for additional settings that might be required for each binding. -* Azure Storage emulator such as [Azurite](https://learn.microsoft.com/azure/storage/common/storage-use-azurite) running in the background -* The target language runtime (e.g. dotnet, nodejs, powershell, python, java) installed on your machine. Refer the official supported versions. + 1. Refer [Configuring AI Service Connections](#configuring-ai-service-connections) + +## Configuring AI Service Connections + +The Azure Functions OpenAI Extension provides flexible options for configuring connections to AI services through the `AIConnectionName` property in the AssistantPost, TextCompletion, SemanticSearch, EmbeddingsStore, Embeddings bindings + +### Managed Identity Role + +Strongly recommended to use managed identity and ensure the user or function app's managed identity has the role - `Cognitive Services OpenAI User` + +### AIConnectionName Property + +The optional `AIConnectionName` property specifies the name of a configuration section that contains connection details for the AI service: + +#### For Azure OpenAI Service + +* If specified, the extension looks for `Endpoint` and `Key` values in the named configuration section +* If not specified or the configuration section doesn't exist, the extension falls back to environment variables: + * `AZURE_OPENAI_ENDPOINT` and/or + * `AZURE_OPENAI_KEY` +* For user-assigned managed identity authentication, a configuration section is required + + ```json + "__endpoint": "Placeholder for the Azure OpenAI endpoint value", + "__credential": "managedidentity", + "__managedIdentityResourceId": "Resource Id of managed identity", + "__clientId": "Client Id of managed identity" + ``` + + * Only one of managedIdentityResourceId or clientId should be specified, not both. + * If no Resource Id or Client Id is specified, the system-assigned managed identity will be used by default. + * Pass the configured `ConnectionNamePrefix` value, example `AzureOpenAI` to the `AIConnectionName` property. + +#### For OpenAI Service (non-Azure) + +* Set the `OPENAI_API_KEY` environment variable + +### Configuration Examples + +#### Example: Using a configuration section + +In `local.settings.json` or app environment variables: + +```json +"AzureOpenAI__endpoint": "Placeholder for the Azure OpenAI endpoint value", +"AzureOpenAI__credential": "managedidentity", +``` + +Specifying credential is optional for system assigned managed identity + +Function usage example: + +```csharp +[Function(nameof(PostUserResponse))] +public static IActionResult PostUserResponse( + [HttpTrigger(AuthorizationLevel.Function, "post", Route = "chats/{chatId}")] HttpRequestData req, + string chatId, + [AssistantPostInput("{chatId}", "{Query.message}", AIConnectionName = "AzureOpenAI", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state) +{ + return new OkObjectResult(state.RecentMessages.LastOrDefault()?.Content ?? "No response returned."); +} +``` + +## Using Reasoning Models + +If using reasoning models, set the `IsReasoningModel` property to true in `AssistantPost`, `SemanticSearch` and `TextCompletion` bindings. This is required due to difference in expected properties for reasoning models. ## Features @@ -270,7 +331,7 @@ public static async Task IngestFile( public class EmbeddingsStoreOutputResponse { - [EmbeddingsStoreOutput("{url}", InputType.Url, "AISearchEndpoint", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] + [EmbeddingsStoreOutput("{url}", InputType.Url, "AISearchEndpoint", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] public required SearchableDocument SearchableDocument { get; init; } public IActionResult? HttpResponse { get; set; } diff --git a/eng/ci/templates/build-local.yml b/eng/ci/templates/build-local.yml index a12251f8..b11122e1 100644 --- a/eng/ci/templates/build-local.yml +++ b/eng/ci/templates/build-local.yml @@ -20,8 +20,13 @@ jobs: dotnet build $(System.DefaultWorkingDirectory)/src/WebJobs.Extensions.OpenAI.Kusto/WebJobs.Extensions.OpenAI.Kusto.csproj --configuration $(config) -p:Version=$(fakeWebJobsPackageVersion) -p:AzureAISearchVersion=$(fakeWebJobsPackageVersion) -p:KustoVersion=$(fakeWebJobsPackageVersion) dotnet build $(System.DefaultWorkingDirectory)/src/WebJobs.Extensions.OpenAI.AzureAISearch/WebJobs.Extensions.OpenAI.AzureAISearch.csproj --configuration $(config) -p:Version=$(fakeWebJobsPackageVersion) -p:AzureAISearchVersion=$(fakeWebJobsPackageVersion) dotnet build $(System.DefaultWorkingDirectory)/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/WebJobs.Extensions.OpenAI.CosmosDBSearch.csproj --configuration $(config) -p:Version=$(fakeWebJobsPackageVersion) -p:CosmosDBSearchVersion=$(fakeWebJobsPackageVersion) + dotnet build $(System.DefaultWorkingDirectory)/tests/UnitTests/WebJobsOpenAIUnitTests.csproj --configuration $(config) -p:WebJobsVersion=$(fakeWebJobsPackageVersion) -p:Version=$(fakeWebJobsPackageVersion) displayName: Dotnet Build WebJobs.Extensions.OpenAI + - script: | + dotnet test $(System.DefaultWorkingDirectory)/tests/UnitTests/WebJobsOpenAIUnitTests.csproj --configuration $(config) --collect "Code Coverage" --no-build + displayName: Dotnet Test WebJobsOpenAIUnitTests + - task: CopyFiles@2 displayName: 'Copy NuGet WebJobs.Extensions.OpenAI to local directory' inputs: diff --git a/java-library/CHANGELOG.md b/java-library/CHANGELOG.md index 8e928b10..ba912f4d 100644 --- a/java-library/CHANGELOG.md +++ b/java-library/CHANGELOG.md @@ -7,9 +7,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## v0.5.0 - Unreleased +### Breaking + +- model properties renamed to `chatModel` and `embeddingsModel` in assistantPost, embeddings and textCompletion bindings. +- renamed connectionName to `searchConnectionName` in semanticSearch binding. +- renamed connectionName to `storeConnectionName` in embeddingsStore binding. +- renamed ChatMessage to `AssistantMessage`. +- managed identity support through config section and binding parameter `aiConnectionName` in assistantPost, embeddings, embeddingsStore, semanticSearch and textCompletion bindings. + ### Changed -- Update azure-ai-openai from 1.0.0-beta.11 to 1.0.0-beta.14 +- Update azure-ai-openai from 1.0.0-beta.11 to 1.0.0-beta.16 + +### Added + +- Introduced experimental `isReasoningModel` property to support reasoning models. Setting of max_completion_tokens and reasoning_effort is not supported with current underlying Azure.AI.OpenAI ## v0.4.0 - 2024/10/08 diff --git a/java-library/pom.xml b/java-library/pom.xml index bf26b3d7..221bf968 100644 --- a/java-library/pom.xml +++ b/java-library/pom.xml @@ -88,7 +88,7 @@ com.azure azure-ai-openai - 1.0.0-beta.14 + 1.0.0-beta.16 compile diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/ChatMessage.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantMessage.java similarity index 55% rename from java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/ChatMessage.java rename to java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantMessage.java index 1d246201..11627eda 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/ChatMessage.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantMessage.java @@ -11,35 +11,23 @@ * Chat Message Entity which contains the content of the message, the role of the chat agent, and the name of the calling function if applicable. *

*/ -public class ChatMessage { +public class AssistantMessage { private String content; private String role; - private String name; + private String toolCalls; /** - * Initializes a new instance of the ChatMessage class. + * Initializes a new instance of the AssistantMessage class. * - * @param content The content of the message. - * @param role The role of the chat agent. - */ - public ChatMessage(String content, String role) { - this.content = content; - this.role = role; - } - - - /** - * Initializes a new instance of the ChatMessage class. - * - * @param content The content of the message. - * @param role The role of the chat agent. - * @param name The name of the calling function if applicable. + * @param content The content of the message. + * @param role The role of the chat agent. + * @param toolCalls The toolCalls of the calling function if applicable. */ - public ChatMessage(String content, String role, String name) { + public AssistantMessage(String content, String role, String toolCalls) { this.content = content; this.role = role; - this.name = name; + this.toolCalls = toolCalls; } /** @@ -79,21 +67,20 @@ public void setRole(String role) { } /** - * Gets the name of the calling function if applicable. + * Gets the toolCalls of the calling function if applicable. * - * @return The name of the calling function if applicable. + * @return The toolCalls of the calling function if applicable. */ - public String getName() { - return name; + public String getToolCalls() { + return toolCalls; } /** - * Sets the name of the calling function if applicable. + * Sets the toolCalls of the calling function if applicable. * - * @param name The name of the calling function if applicable. + * @param toolCalls The toolCalls of the calling function if applicable. */ - public void setName(String name) { - this.name = name; + public void setToolCalls(String toolCalls) { + this.toolCalls = toolCalls; } } - diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java index 67a3cc3f..16be1347 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java @@ -7,6 +7,7 @@ package com.microsoft.azure.functions.openai.annotation.assistant; import com.microsoft.azure.functions.annotation.CustomBinding; +import com.microsoft.azure.functions.openai.constants.ModelDefaults; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; @@ -61,7 +62,7 @@ * * @return The OpenAI chat model to use. */ - String model(); + String chatModel() default ModelDefaults.DEFAULT_CHAT_MODEL; /** * The user message that user has entered for assistant to respond to. @@ -70,6 +71,14 @@ */ String userMessage(); + /** + * The name of the configuration section for AI service connectivity settings. + * + * @return The name of the configuration section for AI service connectivity + * settings. + */ + String aiConnectionName() default ""; + /** * The configuration section name for the table settings for assistant chat * storage. @@ -86,4 +95,46 @@ * returns {@code DEFAULT_COLLECTION}. */ String collectionName() default DEFAULT_COLLECTION; + + /** + * The sampling temperature to use, between 0 and 2. Higher values like 0.8 will + * make the output + * more random, while lower values like 0.2 will make it more focused and + * deterministic. + * It's generally recommended to use this or {@link #topP()} but not both. + * + * @return The sampling temperature value. + */ + String temperature() default "0.5"; + + /** + * An alternative to sampling with temperature, called nucleus sampling, where + * the model considers + * the results of the tokens with top_p probability mass. So 0.1 means only the + * tokens comprising the top 10% + * probability mass are considered. + * It's generally recommended to use this or {@link #temperature()} but not + * both. + * + * @return The topP value. + */ + String topP() default ""; + + /** + * The maximum number of tokens to generate in the completion. + * The token count of your prompt plus max_tokens cannot exceed the model's + * context length. + * Most models have a context length of 2048 tokens (except for the newest + * models, which support 4096). + * + * @return The maxTokens value. + */ + String maxTokens() default "100"; + + /** + * Indicates whether the assistant uses a reasoning model. + * + * @return {@code true} if the assistant is based on a reasoning model; {@code false} otherwise. + */ + boolean isReasoningModel() default false; } diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java index 610a1fcd..36f67dbb 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java @@ -14,54 +14,46 @@ import java.lang.annotation.Target; /** - *

- * Assistant skill trigger attribute. - *

- * - * @since 1.0.0 - */ - @Retention(RetentionPolicy.RUNTIME) - @Target(ElementType.PARAMETER) - @CustomBinding(direction = "in", name = "", type = "assistantSkillTrigger") - public @interface AssistantSkillTrigger { - - /** - * The variable name used in function.json. - * - * @return The variable name used in function.json. - */ - String name(); - - /** - * The name of the function to be invoked by the assistant. - * - * @return The name of the function to be invoked by the assistant. - */ - String functionName() default ""; - - /** - * The description of the assistant function, which is provided to the LLM. - * - * @return The description of the assistant function. - */ - String functionDescription(); - - /** - * The JSON description of the function parameter, which is provided to the LLM. - * If no description is provided, the description will be autogenerated. - * For more information on the syntax of the parameter description JSON, see the OpenAI API documentation: - * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools. - * - * @return The JSON description of the function parameter. - */ - String parameterDescriptionJson() default ""; - - /** - * The OpenAI chat model to use. - * When using Azure OpenAI, this should be the name of the model deployment. - * - * @return The OpenAI chat model to use. - */ - String model() default "gpt-3.5-turbo"; - - } + *

+ * Assistant skill trigger attribute. + *

+ * + * @since 1.0.0 + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.PARAMETER) +@CustomBinding(direction = "in", name = "", type = "assistantSkillTrigger") +public @interface AssistantSkillTrigger { + + /** + * The variable name used in function.json. + * + * @return The variable name used in function.json. + */ + String name(); + + /** + * The name of the function to be invoked by the assistant. + * + * @return The name of the function to be invoked by the assistant. + */ + String functionName() default ""; + + /** + * The description of the assistant function, which is provided to the LLM. + * + * @return The description of the assistant function. + */ + String functionDescription(); + + /** + * The JSON description of the function parameter, which is provided to the LLM. + * If no description is provided, the description will be autogenerated. + * For more information on the syntax of the parameter description JSON, see the + * OpenAI API documentation: + * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools. + * + * @return The JSON description of the function parameter. + */ + String parameterDescriptionJson() default ""; +} diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantState.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantState.java index 4ebcb0ea..0f88cfc1 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantState.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantState.java @@ -20,11 +20,11 @@ public class AssistantState { private String lastUpdatedAt; private int totalMessages; private int totalTokens; - private List recentMessages; + private List recentMessages; public AssistantState(String id, boolean exists, - String createdAt, String lastUpdatedAt, - int totalMessages, int totalTokens, List recentMessages) { + String createdAt, String lastUpdatedAt, + int totalMessages, int totalTokens, List recentMessages) { this.id = id; this.exists = exists; this.createdAt = createdAt; @@ -56,7 +56,7 @@ public boolean isExists() { * Gets timestamp of when assistant is created. * * @return The timestamp of when assistant is created. - */ + */ public String getCreatedAt() { return createdAt; } @@ -93,7 +93,7 @@ public int getTotalTokens() { * * @return A list of the recent messages from the assistant. */ - public List getRecentMessages() { + public List getRecentMessages() { return recentMessages; } @@ -156,8 +156,8 @@ public void setTotalTokens(int totalTokens) { * * @param recentMessages A list of the recent messages from the assistant. */ - public void setRecentMessages(List recentMessages) { + public void setRecentMessages(List recentMessages) { this.recentMessages = recentMessages; } - + } diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java index 1627b380..3e04c8cf 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java @@ -7,19 +7,19 @@ package com.microsoft.azure.functions.openai.annotation.embeddings; import com.azure.ai.openai.models.Embeddings; -import com.azure.ai.openai.models.EmbeddingsOptions; +import java.util.List; public class EmbeddingsContext { - private EmbeddingsOptions request; + private List request; private Embeddings response; private int count = 0; - public EmbeddingsOptions getRequest() { + public List getRequest() { return request; } - public void setRequest(EmbeddingsOptions request) { + public void setRequest(List request) { this.request = request; } @@ -38,7 +38,8 @@ public void setResponse(Embeddings response) { */ public int getCount() { return this.response != null && this.response.getData() != null - ? this.count : 0; + ? this.count + : 0; } } diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java index 39d46691..5d1d911b 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java @@ -7,12 +7,13 @@ package com.microsoft.azure.functions.openai.annotation.embeddings; import com.microsoft.azure.functions.annotation.CustomBinding; +import com.microsoft.azure.functions.openai.constants.ModelDefaults; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; - + /** *

* Embeddings input attribute which is used for embedding generation. @@ -34,17 +35,21 @@ /** * The ID of the model to use. - * Changing the default embeddings model is a breaking change, since any changes will be stored in a vector database for lookup. - * Changing the default model can cause the lookups to start misbehaving if they don't match the data that was previously ingested into the vector database. + * Changing the default embeddings model is a breaking change, since any changes + * will be stored in a vector database for lookup. + * Changing the default model can cause the lookups to start misbehaving if they + * don't match the data that was previously ingested into the vector database. * * @return The model ID. */ - String model() default "text-embedding-ada-002"; + String embeddingsModel() default ModelDefaults.DEFAULT_EMBEDDINGS_MODEL; /** * The maximum number of characters to chunk the input into. - * At the time of writing, the maximum input tokens allowed for second-generation input embedding models - * like text-embedding-ada-002 is 8191. 1 token is ~4 chars in English, which translates to roughly 32K + * At the time of writing, the maximum input tokens allowed for + * second-generation input embedding models + * like text-embedding-ada-002 is 8191. 1 token is ~4 chars in English, which + * translates to roughly 32K * characters of English input that can fit into a single chunk. * * @return The maximum number of characters to chunk the input into. @@ -65,11 +70,19 @@ */ String input(); + /** + * The name of the configuration section for AI service connectivity settings. + * + * @return The name of the configuration section for AI service connectivity + * settings. + */ + String aiConnectionName() default ""; + /** * The input type. * * @return The input type. */ InputType inputType(); - - } + +} diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java index 87402766..ee333138 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java @@ -7,6 +7,7 @@ package com.microsoft.azure.functions.openai.annotation.embeddings; import com.microsoft.azure.functions.annotation.CustomBinding; +import com.microsoft.azure.functions.openai.constants.ModelDefaults; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; @@ -27,17 +28,21 @@ /** * The ID of the model to use. - * Changing the default embeddings model is a breaking change, since any changes will be stored in a vector database for lookup. - * Changing the default model can cause the lookups to start misbehaving if they don't match the data that was previously ingested into the vector database. + * Changing the default embeddings model is a breaking change, since any changes + * will be stored in a vector database for lookup. + * Changing the default model can cause the lookups to start misbehaving if they + * don't match the data that was previously ingested into the vector database. * * @return The model ID. */ - String model() default "text-embedding-ada-002"; + String embeddingsModel() default ModelDefaults.DEFAULT_EMBEDDINGS_MODEL; /** * The maximum number of characters to chunk the input into. - * At the time of writing, the maximum input tokens allowed for second-generation input embedding models - * like text-embedding-ada-002 is 8191. 1 token is ~4 chars in English, which translates to roughly 32K + * At the time of writing, the maximum input tokens allowed for + * second-generation input embedding models + * like text-embedding-ada-002 is 8191. 1 token is ~4 chars in English, which + * translates to roughly 32K * characters of English input that can fit into a single chunk. * * @return The maximum number of characters to chunk the input into. @@ -66,13 +71,22 @@ InputType inputType(); /** - * The name of an app setting or environment variable which contains a connection string value. + * The name of the configuration section for AI service connectivity settings. + * + * @return The name of the configuration section for AI service connectivity + * settings. + */ + String aiConnectionName() default ""; + + /** + * The name of an app setting or environment variable which contains a + * connection string value. * This property supports binding expressions. * - * @return The connection name. + * @return The store connection name. */ - String connectionName(); - + String storeConnectionName(); + /** * The name of the collection or table to search. * This property supports binding expressions. diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java index 87941b85..d3d0c206 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java @@ -7,11 +7,12 @@ package com.microsoft.azure.functions.openai.annotation.search; import com.microsoft.azure.functions.annotation.CustomBinding; +import com.microsoft.azure.functions.openai.constants.ModelDefaults; - import java.lang.annotation.ElementType; - import java.lang.annotation.Retention; - import java.lang.annotation.RetentionPolicy; - import java.lang.annotation.Target; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; @Retention(RetentionPolicy.RUNTIME) @Target(ElementType.PARAMETER) @@ -26,12 +27,21 @@ String name(); /** - * The name of an app setting or environment variable which contains a connection string value. + * The name of the configuration section for AI service connectivity settings. + * + * @return The name of the configuration section for AI service connectivity + * settings. + */ + String aiConnectionName() default ""; + + /** + * The name of an app setting or environment variable which contains a + * connection string value. * This property supports binding expressions. * * @return The connection name. */ - String connectionName(); + String searchConnectionName(); /** * The name of the collection or table to search. @@ -50,7 +60,6 @@ */ String query() default ""; - /** * The model to use for embeddings. * The default value is "text-embedding-ada-002". @@ -58,21 +67,21 @@ * * @return The model to use for embeddings. */ - String embeddingsModel() default "text-embedding-ada-002"; + String embeddingsModel() default ModelDefaults.DEFAULT_EMBEDDINGS_MODEL; /** * the name of the Large Language Model to invoke for chat responses. * The default value is "gpt-3.5-turbo". * This property supports binding expressions. - * + * * @return The name of the Large Language Model to invoke for chat responses. */ - String chatModel() default "gpt-3.5-turbo"; - + String chatModel() default ModelDefaults.DEFAULT_CHAT_MODEL; /** * The system prompt to use for prompting the large language model. - * The system prompt will be appended with knowledge that is fetched as a result of the Query. + * The system prompt will be appended with knowledge that is fetched as a result + * of the Query. * The combined prompt will then be sent to the OpenAI Chat API. * This property supports binding expressions. * @@ -94,4 +103,46 @@ String systemPrompt() default "You are a helpful assistant. You are responding t * @return The number of knowledge items to inject into the SystemPrompt. */ int maxKnowledgeLength() default 1; + + /** + * The sampling temperature to use, between 0 and 2. Higher values like 0.8 will + * make the output + * more random, while lower values like 0.2 will make it more focused and + * deterministic. + * It's generally recommended to use this or {@link #topP()} but not both. + * + * @return The sampling temperature value. + */ + String temperature() default "0.5"; + + /** + * An alternative to sampling with temperature, called nucleus sampling, where + * the model considers + * the results of the tokens with top_p probability mass. So 0.1 means only the + * tokens comprising the top 10% + * probability mass are considered. + * It's generally recommended to use this or {@link #temperature()} but not + * both. + * + * @return The topP value. + */ + String topP() default ""; + + /** + * The maximum number of tokens to generate in the completion. + * The token count of your prompt plus max_tokens cannot exceed the model's + * context length. + * Most models have a context length of 2048 tokens (except for the newest + * models, which support 4096). + * + * @return The maxTokens value. + */ + String maxTokens() default "2048"; + + /** + * Indicates whether the assistant uses a reasoning model. + * + * @return {@code true} if the assistant is based on a reasoning model; {@code false} otherwise. + */ + boolean isReasoningModel(); } diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java index 1d024ce4..879b0c06 100644 --- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java @@ -7,15 +7,17 @@ package com.microsoft.azure.functions.openai.annotation.textcompletion; import com.microsoft.azure.functions.annotation.CustomBinding; +import com.microsoft.azure.functions.openai.constants.ModelDefaults; import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; - + /** *

- * Assistant query input attribute which is used query the Assistant to get current state. + * Assistant query input attribute which is used query the Assistant to get + * current state. *

* * @since 1.0.0 @@ -39,16 +41,26 @@ */ String prompt(); + /** + * The name of the configuration section for AI service connectivity settings. + * + * @return The name of the configuration section for AI service connectivity + * settings. + */ + String aiConnectionName() default ""; + /** * The ID of the model to use. * * @return The model ID. */ - String model() default "gpt-3.5-turbo"; + String chatModel() default ModelDefaults.DEFAULT_CHAT_MODEL; /** - * The sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output - * more random, while lower values like 0.2 will make it more focused and deterministic. + * The sampling temperature to use, between 0 and 2. Higher values like 0.8 will + * make the output + * more random, while lower values like 0.2 will make it more focused and + * deterministic. * It's generally recommended to use this or {@link #topP()} but not both. * * @return The sampling temperature value. @@ -56,10 +68,13 @@ String temperature() default "0.5"; /** - * An alternative to sampling with temperature, called nucleus sampling, where the model considers - * the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + * An alternative to sampling with temperature, called nucleus sampling, where + * the model considers + * the results of the tokens with top_p probability mass. So 0.1 means only the + * tokens comprising the top 10% * probability mass are considered. - * It's generally recommended to use this or {@link #temperature()} but not both. + * It's generally recommended to use this or {@link #temperature()} but not + * both. * * @return The topP value. */ @@ -67,11 +82,19 @@ /** * The maximum number of tokens to generate in the completion. - * The token count of your prompt plus max_tokens cannot exceed the model's context length. - * Most models have a context length of 2048 tokens (except for the newest models, which support 4096). + * The token count of your prompt plus max_tokens cannot exceed the model's + * context length. + * Most models have a context length of 2048 tokens (except for the newest + * models, which support 4096). * * @return The maxTokens value. */ String maxTokens() default "100"; - - } + + /** + * Indicates whether the assistant uses a reasoning model. + * + * @return {@code true} if the assistant is based on a reasoning model; {@code false} otherwise. + */ + boolean isReasoningModel(); +} diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/constants/ModelDefaults.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/constants/ModelDefaults.java new file mode 100644 index 00000000..0d420625 --- /dev/null +++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/constants/ModelDefaults.java @@ -0,0 +1,33 @@ +/** + * Copyright (c) Microsoft Corporation. All rights reserved. + * Licensed under the MIT License. See License.txt in the project root for + * license information. + */ + +package com.microsoft.azure.functions.openai.constants; + +/** + * Constants for default model values used throughout the library. + * Centralizing these values makes it easier to update them across the codebase. + * + * @since 1.0.0 + */ +public final class ModelDefaults { + + /** + * Private constructor to prevent instantiation. + */ + private ModelDefaults() { + // Utility class should not be instantiated + } + + /** + * The default chat completion model used for text generation. + */ + public static final String DEFAULT_CHAT_MODEL = "gpt-3.5-turbo"; + + /** + * The default embeddings model used for vector embeddings. + */ + public static final String DEFAULT_EMBEDDINGS_MODEL = "text-embedding-ada-002"; +} diff --git a/samples/assistant/csharp-legacy/AssistantApis.cs b/samples/assistant/csharp-legacy/AssistantApis.cs index 16f9d5bb..4bdacd5e 100644 --- a/samples/assistant/csharp-legacy/AssistantApis.cs +++ b/samples/assistant/csharp-legacy/AssistantApis.cs @@ -51,7 +51,7 @@ public static async Task CreateAssistant( public static IActionResult PostUserQuery( [HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = "assistants/{assistantId}")] HttpRequest req, string assistantId, - [AssistantPost("{assistantId}", "{Query.message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState updatedState) + [AssistantPost("{assistantId}", "{Query.message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState updatedState) { return new OkObjectResult(updatedState.RecentMessages.Any() ? updatedState.RecentMessages[updatedState.RecentMessages.Count - 1].Content : "No response returned."); } diff --git a/samples/assistant/csharp-legacy/AssistantSample.csproj b/samples/assistant/csharp-legacy/AssistantSample.csproj index d60abd4a..711f9628 100644 --- a/samples/assistant/csharp-legacy/AssistantSample.csproj +++ b/samples/assistant/csharp-legacy/AssistantSample.csproj @@ -7,7 +7,7 @@ - + diff --git a/samples/assistant/csharp-ooproc/AssistantApis.cs b/samples/assistant/csharp-ooproc/AssistantApis.cs index 5d425c8f..a98b6fcd 100644 --- a/samples/assistant/csharp-ooproc/AssistantApis.cs +++ b/samples/assistant/csharp-ooproc/AssistantApis.cs @@ -59,10 +59,10 @@ public class CreateChatBotOutput /// HTTP POST function that sends user prompts to the assistant chat bot. /// [Function(nameof(PostUserQuery))] - public static async Task PostUserQuery( + public static IActionResult PostUserQuery( [HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = "assistants/{assistantId}")] HttpRequestData req, string assistantId, - [AssistantPostInput("{assistantId}", "{Query.message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state) + [AssistantPostInput("{assistantId}", "{Query.message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state) { return new OkObjectResult(state.RecentMessages.Any() ? state.RecentMessages[state.RecentMessages.Count - 1].Content : "No response returned."); } @@ -71,7 +71,7 @@ public static async Task PostUserQuery( /// HTTP GET function that queries the conversation history of the assistant chat bot. /// [Function(nameof(GetChatState))] - public static async Task GetChatState( + public static IActionResult GetChatState( [HttpTrigger(AuthorizationLevel.Anonymous, "get", Route = "assistants/{assistantId}")] HttpRequestData req, string assistantId, [AssistantQueryInput("{assistantId}", TimestampUtc = "{Query.timestampUTC}", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state) diff --git a/samples/assistant/csharp-ooproc/AssistantSample.csproj b/samples/assistant/csharp-ooproc/AssistantSample.csproj index d6e9140a..2087a8de 100644 --- a/samples/assistant/csharp-ooproc/AssistantSample.csproj +++ b/samples/assistant/csharp-ooproc/AssistantSample.csproj @@ -9,11 +9,11 @@ - + - + diff --git a/samples/assistant/java/extensions.csproj b/samples/assistant/java/extensions.csproj index eb243f13..ad5a93c6 100644 --- a/samples/assistant/java/extensions.csproj +++ b/samples/assistant/java/extensions.csproj @@ -1,6 +1,6 @@  - net60 + net80 ** target/azure-functions/azfs-java-openai-sample/bin diff --git a/samples/assistant/javascript/src/functions/assistantApis.js b/samples/assistant/javascript/src/functions/assistantApis.js index a8eaced4..3d167a10 100644 --- a/samples/assistant/javascript/src/functions/assistantApis.js +++ b/samples/assistant/javascript/src/functions/assistantApis.js @@ -36,7 +36,7 @@ app.http('CreateAssistant', { const assistantPostInput = input.generic({ type: 'assistantPost', id: '{assistantId}', - model: '%CHAT_MODEL_DEPLOYMENT_NAME%', + chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%', userMessage: '{Query.message}', chatStorageConnectionSetting: CHAT_STORAGE_CONNECTION_SETTING, collectionName: COLLECTION_NAME diff --git a/samples/assistant/powershell/PostUserQuery/function.json b/samples/assistant/powershell/PostUserQuery/function.json index 024cd24f..05b7702c 100644 --- a/samples/assistant/powershell/PostUserQuery/function.json +++ b/samples/assistant/powershell/PostUserQuery/function.json @@ -22,7 +22,7 @@ "dataType": "string", "id": "{assistantId}", "userMessage": "{Query.message}", - "model": "%CHAT_MODEL_DEPLOYMENT_NAME%", + "chatModel": "%CHAT_MODEL_DEPLOYMENT_NAME%", "chatStorageConnectionSetting": "AzureWebJobsStorage", "collectionName": "ChatState" } diff --git a/samples/assistant/typescript/src/functions/assistantApis.ts b/samples/assistant/typescript/src/functions/assistantApis.ts index 35bdd59c..c3bd70bf 100644 --- a/samples/assistant/typescript/src/functions/assistantApis.ts +++ b/samples/assistant/typescript/src/functions/assistantApis.ts @@ -36,7 +36,7 @@ app.http('CreateAssistant', { const assistantPostInput = input.generic({ type: 'assistantPost', id: '{assistantId}', - model: '%CHAT_MODEL_DEPLOYMENT_NAME%', + chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%', userMessage: '{Query.message}', chatStorageConnectionSetting: CHAT_STORAGE_CONNECTION_SETTING, collectionName: COLLECTION_NAME diff --git a/samples/chat/README.md b/samples/chat/README.md index d554a49b..1c030710 100644 --- a/samples/chat/README.md +++ b/samples/chat/README.md @@ -71,7 +71,7 @@ For additional details on using identity-based connections, refer to the [Azure public static async Task PostUserResponse( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "chats/{chatId}")] HttpRequestData req, string chatId, - [AssistantPostInput("{chatId}", "{message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] AssistantState state) + [AssistantPostInput("{chatId}", "{message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] AssistantState state) ``` diff --git a/samples/chat/csharp-legacy/ChatBot.cs b/samples/chat/csharp-legacy/ChatBot.cs index e7403df7..358d61d8 100644 --- a/samples/chat/csharp-legacy/ChatBot.cs +++ b/samples/chat/csharp-legacy/ChatBot.cs @@ -48,10 +48,10 @@ public static AssistantState GetChatState( } [FunctionName(nameof(PostUserResponse))] - public static async Task PostUserResponse( + public static IActionResult PostUserResponse( [HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = "chats/{chatId}")] HttpRequest req, string chatId, - [AssistantPost("{chatId}", "{Query.message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState updatedState) + [AssistantPost("{chatId}", "{Query.message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState updatedState) { return new OkObjectResult(updatedState.RecentMessages.Any() ? updatedState.RecentMessages[updatedState.RecentMessages.Count - 1].Content : "No response returned."); } diff --git a/samples/chat/csharp-ooproc/ChatBot.cs b/samples/chat/csharp-ooproc/ChatBot.cs index 2a94b4b4..9db059ce 100644 --- a/samples/chat/csharp-ooproc/ChatBot.cs +++ b/samples/chat/csharp-ooproc/ChatBot.cs @@ -23,8 +23,8 @@ public class CreateRequest [Function(nameof(CreateChatBot))] public static async Task CreateChatBot( - [HttpTrigger(AuthorizationLevel.Function, "put", Route = "chats/{chatId}")] HttpRequestData req, - string chatId) + [HttpTrigger(AuthorizationLevel.Function, "put", Route = "chats/{chatId}")] HttpRequestData req, + string chatId) { var responseJson = new { chatId }; CreateRequest? createRequestBody; @@ -39,7 +39,7 @@ public static async Task CreateChatBot( } catch (Exception ex) { - throw new ArgumentException("Invalid request body. Make sure that you pass in {\"instructions\": value } as the request body.", ex.Message); + throw new ArgumentException("Invalid request body. Make sure that you pass in {\"instructions\": value } as the request body.", ex); } return new CreateChatBotOutput @@ -63,16 +63,16 @@ public class CreateChatBotOutput } [Function(nameof(PostUserResponse))] - public static async Task PostUserResponse( + public static IActionResult PostUserResponse( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "chats/{chatId}")] HttpRequestData req, string chatId, - [AssistantPostInput("{chatId}", "{Query.message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state) + [AssistantPostInput("{chatId}", "{Query.message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state) { return new OkObjectResult(state.RecentMessages.LastOrDefault()?.Content ?? "No response returned."); } [Function(nameof(GetChatState))] - public static async Task GetChatState( + public static IActionResult GetChatState( [HttpTrigger(AuthorizationLevel.Function, "get", Route = "chats/{chatId}")] HttpRequestData req, string chatId, [AssistantQueryInput("{chatId}", TimestampUtc = "{Query.timestampUTC}", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state, diff --git a/samples/chat/csharp-ooproc/ChatBot.csproj b/samples/chat/csharp-ooproc/ChatBot.csproj index 12dd29a8..e45a1e8c 100644 --- a/samples/chat/csharp-ooproc/ChatBot.csproj +++ b/samples/chat/csharp-ooproc/ChatBot.csproj @@ -11,7 +11,7 @@ - + diff --git a/samples/chat/java/extensions.csproj b/samples/chat/java/extensions.csproj index eb243f13..ad5a93c6 100644 --- a/samples/chat/java/extensions.csproj +++ b/samples/chat/java/extensions.csproj @@ -1,6 +1,6 @@  - net60 + net80 ** target/azure-functions/azfs-java-openai-sample/bin diff --git a/samples/chat/javascript/src/app.js b/samples/chat/javascript/src/app.js index f0a86ddd..4b3b3409 100644 --- a/samples/chat/javascript/src/app.js +++ b/samples/chat/javascript/src/app.js @@ -52,7 +52,7 @@ app.http('GetChatState', { const assistantPostInput = input.generic({ type: 'assistantPost', id: '{chatID}', - model: '%CHAT_MODEL_DEPLOYMENT_NAME%', + chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%', userMessage: '{Query.message}', chatStorageConnectionSetting: CHAT_STORAGE_CONNECTION_SETTING, collectionName: COLLECTION_NAME diff --git a/samples/chat/powershell/PostUserResponse/function.json b/samples/chat/powershell/PostUserResponse/function.json index 16142e20..46a0ba09 100644 --- a/samples/chat/powershell/PostUserResponse/function.json +++ b/samples/chat/powershell/PostUserResponse/function.json @@ -20,7 +20,7 @@ "direction": "in", "name": "ChatBotState", "id": "{chatId}", - "model": "%CHAT_MODEL_DEPLOYMENT_NAME%", + "chatModel": "%CHAT_MODEL_DEPLOYMENT_NAME%", "userMessage": "{Query.message}", "chatStorageConnectionSetting": "AzureWebJobsStorage", "collectionName": "ChatState" diff --git a/samples/chat/powershell/extensions.csproj b/samples/chat/powershell/extensions.csproj index 8f73abd3..82ae1923 100644 --- a/samples/chat/powershell/extensions.csproj +++ b/samples/chat/powershell/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/chat/typescript/src/functions/app.ts b/samples/chat/typescript/src/functions/app.ts index 84a12075..a5cb025c 100644 --- a/samples/chat/typescript/src/functions/app.ts +++ b/samples/chat/typescript/src/functions/app.ts @@ -52,7 +52,7 @@ app.http('GetChatState', { const assistantPostInput = input.generic({ type: 'assistantPost', id: '{chatID}', - model: '%CHAT_MODEL_DEPLOYMENT_NAME%', + chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%', userMessage: '{Query.message}', chatStorageConnectionSetting: CHAT_STORAGE_CONNECTION_SETTING, collectionName: COLLECTION_NAME diff --git a/samples/embeddings/csharp-legacy/EmbeddingsLegacy.cs b/samples/embeddings/csharp-legacy/EmbeddingsLegacy.cs index 58688120..12baeeca 100644 --- a/samples/embeddings/csharp-legacy/EmbeddingsLegacy.cs +++ b/samples/embeddings/csharp-legacy/EmbeddingsLegacy.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using System.Linq; using Microsoft.Azure.WebJobs; using Microsoft.Azure.WebJobs.Extensions.Http; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; @@ -23,13 +22,13 @@ public record EmbeddingsRequest(string RawText, string FilePath, string Url); [FunctionName(nameof(GenerateEmbeddings_Http_Request))] public static void GenerateEmbeddings_Http_Request( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings")] EmbeddingsRequest req, - [Embeddings("{RawText}", InputType.RawText)] EmbeddingsContext embeddings, + [Embeddings("{RawText}", InputType.RawText, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings, ILogger logger) { logger.LogInformation( - "Received {count} embedding(s) for input text containing {length} characters.", - embeddings.Count, - req.RawText.Length); + "Received {count} embedding(s) for input text containing {length} characters.", + embeddings.Count, + req.RawText.Length); // TODO: Store the embeddings into a database or other storage. } @@ -41,7 +40,7 @@ public static void GenerateEmbeddings_Http_Request( [FunctionName(nameof(GetEmbeddings_Http_FilePath))] public static void GetEmbeddings_Http_FilePath( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings-from-file")] EmbeddingsRequest req, - [Embeddings("{FilePath}", InputType.FilePath, MaxChunkLength = 512)] EmbeddingsContext embeddings, + [Embeddings("{FilePath}", InputType.FilePath, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%", MaxChunkLength = 512)] EmbeddingsContext embeddings, ILogger logger) { logger.LogInformation( @@ -59,7 +58,7 @@ public static void GetEmbeddings_Http_FilePath( [FunctionName(nameof(GetEmbeddings_Http_Url))] public static void GetEmbeddings_Http_Url( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings-from-url")] EmbeddingsRequest req, - [Embeddings("{Url}", InputType.Url, MaxChunkLength = 512)] EmbeddingsContext embeddings, + [Embeddings("{Url}", InputType.Url, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%", MaxChunkLength = 512)] EmbeddingsContext embeddings, ILogger logger) { logger.LogInformation( diff --git a/samples/embeddings/csharp-ooproc/Embeddings/Embeddings.csproj b/samples/embeddings/csharp-ooproc/Embeddings/Embeddings.csproj index 61b2ebf6..464519f2 100644 --- a/samples/embeddings/csharp-ooproc/Embeddings/Embeddings.csproj +++ b/samples/embeddings/csharp-ooproc/Embeddings/Embeddings.csproj @@ -9,7 +9,7 @@ - + diff --git a/samples/embeddings/csharp-ooproc/Embeddings/EmbeddingsGenerator.cs b/samples/embeddings/csharp-ooproc/Embeddings/EmbeddingsGenerator.cs index 4b42dcca..84c5fa41 100644 --- a/samples/embeddings/csharp-ooproc/Embeddings/EmbeddingsGenerator.cs +++ b/samples/embeddings/csharp-ooproc/Embeddings/EmbeddingsGenerator.cs @@ -41,7 +41,7 @@ internal class EmbeddingsRequest [Function(nameof(GenerateEmbeddings_Http_RequestAsync))] public async Task GenerateEmbeddings_Http_RequestAsync( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings")] HttpRequestData req, - [EmbeddingsInput("{rawText}", InputType.RawText, Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings) + [EmbeddingsInput("{rawText}", InputType.RawText, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings) { using StreamReader reader = new(req.Body); string request = await reader.ReadToEndAsync(); @@ -63,7 +63,7 @@ public async Task GenerateEmbeddings_Http_RequestAsync( [Function(nameof(GetEmbeddings_Http_FilePath))] public async Task GetEmbeddings_Http_FilePath( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings-from-file")] HttpRequestData req, - [EmbeddingsInput("{filePath}", InputType.FilePath, MaxChunkLength = 512, Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings) + [EmbeddingsInput("{filePath}", InputType.FilePath, MaxChunkLength = 512, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings) { using StreamReader reader = new(req.Body); string request = await reader.ReadToEndAsync(); @@ -84,7 +84,7 @@ public async Task GetEmbeddings_Http_FilePath( [Function(nameof(GetEmbeddings_Http_URL))] public async Task GetEmbeddings_Http_URL( [HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings-from-url")] HttpRequestData req, - [EmbeddingsInput("{url}", InputType.Url, MaxChunkLength = 512, Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings) + [EmbeddingsInput("{url}", InputType.Url, MaxChunkLength = 512, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings) { using StreamReader reader = new(req.Body); string request = await reader.ReadToEndAsync(); diff --git a/samples/embeddings/java/extensions.csproj b/samples/embeddings/java/extensions.csproj index eb243f13..ad5a93c6 100644 --- a/samples/embeddings/java/extensions.csproj +++ b/samples/embeddings/java/extensions.csproj @@ -1,6 +1,6 @@  - net60 + net80 ** target/azure-functions/azfs-java-openai-sample/bin diff --git a/samples/embeddings/javascript/src/app.js b/samples/embeddings/javascript/src/app.js index 7c53e652..6a52be79 100644 --- a/samples/embeddings/javascript/src/app.js +++ b/samples/embeddings/javascript/src/app.js @@ -4,7 +4,7 @@ const embeddingsHttpInput = input.generic({ input: '{rawText}', inputType: 'RawText', type: 'embeddings', - model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' + embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' }) app.http('generateEmbeddings', { @@ -31,7 +31,7 @@ const embeddingsFilePathInput = input.generic({ inputType: 'FilePath', type: 'embeddings', maxChunkLength: 512, - model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' + embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' }) app.http('getEmbeddingsFilePath', { @@ -58,7 +58,7 @@ const embeddingsUrlInput = input.generic({ inputType: 'Url', type: 'embeddings', maxChunkLength: 512, - model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' + embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' }) app.http('getEmbeddingsUrl', { diff --git a/samples/embeddings/powershell/GenerateEmbeddings/function.json b/samples/embeddings/powershell/GenerateEmbeddings/function.json index 57621818..0236fcdf 100644 --- a/samples/embeddings/powershell/GenerateEmbeddings/function.json +++ b/samples/embeddings/powershell/GenerateEmbeddings/function.json @@ -21,7 +21,7 @@ "direction": "in", "inputType": "RawText", "input": "{rawText}", - "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" } ] } \ No newline at end of file diff --git a/samples/embeddings/powershell/GetEmbeddingsFilePath/function.json b/samples/embeddings/powershell/GetEmbeddingsFilePath/function.json index 3209aed5..1d8d44ed 100644 --- a/samples/embeddings/powershell/GetEmbeddingsFilePath/function.json +++ b/samples/embeddings/powershell/GetEmbeddingsFilePath/function.json @@ -22,7 +22,7 @@ "inputType": "FilePath", "input": "{filePath}", "maxChunkLength": 512, - "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" } ] } \ No newline at end of file diff --git a/samples/embeddings/powershell/GetEmbeddingsURL/function.json b/samples/embeddings/powershell/GetEmbeddingsURL/function.json index ebf96844..de11ff4b 100644 --- a/samples/embeddings/powershell/GetEmbeddingsURL/function.json +++ b/samples/embeddings/powershell/GetEmbeddingsURL/function.json @@ -22,7 +22,7 @@ "inputType": "Url", "input": "{url}", "maxChunkLength": 512, - "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" } ] } \ No newline at end of file diff --git a/samples/embeddings/powershell/extensions.csproj b/samples/embeddings/powershell/extensions.csproj index 8f73abd3..82ae1923 100644 --- a/samples/embeddings/powershell/extensions.csproj +++ b/samples/embeddings/powershell/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/embeddings/typescript/src/app.ts b/samples/embeddings/typescript/src/app.ts index 9af76b96..f749cd0c 100644 --- a/samples/embeddings/typescript/src/app.ts +++ b/samples/embeddings/typescript/src/app.ts @@ -8,7 +8,7 @@ const embeddingsHttpInput = input.generic({ input: '{rawText}', inputType: 'RawText', type: 'embeddings', - model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' + embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' }) app.http('generateEmbeddings', { @@ -39,7 +39,7 @@ const embeddingsFilePathInput = input.generic({ inputType: 'FilePath', type: 'embeddings', maxChunkLength: 512, - model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' + embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' }) app.http('getEmbeddingsFilePath', { @@ -70,7 +70,7 @@ const embeddingsUrlInput = input.generic({ inputType: 'Url', type: 'embeddings', maxChunkLength: 512, - model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' + embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%' }) app.http('getEmbeddingsUrl', { diff --git a/samples/rag-aisearch/README.md b/samples/rag-aisearch/README.md index c56c2952..e9b0ce81 100644 --- a/samples/rag-aisearch/README.md +++ b/samples/rag-aisearch/README.md @@ -35,16 +35,14 @@ and optionally [enable semantic ranking](https://learn.microsoft.com/en-us/azure 1. User-Assigned Managed Identity: - // Note and TODO: Use of user-assigned managed identity for AI Search is breaking at the moment, since Azure OpenAI authentication also needs update to support a separate identity. - ```json "__endpoint": "https://.search.windows.net", "__credential": "managedidentity", "__managedIdentityResourceId": "Resource Id of managed identity", - "__managedIdentityClientId": "Client Id of managed identity" + "__clientId": "Client Id of managed identity" ``` - Only one of managedIdentityResourceId or managedIdentityClientId should be specified, not both. + Only one of managedIdentityResourceId or clientId should be specified, not both. 2. System-Assigned Managed Identity or local development: @@ -56,7 +54,7 @@ and optionally [enable semantic ranking](https://learn.microsoft.com/en-us/azure Specifying credential is optional for system assigned managed identity 4. Binding Configuration - - Pass the configured `ConnectionNamePrefix` value, example `AISearch` to the `connectionName` property in the `SemanticSearchInput` or `EmbeddingsStoreOutput` bindings. Default is `AISearchEndpoint` if just the endpoint is being configured in local.settings.json or environment variables to use DefaultAzureCredential. + Pass the configured `ConnectionNamePrefix` value, example `AISearch` to the `searchConnectionName` property in the `SemanticSearchInput` or `EmbeddingsStoreOutput` bindings. Default is `AISearchEndpoint` if just the endpoint is being configured in local.settings.json or environment variables to use DefaultAzureCredential. ## Running the sample diff --git a/samples/rag-aisearch/csharp-legacy/FilePrompt.cs b/samples/rag-aisearch/csharp-legacy/FilePrompt.cs index df97a125..4c6b7f2f 100644 --- a/samples/rag-aisearch/csharp-legacy/FilePrompt.cs +++ b/samples/rag-aisearch/csharp-legacy/FilePrompt.cs @@ -20,7 +20,7 @@ public record SemanticSearchRequest(string Prompt); [FunctionName("IngestFile")] public static async Task IngestFile( [HttpTrigger(AuthorizationLevel.Function, "post")] EmbeddingsRequest req, - [EmbeddingsStore("{url}", InputType.Url, "AISearchEndpoint", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] + [EmbeddingsStore("{url}", InputType.Url, "AISearchEndpoint", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] IAsyncCollector output) { if (string.IsNullOrWhiteSpace(req.Url)) diff --git a/samples/rag-aisearch/csharp-ooproc/FilePrompt.cs b/samples/rag-aisearch/csharp-ooproc/FilePrompt.cs index d96b1d93..c9d85c6a 100644 --- a/samples/rag-aisearch/csharp-ooproc/FilePrompt.cs +++ b/samples/rag-aisearch/csharp-ooproc/FilePrompt.cs @@ -67,7 +67,7 @@ public static async Task IngestFile( public class EmbeddingsStoreOutputResponse { - [EmbeddingsStoreOutput("{url}", InputType.Url, "AISearchEndpoint", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] + [EmbeddingsStoreOutput("{url}", InputType.Url, "AISearchEndpoint", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] public required SearchableDocument SearchableDocument { get; init; } public IActionResult? HttpResponse { get; set; } diff --git a/samples/rag-aisearch/csharp-ooproc/SemanticAISearchEmbeddings.csproj b/samples/rag-aisearch/csharp-ooproc/SemanticAISearchEmbeddings.csproj index feef3710..2740dd2b 100644 --- a/samples/rag-aisearch/csharp-ooproc/SemanticAISearchEmbeddings.csproj +++ b/samples/rag-aisearch/csharp-ooproc/SemanticAISearchEmbeddings.csproj @@ -11,7 +11,7 @@ - + diff --git a/samples/rag-aisearch/java/extensions.csproj b/samples/rag-aisearch/java/extensions.csproj index fcbb0cac..29683eae 100644 --- a/samples/rag-aisearch/java/extensions.csproj +++ b/samples/rag-aisearch/java/extensions.csproj @@ -1,6 +1,6 @@  - net60 + net80 ** target/azure-functions/azfs-java-openai-sample/bin diff --git a/samples/rag-aisearch/javascript/src/app.js b/samples/rag-aisearch/javascript/src/app.js index 6bc7cc8f..c68b58eb 100644 --- a/samples/rag-aisearch/javascript/src/app.js +++ b/samples/rag-aisearch/javascript/src/app.js @@ -8,7 +8,7 @@ const embeddingsStoreOutput = output.generic({ inputType: "url", connectionName: "AISearchEndpoint", collection: "openai-index", - model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" }); app.http('IngestFile', { diff --git a/samples/rag-aisearch/powershell/IngestFile/function.json b/samples/rag-aisearch/powershell/IngestFile/function.json index 96a6a06e..be786a8b 100644 --- a/samples/rag-aisearch/powershell/IngestFile/function.json +++ b/samples/rag-aisearch/powershell/IngestFile/function.json @@ -22,7 +22,7 @@ "inputType": "Url", "connectionName": "AISearchEndpoint", "collection": "openai-index", - "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" } ] } \ No newline at end of file diff --git a/samples/rag-aisearch/powershell/extensions.csproj b/samples/rag-aisearch/powershell/extensions.csproj index 68c66ba4..1c321726 100644 --- a/samples/rag-aisearch/powershell/extensions.csproj +++ b/samples/rag-aisearch/powershell/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/rag-aisearch/typescript/src/app.ts b/samples/rag-aisearch/typescript/src/app.ts index ec430e24..176686d1 100644 --- a/samples/rag-aisearch/typescript/src/app.ts +++ b/samples/rag-aisearch/typescript/src/app.ts @@ -11,7 +11,7 @@ const embeddingsStoreOutput = output.generic({ inputType: "url", connectionName: "AISearchEndpoint", collection: "openai-index", - model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" }); app.http('IngestFile', { diff --git a/samples/rag-cosmosdb/csharp-legacy/FilePrompt.cs b/samples/rag-cosmosdb/csharp-legacy/FilePrompt.cs index 3411dcb0..ba4d823c 100644 --- a/samples/rag-cosmosdb/csharp-legacy/FilePrompt.cs +++ b/samples/rag-cosmosdb/csharp-legacy/FilePrompt.cs @@ -20,7 +20,7 @@ public record SemanticSearchRequest(string Prompt); [FunctionName("IngestFile")] public static async Task IngestFile( [HttpTrigger(AuthorizationLevel.Function, "post")] EmbeddingsRequest req, - [EmbeddingsStore("{url}", InputType.Url, "CosmosDBMongoVCoreConnectionString", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] + [EmbeddingsStore("{url}", InputType.Url, "CosmosDBMongoVCoreConnectionString", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] IAsyncCollector output) { if (string.IsNullOrWhiteSpace(req.Url)) diff --git a/samples/rag-cosmosdb/csharp-ooproc/FilePrompt.cs b/samples/rag-cosmosdb/csharp-ooproc/FilePrompt.cs index 09be8588..5e1b964d 100644 --- a/samples/rag-cosmosdb/csharp-ooproc/FilePrompt.cs +++ b/samples/rag-cosmosdb/csharp-ooproc/FilePrompt.cs @@ -66,7 +66,7 @@ public static async Task IngestFile( public class EmbeddingsStoreOutputResponse { - [EmbeddingsStoreOutput("{url}", InputType.Url, "CosmosDBMongoVCoreConnectionString", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] + [EmbeddingsStoreOutput("{url}", InputType.Url, "CosmosDBMongoVCoreConnectionString", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] public required SearchableDocument SearchableDocument { get; init; } public IActionResult? HttpResponse { get; set; } diff --git a/samples/rag-cosmosdb/csharp-ooproc/SemanticCosmosDBSearchEmbeddings.csproj b/samples/rag-cosmosdb/csharp-ooproc/SemanticCosmosDBSearchEmbeddings.csproj index 6c3e0147..9b114e87 100644 --- a/samples/rag-cosmosdb/csharp-ooproc/SemanticCosmosDBSearchEmbeddings.csproj +++ b/samples/rag-cosmosdb/csharp-ooproc/SemanticCosmosDBSearchEmbeddings.csproj @@ -11,7 +11,7 @@ - + diff --git a/samples/rag-cosmosdb/java/extensions.csproj b/samples/rag-cosmosdb/java/extensions.csproj index 010fdc9a..2dbe5f2b 100644 --- a/samples/rag-cosmosdb/java/extensions.csproj +++ b/samples/rag-cosmosdb/java/extensions.csproj @@ -1,6 +1,6 @@  - net60 + net80 ** target/azure-functions/azfs-java-openai-sample/bin diff --git a/samples/rag-cosmosdb/javascript/src/app.js b/samples/rag-cosmosdb/javascript/src/app.js index 274917c4..0116e652 100644 --- a/samples/rag-cosmosdb/javascript/src/app.js +++ b/samples/rag-cosmosdb/javascript/src/app.js @@ -8,7 +8,7 @@ const embeddingsStoreOutput = output.generic({ inputType: "url", connectionName: "CosmosDBMongoVCoreConnectionString", collection: "openai-index", - model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" }); app.http('IngestFile', { diff --git a/samples/rag-cosmosdb/powershell/IngestFile/function.json b/samples/rag-cosmosdb/powershell/IngestFile/function.json index 55b986e0..cdb86803 100644 --- a/samples/rag-cosmosdb/powershell/IngestFile/function.json +++ b/samples/rag-cosmosdb/powershell/IngestFile/function.json @@ -22,7 +22,7 @@ "inputType": "Url", "connectionName": "CosmosDBMongoVCoreConnectionString", "collection": "openai-index", - "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" } ] } \ No newline at end of file diff --git a/samples/rag-cosmosdb/powershell/extensions.csproj b/samples/rag-cosmosdb/powershell/extensions.csproj index bf9f1260..02592700 100644 --- a/samples/rag-cosmosdb/powershell/extensions.csproj +++ b/samples/rag-cosmosdb/powershell/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/rag-cosmosdb/python/extensions.csproj b/samples/rag-cosmosdb/python/extensions.csproj index 61aaed1f..e67d7e72 100644 --- a/samples/rag-cosmosdb/python/extensions.csproj +++ b/samples/rag-cosmosdb/python/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/rag-cosmosdb/typescript/src/app.ts b/samples/rag-cosmosdb/typescript/src/app.ts index c1808411..3ef9257e 100644 --- a/samples/rag-cosmosdb/typescript/src/app.ts +++ b/samples/rag-cosmosdb/typescript/src/app.ts @@ -11,7 +11,7 @@ const embeddingsStoreOutput = output.generic({ inputType: "url", connectionName: "CosmosDBMongoVCoreConnectionString", collection: "openai-index", - model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" }); app.http('IngestFile', { diff --git a/samples/rag-kusto/csharp-legacy/FilePrompt.cs b/samples/rag-kusto/csharp-legacy/FilePrompt.cs index 32600947..e63563e7 100644 --- a/samples/rag-kusto/csharp-legacy/FilePrompt.cs +++ b/samples/rag-kusto/csharp-legacy/FilePrompt.cs @@ -20,7 +20,7 @@ public record SemanticSearchRequest(string Prompt); [FunctionName("IngestFile")] public static async Task IngestFile( [HttpTrigger(AuthorizationLevel.Function, "post")] EmbeddingsRequest req, - [EmbeddingsStore("{url}", InputType.Url, "KustoConnectionString", "Documents", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] + [EmbeddingsStore("{url}", InputType.Url, "KustoConnectionString", "Documents", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] IAsyncCollector output) { if (string.IsNullOrWhiteSpace(req.Url)) diff --git a/samples/rag-kusto/csharp-ooproc/EmailPromptDemo.cs b/samples/rag-kusto/csharp-ooproc/EmailPromptDemo.cs index 7b78e9b8..747ef524 100644 --- a/samples/rag-kusto/csharp-ooproc/EmailPromptDemo.cs +++ b/samples/rag-kusto/csharp-ooproc/EmailPromptDemo.cs @@ -66,7 +66,7 @@ public async Task IngestEmail( } public class EmbeddingsStoreOutputResponse { - [EmbeddingsStoreOutput("{url}", InputType.Url, "KustoConnectionString", "Documents", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] + [EmbeddingsStoreOutput("{url}", InputType.Url, "KustoConnectionString", "Documents", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] public required SearchableDocument SearchableDocument { get; init; } [HttpResult] diff --git a/samples/rag-kusto/csharp-ooproc/SemanticSearchEmbeddings.csproj b/samples/rag-kusto/csharp-ooproc/SemanticSearchEmbeddings.csproj index cb5db855..71a9aac5 100644 --- a/samples/rag-kusto/csharp-ooproc/SemanticSearchEmbeddings.csproj +++ b/samples/rag-kusto/csharp-ooproc/SemanticSearchEmbeddings.csproj @@ -10,7 +10,7 @@ - + diff --git a/samples/rag-kusto/java/extensions.csproj b/samples/rag-kusto/java/extensions.csproj index e2323ac5..3fded17b 100644 --- a/samples/rag-kusto/java/extensions.csproj +++ b/samples/rag-kusto/java/extensions.csproj @@ -1,6 +1,6 @@  - net60 + net80 ** target/azure-functions/azfs-java-openai-sample/bin diff --git a/samples/rag-kusto/javascript/src/app.js b/samples/rag-kusto/javascript/src/app.js index 655e64e6..87b948d1 100644 --- a/samples/rag-kusto/javascript/src/app.js +++ b/samples/rag-kusto/javascript/src/app.js @@ -8,7 +8,7 @@ const embeddingsStoreOutput = output.generic({ inputType: "url", connectionName: "KustoConnectionString", collection: "Documents", - model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" }); app.http('IngestEmail', { diff --git a/samples/rag-kusto/powershell/IngestEmail/function.json b/samples/rag-kusto/powershell/IngestEmail/function.json index 62ad48c4..31bd2eaa 100644 --- a/samples/rag-kusto/powershell/IngestEmail/function.json +++ b/samples/rag-kusto/powershell/IngestEmail/function.json @@ -22,7 +22,7 @@ "inputType": "Url", "connectionName": "KustoConnectionString", "collection": "Documents", - "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" } ] } \ No newline at end of file diff --git a/samples/rag-kusto/powershell/extensions.csproj b/samples/rag-kusto/powershell/extensions.csproj index c8f3ee70..08c2647b 100644 --- a/samples/rag-kusto/powershell/extensions.csproj +++ b/samples/rag-kusto/powershell/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/rag-kusto/python/extensions.csproj b/samples/rag-kusto/python/extensions.csproj index 1346f6ff..05d7cdf4 100644 --- a/samples/rag-kusto/python/extensions.csproj +++ b/samples/rag-kusto/python/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/rag-kusto/typescript/src/app.ts b/samples/rag-kusto/typescript/src/app.ts index bd0df1d8..bfc17d70 100644 --- a/samples/rag-kusto/typescript/src/app.ts +++ b/samples/rag-kusto/typescript/src/app.ts @@ -11,7 +11,7 @@ const embeddingsStoreOutput = output.generic({ inputType: "url", connectionName: "KustoConnectionString", collection: "Documents", - model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" + embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%" }); app.http('IngestEmail', { diff --git a/samples/textcompletion/csharp-legacy/TextCompletionLegacy.cs b/samples/textcompletion/csharp-legacy/TextCompletionLegacy.cs index 12d96f83..95fa8567 100644 --- a/samples/textcompletion/csharp-legacy/TextCompletionLegacy.cs +++ b/samples/textcompletion/csharp-legacy/TextCompletionLegacy.cs @@ -24,7 +24,7 @@ public static class TextCompletionLegacy [FunctionName(nameof(WhoIs))] public static IActionResult WhoIs( [HttpTrigger(AuthorizationLevel.Function, Route = "whois/{name}")] HttpRequest req, - [TextCompletion("Who is {name}?", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response) + [TextCompletion("Who is {name}?", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response) { return new OkObjectResult(response.Content); } @@ -36,7 +36,7 @@ public static IActionResult WhoIs( [FunctionName(nameof(GenericCompletion))] public static IActionResult GenericCompletion( [HttpTrigger(AuthorizationLevel.Function, "post")] PromptPayload payload, - [TextCompletion("{Prompt}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response, + [TextCompletion("{Prompt}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response, ILogger log) { string text = response.Content; diff --git a/samples/textcompletion/csharp-ooproc/TextCompletion.csproj b/samples/textcompletion/csharp-ooproc/TextCompletion.csproj index 74c6a35e..2b9b00ec 100644 --- a/samples/textcompletion/csharp-ooproc/TextCompletion.csproj +++ b/samples/textcompletion/csharp-ooproc/TextCompletion.csproj @@ -11,7 +11,7 @@ - + diff --git a/samples/textcompletion/csharp-ooproc/TextCompletions.cs b/samples/textcompletion/csharp-ooproc/TextCompletions.cs index a1a1ef7e..4bdbae7a 100644 --- a/samples/textcompletion/csharp-ooproc/TextCompletions.cs +++ b/samples/textcompletion/csharp-ooproc/TextCompletions.cs @@ -23,7 +23,7 @@ public static class TextCompletions [Function(nameof(WhoIs))] public static IActionResult WhoIs( [HttpTrigger(AuthorizationLevel.Function, Route = "whois/{name}")] HttpRequestData req, - [TextCompletionInput("Who is {name}?", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response) + [TextCompletionInput("Who is {name}?", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response) { return new OkObjectResult(response.Content); } @@ -35,7 +35,7 @@ public static IActionResult WhoIs( [Function(nameof(GenericCompletion))] public static IActionResult GenericCompletion( [HttpTrigger(AuthorizationLevel.Function, "post")] HttpRequestData req, - [TextCompletionInput("{Prompt}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response, + [TextCompletionInput("{Prompt}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response, ILogger log) { string text = response.Content; diff --git a/samples/textcompletion/java/extensions.csproj b/samples/textcompletion/java/extensions.csproj index eb243f13..ad5a93c6 100644 --- a/samples/textcompletion/java/extensions.csproj +++ b/samples/textcompletion/java/extensions.csproj @@ -1,6 +1,6 @@  - net60 + net80 ** target/azure-functions/azfs-java-openai-sample/bin diff --git a/samples/textcompletion/javascript/src/functions/whois.js b/samples/textcompletion/javascript/src/functions/whois.js index 665af065..96725512 100644 --- a/samples/textcompletion/javascript/src/functions/whois.js +++ b/samples/textcompletion/javascript/src/functions/whois.js @@ -5,7 +5,7 @@ const openAICompletionInput = input.generic({ prompt: 'Who is {name}?', maxTokens: '100', type: 'textCompletion', - model: '%CHAT_MODEL_DEPLOYMENT_NAME%' + chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%' }) app.http('whois', { diff --git a/samples/textcompletion/powershell/WhoIs/function.json b/samples/textcompletion/powershell/WhoIs/function.json index bdc381c7..6abf6c9b 100644 --- a/samples/textcompletion/powershell/WhoIs/function.json +++ b/samples/textcompletion/powershell/WhoIs/function.json @@ -21,7 +21,7 @@ "name": "TextCompletionResponse", "prompt": "Who is {name}?", "maxTokens": "100", - "model": "%CHAT_MODEL_DEPLOYMENT_NAME%" + "chatModel": "%CHAT_MODEL_DEPLOYMENT_NAME%" } ] } \ No newline at end of file diff --git a/samples/textcompletion/powershell/extensions.csproj b/samples/textcompletion/powershell/extensions.csproj index 8f73abd3..82ae1923 100644 --- a/samples/textcompletion/powershell/extensions.csproj +++ b/samples/textcompletion/powershell/extensions.csproj @@ -1,6 +1,6 @@ - net60 + net80 ** bin diff --git a/samples/textcompletion/typescript/src/functions/whois.ts b/samples/textcompletion/typescript/src/functions/whois.ts index 22f8c2ee..b910f51c 100644 --- a/samples/textcompletion/typescript/src/functions/whois.ts +++ b/samples/textcompletion/typescript/src/functions/whois.ts @@ -5,7 +5,7 @@ const openAICompletionInput = input.generic({ prompt: 'Who is {name}?', maxTokens: '100', type: 'textCompletion', - model: '%CHAT_MODEL_DEPLOYMENT_NAME%' + chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%' }) app.http('whois', { diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 91466cb7..9f4ab3ef 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -33,7 +33,7 @@ $(VersionPrefix).$(FileVersionRevision) - 0.16.0-alpha + 0.17.0-alpha 0.4.0-alpha diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantMessage.cs similarity index 71% rename from src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs rename to src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantMessage.cs index 585fc205..cdad6ef4 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantMessage.cs @@ -8,18 +8,19 @@ namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Assistants; /// /// Chat Message Entity which contains the content of the message, the role of the chat agent, and the name of the calling function if applicable. /// -public class ChatMessage +public class AssistantMessage { /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The content of the message. /// The role of the chat agent. - public ChatMessage(string content, string role, string? name) + /// The tool calls. + public AssistantMessage(string content, string role, string toolCalls) { this.Content = content; this.Role = role; - this.Name = name; + this.ToolCalls = toolCalls; } /// @@ -35,8 +36,8 @@ public ChatMessage(string content, string role, string? name) public string Role { get; set; } /// - /// Gets or sets the name of the calling function if applicable. + /// Gets or sets the tool calls. /// - [JsonPropertyName("name")] - public string? Name { get; set; } + [JsonPropertyName("toolCalls")] + public string ToolCalls { get; set; } } diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs index fda594e7..c29b462f 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs @@ -10,12 +10,34 @@ namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Assistants; /// public sealed class AssistantPostInputAttribute : InputBindingAttribute { - public AssistantPostInputAttribute(string id, string UserMessage) + /// + /// Initializes a new instance of the class. + /// + /// The assistant identifier. + /// The user message. + public AssistantPostInputAttribute(string id, string userMessage) { this.Id = id; - this.UserMessage = UserMessage; + this.UserMessage = userMessage; } + /// + /// Gets or sets the name of the configuration section for AI service connectivity settings. + /// + /// + /// This property specifies the name of the configuration section that contains connection details for the AI service. + /// + /// For Azure OpenAI: + /// - If specified, looks for "Endpoint" and "Key" values in this configuration section + /// - If not specified or the section doesn't exist, falls back to environment variables: + /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY + /// - For user-assigned managed identity authentication, this property is required + /// + /// For OpenAI: + /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. + /// + public string AIConnectionName { get; set; } = ""; + /// /// Gets the ID of the assistant to update. /// @@ -27,7 +49,7 @@ public AssistantPostInputAttribute(string id, string UserMessage) /// /// When using Azure OpenAI, then should be the name of the model deployment. /// - public string? Model { get; set; } + public string? ChatModel { get; set; } /// /// Gets user message that user has entered for assistant to respond to. @@ -43,4 +65,40 @@ public AssistantPostInputAttribute(string id, string UserMessage) /// Table collection name for chat storage. /// public string CollectionName { get; set; } = "ChatState"; + + /// + /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output + /// more random, while lower values like 0.2 will make it more focused and deterministic. + /// + /// + /// It's generally recommend to use this or but not both. + /// + public string? Temperature { get; set; } = "0.5"; + + /// + /// Gets or sets an alternative to sampling with temperature, called nucleus sampling, where the model considers + /// the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + /// probability mass are considered. + /// + /// + /// It's generally recommend to use this or but not both. + /// + public string? TopP { get; set; } + + /// + /// Gets or sets the maximum number of tokens to output in the completion. Default is 100. + /// + /// + /// The token count of your prompt plus max_tokens cannot exceed the model's context length. + /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). + /// + public string? MaxTokens { get; set; } = "100"; + + /// + /// Gets or sets a value indicating whether the model is a reasoning model. + /// + /// + /// Warning: This is experimental and associated with the reasoning model until all models have parity in the expected properties. + /// + public bool IsReasoningModel { get; set; } } diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs index 6c6787ec..410eb73b 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs @@ -37,12 +37,4 @@ public AssistantSkillTriggerAttribute(string functionDescription) /// https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools. /// public string? ParameterDescriptionJson { get; set; } - - /// - /// Gets or sets the OpenAI chat model to use. - /// - /// - /// When using Azure OpenAI, then should be the name of the model deployment. - /// - public string Model { get; set; } = OpenAIModels.DefaultChatModel; } \ No newline at end of file diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantState.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantState.cs index 1ef18d72..7570ba87 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantState.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantState.cs @@ -50,5 +50,5 @@ public class AssistantState /// Gets a list of the recent messages from the assistant. /// [JsonPropertyName("recentMessages")] - public IReadOnlyList RecentMessages { get; set; } = Array.Empty(); + public IReadOnlyList RecentMessages { get; set; } = Array.Empty(); } diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs index a73974fa..f8e8e676 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs @@ -1,23 +1,27 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using Azure.AI.OpenAI; using System.ClientModel.Primitives; using System.Text.Json; using System.Text.Json.Serialization; +using OpenAI.Chat; namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Assistants; -public class ChatCompletionsJsonConverter : JsonConverter +public class ChatCompletionJsonConverter : JsonConverter { static readonly ModelReaderWriterOptions modelReaderWriterOptions = new("J"); - public override ChatCompletions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override ChatCompletion Read( + ref Utf8JsonReader reader, + Type typeToConvert, + JsonSerializerOptions options) { using JsonDocument jsonDocument = JsonDocument.ParseValue(ref reader); - return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!; + return ModelReaderWriter.Read( + BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!; } - public override void Write(Utf8JsonWriter writer, ChatCompletions value, JsonSerializerOptions options) + public override void Write(Utf8JsonWriter writer, ChatCompletion value, JsonSerializerOptions options) { - ((IJsonModel)value).Write(writer, modelReaderWriterOptions); + ((IJsonModel)value).Write(writer, modelReaderWriterOptions); } } \ No newline at end of file diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs index 21259707..2f7dfa45 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs @@ -1,17 +1,13 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using OpenAISDK = Azure.AI.OpenAI; +using OpenAI.Embeddings; + namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; public class EmbeddingsContext { - /// - /// Binding target for the . - /// - /// The embeddings request that was sent to OpenAI. - /// The embeddings response that was received from OpenAI. - public EmbeddingsContext(OpenAISDK.EmbeddingsOptions Request, OpenAISDK.Embeddings Response) + public EmbeddingsContext(IList Request, OpenAIEmbeddingCollection? Response) { this.Request = Request; this.Response = Response; @@ -20,15 +16,15 @@ public EmbeddingsContext(OpenAISDK.EmbeddingsOptions Request, OpenAISDK.Embeddin /// /// Embeddings request sent to OpenAI. /// - public OpenAISDK.EmbeddingsOptions Request { get; set; } + public IList Request { get; set; } /// /// Embeddings response from OpenAI. /// - public OpenAISDK.Embeddings Response { get; set; } - + public OpenAIEmbeddingCollection? Response { get; set; } + /// /// Gets the number of embeddings that were returned in the response. /// - public int Count => this.Response.Data?.Count ?? 0; + public int Count => this.Response?.Count ?? 0; } diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs index f9a2d073..a5943c6e 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs @@ -8,7 +8,7 @@ namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; public class EmbeddingsInputAttribute : InputBindingAttribute { /// - /// Initializes a new instance of the class with the specified input. + /// Initializes a new instance of the class with the specified input. /// /// The input source containing the data to generate embeddings for. /// The type of the input. @@ -19,13 +19,30 @@ public EmbeddingsInputAttribute(string input, InputType inputType) this.InputType = inputType; } + /// + /// Gets or sets the name of the configuration section for AI service connectivity settings. + /// + /// + /// This property specifies the name of the configuration section that contains connection details for the AI service. + /// + /// For Azure OpenAI: + /// - If specified, looks for "Endpoint" and "Key" values in this configuration section + /// - If not specified or the section doesn't exist, falls back to environment variables: + /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY + /// - For user-assigned managed identity authentication, this property is required + /// + /// For OpenAI: + /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. + /// + public string AIConnectionName { get; set; } = ""; + /// /// Gets or sets the ID of the model to use. /// /// /// Changing the default embeddings model is a breaking change, since any changes will be stored in a vector database for lookup. Changing the default model can cause the lookups to start misbehaving if they don't match the data that was previously ingested into the vector database. /// - public string Model { get; set; } = OpenAIModels.DefaultEmbeddingsModel; + public string EmbeddingsModel { get; set; } = OpenAIModels.DefaultEmbeddingsModel; /// /// Gets or sets the maximum number of characters to chunk the input into. diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsJsonConverter.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsJsonConverter.cs index 30a3679a..cfcc5d23 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsJsonConverter.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsJsonConverter.cs @@ -4,25 +4,25 @@ using System.ClientModel.Primitives; using System.Text.Json; using System.Text.Json.Serialization; -using OpenAISDK = Azure.AI.OpenAI; +using OpenAI.Embeddings; namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; /// -/// Embeddings JSON converter needed to serialize and deserialize the Embeddings object with the dotnet worker. +/// OpenAIEmbeddingCollection JSON converter needed to serialize and deserialize the OpenAIEmbeddingCollection object with the dotnet worker. /// -class EmbeddingsJsonConverter : JsonConverter +class EmbeddingsJsonConverter : JsonConverter { static readonly ModelReaderWriterOptions JsonOptions = new("J"); - public override OpenAISDK.Embeddings Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override OpenAIEmbeddingCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { using JsonDocument jsonDocument = JsonDocument.ParseValue(ref reader); - return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!; + return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!; } - public override void Write(Utf8JsonWriter writer, OpenAISDK.Embeddings value, JsonSerializerOptions options) + public override void Write(Utf8JsonWriter writer, OpenAIEmbeddingCollection value, JsonSerializerOptions options) { - ((IJsonModel)value).Write(writer, JsonOptions); + ((IJsonModel)value).Write(writer, JsonOptions); } } diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs index 0a8949dc..9ca6bdf8 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs @@ -4,25 +4,25 @@ using System.ClientModel.Primitives; using System.Text.Json; using System.Text.Json.Serialization; -using Azure.AI.OpenAI; +using OpenAI.Embeddings; namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; /// /// EmbeddingsOptions JSON converter needed to serialize and deserialize the EmbeddingsOptions object with the dotnet worker. /// -class EmbeddingsOptionsJsonConverter : JsonConverter +class EmbeddingsOptionsJsonConverter : JsonConverter { static readonly ModelReaderWriterOptions JsonOptions = new("J"); - public override EmbeddingsOptions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override EmbeddingGenerationOptions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { using JsonDocument jsonDocument = JsonDocument.ParseValue(ref reader); - return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!; + return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!; } - public override void Write(Utf8JsonWriter writer, EmbeddingsOptions value, JsonSerializerOptions options) + public override void Write(Utf8JsonWriter writer, EmbeddingGenerationOptions value, JsonSerializerOptions options) { - ((IJsonModel)value).Write(writer, JsonOptions); + ((IJsonModel)value).Write(writer, JsonOptions); } } diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs index 4be0d841..9adde53b 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs @@ -14,28 +14,45 @@ public class EmbeddingsStoreOutputAttribute : OutputBindingAttribute /// The input source containing the data to generate embeddings for /// and is interpreted based on the value for . /// The type of the input. - /// - /// The name of an app setting or environment variable which contains a connection string value. + /// + /// The name of an app setting or environment variable which contains a connection string value for embedding store. /// /// The name of the collection or table to search or store. /// - /// Thrown if or or are null. + /// Thrown if or or are null. /// - public EmbeddingsStoreOutputAttribute(string input, InputType inputType, string connectionName, string collection) + public EmbeddingsStoreOutputAttribute(string input, InputType inputType, string storeConnectionName, string collection) { this.Input = input ?? throw new ArgumentNullException(nameof(input)); this.InputType = inputType; - this.ConnectionName = connectionName ?? throw new ArgumentNullException(nameof(connectionName)); + this.StoreConnectionName = storeConnectionName ?? throw new ArgumentNullException(nameof(storeConnectionName)); this.Collection = collection ?? throw new ArgumentNullException(nameof(collection)); } + /// + /// Gets or sets the name of the configuration section for AI service connectivity settings. + /// + /// + /// This property specifies the name of the configuration section that contains connection details for the AI service. + /// + /// For Azure OpenAI: + /// - If specified, looks for "Endpoint" and "Key" values in this configuration section + /// - If not specified or the section doesn't exist, falls back to environment variables: + /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY + /// - For user-assigned managed identity authentication, configuration section is required + /// + /// For OpenAI: + /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. + /// + public string AIConnectionName { get; set; } = ""; + /// /// Gets or sets the ID of the model to use. /// /// /// Changing the default embeddings model is a breaking change, since any changes will be stored in a vector database for lookup. Changing the default model can cause the lookups to start misbehaving if they don't match the data that was previously ingested into the vector database. /// - public string Model { get; set; } = OpenAIModels.DefaultEmbeddingsModel; + public string EmbeddingsModel { get; set; } = OpenAIModels.DefaultEmbeddingsModel; /// /// Gets or sets the maximum number of characters to chunk the input into. @@ -70,7 +87,7 @@ public EmbeddingsStoreOutputAttribute(string input, InputType inputType, string /// /// This property supports binding expressions. /// - public string ConnectionName { get; set; } + public string StoreConnectionName { get; set; } /// /// The name of the collection or table to search. diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs new file mode 100644 index 00000000..242a8fe0 --- /dev/null +++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs @@ -0,0 +1,49 @@ +using System.ClientModel.Primitives; +using System.Text.Json; + +namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; + +class JsonModelListWrapper : IJsonModel> +{ + readonly List list; + + public JsonModelListWrapper(List list) + { + this.list = list; + } + + public void Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + writer.WriteStartArray(); + foreach (string item in this.list) + { + writer.WriteStringValue(item); + } + writer.WriteEndArray(); + } + + public static JsonModelListWrapper FromList(List list) + { + return new JsonModelListWrapper(list); + } + + public List Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public BinaryData Write(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public List Create(BinaryData data, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public string GetFormatFromOptions(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } +} \ No newline at end of file diff --git a/src/Functions.Worker.Extensions.OpenAI/Functions.Worker.Extensions.OpenAI.csproj b/src/Functions.Worker.Extensions.OpenAI/Functions.Worker.Extensions.OpenAI.csproj index a01f9810..aac240ec 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Functions.Worker.Extensions.OpenAI.csproj +++ b/src/Functions.Worker.Extensions.OpenAI/Functions.Worker.Extensions.OpenAI.csproj @@ -7,7 +7,7 @@ - + diff --git a/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs b/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs index 959912a5..5b6804e2 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs @@ -4,7 +4,8 @@ using System.ClientModel.Primitives; using System.Text.Json; using System.Text.Json.Serialization; -using OpenAISDK = Azure.AI.OpenAI; +using Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; +using OpenAI.Embeddings; namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Search; @@ -23,12 +24,13 @@ public override void Write(Utf8JsonWriter writer, SearchableDocument value, Json writer.WritePropertyName("embeddingsContext"u8); writer.WriteStartObject(); - if (value.EmbeddingsContext?.Request is IJsonModel request) + if (value.EmbeddingsContext?.Request is List inputList) { writer.WritePropertyName("request"u8); - request.Write(writer, modelReaderWriterOptions); + var inputWrapper = JsonModelListWrapper.FromList(inputList); + inputWrapper.Write(writer, modelReaderWriterOptions); } - if (value.EmbeddingsContext?.Response is IJsonModel response) + if (value.EmbeddingsContext?.Response is IJsonModel response) { writer.WritePropertyName("response"u8); response.Write(writer, modelReaderWriterOptions); diff --git a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchContext.cs b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchContext.cs index b40a0dc8..8fb1788d 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchContext.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchContext.cs @@ -2,8 +2,8 @@ // Licensed under the MIT License. using System.Text.Json.Serialization; -using Azure.AI.OpenAI; using Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings; +using OpenAI.Chat; namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Search; @@ -17,11 +17,11 @@ public class SemanticSearchContext /// /// The embeddings context associated with the semantic search. /// The chat response from the large language model. - public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletions Chat) + public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletion Chat) { this.Embeddings = Embeddings; this.Chat = Chat; - + } /// @@ -34,12 +34,11 @@ public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletions Chat) /// Gets the chat response from the large language model. /// [JsonPropertyName("chat")] - public ChatCompletions Chat { get; } - + public ChatCompletion Chat { get; } /// /// Gets the latest response message from the OpenAI Chat API. /// [JsonPropertyName("response")] - public string Response => this.Chat.Choices.Last().Message.Content; - } + public string Response => this.Chat.Content.Last().Text; +} diff --git a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs index e660d709..1b679607 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs @@ -14,26 +14,43 @@ public sealed class SemanticSearchInputAttribute : InputBindingAttribute /// Initializes a new instance of the class with the specified connection /// and collection names. /// - /// + /// /// The name of an app setting or environment variable which contains a connection string value. /// /// The name of the collection or table to search or store. /// - /// Thrown if either or are null. + /// Thrown if either or are null. /// - public SemanticSearchInputAttribute(string connectionName, string collection) + public SemanticSearchInputAttribute(string searchConnectionName, string collection) { - this.ConnectionName = connectionName ?? throw new ArgumentNullException(nameof(connectionName)); + this.SearchConnectionName = searchConnectionName ?? throw new ArgumentNullException(nameof(searchConnectionName)); this.Collection = collection ?? throw new ArgumentNullException(nameof(collection)); } + /// + /// Gets or sets the name of the configuration section for AI service connectivity settings. + /// + /// + /// This property specifies the name of the configuration section that contains connection details for the AI service. + /// + /// For Azure OpenAI: + /// - If specified, looks for "Endpoint" and "Key" values in this configuration section + /// - If not specified or the section doesn't exist, falls back to environment variables: + /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY + /// - For user-assigned managed identity authentication, this property is required + /// + /// For OpenAI: + /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. + /// + public string AIConnectionName { get; set; } = ""; + /// /// Gets or sets the name of an app setting or environment variable which contains a connection string value. /// /// /// This property supports binding expressions. /// - public string ConnectionName { get; set; } + public string SearchConnectionName { get; set; } /// /// The name of the collection or table to search. @@ -98,4 +115,40 @@ The following is a list of documents that you can refer to when answering questi /// Gets or sets the number of knowledge items to inject into the . /// public int MaxKnowledgeCount { get; set; } = 1; + + /// + /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output + /// more random, while lower values like 0.2 will make it more focused and deterministic. + /// + /// + /// It's generally recommend to use this or but not both. + /// + public string? Temperature { get; set; } = "0.5"; + + /// + /// Gets or sets an alternative to sampling with temperature, called nucleus sampling, where the model considers + /// the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + /// probability mass are considered. + /// + /// + /// It's generally recommend to use this or but not both. + /// + public string? TopP { get; set; } + + /// + /// Gets or sets the maximum number of tokens to output in the completion. Default is 2048. + /// + /// + /// The token count of your prompt plus max_tokens cannot exceed the model's context length. + /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). + /// + public string? MaxTokens { get; set; } = "2048"; + + /// + /// Gets or sets a value indicating whether the model is a reasoning model. + /// + /// + /// Warning: This is experimental and associated with the reasoning model until all models have parity in the expected properties. + /// + public bool IsReasoningModel { get; set; } } diff --git a/src/Functions.Worker.Extensions.OpenAI/Startup.cs b/src/Functions.Worker.Extensions.OpenAI/Startup.cs index 8bbd524b..e4ad38a9 100644 --- a/src/Functions.Worker.Extensions.OpenAI/Startup.cs +++ b/src/Functions.Worker.Extensions.OpenAI/Startup.cs @@ -25,7 +25,7 @@ public override void Configure(IFunctionsWorkerApplicationBuilder applicationBui jsonSerializerOptions.Converters.Add(new EmbeddingsJsonConverter()); jsonSerializerOptions.Converters.Add(new EmbeddingsOptionsJsonConverter()); jsonSerializerOptions.Converters.Add(new SearchableDocumentJsonConverter()); - jsonSerializerOptions.Converters.Add(new ChatCompletionsJsonConverter()); + jsonSerializerOptions.Converters.Add(new ChatCompletionJsonConverter()); }); } } diff --git a/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs index 463d1e96..d9eaa985 100644 --- a/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs +++ b/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs @@ -19,6 +19,23 @@ public TextCompletionInputAttribute(string prompt) this.Prompt = prompt ?? throw new ArgumentNullException(nameof(prompt)); } + /// + /// Gets or sets the name of the configuration section for AI service connectivity settings. + /// + /// + /// This property specifies the name of the configuration section that contains connection details for the AI service. + /// + /// For Azure OpenAI: + /// - If specified, looks for "Endpoint" and "Key" values in this configuration section + /// - If not specified or the section doesn't exist, falls back to environment variables: + /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY + /// - For user-assigned managed identity authentication, this property is required + /// + /// For OpenAI: + /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. + /// + public string AIConnectionName { get; set; } = ""; + /// /// Gets or sets the prompt to generate completions for, encoded as a string. /// @@ -27,7 +44,7 @@ public TextCompletionInputAttribute(string prompt) /// /// Gets or sets the ID of the model to use. /// - public string Model { get; set; } = "gpt-3.5-turbo"; + public string ChatModel { get; set; } = "gpt-3.5-turbo"; /// /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output @@ -56,4 +73,12 @@ public TextCompletionInputAttribute(string prompt) /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). /// public string? MaxTokens { get; set; } = "100"; + + /// + /// Gets or sets a value indicating whether the model is a reasoning model. + /// + /// + /// Warning: This is experimental and associated with the reasoning model until all models have parity in the expected properties. + /// + public bool IsReasoningModel { get; set; } } diff --git a/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs b/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs index 4ad8280e..4c4d0134 100644 --- a/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs +++ b/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs @@ -246,16 +246,16 @@ async Task IndexSectionsAsync(SearchClient searchClient, SearchableDocument docu { int iteration = 0; IndexDocumentsBatch batch = new(); - for (int i = 0; i < document.Embeddings?.Response?.Data.Count; i++) + for (int i = 0; i < document.Embeddings?.Response?.Count; i++) { batch.Actions.Add(new IndexDocumentsAction( IndexActionType.MergeOrUpload, new SearchDocument { ["id"] = Guid.NewGuid().ToString("N"), - ["text"] = document.Embeddings.Request.Input![i], + ["text"] = document.Embeddings.Request![i], ["title"] = Path.GetFileNameWithoutExtension(document.Title), - ["embeddings"] = document.Embeddings.Response.Data[i].Embedding.ToArray() ?? Array.Empty(), + ["embeddings"] = document.Embeddings.Response[i].ToFloats().ToArray() ?? Array.Empty(), ["timestamp"] = DateTime.UtcNow })); iteration++; diff --git a/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md b/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md index f53671d3..5e1f2299 100644 --- a/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md +++ b/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Managed identity support and consistency established with other Azure Functions extensions +### Changed + +- Updated Microsoft.Azure.WebJobs.Extensions.OpenAI to 0.19.0 + ## v0.3.0 - 2024/10/08 ### Changed diff --git a/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CHANGELOG.md b/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CHANGELOG.md index bdef6010..187ff9f1 100644 --- a/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CHANGELOG.md +++ b/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CHANGELOG.md @@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added DiskANN support for CosmosDB (MongoDB) Search Provider. Refer [README](../../samples/rag-cosmosdb/README.md) for more information on usage. +### Changed + +- Updated Microsoft.Azure.WebJobs.Extensions.OpenAI to 0.19.0 + ## v0.3.0 - 2024/10/08 ### Added @@ -18,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added HNSW support for CosmosDB (MongoDB) Search Provider. Refer [README](../../samples/rag-cosmosdb/README.md) for more information on usage. ### Changed + - Updated nuget dependencies ## v0.2.0 - 2024/05/06 diff --git a/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs b/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs index d4bc4870..82ad69dc 100644 --- a/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs +++ b/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs @@ -212,7 +212,7 @@ public void CreateVectorIndexIfNotExists(MongoClient cosmosClient) async Task UpsertVectorAsync(MongoClient cosmosClient, SearchableDocument document) { List list = new(); - for (int i = 0; i < document.Embeddings?.Response?.Data.Count; i++) + for (int i = 0; i < document.Embeddings?.Response?.Count; i++) { BsonDocument vectorDocument = new() @@ -220,15 +220,14 @@ async Task UpsertVectorAsync(MongoClient cosmosClient, SearchableDocument docume { "id", Guid.NewGuid().ToString("N") }, { this.cosmosDBSearchConfigOptions.Value.TextKey, - document.Embeddings.Request.Input![i] + document.Embeddings.Request![i] }, { "title", Path.GetFileNameWithoutExtension(document.Title) }, { this.cosmosDBSearchConfigOptions.Value.EmbeddingKey, new BsonArray( document - .Embeddings.Response.Data[i] - .Embedding.ToArray() + .Embeddings.Response[i].ToFloats().ToArray() .Select(e => new BsonDouble(Convert.ToDouble(e))) ) }, diff --git a/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md b/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md index 51430935..a7c67706 100644 --- a/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md +++ b/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## v0.17.0 - Unreleased + +- Updated Microsoft.Azure.WebJobs.Extensions.OpenAI to 0.19.0 + ## v0.16.0 - 2024/10/08 ### Changed diff --git a/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs b/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs index 542454a3..fe46f501 100644 --- a/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs +++ b/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs @@ -78,13 +78,13 @@ public async Task AddDocumentAsync(SearchableDocument document, CancellationToke table.AppendColumn("Embeddings", typeof(object)); table.AppendColumn("Timestamp", typeof(DateTime)); - for (int i = 0; i < document.Embeddings?.Response?.Data.Count; i++) + for (int i = 0; i < document.Embeddings?.Response?.Count; i++) { table.Rows.Add( Guid.NewGuid().ToString("N"), Path.GetFileNameWithoutExtension(document.Title), - document.Embeddings.Request.Input![i], - GetEmbeddingsString(document.Embeddings.Response.Data[i].Embedding, true), + document.Embeddings.Request![i], + GetEmbeddingsString(document.Embeddings.Response[i].ToFloats().ToArray(), true), DateTime.UtcNow); } diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs new file mode 100644 index 00000000..9c8a4c21 --- /dev/null +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs @@ -0,0 +1,122 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using Microsoft.Azure.WebJobs.Description; +using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; +using OpenAI.Chat; + +namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; + +/// +/// Binding attribute for the OpenAI Assistant extension. This attribute is used to specify the configuration +/// settings for the OpenAI Assistant when used in a function parameter. It allows to set various chat completion options. +/// +[Binding] +[AttributeUsage(AttributeTargets.Parameter)] +public class AssistantBaseAttribute : Attribute +{ + /// + /// Gets or sets the name of the Large Language Model to invoke for chat responses. + /// The default value is "gpt-3.5-turbo". + /// + /// + /// This property supports binding expressions. + /// + [AutoResolve] + public string ChatModel { get; set; } = OpenAIModels.DefaultChatModel; + + /// + /// Gets or sets the name of the configuration section for AI service connectivity settings. + /// + /// + /// This property specifies the name of the configuration section that contains connection details for the AI service. + /// + /// For Azure OpenAI: + /// - If specified, looks for "Endpoint" and "Key" values in this configuration section + /// - If not specified or the section doesn't exist, falls back to environment variables: + /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY + /// - For user-assigned managed identity authentication, configuration section is required + /// + /// For OpenAI: + /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. + /// + public string AIConnectionName { get; set; } = ""; + + /// + /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output + /// more random, while lower values like 0.2 will make it more focused and deterministic. + /// + /// + /// It's generally recommend to use this or but not both. + /// + [AutoResolve] + public string? Temperature { get; set; } = "0.5"; + + /// + /// Gets or sets an alternative to sampling with temperature, called nucleus sampling, where the model considers + /// the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + /// probability mass are considered. + /// + /// + /// It's generally recommend to use this or but not both. + /// + [AutoResolve] + public string? TopP { get; set; } + + /// + /// Gets or sets a value indicating whether the model is a reasoning model. + /// + /// + /// Warning: This is experimental and associated with the reasoning model until all models have parity in the expected properties. + /// + public bool IsReasoningModel { get; set; } + + /// + /// Gets or sets the maximum number of tokens to output in the completion. Default value = 100. + /// + /// + /// The token count of your prompt plus max_tokens cannot exceed the model's context length. + /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). + /// + [AutoResolve] + public string? MaxTokens { get; set; } = "100"; + + internal ChatCompletionOptions BuildRequest() + { + ChatCompletionOptions request = new(); + if (float.TryParse(this.TopP, out float topP)) + { + request.TopP = topP; + } + + if (this.IsReasoningModel) + { + this.MaxTokens = null; + this.Temperature = null; // property not supported for reasoning model. + } + else + { + if (int.TryParse(this.MaxTokens, out int maxTokens)) + { + request.MaxOutputTokenCount = maxTokens; + } + + if (float.TryParse(this.Temperature, out float temperature)) + { + request.Temperature = temperature; + } + } + + // ToDo: SetNewMaxCompletionTokensPropertyEnabled() has a bug in the current version + // of the Azure.AI.OpenAI SDK but is fixed in the next preview. + // + // This method doesn't swap max_tokens with max_completion_tokens and throws errors + // if max_tokens is set. max_completion_tokens is not configurable due to this bug. + // + // Hence, setting max_tokens to null for reasoning models until the fixed SDK version + // can be adopted. + // request.SetNewMaxCompletionTokensPropertyEnabled(this.IsReasoningModel); + + return request; + } +} diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs index 56ed35fa..9e3c5faf 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs @@ -7,8 +7,13 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; [Binding] [AttributeUsage(AttributeTargets.Parameter)] -public sealed class AssistantPostAttribute : Attribute +public sealed class AssistantPostAttribute : AssistantBaseAttribute { + /// + /// Initializes a new instance of the class. + /// + /// The assistant identifier. + /// The user message. public AssistantPostAttribute(string id, string userMessage) { this.Id = id; @@ -21,15 +26,6 @@ public AssistantPostAttribute(string id, string userMessage) [AutoResolve] public string Id { get; } - /// - /// Gets or sets the OpenAI chat model to use. - /// - /// - /// When using Azure OpenAI, then should be the name of the model deployment. - /// - [AutoResolve] - public string? Model { get; set; } - /// /// Gets or sets the user message to OpenAI. /// diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantRuntimeState.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantRuntimeState.cs index 55dbd2f5..03b10f6d 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantRuntimeState.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantRuntimeState.cs @@ -6,13 +6,13 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; -record struct MessageRecord(DateTime Timestamp, ChatMessage ChatMessageEntity); +record struct MessageRecord(DateTime Timestamp, AssistantMessage ChatMessageEntity); [JsonObject(MemberSerialization.OptIn)] class AssistantRuntimeState { [JsonProperty("messages")] - public List? ChatMessages { get; set; } + public List? ChatMessages { get; set; } [JsonProperty("totalTokens")] public int TotalTokens { get; set; } = 0; diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs index e31e8a3b..cd5726b3 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs @@ -1,13 +1,14 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.ClientModel; using Azure; -using Azure.AI.OpenAI; using Azure.Data.Tables; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; using Microsoft.Extensions.Azure; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.Logging; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; @@ -28,7 +29,7 @@ record InternalChatState(string Id, AssistantStateEntity Metadata, List const int FunctionCallBatchLimit = 50; const string DefaultChatStorage = "AzureWebJobsStorage"; - readonly OpenAIClient openAIClient; + readonly OpenAIClientFactory openAIClientFactory; readonly IAssistantSkillInvoker skillInvoker; readonly ILogger logger; readonly AzureComponentFactory azureComponentFactory; @@ -37,7 +38,7 @@ record InternalChatState(string Id, AssistantStateEntity Metadata, List(); + this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory)); this.azureComponentFactory = azureComponentFactory ?? throw new ArgumentNullException(nameof(azureComponentFactory)); this.configuration = configuration ?? throw new ArgumentNullException(nameof(configuration)); } @@ -117,7 +117,8 @@ async Task DeleteBatch() partitionKey: request.Id, messageIndex: 1, // 1-based index content: request.Instructions, - role: ChatRole.System); + role: ChatMessageRole.System, + toolCalls: null); batch.Add(new TableTransactionAction(TableTransactionActionType.Add, chatMessageEntity)); } @@ -152,7 +153,7 @@ public async Task GetStateAsync(AssistantQueryAttribute assistan if (chatState is null) { this.logger.LogWarning("No assistant exists with ID = '{Id}'", id); - return new AssistantState(id, false, default, default, 0, 0, Array.Empty()); + return new AssistantState(id, false, default, default, 0, 0, Array.Empty()); } List filteredChatMessages = chatState.Messages @@ -171,110 +172,100 @@ public async Task GetStateAsync(AssistantQueryAttribute assistan chatState.Metadata.LastUpdatedAt, chatState.Metadata.TotalMessages, chatState.Metadata.TotalTokens, - filteredChatMessages.Select(msg => new ChatMessage(msg.Content, msg.Role, msg.Name)).ToList()); + filteredChatMessages.Select(msg => new AssistantMessage(msg.Content, msg.Role, msg.ToolCallsString)).ToList()); return state; } public async Task PostMessageAsync(AssistantPostAttribute attribute, CancellationToken cancellationToken) { + // Validate inputs and prepare for processing DateTime timeFilter = DateTime.UtcNow; - if (string.IsNullOrEmpty(attribute.Id)) - { - throw new ArgumentException("The assistant ID must be specified.", nameof(attribute)); - } - - if (string.IsNullOrEmpty(attribute.UserMessage)) - { - throw new ArgumentException("The assistant must have a user message", nameof(attribute)); - } + this.ValidateAttributes(attribute); this.logger.LogInformation("Posting message to assistant entity '{Id}'", attribute.Id); - TableClient tableClient = this.GetOrCreateTableClient(attribute.ChatStorageConnectionSetting, attribute.CollectionName); + // Load and validate chat state InternalChatState? chatState = await this.LoadChatStateAsync(attribute.Id, tableClient, cancellationToken); - - // Check if assistant has been deactivated if (chatState is null || !chatState.Metadata.Exists) { - this.logger.LogWarning("[{Id}] Ignoring request sent to nonexistent assistant.", attribute.Id); - return new AssistantState(attribute.Id, false, default, default, 0, 0, Array.Empty()); + return this.CreateNonExistentAssistantState(attribute.Id); } this.logger.LogInformation("[{Id}] Received message: {Text}", attribute.Id, attribute.UserMessage); - // Create a batch of table transaction actions + // Process the conversation List batch = new(); + this.AddUserMessageToChat(attribute, chatState, batch); + + await this.ProcessConversationWithLLM(attribute, chatState, batch, cancellationToken); - // Add the user message as a new Chat message entity + // Update state and persist changes + this.UpdateAssistantState(chatState, batch); + await tableClient.SubmitTransactionAsync(batch, cancellationToken); + + // Return results + return this.CreateAssistantStateResponse(attribute.Id, chatState, timeFilter); + } + + // Helper methods + void ValidateAttributes(AssistantPostAttribute attribute) + { + if (string.IsNullOrEmpty(attribute.Id)) + { + throw new ArgumentException("The assistant ID must be specified.", nameof(attribute)); + } + + if (string.IsNullOrEmpty(attribute.UserMessage)) + { + throw new ArgumentException("The assistant must have a user message", nameof(attribute)); + } + } + + AssistantState CreateNonExistentAssistantState(string id) + { + this.logger.LogWarning("[{Id}] Ignoring request sent to nonexistent assistant.", id); + return new AssistantState(id, false, default, default, 0, 0, Array.Empty()); + } + + void AddUserMessageToChat(AssistantPostAttribute attribute, InternalChatState chatState, List batch) + { ChatMessageTableEntity chatMessageEntity = new( partitionKey: attribute.Id, messageIndex: ++chatState.Metadata.TotalMessages, content: attribute.UserMessage, - role: ChatRole.User); + role: ChatMessageRole.User, + toolCalls: null); chatState.Messages.Add(chatMessageEntity); - - // Add the chat message to the batch batch.Add(new TableTransactionAction(TableTransactionActionType.Add, chatMessageEntity)); + } - string deploymentName = attribute.Model ?? OpenAIModels.DefaultChatModel; - IList? functions = this.skillInvoker.GetFunctionsDefinitions(); + async Task ProcessConversationWithLLM( + AssistantPostAttribute attribute, + InternalChatState chatState, + List batch, + CancellationToken cancellationToken) + { + IList? functions = this.skillInvoker.GetFunctionsDefinitions(); // We loop if the model returns function calls. Otherwise, we break after receiving a response. while (true) { - // Get the next response from the LLM - ChatCompletionsOptions chatRequest = new(deploymentName, ToOpenAIChatRequestMessages(chatState.Messages)); - if (functions is not null) - { - foreach (ChatCompletionsFunctionToolDefinition fn in functions) - { - chatRequest.Tools.Add(fn); - } - } - - Response response = await this.openAIClient.GetChatCompletionsAsync( - chatRequest, - cancellationToken); + // Get the LLM response + ClientResult response = await this.GetLLMResponse(attribute, chatState, functions, cancellationToken); - // We don't normally expect more than one message, but just in case we get multiple messages, - // return all of them separated by two newlines. - string replyMessage = string.Join( - Environment.NewLine + Environment.NewLine, - response.Value.Choices.Select(choice => choice.Message.Content)); - if (!string.IsNullOrWhiteSpace(replyMessage)) + // Process text response if available + string replyMessage = this.FormatReplyMessage(response); + if (!string.IsNullOrWhiteSpace(replyMessage) || response.Value.ToolCalls.Any()) { - this.logger.LogInformation( - "[{Id}] Got LLM response consisting of {Count} tokens: {Text}", - attribute.Id, - response.Value.Usage.CompletionTokens, - replyMessage); - - // Add the user message as a new Chat message entity - ChatMessageTableEntity replyFromAssistantEntity = new( - partitionKey: attribute.Id, - messageIndex: ++chatState.Metadata.TotalMessages, - content: replyMessage, - role: ChatRole.Assistant); - chatState.Messages.Add(replyFromAssistantEntity); - - // Add the reply from assistant chat message to the batch - batch.Add(new TableTransactionAction(TableTransactionActionType.Add, replyFromAssistantEntity)); - - this.logger.LogInformation( - "[{Id}] Chat length is now {Count} messages", - attribute.Id, - chatState.Metadata.TotalMessages); + this.LogAndAddAssistantReply(attribute.Id, replyMessage, response, chatState, batch); } - // Set the total tokens that have been consumed. - chatState.Metadata.TotalTokens = response.Value.Usage.TotalTokens; + // Update token count + chatState.Metadata.TotalTokens = response.Value.Usage.TotalTokenCount; - // Check for function calls (which are described in the API as tools) - List functionCalls = response.Value.Choices - .SelectMany(c => c.Message.ToolCalls) - .OfType() - .ToList(); + // Handle function calls + List functionCalls = response.Value.ToolCalls.OfType().ToList(); if (functionCalls.Count == 0) { // No function calls, so we're done @@ -283,107 +274,193 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib if (batch.Count > FunctionCallBatchLimit) { - // Too many function calls, something might be wrong. Break out of the loop - // to avoid infinite loops and to avoid exceeding the batch size limit of 100. - this.logger.LogWarning( - "[{Id}] Ignoring {Count} function call(s) in response due to exceeding the limit of {Limit}.", - attribute.Id, - functionCalls.Count, - FunctionCallBatchLimit); + // Too many function calls, something might be wrong + this.LogBatchLimitExceeded(attribute.Id, functionCalls.Count); break; } - // Loop case: found some functions to execute - this.logger.LogInformation( - "[{Id}] Found {Count} function call(s) in response", - attribute.Id, - functionCalls.Count); + // Process function calls + await this.ProcessFunctionCalls(attribute.Id, functionCalls, chatState, batch, cancellationToken); + } + } - // Invoke the function calls and add the responses to the chat history. - List> tasks = new(capacity: functionCalls.Count); - foreach (ChatCompletionsFunctionToolCall call in functionCalls) + async Task> GetLLMResponse( + AssistantPostAttribute attribute, + InternalChatState chatState, + IList? functions, + CancellationToken cancellationToken) + { + ChatCompletionOptions chatRequest = attribute.BuildRequest(); + if (functions is not null) + { + foreach (ChatTool fn in functions) { - // CONSIDER: Call these in parallel - this.logger.LogInformation( - "[{Id}] Calling function '{Name}' with arguments: {Args}", - attribute.Id, - call.Name, - call.Arguments); - - string? functionResult; - try - { - // NOTE: In Consumption plans, calling a function from another function results in double-billing. - // CONSIDER: Use a background thread to invoke the action to avoid double-billing. - functionResult = await this.skillInvoker.InvokeAsync(call, cancellationToken); - - this.logger.LogInformation( - "[{id}] Function '{Name}' returned the following content: {Content}", - attribute.Id, - call.Name, - functionResult); - } - catch (Exception ex) - { - this.logger.LogError( - ex, - "[{id}] Function '{Name}' failed with an unhandled exception", - attribute.Id, - call.Name); - - // CONSIDER: Automatic retries? - functionResult = "The function call failed. Let the user know and ask if they'd like you to try again"; - } - - if (string.IsNullOrWhiteSpace(functionResult)) - { - // When experimenting with gpt-4-0613, an empty result would cause the model to go into a - // function calling loop. By instead providing a result with some instructions, we were able - // to get the model to response to the user in a natural way. - functionResult = "The function call succeeded. Let the user know that you completed the action."; - } - - ChatMessageTableEntity functionResultEntity = new( - partitionKey: attribute.Id, - messageIndex: ++chatState.Metadata.TotalMessages, - content: functionResult, - role: ChatRole.Function, - name: call.Name); - chatState.Messages.Add(functionResultEntity); - - batch.Add(new TableTransactionAction(TableTransactionActionType.Add, functionResultEntity)); + chatRequest.Tools.Add(fn); } } - // Update the assistant state entity + IEnumerable chatMessages = ToOpenAIChatRequestMessages(chatState.Messages); + + return await this.openAIClientFactory.GetChatClient( + attribute.AIConnectionName, + attribute.ChatModel).CompleteChatAsync(chatMessages, chatRequest, cancellationToken: cancellationToken); + } + + string FormatReplyMessage(ClientResult response) + { + return string.Join( + Environment.NewLine + Environment.NewLine, + response.Value.Content.Select(message => message.Text)); + } + + void LogAndAddAssistantReply( + string assistantId, + string replyMessage, + ClientResult response, + InternalChatState chatState, + List batch) + { + this.logger.LogInformation( + "[{Id}] Got LLM response consisting of {Count} tokens: [{Text}] && {Count} ToolCalls", + assistantId, + response.Value.Usage.OutputTokenCount, + replyMessage, + response.Value.ToolCalls.Count); + + ChatMessageTableEntity replyFromAssistantEntity = new( + partitionKey: assistantId, + messageIndex: ++chatState.Metadata.TotalMessages, + content: replyMessage, + role: ChatMessageRole.Assistant, + toolCalls: response.Value.ToolCalls); + + chatState.Messages.Add(replyFromAssistantEntity); + batch.Add(new TableTransactionAction(TableTransactionActionType.Add, replyFromAssistantEntity)); + + this.logger.LogInformation( + "[{Id}] Chat length is now {Count} messages", + assistantId, + chatState.Metadata.TotalMessages); + } + + void LogBatchLimitExceeded(string assistantId, int functionCallCount) + { + this.logger.LogWarning( + "[{Id}] Ignoring {Count} function call(s) in response due to exceeding the limit of {Limit}.", + assistantId, + functionCallCount, + FunctionCallBatchLimit); + } + + async Task ProcessFunctionCalls( + string assistantId, + List functionCalls, + InternalChatState chatState, + List batch, + CancellationToken cancellationToken) + { + this.logger.LogInformation( + "[{Id}] Found {Count} function call(s) in response", + assistantId, + functionCalls.Count); + + foreach (ChatToolCall call in functionCalls) + { + await this.ProcessSingleFunctionCall(assistantId, call, chatState, batch, cancellationToken); + } + } + + async Task ProcessSingleFunctionCall( + string assistantId, + ChatToolCall call, + InternalChatState chatState, + List batch, + CancellationToken cancellationToken) + { + this.logger.LogInformation( + "[{Id}] Calling function '{Name}' with arguments: {Args}", + assistantId, + call.FunctionName, + call.FunctionArguments); + + string? functionResult = await this.InvokeFunctionWithErrorHandling(assistantId, call, cancellationToken); + + if (string.IsNullOrWhiteSpace(functionResult)) + { + functionResult = "The function call succeeded. Let the user know that you completed the action."; + } + + ChatMessageTableEntity functionResultEntity = new( + partitionKey: assistantId, + messageIndex: ++chatState.Metadata.TotalMessages, + content: $"Function Name: '{call.FunctionName}' and Function Result: '{functionResult}'", + role: ChatMessageRole.Tool, + name: call.Id, + toolCalls: null); + + chatState.Messages.Add(functionResultEntity); + batch.Add(new TableTransactionAction(TableTransactionActionType.Add, functionResultEntity)); + } + + async Task InvokeFunctionWithErrorHandling( + string assistantId, + ChatToolCall call, + CancellationToken cancellationToken) + { + try + { + // NOTE: In Consumption plans, calling a function from another function results in double-billing. + // CONSIDER: Use a background thread to invoke the action to avoid double-billing. + string? result = await this.skillInvoker.InvokeAsync(call, cancellationToken); + + this.logger.LogInformation( + "[{id}] Function '{Name}' returned the following content: {Content}", + assistantId, + call.FunctionName, + result); + + return result; + } + catch (Exception ex) + { + this.logger.LogError( + ex, + "[{id}] Function '{Name}' failed with an unhandled exception", + assistantId, + call.FunctionName); + + // CONSIDER: Automatic retries? + return "The function call failed. Let the user know and ask if they'd like you to try again"; + } + } + + void UpdateAssistantState(InternalChatState chatState, List batch) + { chatState.Metadata.TotalMessages = chatState.Messages.Count; chatState.Metadata.LastUpdatedAt = DateTime.UtcNow; batch.Add(new TableTransactionAction(TableTransactionActionType.UpdateMerge, chatState.Metadata)); + } - // Add the batch of table transaction actions to the table - await tableClient.SubmitTransactionAsync(batch, cancellationToken); - - // return the latest assistant message in the chat state + AssistantState CreateAssistantStateResponse(string assistantId, InternalChatState chatState, DateTime timeFilter) + { List filteredChatMessages = chatState.Messages - .Where(msg => msg.CreatedAt > timeFilter && msg.Role == ChatRole.Assistant) + .Where(msg => msg.CreatedAt > timeFilter && msg.Role == ChatMessageRole.Assistant.ToString()) .ToList(); this.logger.LogInformation( "Returning {Count}/{Total} chat messages from entity '{Id}'", filteredChatMessages.Count, chatState.Metadata.TotalMessages, - attribute.Id); + assistantId); - AssistantState state = new( - attribute.Id, + return new AssistantState( + assistantId, true, chatState.Metadata.CreatedAt, chatState.Metadata.LastUpdatedAt, chatState.Metadata.TotalMessages, chatState.Metadata.TotalTokens, - filteredChatMessages.Select(msg => new ChatMessage(msg.Content, msg.Role, msg.Name)).ToList()); - - return state; + filteredChatMessages.Select(msg => new AssistantMessage(msg.Content, msg.Role, msg.ToolCallsString)).ToList()); } async Task LoadChatStateAsync(string id, TableClient tableClient, CancellationToken cancellationToken) @@ -420,26 +497,30 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib return new InternalChatState(id, assistantStateEntity, chatMessageList); } - static IEnumerable ToOpenAIChatRequestMessages(IEnumerable entities) + static IEnumerable ToOpenAIChatRequestMessages(IEnumerable entities) { foreach (ChatMessageTableEntity entity in entities) { switch (entity.Role.ToLowerInvariant()) { case "user": - yield return new ChatRequestUserMessage(entity.Content); + yield return new UserChatMessage(entity.Content); break; case "assistant": - yield return new ChatRequestAssistantMessage(entity.Content); + if (entity.ToolCalls != null && entity.ToolCalls.Any()) + { + yield return new AssistantChatMessage(entity.ToolCalls); + } + else + { + yield return new AssistantChatMessage(entity.Content); + } break; case "system": - yield return new ChatRequestSystemMessage(entity.Content); - break; - case "function": - yield return new ChatRequestFunctionMessage(entity.Name, entity.Content); + yield return new SystemChatMessage(entity.Content); break; case "tool": - yield return new ChatRequestToolMessage(entity.Content, toolCallId: entity.Name); + yield return new ToolChatMessage(toolCallId: entity.Name, entity.Content); break; default: throw new InvalidOperationException($"Unknown chat role '{entity.Role}'"); @@ -483,7 +564,7 @@ TableClient GetOrCreateTableClient(string? chatStorageConnectionSetting, string? // Else, will use the connection string connectionStringName = chatStorageConnectionSetting ?? DefaultChatStorage; string connectionString = this.configuration.GetValue(connectionStringName); - + this.logger.LogInformation("using connection string for table service client"); this.tableServiceClient = new TableServiceClient(connectionString); diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs index 749e868b..25209b4c 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs @@ -42,12 +42,4 @@ public AssistantSkillTriggerAttribute(string functionDescription) /// https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools. /// public string? ParameterDescriptionJson { get; set; } - - /// - /// Gets or sets the OpenAI chat model to use. - /// - /// - /// When using Azure OpenAI, then should be the name of the model deployment. - /// - public string Model { get; set; } = "gpt-3.5-turbo"; } \ No newline at end of file diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs index 55eed508..8b2ee6f9 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs @@ -113,10 +113,10 @@ public Task BindAsync(object value, ValueBindingContext context) SkillInvocationContext skillInvocationContext = (SkillInvocationContext)value; object? convertedValue; - if (!string.IsNullOrEmpty(skillInvocationContext.Arguments)) + if (!string.IsNullOrEmpty(skillInvocationContext.Arguments?.ToString())) { // We expect that input to always be a string value in the form {"paramName":paramValue} - JObject argsJson = JObject.Parse(skillInvocationContext.Arguments); + JObject argsJson = JObject.Parse(skillInvocationContext.Arguments.ToString()); JToken? paramValue = argsJson[this.parameterInfo.Name]; convertedValue = paramValue?.ToObject(destinationType); } diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/BuiltInFunctionsProvider.cs b/src/WebJobs.Extensions.OpenAI/Assistants/BuiltInFunctionsProvider.cs deleted file mode 100644 index db996693..00000000 --- a/src/WebJobs.Extensions.OpenAI/Assistants/BuiltInFunctionsProvider.cs +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -using System.Collections.Immutable; -using System.Reflection; -using Microsoft.Azure.WebJobs.Script.Description; -using Newtonsoft.Json.Linq; - -namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; - -/// -/// Class that defines all the built-in functions for executing CNCF Serverless Workflows. -/// IMPORTANT: Renaming methods in this class is a breaking change! -/// -class BuiltInFunctionsProvider : IFunctionProvider -{ - /// - Task> IFunctionProvider.GetFunctionMetadataAsync() => - Task.FromResult(this.GetFunctionMetadata().ToImmutableArray()); - - // TODO: Not sure what this is for... - /// - ImmutableDictionary> IFunctionProvider.FunctionErrors => - new Dictionary>().ToImmutableDictionary(); - - - internal static string GetBuiltInFunctionName(string functionName) - { - return $"OpenAI::{functionName}"; - } - - /// - /// Returns an enumeration of all the function triggers defined in this class. - /// - IEnumerable GetFunctionMetadata() - { - foreach (MethodInfo method in this.GetType().GetMethods()) - { - if (method.GetCustomAttribute() is not FunctionNameAttribute) - { - // Not an Azure Function definition - continue; - } - - FunctionMetadata metadata = new() - { - // NOTE: We always use the method name and ignore the function name - Name = GetBuiltInFunctionName(method.Name), - ScriptFile = $"assembly:{method.ReflectedType.Assembly.FullName}", - EntryPoint = $"{method.ReflectedType.FullName}.{method.Name}", - Language = "DotNetAssembly", - }; - - // Scan the parameters for binding attributes and add them to the bindings collection - // so that we can register them with the Functions runtime. - foreach (ParameterInfo parameter in method.GetParameters()) - { - if (parameter.GetCustomAttribute() is not null) - { - // NOTE: We assume each OpenAI service function in this file defines the parameter name as "service". - metadata.Bindings.Add(BindingMetadata.Create(new JObject( - new JProperty("type", "openAIService"), - new JProperty("name", "service"), - new JProperty("direction", "in")))); - } - } - - yield return metadata; - } - } -} diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/ChatCompletionsJsonConverter.cs b/src/WebJobs.Extensions.OpenAI/Assistants/ChatCompletionsJsonConverter.cs index f8ad8c8c..6d0b7e49 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/ChatCompletionsJsonConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/ChatCompletionsJsonConverter.cs @@ -4,19 +4,19 @@ using System.ClientModel.Primitives; using System.Text.Json; using System.Text.Json.Serialization; -using Azure.AI.OpenAI; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; -class ChatCompletionsJsonConverter : JsonConverter +class ChatCompletionsJsonConverter : JsonConverter { static readonly ModelReaderWriterOptions modelReaderWriterOptions = new("J"); - public override ChatCompletions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override ChatCompletion Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { throw new NotImplementedException(); } - public override void Write(Utf8JsonWriter writer, ChatCompletions value, JsonSerializerOptions options) + public override void Write(Utf8JsonWriter writer, ChatCompletion value, JsonSerializerOptions options) { - ((IJsonModel)value).Write(writer, modelReaderWriterOptions); + ((IJsonModel)value).Write(writer, modelReaderWriterOptions); } } diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs b/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs index d3e539ba..c11f55bd 100644 --- a/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs +++ b/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs @@ -4,28 +4,28 @@ using System.Reflection; using System.Runtime.ExceptionServices; using System.Text; -using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Host.Executors; using Microsoft.Extensions.Logging; using Newtonsoft.Json; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; public interface IAssistantSkillInvoker { - IList? GetFunctionsDefinitions(); - Task InvokeAsync(ChatCompletionsFunctionToolCall call, CancellationToken cancellationToken); + IList? GetFunctionsDefinitions(); + Task InvokeAsync(ChatToolCall call, CancellationToken cancellationToken); } class SkillInvocationContext { - public SkillInvocationContext(string arguments) + public SkillInvocationContext(BinaryData arguments) { this.Arguments = arguments; } // The arguments are passed as a JSON object in the form of {"paramName":paramValue} - public string Arguments { get; } + public BinaryData Arguments { get; } // The result of the function invocation, if any public object? Result { get; set; } @@ -70,14 +70,14 @@ internal void UnregisterSkill(string name) this.skills.Remove(name); } - IList? IAssistantSkillInvoker.GetFunctionsDefinitions() + IList? IAssistantSkillInvoker.GetFunctionsDefinitions() { if (this.skills.Count == 0) { return null; } - List functions = new(capacity: this.skills.Count); + List functions = new(capacity: this.skills.Count); foreach (Skill skill in this.skills.Values) { // The parameters can be defined in the attribute JSON or can be inferred from @@ -85,12 +85,12 @@ internal void UnregisterSkill(string name) string parametersJson = skill.Attribute.ParameterDescriptionJson ?? JsonConvert.SerializeObject(GetParameterDefinition(skill)); - functions.Add(new ChatCompletionsFunctionToolDefinition - { - Name = skill.Name, - Description = skill.Attribute.FunctionDescription, - Parameters = BinaryData.FromBytes(Encoding.UTF8.GetBytes(parametersJson)), - }); + ChatTool chatTool = ChatTool.CreateFunctionTool( + skill.Name, + skill.Attribute.FunctionDescription, + BinaryData.FromBytes(Encoding.UTF8.GetBytes(parametersJson)) + ); + functions.Add(chatTool); } return functions; @@ -140,7 +140,7 @@ static Dictionary GetParameterDefinition(Skill skill) } async Task IAssistantSkillInvoker.InvokeAsync( - ChatCompletionsFunctionToolCall call, + ChatToolCall call, CancellationToken cancellationToken) { if (call is null) @@ -148,17 +148,17 @@ static Dictionary GetParameterDefinition(Skill skill) throw new ArgumentNullException(nameof(call)); } - if (call.Name is null) + if (call.FunctionName is null) { throw new ArgumentException("The function call must have a name", nameof(call)); } - if (!this.skills.TryGetValue(call.Name, out Skill? skill)) + if (!this.skills.TryGetValue(call.FunctionName, out Skill? skill)) { - throw new InvalidOperationException($"No skill registered with name '{call.Name}'"); + throw new InvalidOperationException($"No skill registered with name '{call.FunctionName}'"); } - SkillInvocationContext skillInvocationContext = new(call.Arguments); + SkillInvocationContext skillInvocationContext = new(call.FunctionArguments); // This call may throw if the Functions host is shutting down or if there is an internal error // in the Functions runtime. We don't currently try to handle these exceptions. @@ -170,7 +170,7 @@ static Dictionary GetParameterDefinition(Skill skill) InvokeHandler = async userCodeInvoker => { // Invoke the function and attempt to get the result. - this.logger.LogInformation("Invoking user-code function '{Name}'", call.Name); + this.logger.LogInformation("Invoking user-code function '{Name}'", call.FunctionName); Task invokeTask = userCodeInvoker.Invoke(); if (invokeTask is Task invokeTaskWithResult) { @@ -182,7 +182,7 @@ static Dictionary GetParameterDefinition(Skill skill) this.logger.LogWarning( "Unable to discover the return value (if any) for user-code function '{Name}'. " + "This is an internal error in the extension that may result in model hallucination.", - call.Name); + call.FunctionName); await invokeTask; } } @@ -205,7 +205,7 @@ static Dictionary GetParameterDefinition(Skill skill) // Convert the output to JSON string jsonResult = JsonConvert.SerializeObject(skillInvocationContext.Result); this.logger.LogInformation( - "Returning output of user-code function '{Name}' as JSON: {Json}", call.Name, jsonResult); + "Returning output of user-code function '{Name}' as JSON: {Json}", call.FunctionName, jsonResult); return jsonResult; } } \ No newline at end of file diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsBaseAttribute.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsBaseAttribute.cs index 7873474f..5ecac174 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsBaseAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsBaseAttribute.cs @@ -25,10 +25,29 @@ public class EmbeddingsBaseAttribute : Attribute /// Thrown if is null. public EmbeddingsBaseAttribute(string input, InputType inputType) { - this.Input = input ?? throw new ArgumentNullException(nameof(input)); + this.Input = string.IsNullOrEmpty(input) + ? throw new ArgumentException("Input cannot be null or empty.", nameof(input)) + : input; this.InputType = inputType; } + /// + /// Gets or sets the name of the configuration section for AI service connectivity settings. + /// + /// + /// This property specifies the name of the configuration section that contains connection details for the AI service. + /// + /// For Azure OpenAI: + /// - If specified, looks for "Endpoint" and "Key" values in this configuration section + /// - If not specified or the section doesn't exist, falls back to environment variables: + /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY + /// - For user-assigned managed identity authentication, configuration section is required + /// + /// For OpenAI: + /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable. + /// + public string AIConnectionName { get; set; } = ""; + /// /// Gets or sets the ID of the model to use. /// @@ -36,7 +55,7 @@ public EmbeddingsBaseAttribute(string input, InputType inputType) /// Changing the default embeddings model is a breaking change, since any changes will be stored in a vector database for lookup. Changing the default model can cause the lookups to start misbehaving if they don't match the data that was previously ingested into the vector database. /// [AutoResolve] - public string Model { get; set; } = OpenAIModels.DefaultEmbeddingsModel; + public string EmbeddingsModel { get; set; } = OpenAIModels.DefaultEmbeddingsModel; /// /// Gets or sets the maximum number of characters to chunk the input into. diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs index c6b90c1a..a890fb57 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using OpenAISDK = Azure.AI.OpenAI; +using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; @@ -12,7 +12,7 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; /// The embeddings response that was received from OpenAI. public class EmbeddingsContext { - public EmbeddingsContext(OpenAISDK.EmbeddingsOptions Request, OpenAISDK.Embeddings? Response) + public EmbeddingsContext(IList Request, OpenAIEmbeddingCollection? Response) { this.Request = Request; this.Response = Response; @@ -21,15 +21,15 @@ public EmbeddingsContext(OpenAISDK.EmbeddingsOptions Request, OpenAISDK.Embeddin /// /// Embeddings request sent to OpenAI. /// - public OpenAISDK.EmbeddingsOptions Request { get; set; } + public IList Request { get; set; } /// /// Embeddings response from OpenAI. /// - public OpenAISDK.Embeddings? Response { get; set; } + public OpenAIEmbeddingCollection? Response { get; set; } /// /// Gets the number of embeddings that were returned in the response. /// - public int Count => this.Response?.Data?.Count ?? 0; + public int Count => this.Response?.Count ?? 0; } diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs index ebe439c7..47c8807e 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsContextConverter.cs @@ -4,7 +4,7 @@ using System.ClientModel.Primitives; using System.Text.Json; using System.Text.Json.Serialization; -using OpenAISDK = Azure.AI.OpenAI; +using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; @@ -23,9 +23,14 @@ public override void Write(Utf8JsonWriter writer, EmbeddingsContext value, JsonS { writer.WriteStartObject(); writer.WritePropertyName("request"u8); - ((IJsonModel)value.Request).Write(writer, modelReaderWriterOptions); - if (value.Response is IJsonModel response) + if (value.Request is List inputList) + { + var inputWrapper = JsonModelListWrapper.FromList(inputList); + inputWrapper.Write(writer, modelReaderWriterOptions); + } + + if (value.Response is IJsonModel response) { writer.WritePropertyName("response"u8); response.Write(writer, modelReaderWriterOptions); diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs index 28fc990b..364f552b 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsConverter.cs @@ -2,10 +2,8 @@ // Licensed under the MIT License. using System.Text.Json; -using Azure; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; using Microsoft.Extensions.Logging; -using OpenAISDK = Azure.AI.OpenAI; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; @@ -13,7 +11,7 @@ class EmbeddingsConverter : IAsyncConverter, IAsyncConverter { - readonly OpenAISDK.OpenAIClient openAIClient; + readonly OpenAIClientFactory openAIClientFactory; readonly ILogger logger; // Note: we need this converter as Azure.AI.OpenAI does not support System.Text.Json serialization since their constructors are internal @@ -22,9 +20,11 @@ class EmbeddingsConverter : Converters = { new EmbeddingsContextConverter(), new SearchableDocumentJsonConverter() } }; - public EmbeddingsConverter(OpenAISDK.OpenAIClient openAIClient, ILoggerFactory loggerFactory) + public EmbeddingsConverter( + OpenAIClientFactory openAIClientFactory, + ILoggerFactory loggerFactory) { - this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient)); + this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory)); this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); } @@ -47,11 +47,10 @@ async Task ConvertCoreAsync( EmbeddingsAttribute attribute, CancellationToken cancellationToken) { - OpenAISDK.EmbeddingsOptions request = await EmbeddingsHelper.BuildRequest(attribute.MaxOverlap, attribute.MaxChunkLength, attribute.Model, attribute.InputType, attribute.Input); - this.logger.LogInformation("Sending OpenAI embeddings request: {request}", request.Input); - Response response = await this.openAIClient.GetEmbeddingsAsync(request, cancellationToken); - this.logger.LogInformation("Received OpenAI embeddings count: {response}", response.Value.Data.Count); - - return new EmbeddingsContext(request, response); + return await EmbeddingsHelper.GenerateEmbeddingsAsync( + attribute, + this.openAIClientFactory, + this.logger, + cancellationToken); } } diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs index 43125c3e..0b26aac6 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsHelper.cs @@ -1,8 +1,10 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.ClientModel; using System.Diagnostics; -using Azure.AI.OpenAI; +using Microsoft.Extensions.Logging; +using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; static class EmbeddingsHelper @@ -17,16 +19,35 @@ static EmbeddingsHelper() httpClient.DefaultRequestHeaders.UserAgent.ParseAdd(UserAgent); } - public static async Task BuildRequest(int maxOverlap, int maxChunkLength, string model, InputType inputType, string input) + internal static async Task GenerateEmbeddingsAsync( + EmbeddingsBaseAttribute attribute, + OpenAIClientFactory openAIClientFactory, + ILogger logger, + CancellationToken cancellationToken = default) { - using TextReader reader = await GetTextReader(inputType, input); - if (maxOverlap >= maxChunkLength) + List chunks = await BuildRequest(attribute); + + logger.LogInformation("Sending OpenAI embeddings request"); + + ClientResult response = await openAIClientFactory.GetEmbeddingClient( + attribute.AIConnectionName, + attribute.EmbeddingsModel).GenerateEmbeddingsAsync(chunks, cancellationToken: cancellationToken); + + logger.LogInformation("Received OpenAI embeddings count: {count}", response.Value.Count); + + return new EmbeddingsContext(chunks, response); + } + + static async Task> BuildRequest(EmbeddingsBaseAttribute attribute) + { + using TextReader reader = await GetTextReader(attribute.InputType, attribute.Input); + if (attribute.MaxOverlap >= attribute.MaxChunkLength) { - throw new ArgumentOutOfRangeException($"MaxOverlap ({maxOverlap}) must be less than MaxChunkLength ({maxChunkLength})."); + throw new ArgumentOutOfRangeException($"MaxOverlap ({attribute.MaxOverlap}) must be less than MaxChunkLength ({attribute.MaxChunkLength})."); } - List chunks = GetTextChunks(reader, 0, maxChunkLength, maxOverlap).ToList(); - return new EmbeddingsOptions(model, chunks); + List chunks = GetTextChunks(reader, 0, attribute.MaxChunkLength, attribute.MaxOverlap).ToList(); + return chunks; } static async Task GetTextReader(InputType inputType, string input) @@ -41,6 +62,12 @@ static async Task GetTextReader(InputType inputType, string input) } else if (inputType == InputType.Url) { + if (!Uri.TryCreate(input, UriKind.Absolute, out Uri? uriResult) || + uriResult.Scheme != Uri.UriSchemeHttps) + { + throw new ArgumentException($"Invalid Url: {input}. Ensure it is a valid https Url."); + } + Stream stream = await httpClient.GetStreamAsync(input); return new StreamReader(stream); } diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs index e1705ac5..2b5fdacb 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs @@ -4,19 +4,19 @@ using System.ClientModel.Primitives; using System.Text.Json; using System.Text.Json.Serialization; -using Azure.AI.OpenAI; +using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; -class EmbeddingsOptionsJsonConverter : JsonConverter +class EmbeddingsOptionsJsonConverter : JsonConverter { static readonly ModelReaderWriterOptions modelReaderWriterOptions = new("J"); - public override EmbeddingsOptions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override EmbeddingGenerationOptions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { throw new NotImplementedException(); } - public override void Write(Utf8JsonWriter writer, EmbeddingsOptions value, JsonSerializerOptions options) + public override void Write(Utf8JsonWriter writer, EmbeddingGenerationOptions value, JsonSerializerOptions options) { - ((IJsonModel)value).Write(writer, modelReaderWriterOptions); + ((IJsonModel)value).Write(writer, modelReaderWriterOptions); } } diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreAttribute.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreAttribute.cs index 1355404b..29ec64d3 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreAttribute.cs @@ -16,26 +16,26 @@ public sealed class EmbeddingsStoreAttribute : EmbeddingsBaseAttribute /// The input source containing the data to generate embeddings for /// and is interpreted based on the value for . /// The type of the input. - /// + /// /// The name of an app setting or environment variable which contains a connection string value. /// /// The name of the collection or table to search or store. /// /// Thrown if or or are null. /// - public EmbeddingsStoreAttribute(string input, InputType inputType, string connectionName, string collection) : base(input, inputType) + public EmbeddingsStoreAttribute(string input, InputType inputType, string storeConnectionName, string collection) : base(input, inputType) { - this.ConnectionName = connectionName ?? throw new ArgumentNullException(nameof(connectionName)); + this.StoreConnectionName = storeConnectionName ?? throw new ArgumentNullException(nameof(storeConnectionName)); this.Collection = collection ?? throw new ArgumentNullException(nameof(collection)); } /// - /// Gets or sets the name of an app setting or environment variable which contains a connection string value. + /// Gets or sets the name of an app setting or environment variable which contains a connection string value for embedding store. /// /// /// This property supports binding expressions. /// - public string ConnectionName { get; set; } + public string StoreConnectionName { get; set; } /// /// The name of the collection or table to search. diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs index f51285d5..73be52a4 100644 --- a/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/EmbeddingsStoreConverter.cs @@ -2,17 +2,15 @@ // Licensed under the MIT License. using System.Text.Json; -using Azure; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using OpenAISDK = Azure.AI.OpenAI; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; class EmbeddingsStoreConverter : IAsyncConverter> { - readonly OpenAISDK.OpenAIClient openAIClient; + readonly OpenAIClientFactory openAIClientFactory; readonly ILogger logger; readonly ISearchProvider? searchProvider; @@ -22,12 +20,13 @@ class EmbeddingsStoreConverter : Converters = { new EmbeddingsContextConverter(), new SearchableDocumentJsonConverter() } }; - public EmbeddingsStoreConverter(OpenAISDK.OpenAIClient openAIClient, + public EmbeddingsStoreConverter( + OpenAIClientFactory openAIClientFactory, ILoggerFactory loggerFactory, IEnumerable searchProviders, IOptions openAiConfigOptions) { - this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient)); + this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory)); this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); openAiConfigOptions.Value.SearchProvider.TryGetValue("type", out object value); this.searchProvider = searchProviders? @@ -41,7 +40,7 @@ public Task> ConvertAsync(EmbeddingsStoreAtt throw new InvalidOperationException( "No search provider is configured. Search providers are configured in the host.json file. For .NET apps, the appropriate nuget package must also be added to the app's project file."); } - IAsyncCollector collector = new SemanticDocumentCollector(input, this.searchProvider, this.openAIClient, this.logger); + IAsyncCollector collector = new SemanticDocumentCollector(input, this.searchProvider, this.openAIClientFactory, this.logger); return Task.FromResult(collector); } @@ -62,20 +61,23 @@ sealed class SemanticDocumentCollector : IAsyncCollector { readonly EmbeddingsStoreAttribute attribute; readonly ISearchProvider searchProvider; - readonly OpenAISDK.OpenAIClient openAIClient; + readonly OpenAIClientFactory openAIClientFactory; readonly ILogger logger; - public SemanticDocumentCollector(EmbeddingsStoreAttribute attribute, ISearchProvider searchProvider, OpenAISDK.OpenAIClient openAIClient, ILogger logger) + public SemanticDocumentCollector(EmbeddingsStoreAttribute attribute, + ISearchProvider searchProvider, + OpenAIClientFactory openAIClientFactory, + ILogger logger) { this.attribute = attribute; this.searchProvider = searchProvider; - this.openAIClient = openAIClient; + this.openAIClientFactory = openAIClientFactory; this.logger = logger; } public async Task AddAsync(SearchableDocument item, CancellationToken cancellationToken = default) { - if (string.IsNullOrEmpty(this.attribute.ConnectionName)) + if (string.IsNullOrEmpty(this.attribute.StoreConnectionName)) { throw new InvalidOperationException("No connection string information was provided."); } @@ -85,15 +87,12 @@ public async Task AddAsync(SearchableDocument item, CancellationToken cancellati } // Get embeddings from OpenAI - OpenAISDK.EmbeddingsOptions request = await EmbeddingsHelper.BuildRequest(this.attribute.MaxOverlap, this.attribute.MaxChunkLength, this.attribute.Model, this.attribute.InputType, this.attribute.Input); - this.logger.LogInformation("Sending OpenAI embeddings request to deployment: {deploymentName}", request.DeploymentName); - Response response = await this.openAIClient.GetEmbeddingsAsync(request, cancellationToken); - EmbeddingsContext embeddingsContext = new(request, response); - this.logger.LogInformation("Received OpenAI embeddings of count: {count}", embeddingsContext.Count); + EmbeddingsContext embeddingsContext = await EmbeddingsHelper. + GenerateEmbeddingsAsync(this.attribute, this.openAIClientFactory, this.logger, cancellationToken); // Add document to the embed store item.Embeddings = embeddingsContext; - item.ConnectionInfo = new ConnectionInfo(this.attribute.ConnectionName, this.attribute.Collection); + item.ConnectionInfo = new ConnectionInfo(this.attribute.StoreConnectionName, this.attribute.Collection); this.logger.LogInformation("Adding document to the embed store."); await this.searchProvider.AddDocumentAsync(item, cancellationToken); this.logger.LogInformation("Finished adding document to the embed store."); diff --git a/src/WebJobs.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs b/src/WebJobs.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs new file mode 100644 index 00000000..0a4277ee --- /dev/null +++ b/src/WebJobs.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs @@ -0,0 +1,49 @@ +using System.ClientModel.Primitives; +using System.Text.Json; + +namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; + +class JsonModelListWrapper : IJsonModel> +{ + readonly List list; + + public JsonModelListWrapper(List list) + { + this.list = list; + } + + public void Write(Utf8JsonWriter writer, ModelReaderWriterOptions options) + { + writer.WriteStartArray(); + foreach (string item in this.list) + { + writer.WriteStringValue(item); + } + writer.WriteEndArray(); + } + + public static JsonModelListWrapper FromList(List list) + { + return new JsonModelListWrapper(list); + } + + public List Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public BinaryData Write(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public List Create(BinaryData data, ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } + + public string GetFormatFromOptions(ModelReaderWriterOptions options) + { + throw new NotImplementedException(); + } +} \ No newline at end of file diff --git a/src/WebJobs.Extensions.OpenAI/Models/ChatMessage.cs b/src/WebJobs.Extensions.OpenAI/Models/AssistantMessage.cs similarity index 71% rename from src/WebJobs.Extensions.OpenAI/Models/ChatMessage.cs rename to src/WebJobs.Extensions.OpenAI/Models/AssistantMessage.cs index b52e0976..68bdc20b 100644 --- a/src/WebJobs.Extensions.OpenAI/Models/ChatMessage.cs +++ b/src/WebJobs.Extensions.OpenAI/Models/AssistantMessage.cs @@ -9,18 +9,19 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; /// Chat Message Entity which contains the content of the message, the role of the chat agent, and the name of the calling function if applicable. /// [JsonObject(MemberSerialization.OptIn)] -public class ChatMessage +public class AssistantMessage { /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// /// The content of the message. /// The role of the chat agent. - public ChatMessage(string content, string role, string? name) + /// The tool calls. + public AssistantMessage(string content, string role, string toolCalls) { this.Content = content; this.Role = role; - this.Name = name; + this.ToolCalls = toolCalls; } /// @@ -36,8 +37,8 @@ public ChatMessage(string content, string role, string? name) public string Role { get; set; } /// - /// Gets or sets the name of the calling function if applicable. + /// Gets or sets the tool calls. /// - [JsonProperty("name")] - public string? Name { get; set; } + [JsonProperty("toolCalls")] + public string ToolCalls { get; set; } } diff --git a/src/WebJobs.Extensions.OpenAI/Models/AssistantState.cs b/src/WebJobs.Extensions.OpenAI/Models/AssistantState.cs index bb1775d7..6dcd68d1 100644 --- a/src/WebJobs.Extensions.OpenAI/Models/AssistantState.cs +++ b/src/WebJobs.Extensions.OpenAI/Models/AssistantState.cs @@ -18,7 +18,7 @@ public AssistantState( DateTime LastUpdatedAt, int TotalMessages, int TotalTokens, - IReadOnlyList RecentMessages) + IReadOnlyList RecentMessages) { this.Id = Id; this.Exists = Exists; @@ -69,5 +69,5 @@ public AssistantState( /// Gets a list of the recent messages from the assistant. /// [JsonProperty("recentMessages")] - public IReadOnlyList RecentMessages { get; set; } = Array.Empty(); + public IReadOnlyList RecentMessages { get; set; } = Array.Empty(); } diff --git a/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs b/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs index 022350d9..00b72bb7 100644 --- a/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs +++ b/src/WebJobs.Extensions.OpenAI/Models/ChatMessageTableEntity.cs @@ -1,9 +1,11 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Runtime.Serialization; +using System.Text.Json; using Azure; -using Azure.AI.OpenAI; using Azure.Data.Tables; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; @@ -19,8 +21,9 @@ public ChatMessageTableEntity( string partitionKey, int messageIndex, string content, - ChatRole role, - string? name = null) + ChatMessageRole role, + string? name = null, + IEnumerable? toolCalls = null) { this.PartitionKey = partitionKey; this.RowKey = GetRowKey(messageIndex); @@ -28,6 +31,7 @@ public ChatMessageTableEntity( this.Role = role.ToString(); this.Name = name; this.CreatedAt = DateTime.UtcNow; + this.ToolCalls = toolCalls?.ToList(); } public ChatMessageTableEntity(TableEntity entity) @@ -40,6 +44,7 @@ public ChatMessageTableEntity(TableEntity entity) this.Role = entity.GetString(nameof(this.Role)); this.Name = entity.GetString(nameof(this.Name)); this.CreatedAt = DateTime.SpecifyKind(entity.GetDateTime(nameof(this.CreatedAt)).GetValueOrDefault(), DateTimeKind.Utc); + this.ToolCallsString = entity.GetString(nameof(this.ToolCalls)); } /// @@ -82,10 +87,98 @@ public ChatMessageTableEntity(TableEntity entity) /// public DateTime CreatedAt { get; set; } + /// + /// Gets or sets the ToolCalls for Assistant + /// + [IgnoreDataMember] + public IList? ToolCalls { get; set; } + // WARNING: Changing this is a breaking change! static string GetRowKey(int messageNumber) { // Example msg-001B return string.Concat(RowKeyPrefix, messageNumber.ToString("X4")); } + + /// + /// Converts the ToolCalls to a Json string for table storage + /// + [DataMember(Name = "ToolCalls")] + public string ToolCallsString + { + get + { + if (this.ToolCalls == null || this.ToolCalls.Count == 0) + { + return string.Empty; + } + + IList cloneList = this.SerializeChatTool(this.ToolCalls); + var options = new JsonSerializerOptions { WriteIndented = false, PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; + return JsonSerializer.Serialize(cloneList, options); + } + set + { + if (!string.IsNullOrEmpty(value)) + { + JsonSerializerOptions options = new() { PropertyNamingPolicy = JsonNamingPolicy.CamelCase }; + List? cloneList = JsonSerializer.Deserialize>(value, options); + this.ToolCalls = cloneList != null ? this.DeserializeChatTool(cloneList) : new List(); + } + else + { + this.ToolCalls = new List(); + } + } + } + + IList SerializeChatTool(IList toolCalls) + { + IList chatToolCloneList = new List(); + foreach (ChatToolCall toolCall in toolCalls) + { + ChatToolCallClone chatToolClone = new(toolCall.Id, toolCall.FunctionName, toolCall.FunctionArguments.ToString(), toolCall.Kind.ToString()); + chatToolCloneList.Add(chatToolClone); + } + return chatToolCloneList; + } + + IList DeserializeChatTool(IList clones) + { + IList result = new List(); + foreach (ChatToolCallClone clone in clones) + { + JsonElement functionArgs = JsonDocument.Parse(clone.FunctionArguments).RootElement; + ChatToolCall toolCall = ChatToolCall.CreateFunctionToolCall(clone.Id, clone.FunctionName, BinaryData.FromString(functionArgs.GetRawText())); + result.Add(toolCall); + } + return result; + } } + +class ChatToolCallClone +{ + public ChatToolCallClone() + { + this.Id = string.Empty; + this.FunctionName = string.Empty; + this.FunctionArguments = string.Empty; + this.Kind = string.Empty; + } + + internal ChatToolCallClone(string id, string functionName, string functionArguments, string kind) + { + this.Id = id; + this.FunctionName = functionName; + this.Kind = kind; + this.FunctionArguments = functionArguments; + } + + public string Id { get; set; } + + public string FunctionName { get; set; } + + public string FunctionArguments { get; set; } + + public string Kind { get; set; } +} \ No newline at end of file diff --git a/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs b/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs new file mode 100644 index 00000000..51adb201 --- /dev/null +++ b/src/WebJobs.Extensions.OpenAI/OpenAIClientFactory.cs @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Collections.Concurrent; +using Azure; +using Azure.AI.OpenAI; +using Azure.Core; +using Azure.Identity; +using Microsoft.Extensions.Azure; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using OpenAI; +using OpenAI.Chat; +using OpenAI.Embeddings; + +namespace Microsoft.Azure.WebJobs.Extensions.OpenAI; + +public class OpenAIClientFactory +{ + readonly IConfiguration configuration; + readonly AzureComponentFactory azureComponentFactory; + readonly ILogger logger; + readonly ConcurrentDictionary azureOpenAIclients = new(); + readonly ConcurrentDictionary openAIClients = new(); + readonly ConcurrentDictionary chatClients = new(); // key is ai endpoint, model + readonly ConcurrentDictionary embeddingClients = new(); // key is ai endpoint, model + string aiEndpoint = string.Empty; + + public OpenAIClientFactory( + IConfiguration configuration, + AzureComponentFactory azureComponentFactory, + ILoggerFactory loggerFactory) + { + this.configuration = configuration ?? throw new ArgumentNullException(nameof(configuration)); + this.azureComponentFactory = azureComponentFactory ?? throw new ArgumentNullException(nameof(azureComponentFactory)); + this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); + } + + public ChatClient GetChatClient(string aiConnectionName, string model) + { + HasOpenAIKey(out bool hasOpenAIKey, out string openAIKey); + ChatClient chatClient; + (chatClient, string endpoint, string chatModel) = this.chatClients.GetOrAdd( + hasOpenAIKey ? "OpenAI" : aiConnectionName, + name => + { + if (!hasOpenAIKey) + { + AzureOpenAIClient azureOpenAIClient = this.CreateClientFromConfigSection(aiConnectionName); + return (azureOpenAIClient.GetChatClient(model), this.aiEndpoint, model); + } + else + { + OpenAIClient openAIClient = this.CreateOpenAIClient(openAIKey); + return (openAIClient.GetChatClient(model), this.aiEndpoint, model); + } + }); + + return chatClient; + } + + public EmbeddingClient GetEmbeddingClient(string aiConnectionName, string model) + { + HasOpenAIKey(out bool hasOpenAIKey, out string openAIKey); + EmbeddingClient embeddingClient; + (embeddingClient, string endpoint, string embeddingModel) = this.embeddingClients.GetOrAdd( + hasOpenAIKey ? "OpenAI" : aiConnectionName, + name => + { + if (!hasOpenAIKey) + { + AzureOpenAIClient azureOpenAIClient = this.CreateClientFromConfigSection(aiConnectionName); + return (azureOpenAIClient.GetEmbeddingClient(model), this.aiEndpoint, model); + } + else + { + OpenAIClient openAIClient = this.CreateOpenAIClient(openAIKey); + return (openAIClient.GetEmbeddingClient(model), this.aiEndpoint, model); + } + }); + + return embeddingClient; + } + + static void HasOpenAIKey(out bool hasOpenAIKey, out string openAIKey) + { + hasOpenAIKey = false; + openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + if (!string.IsNullOrEmpty(openAIKey)) + { + hasOpenAIKey = true; + } + } + + AzureOpenAIClient CreateClientFromConfigSection(string aiConnectionName) + { + IConfigurationSection section = this.configuration.GetSection(aiConnectionName); + + if (!section.Exists()) + { + this.logger.LogInformation($"Configuration section for Azure OpenAI not found, trying fallback to environment variables - AZURE_OPENAI_ENDPOINT and/or AZURE_OPENAI_KEY"); + } + + this.aiEndpoint = section?.GetValue("Endpoint") ?? Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); + string? azureOpenAIKey = section?.GetValue("Key") ?? Environment.GetEnvironmentVariable("AZURE_OPENAI_KEY"); + + if (!string.IsNullOrEmpty(this.aiEndpoint)) + { + this.logger.LogInformation($"Using Azure OpenAI endpoint: {this.aiEndpoint}"); + if (!string.IsNullOrEmpty(azureOpenAIKey)) + { + this.logger.LogInformation($"Authenticating using Azure OpenAI Key."); + return this.CreateAzureOpenAIClient(this.aiEndpoint, azureOpenAIKey); + } + else + { + this.logger.LogInformation($"Authenticating using Azure OpenAI TokenCredential."); + + TokenCredential tokenCredential = section.Exists() ? + this.azureComponentFactory.CreateTokenCredential(section) : + new DefaultAzureCredential(); + return this.CreateAzureOpenAIClientWithTokenCredential(this.aiEndpoint, tokenCredential); + } + } + + string errorMessage = $"Configuration section '{aiConnectionName}' is missing required 'Endpoint' and/or 'Key' values."; + this.logger.LogError(errorMessage); + throw new InvalidOperationException(errorMessage); + } + + AzureOpenAIClient CreateAzureOpenAIClient(string endpoint, string apiKey) + { + string key = $"{endpoint}-{apiKey}"; + return this.azureOpenAIclients.GetOrAdd(key, _ => new AzureOpenAIClient(new Uri(endpoint), new AzureKeyCredential(apiKey))); + } + + AzureOpenAIClient CreateAzureOpenAIClientWithTokenCredential(string endpoint, TokenCredential tokenCredential) + { + return this.azureOpenAIclients.GetOrAdd(endpoint, _ => new AzureOpenAIClient(new Uri(endpoint), tokenCredential)); + } + + OpenAIClient CreateOpenAIClient(string openAIKey) + { + this.logger.LogInformation($"Authenticating using OpenAI Key."); + return this.openAIClients.GetOrAdd(openAIKey, _ => new OpenAIClient(openAIKey)); + } +} \ No newline at end of file diff --git a/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs b/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs index 4204b525..c9c31bc8 100644 --- a/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs +++ b/src/WebJobs.Extensions.OpenAI/OpenAIExtension.cs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Description; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; @@ -15,29 +14,29 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI; [Extension("OpenAI")] partial class OpenAIExtension : IExtensionConfigProvider { - readonly OpenAIClient openAIClient; + readonly OpenAIClientFactory openAIClientFactory; readonly TextCompletionConverter textCompletionConverter; readonly EmbeddingsConverter embeddingsConverter; readonly EmbeddingsStoreConverter embeddingsStoreConverter; readonly SemanticSearchConverter semanticSearchConverter; - readonly AssistantBindingConverter chatBotConverter; + readonly AssistantBindingConverter assistantConverter; readonly AssistantSkillTriggerBindingProvider assistantskillTriggerBindingProvider; public OpenAIExtension( - OpenAIClient openAIClient, + OpenAIClientFactory openAIClientFactory, TextCompletionConverter textCompletionConverter, EmbeddingsConverter embeddingsConverter, EmbeddingsStoreConverter embeddingsStoreConverter, SemanticSearchConverter semanticSearchConverter, - AssistantBindingConverter chatBotConverter, + AssistantBindingConverter assistantConverter, AssistantSkillTriggerBindingProvider assistantTriggerBindingProvider) { - this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient)); + this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory)); this.textCompletionConverter = textCompletionConverter ?? throw new ArgumentNullException(nameof(textCompletionConverter)); this.embeddingsConverter = embeddingsConverter ?? throw new ArgumentNullException(nameof(embeddingsConverter)); this.embeddingsStoreConverter = embeddingsStoreConverter ?? throw new ArgumentNullException(nameof(embeddingsStoreConverter)); this.semanticSearchConverter = semanticSearchConverter ?? throw new ArgumentNullException(nameof(semanticSearchConverter)); - this.chatBotConverter = chatBotConverter ?? throw new ArgumentNullException(nameof(chatBotConverter)); + this.assistantConverter = assistantConverter ?? throw new ArgumentNullException(nameof(assistantConverter)); this.assistantskillTriggerBindingProvider = assistantTriggerBindingProvider ?? throw new ArgumentNullException(nameof(assistantTriggerBindingProvider)); } @@ -64,24 +63,24 @@ void IExtensionConfigProvider.Initialize(ExtensionConfigContext context) semanticSearchRule.BindToInput(this.semanticSearchConverter); // Assistant support - var chatBotCreateRule = context.AddBindingRule(); - chatBotCreateRule.BindToCollector(this.chatBotConverter); - context.AddConverter(this.chatBotConverter.ToAssistantCreateRequest); - context.AddConverter(this.chatBotConverter.ToAssistantCreateRequest); + var assistantCreateRule = context.AddBindingRule(); + assistantCreateRule.BindToCollector(this.assistantConverter); + context.AddConverter(this.assistantConverter.ToAssistantCreateRequest); + context.AddConverter(this.assistantConverter.ToAssistantCreateRequest); - var chatBotPostRule = context.AddBindingRule(); - chatBotPostRule.BindToInput(this.chatBotConverter); - chatBotPostRule.BindToInput(this.chatBotConverter); + var assistantPostRule = context.AddBindingRule(); + assistantPostRule.BindToInput(this.assistantConverter); + assistantPostRule.BindToInput(this.assistantConverter); - var chatBotQueryRule = context.AddBindingRule(); - chatBotQueryRule.BindToInput(this.chatBotConverter); - chatBotQueryRule.BindToInput(this.chatBotConverter); + var assistantQueryRule = context.AddBindingRule(); + assistantQueryRule.BindToInput(this.assistantConverter); + assistantQueryRule.BindToInput(this.assistantConverter); // Assistant skill trigger support context.AddBindingRule() .BindToTrigger(this.assistantskillTriggerBindingProvider); // OpenAI service input binding support (NOTE: This may be removed in a future version.) - context.AddBindingRule().BindToInput(_ => this.openAIClient); + context.AddBindingRule().BindToInput(_ => this.openAIClientFactory); } } diff --git a/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs b/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs index 3b49603b..733d7956 100644 --- a/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs +++ b/src/WebJobs.Extensions.OpenAI/OpenAIWebJobsBuilderExtensions.cs @@ -1,9 +1,6 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using Azure; -using Azure.AI.OpenAI; -using Azure.Identity; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; @@ -31,31 +28,9 @@ public static IWebJobsBuilder AddOpenAIBindings(this IWebJobsBuilder builder) throw new ArgumentNullException(nameof(builder)); } - // Register the client for Azure Open AI - Uri? azureOpenAIEndpoint = GetAzureOpenAIEndpoint(); - string? openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); - string? azureOpenAIKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_KEY"); - - if (azureOpenAIEndpoint != null && !string.IsNullOrEmpty(azureOpenAIKey)) - { - RegisterAzureOpenAIClient(builder.Services, azureOpenAIEndpoint, azureOpenAIKey); - } - else if (azureOpenAIEndpoint != null) - { - RegisterAzureOpenAIADAuthClient(builder.Services, azureOpenAIEndpoint); - } - else if (!string.IsNullOrEmpty(openAIKey)) - { - RegisterOpenAIClient(builder.Services, openAIKey); - } - else - { - throw new InvalidOperationException("Must set AZURE_OPENAI_ENDPOINT or OPENAI_API_KEY environment variables."); - } - // Register the WebJobs extension, which enables the bindings. builder.AddExtension(); - + // Service objects that will be used by the extension builder.Services.AddSingleton(); builder.Services.AddSingleton(); @@ -77,36 +52,8 @@ public static IWebJobsBuilder AddOpenAIBindings(this IWebJobsBuilder builder) builder.Services.AddAzureClientsCore(); // Adds AzureComponentFactory - return builder; - } - - static Uri? GetAzureOpenAIEndpoint() - { - if (Uri.TryCreate(Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"), UriKind.Absolute, out var uri)) - { - return uri; - } - - return null; - } - - static void RegisterAzureOpenAIClient(IServiceCollection services, Uri azureOpenAIEndpoint, string azureOpenAIKey) - { - services.AddAzureClients(clientBuilder => - { - clientBuilder.AddOpenAIClient(azureOpenAIEndpoint, new AzureKeyCredential(azureOpenAIKey)); - }); - } + builder.Services.AddSingleton(); - static void RegisterAzureOpenAIADAuthClient(IServiceCollection services, Uri azureOpenAIEndpoint) - { - var managedIdentityClient = new OpenAIClient(azureOpenAIEndpoint, new DefaultAzureCredential()); - services.AddSingleton(managedIdentityClient); - } - - static void RegisterOpenAIClient(IServiceCollection services, string openAIKey) - { - var openAIClient = new OpenAIClient(openAIKey); - services.AddSingleton(openAIClient); + return builder; } } diff --git a/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs b/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs index 0bbf9b55..d81919b5 100644 --- a/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs @@ -5,7 +5,7 @@ using System.Text.Json; using System.Text.Json.Serialization; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; -using OpenAISDK = Azure.AI.OpenAI; +using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; class SearchableDocumentJsonConverter : JsonConverter @@ -16,8 +16,8 @@ public override SearchableDocument Read(ref Utf8JsonReader reader, Type typeToCo using JsonDocument jsonDocument = JsonDocument.ParseValue(ref reader); // Properties for SearchableDocument - OpenAISDK.EmbeddingsOptions embeddingsOptions = new(); - OpenAISDK.Embeddings? embeddings = null; + IList input = new List(); + OpenAIEmbeddingCollection? embeddings = null; int count; string title = string.Empty; string connectionName = string.Empty; @@ -31,11 +31,16 @@ public override SearchableDocument Read(ref Utf8JsonReader reader, Type typeToCo { if (embeddingContextItem.NameEquals("request"u8)) { - embeddingsOptions = ModelReaderWriter.Read(BinaryData.FromString(embeddingContextItem.Value.GetRawText()))!; + // Parse the array of string inputs + input = new List(); + foreach (JsonElement element in embeddingContextItem.Value.EnumerateArray()) + { + input.Add(element.GetString() ?? string.Empty); + } } if (embeddingContextItem.NameEquals("response"u8)) { - embeddings = ModelReaderWriter.Read(BinaryData.FromString(embeddingContextItem.Value.GetRawText()))!; + embeddings = ModelReaderWriter.Read(BinaryData.FromString(embeddingContextItem.Value.GetRawText()))!; } if (embeddingContextItem.NameEquals("count"u8)) { @@ -65,7 +70,7 @@ public override SearchableDocument Read(ref Utf8JsonReader reader, Type typeToCo } SearchableDocument searchableDocument = new SearchableDocument(title) { - Embeddings = new EmbeddingsContext(embeddingsOptions, embeddings), + Embeddings = new EmbeddingsContext(input, embeddings), ConnectionInfo = new ConnectionInfo(connectionName, collectionName), }; return searchableDocument; @@ -78,13 +83,14 @@ public override void Write(Utf8JsonWriter writer, SearchableDocument value, Json writer.WritePropertyName("embeddingsContext"u8); writer.WriteStartObject(); - if (value.Embeddings?.Request is IJsonModel request) + if (value.Embeddings?.Request is List inputList) { writer.WritePropertyName("request"u8); - request.Write(writer, modelReaderWriterOptions); + var inputWrapper = JsonModelListWrapper.FromList(inputList); + inputWrapper.Write(writer, modelReaderWriterOptions); } - if (value.Embeddings?.Response is IJsonModel response) + if (value.Embeddings?.Response is IJsonModel response) { writer.WritePropertyName("response"u8); response.Write(writer, modelReaderWriterOptions); diff --git a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchAttribute.cs b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchAttribute.cs index 70b3e742..dce23d5d 100644 --- a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchAttribute.cs @@ -2,6 +2,7 @@ // Licensed under the MIT License. using Microsoft.Azure.WebJobs.Description; +using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; @@ -11,22 +12,22 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; /// [Binding] [AttributeUsage(AttributeTargets.Parameter)] -public sealed class SemanticSearchAttribute : Attribute +public sealed class SemanticSearchAttribute : AssistantBaseAttribute { /// /// Initializes a new instance of the class with the specified connection /// and collection names. /// - /// - /// The name of an app setting or environment variable which contains a connection string value. + /// + /// The name of an app setting or environment variable which contains a connection string value of search provider. /// /// The name of the collection or table to search or store. /// - /// Thrown if either or are null. + /// Thrown if either or are null. /// - public SemanticSearchAttribute(string connectionName, string collection) + public SemanticSearchAttribute(string searchConnectionName, string collection) { - this.ConnectionName = connectionName ?? throw new ArgumentNullException(nameof(connectionName)); + this.SearchConnectionName = searchConnectionName ?? throw new ArgumentNullException(nameof(searchConnectionName)); this.Collection = collection ?? throw new ArgumentNullException(nameof(collection)); } @@ -37,7 +38,7 @@ public SemanticSearchAttribute(string connectionName, string collection) /// This property supports binding expressions. /// [AutoResolve] - public string ConnectionName { get; set; } + public string SearchConnectionName { get; set; } /// /// The name of the collection or table or index to search. @@ -68,16 +69,6 @@ public SemanticSearchAttribute(string connectionName, string collection) [AutoResolve] public string EmbeddingsModel { get; set; } = OpenAIModels.DefaultEmbeddingsModel; - /// - /// Gets or sets the name of the Large Language Model to invoke for chat responses. - /// The default value is "gpt-3.5-turbo". - /// - /// - /// This property supports binding expressions. - /// - [AutoResolve] - public string ChatModel { get; set; } = OpenAIModels.DefaultChatModel; - /// /// Gets or sets the system prompt to use for prompting the large language model. /// diff --git a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchContext.cs b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchContext.cs index adb000f7..a0bf8073 100644 --- a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchContext.cs +++ b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchContext.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; using Newtonsoft.Json; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; @@ -18,7 +18,7 @@ public class SemanticSearchContext /// /// The embeddings context associated with the semantic search. /// The chat response from the large language model. - public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletions Chat) + public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletion Chat) { this.Embeddings = Embeddings; this.Chat = Chat; @@ -34,11 +34,11 @@ public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletions Chat) /// Chat response from the chat completions request. /// [JsonProperty("chat")] - public ChatCompletions Chat { get; } + public ChatCompletion Chat { get; } /// /// Gets the latest response message from the OpenAI Chat API. /// [JsonProperty("response")] - public string Response => this.Chat.Choices.Last().Message.Content; + public string Response => this.Chat.Content.Last().Text; } diff --git a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs index b853036b..3385ca3d 100644 --- a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs @@ -1,14 +1,15 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.ClientModel; using System.Text; using System.Text.Json; -using Azure; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; -using OpenAISDK = Azure.AI.OpenAI; +using OpenAI.Chat; +using OpenAI.Embeddings; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Search; @@ -16,27 +17,27 @@ class SemanticSearchConverter : IAsyncConverter, IAsyncConverter { - readonly OpenAISDK.OpenAIClient openAIClient; + readonly OpenAIClientFactory openAIClientFactory; readonly ILogger logger; readonly ISearchProvider? searchProvider; static readonly JsonSerializerOptions options = new() { - Converters = { + Converters = { new SearchableDocumentJsonConverter(), new EmbeddingsContextConverter(), new EmbeddingsOptionsJsonConverter(), new ChatCompletionsJsonConverter()} - }; + }; public SemanticSearchConverter( - OpenAISDK.OpenAIClient openAIClient, + OpenAIClientFactory openAIClientFactory, ILoggerFactory loggerFactory, IEnumerable searchProviders, IOptions openAiConfigOptions) { - this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient)); this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); + this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory)); openAiConfigOptions.Value.SearchProvider.TryGetValue("type", out object value); this.logger.LogInformation("Type of the searchProvider configured in host file: {type}", value); @@ -61,14 +62,13 @@ async Task ConvertHelperAsync( } // Get the embeddings for the query, which will be used for doing a semantic search - OpenAISDK.EmbeddingsOptions embeddingsRequest = new(attribute.EmbeddingsModel, new List { attribute.Query }); + this.logger.LogInformation("Sending OpenAI embeddings request: {request}", attribute.Query); + ClientResult embedding = await this.openAIClientFactory.GetEmbeddingClient( + attribute.AIConnectionName, + attribute.EmbeddingsModel).GenerateEmbeddingAsync(attribute.Query, cancellationToken: cancellationToken); + this.logger.LogInformation("Received OpenAI embeddings"); - this.logger.LogInformation("Sending OpenAI embeddings request: {request}", embeddingsRequest.Input); - Response embeddingsResponse = await this.openAIClient.GetEmbeddingsAsync(embeddingsRequest, cancellationToken); - this.logger.LogInformation("Received OpenAI embeddings count: {response}", embeddingsResponse.Value.Data.Count); - - - ConnectionInfo connectionInfo = new(attribute.ConnectionName, attribute.Collection); + ConnectionInfo connectionInfo = new(attribute.SearchConnectionName, attribute.Collection); if (string.IsNullOrEmpty(connectionInfo.ConnectionName)) { throw new InvalidOperationException("No connection string information was provided."); @@ -81,7 +81,7 @@ async Task ConvertHelperAsync( // Search for relevant document snippets using the original query and the embeddings SearchRequest searchRequest = new( attribute.Query, - embeddingsResponse.Value.Data[0].Embedding, + embedding.Value.ToFloats(), attribute.MaxKnowledgeCount, connectionInfo); SearchResponse searchResponse = await this.searchProvider.SearchAsync(searchRequest); @@ -95,20 +95,17 @@ async Task ConvertHelperAsync( } // Call the chat API with the new combined prompt to get a response back - OpenAISDK.ChatCompletionsOptions chatCompletionsOptions = new() - { - DeploymentName = attribute.ChatModel, - Messages = + IList messages = new List() { - new OpenAISDK.ChatRequestSystemMessage(promptBuilder.ToString()), - new OpenAISDK.ChatRequestUserMessage(attribute.Query), - } - }; + new SystemChatMessage(promptBuilder.ToString()), + new UserChatMessage(attribute.Query), + }; - Response chatResponse = await this.openAIClient.GetChatCompletionsAsync(chatCompletionsOptions); + ChatCompletionOptions completionOptions = attribute.BuildRequest(); + ClientResult chatResponse = await this.openAIClientFactory.GetChatClient(attribute.AIConnectionName, attribute.ChatModel).CompleteChatAsync(messages, completionOptions); // Give the user the full context, including the embeddings information as well as the chat info - return new SemanticSearchContext(new EmbeddingsContext(embeddingsRequest, embeddingsResponse), chatResponse); + return new SemanticSearchContext(new EmbeddingsContext(new List { attribute.Query }, null), chatResponse); } async Task IAsyncConverter.ConvertAsync( diff --git a/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs b/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs index 2ad00ac8..ab3f2944 100644 --- a/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs +++ b/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs @@ -1,9 +1,8 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using Azure.AI.OpenAI; using Microsoft.Azure.WebJobs.Description; -using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; +using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI; @@ -12,7 +11,7 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI; /// [Binding] [AttributeUsage(AttributeTargets.Parameter)] -public sealed class TextCompletionAttribute : Attribute +public sealed class TextCompletionAttribute : AssistantBaseAttribute { /// /// Initializes a new instance of the class with the specified text prompt. @@ -20,7 +19,9 @@ public sealed class TextCompletionAttribute : Attribute /// The prompt to generate completions for, encoded as a string. public TextCompletionAttribute(string prompt) { - this.Prompt = prompt ?? throw new ArgumentNullException(nameof(prompt)); + this.Prompt = string.IsNullOrEmpty(prompt) + ? throw new ArgumentException("Input cannot be null or empty.", nameof(prompt)) + : prompt; } /// @@ -28,70 +29,4 @@ public TextCompletionAttribute(string prompt) /// [AutoResolve] public string Prompt { get; } - - /// - /// Gets or sets the ID of the model to use. - /// - [AutoResolve] - public string Model { get; set; } = OpenAIModels.DefaultChatModel; - - /// - /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output - /// more random, while lower values like 0.2 will make it more focused and deterministic. - /// - /// - /// It's generally recommend to use this or but not both. - /// - [AutoResolve] - public string? Temperature { get; set; } = "0.5"; - - /// - /// Gets or sets an alternative to sampling with temperature, called nucleus sampling, where the model considers - /// the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% - /// probability mass are considered. - /// - /// - /// It's generally recommend to use this or but not both. - /// - [AutoResolve] - public string? TopP { get; set; } - - /// - /// Gets or sets the maximum number of tokens to generate in the completion. - /// - /// - /// The token count of your prompt plus max_tokens cannot exceed the model's context length. - /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096). - /// - [AutoResolve] - public string? MaxTokens { get; set; } = "100"; - - internal ChatCompletionsOptions BuildRequest() - { - ChatCompletionsOptions request = new() - { - DeploymentName = this.Model, - Messages = - { - new ChatRequestUserMessage(this.Prompt), - } - }; - - if (int.TryParse(this.MaxTokens, out int maxTokens)) - { - request.MaxTokens = maxTokens; - } - - if (float.TryParse(this.Temperature, out float temperature)) - { - request.Temperature = temperature; - } - - if (float.TryParse(this.TopP, out float topP)) - { - request.NucleusSamplingFactor = topP; - } - - return request; - } } diff --git a/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs b/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs index b936ea44..6043dec5 100644 --- a/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs +++ b/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -using Azure; -using Azure.AI.OpenAI; +using System.ClientModel; using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; using Microsoft.Extensions.Logging; using Newtonsoft.Json; +using OpenAI.Chat; namespace Microsoft.Azure.WebJobs.Extensions.OpenAI; @@ -13,12 +13,12 @@ class TextCompletionConverter : IAsyncConverter, IAsyncConverter { - readonly OpenAIClient openAIClient; + readonly OpenAIClientFactory openAIClientFactory; readonly ILogger logger; - public TextCompletionConverter(OpenAIClient openAIClient, ILoggerFactory loggerFactory) + public TextCompletionConverter(OpenAIClientFactory openAIClientFactory, ILoggerFactory loggerFactory) { - this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient)); + this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory)); this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory)); } @@ -43,18 +43,22 @@ async Task ConvertCoreAsync( TextCompletionAttribute attribute, CancellationToken cancellationToken) { - ChatCompletionsOptions options = attribute.BuildRequest(); - this.logger.LogInformation("Sending OpenAI completion request: {request}", options); + ChatCompletionOptions options = attribute.BuildRequest(); + this.logger.LogInformation("Sending OpenAI completion request with prompt: {request}", attribute.Prompt); - Response response = await this.openAIClient.GetChatCompletionsAsync( - options, - cancellationToken); + IList chatMessages = new List() + { + new UserChatMessage(attribute.Prompt) + }; + + ClientResult response = await this.openAIClientFactory.GetChatClient( + attribute.AIConnectionName, + attribute.ChatModel).CompleteChatAsync(chatMessages, options, cancellationToken: cancellationToken); string text = string.Join( Environment.NewLine + Environment.NewLine, - response.Value.Choices.Select(choice => choice.Message.Content)); - TextCompletionResponse textCompletionResponse = new(text, response.Value.Usage.TotalTokens); - + response.Value.Content[0].Text); + TextCompletionResponse textCompletionResponse = new(text, response.Value.Usage.TotalTokenCount); return textCompletionResponse; } } diff --git a/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj b/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj index 323285c0..0f4fbb48 100644 --- a/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj +++ b/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj @@ -5,13 +5,16 @@ - + - - + - + + + + + \ No newline at end of file diff --git a/tests/SampleValidation/AssistantTests.cs b/tests/SampleValidation/AssistantTests.cs index e6eb44f6..fc53f512 100644 --- a/tests/SampleValidation/AssistantTests.cs +++ b/tests/SampleValidation/AssistantTests.cs @@ -71,7 +71,7 @@ public async Task AddTodoTest() Assert.StartsWith("text/plain", questionResponse.Content.Headers.ContentType?.MediaType); // Ensure that the model responded and mentioned the new todo item. - await ValidateAssistantResponseAsync(expectedMessageCount: 4, expectedContent: "Buy milk", hasTotalTokens: true); + await ValidateAssistantResponseAsync(expectedMessageCount: 5, expectedContent: "Buy milk", hasTotalTokens: true); // Local function to validate each chat bot response async Task ValidateAssistantResponseAsync(int expectedMessageCount, string expectedContent, bool hasTotalTokens = false) @@ -107,7 +107,7 @@ async Task ValidateAssistantResponseAsync(int expectedMessageCount, string expec // The timestamp filter should ensure we only ever look at the most recent messages Assert.True(messageArray!.Count <= totalMessages); - Assert.True(messageArray!.Count <= 4); + Assert.True(messageArray!.Count <= 5); if (totalMessages >= expectedMessageCount) { @@ -118,13 +118,23 @@ async Task ValidateAssistantResponseAsync(int expectedMessageCount, string expec { // Make sure the first message is the system message JsonNode systemMessage = messageArray!.First()!; - Assert.Equal("system", systemMessage["role"]?.GetValue()); + Assert.Equal("system", systemMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase); } else { + // Validate the third message contains the toolcalls string with task description + if (messageArray.Count >= 2) + { + JsonNode thirdMessage = messageArray![1]!; + Assert.Equal("assistant", thirdMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase); + string? thirdMessageToolCalls = thirdMessage["toolCalls"]?.GetValue(); + Assert.NotNull(thirdMessageToolCalls); + Assert.Contains("AddTodo", thirdMessageToolCalls, StringComparison.OrdinalIgnoreCase); + } + // Make sure that the last message is from the chat bot (assistant) JsonNode lastMessage = messageArray![messageArray.Count - 1]!; - Assert.Equal("assistant", lastMessage["role"]?.GetValue()); + Assert.Equal("assistant", lastMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase); Assert.Contains("Buy milk", lastMessage!["content"]?.GetValue()); } diff --git a/tests/SampleValidation/Chat.cs b/tests/SampleValidation/Chat.cs index e6719e72..386b1f85 100644 --- a/tests/SampleValidation/Chat.cs +++ b/tests/SampleValidation/Chat.cs @@ -125,13 +125,13 @@ async Task ValidateAssistantResponseAsync(int expectedMessageCount, string expec { // Make sure the first message is the system message JsonNode systemMessage = messageArray!.First()!; - Assert.Equal("system", systemMessage["role"]?.GetValue()); + Assert.Equal("system", systemMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase); } else { // Make sure that the last message is from the chat bot (assistant) JsonNode lastMessage = messageArray![messageArray.Count - 1]!; - Assert.Equal("assistant", lastMessage["role"]?.GetValue()); + Assert.Equal("assistant", lastMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase); Assert.StartsWith("Yo!", lastMessage!["content"]?.GetValue()); } diff --git a/tests/SampleValidation/EmbeddingsTests.cs b/tests/SampleValidation/EmbeddingsTests.cs new file mode 100644 index 00000000..d9f215bd --- /dev/null +++ b/tests/SampleValidation/EmbeddingsTests.cs @@ -0,0 +1,147 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Net; +using System.Net.Http.Json; +using Xunit; +using Xunit.Abstractions; + +namespace SampleValidation; + +public class EmbeddingsTests +{ + readonly ITestOutputHelper output; + readonly string baseAddress; + readonly HttpClient client; + readonly CancellationTokenSource cts; + + public EmbeddingsTests(ITestOutputHelper output) + { + this.output = output; + this.client = new HttpClient(new LoggingHandler(this.output)); + this.cts = new CancellationTokenSource(delay: TimeSpan.FromMinutes(Debugger.IsAttached ? 5 : 1)); + + this.baseAddress = Environment.GetEnvironmentVariable("FUNC_BASE_ADDRESS") ?? "http://localhost:7071"; + +#if RELEASE + // Use the default key for the Azure Functions app in RELEASE mode; for local development, DEBUG mode can be used. + string functionKey = Environment.GetEnvironmentVariable("FUNC_DEFAULT_KEY") ?? throw new InvalidOperationException("Missing environment variable 'FUNC_DEFAULT_KEY'"); + client.DefaultRequestHeaders.Add("x-functions-key", functionKey); +#endif + } + + [Fact] + public async Task GenerateEmbeddings_Http_Request_Test() + { + // Arrange + var request = new + { + rawText = "This is a test for generating embeddings from raw text." + }; + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/embeddings", + request, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.NoContent, response.StatusCode); + } + + [Fact] + public async Task EmbeddingsLegacy_Performance_Test() + { + // Create a large text for testing performance + string largeText = new('A', 10000); // 10KB text + + var request = new + { + rawText = largeText + }; + + // Measure time + var stopwatch = Stopwatch.StartNew(); + + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/embeddings", + request, + cancellationToken: this.cts.Token); + + stopwatch.Stop(); + + // Assert + Assert.Equal(HttpStatusCode.NoContent, response.StatusCode); + + // Log performance metrics + this.output.WriteLine($"Embeddings generation for 10KB text took {stopwatch.ElapsedMilliseconds}ms"); + + // If specific performance SLAs exist, add assertions like: + // Assert.True(stopwatch.ElapsedMilliseconds < 5000, "Embeddings generation took too long"); + } + + [Fact] + public async Task GetEmbeddings_Url_ReturnsNoContent() + { + // Arrange + var request = new { url = "https://github.com/Azure/azure-functions-openai-extension/blob/main/README.md" }; + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/embeddings-from-url", + request, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.NoContent, response.StatusCode); + } + + [Fact] + public async Task GenerateEmbeddings_InvalidText_ReturnsBadRequest() + { + // Arrange + var request = new { url = "" }; // Invalid: Empty text + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/embeddings", + request, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + } + + [Fact] + public async Task GetEmbeddings_InvalidFilePath_ReturnsBadRequest() + { + // Arrange + var request = new { filePath = "invalid/file/path.txt" }; // Invalid: Non-existent file path + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/embeddings-from-file", + request, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + } + + [Fact] + public async Task GetEmbeddings_InvalidUrl_ReturnsBadRequest() + { + // Arrange + var request = new { url = "invalid-url" }; // Invalid: Malformed URL + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/embeddings-from-url", + request, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + } +} \ No newline at end of file diff --git a/tests/SampleValidation/FilePromptsTests.cs b/tests/SampleValidation/FilePromptsTests.cs new file mode 100644 index 00000000..8eb01abd --- /dev/null +++ b/tests/SampleValidation/FilePromptsTests.cs @@ -0,0 +1,68 @@ +using System.Diagnostics; +using System.Net; +using System.Net.Http.Json; +using System.Text.Json.Nodes; +using Xunit; +using Xunit.Abstractions; + +namespace SampleValidation; + +public class FilePromptTests +{ + readonly ITestOutputHelper output; + readonly string baseAddress; + readonly HttpClient client; + readonly CancellationTokenSource cts; + + public FilePromptTests(ITestOutputHelper output) + { + this.output = output; + this.client = new HttpClient(new LoggingHandler(this.output)); + this.cts = new CancellationTokenSource(delay: TimeSpan.FromMinutes(Debugger.IsAttached ? 5 : 2)); + + this.baseAddress = Environment.GetEnvironmentVariable("FUNC_BASE_ADDRESS") ?? "http://localhost:7071"; + +#if RELEASE + // Use the default key for the Azure Functions app in RELEASE mode; for local development, DEBUG mode can be used. + string functionKey = Environment.GetEnvironmentVariable("FUNC_DEFAULT_KEY") ?? throw new InvalidOperationException("Missing environment variable 'FUNC_DEFAULT_KEY'"); + client.DefaultRequestHeaders.Add("x-functions-key", functionKey); +#endif + } + + [Fact] + public async Task Ingest_Prompt_File_Test() + { + // Step 1: Test IngestFile + var ingestRequest = new { url = "https://github.com/Azure/azure-functions-openai-extension/blob/main/README.md" }; + + using HttpResponseMessage ingestResponse = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/IngestFile", + ingestRequest, + cancellationToken: this.cts.Token); + + Assert.Equal(HttpStatusCode.OK, ingestResponse.StatusCode); + Assert.StartsWith("application/json", ingestResponse.Content.Headers.ContentType?.MediaType); + + string ingestResponseContent = await ingestResponse.Content.ReadAsStringAsync(this.cts.Token); + JsonNode? ingestJsonResponse = JsonNode.Parse(ingestResponseContent); + Assert.NotNull(ingestJsonResponse); + Assert.Equal("success", ingestJsonResponse!["status"]?.GetValue()); + Assert.Equal("README.md", ingestJsonResponse!["title"]?.GetValue()); + + // Step 2: Test PromptFile + var promptRequest = new { prompt = "How can the textCompletion input binding be used from Azure Functions OpenAI extension?" }; + + using HttpResponseMessage promptResponse = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/PromptFile", + promptRequest, + cancellationToken: this.cts.Token); + + Assert.Equal(HttpStatusCode.OK, promptResponse.StatusCode); + Assert.StartsWith("text/plain", promptResponse.Content.Headers.ContentType?.MediaType); + + string promptResponseContent = await promptResponse.Content.ReadAsStringAsync(this.cts.Token); + Assert.False(string.IsNullOrWhiteSpace(promptResponseContent)); + Assert.Contains("OpenAI Chat Completions API", promptResponseContent, StringComparison.OrdinalIgnoreCase); + Assert.Contains("README", promptResponseContent, StringComparison.OrdinalIgnoreCase); + } +} \ No newline at end of file diff --git a/tests/SampleValidation/TextCompletionTests.cs b/tests/SampleValidation/TextCompletionTests.cs new file mode 100644 index 00000000..2cb18ac4 --- /dev/null +++ b/tests/SampleValidation/TextCompletionTests.cs @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.Diagnostics; +using System.Net; +using System.Net.Http.Json; +using Xunit; +using Xunit.Abstractions; + +namespace SampleValidation; + +public class TextCompletionLegacyTests +{ + readonly ITestOutputHelper output; + readonly HttpClient client; + readonly string baseAddress; + readonly CancellationTokenSource cts; + + public TextCompletionLegacyTests(ITestOutputHelper output) + { + this.output = output; + this.client = new HttpClient(new LoggingHandler(this.output)); + this.baseAddress = Environment.GetEnvironmentVariable("FUNC_BASE_ADDRESS") ?? "http://localhost:7071"; + this.cts = new CancellationTokenSource(delay: TimeSpan.FromMinutes(Debugger.IsAttached ? 5 : 1)); + +#if RELEASE + // Use the default key for the Azure Functions app in RELEASE mode; for local development, DEBUG mode can be used. + string functionKey = Environment.GetEnvironmentVariable("FUNC_DEFAULT_KEY") ?? throw new InvalidOperationException("Missing environment variable 'FUNC_DEFAULT_KEY'"); + this.client.DefaultRequestHeaders.Add("x-functions-key", functionKey); +#endif + } + + [Fact] + public async Task WhoIs_ValidName_ReturnsContent() + { + // Arrange + string name = "Albert Einstein"; + + // Act + using HttpResponseMessage response = await this.client.GetAsync( + requestUri: $"{this.baseAddress}/api/whois/{name}", + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + string content = await response.Content.ReadAsStringAsync(this.cts.Token); + this.output.WriteLine($"Response content: {content}"); + + // Validate that content is not empty and contains some relevant information about Albert Einstein + Assert.NotEmpty(content); + Assert.Contains("physicist", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("relativity", content, StringComparison.OrdinalIgnoreCase); + } + + [Fact] + public async Task GenericCompletion_ValidPrompt_ReturnsContent() + { + // Arrange + var promptPayload = new { Prompt = "Write a haiku about programming" }; + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/GenericCompletion", + promptPayload, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + string content = await response.Content.ReadAsStringAsync(this.cts.Token); + this.output.WriteLine($"Response content: {content}"); + + // Validate that content is not empty and resembles a haiku structure + Assert.NotEmpty(content); + + // Verify we received some form of haiku (might contain line breaks) + // We can't verify exact content since AI responses vary, but we can check for reasonable length + // and common haiku-related formatting + Assert.True(content.Length >= 10, "Response should contain a meaningful haiku"); + Assert.True(content.Split('\n').Length >= 1, "Haiku should have at least one line"); + } + + [Fact] + public async Task GenericCompletion_EmptyPrompt_ReturnsError() + { + // Arrange + var emptyPrompt = new { Prompt = string.Empty }; + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/GenericCompletion", + emptyPrompt, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + } + + [Fact] + public async Task GenericCompletion_ComplexPrompt_ReturnsValidContent() + { + // Arrange + var complexPrompt = new + { + Prompt = "Explain the difference between synchronous and asynchronous programming in 3 bullet points" + }; + + // Act + using HttpResponseMessage response = await this.client.PostAsJsonAsync( + requestUri: $"{this.baseAddress}/api/GenericCompletion", + complexPrompt, + cancellationToken: this.cts.Token); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + string content = await response.Content.ReadAsStringAsync(this.cts.Token); + this.output.WriteLine($"Response content: {content}"); + + // Validate that content is not empty and contains relevant terms + Assert.NotEmpty(content); + Assert.Contains("synchronous", content, StringComparison.OrdinalIgnoreCase); + Assert.Contains("asynchronous", content, StringComparison.OrdinalIgnoreCase); + + // Check for bullet point formatting (could be -, *, or numbered) + bool hasBulletFormat = content.Contains('-') || + content.Contains('*') || + content.Contains("1.") || + content.Contains('•'); + + Assert.True(hasBulletFormat, "Response should contain bullet points"); + } +} diff --git a/tests/UnitTests/AssistantServiceTests.cs b/tests/UnitTests/AssistantServiceTests.cs new file mode 100644 index 00000000..eb40bb59 --- /dev/null +++ b/tests/UnitTests/AssistantServiceTests.cs @@ -0,0 +1,472 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + + +using System.ClientModel; +using System.Threading; +using Azure; +using Azure.Core; +using Azure.Data.Tables; +using Azure.Data.Tables.Models; +using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants; +using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models; +using Microsoft.Extensions.Azure; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Moq; +using OpenAI.Chat; +using Xunit; + +namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Tests.Assistants; + +public class DefaultAssistantServiceTests +{ + readonly Mock mockOpenAIClientFactory; + readonly Mock mockAzureComponentFactory; + readonly Mock mockConfiguration; + readonly Mock mockSkillInvoker; + readonly Mock mockLoggerFactory; + readonly Mock> mockLogger; + readonly Mock mockTableServiceClient; + readonly Mock mockTableClient; + + public DefaultAssistantServiceTests() + { + this.mockAzureComponentFactory = new Mock(); + this.mockConfiguration = new Mock(); + this.mockSkillInvoker = new Mock(); + this.mockLoggerFactory = new Mock(); + this.mockLogger = new Mock>(); + this.mockTableServiceClient = new Mock(); + this.mockTableClient = new Mock(); + this.mockOpenAIClientFactory = new Mock( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + this.mockLoggerFactory.Setup(x => x.CreateLogger(It.IsAny())) + .Returns(this.mockLogger.Object); + + // Setup table client + this.mockTableServiceClient.Setup(x => x.GetTableClient(It.IsAny())) + .Returns(this.mockTableClient.Object); + } + + [Fact] + public async Task CreateAssistantAsync_WithValidRequest_CreatesAssistantAndMessages() + { + // Arrange + var request = new AssistantCreateRequest("testId", "Test instructions") + { + CollectionName = "ChatState", + ChatStorageConnectionSetting = "AzureWebJobsStorage" + }; + + var mockQueryResult = new List(); + AsyncPageable mockQueryable = MockAsyncPageable.Create(mockQueryResult); + + this.mockTableClient.Setup(x => x.CreateIfNotExistsAsync(It.IsAny())) + .ReturnsAsync(Response.FromValue(new TableItem(request.CollectionName), new Mock().Object)); + + this.mockTableClient.Setup(x => x.QueryAsync( + It.Is(s => s == $"PartitionKey eq '{request.Id}'"), + null, + null, + It.IsAny())) + .Returns(mockQueryable); + + this.mockTableClient.Setup(x => x.SubmitTransactionAsync( + It.IsAny>(), + It.IsAny())) + .ReturnsAsync(Response.FromValue(new List() as IReadOnlyList, new Mock().Object)); + + //// Arrange + //Mock mockSection = CreateMockSection( + // exists: false, + // tableServiceUri: null); + //mockSection.Setup(s => s.Value).Returns("UseDevelopmentStorage=true"); + + //// Setup AzureWebJobsStorage directly + //this.mockConfiguration.Setup(c => c["AzureWebJobsStorage"]).Returns("UseDevelopmentStorage=true"); + //this.mockConfiguration.Setup(c => c.GetSection("AzureWebJobsStorage")).Returns(mockSection.Object); + + // Create the service under test + var assistantService = new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object); + + // Use reflection to set the tableClient field + System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + tableClientField?.SetValue(assistantService, this.mockTableClient.Object); + + // Act + await assistantService.CreateAssistantAsync(request, CancellationToken.None); + + // Assert + this.mockTableClient.Verify(x => x.CreateIfNotExistsAsync(It.IsAny()), Times.Once); + + this.mockTableClient.Verify(x => x.QueryAsync( + It.Is(s => s.Contains(request.Id)), + null, + null, + It.IsAny()), Times.Once); + + this.mockTableClient.Verify(x => x.SubmitTransactionAsync( + It.IsAny>(), + It.IsAny()), Times.Once); + } + + [Fact] + public async Task CreateAssistantAsync_WithExistingAssistant_DeletesOldEntitiesFirst() + { + // Arrange + var request = new AssistantCreateRequest("testId", "Test instructions") + { + CollectionName = "ChatState", + ChatStorageConnectionSetting = "AzureWebJobsStorage" + }; + + // Create existing entities + var existingEntities = new List + { + new("testId", "state"), + new("testId", "msg-001") + }; + + AsyncPageable mockQueryable = MockAsyncPageable.Create(existingEntities); + + this.mockTableClient.Setup(x => x.CreateIfNotExistsAsync(It.IsAny())) + .ReturnsAsync(Response.FromValue(new TableItem(request.CollectionName), new Mock().Object)); + + this.mockTableClient.Setup(x => x.QueryAsync( + It.Is(s => s == $"PartitionKey eq '{request.Id}'"), + null, + null, + It.IsAny())) + .Returns(mockQueryable); + + this.mockTableClient.Setup(x => x.SubmitTransactionAsync( + It.IsAny>(), + It.IsAny())) + .ReturnsAsync(Response.FromValue(new List() as IReadOnlyList, new Mock().Object)); + + // Create the service under test + var assistantService = new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object); + + // Use reflection to set the tableClient field + System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + tableClientField?.SetValue(assistantService, this.mockTableClient.Object); + + + // Act + await assistantService.CreateAssistantAsync(request, CancellationToken.None); + + // Assert + this.mockTableClient.Verify(x => x.SubmitTransactionAsync( + It.IsAny>(), + It.IsAny()), Times.AtLeastOnce); + + this.mockTableClient.Verify(x => x.SubmitTransactionAsync( + It.Is>( + actions => actions.Any(a => a.ActionType == TableTransactionActionType.Add)), + It.IsAny()), Times.AtLeastOnce); + } + + [Fact] + public async Task GetStateAsync_WithValidId_ReturnsCorrectState() + { + // Arrange + string id = "testId"; + string timestamp = DateTime.UtcNow.AddHours(-1).ToString("o"); + var attribute = new AssistantQueryAttribute(id) + { + TimestampUtc = timestamp, + CollectionName = "testCollection" + }; + + // Create mock entities + var stateEntity = new TableEntity(id, AssistantStateEntity.FixedRowKeyValue) + { + ["Exists"] = true, + ["CreatedAt"] = DateTime.UtcNow.AddDays(-1), + ["LastUpdatedAt"] = DateTime.UtcNow, + ["TotalMessages"] = 2, + ["TotalTokens"] = 100 + }; + + var message1 = new TableEntity(id, "msg-001") + { + ["Content"] = "Test message 1", + ["Role"] = ChatMessageRole.System.ToString(), + ["CreatedAt"] = DateTime.UtcNow.AddMinutes(-30), + ["ToolCalls"] = "" + }; + + var message2 = new TableEntity(id, "msg-002") + { + ["Content"] = "Test message 2", + ["Role"] = ChatMessageRole.Assistant.ToString(), + ["CreatedAt"] = DateTime.UtcNow.AddMinutes(-15), + ["ToolCalls"] = "" + }; + + var mockQueryResult = new List { stateEntity, message1, message2 }; + AsyncPageable mockQueryable = MockAsyncPageable.Create(mockQueryResult); + + this.mockTableClient.Setup(x => x.QueryAsync( + It.Is(s => s == $"PartitionKey eq '{id}'"), + null, + null, + It.IsAny())) + .Returns(mockQueryable); + + // Create the service under test + var assistantService = new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object); + + // Use reflection to set the tableClient field + System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + tableClientField?.SetValue(assistantService, this.mockTableClient.Object); + + // Act + AssistantState result = await assistantService.GetStateAsync(attribute, CancellationToken.None); + + // Assert + Assert.Equal(id, result.Id); + Assert.True(result.Exists); + Assert.Equal(2, result.TotalMessages); + Assert.Equal(100, result.TotalTokens); + + // Verify that we filter messages based on timestamp + DateTime parsedTimestamp = DateTime.Parse(Uri.UnescapeDataString(timestamp)).ToUniversalTime(); + Assert.All(result.RecentMessages, msg => + Assert.True(DateTime.UtcNow.AddMinutes(-30) > parsedTimestamp)); + } + + [Fact] + public async Task GetStateAsync_WithNonExistentId_ReturnsEmptyState() + { + // Arrange + string id = "nonExistentId"; + string timestamp = DateTime.UtcNow.AddHours(-1).ToString("o"); + var attribute = new AssistantQueryAttribute(id) + { + TimestampUtc = timestamp, + CollectionName = "testCollection" + }; + + var mockQueryResult = new List { }; + AsyncPageable mockQueryable = MockAsyncPageable.Create(mockQueryResult); + + this.mockTableClient.Setup(x => x.QueryAsync( + It.Is(s => s == $"PartitionKey eq '{id}'"), + null, + null, + It.IsAny())) + .Returns(mockQueryable); + + // Create the service under test + var assistantService = new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object); + + // Use reflection to set the tableClient field + System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + tableClientField?.SetValue(assistantService, this.mockTableClient.Object); + + // Act + AssistantState result = await assistantService.GetStateAsync(attribute, CancellationToken.None); + + // Assert + Assert.Equal(id, result.Id); + Assert.False(result.Exists); + Assert.Equal(0, result.TotalMessages); + Assert.Empty(result.RecentMessages); + } + + [Fact] + public void Constructor_WithNullArguments_ThrowsArgumentNullException() + { + // Act & Assert +#nullable disable + Assert.Throws(() => new DefaultAssistantService( + null, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object)); + + Assert.Throws(() => new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + null, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object)); + + Assert.Throws(() => new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + null, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object)); + + Assert.Throws(() => new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + null, + this.mockLoggerFactory.Object)); + + Assert.Throws(() => new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + null)); +#nullable restore + } + + [Fact] + public async Task PostMessageAsync_WithNonExistentAssistant_ReturnsEmptyState() + { + // Arrange + string assistantId = "nonExistentId"; + var attribute = new AssistantPostAttribute(assistantId, "Hello") + { + CollectionName = "ChatState", + ChatStorageConnectionSetting = "AzureWebJobsStorage" + }; + + var mockQueryResult = new List(); + AsyncPageable mockQueryable = MockAsyncPageable.Create(mockQueryResult); + + this.mockTableClient.Setup(x => x.QueryAsync( + It.Is(s => s == $"PartitionKey eq '{assistantId}'"), + null, + null, + It.IsAny())) + .Returns(mockQueryable); + + // Create the service under test + var assistantService = new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object); + + // Use reflection to set the tableClient field + System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient", + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance); + tableClientField?.SetValue(assistantService, this.mockTableClient.Object); + + // Act + AssistantState result = await assistantService.PostMessageAsync(attribute, CancellationToken.None); + + // Assert + Assert.Equal(assistantId, result.Id); + Assert.False(result.Exists); + Assert.Equal(0, result.TotalMessages); + Assert.Empty(result.RecentMessages); + } + + [Fact] + public async Task PostMessageAsync_WithInvalidAttributes_ThrowsArgumentException() + { + // Arrange - Missing user message + var attributeNoMessage = new AssistantPostAttribute("testId", "") + { + CollectionName = "ChatState", + ChatStorageConnectionSetting = "AzureWebJobsStorage" + }; + + // Create the service under test + var assistantService = new DefaultAssistantService( + this.mockOpenAIClientFactory.Object, + this.mockAzureComponentFactory.Object, + this.mockConfiguration.Object, + this.mockSkillInvoker.Object, + this.mockLoggerFactory.Object); + + // Act & Assert - Empty message + await Assert.ThrowsAsync(() => + assistantService.PostMessageAsync(attributeNoMessage, CancellationToken.None)); + + // Arrange - Missing ID + var attributeNoId = new AssistantPostAttribute("", "Hello") + { + CollectionName = "ChatState", + ChatStorageConnectionSetting = "AzureWebJobsStorage" + }; + + // Act & Assert - Empty ID + await Assert.ThrowsAsync(() => + assistantService.PostMessageAsync(attributeNoId, CancellationToken.None)); + } + + static Mock CreateMockSection(bool exists, string? tableServiceUri = null) + { + var mockSection = new Mock(); + + if (!exists) + { + mockSection.Setup(s => s.Value).Returns(""); + mockSection.Setup(s => s.GetChildren()).Returns([]); + } + else + { + mockSection.Setup(s => s.Value).Returns("some-value"); + } + + var endpointSection = new Mock(); +#nullable disable + endpointSection.Setup(s => s.Value).Returns(tableServiceUri); +#nullable restore + mockSection.Setup(s => s.GetSection("tableServiceUri")).Returns(endpointSection.Object); + + return mockSection; + } +} + +// Update the MockAsyncPageable class to ensure TableEntity satisfies the 'notnull' constraint +public class MockAsyncPageable : AsyncPageable where T : notnull +{ + readonly IEnumerable _items; + + MockAsyncPageable(IEnumerable items) + { + this._items = items; + } + + public static AsyncPageable Create(IEnumerable items) + { + return new MockAsyncPageable(items); + } + + public override async IAsyncEnumerable> AsPages( + string? continuationToken = null, + int? pageSizeHint = null) + { + await Task.Yield(); + yield return Page.FromValues([.. this._items], null, new Mock().Object); + } +} diff --git a/tests/UnitTests/OpenAIClientFactoryTests.cs b/tests/UnitTests/OpenAIClientFactoryTests.cs new file mode 100644 index 00000000..37c5ba72 --- /dev/null +++ b/tests/UnitTests/OpenAIClientFactoryTests.cs @@ -0,0 +1,383 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + + +using System.Reflection; +using Azure.Core; +using Microsoft.Azure.WebJobs.Extensions.OpenAI; +using Microsoft.Extensions.Azure; +using Microsoft.Extensions.Configuration; +using Microsoft.Extensions.Logging; +using Moq; +using OpenAI.Chat; +using OpenAI.Embeddings; +using Xunit; + +namespace WebJobsOpenAIUnitTests; + +public class OpenAIClientFactoryTests +{ + readonly Mock mockConfiguration; + readonly Mock mockAzureComponentFactory; + readonly Mock mockLoggerFactory; + readonly Mock> mockLogger; + + public OpenAIClientFactoryTests() + { + this.mockConfiguration = new Mock(); + this.mockAzureComponentFactory = new Mock(); + this.mockLoggerFactory = new Mock(); + this.mockLogger = new Mock>(); + this.mockLoggerFactory + .Setup(f => f.CreateLogger(It.Is(s => s == typeof(OpenAIClientFactory).FullName))) + .Returns(this.mockLogger.Object); + } + + [Fact] + public void Constructor_NullConfiguration_ThrowsArgumentNullException() + { + // Arrange + IConfiguration? configuration = null; + +#nullable disable + // Act & Assert + ArgumentNullException exception = Assert.Throws(() => + new OpenAIClientFactory(configuration, this.mockAzureComponentFactory.Object, this.mockLoggerFactory.Object)); +#nullable restore + Assert.Equal("configuration", exception.ParamName); + } + + [Fact] + public void Constructor_NullAzureComponentFactory_ThrowsArgumentNullException() + { + // Arrange + AzureComponentFactory? azureComponentFactory = null; + +#nullable disable + // Act & Assert + ArgumentNullException exception = Assert.Throws(() => + new OpenAIClientFactory(this.mockConfiguration.Object, azureComponentFactory, this.mockLoggerFactory.Object)); +#nullable restore + Assert.Equal("azureComponentFactory", exception.ParamName); + } + + [Fact] + public void Constructor_NullLoggerFactory_ThrowsArgumentNullException() + { + // Arrange + ILoggerFactory? loggerFactory = null; + + // Act & Assert +#nullable disable + ArgumentNullException exception = Assert.Throws(() => + new OpenAIClientFactory(this.mockConfiguration.Object, this.mockAzureComponentFactory.Object, loggerFactory)); +#nullable restore + Assert.Equal("loggerFactory", exception.ParamName); + } + + [Fact] + public void GetChatClient_WithAzureOpenAI_ReturnsClient() + { + // Arrange + Mock mockSection = CreateMockSection( + exists: true, + endpoint: "https://test-endpoint.openai.azure.com/", + key: "test-key"); + + this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object); + + // Clear environment variable if set + string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + try + { + Environment.SetEnvironmentVariable("OPENAI_API_KEY", null); + + OpenAIClientFactory factory = new( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + // Act + ChatClient chatClient = factory.GetChatClient("TestConnection", "gpt-35-turbo"); + + // Assert + Assert.NotNull(chatClient); + } + finally + { + // Restore original environment value + Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue); + } + } + + [Fact] + public void GetChatClient_WithOpenAIKey_ReturnsClient() + { + // Arrange + string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + try + { + Environment.SetEnvironmentVariable("OPENAI_API_KEY", "test-openai-key"); + + OpenAIClientFactory factory = new( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + // Act + ChatClient chatClient = factory.GetChatClient("TestConnection", "gpt-4"); + + // Assert + Assert.NotNull(chatClient); + } + finally + { + // Restore original environment value + Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue); + } + } + + [Fact] + public void GetEmbeddingClient_WithAzureOpenAI_ReturnsClient() + { + // Arrange + Mock mockSection = CreateMockSection( + exists: true, + endpoint: "https://test-endpoint.openai.azure.com/", + key: "test-key"); + + this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object); + + // Clear environment variable if set + string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + try + { + Environment.SetEnvironmentVariable("OPENAI_API_KEY", null); + + OpenAIClientFactory factory = new( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + // Act + EmbeddingClient embeddingClient = factory.GetEmbeddingClient("TestConnection", "text-embedding-ada-002"); + + // Assert + Assert.NotNull(embeddingClient); + } + finally + { + // Restore original environment value + Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue); + } + } + + [Fact] + public void GetEmbeddingClient_WithOpenAIKey_ReturnsClient() + { + // Arrange + string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + try + { + Environment.SetEnvironmentVariable("OPENAI_API_KEY", "test-openai-key"); + + OpenAIClientFactory factory = new( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + // Act + EmbeddingClient embeddingClient = factory.GetEmbeddingClient("TestConnection", "text-embedding-ada-002"); + + // Assert + Assert.NotNull(embeddingClient); + } + finally + { + // Restore original environment value + Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue); + } + } + + [Fact] + public void HasOpenAIKey_KeyExists_SetsHasKeyToTrue() + { + // Arrange + string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + try + { + Environment.SetEnvironmentVariable("OPENAI_API_KEY", "test-key"); + + // Use reflection to access private static method + MethodInfo? methodInfo = typeof(OpenAIClientFactory).GetMethod("HasOpenAIKey", + BindingFlags.NonPublic | BindingFlags.Static); + + // Act +#nullable disable + object[] parameters = [false, null]; +#nullable restore + methodInfo?.Invoke(null, parameters); + + // Assert + Assert.True((bool)parameters[0]); + Assert.Equal("test-key", parameters[1]); + } + finally + { + // Restore original environment value + Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue); + } + } + + [Fact] + public void HasOpenAIKey_NoKeyExists_SetsHasKeyToFalse() + { + // Arrange + string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY"); + try + { + Environment.SetEnvironmentVariable("OPENAI_API_KEY", null); + + // Use reflection to access private static method + MethodInfo? methodInfo = typeof(OpenAIClientFactory).GetMethod("HasOpenAIKey", + BindingFlags.NonPublic | BindingFlags.Static); + + // Act + object[] parameters = [true, "some-value"]; + methodInfo?.Invoke(null, parameters); + + // Assert + Assert.False((bool)parameters[0]); + Assert.Null((string)parameters[1]); + } + finally + { + // Restore original environment value + Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue); + } + } + + [Fact] + public void CreateClientFromConfigSection_MissingEndpoint_ThrowsInvalidOperationException() + { + // Arrange + Mock mockSection = CreateMockSection( + exists: true, + endpoint: null, + key: "test-key"); + + this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object); + + // Clear environment variable if set + string? originalEndpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); + try + { + Environment.SetEnvironmentVariable("AZURE_OPENAI_ENDPOINT", null); + + OpenAIClientFactory factory = new( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + // Act & Assert + InvalidOperationException exception = Assert.Throws(() => + factory.GetChatClient("TestConnection", "gpt-35-turbo")); + Assert.Contains("missing required 'Endpoint'", exception.Message); + } + finally + { + // Restore original environment value + Environment.SetEnvironmentVariable("AZURE_OPENAI_ENDPOINT", originalEndpoint); + } + } + + [Fact] + public void CreateClientFromConfigSection_WithTokenCredential_ReturnsClient() + { + // Arrange + Mock mockSection = CreateMockSection( + exists: true, + endpoint: "https://test-endpoint.openai.azure.com/", + key: null); + + this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object); + + TokenCredential tokenCredential = new Mock().Object; + this.mockAzureComponentFactory.Setup(f => f.CreateTokenCredential(It.IsAny())) + .Returns(tokenCredential); + + OpenAIClientFactory factory = new( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + // Act + ChatClient chatClient = factory.GetChatClient("TestConnection", "gpt-35-turbo"); + + // Assert + Assert.NotNull(chatClient); + this.mockAzureComponentFactory.Verify(f => f.CreateTokenCredential(It.IsAny()), Times.Once); + } + + [Fact] + public void CreateClientFromConfigSection_WithEnvironmentVariables_UsesEnvironmentValues() + { + // Arrange + Mock mockSection = CreateMockSection( + exists: false, + endpoint: null, + key: null); + this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object); + + string? originalEndpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT"); + string? originalKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_KEY"); + try + { + Environment.SetEnvironmentVariable("AZURE_OPENAI_ENDPOINT", "https://env-endpoint.openai.azure.com/"); + Environment.SetEnvironmentVariable("AZURE_OPENAI_KEY", "env-key"); + + OpenAIClientFactory factory = new( + this.mockConfiguration.Object, + this.mockAzureComponentFactory.Object, + this.mockLoggerFactory.Object); + + // Act + ChatClient chatClient = factory.GetChatClient("TestConnection", "gpt-35-turbo"); + + // Assert + Assert.NotNull(chatClient); + } + finally + { + // Restore original environment values + Environment.SetEnvironmentVariable("AZURE_OPENAI_ENDPOINT", originalEndpoint); + Environment.SetEnvironmentVariable("AZURE_OPENAI_KEY", originalKey); + } + } + + static Mock CreateMockSection(bool exists, string? endpoint = null, string? key = null) + { + var mockSection = new Mock(); + + if (!exists) + { + mockSection.Setup(s => s.Value).Returns(""); + mockSection.Setup(s => s.GetChildren()).Returns([]); + } + else + { + mockSection.Setup(s => s.Value).Returns("some-value"); + } + +#nullable disable + var endpointSection = new Mock(); + endpointSection.Setup(s => s.Value).Returns(endpoint); + mockSection.Setup(s => s.GetSection("Endpoint")).Returns(endpointSection.Object); + var keySection = new Mock(); + keySection.Setup(s => s.Value).Returns(key); +#nullable restore + mockSection.Setup(s => s.GetSection("Key")).Returns(keySection.Object); + + return mockSection; + } +} \ No newline at end of file diff --git a/tests/UnitTests/WebJobsOpenAIUnitTests.csproj b/tests/UnitTests/WebJobsOpenAIUnitTests.csproj new file mode 100644 index 00000000..d3c1b04b --- /dev/null +++ b/tests/UnitTests/WebJobsOpenAIUnitTests.csproj @@ -0,0 +1,29 @@ + + + + net8.0 + enable + enable + + false + + + + + + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + runtime; build; native; contentfiles; analyzers; buildtransitive + all + + + + + + + +