diff --git a/CHANGELOG.md b/CHANGELOG.md
index 6b54278d..5b9c0d31 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -11,9 +11,22 @@ Starting v0.1.0 for Microsoft.Azure.WebJobs.Extensions.OpenAI.AzureAISearch, it
## v0.19.0 - Unreleased
+### Breaking
+
+- Renamed model properites to `chatModel` and `embeddingsModel` in AssistantPost, Embeddings and TextCompletion bindings.
+- Renamed connectionName to `searchConnectionName` in SemanticSearch binding.
+- Renamed connectionName to `storeConnectionName` in EmbeddingsStore binding.
+- Renamed ChatMessage entity to AssistantMessage.
+- Managed identity support through config section and binding parameter `aiConnectionName` in AssistantPost, Embeddings, EmbeddingsStore, SemanticSearch and TextCompletion bindings.
+
### Changed
-- Updated Azure.Data.Tables from 12.9.1 to 12.10.0, Azure.Identity from 1.12.1 to 1.13.2, Microsoft.Extensions.Azure from 1.7.5 to 1.8.0
+- Updated Azure.AI.OpenAI from 1.0.0-beta.15 to 2.1.0
+- Updated Azure.Data.Tables from 12.9.1 to 12.10.0, Azure.Identity from 1.12.1 to 1.13.2, Microsoft.Extensions.Azure from 1.7.5 to 1.10.0
+
+### Added
+
+- Introduced experimental `isReasoningModel` property to support reasoning models. Setting of max_completion_tokens and reasoning_effort is not supported with current underlying Azure.AI.OpenAI
## v0.18.0 - 2024/10/08
diff --git a/OpenAI-Extension.sln b/OpenAI-Extension.sln
index 6c86f9cc..2fee6c2d 100644
--- a/OpenAI-Extension.sln
+++ b/OpenAI-Extension.sln
@@ -118,6 +118,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "KustoSearchLegacy", "sample
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TextCompletionLegacy", "samples\textcompletion\csharp-legacy\TextCompletionLegacy.csproj", "{73C26271-7B8D-4B38-B454-D9902B92270F}"
EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "WebJobsOpenAIUnitTests", "tests\UnitTests\WebJobsOpenAIUnitTests.csproj", "{52337999-5676-27FD-AE3D-CAAE406AE337}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -216,6 +218,10 @@ Global
{73C26271-7B8D-4B38-B454-D9902B92270F}.Debug|Any CPU.Build.0 = Debug|Any CPU
{73C26271-7B8D-4B38-B454-D9902B92270F}.Release|Any CPU.ActiveCfg = Release|Any CPU
{73C26271-7B8D-4B38-B454-D9902B92270F}.Release|Any CPU.Build.0 = Release|Any CPU
+ {52337999-5676-27FD-AE3D-CAAE406AE337}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {52337999-5676-27FD-AE3D-CAAE406AE337}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {52337999-5676-27FD-AE3D-CAAE406AE337}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {52337999-5676-27FD-AE3D-CAAE406AE337}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -258,6 +264,7 @@ Global
{BB20B2C0-35AE-CC75-61CE-C2713C83E4F6} = {4F208B03-C74C-458B-8300-9FF82F3C6325}
{C5ACB61F-CDAC-E93F-895E-4D46F086DE61} = {B87CBFFB-8221-45E8-8631-4BB1685D50F0}
{73C26271-7B8D-4B38-B454-D9902B92270F} = {5315AE81-BC33-49F2-935B-69287BB44CBD}
+ {52337999-5676-27FD-AE3D-CAAE406AE337} = {B99948E2-9853-4D9C-B784-19C2503CDA8B}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {05A7A679-3A53-45B5-AE93-4313655E127D}
diff --git a/README.md b/README.md
index 537adecc..2294b0cb 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
[](https://opensource.org/licenses/MIT)
[](https://dev.azure.com/azfunc/Azure%20Functions/_build/latest?definitionId=303&branchName=main)
-This project adds support for [OpenAI](https://platform.openai.com/) LLM (GPT-3.5-turbo, GPT-4) bindings in [Azure Functions](https://azure.microsoft.com/products/functions/).
+This project adds support for [OpenAI](https://platform.openai.com/) LLM (GPT-3.5-turbo, GPT-4, o-series) bindings in [Azure Functions](https://azure.microsoft.com/products/functions/).
This extension depends on the [Azure AI OpenAI SDK](https://github.com/Azure/azure-sdk-for-net/tree/main/sdk/openai/Azure.AI.OpenAI).
@@ -29,18 +29,79 @@ Add following section to `host.json` of the function app for non dotnet language
## Requirements
-* [.NET 6 SDK](https://dotnet.microsoft.com/download/dotnet/6.0) or greater (Visual Studio 2022 recommended)
+* [.NET 8 SDK](https://dotnet.microsoft.com/download/dotnet/8.0) or greater (Visual Studio 2022 recommended)
* [Azure Functions Core Tools v4.x](https://learn.microsoft.com/azure/azure-functions/functions-run-local?tabs=v4%2Cwindows%2Cnode%2Cportal%2Cbash)
+* Azure Storage emulator such as [Azurite](https://learn.microsoft.com/azure/storage/common/storage-use-azurite) running in the background
+* The target language runtime (e.g. dotnet, nodejs, powershell, python, java) installed on your machine. Refer the official supported versions.
* Update settings in Azure Function or the `local.settings.json` file for local development with the following keys:
- 1. For Azure, `AZURE_OPENAI_ENDPOINT` - [Azure OpenAI resource](https://learn.microsoft.com/azure/ai-services/openai/how-to/create-resource?pivots=web-portal) (e.g. `https://***.openai.azure.com/`) set.
- 1. For Azure, assign the user or function app managed identity `Cognitive Services OpenAI User` role on the Azure OpenAI resource. It is strongly recommended to use managed identity to avoid overhead of secrets maintenance, however if there is a need for key based authentication add the setting `AZURE_OPENAI_KEY` and its value in the settings.
- 1. For non- Azure, `OPENAI_API_KEY` - An OpenAI account and an [API key](https://platform.openai.com/account/api-keys) saved into a setting.
- If using environment variables, Learn more in [.env readme](./env/README.md).
1. Update `CHAT_MODEL_DEPLOYMENT_NAME` and `EMBEDDING_MODEL_DEPLOYMENT_NAME` keys to Azure Deployment names or override default OpenAI model names.
- 1. If using user assigned managed identity, add `AZURE_CLIENT_ID` to environment variable settings with value as client id of the managed identity.
1. Visit binding specific samples README for additional settings that might be required for each binding.
-* Azure Storage emulator such as [Azurite](https://learn.microsoft.com/azure/storage/common/storage-use-azurite) running in the background
-* The target language runtime (e.g. dotnet, nodejs, powershell, python, java) installed on your machine. Refer the official supported versions.
+ 1. Refer [Configuring AI Service Connections](#configuring-ai-service-connections)
+
+## Configuring AI Service Connections
+
+The Azure Functions OpenAI Extension provides flexible options for configuring connections to AI services through the `AIConnectionName` property in the AssistantPost, TextCompletion, SemanticSearch, EmbeddingsStore, Embeddings bindings
+
+### Managed Identity Role
+
+Strongly recommended to use managed identity and ensure the user or function app's managed identity has the role - `Cognitive Services OpenAI User`
+
+### AIConnectionName Property
+
+The optional `AIConnectionName` property specifies the name of a configuration section that contains connection details for the AI service:
+
+#### For Azure OpenAI Service
+
+* If specified, the extension looks for `Endpoint` and `Key` values in the named configuration section
+* If not specified or the configuration section doesn't exist, the extension falls back to environment variables:
+ * `AZURE_OPENAI_ENDPOINT` and/or
+ * `AZURE_OPENAI_KEY`
+* For user-assigned managed identity authentication, a configuration section is required
+
+ ```json
+ "__endpoint": "Placeholder for the Azure OpenAI endpoint value",
+ "__credential": "managedidentity",
+ "__managedIdentityResourceId": "Resource Id of managed identity",
+ "__clientId": "Client Id of managed identity"
+ ```
+
+ * Only one of managedIdentityResourceId or clientId should be specified, not both.
+ * If no Resource Id or Client Id is specified, the system-assigned managed identity will be used by default.
+ * Pass the configured `ConnectionNamePrefix` value, example `AzureOpenAI` to the `AIConnectionName` property.
+
+#### For OpenAI Service (non-Azure)
+
+* Set the `OPENAI_API_KEY` environment variable
+
+### Configuration Examples
+
+#### Example: Using a configuration section
+
+In `local.settings.json` or app environment variables:
+
+```json
+"AzureOpenAI__endpoint": "Placeholder for the Azure OpenAI endpoint value",
+"AzureOpenAI__credential": "managedidentity",
+```
+
+Specifying credential is optional for system assigned managed identity
+
+Function usage example:
+
+```csharp
+[Function(nameof(PostUserResponse))]
+public static IActionResult PostUserResponse(
+ [HttpTrigger(AuthorizationLevel.Function, "post", Route = "chats/{chatId}")] HttpRequestData req,
+ string chatId,
+ [AssistantPostInput("{chatId}", "{Query.message}", AIConnectionName = "AzureOpenAI", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state)
+{
+ return new OkObjectResult(state.RecentMessages.LastOrDefault()?.Content ?? "No response returned.");
+}
+```
+
+## Using Reasoning Models
+
+If using reasoning models, set the `IsReasoningModel` property to true in `AssistantPost`, `SemanticSearch` and `TextCompletion` bindings. This is required due to difference in expected properties for reasoning models.
## Features
@@ -270,7 +331,7 @@ public static async Task IngestFile(
public class EmbeddingsStoreOutputResponse
{
- [EmbeddingsStoreOutput("{url}", InputType.Url, "AISearchEndpoint", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")]
+ [EmbeddingsStoreOutput("{url}", InputType.Url, "AISearchEndpoint", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")]
public required SearchableDocument SearchableDocument { get; init; }
public IActionResult? HttpResponse { get; set; }
diff --git a/eng/ci/templates/build-local.yml b/eng/ci/templates/build-local.yml
index a12251f8..b11122e1 100644
--- a/eng/ci/templates/build-local.yml
+++ b/eng/ci/templates/build-local.yml
@@ -20,8 +20,13 @@ jobs:
dotnet build $(System.DefaultWorkingDirectory)/src/WebJobs.Extensions.OpenAI.Kusto/WebJobs.Extensions.OpenAI.Kusto.csproj --configuration $(config) -p:Version=$(fakeWebJobsPackageVersion) -p:AzureAISearchVersion=$(fakeWebJobsPackageVersion) -p:KustoVersion=$(fakeWebJobsPackageVersion)
dotnet build $(System.DefaultWorkingDirectory)/src/WebJobs.Extensions.OpenAI.AzureAISearch/WebJobs.Extensions.OpenAI.AzureAISearch.csproj --configuration $(config) -p:Version=$(fakeWebJobsPackageVersion) -p:AzureAISearchVersion=$(fakeWebJobsPackageVersion)
dotnet build $(System.DefaultWorkingDirectory)/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/WebJobs.Extensions.OpenAI.CosmosDBSearch.csproj --configuration $(config) -p:Version=$(fakeWebJobsPackageVersion) -p:CosmosDBSearchVersion=$(fakeWebJobsPackageVersion)
+ dotnet build $(System.DefaultWorkingDirectory)/tests/UnitTests/WebJobsOpenAIUnitTests.csproj --configuration $(config) -p:WebJobsVersion=$(fakeWebJobsPackageVersion) -p:Version=$(fakeWebJobsPackageVersion)
displayName: Dotnet Build WebJobs.Extensions.OpenAI
+ - script: |
+ dotnet test $(System.DefaultWorkingDirectory)/tests/UnitTests/WebJobsOpenAIUnitTests.csproj --configuration $(config) --collect "Code Coverage" --no-build
+ displayName: Dotnet Test WebJobsOpenAIUnitTests
+
- task: CopyFiles@2
displayName: 'Copy NuGet WebJobs.Extensions.OpenAI to local directory'
inputs:
diff --git a/java-library/CHANGELOG.md b/java-library/CHANGELOG.md
index 8e928b10..ba912f4d 100644
--- a/java-library/CHANGELOG.md
+++ b/java-library/CHANGELOG.md
@@ -7,9 +7,21 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## v0.5.0 - Unreleased
+### Breaking
+
+- model properties renamed to `chatModel` and `embeddingsModel` in assistantPost, embeddings and textCompletion bindings.
+- renamed connectionName to `searchConnectionName` in semanticSearch binding.
+- renamed connectionName to `storeConnectionName` in embeddingsStore binding.
+- renamed ChatMessage to `AssistantMessage`.
+- managed identity support through config section and binding parameter `aiConnectionName` in assistantPost, embeddings, embeddingsStore, semanticSearch and textCompletion bindings.
+
### Changed
-- Update azure-ai-openai from 1.0.0-beta.11 to 1.0.0-beta.14
+- Update azure-ai-openai from 1.0.0-beta.11 to 1.0.0-beta.16
+
+### Added
+
+- Introduced experimental `isReasoningModel` property to support reasoning models. Setting of max_completion_tokens and reasoning_effort is not supported with current underlying Azure.AI.OpenAI
## v0.4.0 - 2024/10/08
diff --git a/java-library/pom.xml b/java-library/pom.xml
index bf26b3d7..221bf968 100644
--- a/java-library/pom.xml
+++ b/java-library/pom.xml
@@ -88,7 +88,7 @@
com.azureazure-ai-openai
- 1.0.0-beta.14
+ 1.0.0-beta.16compile
diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/ChatMessage.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantMessage.java
similarity index 55%
rename from java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/ChatMessage.java
rename to java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantMessage.java
index 1d246201..11627eda 100644
--- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/ChatMessage.java
+++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantMessage.java
@@ -11,35 +11,23 @@
* Chat Message Entity which contains the content of the message, the role of the chat agent, and the name of the calling function if applicable.
*
*/
-public class ChatMessage {
+public class AssistantMessage {
private String content;
private String role;
- private String name;
+ private String toolCalls;
/**
- * Initializes a new instance of the ChatMessage class.
+ * Initializes a new instance of the AssistantMessage class.
*
- * @param content The content of the message.
- * @param role The role of the chat agent.
- */
- public ChatMessage(String content, String role) {
- this.content = content;
- this.role = role;
- }
-
-
- /**
- * Initializes a new instance of the ChatMessage class.
- *
- * @param content The content of the message.
- * @param role The role of the chat agent.
- * @param name The name of the calling function if applicable.
+ * @param content The content of the message.
+ * @param role The role of the chat agent.
+ * @param toolCalls The toolCalls of the calling function if applicable.
*/
- public ChatMessage(String content, String role, String name) {
+ public AssistantMessage(String content, String role, String toolCalls) {
this.content = content;
this.role = role;
- this.name = name;
+ this.toolCalls = toolCalls;
}
/**
@@ -79,21 +67,20 @@ public void setRole(String role) {
}
/**
- * Gets the name of the calling function if applicable.
+ * Gets the toolCalls of the calling function if applicable.
*
- * @return The name of the calling function if applicable.
+ * @return The toolCalls of the calling function if applicable.
*/
- public String getName() {
- return name;
+ public String getToolCalls() {
+ return toolCalls;
}
/**
- * Sets the name of the calling function if applicable.
+ * Sets the toolCalls of the calling function if applicable.
*
- * @param name The name of the calling function if applicable.
+ * @param toolCalls The toolCalls of the calling function if applicable.
*/
- public void setName(String name) {
- this.name = name;
+ public void setToolCalls(String toolCalls) {
+ this.toolCalls = toolCalls;
}
}
-
diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java
index 67a3cc3f..16be1347 100644
--- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java
+++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantPost.java
@@ -7,6 +7,7 @@
package com.microsoft.azure.functions.openai.annotation.assistant;
import com.microsoft.azure.functions.annotation.CustomBinding;
+import com.microsoft.azure.functions.openai.constants.ModelDefaults;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
@@ -61,7 +62,7 @@
*
* @return The OpenAI chat model to use.
*/
- String model();
+ String chatModel() default ModelDefaults.DEFAULT_CHAT_MODEL;
/**
* The user message that user has entered for assistant to respond to.
@@ -70,6 +71,14 @@
*/
String userMessage();
+ /**
+ * The name of the configuration section for AI service connectivity settings.
+ *
+ * @return The name of the configuration section for AI service connectivity
+ * settings.
+ */
+ String aiConnectionName() default "";
+
/**
* The configuration section name for the table settings for assistant chat
* storage.
@@ -86,4 +95,46 @@
* returns {@code DEFAULT_COLLECTION}.
*/
String collectionName() default DEFAULT_COLLECTION;
+
+ /**
+ * The sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ * make the output
+ * more random, while lower values like 0.2 will make it more focused and
+ * deterministic.
+ * It's generally recommended to use this or {@link #topP()} but not both.
+ *
+ * @return The sampling temperature value.
+ */
+ String temperature() default "0.5";
+
+ /**
+ * An alternative to sampling with temperature, called nucleus sampling, where
+ * the model considers
+ * the results of the tokens with top_p probability mass. So 0.1 means only the
+ * tokens comprising the top 10%
+ * probability mass are considered.
+ * It's generally recommended to use this or {@link #temperature()} but not
+ * both.
+ *
+ * @return The topP value.
+ */
+ String topP() default "";
+
+ /**
+ * The maximum number of tokens to generate in the completion.
+ * The token count of your prompt plus max_tokens cannot exceed the model's
+ * context length.
+ * Most models have a context length of 2048 tokens (except for the newest
+ * models, which support 4096).
+ *
+ * @return The maxTokens value.
+ */
+ String maxTokens() default "100";
+
+ /**
+ * Indicates whether the assistant uses a reasoning model.
+ *
+ * @return {@code true} if the assistant is based on a reasoning model; {@code false} otherwise.
+ */
+ boolean isReasoningModel() default false;
}
diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java
index 610a1fcd..36f67dbb 100644
--- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java
+++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantSkillTrigger.java
@@ -14,54 +14,46 @@
import java.lang.annotation.Target;
/**
- *
- * Assistant skill trigger attribute.
- *
- *
- * @since 1.0.0
- */
- @Retention(RetentionPolicy.RUNTIME)
- @Target(ElementType.PARAMETER)
- @CustomBinding(direction = "in", name = "", type = "assistantSkillTrigger")
- public @interface AssistantSkillTrigger {
-
- /**
- * The variable name used in function.json.
- *
- * @return The variable name used in function.json.
- */
- String name();
-
- /**
- * The name of the function to be invoked by the assistant.
- *
- * @return The name of the function to be invoked by the assistant.
- */
- String functionName() default "";
-
- /**
- * The description of the assistant function, which is provided to the LLM.
- *
- * @return The description of the assistant function.
- */
- String functionDescription();
-
- /**
- * The JSON description of the function parameter, which is provided to the LLM.
- * If no description is provided, the description will be autogenerated.
- * For more information on the syntax of the parameter description JSON, see the OpenAI API documentation:
- * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools.
- *
- * @return The JSON description of the function parameter.
- */
- String parameterDescriptionJson() default "";
-
- /**
- * The OpenAI chat model to use.
- * When using Azure OpenAI, this should be the name of the model deployment.
- *
- * @return The OpenAI chat model to use.
- */
- String model() default "gpt-3.5-turbo";
-
- }
+ *
+ * Assistant skill trigger attribute.
+ *
+ *
+ * @since 1.0.0
+ */
+@Retention(RetentionPolicy.RUNTIME)
+@Target(ElementType.PARAMETER)
+@CustomBinding(direction = "in", name = "", type = "assistantSkillTrigger")
+public @interface AssistantSkillTrigger {
+
+ /**
+ * The variable name used in function.json.
+ *
+ * @return The variable name used in function.json.
+ */
+ String name();
+
+ /**
+ * The name of the function to be invoked by the assistant.
+ *
+ * @return The name of the function to be invoked by the assistant.
+ */
+ String functionName() default "";
+
+ /**
+ * The description of the assistant function, which is provided to the LLM.
+ *
+ * @return The description of the assistant function.
+ */
+ String functionDescription();
+
+ /**
+ * The JSON description of the function parameter, which is provided to the LLM.
+ * If no description is provided, the description will be autogenerated.
+ * For more information on the syntax of the parameter description JSON, see the
+ * OpenAI API documentation:
+ * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools.
+ *
+ * @return The JSON description of the function parameter.
+ */
+ String parameterDescriptionJson() default "";
+}
diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantState.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantState.java
index 4ebcb0ea..0f88cfc1 100644
--- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantState.java
+++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/assistant/AssistantState.java
@@ -20,11 +20,11 @@ public class AssistantState {
private String lastUpdatedAt;
private int totalMessages;
private int totalTokens;
- private List recentMessages;
+ private List recentMessages;
public AssistantState(String id, boolean exists,
- String createdAt, String lastUpdatedAt,
- int totalMessages, int totalTokens, List recentMessages) {
+ String createdAt, String lastUpdatedAt,
+ int totalMessages, int totalTokens, List recentMessages) {
this.id = id;
this.exists = exists;
this.createdAt = createdAt;
@@ -56,7 +56,7 @@ public boolean isExists() {
* Gets timestamp of when assistant is created.
*
* @return The timestamp of when assistant is created.
- */
+ */
public String getCreatedAt() {
return createdAt;
}
@@ -93,7 +93,7 @@ public int getTotalTokens() {
*
* @return A list of the recent messages from the assistant.
*/
- public List getRecentMessages() {
+ public List getRecentMessages() {
return recentMessages;
}
@@ -156,8 +156,8 @@ public void setTotalTokens(int totalTokens) {
*
* @param recentMessages A list of the recent messages from the assistant.
*/
- public void setRecentMessages(List recentMessages) {
+ public void setRecentMessages(List recentMessages) {
this.recentMessages = recentMessages;
}
-
+
}
diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java
index 1627b380..3e04c8cf 100644
--- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java
+++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsContext.java
@@ -7,19 +7,19 @@
package com.microsoft.azure.functions.openai.annotation.embeddings;
import com.azure.ai.openai.models.Embeddings;
-import com.azure.ai.openai.models.EmbeddingsOptions;
+import java.util.List;
public class EmbeddingsContext {
- private EmbeddingsOptions request;
+ private List request;
private Embeddings response;
private int count = 0;
- public EmbeddingsOptions getRequest() {
+ public List getRequest() {
return request;
}
- public void setRequest(EmbeddingsOptions request) {
+ public void setRequest(List request) {
this.request = request;
}
@@ -38,7 +38,8 @@ public void setResponse(Embeddings response) {
*/
public int getCount() {
return this.response != null && this.response.getData() != null
- ? this.count : 0;
+ ? this.count
+ : 0;
}
}
diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java
index 39d46691..5d1d911b 100644
--- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java
+++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsInput.java
@@ -7,12 +7,13 @@
package com.microsoft.azure.functions.openai.annotation.embeddings;
import com.microsoft.azure.functions.annotation.CustomBinding;
+import com.microsoft.azure.functions.openai.constants.ModelDefaults;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
-
+
/**
*
* Embeddings input attribute which is used for embedding generation.
@@ -34,17 +35,21 @@
/**
* The ID of the model to use.
- * Changing the default embeddings model is a breaking change, since any changes will be stored in a vector database for lookup.
- * Changing the default model can cause the lookups to start misbehaving if they don't match the data that was previously ingested into the vector database.
+ * Changing the default embeddings model is a breaking change, since any changes
+ * will be stored in a vector database for lookup.
+ * Changing the default model can cause the lookups to start misbehaving if they
+ * don't match the data that was previously ingested into the vector database.
*
* @return The model ID.
*/
- String model() default "text-embedding-ada-002";
+ String embeddingsModel() default ModelDefaults.DEFAULT_EMBEDDINGS_MODEL;
/**
* The maximum number of characters to chunk the input into.
- * At the time of writing, the maximum input tokens allowed for second-generation input embedding models
- * like text-embedding-ada-002 is 8191. 1 token is ~4 chars in English, which translates to roughly 32K
+ * At the time of writing, the maximum input tokens allowed for
+ * second-generation input embedding models
+ * like text-embedding-ada-002 is 8191. 1 token is ~4 chars in English, which
+ * translates to roughly 32K
* characters of English input that can fit into a single chunk.
*
* @return The maximum number of characters to chunk the input into.
@@ -65,11 +70,19 @@
*/
String input();
+ /**
+ * The name of the configuration section for AI service connectivity settings.
+ *
+ * @return The name of the configuration section for AI service connectivity
+ * settings.
+ */
+ String aiConnectionName() default "";
+
/**
* The input type.
*
* @return The input type.
*/
InputType inputType();
-
- }
+
+}
diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java
index 87402766..ee333138 100644
--- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java
+++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/embeddings/EmbeddingsStoreOutput.java
@@ -7,6 +7,7 @@
package com.microsoft.azure.functions.openai.annotation.embeddings;
import com.microsoft.azure.functions.annotation.CustomBinding;
+import com.microsoft.azure.functions.openai.constants.ModelDefaults;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
@@ -27,17 +28,21 @@
/**
* The ID of the model to use.
- * Changing the default embeddings model is a breaking change, since any changes will be stored in a vector database for lookup.
- * Changing the default model can cause the lookups to start misbehaving if they don't match the data that was previously ingested into the vector database.
+ * Changing the default embeddings model is a breaking change, since any changes
+ * will be stored in a vector database for lookup.
+ * Changing the default model can cause the lookups to start misbehaving if they
+ * don't match the data that was previously ingested into the vector database.
*
* @return The model ID.
*/
- String model() default "text-embedding-ada-002";
+ String embeddingsModel() default ModelDefaults.DEFAULT_EMBEDDINGS_MODEL;
/**
* The maximum number of characters to chunk the input into.
- * At the time of writing, the maximum input tokens allowed for second-generation input embedding models
- * like text-embedding-ada-002 is 8191. 1 token is ~4 chars in English, which translates to roughly 32K
+ * At the time of writing, the maximum input tokens allowed for
+ * second-generation input embedding models
+ * like text-embedding-ada-002 is 8191. 1 token is ~4 chars in English, which
+ * translates to roughly 32K
* characters of English input that can fit into a single chunk.
*
* @return The maximum number of characters to chunk the input into.
@@ -66,13 +71,22 @@
InputType inputType();
/**
- * The name of an app setting or environment variable which contains a connection string value.
+ * The name of the configuration section for AI service connectivity settings.
+ *
+ * @return The name of the configuration section for AI service connectivity
+ * settings.
+ */
+ String aiConnectionName() default "";
+
+ /**
+ * The name of an app setting or environment variable which contains a
+ * connection string value.
* This property supports binding expressions.
*
- * @return The connection name.
+ * @return The store connection name.
*/
- String connectionName();
-
+ String storeConnectionName();
+
/**
* The name of the collection or table to search.
* This property supports binding expressions.
diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java
index 87941b85..d3d0c206 100644
--- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java
+++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/search/SemanticSearch.java
@@ -7,11 +7,12 @@
package com.microsoft.azure.functions.openai.annotation.search;
import com.microsoft.azure.functions.annotation.CustomBinding;
+import com.microsoft.azure.functions.openai.constants.ModelDefaults;
- import java.lang.annotation.ElementType;
- import java.lang.annotation.Retention;
- import java.lang.annotation.RetentionPolicy;
- import java.lang.annotation.Target;
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.PARAMETER)
@@ -26,12 +27,21 @@
String name();
/**
- * The name of an app setting or environment variable which contains a connection string value.
+ * The name of the configuration section for AI service connectivity settings.
+ *
+ * @return The name of the configuration section for AI service connectivity
+ * settings.
+ */
+ String aiConnectionName() default "";
+
+ /**
+ * The name of an app setting or environment variable which contains a
+ * connection string value.
* This property supports binding expressions.
*
* @return The connection name.
*/
- String connectionName();
+ String searchConnectionName();
/**
* The name of the collection or table to search.
@@ -50,7 +60,6 @@
*/
String query() default "";
-
/**
* The model to use for embeddings.
* The default value is "text-embedding-ada-002".
@@ -58,21 +67,21 @@
*
* @return The model to use for embeddings.
*/
- String embeddingsModel() default "text-embedding-ada-002";
+ String embeddingsModel() default ModelDefaults.DEFAULT_EMBEDDINGS_MODEL;
/**
* the name of the Large Language Model to invoke for chat responses.
* The default value is "gpt-3.5-turbo".
* This property supports binding expressions.
- *
+ *
* @return The name of the Large Language Model to invoke for chat responses.
*/
- String chatModel() default "gpt-3.5-turbo";
-
+ String chatModel() default ModelDefaults.DEFAULT_CHAT_MODEL;
/**
* The system prompt to use for prompting the large language model.
- * The system prompt will be appended with knowledge that is fetched as a result of the Query.
+ * The system prompt will be appended with knowledge that is fetched as a result
+ * of the Query.
* The combined prompt will then be sent to the OpenAI Chat API.
* This property supports binding expressions.
*
@@ -94,4 +103,46 @@ String systemPrompt() default "You are a helpful assistant. You are responding t
* @return The number of knowledge items to inject into the SystemPrompt.
*/
int maxKnowledgeLength() default 1;
+
+ /**
+ * The sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ * make the output
+ * more random, while lower values like 0.2 will make it more focused and
+ * deterministic.
+ * It's generally recommended to use this or {@link #topP()} but not both.
+ *
+ * @return The sampling temperature value.
+ */
+ String temperature() default "0.5";
+
+ /**
+ * An alternative to sampling with temperature, called nucleus sampling, where
+ * the model considers
+ * the results of the tokens with top_p probability mass. So 0.1 means only the
+ * tokens comprising the top 10%
+ * probability mass are considered.
+ * It's generally recommended to use this or {@link #temperature()} but not
+ * both.
+ *
+ * @return The topP value.
+ */
+ String topP() default "";
+
+ /**
+ * The maximum number of tokens to generate in the completion.
+ * The token count of your prompt plus max_tokens cannot exceed the model's
+ * context length.
+ * Most models have a context length of 2048 tokens (except for the newest
+ * models, which support 4096).
+ *
+ * @return The maxTokens value.
+ */
+ String maxTokens() default "2048";
+
+ /**
+ * Indicates whether the assistant uses a reasoning model.
+ *
+ * @return {@code true} if the assistant is based on a reasoning model; {@code false} otherwise.
+ */
+ boolean isReasoningModel();
}
diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java
index 1d024ce4..879b0c06 100644
--- a/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java
+++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/annotation/textcompletion/TextCompletion.java
@@ -7,15 +7,17 @@
package com.microsoft.azure.functions.openai.annotation.textcompletion;
import com.microsoft.azure.functions.annotation.CustomBinding;
+import com.microsoft.azure.functions.openai.constants.ModelDefaults;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
-
+
/**
*
- * Assistant query input attribute which is used query the Assistant to get current state.
+ * Assistant query input attribute which is used query the Assistant to get
+ * current state.
*
*
* @since 1.0.0
@@ -39,16 +41,26 @@
*/
String prompt();
+ /**
+ * The name of the configuration section for AI service connectivity settings.
+ *
+ * @return The name of the configuration section for AI service connectivity
+ * settings.
+ */
+ String aiConnectionName() default "";
+
/**
* The ID of the model to use.
*
* @return The model ID.
*/
- String model() default "gpt-3.5-turbo";
+ String chatModel() default ModelDefaults.DEFAULT_CHAT_MODEL;
/**
- * The sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output
- * more random, while lower values like 0.2 will make it more focused and deterministic.
+ * The sampling temperature to use, between 0 and 2. Higher values like 0.8 will
+ * make the output
+ * more random, while lower values like 0.2 will make it more focused and
+ * deterministic.
* It's generally recommended to use this or {@link #topP()} but not both.
*
* @return The sampling temperature value.
@@ -56,10 +68,13 @@
String temperature() default "0.5";
/**
- * An alternative to sampling with temperature, called nucleus sampling, where the model considers
- * the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%
+ * An alternative to sampling with temperature, called nucleus sampling, where
+ * the model considers
+ * the results of the tokens with top_p probability mass. So 0.1 means only the
+ * tokens comprising the top 10%
* probability mass are considered.
- * It's generally recommended to use this or {@link #temperature()} but not both.
+ * It's generally recommended to use this or {@link #temperature()} but not
+ * both.
*
* @return The topP value.
*/
@@ -67,11 +82,19 @@
/**
* The maximum number of tokens to generate in the completion.
- * The token count of your prompt plus max_tokens cannot exceed the model's context length.
- * Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
+ * The token count of your prompt plus max_tokens cannot exceed the model's
+ * context length.
+ * Most models have a context length of 2048 tokens (except for the newest
+ * models, which support 4096).
*
* @return The maxTokens value.
*/
String maxTokens() default "100";
-
- }
+
+ /**
+ * Indicates whether the assistant uses a reasoning model.
+ *
+ * @return {@code true} if the assistant is based on a reasoning model; {@code false} otherwise.
+ */
+ boolean isReasoningModel();
+}
diff --git a/java-library/src/main/java/com/microsoft/azure/functions/openai/constants/ModelDefaults.java b/java-library/src/main/java/com/microsoft/azure/functions/openai/constants/ModelDefaults.java
new file mode 100644
index 00000000..0d420625
--- /dev/null
+++ b/java-library/src/main/java/com/microsoft/azure/functions/openai/constants/ModelDefaults.java
@@ -0,0 +1,33 @@
+/**
+ * Copyright (c) Microsoft Corporation. All rights reserved.
+ * Licensed under the MIT License. See License.txt in the project root for
+ * license information.
+ */
+
+package com.microsoft.azure.functions.openai.constants;
+
+/**
+ * Constants for default model values used throughout the library.
+ * Centralizing these values makes it easier to update them across the codebase.
+ *
+ * @since 1.0.0
+ */
+public final class ModelDefaults {
+
+ /**
+ * Private constructor to prevent instantiation.
+ */
+ private ModelDefaults() {
+ // Utility class should not be instantiated
+ }
+
+ /**
+ * The default chat completion model used for text generation.
+ */
+ public static final String DEFAULT_CHAT_MODEL = "gpt-3.5-turbo";
+
+ /**
+ * The default embeddings model used for vector embeddings.
+ */
+ public static final String DEFAULT_EMBEDDINGS_MODEL = "text-embedding-ada-002";
+}
diff --git a/samples/assistant/csharp-legacy/AssistantApis.cs b/samples/assistant/csharp-legacy/AssistantApis.cs
index 16f9d5bb..4bdacd5e 100644
--- a/samples/assistant/csharp-legacy/AssistantApis.cs
+++ b/samples/assistant/csharp-legacy/AssistantApis.cs
@@ -51,7 +51,7 @@ public static async Task CreateAssistant(
public static IActionResult PostUserQuery(
[HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = "assistants/{assistantId}")] HttpRequest req,
string assistantId,
- [AssistantPost("{assistantId}", "{Query.message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState updatedState)
+ [AssistantPost("{assistantId}", "{Query.message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState updatedState)
{
return new OkObjectResult(updatedState.RecentMessages.Any() ? updatedState.RecentMessages[updatedState.RecentMessages.Count - 1].Content : "No response returned.");
}
diff --git a/samples/assistant/csharp-legacy/AssistantSample.csproj b/samples/assistant/csharp-legacy/AssistantSample.csproj
index d60abd4a..711f9628 100644
--- a/samples/assistant/csharp-legacy/AssistantSample.csproj
+++ b/samples/assistant/csharp-legacy/AssistantSample.csproj
@@ -7,7 +7,7 @@
-
+
diff --git a/samples/assistant/csharp-ooproc/AssistantApis.cs b/samples/assistant/csharp-ooproc/AssistantApis.cs
index 5d425c8f..a98b6fcd 100644
--- a/samples/assistant/csharp-ooproc/AssistantApis.cs
+++ b/samples/assistant/csharp-ooproc/AssistantApis.cs
@@ -59,10 +59,10 @@ public class CreateChatBotOutput
/// HTTP POST function that sends user prompts to the assistant chat bot.
///
[Function(nameof(PostUserQuery))]
- public static async Task PostUserQuery(
+ public static IActionResult PostUserQuery(
[HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = "assistants/{assistantId}")] HttpRequestData req,
string assistantId,
- [AssistantPostInput("{assistantId}", "{Query.message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state)
+ [AssistantPostInput("{assistantId}", "{Query.message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state)
{
return new OkObjectResult(state.RecentMessages.Any() ? state.RecentMessages[state.RecentMessages.Count - 1].Content : "No response returned.");
}
@@ -71,7 +71,7 @@ public static async Task PostUserQuery(
/// HTTP GET function that queries the conversation history of the assistant chat bot.
///
[Function(nameof(GetChatState))]
- public static async Task GetChatState(
+ public static IActionResult GetChatState(
[HttpTrigger(AuthorizationLevel.Anonymous, "get", Route = "assistants/{assistantId}")] HttpRequestData req,
string assistantId,
[AssistantQueryInput("{assistantId}", TimestampUtc = "{Query.timestampUTC}", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state)
diff --git a/samples/assistant/csharp-ooproc/AssistantSample.csproj b/samples/assistant/csharp-ooproc/AssistantSample.csproj
index d6e9140a..2087a8de 100644
--- a/samples/assistant/csharp-ooproc/AssistantSample.csproj
+++ b/samples/assistant/csharp-ooproc/AssistantSample.csproj
@@ -9,11 +9,11 @@
-
+
-
+
diff --git a/samples/assistant/java/extensions.csproj b/samples/assistant/java/extensions.csproj
index eb243f13..ad5a93c6 100644
--- a/samples/assistant/java/extensions.csproj
+++ b/samples/assistant/java/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**target/azure-functions/azfs-java-openai-sample/bin
diff --git a/samples/assistant/javascript/src/functions/assistantApis.js b/samples/assistant/javascript/src/functions/assistantApis.js
index a8eaced4..3d167a10 100644
--- a/samples/assistant/javascript/src/functions/assistantApis.js
+++ b/samples/assistant/javascript/src/functions/assistantApis.js
@@ -36,7 +36,7 @@ app.http('CreateAssistant', {
const assistantPostInput = input.generic({
type: 'assistantPost',
id: '{assistantId}',
- model: '%CHAT_MODEL_DEPLOYMENT_NAME%',
+ chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%',
userMessage: '{Query.message}',
chatStorageConnectionSetting: CHAT_STORAGE_CONNECTION_SETTING,
collectionName: COLLECTION_NAME
diff --git a/samples/assistant/powershell/PostUserQuery/function.json b/samples/assistant/powershell/PostUserQuery/function.json
index 024cd24f..05b7702c 100644
--- a/samples/assistant/powershell/PostUserQuery/function.json
+++ b/samples/assistant/powershell/PostUserQuery/function.json
@@ -22,7 +22,7 @@
"dataType": "string",
"id": "{assistantId}",
"userMessage": "{Query.message}",
- "model": "%CHAT_MODEL_DEPLOYMENT_NAME%",
+ "chatModel": "%CHAT_MODEL_DEPLOYMENT_NAME%",
"chatStorageConnectionSetting": "AzureWebJobsStorage",
"collectionName": "ChatState"
}
diff --git a/samples/assistant/typescript/src/functions/assistantApis.ts b/samples/assistant/typescript/src/functions/assistantApis.ts
index 35bdd59c..c3bd70bf 100644
--- a/samples/assistant/typescript/src/functions/assistantApis.ts
+++ b/samples/assistant/typescript/src/functions/assistantApis.ts
@@ -36,7 +36,7 @@ app.http('CreateAssistant', {
const assistantPostInput = input.generic({
type: 'assistantPost',
id: '{assistantId}',
- model: '%CHAT_MODEL_DEPLOYMENT_NAME%',
+ chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%',
userMessage: '{Query.message}',
chatStorageConnectionSetting: CHAT_STORAGE_CONNECTION_SETTING,
collectionName: COLLECTION_NAME
diff --git a/samples/chat/README.md b/samples/chat/README.md
index d554a49b..1c030710 100644
--- a/samples/chat/README.md
+++ b/samples/chat/README.md
@@ -71,7 +71,7 @@ For additional details on using identity-based connections, refer to the [Azure
public static async Task PostUserResponse(
[HttpTrigger(AuthorizationLevel.Function, "post", Route = "chats/{chatId}")] HttpRequestData req,
string chatId,
- [AssistantPostInput("{chatId}", "{message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] AssistantState state)
+ [AssistantPostInput("{chatId}", "{message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] AssistantState state)
```
diff --git a/samples/chat/csharp-legacy/ChatBot.cs b/samples/chat/csharp-legacy/ChatBot.cs
index e7403df7..358d61d8 100644
--- a/samples/chat/csharp-legacy/ChatBot.cs
+++ b/samples/chat/csharp-legacy/ChatBot.cs
@@ -48,10 +48,10 @@ public static AssistantState GetChatState(
}
[FunctionName(nameof(PostUserResponse))]
- public static async Task PostUserResponse(
+ public static IActionResult PostUserResponse(
[HttpTrigger(AuthorizationLevel.Anonymous, "post", Route = "chats/{chatId}")] HttpRequest req,
string chatId,
- [AssistantPost("{chatId}", "{Query.message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState updatedState)
+ [AssistantPost("{chatId}", "{Query.message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState updatedState)
{
return new OkObjectResult(updatedState.RecentMessages.Any() ? updatedState.RecentMessages[updatedState.RecentMessages.Count - 1].Content : "No response returned.");
}
diff --git a/samples/chat/csharp-ooproc/ChatBot.cs b/samples/chat/csharp-ooproc/ChatBot.cs
index 2a94b4b4..9db059ce 100644
--- a/samples/chat/csharp-ooproc/ChatBot.cs
+++ b/samples/chat/csharp-ooproc/ChatBot.cs
@@ -23,8 +23,8 @@ public class CreateRequest
[Function(nameof(CreateChatBot))]
public static async Task CreateChatBot(
- [HttpTrigger(AuthorizationLevel.Function, "put", Route = "chats/{chatId}")] HttpRequestData req,
- string chatId)
+ [HttpTrigger(AuthorizationLevel.Function, "put", Route = "chats/{chatId}")] HttpRequestData req,
+ string chatId)
{
var responseJson = new { chatId };
CreateRequest? createRequestBody;
@@ -39,7 +39,7 @@ public static async Task CreateChatBot(
}
catch (Exception ex)
{
- throw new ArgumentException("Invalid request body. Make sure that you pass in {\"instructions\": value } as the request body.", ex.Message);
+ throw new ArgumentException("Invalid request body. Make sure that you pass in {\"instructions\": value } as the request body.", ex);
}
return new CreateChatBotOutput
@@ -63,16 +63,16 @@ public class CreateChatBotOutput
}
[Function(nameof(PostUserResponse))]
- public static async Task PostUserResponse(
+ public static IActionResult PostUserResponse(
[HttpTrigger(AuthorizationLevel.Function, "post", Route = "chats/{chatId}")] HttpRequestData req,
string chatId,
- [AssistantPostInput("{chatId}", "{Query.message}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state)
+ [AssistantPostInput("{chatId}", "{Query.message}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state)
{
return new OkObjectResult(state.RecentMessages.LastOrDefault()?.Content ?? "No response returned.");
}
[Function(nameof(GetChatState))]
- public static async Task GetChatState(
+ public static IActionResult GetChatState(
[HttpTrigger(AuthorizationLevel.Function, "get", Route = "chats/{chatId}")] HttpRequestData req,
string chatId,
[AssistantQueryInput("{chatId}", TimestampUtc = "{Query.timestampUTC}", ChatStorageConnectionSetting = DefaultChatStorageConnectionSetting, CollectionName = DefaultCollectionName)] AssistantState state,
diff --git a/samples/chat/csharp-ooproc/ChatBot.csproj b/samples/chat/csharp-ooproc/ChatBot.csproj
index 12dd29a8..e45a1e8c 100644
--- a/samples/chat/csharp-ooproc/ChatBot.csproj
+++ b/samples/chat/csharp-ooproc/ChatBot.csproj
@@ -11,7 +11,7 @@
-
+
diff --git a/samples/chat/java/extensions.csproj b/samples/chat/java/extensions.csproj
index eb243f13..ad5a93c6 100644
--- a/samples/chat/java/extensions.csproj
+++ b/samples/chat/java/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**target/azure-functions/azfs-java-openai-sample/bin
diff --git a/samples/chat/javascript/src/app.js b/samples/chat/javascript/src/app.js
index f0a86ddd..4b3b3409 100644
--- a/samples/chat/javascript/src/app.js
+++ b/samples/chat/javascript/src/app.js
@@ -52,7 +52,7 @@ app.http('GetChatState', {
const assistantPostInput = input.generic({
type: 'assistantPost',
id: '{chatID}',
- model: '%CHAT_MODEL_DEPLOYMENT_NAME%',
+ chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%',
userMessage: '{Query.message}',
chatStorageConnectionSetting: CHAT_STORAGE_CONNECTION_SETTING,
collectionName: COLLECTION_NAME
diff --git a/samples/chat/powershell/PostUserResponse/function.json b/samples/chat/powershell/PostUserResponse/function.json
index 16142e20..46a0ba09 100644
--- a/samples/chat/powershell/PostUserResponse/function.json
+++ b/samples/chat/powershell/PostUserResponse/function.json
@@ -20,7 +20,7 @@
"direction": "in",
"name": "ChatBotState",
"id": "{chatId}",
- "model": "%CHAT_MODEL_DEPLOYMENT_NAME%",
+ "chatModel": "%CHAT_MODEL_DEPLOYMENT_NAME%",
"userMessage": "{Query.message}",
"chatStorageConnectionSetting": "AzureWebJobsStorage",
"collectionName": "ChatState"
diff --git a/samples/chat/powershell/extensions.csproj b/samples/chat/powershell/extensions.csproj
index 8f73abd3..82ae1923 100644
--- a/samples/chat/powershell/extensions.csproj
+++ b/samples/chat/powershell/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**bin
diff --git a/samples/chat/typescript/src/functions/app.ts b/samples/chat/typescript/src/functions/app.ts
index 84a12075..a5cb025c 100644
--- a/samples/chat/typescript/src/functions/app.ts
+++ b/samples/chat/typescript/src/functions/app.ts
@@ -52,7 +52,7 @@ app.http('GetChatState', {
const assistantPostInput = input.generic({
type: 'assistantPost',
id: '{chatID}',
- model: '%CHAT_MODEL_DEPLOYMENT_NAME%',
+ chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%',
userMessage: '{Query.message}',
chatStorageConnectionSetting: CHAT_STORAGE_CONNECTION_SETTING,
collectionName: COLLECTION_NAME
diff --git a/samples/embeddings/csharp-legacy/EmbeddingsLegacy.cs b/samples/embeddings/csharp-legacy/EmbeddingsLegacy.cs
index 58688120..12baeeca 100644
--- a/samples/embeddings/csharp-legacy/EmbeddingsLegacy.cs
+++ b/samples/embeddings/csharp-legacy/EmbeddingsLegacy.cs
@@ -1,7 +1,6 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
-using System.Linq;
using Microsoft.Azure.WebJobs;
using Microsoft.Azure.WebJobs.Extensions.Http;
using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings;
@@ -23,13 +22,13 @@ public record EmbeddingsRequest(string RawText, string FilePath, string Url);
[FunctionName(nameof(GenerateEmbeddings_Http_Request))]
public static void GenerateEmbeddings_Http_Request(
[HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings")] EmbeddingsRequest req,
- [Embeddings("{RawText}", InputType.RawText)] EmbeddingsContext embeddings,
+ [Embeddings("{RawText}", InputType.RawText, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings,
ILogger logger)
{
logger.LogInformation(
- "Received {count} embedding(s) for input text containing {length} characters.",
- embeddings.Count,
- req.RawText.Length);
+ "Received {count} embedding(s) for input text containing {length} characters.",
+ embeddings.Count,
+ req.RawText.Length);
// TODO: Store the embeddings into a database or other storage.
}
@@ -41,7 +40,7 @@ public static void GenerateEmbeddings_Http_Request(
[FunctionName(nameof(GetEmbeddings_Http_FilePath))]
public static void GetEmbeddings_Http_FilePath(
[HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings-from-file")] EmbeddingsRequest req,
- [Embeddings("{FilePath}", InputType.FilePath, MaxChunkLength = 512)] EmbeddingsContext embeddings,
+ [Embeddings("{FilePath}", InputType.FilePath, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%", MaxChunkLength = 512)] EmbeddingsContext embeddings,
ILogger logger)
{
logger.LogInformation(
@@ -59,7 +58,7 @@ public static void GetEmbeddings_Http_FilePath(
[FunctionName(nameof(GetEmbeddings_Http_Url))]
public static void GetEmbeddings_Http_Url(
[HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings-from-url")] EmbeddingsRequest req,
- [Embeddings("{Url}", InputType.Url, MaxChunkLength = 512)] EmbeddingsContext embeddings,
+ [Embeddings("{Url}", InputType.Url, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%", MaxChunkLength = 512)] EmbeddingsContext embeddings,
ILogger logger)
{
logger.LogInformation(
diff --git a/samples/embeddings/csharp-ooproc/Embeddings/Embeddings.csproj b/samples/embeddings/csharp-ooproc/Embeddings/Embeddings.csproj
index 61b2ebf6..464519f2 100644
--- a/samples/embeddings/csharp-ooproc/Embeddings/Embeddings.csproj
+++ b/samples/embeddings/csharp-ooproc/Embeddings/Embeddings.csproj
@@ -9,7 +9,7 @@
-
+
diff --git a/samples/embeddings/csharp-ooproc/Embeddings/EmbeddingsGenerator.cs b/samples/embeddings/csharp-ooproc/Embeddings/EmbeddingsGenerator.cs
index 4b42dcca..84c5fa41 100644
--- a/samples/embeddings/csharp-ooproc/Embeddings/EmbeddingsGenerator.cs
+++ b/samples/embeddings/csharp-ooproc/Embeddings/EmbeddingsGenerator.cs
@@ -41,7 +41,7 @@ internal class EmbeddingsRequest
[Function(nameof(GenerateEmbeddings_Http_RequestAsync))]
public async Task GenerateEmbeddings_Http_RequestAsync(
[HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings")] HttpRequestData req,
- [EmbeddingsInput("{rawText}", InputType.RawText, Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings)
+ [EmbeddingsInput("{rawText}", InputType.RawText, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings)
{
using StreamReader reader = new(req.Body);
string request = await reader.ReadToEndAsync();
@@ -63,7 +63,7 @@ public async Task GenerateEmbeddings_Http_RequestAsync(
[Function(nameof(GetEmbeddings_Http_FilePath))]
public async Task GetEmbeddings_Http_FilePath(
[HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings-from-file")] HttpRequestData req,
- [EmbeddingsInput("{filePath}", InputType.FilePath, MaxChunkLength = 512, Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings)
+ [EmbeddingsInput("{filePath}", InputType.FilePath, MaxChunkLength = 512, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings)
{
using StreamReader reader = new(req.Body);
string request = await reader.ReadToEndAsync();
@@ -84,7 +84,7 @@ public async Task GetEmbeddings_Http_FilePath(
[Function(nameof(GetEmbeddings_Http_URL))]
public async Task GetEmbeddings_Http_URL(
[HttpTrigger(AuthorizationLevel.Function, "post", Route = "embeddings-from-url")] HttpRequestData req,
- [EmbeddingsInput("{url}", InputType.Url, MaxChunkLength = 512, Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings)
+ [EmbeddingsInput("{url}", InputType.Url, MaxChunkLength = 512, EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")] EmbeddingsContext embeddings)
{
using StreamReader reader = new(req.Body);
string request = await reader.ReadToEndAsync();
diff --git a/samples/embeddings/java/extensions.csproj b/samples/embeddings/java/extensions.csproj
index eb243f13..ad5a93c6 100644
--- a/samples/embeddings/java/extensions.csproj
+++ b/samples/embeddings/java/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**target/azure-functions/azfs-java-openai-sample/bin
diff --git a/samples/embeddings/javascript/src/app.js b/samples/embeddings/javascript/src/app.js
index 7c53e652..6a52be79 100644
--- a/samples/embeddings/javascript/src/app.js
+++ b/samples/embeddings/javascript/src/app.js
@@ -4,7 +4,7 @@ const embeddingsHttpInput = input.generic({
input: '{rawText}',
inputType: 'RawText',
type: 'embeddings',
- model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%'
+ embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%'
})
app.http('generateEmbeddings', {
@@ -31,7 +31,7 @@ const embeddingsFilePathInput = input.generic({
inputType: 'FilePath',
type: 'embeddings',
maxChunkLength: 512,
- model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%'
+ embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%'
})
app.http('getEmbeddingsFilePath', {
@@ -58,7 +58,7 @@ const embeddingsUrlInput = input.generic({
inputType: 'Url',
type: 'embeddings',
maxChunkLength: 512,
- model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%'
+ embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%'
})
app.http('getEmbeddingsUrl', {
diff --git a/samples/embeddings/powershell/GenerateEmbeddings/function.json b/samples/embeddings/powershell/GenerateEmbeddings/function.json
index 57621818..0236fcdf 100644
--- a/samples/embeddings/powershell/GenerateEmbeddings/function.json
+++ b/samples/embeddings/powershell/GenerateEmbeddings/function.json
@@ -21,7 +21,7 @@
"direction": "in",
"inputType": "RawText",
"input": "{rawText}",
- "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
+ "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
}
]
}
\ No newline at end of file
diff --git a/samples/embeddings/powershell/GetEmbeddingsFilePath/function.json b/samples/embeddings/powershell/GetEmbeddingsFilePath/function.json
index 3209aed5..1d8d44ed 100644
--- a/samples/embeddings/powershell/GetEmbeddingsFilePath/function.json
+++ b/samples/embeddings/powershell/GetEmbeddingsFilePath/function.json
@@ -22,7 +22,7 @@
"inputType": "FilePath",
"input": "{filePath}",
"maxChunkLength": 512,
- "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
+ "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
}
]
}
\ No newline at end of file
diff --git a/samples/embeddings/powershell/GetEmbeddingsURL/function.json b/samples/embeddings/powershell/GetEmbeddingsURL/function.json
index ebf96844..de11ff4b 100644
--- a/samples/embeddings/powershell/GetEmbeddingsURL/function.json
+++ b/samples/embeddings/powershell/GetEmbeddingsURL/function.json
@@ -22,7 +22,7 @@
"inputType": "Url",
"input": "{url}",
"maxChunkLength": 512,
- "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
+ "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
}
]
}
\ No newline at end of file
diff --git a/samples/embeddings/powershell/extensions.csproj b/samples/embeddings/powershell/extensions.csproj
index 8f73abd3..82ae1923 100644
--- a/samples/embeddings/powershell/extensions.csproj
+++ b/samples/embeddings/powershell/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**bin
diff --git a/samples/embeddings/typescript/src/app.ts b/samples/embeddings/typescript/src/app.ts
index 9af76b96..f749cd0c 100644
--- a/samples/embeddings/typescript/src/app.ts
+++ b/samples/embeddings/typescript/src/app.ts
@@ -8,7 +8,7 @@ const embeddingsHttpInput = input.generic({
input: '{rawText}',
inputType: 'RawText',
type: 'embeddings',
- model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%'
+ embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%'
})
app.http('generateEmbeddings', {
@@ -39,7 +39,7 @@ const embeddingsFilePathInput = input.generic({
inputType: 'FilePath',
type: 'embeddings',
maxChunkLength: 512,
- model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%'
+ embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%'
})
app.http('getEmbeddingsFilePath', {
@@ -70,7 +70,7 @@ const embeddingsUrlInput = input.generic({
inputType: 'Url',
type: 'embeddings',
maxChunkLength: 512,
- model: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%'
+ embeddingsModel: '%EMBEDDING_MODEL_DEPLOYMENT_NAME%'
})
app.http('getEmbeddingsUrl', {
diff --git a/samples/rag-aisearch/README.md b/samples/rag-aisearch/README.md
index c56c2952..e9b0ce81 100644
--- a/samples/rag-aisearch/README.md
+++ b/samples/rag-aisearch/README.md
@@ -35,16 +35,14 @@ and optionally [enable semantic ranking](https://learn.microsoft.com/en-us/azure
1. User-Assigned Managed Identity:
- // Note and TODO: Use of user-assigned managed identity for AI Search is breaking at the moment, since Azure OpenAI authentication also needs update to support a separate identity.
-
```json
"__endpoint": "https://.search.windows.net",
"__credential": "managedidentity",
"__managedIdentityResourceId": "Resource Id of managed identity",
- "__managedIdentityClientId": "Client Id of managed identity"
+ "__clientId": "Client Id of managed identity"
```
- Only one of managedIdentityResourceId or managedIdentityClientId should be specified, not both.
+ Only one of managedIdentityResourceId or clientId should be specified, not both.
2. System-Assigned Managed Identity or local development:
@@ -56,7 +54,7 @@ and optionally [enable semantic ranking](https://learn.microsoft.com/en-us/azure
Specifying credential is optional for system assigned managed identity
4. Binding Configuration -
- Pass the configured `ConnectionNamePrefix` value, example `AISearch` to the `connectionName` property in the `SemanticSearchInput` or `EmbeddingsStoreOutput` bindings. Default is `AISearchEndpoint` if just the endpoint is being configured in local.settings.json or environment variables to use DefaultAzureCredential.
+ Pass the configured `ConnectionNamePrefix` value, example `AISearch` to the `searchConnectionName` property in the `SemanticSearchInput` or `EmbeddingsStoreOutput` bindings. Default is `AISearchEndpoint` if just the endpoint is being configured in local.settings.json or environment variables to use DefaultAzureCredential.
## Running the sample
diff --git a/samples/rag-aisearch/csharp-legacy/FilePrompt.cs b/samples/rag-aisearch/csharp-legacy/FilePrompt.cs
index df97a125..4c6b7f2f 100644
--- a/samples/rag-aisearch/csharp-legacy/FilePrompt.cs
+++ b/samples/rag-aisearch/csharp-legacy/FilePrompt.cs
@@ -20,7 +20,7 @@ public record SemanticSearchRequest(string Prompt);
[FunctionName("IngestFile")]
public static async Task IngestFile(
[HttpTrigger(AuthorizationLevel.Function, "post")] EmbeddingsRequest req,
- [EmbeddingsStore("{url}", InputType.Url, "AISearchEndpoint", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")]
+ [EmbeddingsStore("{url}", InputType.Url, "AISearchEndpoint", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")]
IAsyncCollector output)
{
if (string.IsNullOrWhiteSpace(req.Url))
diff --git a/samples/rag-aisearch/csharp-ooproc/FilePrompt.cs b/samples/rag-aisearch/csharp-ooproc/FilePrompt.cs
index d96b1d93..c9d85c6a 100644
--- a/samples/rag-aisearch/csharp-ooproc/FilePrompt.cs
+++ b/samples/rag-aisearch/csharp-ooproc/FilePrompt.cs
@@ -67,7 +67,7 @@ public static async Task IngestFile(
public class EmbeddingsStoreOutputResponse
{
- [EmbeddingsStoreOutput("{url}", InputType.Url, "AISearchEndpoint", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")]
+ [EmbeddingsStoreOutput("{url}", InputType.Url, "AISearchEndpoint", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")]
public required SearchableDocument SearchableDocument { get; init; }
public IActionResult? HttpResponse { get; set; }
diff --git a/samples/rag-aisearch/csharp-ooproc/SemanticAISearchEmbeddings.csproj b/samples/rag-aisearch/csharp-ooproc/SemanticAISearchEmbeddings.csproj
index feef3710..2740dd2b 100644
--- a/samples/rag-aisearch/csharp-ooproc/SemanticAISearchEmbeddings.csproj
+++ b/samples/rag-aisearch/csharp-ooproc/SemanticAISearchEmbeddings.csproj
@@ -11,7 +11,7 @@
-
+
diff --git a/samples/rag-aisearch/java/extensions.csproj b/samples/rag-aisearch/java/extensions.csproj
index fcbb0cac..29683eae 100644
--- a/samples/rag-aisearch/java/extensions.csproj
+++ b/samples/rag-aisearch/java/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**target/azure-functions/azfs-java-openai-sample/bin
diff --git a/samples/rag-aisearch/javascript/src/app.js b/samples/rag-aisearch/javascript/src/app.js
index 6bc7cc8f..c68b58eb 100644
--- a/samples/rag-aisearch/javascript/src/app.js
+++ b/samples/rag-aisearch/javascript/src/app.js
@@ -8,7 +8,7 @@ const embeddingsStoreOutput = output.generic({
inputType: "url",
connectionName: "AISearchEndpoint",
collection: "openai-index",
- model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
+ embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
});
app.http('IngestFile', {
diff --git a/samples/rag-aisearch/powershell/IngestFile/function.json b/samples/rag-aisearch/powershell/IngestFile/function.json
index 96a6a06e..be786a8b 100644
--- a/samples/rag-aisearch/powershell/IngestFile/function.json
+++ b/samples/rag-aisearch/powershell/IngestFile/function.json
@@ -22,7 +22,7 @@
"inputType": "Url",
"connectionName": "AISearchEndpoint",
"collection": "openai-index",
- "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
+ "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
}
]
}
\ No newline at end of file
diff --git a/samples/rag-aisearch/powershell/extensions.csproj b/samples/rag-aisearch/powershell/extensions.csproj
index 68c66ba4..1c321726 100644
--- a/samples/rag-aisearch/powershell/extensions.csproj
+++ b/samples/rag-aisearch/powershell/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**bin
diff --git a/samples/rag-aisearch/typescript/src/app.ts b/samples/rag-aisearch/typescript/src/app.ts
index ec430e24..176686d1 100644
--- a/samples/rag-aisearch/typescript/src/app.ts
+++ b/samples/rag-aisearch/typescript/src/app.ts
@@ -11,7 +11,7 @@ const embeddingsStoreOutput = output.generic({
inputType: "url",
connectionName: "AISearchEndpoint",
collection: "openai-index",
- model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
+ embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
});
app.http('IngestFile', {
diff --git a/samples/rag-cosmosdb/csharp-legacy/FilePrompt.cs b/samples/rag-cosmosdb/csharp-legacy/FilePrompt.cs
index 3411dcb0..ba4d823c 100644
--- a/samples/rag-cosmosdb/csharp-legacy/FilePrompt.cs
+++ b/samples/rag-cosmosdb/csharp-legacy/FilePrompt.cs
@@ -20,7 +20,7 @@ public record SemanticSearchRequest(string Prompt);
[FunctionName("IngestFile")]
public static async Task IngestFile(
[HttpTrigger(AuthorizationLevel.Function, "post")] EmbeddingsRequest req,
- [EmbeddingsStore("{url}", InputType.Url, "CosmosDBMongoVCoreConnectionString", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")]
+ [EmbeddingsStore("{url}", InputType.Url, "CosmosDBMongoVCoreConnectionString", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")]
IAsyncCollector output)
{
if (string.IsNullOrWhiteSpace(req.Url))
diff --git a/samples/rag-cosmosdb/csharp-ooproc/FilePrompt.cs b/samples/rag-cosmosdb/csharp-ooproc/FilePrompt.cs
index 09be8588..5e1b964d 100644
--- a/samples/rag-cosmosdb/csharp-ooproc/FilePrompt.cs
+++ b/samples/rag-cosmosdb/csharp-ooproc/FilePrompt.cs
@@ -66,7 +66,7 @@ public static async Task IngestFile(
public class EmbeddingsStoreOutputResponse
{
- [EmbeddingsStoreOutput("{url}", InputType.Url, "CosmosDBMongoVCoreConnectionString", "openai-index", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")]
+ [EmbeddingsStoreOutput("{url}", InputType.Url, "CosmosDBMongoVCoreConnectionString", "openai-index", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")]
public required SearchableDocument SearchableDocument { get; init; }
public IActionResult? HttpResponse { get; set; }
diff --git a/samples/rag-cosmosdb/csharp-ooproc/SemanticCosmosDBSearchEmbeddings.csproj b/samples/rag-cosmosdb/csharp-ooproc/SemanticCosmosDBSearchEmbeddings.csproj
index 6c3e0147..9b114e87 100644
--- a/samples/rag-cosmosdb/csharp-ooproc/SemanticCosmosDBSearchEmbeddings.csproj
+++ b/samples/rag-cosmosdb/csharp-ooproc/SemanticCosmosDBSearchEmbeddings.csproj
@@ -11,7 +11,7 @@
-
+
diff --git a/samples/rag-cosmosdb/java/extensions.csproj b/samples/rag-cosmosdb/java/extensions.csproj
index 010fdc9a..2dbe5f2b 100644
--- a/samples/rag-cosmosdb/java/extensions.csproj
+++ b/samples/rag-cosmosdb/java/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**target/azure-functions/azfs-java-openai-sample/bin
diff --git a/samples/rag-cosmosdb/javascript/src/app.js b/samples/rag-cosmosdb/javascript/src/app.js
index 274917c4..0116e652 100644
--- a/samples/rag-cosmosdb/javascript/src/app.js
+++ b/samples/rag-cosmosdb/javascript/src/app.js
@@ -8,7 +8,7 @@ const embeddingsStoreOutput = output.generic({
inputType: "url",
connectionName: "CosmosDBMongoVCoreConnectionString",
collection: "openai-index",
- model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
+ embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
});
app.http('IngestFile', {
diff --git a/samples/rag-cosmosdb/powershell/IngestFile/function.json b/samples/rag-cosmosdb/powershell/IngestFile/function.json
index 55b986e0..cdb86803 100644
--- a/samples/rag-cosmosdb/powershell/IngestFile/function.json
+++ b/samples/rag-cosmosdb/powershell/IngestFile/function.json
@@ -22,7 +22,7 @@
"inputType": "Url",
"connectionName": "CosmosDBMongoVCoreConnectionString",
"collection": "openai-index",
- "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
+ "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
}
]
}
\ No newline at end of file
diff --git a/samples/rag-cosmosdb/powershell/extensions.csproj b/samples/rag-cosmosdb/powershell/extensions.csproj
index bf9f1260..02592700 100644
--- a/samples/rag-cosmosdb/powershell/extensions.csproj
+++ b/samples/rag-cosmosdb/powershell/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**bin
diff --git a/samples/rag-cosmosdb/python/extensions.csproj b/samples/rag-cosmosdb/python/extensions.csproj
index 61aaed1f..e67d7e72 100644
--- a/samples/rag-cosmosdb/python/extensions.csproj
+++ b/samples/rag-cosmosdb/python/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**bin
diff --git a/samples/rag-cosmosdb/typescript/src/app.ts b/samples/rag-cosmosdb/typescript/src/app.ts
index c1808411..3ef9257e 100644
--- a/samples/rag-cosmosdb/typescript/src/app.ts
+++ b/samples/rag-cosmosdb/typescript/src/app.ts
@@ -11,7 +11,7 @@ const embeddingsStoreOutput = output.generic({
inputType: "url",
connectionName: "CosmosDBMongoVCoreConnectionString",
collection: "openai-index",
- model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
+ embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
});
app.http('IngestFile', {
diff --git a/samples/rag-kusto/csharp-legacy/FilePrompt.cs b/samples/rag-kusto/csharp-legacy/FilePrompt.cs
index 32600947..e63563e7 100644
--- a/samples/rag-kusto/csharp-legacy/FilePrompt.cs
+++ b/samples/rag-kusto/csharp-legacy/FilePrompt.cs
@@ -20,7 +20,7 @@ public record SemanticSearchRequest(string Prompt);
[FunctionName("IngestFile")]
public static async Task IngestFile(
[HttpTrigger(AuthorizationLevel.Function, "post")] EmbeddingsRequest req,
- [EmbeddingsStore("{url}", InputType.Url, "KustoConnectionString", "Documents", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")]
+ [EmbeddingsStore("{url}", InputType.Url, "KustoConnectionString", "Documents", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")]
IAsyncCollector output)
{
if (string.IsNullOrWhiteSpace(req.Url))
diff --git a/samples/rag-kusto/csharp-ooproc/EmailPromptDemo.cs b/samples/rag-kusto/csharp-ooproc/EmailPromptDemo.cs
index 7b78e9b8..747ef524 100644
--- a/samples/rag-kusto/csharp-ooproc/EmailPromptDemo.cs
+++ b/samples/rag-kusto/csharp-ooproc/EmailPromptDemo.cs
@@ -66,7 +66,7 @@ public async Task IngestEmail(
}
public class EmbeddingsStoreOutputResponse
{
- [EmbeddingsStoreOutput("{url}", InputType.Url, "KustoConnectionString", "Documents", Model = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")]
+ [EmbeddingsStoreOutput("{url}", InputType.Url, "KustoConnectionString", "Documents", EmbeddingsModel = "%EMBEDDING_MODEL_DEPLOYMENT_NAME%")]
public required SearchableDocument SearchableDocument { get; init; }
[HttpResult]
diff --git a/samples/rag-kusto/csharp-ooproc/SemanticSearchEmbeddings.csproj b/samples/rag-kusto/csharp-ooproc/SemanticSearchEmbeddings.csproj
index cb5db855..71a9aac5 100644
--- a/samples/rag-kusto/csharp-ooproc/SemanticSearchEmbeddings.csproj
+++ b/samples/rag-kusto/csharp-ooproc/SemanticSearchEmbeddings.csproj
@@ -10,7 +10,7 @@
-
+
diff --git a/samples/rag-kusto/java/extensions.csproj b/samples/rag-kusto/java/extensions.csproj
index e2323ac5..3fded17b 100644
--- a/samples/rag-kusto/java/extensions.csproj
+++ b/samples/rag-kusto/java/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**target/azure-functions/azfs-java-openai-sample/bin
diff --git a/samples/rag-kusto/javascript/src/app.js b/samples/rag-kusto/javascript/src/app.js
index 655e64e6..87b948d1 100644
--- a/samples/rag-kusto/javascript/src/app.js
+++ b/samples/rag-kusto/javascript/src/app.js
@@ -8,7 +8,7 @@ const embeddingsStoreOutput = output.generic({
inputType: "url",
connectionName: "KustoConnectionString",
collection: "Documents",
- model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
+ embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
});
app.http('IngestEmail', {
diff --git a/samples/rag-kusto/powershell/IngestEmail/function.json b/samples/rag-kusto/powershell/IngestEmail/function.json
index 62ad48c4..31bd2eaa 100644
--- a/samples/rag-kusto/powershell/IngestEmail/function.json
+++ b/samples/rag-kusto/powershell/IngestEmail/function.json
@@ -22,7 +22,7 @@
"inputType": "Url",
"connectionName": "KustoConnectionString",
"collection": "Documents",
- "model": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
+ "embeddingsModel": "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
}
]
}
\ No newline at end of file
diff --git a/samples/rag-kusto/powershell/extensions.csproj b/samples/rag-kusto/powershell/extensions.csproj
index c8f3ee70..08c2647b 100644
--- a/samples/rag-kusto/powershell/extensions.csproj
+++ b/samples/rag-kusto/powershell/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**bin
diff --git a/samples/rag-kusto/python/extensions.csproj b/samples/rag-kusto/python/extensions.csproj
index 1346f6ff..05d7cdf4 100644
--- a/samples/rag-kusto/python/extensions.csproj
+++ b/samples/rag-kusto/python/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**bin
diff --git a/samples/rag-kusto/typescript/src/app.ts b/samples/rag-kusto/typescript/src/app.ts
index bd0df1d8..bfc17d70 100644
--- a/samples/rag-kusto/typescript/src/app.ts
+++ b/samples/rag-kusto/typescript/src/app.ts
@@ -11,7 +11,7 @@ const embeddingsStoreOutput = output.generic({
inputType: "url",
connectionName: "KustoConnectionString",
collection: "Documents",
- model: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
+ embeddingsModel: "%EMBEDDING_MODEL_DEPLOYMENT_NAME%"
});
app.http('IngestEmail', {
diff --git a/samples/textcompletion/csharp-legacy/TextCompletionLegacy.cs b/samples/textcompletion/csharp-legacy/TextCompletionLegacy.cs
index 12d96f83..95fa8567 100644
--- a/samples/textcompletion/csharp-legacy/TextCompletionLegacy.cs
+++ b/samples/textcompletion/csharp-legacy/TextCompletionLegacy.cs
@@ -24,7 +24,7 @@ public static class TextCompletionLegacy
[FunctionName(nameof(WhoIs))]
public static IActionResult WhoIs(
[HttpTrigger(AuthorizationLevel.Function, Route = "whois/{name}")] HttpRequest req,
- [TextCompletion("Who is {name}?", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response)
+ [TextCompletion("Who is {name}?", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response)
{
return new OkObjectResult(response.Content);
}
@@ -36,7 +36,7 @@ public static IActionResult WhoIs(
[FunctionName(nameof(GenericCompletion))]
public static IActionResult GenericCompletion(
[HttpTrigger(AuthorizationLevel.Function, "post")] PromptPayload payload,
- [TextCompletion("{Prompt}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response,
+ [TextCompletion("{Prompt}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response,
ILogger log)
{
string text = response.Content;
diff --git a/samples/textcompletion/csharp-ooproc/TextCompletion.csproj b/samples/textcompletion/csharp-ooproc/TextCompletion.csproj
index 74c6a35e..2b9b00ec 100644
--- a/samples/textcompletion/csharp-ooproc/TextCompletion.csproj
+++ b/samples/textcompletion/csharp-ooproc/TextCompletion.csproj
@@ -11,7 +11,7 @@
-
+
diff --git a/samples/textcompletion/csharp-ooproc/TextCompletions.cs b/samples/textcompletion/csharp-ooproc/TextCompletions.cs
index a1a1ef7e..4bdbae7a 100644
--- a/samples/textcompletion/csharp-ooproc/TextCompletions.cs
+++ b/samples/textcompletion/csharp-ooproc/TextCompletions.cs
@@ -23,7 +23,7 @@ public static class TextCompletions
[Function(nameof(WhoIs))]
public static IActionResult WhoIs(
[HttpTrigger(AuthorizationLevel.Function, Route = "whois/{name}")] HttpRequestData req,
- [TextCompletionInput("Who is {name}?", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response)
+ [TextCompletionInput("Who is {name}?", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response)
{
return new OkObjectResult(response.Content);
}
@@ -35,7 +35,7 @@ public static IActionResult WhoIs(
[Function(nameof(GenericCompletion))]
public static IActionResult GenericCompletion(
[HttpTrigger(AuthorizationLevel.Function, "post")] HttpRequestData req,
- [TextCompletionInput("{Prompt}", Model = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response,
+ [TextCompletionInput("{Prompt}", ChatModel = "%CHAT_MODEL_DEPLOYMENT_NAME%")] TextCompletionResponse response,
ILogger log)
{
string text = response.Content;
diff --git a/samples/textcompletion/java/extensions.csproj b/samples/textcompletion/java/extensions.csproj
index eb243f13..ad5a93c6 100644
--- a/samples/textcompletion/java/extensions.csproj
+++ b/samples/textcompletion/java/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**target/azure-functions/azfs-java-openai-sample/bin
diff --git a/samples/textcompletion/javascript/src/functions/whois.js b/samples/textcompletion/javascript/src/functions/whois.js
index 665af065..96725512 100644
--- a/samples/textcompletion/javascript/src/functions/whois.js
+++ b/samples/textcompletion/javascript/src/functions/whois.js
@@ -5,7 +5,7 @@ const openAICompletionInput = input.generic({
prompt: 'Who is {name}?',
maxTokens: '100',
type: 'textCompletion',
- model: '%CHAT_MODEL_DEPLOYMENT_NAME%'
+ chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%'
})
app.http('whois', {
diff --git a/samples/textcompletion/powershell/WhoIs/function.json b/samples/textcompletion/powershell/WhoIs/function.json
index bdc381c7..6abf6c9b 100644
--- a/samples/textcompletion/powershell/WhoIs/function.json
+++ b/samples/textcompletion/powershell/WhoIs/function.json
@@ -21,7 +21,7 @@
"name": "TextCompletionResponse",
"prompt": "Who is {name}?",
"maxTokens": "100",
- "model": "%CHAT_MODEL_DEPLOYMENT_NAME%"
+ "chatModel": "%CHAT_MODEL_DEPLOYMENT_NAME%"
}
]
}
\ No newline at end of file
diff --git a/samples/textcompletion/powershell/extensions.csproj b/samples/textcompletion/powershell/extensions.csproj
index 8f73abd3..82ae1923 100644
--- a/samples/textcompletion/powershell/extensions.csproj
+++ b/samples/textcompletion/powershell/extensions.csproj
@@ -1,6 +1,6 @@
- net60
+ net80**bin
diff --git a/samples/textcompletion/typescript/src/functions/whois.ts b/samples/textcompletion/typescript/src/functions/whois.ts
index 22f8c2ee..b910f51c 100644
--- a/samples/textcompletion/typescript/src/functions/whois.ts
+++ b/samples/textcompletion/typescript/src/functions/whois.ts
@@ -5,7 +5,7 @@ const openAICompletionInput = input.generic({
prompt: 'Who is {name}?',
maxTokens: '100',
type: 'textCompletion',
- model: '%CHAT_MODEL_DEPLOYMENT_NAME%'
+ chatModel: '%CHAT_MODEL_DEPLOYMENT_NAME%'
})
app.http('whois', {
diff --git a/src/Directory.Build.props b/src/Directory.Build.props
index 91466cb7..9f4ab3ef 100644
--- a/src/Directory.Build.props
+++ b/src/Directory.Build.props
@@ -33,7 +33,7 @@
$(VersionPrefix).$(FileVersionRevision)
- 0.16.0-alpha
+ 0.17.0-alpha0.4.0-alpha
diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantMessage.cs
similarity index 71%
rename from src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs
rename to src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantMessage.cs
index 585fc205..cdad6ef4 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatMessage.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantMessage.cs
@@ -8,18 +8,19 @@ namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Assistants;
///
/// Chat Message Entity which contains the content of the message, the role of the chat agent, and the name of the calling function if applicable.
///
-public class ChatMessage
+public class AssistantMessage
{
///
- /// Initializes a new instance of the class.
+ /// Initializes a new instance of the class.
///
/// The content of the message.
/// The role of the chat agent.
- public ChatMessage(string content, string role, string? name)
+ /// The tool calls.
+ public AssistantMessage(string content, string role, string toolCalls)
{
this.Content = content;
this.Role = role;
- this.Name = name;
+ this.ToolCalls = toolCalls;
}
///
@@ -35,8 +36,8 @@ public ChatMessage(string content, string role, string? name)
public string Role { get; set; }
///
- /// Gets or sets the name of the calling function if applicable.
+ /// Gets or sets the tool calls.
///
- [JsonPropertyName("name")]
- public string? Name { get; set; }
+ [JsonPropertyName("toolCalls")]
+ public string ToolCalls { get; set; }
}
diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs
index fda594e7..c29b462f 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantPostInputAttribute.cs
@@ -10,12 +10,34 @@ namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Assistants;
///
public sealed class AssistantPostInputAttribute : InputBindingAttribute
{
- public AssistantPostInputAttribute(string id, string UserMessage)
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The assistant identifier.
+ /// The user message.
+ public AssistantPostInputAttribute(string id, string userMessage)
{
this.Id = id;
- this.UserMessage = UserMessage;
+ this.UserMessage = userMessage;
}
+ ///
+ /// Gets or sets the name of the configuration section for AI service connectivity settings.
+ ///
+ ///
+ /// This property specifies the name of the configuration section that contains connection details for the AI service.
+ ///
+ /// For Azure OpenAI:
+ /// - If specified, looks for "Endpoint" and "Key" values in this configuration section
+ /// - If not specified or the section doesn't exist, falls back to environment variables:
+ /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY
+ /// - For user-assigned managed identity authentication, this property is required
+ ///
+ /// For OpenAI:
+ /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable.
+ ///
+ public string AIConnectionName { get; set; } = "";
+
///
/// Gets the ID of the assistant to update.
///
@@ -27,7 +49,7 @@ public AssistantPostInputAttribute(string id, string UserMessage)
///
/// When using Azure OpenAI, then should be the name of the model deployment.
///
- public string? Model { get; set; }
+ public string? ChatModel { get; set; }
///
/// Gets user message that user has entered for assistant to respond to.
@@ -43,4 +65,40 @@ public AssistantPostInputAttribute(string id, string UserMessage)
/// Table collection name for chat storage.
///
public string CollectionName { get; set; } = "ChatState";
+
+ ///
+ /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output
+ /// more random, while lower values like 0.2 will make it more focused and deterministic.
+ ///
+ ///
+ /// It's generally recommend to use this or but not both.
+ ///
+ public string? Temperature { get; set; } = "0.5";
+
+ ///
+ /// Gets or sets an alternative to sampling with temperature, called nucleus sampling, where the model considers
+ /// the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%
+ /// probability mass are considered.
+ ///
+ ///
+ /// It's generally recommend to use this or but not both.
+ ///
+ public string? TopP { get; set; }
+
+ ///
+ /// Gets or sets the maximum number of tokens to output in the completion. Default is 100.
+ ///
+ ///
+ /// The token count of your prompt plus max_tokens cannot exceed the model's context length.
+ /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
+ ///
+ public string? MaxTokens { get; set; } = "100";
+
+ ///
+ /// Gets or sets a value indicating whether the model is a reasoning model.
+ ///
+ ///
+ /// Warning: This is experimental and associated with the reasoning model until all models have parity in the expected properties.
+ ///
+ public bool IsReasoningModel { get; set; }
}
diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs
index 6c6787ec..410eb73b 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs
@@ -37,12 +37,4 @@ public AssistantSkillTriggerAttribute(string functionDescription)
/// https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools.
///
public string? ParameterDescriptionJson { get; set; }
-
- ///
- /// Gets or sets the OpenAI chat model to use.
- ///
- ///
- /// When using Azure OpenAI, then should be the name of the model deployment.
- ///
- public string Model { get; set; } = OpenAIModels.DefaultChatModel;
}
\ No newline at end of file
diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantState.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantState.cs
index 1ef18d72..7570ba87 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantState.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/AssistantState.cs
@@ -50,5 +50,5 @@ public class AssistantState
/// Gets a list of the recent messages from the assistant.
///
[JsonPropertyName("recentMessages")]
- public IReadOnlyList RecentMessages { get; set; } = Array.Empty();
+ public IReadOnlyList RecentMessages { get; set; } = Array.Empty();
}
diff --git a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs b/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs
index a73974fa..f8e8e676 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/Assistants/ChatCompletionJsonConverter.cs
@@ -1,23 +1,27 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
-using Azure.AI.OpenAI;
using System.ClientModel.Primitives;
using System.Text.Json;
using System.Text.Json.Serialization;
+using OpenAI.Chat;
namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Assistants;
-public class ChatCompletionsJsonConverter : JsonConverter
+public class ChatCompletionJsonConverter : JsonConverter
{
static readonly ModelReaderWriterOptions modelReaderWriterOptions = new("J");
- public override ChatCompletions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ public override ChatCompletion Read(
+ ref Utf8JsonReader reader,
+ Type typeToConvert,
+ JsonSerializerOptions options)
{
using JsonDocument jsonDocument = JsonDocument.ParseValue(ref reader);
- return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!;
+ return ModelReaderWriter.Read(
+ BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!;
}
- public override void Write(Utf8JsonWriter writer, ChatCompletions value, JsonSerializerOptions options)
+ public override void Write(Utf8JsonWriter writer, ChatCompletion value, JsonSerializerOptions options)
{
- ((IJsonModel)value).Write(writer, modelReaderWriterOptions);
+ ((IJsonModel)value).Write(writer, modelReaderWriterOptions);
}
}
\ No newline at end of file
diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs
index 21259707..2f7dfa45 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsContext.cs
@@ -1,17 +1,13 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
-using OpenAISDK = Azure.AI.OpenAI;
+using OpenAI.Embeddings;
+
namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings;
public class EmbeddingsContext
{
- ///
- /// Binding target for the .
- ///
- /// The embeddings request that was sent to OpenAI.
- /// The embeddings response that was received from OpenAI.
- public EmbeddingsContext(OpenAISDK.EmbeddingsOptions Request, OpenAISDK.Embeddings Response)
+ public EmbeddingsContext(IList Request, OpenAIEmbeddingCollection? Response)
{
this.Request = Request;
this.Response = Response;
@@ -20,15 +16,15 @@ public EmbeddingsContext(OpenAISDK.EmbeddingsOptions Request, OpenAISDK.Embeddin
///
/// Embeddings request sent to OpenAI.
///
- public OpenAISDK.EmbeddingsOptions Request { get; set; }
+ public IList Request { get; set; }
///
/// Embeddings response from OpenAI.
///
- public OpenAISDK.Embeddings Response { get; set; }
-
+ public OpenAIEmbeddingCollection? Response { get; set; }
+
///
/// Gets the number of embeddings that were returned in the response.
///
- public int Count => this.Response.Data?.Count ?? 0;
+ public int Count => this.Response?.Count ?? 0;
}
diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs
index f9a2d073..a5943c6e 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsInputAttribute.cs
@@ -8,7 +8,7 @@ namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings;
public class EmbeddingsInputAttribute : InputBindingAttribute
{
///
- /// Initializes a new instance of the class with the specified input.
+ /// Initializes a new instance of the class with the specified input.
///
/// The input source containing the data to generate embeddings for.
/// The type of the input.
@@ -19,13 +19,30 @@ public EmbeddingsInputAttribute(string input, InputType inputType)
this.InputType = inputType;
}
+ ///
+ /// Gets or sets the name of the configuration section for AI service connectivity settings.
+ ///
+ ///
+ /// This property specifies the name of the configuration section that contains connection details for the AI service.
+ ///
+ /// For Azure OpenAI:
+ /// - If specified, looks for "Endpoint" and "Key" values in this configuration section
+ /// - If not specified or the section doesn't exist, falls back to environment variables:
+ /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY
+ /// - For user-assigned managed identity authentication, this property is required
+ ///
+ /// For OpenAI:
+ /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable.
+ ///
+ public string AIConnectionName { get; set; } = "";
+
///
/// Gets or sets the ID of the model to use.
///
///
/// Changing the default embeddings model is a breaking change, since any changes will be stored in a vector database for lookup. Changing the default model can cause the lookups to start misbehaving if they don't match the data that was previously ingested into the vector database.
///
- public string Model { get; set; } = OpenAIModels.DefaultEmbeddingsModel;
+ public string EmbeddingsModel { get; set; } = OpenAIModels.DefaultEmbeddingsModel;
///
/// Gets or sets the maximum number of characters to chunk the input into.
diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsJsonConverter.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsJsonConverter.cs
index 30a3679a..cfcc5d23 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsJsonConverter.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsJsonConverter.cs
@@ -4,25 +4,25 @@
using System.ClientModel.Primitives;
using System.Text.Json;
using System.Text.Json.Serialization;
-using OpenAISDK = Azure.AI.OpenAI;
+using OpenAI.Embeddings;
namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings;
///
-/// Embeddings JSON converter needed to serialize and deserialize the Embeddings object with the dotnet worker.
+/// OpenAIEmbeddingCollection JSON converter needed to serialize and deserialize the OpenAIEmbeddingCollection object with the dotnet worker.
///
-class EmbeddingsJsonConverter : JsonConverter
+class EmbeddingsJsonConverter : JsonConverter
{
static readonly ModelReaderWriterOptions JsonOptions = new("J");
- public override OpenAISDK.Embeddings Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ public override OpenAIEmbeddingCollection Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
using JsonDocument jsonDocument = JsonDocument.ParseValue(ref reader);
- return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!;
+ return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!;
}
- public override void Write(Utf8JsonWriter writer, OpenAISDK.Embeddings value, JsonSerializerOptions options)
+ public override void Write(Utf8JsonWriter writer, OpenAIEmbeddingCollection value, JsonSerializerOptions options)
{
- ((IJsonModel)value).Write(writer, JsonOptions);
+ ((IJsonModel)value).Write(writer, JsonOptions);
}
}
diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs
index 0a8949dc..9ca6bdf8 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsOptionsJsonConverter.cs
@@ -4,25 +4,25 @@
using System.ClientModel.Primitives;
using System.Text.Json;
using System.Text.Json.Serialization;
-using Azure.AI.OpenAI;
+using OpenAI.Embeddings;
namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings;
///
/// EmbeddingsOptions JSON converter needed to serialize and deserialize the EmbeddingsOptions object with the dotnet worker.
///
-class EmbeddingsOptionsJsonConverter : JsonConverter
+class EmbeddingsOptionsJsonConverter : JsonConverter
{
static readonly ModelReaderWriterOptions JsonOptions = new("J");
- public override EmbeddingsOptions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ public override EmbeddingGenerationOptions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
using JsonDocument jsonDocument = JsonDocument.ParseValue(ref reader);
- return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!;
+ return ModelReaderWriter.Read(BinaryData.FromString(jsonDocument.RootElement.GetRawText()))!;
}
- public override void Write(Utf8JsonWriter writer, EmbeddingsOptions value, JsonSerializerOptions options)
+ public override void Write(Utf8JsonWriter writer, EmbeddingGenerationOptions value, JsonSerializerOptions options)
{
- ((IJsonModel)value).Write(writer, JsonOptions);
+ ((IJsonModel)value).Write(writer, JsonOptions);
}
}
diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs
index 4be0d841..9adde53b 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/EmbeddingsStoreOutputAttribute.cs
@@ -14,28 +14,45 @@ public class EmbeddingsStoreOutputAttribute : OutputBindingAttribute
/// The input source containing the data to generate embeddings for
/// and is interpreted based on the value for .
/// The type of the input.
- ///
- /// The name of an app setting or environment variable which contains a connection string value.
+ ///
+ /// The name of an app setting or environment variable which contains a connection string value for embedding store.
///
/// The name of the collection or table to search or store.
///
- /// Thrown if or or are null.
+ /// Thrown if or or are null.
///
- public EmbeddingsStoreOutputAttribute(string input, InputType inputType, string connectionName, string collection)
+ public EmbeddingsStoreOutputAttribute(string input, InputType inputType, string storeConnectionName, string collection)
{
this.Input = input ?? throw new ArgumentNullException(nameof(input));
this.InputType = inputType;
- this.ConnectionName = connectionName ?? throw new ArgumentNullException(nameof(connectionName));
+ this.StoreConnectionName = storeConnectionName ?? throw new ArgumentNullException(nameof(storeConnectionName));
this.Collection = collection ?? throw new ArgumentNullException(nameof(collection));
}
+ ///
+ /// Gets or sets the name of the configuration section for AI service connectivity settings.
+ ///
+ ///
+ /// This property specifies the name of the configuration section that contains connection details for the AI service.
+ ///
+ /// For Azure OpenAI:
+ /// - If specified, looks for "Endpoint" and "Key" values in this configuration section
+ /// - If not specified or the section doesn't exist, falls back to environment variables:
+ /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY
+ /// - For user-assigned managed identity authentication, configuration section is required
+ ///
+ /// For OpenAI:
+ /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable.
+ ///
+ public string AIConnectionName { get; set; } = "";
+
///
/// Gets or sets the ID of the model to use.
///
///
/// Changing the default embeddings model is a breaking change, since any changes will be stored in a vector database for lookup. Changing the default model can cause the lookups to start misbehaving if they don't match the data that was previously ingested into the vector database.
///
- public string Model { get; set; } = OpenAIModels.DefaultEmbeddingsModel;
+ public string EmbeddingsModel { get; set; } = OpenAIModels.DefaultEmbeddingsModel;
///
/// Gets or sets the maximum number of characters to chunk the input into.
@@ -70,7 +87,7 @@ public EmbeddingsStoreOutputAttribute(string input, InputType inputType, string
///
/// This property supports binding expressions.
///
- public string ConnectionName { get; set; }
+ public string StoreConnectionName { get; set; }
///
/// The name of the collection or table to search.
diff --git a/src/Functions.Worker.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs b/src/Functions.Worker.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs
new file mode 100644
index 00000000..242a8fe0
--- /dev/null
+++ b/src/Functions.Worker.Extensions.OpenAI/Embeddings/JsonModelListWrapper.cs
@@ -0,0 +1,49 @@
+using System.ClientModel.Primitives;
+using System.Text.Json;
+
+namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings;
+
+class JsonModelListWrapper : IJsonModel>
+{
+ readonly List list;
+
+ public JsonModelListWrapper(List list)
+ {
+ this.list = list;
+ }
+
+ public void Write(Utf8JsonWriter writer, ModelReaderWriterOptions options)
+ {
+ writer.WriteStartArray();
+ foreach (string item in this.list)
+ {
+ writer.WriteStringValue(item);
+ }
+ writer.WriteEndArray();
+ }
+
+ public static JsonModelListWrapper FromList(List list)
+ {
+ return new JsonModelListWrapper(list);
+ }
+
+ public List Create(ref Utf8JsonReader reader, ModelReaderWriterOptions options)
+ {
+ throw new NotImplementedException();
+ }
+
+ public BinaryData Write(ModelReaderWriterOptions options)
+ {
+ throw new NotImplementedException();
+ }
+
+ public List Create(BinaryData data, ModelReaderWriterOptions options)
+ {
+ throw new NotImplementedException();
+ }
+
+ public string GetFormatFromOptions(ModelReaderWriterOptions options)
+ {
+ throw new NotImplementedException();
+ }
+}
\ No newline at end of file
diff --git a/src/Functions.Worker.Extensions.OpenAI/Functions.Worker.Extensions.OpenAI.csproj b/src/Functions.Worker.Extensions.OpenAI/Functions.Worker.Extensions.OpenAI.csproj
index a01f9810..aac240ec 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Functions.Worker.Extensions.OpenAI.csproj
+++ b/src/Functions.Worker.Extensions.OpenAI/Functions.Worker.Extensions.OpenAI.csproj
@@ -7,7 +7,7 @@
-
+
diff --git a/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs b/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs
index 959912a5..5b6804e2 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/Search/SearchableDocumentJsonConverter.cs
@@ -4,7 +4,8 @@
using System.ClientModel.Primitives;
using System.Text.Json;
using System.Text.Json.Serialization;
-using OpenAISDK = Azure.AI.OpenAI;
+using Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings;
+using OpenAI.Embeddings;
namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Search;
@@ -23,12 +24,13 @@ public override void Write(Utf8JsonWriter writer, SearchableDocument value, Json
writer.WritePropertyName("embeddingsContext"u8);
writer.WriteStartObject();
- if (value.EmbeddingsContext?.Request is IJsonModel request)
+ if (value.EmbeddingsContext?.Request is List inputList)
{
writer.WritePropertyName("request"u8);
- request.Write(writer, modelReaderWriterOptions);
+ var inputWrapper = JsonModelListWrapper.FromList(inputList);
+ inputWrapper.Write(writer, modelReaderWriterOptions);
}
- if (value.EmbeddingsContext?.Response is IJsonModel response)
+ if (value.EmbeddingsContext?.Response is IJsonModel response)
{
writer.WritePropertyName("response"u8);
response.Write(writer, modelReaderWriterOptions);
diff --git a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchContext.cs b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchContext.cs
index b40a0dc8..8fb1788d 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchContext.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchContext.cs
@@ -2,8 +2,8 @@
// Licensed under the MIT License.
using System.Text.Json.Serialization;
-using Azure.AI.OpenAI;
using Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Embeddings;
+using OpenAI.Chat;
namespace Microsoft.Azure.Functions.Worker.Extensions.OpenAI.Search;
@@ -17,11 +17,11 @@ public class SemanticSearchContext
///
/// The embeddings context associated with the semantic search.
/// The chat response from the large language model.
- public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletions Chat)
+ public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletion Chat)
{
this.Embeddings = Embeddings;
this.Chat = Chat;
-
+
}
///
@@ -34,12 +34,11 @@ public SemanticSearchContext(EmbeddingsContext Embeddings, ChatCompletions Chat)
/// Gets the chat response from the large language model.
///
[JsonPropertyName("chat")]
- public ChatCompletions Chat { get; }
-
+ public ChatCompletion Chat { get; }
///
/// Gets the latest response message from the OpenAI Chat API.
///
[JsonPropertyName("response")]
- public string Response => this.Chat.Choices.Last().Message.Content;
- }
+ public string Response => this.Chat.Content.Last().Text;
+}
diff --git a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs
index e660d709..1b679607 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/Search/SemanticSearchInputAttribute.cs
@@ -14,26 +14,43 @@ public sealed class SemanticSearchInputAttribute : InputBindingAttribute
/// Initializes a new instance of the class with the specified connection
/// and collection names.
///
- ///
+ ///
/// The name of an app setting or environment variable which contains a connection string value.
///
/// The name of the collection or table to search or store.
///
- /// Thrown if either or are null.
+ /// Thrown if either or are null.
///
- public SemanticSearchInputAttribute(string connectionName, string collection)
+ public SemanticSearchInputAttribute(string searchConnectionName, string collection)
{
- this.ConnectionName = connectionName ?? throw new ArgumentNullException(nameof(connectionName));
+ this.SearchConnectionName = searchConnectionName ?? throw new ArgumentNullException(nameof(searchConnectionName));
this.Collection = collection ?? throw new ArgumentNullException(nameof(collection));
}
+ ///
+ /// Gets or sets the name of the configuration section for AI service connectivity settings.
+ ///
+ ///
+ /// This property specifies the name of the configuration section that contains connection details for the AI service.
+ ///
+ /// For Azure OpenAI:
+ /// - If specified, looks for "Endpoint" and "Key" values in this configuration section
+ /// - If not specified or the section doesn't exist, falls back to environment variables:
+ /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY
+ /// - For user-assigned managed identity authentication, this property is required
+ ///
+ /// For OpenAI:
+ /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable.
+ ///
+ public string AIConnectionName { get; set; } = "";
+
///
/// Gets or sets the name of an app setting or environment variable which contains a connection string value.
///
///
/// This property supports binding expressions.
///
- public string ConnectionName { get; set; }
+ public string SearchConnectionName { get; set; }
///
/// The name of the collection or table to search.
@@ -98,4 +115,40 @@ The following is a list of documents that you can refer to when answering questi
/// Gets or sets the number of knowledge items to inject into the .
///
public int MaxKnowledgeCount { get; set; } = 1;
+
+ ///
+ /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output
+ /// more random, while lower values like 0.2 will make it more focused and deterministic.
+ ///
+ ///
+ /// It's generally recommend to use this or but not both.
+ ///
+ public string? Temperature { get; set; } = "0.5";
+
+ ///
+ /// Gets or sets an alternative to sampling with temperature, called nucleus sampling, where the model considers
+ /// the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%
+ /// probability mass are considered.
+ ///
+ ///
+ /// It's generally recommend to use this or but not both.
+ ///
+ public string? TopP { get; set; }
+
+ ///
+ /// Gets or sets the maximum number of tokens to output in the completion. Default is 2048.
+ ///
+ ///
+ /// The token count of your prompt plus max_tokens cannot exceed the model's context length.
+ /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
+ ///
+ public string? MaxTokens { get; set; } = "2048";
+
+ ///
+ /// Gets or sets a value indicating whether the model is a reasoning model.
+ ///
+ ///
+ /// Warning: This is experimental and associated with the reasoning model until all models have parity in the expected properties.
+ ///
+ public bool IsReasoningModel { get; set; }
}
diff --git a/src/Functions.Worker.Extensions.OpenAI/Startup.cs b/src/Functions.Worker.Extensions.OpenAI/Startup.cs
index 8bbd524b..e4ad38a9 100644
--- a/src/Functions.Worker.Extensions.OpenAI/Startup.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/Startup.cs
@@ -25,7 +25,7 @@ public override void Configure(IFunctionsWorkerApplicationBuilder applicationBui
jsonSerializerOptions.Converters.Add(new EmbeddingsJsonConverter());
jsonSerializerOptions.Converters.Add(new EmbeddingsOptionsJsonConverter());
jsonSerializerOptions.Converters.Add(new SearchableDocumentJsonConverter());
- jsonSerializerOptions.Converters.Add(new ChatCompletionsJsonConverter());
+ jsonSerializerOptions.Converters.Add(new ChatCompletionJsonConverter());
});
}
}
diff --git a/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs b/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs
index 463d1e96..d9eaa985 100644
--- a/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs
+++ b/src/Functions.Worker.Extensions.OpenAI/TextCompletion/TextCompletionInputAttribute.cs
@@ -19,6 +19,23 @@ public TextCompletionInputAttribute(string prompt)
this.Prompt = prompt ?? throw new ArgumentNullException(nameof(prompt));
}
+ ///
+ /// Gets or sets the name of the configuration section for AI service connectivity settings.
+ ///
+ ///
+ /// This property specifies the name of the configuration section that contains connection details for the AI service.
+ ///
+ /// For Azure OpenAI:
+ /// - If specified, looks for "Endpoint" and "Key" values in this configuration section
+ /// - If not specified or the section doesn't exist, falls back to environment variables:
+ /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY
+ /// - For user-assigned managed identity authentication, this property is required
+ ///
+ /// For OpenAI:
+ /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable.
+ ///
+ public string AIConnectionName { get; set; } = "";
+
///
/// Gets or sets the prompt to generate completions for, encoded as a string.
///
@@ -27,7 +44,7 @@ public TextCompletionInputAttribute(string prompt)
///
/// Gets or sets the ID of the model to use.
///
- public string Model { get; set; } = "gpt-3.5-turbo";
+ public string ChatModel { get; set; } = "gpt-3.5-turbo";
///
/// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output
@@ -56,4 +73,12 @@ public TextCompletionInputAttribute(string prompt)
/// Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
///
public string? MaxTokens { get; set; } = "100";
+
+ ///
+ /// Gets or sets a value indicating whether the model is a reasoning model.
+ ///
+ ///
+ /// Warning: This is experimental and associated with the reasoning model until all models have parity in the expected properties.
+ ///
+ public bool IsReasoningModel { get; set; }
}
diff --git a/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs b/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs
index 4ad8280e..4c4d0134 100644
--- a/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs
+++ b/src/WebJobs.Extensions.OpenAI.AzureAISearch/AzureAISearchProvider.cs
@@ -246,16 +246,16 @@ async Task IndexSectionsAsync(SearchClient searchClient, SearchableDocument docu
{
int iteration = 0;
IndexDocumentsBatch batch = new();
- for (int i = 0; i < document.Embeddings?.Response?.Data.Count; i++)
+ for (int i = 0; i < document.Embeddings?.Response?.Count; i++)
{
batch.Actions.Add(new IndexDocumentsAction(
IndexActionType.MergeOrUpload,
new SearchDocument
{
["id"] = Guid.NewGuid().ToString("N"),
- ["text"] = document.Embeddings.Request.Input![i],
+ ["text"] = document.Embeddings.Request![i],
["title"] = Path.GetFileNameWithoutExtension(document.Title),
- ["embeddings"] = document.Embeddings.Response.Data[i].Embedding.ToArray() ?? Array.Empty(),
+ ["embeddings"] = document.Embeddings.Response[i].ToFloats().ToArray() ?? Array.Empty(),
["timestamp"] = DateTime.UtcNow
}));
iteration++;
diff --git a/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md b/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md
index f53671d3..5e1f2299 100644
--- a/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md
+++ b/src/WebJobs.Extensions.OpenAI.AzureAISearch/CHANGELOG.md
@@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Managed identity support and consistency established with other Azure Functions extensions
+### Changed
+
+- Updated Microsoft.Azure.WebJobs.Extensions.OpenAI to 0.19.0
+
## v0.3.0 - 2024/10/08
### Changed
diff --git a/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CHANGELOG.md b/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CHANGELOG.md
index bdef6010..187ff9f1 100644
--- a/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CHANGELOG.md
+++ b/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CHANGELOG.md
@@ -11,6 +11,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added DiskANN support for CosmosDB (MongoDB) Search Provider. Refer [README](../../samples/rag-cosmosdb/README.md) for more information on usage.
+### Changed
+
+- Updated Microsoft.Azure.WebJobs.Extensions.OpenAI to 0.19.0
+
## v0.3.0 - 2024/10/08
### Added
@@ -18,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added HNSW support for CosmosDB (MongoDB) Search Provider. Refer [README](../../samples/rag-cosmosdb/README.md) for more information on usage.
### Changed
+
- Updated nuget dependencies
## v0.2.0 - 2024/05/06
diff --git a/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs b/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs
index d4bc4870..82ad69dc 100644
--- a/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs
+++ b/src/WebJobs.Extensions.OpenAI.CosmosDBSearch/CosmosDBSearchProvider.cs
@@ -212,7 +212,7 @@ public void CreateVectorIndexIfNotExists(MongoClient cosmosClient)
async Task UpsertVectorAsync(MongoClient cosmosClient, SearchableDocument document)
{
List list = new();
- for (int i = 0; i < document.Embeddings?.Response?.Data.Count; i++)
+ for (int i = 0; i < document.Embeddings?.Response?.Count; i++)
{
BsonDocument vectorDocument =
new()
@@ -220,15 +220,14 @@ async Task UpsertVectorAsync(MongoClient cosmosClient, SearchableDocument docume
{ "id", Guid.NewGuid().ToString("N") },
{
this.cosmosDBSearchConfigOptions.Value.TextKey,
- document.Embeddings.Request.Input![i]
+ document.Embeddings.Request![i]
},
{ "title", Path.GetFileNameWithoutExtension(document.Title) },
{
this.cosmosDBSearchConfigOptions.Value.EmbeddingKey,
new BsonArray(
document
- .Embeddings.Response.Data[i]
- .Embedding.ToArray()
+ .Embeddings.Response[i].ToFloats().ToArray()
.Select(e => new BsonDouble(Convert.ToDouble(e)))
)
},
diff --git a/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md b/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md
index 51430935..a7c67706 100644
--- a/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md
+++ b/src/WebJobs.Extensions.OpenAI.Kusto/CHANGELOG.md
@@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## v0.17.0 - Unreleased
+
+- Updated Microsoft.Azure.WebJobs.Extensions.OpenAI to 0.19.0
+
## v0.16.0 - 2024/10/08
### Changed
diff --git a/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs b/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs
index 542454a3..fe46f501 100644
--- a/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs
+++ b/src/WebJobs.Extensions.OpenAI.Kusto/KustoSearchProvider.cs
@@ -78,13 +78,13 @@ public async Task AddDocumentAsync(SearchableDocument document, CancellationToke
table.AppendColumn("Embeddings", typeof(object));
table.AppendColumn("Timestamp", typeof(DateTime));
- for (int i = 0; i < document.Embeddings?.Response?.Data.Count; i++)
+ for (int i = 0; i < document.Embeddings?.Response?.Count; i++)
{
table.Rows.Add(
Guid.NewGuid().ToString("N"),
Path.GetFileNameWithoutExtension(document.Title),
- document.Embeddings.Request.Input![i],
- GetEmbeddingsString(document.Embeddings.Response.Data[i].Embedding, true),
+ document.Embeddings.Request![i],
+ GetEmbeddingsString(document.Embeddings.Response[i].ToFloats().ToArray(), true),
DateTime.UtcNow);
}
diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs
new file mode 100644
index 00000000..9c8a4c21
--- /dev/null
+++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantBaseAttribute.cs
@@ -0,0 +1,122 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+using Microsoft.Azure.WebJobs.Description;
+using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models;
+using OpenAI.Chat;
+
+namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants;
+
+///
+/// Binding attribute for the OpenAI Assistant extension. This attribute is used to specify the configuration
+/// settings for the OpenAI Assistant when used in a function parameter. It allows to set various chat completion options.
+///
+[Binding]
+[AttributeUsage(AttributeTargets.Parameter)]
+public class AssistantBaseAttribute : Attribute
+{
+ ///
+ /// Gets or sets the name of the Large Language Model to invoke for chat responses.
+ /// The default value is "gpt-3.5-turbo".
+ ///
+ ///
+ /// This property supports binding expressions.
+ ///
+ [AutoResolve]
+ public string ChatModel { get; set; } = OpenAIModels.DefaultChatModel;
+
+ ///
+ /// Gets or sets the name of the configuration section for AI service connectivity settings.
+ ///
+ ///
+ /// This property specifies the name of the configuration section that contains connection details for the AI service.
+ ///
+ /// For Azure OpenAI:
+ /// - If specified, looks for "Endpoint" and "Key" values in this configuration section
+ /// - If not specified or the section doesn't exist, falls back to environment variables:
+ /// AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_KEY
+ /// - For user-assigned managed identity authentication, configuration section is required
+ ///
+ /// For OpenAI:
+ /// - For OpenAI service (non-Azure), set the OPENAI_API_KEY environment variable.
+ ///
+ public string AIConnectionName { get; set; } = "";
+
+ ///
+ /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output
+ /// more random, while lower values like 0.2 will make it more focused and deterministic.
+ ///
+ ///
+ /// It's generally recommend to use this or but not both.
+ ///
+ [AutoResolve]
+ public string? Temperature { get; set; } = "0.5";
+
+ ///
+ /// Gets or sets an alternative to sampling with temperature, called nucleus sampling, where the model considers
+ /// the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%
+ /// probability mass are considered.
+ ///
+ ///
+ /// It's generally recommend to use this or but not both.
+ ///
+ [AutoResolve]
+ public string? TopP { get; set; }
+
+ ///
+ /// Gets or sets a value indicating whether the model is a reasoning model.
+ ///
+ ///
+ /// Warning: This is experimental and associated with the reasoning model until all models have parity in the expected properties.
+ ///
+ public bool IsReasoningModel { get; set; }
+
+ ///
+ /// Gets or sets the maximum number of tokens to output in the completion. Default value = 100.
+ ///
+ ///
+ /// The token count of your prompt plus max_tokens cannot exceed the model's context length.
+ /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
+ ///
+ [AutoResolve]
+ public string? MaxTokens { get; set; } = "100";
+
+ internal ChatCompletionOptions BuildRequest()
+ {
+ ChatCompletionOptions request = new();
+ if (float.TryParse(this.TopP, out float topP))
+ {
+ request.TopP = topP;
+ }
+
+ if (this.IsReasoningModel)
+ {
+ this.MaxTokens = null;
+ this.Temperature = null; // property not supported for reasoning model.
+ }
+ else
+ {
+ if (int.TryParse(this.MaxTokens, out int maxTokens))
+ {
+ request.MaxOutputTokenCount = maxTokens;
+ }
+
+ if (float.TryParse(this.Temperature, out float temperature))
+ {
+ request.Temperature = temperature;
+ }
+ }
+
+ // ToDo: SetNewMaxCompletionTokensPropertyEnabled() has a bug in the current version
+ // of the Azure.AI.OpenAI SDK but is fixed in the next preview.
+ //
+ // This method doesn't swap max_tokens with max_completion_tokens and throws errors
+ // if max_tokens is set. max_completion_tokens is not configurable due to this bug.
+ //
+ // Hence, setting max_tokens to null for reasoning models until the fixed SDK version
+ // can be adopted.
+ // request.SetNewMaxCompletionTokensPropertyEnabled(this.IsReasoningModel);
+
+ return request;
+ }
+}
diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs
index 56ed35fa..9e3c5faf 100644
--- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs
+++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantPostAttribute.cs
@@ -7,8 +7,13 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants;
[Binding]
[AttributeUsage(AttributeTargets.Parameter)]
-public sealed class AssistantPostAttribute : Attribute
+public sealed class AssistantPostAttribute : AssistantBaseAttribute
{
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The assistant identifier.
+ /// The user message.
public AssistantPostAttribute(string id, string userMessage)
{
this.Id = id;
@@ -21,15 +26,6 @@ public AssistantPostAttribute(string id, string userMessage)
[AutoResolve]
public string Id { get; }
- ///
- /// Gets or sets the OpenAI chat model to use.
- ///
- ///
- /// When using Azure OpenAI, then should be the name of the model deployment.
- ///
- [AutoResolve]
- public string? Model { get; set; }
-
///
/// Gets or sets the user message to OpenAI.
///
diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantRuntimeState.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantRuntimeState.cs
index 55dbd2f5..03b10f6d 100644
--- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantRuntimeState.cs
+++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantRuntimeState.cs
@@ -6,13 +6,13 @@
namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants;
-record struct MessageRecord(DateTime Timestamp, ChatMessage ChatMessageEntity);
+record struct MessageRecord(DateTime Timestamp, AssistantMessage ChatMessageEntity);
[JsonObject(MemberSerialization.OptIn)]
class AssistantRuntimeState
{
[JsonProperty("messages")]
- public List? ChatMessages { get; set; }
+ public List? ChatMessages { get; set; }
[JsonProperty("totalTokens")]
public int TotalTokens { get; set; } = 0;
diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs
index e31e8a3b..cd5726b3 100644
--- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs
+++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantService.cs
@@ -1,13 +1,14 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
+using System.ClientModel;
using Azure;
-using Azure.AI.OpenAI;
using Azure.Data.Tables;
using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models;
using Microsoft.Extensions.Azure;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.Logging;
+using OpenAI.Chat;
namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants;
@@ -28,7 +29,7 @@ record InternalChatState(string Id, AssistantStateEntity Metadata, List
const int FunctionCallBatchLimit = 50;
const string DefaultChatStorage = "AzureWebJobsStorage";
- readonly OpenAIClient openAIClient;
+ readonly OpenAIClientFactory openAIClientFactory;
readonly IAssistantSkillInvoker skillInvoker;
readonly ILogger logger;
readonly AzureComponentFactory azureComponentFactory;
@@ -37,7 +38,7 @@ record InternalChatState(string Id, AssistantStateEntity Metadata, List();
+ this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory));
this.azureComponentFactory = azureComponentFactory ?? throw new ArgumentNullException(nameof(azureComponentFactory));
this.configuration = configuration ?? throw new ArgumentNullException(nameof(configuration));
}
@@ -117,7 +117,8 @@ async Task DeleteBatch()
partitionKey: request.Id,
messageIndex: 1, // 1-based index
content: request.Instructions,
- role: ChatRole.System);
+ role: ChatMessageRole.System,
+ toolCalls: null);
batch.Add(new TableTransactionAction(TableTransactionActionType.Add, chatMessageEntity));
}
@@ -152,7 +153,7 @@ public async Task GetStateAsync(AssistantQueryAttribute assistan
if (chatState is null)
{
this.logger.LogWarning("No assistant exists with ID = '{Id}'", id);
- return new AssistantState(id, false, default, default, 0, 0, Array.Empty());
+ return new AssistantState(id, false, default, default, 0, 0, Array.Empty());
}
List filteredChatMessages = chatState.Messages
@@ -171,110 +172,100 @@ public async Task GetStateAsync(AssistantQueryAttribute assistan
chatState.Metadata.LastUpdatedAt,
chatState.Metadata.TotalMessages,
chatState.Metadata.TotalTokens,
- filteredChatMessages.Select(msg => new ChatMessage(msg.Content, msg.Role, msg.Name)).ToList());
+ filteredChatMessages.Select(msg => new AssistantMessage(msg.Content, msg.Role, msg.ToolCallsString)).ToList());
return state;
}
public async Task PostMessageAsync(AssistantPostAttribute attribute, CancellationToken cancellationToken)
{
+ // Validate inputs and prepare for processing
DateTime timeFilter = DateTime.UtcNow;
- if (string.IsNullOrEmpty(attribute.Id))
- {
- throw new ArgumentException("The assistant ID must be specified.", nameof(attribute));
- }
-
- if (string.IsNullOrEmpty(attribute.UserMessage))
- {
- throw new ArgumentException("The assistant must have a user message", nameof(attribute));
- }
+ this.ValidateAttributes(attribute);
this.logger.LogInformation("Posting message to assistant entity '{Id}'", attribute.Id);
-
TableClient tableClient = this.GetOrCreateTableClient(attribute.ChatStorageConnectionSetting, attribute.CollectionName);
+ // Load and validate chat state
InternalChatState? chatState = await this.LoadChatStateAsync(attribute.Id, tableClient, cancellationToken);
-
- // Check if assistant has been deactivated
if (chatState is null || !chatState.Metadata.Exists)
{
- this.logger.LogWarning("[{Id}] Ignoring request sent to nonexistent assistant.", attribute.Id);
- return new AssistantState(attribute.Id, false, default, default, 0, 0, Array.Empty());
+ return this.CreateNonExistentAssistantState(attribute.Id);
}
this.logger.LogInformation("[{Id}] Received message: {Text}", attribute.Id, attribute.UserMessage);
- // Create a batch of table transaction actions
+ // Process the conversation
List batch = new();
+ this.AddUserMessageToChat(attribute, chatState, batch);
+
+ await this.ProcessConversationWithLLM(attribute, chatState, batch, cancellationToken);
- // Add the user message as a new Chat message entity
+ // Update state and persist changes
+ this.UpdateAssistantState(chatState, batch);
+ await tableClient.SubmitTransactionAsync(batch, cancellationToken);
+
+ // Return results
+ return this.CreateAssistantStateResponse(attribute.Id, chatState, timeFilter);
+ }
+
+ // Helper methods
+ void ValidateAttributes(AssistantPostAttribute attribute)
+ {
+ if (string.IsNullOrEmpty(attribute.Id))
+ {
+ throw new ArgumentException("The assistant ID must be specified.", nameof(attribute));
+ }
+
+ if (string.IsNullOrEmpty(attribute.UserMessage))
+ {
+ throw new ArgumentException("The assistant must have a user message", nameof(attribute));
+ }
+ }
+
+ AssistantState CreateNonExistentAssistantState(string id)
+ {
+ this.logger.LogWarning("[{Id}] Ignoring request sent to nonexistent assistant.", id);
+ return new AssistantState(id, false, default, default, 0, 0, Array.Empty());
+ }
+
+ void AddUserMessageToChat(AssistantPostAttribute attribute, InternalChatState chatState, List batch)
+ {
ChatMessageTableEntity chatMessageEntity = new(
partitionKey: attribute.Id,
messageIndex: ++chatState.Metadata.TotalMessages,
content: attribute.UserMessage,
- role: ChatRole.User);
+ role: ChatMessageRole.User,
+ toolCalls: null);
chatState.Messages.Add(chatMessageEntity);
-
- // Add the chat message to the batch
batch.Add(new TableTransactionAction(TableTransactionActionType.Add, chatMessageEntity));
+ }
- string deploymentName = attribute.Model ?? OpenAIModels.DefaultChatModel;
- IList? functions = this.skillInvoker.GetFunctionsDefinitions();
+ async Task ProcessConversationWithLLM(
+ AssistantPostAttribute attribute,
+ InternalChatState chatState,
+ List batch,
+ CancellationToken cancellationToken)
+ {
+ IList? functions = this.skillInvoker.GetFunctionsDefinitions();
// We loop if the model returns function calls. Otherwise, we break after receiving a response.
while (true)
{
- // Get the next response from the LLM
- ChatCompletionsOptions chatRequest = new(deploymentName, ToOpenAIChatRequestMessages(chatState.Messages));
- if (functions is not null)
- {
- foreach (ChatCompletionsFunctionToolDefinition fn in functions)
- {
- chatRequest.Tools.Add(fn);
- }
- }
-
- Response response = await this.openAIClient.GetChatCompletionsAsync(
- chatRequest,
- cancellationToken);
+ // Get the LLM response
+ ClientResult response = await this.GetLLMResponse(attribute, chatState, functions, cancellationToken);
- // We don't normally expect more than one message, but just in case we get multiple messages,
- // return all of them separated by two newlines.
- string replyMessage = string.Join(
- Environment.NewLine + Environment.NewLine,
- response.Value.Choices.Select(choice => choice.Message.Content));
- if (!string.IsNullOrWhiteSpace(replyMessage))
+ // Process text response if available
+ string replyMessage = this.FormatReplyMessage(response);
+ if (!string.IsNullOrWhiteSpace(replyMessage) || response.Value.ToolCalls.Any())
{
- this.logger.LogInformation(
- "[{Id}] Got LLM response consisting of {Count} tokens: {Text}",
- attribute.Id,
- response.Value.Usage.CompletionTokens,
- replyMessage);
-
- // Add the user message as a new Chat message entity
- ChatMessageTableEntity replyFromAssistantEntity = new(
- partitionKey: attribute.Id,
- messageIndex: ++chatState.Metadata.TotalMessages,
- content: replyMessage,
- role: ChatRole.Assistant);
- chatState.Messages.Add(replyFromAssistantEntity);
-
- // Add the reply from assistant chat message to the batch
- batch.Add(new TableTransactionAction(TableTransactionActionType.Add, replyFromAssistantEntity));
-
- this.logger.LogInformation(
- "[{Id}] Chat length is now {Count} messages",
- attribute.Id,
- chatState.Metadata.TotalMessages);
+ this.LogAndAddAssistantReply(attribute.Id, replyMessage, response, chatState, batch);
}
- // Set the total tokens that have been consumed.
- chatState.Metadata.TotalTokens = response.Value.Usage.TotalTokens;
+ // Update token count
+ chatState.Metadata.TotalTokens = response.Value.Usage.TotalTokenCount;
- // Check for function calls (which are described in the API as tools)
- List functionCalls = response.Value.Choices
- .SelectMany(c => c.Message.ToolCalls)
- .OfType()
- .ToList();
+ // Handle function calls
+ List functionCalls = response.Value.ToolCalls.OfType().ToList();
if (functionCalls.Count == 0)
{
// No function calls, so we're done
@@ -283,107 +274,193 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib
if (batch.Count > FunctionCallBatchLimit)
{
- // Too many function calls, something might be wrong. Break out of the loop
- // to avoid infinite loops and to avoid exceeding the batch size limit of 100.
- this.logger.LogWarning(
- "[{Id}] Ignoring {Count} function call(s) in response due to exceeding the limit of {Limit}.",
- attribute.Id,
- functionCalls.Count,
- FunctionCallBatchLimit);
+ // Too many function calls, something might be wrong
+ this.LogBatchLimitExceeded(attribute.Id, functionCalls.Count);
break;
}
- // Loop case: found some functions to execute
- this.logger.LogInformation(
- "[{Id}] Found {Count} function call(s) in response",
- attribute.Id,
- functionCalls.Count);
+ // Process function calls
+ await this.ProcessFunctionCalls(attribute.Id, functionCalls, chatState, batch, cancellationToken);
+ }
+ }
- // Invoke the function calls and add the responses to the chat history.
- List> tasks = new(capacity: functionCalls.Count);
- foreach (ChatCompletionsFunctionToolCall call in functionCalls)
+ async Task> GetLLMResponse(
+ AssistantPostAttribute attribute,
+ InternalChatState chatState,
+ IList? functions,
+ CancellationToken cancellationToken)
+ {
+ ChatCompletionOptions chatRequest = attribute.BuildRequest();
+ if (functions is not null)
+ {
+ foreach (ChatTool fn in functions)
{
- // CONSIDER: Call these in parallel
- this.logger.LogInformation(
- "[{Id}] Calling function '{Name}' with arguments: {Args}",
- attribute.Id,
- call.Name,
- call.Arguments);
-
- string? functionResult;
- try
- {
- // NOTE: In Consumption plans, calling a function from another function results in double-billing.
- // CONSIDER: Use a background thread to invoke the action to avoid double-billing.
- functionResult = await this.skillInvoker.InvokeAsync(call, cancellationToken);
-
- this.logger.LogInformation(
- "[{id}] Function '{Name}' returned the following content: {Content}",
- attribute.Id,
- call.Name,
- functionResult);
- }
- catch (Exception ex)
- {
- this.logger.LogError(
- ex,
- "[{id}] Function '{Name}' failed with an unhandled exception",
- attribute.Id,
- call.Name);
-
- // CONSIDER: Automatic retries?
- functionResult = "The function call failed. Let the user know and ask if they'd like you to try again";
- }
-
- if (string.IsNullOrWhiteSpace(functionResult))
- {
- // When experimenting with gpt-4-0613, an empty result would cause the model to go into a
- // function calling loop. By instead providing a result with some instructions, we were able
- // to get the model to response to the user in a natural way.
- functionResult = "The function call succeeded. Let the user know that you completed the action.";
- }
-
- ChatMessageTableEntity functionResultEntity = new(
- partitionKey: attribute.Id,
- messageIndex: ++chatState.Metadata.TotalMessages,
- content: functionResult,
- role: ChatRole.Function,
- name: call.Name);
- chatState.Messages.Add(functionResultEntity);
-
- batch.Add(new TableTransactionAction(TableTransactionActionType.Add, functionResultEntity));
+ chatRequest.Tools.Add(fn);
}
}
- // Update the assistant state entity
+ IEnumerable chatMessages = ToOpenAIChatRequestMessages(chatState.Messages);
+
+ return await this.openAIClientFactory.GetChatClient(
+ attribute.AIConnectionName,
+ attribute.ChatModel).CompleteChatAsync(chatMessages, chatRequest, cancellationToken: cancellationToken);
+ }
+
+ string FormatReplyMessage(ClientResult response)
+ {
+ return string.Join(
+ Environment.NewLine + Environment.NewLine,
+ response.Value.Content.Select(message => message.Text));
+ }
+
+ void LogAndAddAssistantReply(
+ string assistantId,
+ string replyMessage,
+ ClientResult response,
+ InternalChatState chatState,
+ List batch)
+ {
+ this.logger.LogInformation(
+ "[{Id}] Got LLM response consisting of {Count} tokens: [{Text}] && {Count} ToolCalls",
+ assistantId,
+ response.Value.Usage.OutputTokenCount,
+ replyMessage,
+ response.Value.ToolCalls.Count);
+
+ ChatMessageTableEntity replyFromAssistantEntity = new(
+ partitionKey: assistantId,
+ messageIndex: ++chatState.Metadata.TotalMessages,
+ content: replyMessage,
+ role: ChatMessageRole.Assistant,
+ toolCalls: response.Value.ToolCalls);
+
+ chatState.Messages.Add(replyFromAssistantEntity);
+ batch.Add(new TableTransactionAction(TableTransactionActionType.Add, replyFromAssistantEntity));
+
+ this.logger.LogInformation(
+ "[{Id}] Chat length is now {Count} messages",
+ assistantId,
+ chatState.Metadata.TotalMessages);
+ }
+
+ void LogBatchLimitExceeded(string assistantId, int functionCallCount)
+ {
+ this.logger.LogWarning(
+ "[{Id}] Ignoring {Count} function call(s) in response due to exceeding the limit of {Limit}.",
+ assistantId,
+ functionCallCount,
+ FunctionCallBatchLimit);
+ }
+
+ async Task ProcessFunctionCalls(
+ string assistantId,
+ List functionCalls,
+ InternalChatState chatState,
+ List batch,
+ CancellationToken cancellationToken)
+ {
+ this.logger.LogInformation(
+ "[{Id}] Found {Count} function call(s) in response",
+ assistantId,
+ functionCalls.Count);
+
+ foreach (ChatToolCall call in functionCalls)
+ {
+ await this.ProcessSingleFunctionCall(assistantId, call, chatState, batch, cancellationToken);
+ }
+ }
+
+ async Task ProcessSingleFunctionCall(
+ string assistantId,
+ ChatToolCall call,
+ InternalChatState chatState,
+ List batch,
+ CancellationToken cancellationToken)
+ {
+ this.logger.LogInformation(
+ "[{Id}] Calling function '{Name}' with arguments: {Args}",
+ assistantId,
+ call.FunctionName,
+ call.FunctionArguments);
+
+ string? functionResult = await this.InvokeFunctionWithErrorHandling(assistantId, call, cancellationToken);
+
+ if (string.IsNullOrWhiteSpace(functionResult))
+ {
+ functionResult = "The function call succeeded. Let the user know that you completed the action.";
+ }
+
+ ChatMessageTableEntity functionResultEntity = new(
+ partitionKey: assistantId,
+ messageIndex: ++chatState.Metadata.TotalMessages,
+ content: $"Function Name: '{call.FunctionName}' and Function Result: '{functionResult}'",
+ role: ChatMessageRole.Tool,
+ name: call.Id,
+ toolCalls: null);
+
+ chatState.Messages.Add(functionResultEntity);
+ batch.Add(new TableTransactionAction(TableTransactionActionType.Add, functionResultEntity));
+ }
+
+ async Task InvokeFunctionWithErrorHandling(
+ string assistantId,
+ ChatToolCall call,
+ CancellationToken cancellationToken)
+ {
+ try
+ {
+ // NOTE: In Consumption plans, calling a function from another function results in double-billing.
+ // CONSIDER: Use a background thread to invoke the action to avoid double-billing.
+ string? result = await this.skillInvoker.InvokeAsync(call, cancellationToken);
+
+ this.logger.LogInformation(
+ "[{id}] Function '{Name}' returned the following content: {Content}",
+ assistantId,
+ call.FunctionName,
+ result);
+
+ return result;
+ }
+ catch (Exception ex)
+ {
+ this.logger.LogError(
+ ex,
+ "[{id}] Function '{Name}' failed with an unhandled exception",
+ assistantId,
+ call.FunctionName);
+
+ // CONSIDER: Automatic retries?
+ return "The function call failed. Let the user know and ask if they'd like you to try again";
+ }
+ }
+
+ void UpdateAssistantState(InternalChatState chatState, List batch)
+ {
chatState.Metadata.TotalMessages = chatState.Messages.Count;
chatState.Metadata.LastUpdatedAt = DateTime.UtcNow;
batch.Add(new TableTransactionAction(TableTransactionActionType.UpdateMerge, chatState.Metadata));
+ }
- // Add the batch of table transaction actions to the table
- await tableClient.SubmitTransactionAsync(batch, cancellationToken);
-
- // return the latest assistant message in the chat state
+ AssistantState CreateAssistantStateResponse(string assistantId, InternalChatState chatState, DateTime timeFilter)
+ {
List filteredChatMessages = chatState.Messages
- .Where(msg => msg.CreatedAt > timeFilter && msg.Role == ChatRole.Assistant)
+ .Where(msg => msg.CreatedAt > timeFilter && msg.Role == ChatMessageRole.Assistant.ToString())
.ToList();
this.logger.LogInformation(
"Returning {Count}/{Total} chat messages from entity '{Id}'",
filteredChatMessages.Count,
chatState.Metadata.TotalMessages,
- attribute.Id);
+ assistantId);
- AssistantState state = new(
- attribute.Id,
+ return new AssistantState(
+ assistantId,
true,
chatState.Metadata.CreatedAt,
chatState.Metadata.LastUpdatedAt,
chatState.Metadata.TotalMessages,
chatState.Metadata.TotalTokens,
- filteredChatMessages.Select(msg => new ChatMessage(msg.Content, msg.Role, msg.Name)).ToList());
-
- return state;
+ filteredChatMessages.Select(msg => new AssistantMessage(msg.Content, msg.Role, msg.ToolCallsString)).ToList());
}
async Task LoadChatStateAsync(string id, TableClient tableClient, CancellationToken cancellationToken)
@@ -420,26 +497,30 @@ public async Task PostMessageAsync(AssistantPostAttribute attrib
return new InternalChatState(id, assistantStateEntity, chatMessageList);
}
- static IEnumerable ToOpenAIChatRequestMessages(IEnumerable entities)
+ static IEnumerable ToOpenAIChatRequestMessages(IEnumerable entities)
{
foreach (ChatMessageTableEntity entity in entities)
{
switch (entity.Role.ToLowerInvariant())
{
case "user":
- yield return new ChatRequestUserMessage(entity.Content);
+ yield return new UserChatMessage(entity.Content);
break;
case "assistant":
- yield return new ChatRequestAssistantMessage(entity.Content);
+ if (entity.ToolCalls != null && entity.ToolCalls.Any())
+ {
+ yield return new AssistantChatMessage(entity.ToolCalls);
+ }
+ else
+ {
+ yield return new AssistantChatMessage(entity.Content);
+ }
break;
case "system":
- yield return new ChatRequestSystemMessage(entity.Content);
- break;
- case "function":
- yield return new ChatRequestFunctionMessage(entity.Name, entity.Content);
+ yield return new SystemChatMessage(entity.Content);
break;
case "tool":
- yield return new ChatRequestToolMessage(entity.Content, toolCallId: entity.Name);
+ yield return new ToolChatMessage(toolCallId: entity.Name, entity.Content);
break;
default:
throw new InvalidOperationException($"Unknown chat role '{entity.Role}'");
@@ -483,7 +564,7 @@ TableClient GetOrCreateTableClient(string? chatStorageConnectionSetting, string?
// Else, will use the connection string
connectionStringName = chatStorageConnectionSetting ?? DefaultChatStorage;
string connectionString = this.configuration.GetValue(connectionStringName);
-
+
this.logger.LogInformation("using connection string for table service client");
this.tableServiceClient = new TableServiceClient(connectionString);
diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs
index 749e868b..25209b4c 100644
--- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs
+++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerAttribute.cs
@@ -42,12 +42,4 @@ public AssistantSkillTriggerAttribute(string functionDescription)
/// https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools.
///
public string? ParameterDescriptionJson { get; set; }
-
- ///
- /// Gets or sets the OpenAI chat model to use.
- ///
- ///
- /// When using Azure OpenAI, then should be the name of the model deployment.
- ///
- public string Model { get; set; } = "gpt-3.5-turbo";
}
\ No newline at end of file
diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs
index 55eed508..8b2ee6f9 100644
--- a/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs
+++ b/src/WebJobs.Extensions.OpenAI/Assistants/AssistantSkillTriggerBindingProvider.cs
@@ -113,10 +113,10 @@ public Task BindAsync(object value, ValueBindingContext context)
SkillInvocationContext skillInvocationContext = (SkillInvocationContext)value;
object? convertedValue;
- if (!string.IsNullOrEmpty(skillInvocationContext.Arguments))
+ if (!string.IsNullOrEmpty(skillInvocationContext.Arguments?.ToString()))
{
// We expect that input to always be a string value in the form {"paramName":paramValue}
- JObject argsJson = JObject.Parse(skillInvocationContext.Arguments);
+ JObject argsJson = JObject.Parse(skillInvocationContext.Arguments.ToString());
JToken? paramValue = argsJson[this.parameterInfo.Name];
convertedValue = paramValue?.ToObject(destinationType);
}
diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/BuiltInFunctionsProvider.cs b/src/WebJobs.Extensions.OpenAI/Assistants/BuiltInFunctionsProvider.cs
deleted file mode 100644
index db996693..00000000
--- a/src/WebJobs.Extensions.OpenAI/Assistants/BuiltInFunctionsProvider.cs
+++ /dev/null
@@ -1,71 +0,0 @@
-// Copyright (c) Microsoft Corporation.
-// Licensed under the MIT License.
-
-using System.Collections.Immutable;
-using System.Reflection;
-using Microsoft.Azure.WebJobs.Script.Description;
-using Newtonsoft.Json.Linq;
-
-namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants;
-
-///
-/// Class that defines all the built-in functions for executing CNCF Serverless Workflows.
-/// IMPORTANT: Renaming methods in this class is a breaking change!
-///
-class BuiltInFunctionsProvider : IFunctionProvider
-{
- ///
- Task> IFunctionProvider.GetFunctionMetadataAsync() =>
- Task.FromResult(this.GetFunctionMetadata().ToImmutableArray());
-
- // TODO: Not sure what this is for...
- ///
- ImmutableDictionary> IFunctionProvider.FunctionErrors =>
- new Dictionary>().ToImmutableDictionary();
-
-
- internal static string GetBuiltInFunctionName(string functionName)
- {
- return $"OpenAI::{functionName}";
- }
-
- ///
- /// Returns an enumeration of all the function triggers defined in this class.
- ///
- IEnumerable GetFunctionMetadata()
- {
- foreach (MethodInfo method in this.GetType().GetMethods())
- {
- if (method.GetCustomAttribute() is not FunctionNameAttribute)
- {
- // Not an Azure Function definition
- continue;
- }
-
- FunctionMetadata metadata = new()
- {
- // NOTE: We always use the method name and ignore the function name
- Name = GetBuiltInFunctionName(method.Name),
- ScriptFile = $"assembly:{method.ReflectedType.Assembly.FullName}",
- EntryPoint = $"{method.ReflectedType.FullName}.{method.Name}",
- Language = "DotNetAssembly",
- };
-
- // Scan the parameters for binding attributes and add them to the bindings collection
- // so that we can register them with the Functions runtime.
- foreach (ParameterInfo parameter in method.GetParameters())
- {
- if (parameter.GetCustomAttribute() is not null)
- {
- // NOTE: We assume each OpenAI service function in this file defines the parameter name as "service".
- metadata.Bindings.Add(BindingMetadata.Create(new JObject(
- new JProperty("type", "openAIService"),
- new JProperty("name", "service"),
- new JProperty("direction", "in"))));
- }
- }
-
- yield return metadata;
- }
- }
-}
diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/ChatCompletionsJsonConverter.cs b/src/WebJobs.Extensions.OpenAI/Assistants/ChatCompletionsJsonConverter.cs
index f8ad8c8c..6d0b7e49 100644
--- a/src/WebJobs.Extensions.OpenAI/Assistants/ChatCompletionsJsonConverter.cs
+++ b/src/WebJobs.Extensions.OpenAI/Assistants/ChatCompletionsJsonConverter.cs
@@ -4,19 +4,19 @@
using System.ClientModel.Primitives;
using System.Text.Json;
using System.Text.Json.Serialization;
-using Azure.AI.OpenAI;
+using OpenAI.Chat;
namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants;
-class ChatCompletionsJsonConverter : JsonConverter
+class ChatCompletionsJsonConverter : JsonConverter
{
static readonly ModelReaderWriterOptions modelReaderWriterOptions = new("J");
- public override ChatCompletions Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
+ public override ChatCompletion Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
throw new NotImplementedException();
}
- public override void Write(Utf8JsonWriter writer, ChatCompletions value, JsonSerializerOptions options)
+ public override void Write(Utf8JsonWriter writer, ChatCompletion value, JsonSerializerOptions options)
{
- ((IJsonModel)value).Write(writer, modelReaderWriterOptions);
+ ((IJsonModel)value).Write(writer, modelReaderWriterOptions);
}
}
diff --git a/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs b/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs
index d3e539ba..c11f55bd 100644
--- a/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs
+++ b/src/WebJobs.Extensions.OpenAI/Assistants/IAssistantSkillInvoker.cs
@@ -4,28 +4,28 @@
using System.Reflection;
using System.Runtime.ExceptionServices;
using System.Text;
-using Azure.AI.OpenAI;
using Microsoft.Azure.WebJobs.Host.Executors;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
+using OpenAI.Chat;
namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants;
public interface IAssistantSkillInvoker
{
- IList? GetFunctionsDefinitions();
- Task InvokeAsync(ChatCompletionsFunctionToolCall call, CancellationToken cancellationToken);
+ IList? GetFunctionsDefinitions();
+ Task InvokeAsync(ChatToolCall call, CancellationToken cancellationToken);
}
class SkillInvocationContext
{
- public SkillInvocationContext(string arguments)
+ public SkillInvocationContext(BinaryData arguments)
{
this.Arguments = arguments;
}
// The arguments are passed as a JSON object in the form of {"paramName":paramValue}
- public string Arguments { get; }
+ public BinaryData Arguments { get; }
// The result of the function invocation, if any
public object? Result { get; set; }
@@ -70,14 +70,14 @@ internal void UnregisterSkill(string name)
this.skills.Remove(name);
}
- IList? IAssistantSkillInvoker.GetFunctionsDefinitions()
+ IList? IAssistantSkillInvoker.GetFunctionsDefinitions()
{
if (this.skills.Count == 0)
{
return null;
}
- List functions = new(capacity: this.skills.Count);
+ List functions = new(capacity: this.skills.Count);
foreach (Skill skill in this.skills.Values)
{
// The parameters can be defined in the attribute JSON or can be inferred from
@@ -85,12 +85,12 @@ internal void UnregisterSkill(string name)
string parametersJson = skill.Attribute.ParameterDescriptionJson ??
JsonConvert.SerializeObject(GetParameterDefinition(skill));
- functions.Add(new ChatCompletionsFunctionToolDefinition
- {
- Name = skill.Name,
- Description = skill.Attribute.FunctionDescription,
- Parameters = BinaryData.FromBytes(Encoding.UTF8.GetBytes(parametersJson)),
- });
+ ChatTool chatTool = ChatTool.CreateFunctionTool(
+ skill.Name,
+ skill.Attribute.FunctionDescription,
+ BinaryData.FromBytes(Encoding.UTF8.GetBytes(parametersJson))
+ );
+ functions.Add(chatTool);
}
return functions;
@@ -140,7 +140,7 @@ static Dictionary GetParameterDefinition(Skill skill)
}
async Task IAssistantSkillInvoker.InvokeAsync(
- ChatCompletionsFunctionToolCall call,
+ ChatToolCall call,
CancellationToken cancellationToken)
{
if (call is null)
@@ -148,17 +148,17 @@ static Dictionary GetParameterDefinition(Skill skill)
throw new ArgumentNullException(nameof(call));
}
- if (call.Name is null)
+ if (call.FunctionName is null)
{
throw new ArgumentException("The function call must have a name", nameof(call));
}
- if (!this.skills.TryGetValue(call.Name, out Skill? skill))
+ if (!this.skills.TryGetValue(call.FunctionName, out Skill? skill))
{
- throw new InvalidOperationException($"No skill registered with name '{call.Name}'");
+ throw new InvalidOperationException($"No skill registered with name '{call.FunctionName}'");
}
- SkillInvocationContext skillInvocationContext = new(call.Arguments);
+ SkillInvocationContext skillInvocationContext = new(call.FunctionArguments);
// This call may throw if the Functions host is shutting down or if there is an internal error
// in the Functions runtime. We don't currently try to handle these exceptions.
@@ -170,7 +170,7 @@ static Dictionary GetParameterDefinition(Skill skill)
InvokeHandler = async userCodeInvoker =>
{
// Invoke the function and attempt to get the result.
- this.logger.LogInformation("Invoking user-code function '{Name}'", call.Name);
+ this.logger.LogInformation("Invoking user-code function '{Name}'", call.FunctionName);
Task invokeTask = userCodeInvoker.Invoke();
if (invokeTask is Task
[JsonProperty("chat")]
- public ChatCompletions Chat { get; }
+ public ChatCompletion Chat { get; }
///
/// Gets the latest response message from the OpenAI Chat API.
///
[JsonProperty("response")]
- public string Response => this.Chat.Choices.Last().Message.Content;
+ public string Response => this.Chat.Content.Last().Text;
}
diff --git a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs
index b853036b..3385ca3d 100644
--- a/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs
+++ b/src/WebJobs.Extensions.OpenAI/Search/SemanticSearchConverter.cs
@@ -1,14 +1,15 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
+using System.ClientModel;
using System.Text;
using System.Text.Json;
-using Azure;
using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants;
using Microsoft.Azure.WebJobs.Extensions.OpenAI.Embeddings;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
-using OpenAISDK = Azure.AI.OpenAI;
+using OpenAI.Chat;
+using OpenAI.Embeddings;
namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Search;
@@ -16,27 +17,27 @@ class SemanticSearchConverter :
IAsyncConverter,
IAsyncConverter
{
- readonly OpenAISDK.OpenAIClient openAIClient;
+ readonly OpenAIClientFactory openAIClientFactory;
readonly ILogger logger;
readonly ISearchProvider? searchProvider;
static readonly JsonSerializerOptions options = new()
{
- Converters = {
+ Converters = {
new SearchableDocumentJsonConverter(),
new EmbeddingsContextConverter(),
new EmbeddingsOptionsJsonConverter(),
new ChatCompletionsJsonConverter()}
- };
+ };
public SemanticSearchConverter(
- OpenAISDK.OpenAIClient openAIClient,
+ OpenAIClientFactory openAIClientFactory,
ILoggerFactory loggerFactory,
IEnumerable searchProviders,
IOptions openAiConfigOptions)
{
- this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient));
this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory));
+ this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory));
openAiConfigOptions.Value.SearchProvider.TryGetValue("type", out object value);
this.logger.LogInformation("Type of the searchProvider configured in host file: {type}", value);
@@ -61,14 +62,13 @@ async Task ConvertHelperAsync(
}
// Get the embeddings for the query, which will be used for doing a semantic search
- OpenAISDK.EmbeddingsOptions embeddingsRequest = new(attribute.EmbeddingsModel, new List { attribute.Query });
+ this.logger.LogInformation("Sending OpenAI embeddings request: {request}", attribute.Query);
+ ClientResult embedding = await this.openAIClientFactory.GetEmbeddingClient(
+ attribute.AIConnectionName,
+ attribute.EmbeddingsModel).GenerateEmbeddingAsync(attribute.Query, cancellationToken: cancellationToken);
+ this.logger.LogInformation("Received OpenAI embeddings");
- this.logger.LogInformation("Sending OpenAI embeddings request: {request}", embeddingsRequest.Input);
- Response embeddingsResponse = await this.openAIClient.GetEmbeddingsAsync(embeddingsRequest, cancellationToken);
- this.logger.LogInformation("Received OpenAI embeddings count: {response}", embeddingsResponse.Value.Data.Count);
-
-
- ConnectionInfo connectionInfo = new(attribute.ConnectionName, attribute.Collection);
+ ConnectionInfo connectionInfo = new(attribute.SearchConnectionName, attribute.Collection);
if (string.IsNullOrEmpty(connectionInfo.ConnectionName))
{
throw new InvalidOperationException("No connection string information was provided.");
@@ -81,7 +81,7 @@ async Task ConvertHelperAsync(
// Search for relevant document snippets using the original query and the embeddings
SearchRequest searchRequest = new(
attribute.Query,
- embeddingsResponse.Value.Data[0].Embedding,
+ embedding.Value.ToFloats(),
attribute.MaxKnowledgeCount,
connectionInfo);
SearchResponse searchResponse = await this.searchProvider.SearchAsync(searchRequest);
@@ -95,20 +95,17 @@ async Task ConvertHelperAsync(
}
// Call the chat API with the new combined prompt to get a response back
- OpenAISDK.ChatCompletionsOptions chatCompletionsOptions = new()
- {
- DeploymentName = attribute.ChatModel,
- Messages =
+ IList messages = new List()
{
- new OpenAISDK.ChatRequestSystemMessage(promptBuilder.ToString()),
- new OpenAISDK.ChatRequestUserMessage(attribute.Query),
- }
- };
+ new SystemChatMessage(promptBuilder.ToString()),
+ new UserChatMessage(attribute.Query),
+ };
- Response chatResponse = await this.openAIClient.GetChatCompletionsAsync(chatCompletionsOptions);
+ ChatCompletionOptions completionOptions = attribute.BuildRequest();
+ ClientResult chatResponse = await this.openAIClientFactory.GetChatClient(attribute.AIConnectionName, attribute.ChatModel).CompleteChatAsync(messages, completionOptions);
// Give the user the full context, including the embeddings information as well as the chat info
- return new SemanticSearchContext(new EmbeddingsContext(embeddingsRequest, embeddingsResponse), chatResponse);
+ return new SemanticSearchContext(new EmbeddingsContext(new List { attribute.Query }, null), chatResponse);
}
async Task IAsyncConverter.ConvertAsync(
diff --git a/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs b/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs
index 2ad00ac8..ab3f2944 100644
--- a/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs
+++ b/src/WebJobs.Extensions.OpenAI/TextCompletionAttribute.cs
@@ -1,9 +1,8 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
-using Azure.AI.OpenAI;
using Microsoft.Azure.WebJobs.Description;
-using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models;
+using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants;
namespace Microsoft.Azure.WebJobs.Extensions.OpenAI;
@@ -12,7 +11,7 @@ namespace Microsoft.Azure.WebJobs.Extensions.OpenAI;
///
[Binding]
[AttributeUsage(AttributeTargets.Parameter)]
-public sealed class TextCompletionAttribute : Attribute
+public sealed class TextCompletionAttribute : AssistantBaseAttribute
{
///
/// Initializes a new instance of the class with the specified text prompt.
@@ -20,7 +19,9 @@ public sealed class TextCompletionAttribute : Attribute
/// The prompt to generate completions for, encoded as a string.
public TextCompletionAttribute(string prompt)
{
- this.Prompt = prompt ?? throw new ArgumentNullException(nameof(prompt));
+ this.Prompt = string.IsNullOrEmpty(prompt)
+ ? throw new ArgumentException("Input cannot be null or empty.", nameof(prompt))
+ : prompt;
}
///
@@ -28,70 +29,4 @@ public TextCompletionAttribute(string prompt)
///
[AutoResolve]
public string Prompt { get; }
-
- ///
- /// Gets or sets the ID of the model to use.
- ///
- [AutoResolve]
- public string Model { get; set; } = OpenAIModels.DefaultChatModel;
-
- ///
- /// Gets or sets the sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output
- /// more random, while lower values like 0.2 will make it more focused and deterministic.
- ///
- ///
- /// It's generally recommend to use this or but not both.
- ///
- [AutoResolve]
- public string? Temperature { get; set; } = "0.5";
-
- ///
- /// Gets or sets an alternative to sampling with temperature, called nucleus sampling, where the model considers
- /// the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10%
- /// probability mass are considered.
- ///
- ///
- /// It's generally recommend to use this or but not both.
- ///
- [AutoResolve]
- public string? TopP { get; set; }
-
- ///
- /// Gets or sets the maximum number of tokens to generate in the completion.
- ///
- ///
- /// The token count of your prompt plus max_tokens cannot exceed the model's context length.
- /// Most models have a context length of 2048 tokens (except for the newest models, which support 4096).
- ///
- [AutoResolve]
- public string? MaxTokens { get; set; } = "100";
-
- internal ChatCompletionsOptions BuildRequest()
- {
- ChatCompletionsOptions request = new()
- {
- DeploymentName = this.Model,
- Messages =
- {
- new ChatRequestUserMessage(this.Prompt),
- }
- };
-
- if (int.TryParse(this.MaxTokens, out int maxTokens))
- {
- request.MaxTokens = maxTokens;
- }
-
- if (float.TryParse(this.Temperature, out float temperature))
- {
- request.Temperature = temperature;
- }
-
- if (float.TryParse(this.TopP, out float topP))
- {
- request.NucleusSamplingFactor = topP;
- }
-
- return request;
- }
}
diff --git a/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs b/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs
index b936ea44..6043dec5 100644
--- a/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs
+++ b/src/WebJobs.Extensions.OpenAI/TextCompletionConverter.cs
@@ -1,11 +1,11 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
-using Azure;
-using Azure.AI.OpenAI;
+using System.ClientModel;
using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
+using OpenAI.Chat;
namespace Microsoft.Azure.WebJobs.Extensions.OpenAI;
@@ -13,12 +13,12 @@ class TextCompletionConverter :
IAsyncConverter,
IAsyncConverter
{
- readonly OpenAIClient openAIClient;
+ readonly OpenAIClientFactory openAIClientFactory;
readonly ILogger logger;
- public TextCompletionConverter(OpenAIClient openAIClient, ILoggerFactory loggerFactory)
+ public TextCompletionConverter(OpenAIClientFactory openAIClientFactory, ILoggerFactory loggerFactory)
{
- this.openAIClient = openAIClient ?? throw new ArgumentNullException(nameof(openAIClient));
+ this.openAIClientFactory = openAIClientFactory ?? throw new ArgumentNullException(nameof(openAIClientFactory));
this.logger = loggerFactory?.CreateLogger() ?? throw new ArgumentNullException(nameof(loggerFactory));
}
@@ -43,18 +43,22 @@ async Task ConvertCoreAsync(
TextCompletionAttribute attribute,
CancellationToken cancellationToken)
{
- ChatCompletionsOptions options = attribute.BuildRequest();
- this.logger.LogInformation("Sending OpenAI completion request: {request}", options);
+ ChatCompletionOptions options = attribute.BuildRequest();
+ this.logger.LogInformation("Sending OpenAI completion request with prompt: {request}", attribute.Prompt);
- Response response = await this.openAIClient.GetChatCompletionsAsync(
- options,
- cancellationToken);
+ IList chatMessages = new List()
+ {
+ new UserChatMessage(attribute.Prompt)
+ };
+
+ ClientResult response = await this.openAIClientFactory.GetChatClient(
+ attribute.AIConnectionName,
+ attribute.ChatModel).CompleteChatAsync(chatMessages, options, cancellationToken: cancellationToken);
string text = string.Join(
Environment.NewLine + Environment.NewLine,
- response.Value.Choices.Select(choice => choice.Message.Content));
- TextCompletionResponse textCompletionResponse = new(text, response.Value.Usage.TotalTokens);
-
+ response.Value.Content[0].Text);
+ TextCompletionResponse textCompletionResponse = new(text, response.Value.Usage.TotalTokenCount);
return textCompletionResponse;
}
}
diff --git a/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj b/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj
index 323285c0..0f4fbb48 100644
--- a/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj
+++ b/src/WebJobs.Extensions.OpenAI/WebJobs.Extensions.OpenAI.csproj
@@ -5,13 +5,16 @@
-
+
-
-
+
-
+
+
+
+
+
\ No newline at end of file
diff --git a/tests/SampleValidation/AssistantTests.cs b/tests/SampleValidation/AssistantTests.cs
index e6eb44f6..fc53f512 100644
--- a/tests/SampleValidation/AssistantTests.cs
+++ b/tests/SampleValidation/AssistantTests.cs
@@ -71,7 +71,7 @@ public async Task AddTodoTest()
Assert.StartsWith("text/plain", questionResponse.Content.Headers.ContentType?.MediaType);
// Ensure that the model responded and mentioned the new todo item.
- await ValidateAssistantResponseAsync(expectedMessageCount: 4, expectedContent: "Buy milk", hasTotalTokens: true);
+ await ValidateAssistantResponseAsync(expectedMessageCount: 5, expectedContent: "Buy milk", hasTotalTokens: true);
// Local function to validate each chat bot response
async Task ValidateAssistantResponseAsync(int expectedMessageCount, string expectedContent, bool hasTotalTokens = false)
@@ -107,7 +107,7 @@ async Task ValidateAssistantResponseAsync(int expectedMessageCount, string expec
// The timestamp filter should ensure we only ever look at the most recent messages
Assert.True(messageArray!.Count <= totalMessages);
- Assert.True(messageArray!.Count <= 4);
+ Assert.True(messageArray!.Count <= 5);
if (totalMessages >= expectedMessageCount)
{
@@ -118,13 +118,23 @@ async Task ValidateAssistantResponseAsync(int expectedMessageCount, string expec
{
// Make sure the first message is the system message
JsonNode systemMessage = messageArray!.First()!;
- Assert.Equal("system", systemMessage["role"]?.GetValue());
+ Assert.Equal("system", systemMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase);
}
else
{
+ // Validate the third message contains the toolcalls string with task description
+ if (messageArray.Count >= 2)
+ {
+ JsonNode thirdMessage = messageArray![1]!;
+ Assert.Equal("assistant", thirdMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase);
+ string? thirdMessageToolCalls = thirdMessage["toolCalls"]?.GetValue();
+ Assert.NotNull(thirdMessageToolCalls);
+ Assert.Contains("AddTodo", thirdMessageToolCalls, StringComparison.OrdinalIgnoreCase);
+ }
+
// Make sure that the last message is from the chat bot (assistant)
JsonNode lastMessage = messageArray![messageArray.Count - 1]!;
- Assert.Equal("assistant", lastMessage["role"]?.GetValue());
+ Assert.Equal("assistant", lastMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase);
Assert.Contains("Buy milk", lastMessage!["content"]?.GetValue());
}
diff --git a/tests/SampleValidation/Chat.cs b/tests/SampleValidation/Chat.cs
index e6719e72..386b1f85 100644
--- a/tests/SampleValidation/Chat.cs
+++ b/tests/SampleValidation/Chat.cs
@@ -125,13 +125,13 @@ async Task ValidateAssistantResponseAsync(int expectedMessageCount, string expec
{
// Make sure the first message is the system message
JsonNode systemMessage = messageArray!.First()!;
- Assert.Equal("system", systemMessage["role"]?.GetValue());
+ Assert.Equal("system", systemMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase);
}
else
{
// Make sure that the last message is from the chat bot (assistant)
JsonNode lastMessage = messageArray![messageArray.Count - 1]!;
- Assert.Equal("assistant", lastMessage["role"]?.GetValue());
+ Assert.Equal("assistant", lastMessage["role"]?.GetValue(), StringComparer.OrdinalIgnoreCase);
Assert.StartsWith("Yo!", lastMessage!["content"]?.GetValue());
}
diff --git a/tests/SampleValidation/EmbeddingsTests.cs b/tests/SampleValidation/EmbeddingsTests.cs
new file mode 100644
index 00000000..d9f215bd
--- /dev/null
+++ b/tests/SampleValidation/EmbeddingsTests.cs
@@ -0,0 +1,147 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+using System.Diagnostics;
+using System.Net;
+using System.Net.Http.Json;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace SampleValidation;
+
+public class EmbeddingsTests
+{
+ readonly ITestOutputHelper output;
+ readonly string baseAddress;
+ readonly HttpClient client;
+ readonly CancellationTokenSource cts;
+
+ public EmbeddingsTests(ITestOutputHelper output)
+ {
+ this.output = output;
+ this.client = new HttpClient(new LoggingHandler(this.output));
+ this.cts = new CancellationTokenSource(delay: TimeSpan.FromMinutes(Debugger.IsAttached ? 5 : 1));
+
+ this.baseAddress = Environment.GetEnvironmentVariable("FUNC_BASE_ADDRESS") ?? "http://localhost:7071";
+
+#if RELEASE
+ // Use the default key for the Azure Functions app in RELEASE mode; for local development, DEBUG mode can be used.
+ string functionKey = Environment.GetEnvironmentVariable("FUNC_DEFAULT_KEY") ?? throw new InvalidOperationException("Missing environment variable 'FUNC_DEFAULT_KEY'");
+ client.DefaultRequestHeaders.Add("x-functions-key", functionKey);
+#endif
+ }
+
+ [Fact]
+ public async Task GenerateEmbeddings_Http_Request_Test()
+ {
+ // Arrange
+ var request = new
+ {
+ rawText = "This is a test for generating embeddings from raw text."
+ };
+
+ // Act
+ using HttpResponseMessage response = await this.client.PostAsJsonAsync(
+ requestUri: $"{this.baseAddress}/api/embeddings",
+ request,
+ cancellationToken: this.cts.Token);
+
+ // Assert
+ Assert.Equal(HttpStatusCode.NoContent, response.StatusCode);
+ }
+
+ [Fact]
+ public async Task EmbeddingsLegacy_Performance_Test()
+ {
+ // Create a large text for testing performance
+ string largeText = new('A', 10000); // 10KB text
+
+ var request = new
+ {
+ rawText = largeText
+ };
+
+ // Measure time
+ var stopwatch = Stopwatch.StartNew();
+
+ using HttpResponseMessage response = await this.client.PostAsJsonAsync(
+ requestUri: $"{this.baseAddress}/api/embeddings",
+ request,
+ cancellationToken: this.cts.Token);
+
+ stopwatch.Stop();
+
+ // Assert
+ Assert.Equal(HttpStatusCode.NoContent, response.StatusCode);
+
+ // Log performance metrics
+ this.output.WriteLine($"Embeddings generation for 10KB text took {stopwatch.ElapsedMilliseconds}ms");
+
+ // If specific performance SLAs exist, add assertions like:
+ // Assert.True(stopwatch.ElapsedMilliseconds < 5000, "Embeddings generation took too long");
+ }
+
+ [Fact]
+ public async Task GetEmbeddings_Url_ReturnsNoContent()
+ {
+ // Arrange
+ var request = new { url = "https://github.com/Azure/azure-functions-openai-extension/blob/main/README.md" };
+
+ // Act
+ using HttpResponseMessage response = await this.client.PostAsJsonAsync(
+ requestUri: $"{this.baseAddress}/api/embeddings-from-url",
+ request,
+ cancellationToken: this.cts.Token);
+
+ // Assert
+ Assert.Equal(HttpStatusCode.NoContent, response.StatusCode);
+ }
+
+ [Fact]
+ public async Task GenerateEmbeddings_InvalidText_ReturnsBadRequest()
+ {
+ // Arrange
+ var request = new { url = "" }; // Invalid: Empty text
+
+ // Act
+ using HttpResponseMessage response = await this.client.PostAsJsonAsync(
+ requestUri: $"{this.baseAddress}/api/embeddings",
+ request,
+ cancellationToken: this.cts.Token);
+
+ // Assert
+ Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode);
+ }
+
+ [Fact]
+ public async Task GetEmbeddings_InvalidFilePath_ReturnsBadRequest()
+ {
+ // Arrange
+ var request = new { filePath = "invalid/file/path.txt" }; // Invalid: Non-existent file path
+
+ // Act
+ using HttpResponseMessage response = await this.client.PostAsJsonAsync(
+ requestUri: $"{this.baseAddress}/api/embeddings-from-file",
+ request,
+ cancellationToken: this.cts.Token);
+
+ // Assert
+ Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode);
+ }
+
+ [Fact]
+ public async Task GetEmbeddings_InvalidUrl_ReturnsBadRequest()
+ {
+ // Arrange
+ var request = new { url = "invalid-url" }; // Invalid: Malformed URL
+
+ // Act
+ using HttpResponseMessage response = await this.client.PostAsJsonAsync(
+ requestUri: $"{this.baseAddress}/api/embeddings-from-url",
+ request,
+ cancellationToken: this.cts.Token);
+
+ // Assert
+ Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode);
+ }
+}
\ No newline at end of file
diff --git a/tests/SampleValidation/FilePromptsTests.cs b/tests/SampleValidation/FilePromptsTests.cs
new file mode 100644
index 00000000..8eb01abd
--- /dev/null
+++ b/tests/SampleValidation/FilePromptsTests.cs
@@ -0,0 +1,68 @@
+using System.Diagnostics;
+using System.Net;
+using System.Net.Http.Json;
+using System.Text.Json.Nodes;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace SampleValidation;
+
+public class FilePromptTests
+{
+ readonly ITestOutputHelper output;
+ readonly string baseAddress;
+ readonly HttpClient client;
+ readonly CancellationTokenSource cts;
+
+ public FilePromptTests(ITestOutputHelper output)
+ {
+ this.output = output;
+ this.client = new HttpClient(new LoggingHandler(this.output));
+ this.cts = new CancellationTokenSource(delay: TimeSpan.FromMinutes(Debugger.IsAttached ? 5 : 2));
+
+ this.baseAddress = Environment.GetEnvironmentVariable("FUNC_BASE_ADDRESS") ?? "http://localhost:7071";
+
+#if RELEASE
+ // Use the default key for the Azure Functions app in RELEASE mode; for local development, DEBUG mode can be used.
+ string functionKey = Environment.GetEnvironmentVariable("FUNC_DEFAULT_KEY") ?? throw new InvalidOperationException("Missing environment variable 'FUNC_DEFAULT_KEY'");
+ client.DefaultRequestHeaders.Add("x-functions-key", functionKey);
+#endif
+ }
+
+ [Fact]
+ public async Task Ingest_Prompt_File_Test()
+ {
+ // Step 1: Test IngestFile
+ var ingestRequest = new { url = "https://github.com/Azure/azure-functions-openai-extension/blob/main/README.md" };
+
+ using HttpResponseMessage ingestResponse = await this.client.PostAsJsonAsync(
+ requestUri: $"{this.baseAddress}/api/IngestFile",
+ ingestRequest,
+ cancellationToken: this.cts.Token);
+
+ Assert.Equal(HttpStatusCode.OK, ingestResponse.StatusCode);
+ Assert.StartsWith("application/json", ingestResponse.Content.Headers.ContentType?.MediaType);
+
+ string ingestResponseContent = await ingestResponse.Content.ReadAsStringAsync(this.cts.Token);
+ JsonNode? ingestJsonResponse = JsonNode.Parse(ingestResponseContent);
+ Assert.NotNull(ingestJsonResponse);
+ Assert.Equal("success", ingestJsonResponse!["status"]?.GetValue());
+ Assert.Equal("README.md", ingestJsonResponse!["title"]?.GetValue());
+
+ // Step 2: Test PromptFile
+ var promptRequest = new { prompt = "How can the textCompletion input binding be used from Azure Functions OpenAI extension?" };
+
+ using HttpResponseMessage promptResponse = await this.client.PostAsJsonAsync(
+ requestUri: $"{this.baseAddress}/api/PromptFile",
+ promptRequest,
+ cancellationToken: this.cts.Token);
+
+ Assert.Equal(HttpStatusCode.OK, promptResponse.StatusCode);
+ Assert.StartsWith("text/plain", promptResponse.Content.Headers.ContentType?.MediaType);
+
+ string promptResponseContent = await promptResponse.Content.ReadAsStringAsync(this.cts.Token);
+ Assert.False(string.IsNullOrWhiteSpace(promptResponseContent));
+ Assert.Contains("OpenAI Chat Completions API", promptResponseContent, StringComparison.OrdinalIgnoreCase);
+ Assert.Contains("README", promptResponseContent, StringComparison.OrdinalIgnoreCase);
+ }
+}
\ No newline at end of file
diff --git a/tests/SampleValidation/TextCompletionTests.cs b/tests/SampleValidation/TextCompletionTests.cs
new file mode 100644
index 00000000..2cb18ac4
--- /dev/null
+++ b/tests/SampleValidation/TextCompletionTests.cs
@@ -0,0 +1,131 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+using System.Diagnostics;
+using System.Net;
+using System.Net.Http.Json;
+using Xunit;
+using Xunit.Abstractions;
+
+namespace SampleValidation;
+
+public class TextCompletionLegacyTests
+{
+ readonly ITestOutputHelper output;
+ readonly HttpClient client;
+ readonly string baseAddress;
+ readonly CancellationTokenSource cts;
+
+ public TextCompletionLegacyTests(ITestOutputHelper output)
+ {
+ this.output = output;
+ this.client = new HttpClient(new LoggingHandler(this.output));
+ this.baseAddress = Environment.GetEnvironmentVariable("FUNC_BASE_ADDRESS") ?? "http://localhost:7071";
+ this.cts = new CancellationTokenSource(delay: TimeSpan.FromMinutes(Debugger.IsAttached ? 5 : 1));
+
+#if RELEASE
+ // Use the default key for the Azure Functions app in RELEASE mode; for local development, DEBUG mode can be used.
+ string functionKey = Environment.GetEnvironmentVariable("FUNC_DEFAULT_KEY") ?? throw new InvalidOperationException("Missing environment variable 'FUNC_DEFAULT_KEY'");
+ this.client.DefaultRequestHeaders.Add("x-functions-key", functionKey);
+#endif
+ }
+
+ [Fact]
+ public async Task WhoIs_ValidName_ReturnsContent()
+ {
+ // Arrange
+ string name = "Albert Einstein";
+
+ // Act
+ using HttpResponseMessage response = await this.client.GetAsync(
+ requestUri: $"{this.baseAddress}/api/whois/{name}",
+ cancellationToken: this.cts.Token);
+
+ // Assert
+ Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+ string content = await response.Content.ReadAsStringAsync(this.cts.Token);
+ this.output.WriteLine($"Response content: {content}");
+
+ // Validate that content is not empty and contains some relevant information about Albert Einstein
+ Assert.NotEmpty(content);
+ Assert.Contains("physicist", content, StringComparison.OrdinalIgnoreCase);
+ Assert.Contains("relativity", content, StringComparison.OrdinalIgnoreCase);
+ }
+
+ [Fact]
+ public async Task GenericCompletion_ValidPrompt_ReturnsContent()
+ {
+ // Arrange
+ var promptPayload = new { Prompt = "Write a haiku about programming" };
+
+ // Act
+ using HttpResponseMessage response = await this.client.PostAsJsonAsync(
+ requestUri: $"{this.baseAddress}/api/GenericCompletion",
+ promptPayload,
+ cancellationToken: this.cts.Token);
+
+ // Assert
+ Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+ string content = await response.Content.ReadAsStringAsync(this.cts.Token);
+ this.output.WriteLine($"Response content: {content}");
+
+ // Validate that content is not empty and resembles a haiku structure
+ Assert.NotEmpty(content);
+
+ // Verify we received some form of haiku (might contain line breaks)
+ // We can't verify exact content since AI responses vary, but we can check for reasonable length
+ // and common haiku-related formatting
+ Assert.True(content.Length >= 10, "Response should contain a meaningful haiku");
+ Assert.True(content.Split('\n').Length >= 1, "Haiku should have at least one line");
+ }
+
+ [Fact]
+ public async Task GenericCompletion_EmptyPrompt_ReturnsError()
+ {
+ // Arrange
+ var emptyPrompt = new { Prompt = string.Empty };
+
+ // Act
+ using HttpResponseMessage response = await this.client.PostAsJsonAsync(
+ requestUri: $"{this.baseAddress}/api/GenericCompletion",
+ emptyPrompt,
+ cancellationToken: this.cts.Token);
+
+ // Assert
+ Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode);
+ }
+
+ [Fact]
+ public async Task GenericCompletion_ComplexPrompt_ReturnsValidContent()
+ {
+ // Arrange
+ var complexPrompt = new
+ {
+ Prompt = "Explain the difference between synchronous and asynchronous programming in 3 bullet points"
+ };
+
+ // Act
+ using HttpResponseMessage response = await this.client.PostAsJsonAsync(
+ requestUri: $"{this.baseAddress}/api/GenericCompletion",
+ complexPrompt,
+ cancellationToken: this.cts.Token);
+
+ // Assert
+ Assert.Equal(HttpStatusCode.OK, response.StatusCode);
+ string content = await response.Content.ReadAsStringAsync(this.cts.Token);
+ this.output.WriteLine($"Response content: {content}");
+
+ // Validate that content is not empty and contains relevant terms
+ Assert.NotEmpty(content);
+ Assert.Contains("synchronous", content, StringComparison.OrdinalIgnoreCase);
+ Assert.Contains("asynchronous", content, StringComparison.OrdinalIgnoreCase);
+
+ // Check for bullet point formatting (could be -, *, or numbered)
+ bool hasBulletFormat = content.Contains('-') ||
+ content.Contains('*') ||
+ content.Contains("1.") ||
+ content.Contains('•');
+
+ Assert.True(hasBulletFormat, "Response should contain bullet points");
+ }
+}
diff --git a/tests/UnitTests/AssistantServiceTests.cs b/tests/UnitTests/AssistantServiceTests.cs
new file mode 100644
index 00000000..eb40bb59
--- /dev/null
+++ b/tests/UnitTests/AssistantServiceTests.cs
@@ -0,0 +1,472 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+
+using System.ClientModel;
+using System.Threading;
+using Azure;
+using Azure.Core;
+using Azure.Data.Tables;
+using Azure.Data.Tables.Models;
+using Microsoft.Azure.WebJobs.Extensions.OpenAI.Assistants;
+using Microsoft.Azure.WebJobs.Extensions.OpenAI.Models;
+using Microsoft.Extensions.Azure;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.Logging;
+using Moq;
+using OpenAI.Chat;
+using Xunit;
+
+namespace Microsoft.Azure.WebJobs.Extensions.OpenAI.Tests.Assistants;
+
+public class DefaultAssistantServiceTests
+{
+ readonly Mock mockOpenAIClientFactory;
+ readonly Mock mockAzureComponentFactory;
+ readonly Mock mockConfiguration;
+ readonly Mock mockSkillInvoker;
+ readonly Mock mockLoggerFactory;
+ readonly Mock> mockLogger;
+ readonly Mock mockTableServiceClient;
+ readonly Mock mockTableClient;
+
+ public DefaultAssistantServiceTests()
+ {
+ this.mockAzureComponentFactory = new Mock();
+ this.mockConfiguration = new Mock();
+ this.mockSkillInvoker = new Mock();
+ this.mockLoggerFactory = new Mock();
+ this.mockLogger = new Mock>();
+ this.mockTableServiceClient = new Mock();
+ this.mockTableClient = new Mock();
+ this.mockOpenAIClientFactory = new Mock(
+ this.mockConfiguration.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockLoggerFactory.Object);
+
+ this.mockLoggerFactory.Setup(x => x.CreateLogger(It.IsAny()))
+ .Returns(this.mockLogger.Object);
+
+ // Setup table client
+ this.mockTableServiceClient.Setup(x => x.GetTableClient(It.IsAny()))
+ .Returns(this.mockTableClient.Object);
+ }
+
+ [Fact]
+ public async Task CreateAssistantAsync_WithValidRequest_CreatesAssistantAndMessages()
+ {
+ // Arrange
+ var request = new AssistantCreateRequest("testId", "Test instructions")
+ {
+ CollectionName = "ChatState",
+ ChatStorageConnectionSetting = "AzureWebJobsStorage"
+ };
+
+ var mockQueryResult = new List();
+ AsyncPageable mockQueryable = MockAsyncPageable.Create(mockQueryResult);
+
+ this.mockTableClient.Setup(x => x.CreateIfNotExistsAsync(It.IsAny()))
+ .ReturnsAsync(Response.FromValue(new TableItem(request.CollectionName), new Mock().Object));
+
+ this.mockTableClient.Setup(x => x.QueryAsync(
+ It.Is(s => s == $"PartitionKey eq '{request.Id}'"),
+ null,
+ null,
+ It.IsAny()))
+ .Returns(mockQueryable);
+
+ this.mockTableClient.Setup(x => x.SubmitTransactionAsync(
+ It.IsAny>(),
+ It.IsAny()))
+ .ReturnsAsync(Response.FromValue(new List() as IReadOnlyList, new Mock().Object));
+
+ //// Arrange
+ //Mock mockSection = CreateMockSection(
+ // exists: false,
+ // tableServiceUri: null);
+ //mockSection.Setup(s => s.Value).Returns("UseDevelopmentStorage=true");
+
+ //// Setup AzureWebJobsStorage directly
+ //this.mockConfiguration.Setup(c => c["AzureWebJobsStorage"]).Returns("UseDevelopmentStorage=true");
+ //this.mockConfiguration.Setup(c => c.GetSection("AzureWebJobsStorage")).Returns(mockSection.Object);
+
+ // Create the service under test
+ var assistantService = new DefaultAssistantService(
+ this.mockOpenAIClientFactory.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockConfiguration.Object,
+ this.mockSkillInvoker.Object,
+ this.mockLoggerFactory.Object);
+
+ // Use reflection to set the tableClient field
+ System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient",
+ System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
+ tableClientField?.SetValue(assistantService, this.mockTableClient.Object);
+
+ // Act
+ await assistantService.CreateAssistantAsync(request, CancellationToken.None);
+
+ // Assert
+ this.mockTableClient.Verify(x => x.CreateIfNotExistsAsync(It.IsAny()), Times.Once);
+
+ this.mockTableClient.Verify(x => x.QueryAsync(
+ It.Is(s => s.Contains(request.Id)),
+ null,
+ null,
+ It.IsAny()), Times.Once);
+
+ this.mockTableClient.Verify(x => x.SubmitTransactionAsync(
+ It.IsAny>(),
+ It.IsAny()), Times.Once);
+ }
+
+ [Fact]
+ public async Task CreateAssistantAsync_WithExistingAssistant_DeletesOldEntitiesFirst()
+ {
+ // Arrange
+ var request = new AssistantCreateRequest("testId", "Test instructions")
+ {
+ CollectionName = "ChatState",
+ ChatStorageConnectionSetting = "AzureWebJobsStorage"
+ };
+
+ // Create existing entities
+ var existingEntities = new List
+ {
+ new("testId", "state"),
+ new("testId", "msg-001")
+ };
+
+ AsyncPageable mockQueryable = MockAsyncPageable.Create(existingEntities);
+
+ this.mockTableClient.Setup(x => x.CreateIfNotExistsAsync(It.IsAny()))
+ .ReturnsAsync(Response.FromValue(new TableItem(request.CollectionName), new Mock().Object));
+
+ this.mockTableClient.Setup(x => x.QueryAsync(
+ It.Is(s => s == $"PartitionKey eq '{request.Id}'"),
+ null,
+ null,
+ It.IsAny()))
+ .Returns(mockQueryable);
+
+ this.mockTableClient.Setup(x => x.SubmitTransactionAsync(
+ It.IsAny>(),
+ It.IsAny()))
+ .ReturnsAsync(Response.FromValue(new List() as IReadOnlyList, new Mock().Object));
+
+ // Create the service under test
+ var assistantService = new DefaultAssistantService(
+ this.mockOpenAIClientFactory.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockConfiguration.Object,
+ this.mockSkillInvoker.Object,
+ this.mockLoggerFactory.Object);
+
+ // Use reflection to set the tableClient field
+ System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient",
+ System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
+ tableClientField?.SetValue(assistantService, this.mockTableClient.Object);
+
+
+ // Act
+ await assistantService.CreateAssistantAsync(request, CancellationToken.None);
+
+ // Assert
+ this.mockTableClient.Verify(x => x.SubmitTransactionAsync(
+ It.IsAny>(),
+ It.IsAny()), Times.AtLeastOnce);
+
+ this.mockTableClient.Verify(x => x.SubmitTransactionAsync(
+ It.Is>(
+ actions => actions.Any(a => a.ActionType == TableTransactionActionType.Add)),
+ It.IsAny()), Times.AtLeastOnce);
+ }
+
+ [Fact]
+ public async Task GetStateAsync_WithValidId_ReturnsCorrectState()
+ {
+ // Arrange
+ string id = "testId";
+ string timestamp = DateTime.UtcNow.AddHours(-1).ToString("o");
+ var attribute = new AssistantQueryAttribute(id)
+ {
+ TimestampUtc = timestamp,
+ CollectionName = "testCollection"
+ };
+
+ // Create mock entities
+ var stateEntity = new TableEntity(id, AssistantStateEntity.FixedRowKeyValue)
+ {
+ ["Exists"] = true,
+ ["CreatedAt"] = DateTime.UtcNow.AddDays(-1),
+ ["LastUpdatedAt"] = DateTime.UtcNow,
+ ["TotalMessages"] = 2,
+ ["TotalTokens"] = 100
+ };
+
+ var message1 = new TableEntity(id, "msg-001")
+ {
+ ["Content"] = "Test message 1",
+ ["Role"] = ChatMessageRole.System.ToString(),
+ ["CreatedAt"] = DateTime.UtcNow.AddMinutes(-30),
+ ["ToolCalls"] = ""
+ };
+
+ var message2 = new TableEntity(id, "msg-002")
+ {
+ ["Content"] = "Test message 2",
+ ["Role"] = ChatMessageRole.Assistant.ToString(),
+ ["CreatedAt"] = DateTime.UtcNow.AddMinutes(-15),
+ ["ToolCalls"] = ""
+ };
+
+ var mockQueryResult = new List { stateEntity, message1, message2 };
+ AsyncPageable mockQueryable = MockAsyncPageable.Create(mockQueryResult);
+
+ this.mockTableClient.Setup(x => x.QueryAsync(
+ It.Is(s => s == $"PartitionKey eq '{id}'"),
+ null,
+ null,
+ It.IsAny()))
+ .Returns(mockQueryable);
+
+ // Create the service under test
+ var assistantService = new DefaultAssistantService(
+ this.mockOpenAIClientFactory.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockConfiguration.Object,
+ this.mockSkillInvoker.Object,
+ this.mockLoggerFactory.Object);
+
+ // Use reflection to set the tableClient field
+ System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient",
+ System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
+ tableClientField?.SetValue(assistantService, this.mockTableClient.Object);
+
+ // Act
+ AssistantState result = await assistantService.GetStateAsync(attribute, CancellationToken.None);
+
+ // Assert
+ Assert.Equal(id, result.Id);
+ Assert.True(result.Exists);
+ Assert.Equal(2, result.TotalMessages);
+ Assert.Equal(100, result.TotalTokens);
+
+ // Verify that we filter messages based on timestamp
+ DateTime parsedTimestamp = DateTime.Parse(Uri.UnescapeDataString(timestamp)).ToUniversalTime();
+ Assert.All(result.RecentMessages, msg =>
+ Assert.True(DateTime.UtcNow.AddMinutes(-30) > parsedTimestamp));
+ }
+
+ [Fact]
+ public async Task GetStateAsync_WithNonExistentId_ReturnsEmptyState()
+ {
+ // Arrange
+ string id = "nonExistentId";
+ string timestamp = DateTime.UtcNow.AddHours(-1).ToString("o");
+ var attribute = new AssistantQueryAttribute(id)
+ {
+ TimestampUtc = timestamp,
+ CollectionName = "testCollection"
+ };
+
+ var mockQueryResult = new List { };
+ AsyncPageable mockQueryable = MockAsyncPageable.Create(mockQueryResult);
+
+ this.mockTableClient.Setup(x => x.QueryAsync(
+ It.Is(s => s == $"PartitionKey eq '{id}'"),
+ null,
+ null,
+ It.IsAny()))
+ .Returns(mockQueryable);
+
+ // Create the service under test
+ var assistantService = new DefaultAssistantService(
+ this.mockOpenAIClientFactory.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockConfiguration.Object,
+ this.mockSkillInvoker.Object,
+ this.mockLoggerFactory.Object);
+
+ // Use reflection to set the tableClient field
+ System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient",
+ System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
+ tableClientField?.SetValue(assistantService, this.mockTableClient.Object);
+
+ // Act
+ AssistantState result = await assistantService.GetStateAsync(attribute, CancellationToken.None);
+
+ // Assert
+ Assert.Equal(id, result.Id);
+ Assert.False(result.Exists);
+ Assert.Equal(0, result.TotalMessages);
+ Assert.Empty(result.RecentMessages);
+ }
+
+ [Fact]
+ public void Constructor_WithNullArguments_ThrowsArgumentNullException()
+ {
+ // Act & Assert
+#nullable disable
+ Assert.Throws(() => new DefaultAssistantService(
+ null,
+ this.mockAzureComponentFactory.Object,
+ this.mockConfiguration.Object,
+ this.mockSkillInvoker.Object,
+ this.mockLoggerFactory.Object));
+
+ Assert.Throws(() => new DefaultAssistantService(
+ this.mockOpenAIClientFactory.Object,
+ null,
+ this.mockConfiguration.Object,
+ this.mockSkillInvoker.Object,
+ this.mockLoggerFactory.Object));
+
+ Assert.Throws(() => new DefaultAssistantService(
+ this.mockOpenAIClientFactory.Object,
+ this.mockAzureComponentFactory.Object,
+ null,
+ this.mockSkillInvoker.Object,
+ this.mockLoggerFactory.Object));
+
+ Assert.Throws(() => new DefaultAssistantService(
+ this.mockOpenAIClientFactory.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockConfiguration.Object,
+ null,
+ this.mockLoggerFactory.Object));
+
+ Assert.Throws(() => new DefaultAssistantService(
+ this.mockOpenAIClientFactory.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockConfiguration.Object,
+ this.mockSkillInvoker.Object,
+ null));
+#nullable restore
+ }
+
+ [Fact]
+ public async Task PostMessageAsync_WithNonExistentAssistant_ReturnsEmptyState()
+ {
+ // Arrange
+ string assistantId = "nonExistentId";
+ var attribute = new AssistantPostAttribute(assistantId, "Hello")
+ {
+ CollectionName = "ChatState",
+ ChatStorageConnectionSetting = "AzureWebJobsStorage"
+ };
+
+ var mockQueryResult = new List();
+ AsyncPageable mockQueryable = MockAsyncPageable.Create(mockQueryResult);
+
+ this.mockTableClient.Setup(x => x.QueryAsync(
+ It.Is(s => s == $"PartitionKey eq '{assistantId}'"),
+ null,
+ null,
+ It.IsAny()))
+ .Returns(mockQueryable);
+
+ // Create the service under test
+ var assistantService = new DefaultAssistantService(
+ this.mockOpenAIClientFactory.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockConfiguration.Object,
+ this.mockSkillInvoker.Object,
+ this.mockLoggerFactory.Object);
+
+ // Use reflection to set the tableClient field
+ System.Reflection.FieldInfo? tableClientField = typeof(DefaultAssistantService).GetField("tableClient",
+ System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
+ tableClientField?.SetValue(assistantService, this.mockTableClient.Object);
+
+ // Act
+ AssistantState result = await assistantService.PostMessageAsync(attribute, CancellationToken.None);
+
+ // Assert
+ Assert.Equal(assistantId, result.Id);
+ Assert.False(result.Exists);
+ Assert.Equal(0, result.TotalMessages);
+ Assert.Empty(result.RecentMessages);
+ }
+
+ [Fact]
+ public async Task PostMessageAsync_WithInvalidAttributes_ThrowsArgumentException()
+ {
+ // Arrange - Missing user message
+ var attributeNoMessage = new AssistantPostAttribute("testId", "")
+ {
+ CollectionName = "ChatState",
+ ChatStorageConnectionSetting = "AzureWebJobsStorage"
+ };
+
+ // Create the service under test
+ var assistantService = new DefaultAssistantService(
+ this.mockOpenAIClientFactory.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockConfiguration.Object,
+ this.mockSkillInvoker.Object,
+ this.mockLoggerFactory.Object);
+
+ // Act & Assert - Empty message
+ await Assert.ThrowsAsync(() =>
+ assistantService.PostMessageAsync(attributeNoMessage, CancellationToken.None));
+
+ // Arrange - Missing ID
+ var attributeNoId = new AssistantPostAttribute("", "Hello")
+ {
+ CollectionName = "ChatState",
+ ChatStorageConnectionSetting = "AzureWebJobsStorage"
+ };
+
+ // Act & Assert - Empty ID
+ await Assert.ThrowsAsync(() =>
+ assistantService.PostMessageAsync(attributeNoId, CancellationToken.None));
+ }
+
+ static Mock CreateMockSection(bool exists, string? tableServiceUri = null)
+ {
+ var mockSection = new Mock();
+
+ if (!exists)
+ {
+ mockSection.Setup(s => s.Value).Returns("");
+ mockSection.Setup(s => s.GetChildren()).Returns([]);
+ }
+ else
+ {
+ mockSection.Setup(s => s.Value).Returns("some-value");
+ }
+
+ var endpointSection = new Mock();
+#nullable disable
+ endpointSection.Setup(s => s.Value).Returns(tableServiceUri);
+#nullable restore
+ mockSection.Setup(s => s.GetSection("tableServiceUri")).Returns(endpointSection.Object);
+
+ return mockSection;
+ }
+}
+
+// Update the MockAsyncPageable class to ensure TableEntity satisfies the 'notnull' constraint
+public class MockAsyncPageable : AsyncPageable where T : notnull
+{
+ readonly IEnumerable _items;
+
+ MockAsyncPageable(IEnumerable items)
+ {
+ this._items = items;
+ }
+
+ public static AsyncPageable Create(IEnumerable items)
+ {
+ return new MockAsyncPageable(items);
+ }
+
+ public override async IAsyncEnumerable> AsPages(
+ string? continuationToken = null,
+ int? pageSizeHint = null)
+ {
+ await Task.Yield();
+ yield return Page.FromValues([.. this._items], null, new Mock().Object);
+ }
+}
diff --git a/tests/UnitTests/OpenAIClientFactoryTests.cs b/tests/UnitTests/OpenAIClientFactoryTests.cs
new file mode 100644
index 00000000..37c5ba72
--- /dev/null
+++ b/tests/UnitTests/OpenAIClientFactoryTests.cs
@@ -0,0 +1,383 @@
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+
+
+using System.Reflection;
+using Azure.Core;
+using Microsoft.Azure.WebJobs.Extensions.OpenAI;
+using Microsoft.Extensions.Azure;
+using Microsoft.Extensions.Configuration;
+using Microsoft.Extensions.Logging;
+using Moq;
+using OpenAI.Chat;
+using OpenAI.Embeddings;
+using Xunit;
+
+namespace WebJobsOpenAIUnitTests;
+
+public class OpenAIClientFactoryTests
+{
+ readonly Mock mockConfiguration;
+ readonly Mock mockAzureComponentFactory;
+ readonly Mock mockLoggerFactory;
+ readonly Mock> mockLogger;
+
+ public OpenAIClientFactoryTests()
+ {
+ this.mockConfiguration = new Mock();
+ this.mockAzureComponentFactory = new Mock();
+ this.mockLoggerFactory = new Mock();
+ this.mockLogger = new Mock>();
+ this.mockLoggerFactory
+ .Setup(f => f.CreateLogger(It.Is(s => s == typeof(OpenAIClientFactory).FullName)))
+ .Returns(this.mockLogger.Object);
+ }
+
+ [Fact]
+ public void Constructor_NullConfiguration_ThrowsArgumentNullException()
+ {
+ // Arrange
+ IConfiguration? configuration = null;
+
+#nullable disable
+ // Act & Assert
+ ArgumentNullException exception = Assert.Throws(() =>
+ new OpenAIClientFactory(configuration, this.mockAzureComponentFactory.Object, this.mockLoggerFactory.Object));
+#nullable restore
+ Assert.Equal("configuration", exception.ParamName);
+ }
+
+ [Fact]
+ public void Constructor_NullAzureComponentFactory_ThrowsArgumentNullException()
+ {
+ // Arrange
+ AzureComponentFactory? azureComponentFactory = null;
+
+#nullable disable
+ // Act & Assert
+ ArgumentNullException exception = Assert.Throws(() =>
+ new OpenAIClientFactory(this.mockConfiguration.Object, azureComponentFactory, this.mockLoggerFactory.Object));
+#nullable restore
+ Assert.Equal("azureComponentFactory", exception.ParamName);
+ }
+
+ [Fact]
+ public void Constructor_NullLoggerFactory_ThrowsArgumentNullException()
+ {
+ // Arrange
+ ILoggerFactory? loggerFactory = null;
+
+ // Act & Assert
+#nullable disable
+ ArgumentNullException exception = Assert.Throws(() =>
+ new OpenAIClientFactory(this.mockConfiguration.Object, this.mockAzureComponentFactory.Object, loggerFactory));
+#nullable restore
+ Assert.Equal("loggerFactory", exception.ParamName);
+ }
+
+ [Fact]
+ public void GetChatClient_WithAzureOpenAI_ReturnsClient()
+ {
+ // Arrange
+ Mock mockSection = CreateMockSection(
+ exists: true,
+ endpoint: "https://test-endpoint.openai.azure.com/",
+ key: "test-key");
+
+ this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object);
+
+ // Clear environment variable if set
+ string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY");
+ try
+ {
+ Environment.SetEnvironmentVariable("OPENAI_API_KEY", null);
+
+ OpenAIClientFactory factory = new(
+ this.mockConfiguration.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockLoggerFactory.Object);
+
+ // Act
+ ChatClient chatClient = factory.GetChatClient("TestConnection", "gpt-35-turbo");
+
+ // Assert
+ Assert.NotNull(chatClient);
+ }
+ finally
+ {
+ // Restore original environment value
+ Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue);
+ }
+ }
+
+ [Fact]
+ public void GetChatClient_WithOpenAIKey_ReturnsClient()
+ {
+ // Arrange
+ string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY");
+ try
+ {
+ Environment.SetEnvironmentVariable("OPENAI_API_KEY", "test-openai-key");
+
+ OpenAIClientFactory factory = new(
+ this.mockConfiguration.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockLoggerFactory.Object);
+
+ // Act
+ ChatClient chatClient = factory.GetChatClient("TestConnection", "gpt-4");
+
+ // Assert
+ Assert.NotNull(chatClient);
+ }
+ finally
+ {
+ // Restore original environment value
+ Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue);
+ }
+ }
+
+ [Fact]
+ public void GetEmbeddingClient_WithAzureOpenAI_ReturnsClient()
+ {
+ // Arrange
+ Mock mockSection = CreateMockSection(
+ exists: true,
+ endpoint: "https://test-endpoint.openai.azure.com/",
+ key: "test-key");
+
+ this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object);
+
+ // Clear environment variable if set
+ string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY");
+ try
+ {
+ Environment.SetEnvironmentVariable("OPENAI_API_KEY", null);
+
+ OpenAIClientFactory factory = new(
+ this.mockConfiguration.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockLoggerFactory.Object);
+
+ // Act
+ EmbeddingClient embeddingClient = factory.GetEmbeddingClient("TestConnection", "text-embedding-ada-002");
+
+ // Assert
+ Assert.NotNull(embeddingClient);
+ }
+ finally
+ {
+ // Restore original environment value
+ Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue);
+ }
+ }
+
+ [Fact]
+ public void GetEmbeddingClient_WithOpenAIKey_ReturnsClient()
+ {
+ // Arrange
+ string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY");
+ try
+ {
+ Environment.SetEnvironmentVariable("OPENAI_API_KEY", "test-openai-key");
+
+ OpenAIClientFactory factory = new(
+ this.mockConfiguration.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockLoggerFactory.Object);
+
+ // Act
+ EmbeddingClient embeddingClient = factory.GetEmbeddingClient("TestConnection", "text-embedding-ada-002");
+
+ // Assert
+ Assert.NotNull(embeddingClient);
+ }
+ finally
+ {
+ // Restore original environment value
+ Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue);
+ }
+ }
+
+ [Fact]
+ public void HasOpenAIKey_KeyExists_SetsHasKeyToTrue()
+ {
+ // Arrange
+ string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY");
+ try
+ {
+ Environment.SetEnvironmentVariable("OPENAI_API_KEY", "test-key");
+
+ // Use reflection to access private static method
+ MethodInfo? methodInfo = typeof(OpenAIClientFactory).GetMethod("HasOpenAIKey",
+ BindingFlags.NonPublic | BindingFlags.Static);
+
+ // Act
+#nullable disable
+ object[] parameters = [false, null];
+#nullable restore
+ methodInfo?.Invoke(null, parameters);
+
+ // Assert
+ Assert.True((bool)parameters[0]);
+ Assert.Equal("test-key", parameters[1]);
+ }
+ finally
+ {
+ // Restore original environment value
+ Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue);
+ }
+ }
+
+ [Fact]
+ public void HasOpenAIKey_NoKeyExists_SetsHasKeyToFalse()
+ {
+ // Arrange
+ string? originalValue = Environment.GetEnvironmentVariable("OPENAI_API_KEY");
+ try
+ {
+ Environment.SetEnvironmentVariable("OPENAI_API_KEY", null);
+
+ // Use reflection to access private static method
+ MethodInfo? methodInfo = typeof(OpenAIClientFactory).GetMethod("HasOpenAIKey",
+ BindingFlags.NonPublic | BindingFlags.Static);
+
+ // Act
+ object[] parameters = [true, "some-value"];
+ methodInfo?.Invoke(null, parameters);
+
+ // Assert
+ Assert.False((bool)parameters[0]);
+ Assert.Null((string)parameters[1]);
+ }
+ finally
+ {
+ // Restore original environment value
+ Environment.SetEnvironmentVariable("OPENAI_API_KEY", originalValue);
+ }
+ }
+
+ [Fact]
+ public void CreateClientFromConfigSection_MissingEndpoint_ThrowsInvalidOperationException()
+ {
+ // Arrange
+ Mock mockSection = CreateMockSection(
+ exists: true,
+ endpoint: null,
+ key: "test-key");
+
+ this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object);
+
+ // Clear environment variable if set
+ string? originalEndpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT");
+ try
+ {
+ Environment.SetEnvironmentVariable("AZURE_OPENAI_ENDPOINT", null);
+
+ OpenAIClientFactory factory = new(
+ this.mockConfiguration.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockLoggerFactory.Object);
+
+ // Act & Assert
+ InvalidOperationException exception = Assert.Throws(() =>
+ factory.GetChatClient("TestConnection", "gpt-35-turbo"));
+ Assert.Contains("missing required 'Endpoint'", exception.Message);
+ }
+ finally
+ {
+ // Restore original environment value
+ Environment.SetEnvironmentVariable("AZURE_OPENAI_ENDPOINT", originalEndpoint);
+ }
+ }
+
+ [Fact]
+ public void CreateClientFromConfigSection_WithTokenCredential_ReturnsClient()
+ {
+ // Arrange
+ Mock mockSection = CreateMockSection(
+ exists: true,
+ endpoint: "https://test-endpoint.openai.azure.com/",
+ key: null);
+
+ this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object);
+
+ TokenCredential tokenCredential = new Mock().Object;
+ this.mockAzureComponentFactory.Setup(f => f.CreateTokenCredential(It.IsAny()))
+ .Returns(tokenCredential);
+
+ OpenAIClientFactory factory = new(
+ this.mockConfiguration.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockLoggerFactory.Object);
+
+ // Act
+ ChatClient chatClient = factory.GetChatClient("TestConnection", "gpt-35-turbo");
+
+ // Assert
+ Assert.NotNull(chatClient);
+ this.mockAzureComponentFactory.Verify(f => f.CreateTokenCredential(It.IsAny()), Times.Once);
+ }
+
+ [Fact]
+ public void CreateClientFromConfigSection_WithEnvironmentVariables_UsesEnvironmentValues()
+ {
+ // Arrange
+ Mock mockSection = CreateMockSection(
+ exists: false,
+ endpoint: null,
+ key: null);
+ this.mockConfiguration.Setup(c => c.GetSection("TestConnection")).Returns(mockSection.Object);
+
+ string? originalEndpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT");
+ string? originalKey = Environment.GetEnvironmentVariable("AZURE_OPENAI_KEY");
+ try
+ {
+ Environment.SetEnvironmentVariable("AZURE_OPENAI_ENDPOINT", "https://env-endpoint.openai.azure.com/");
+ Environment.SetEnvironmentVariable("AZURE_OPENAI_KEY", "env-key");
+
+ OpenAIClientFactory factory = new(
+ this.mockConfiguration.Object,
+ this.mockAzureComponentFactory.Object,
+ this.mockLoggerFactory.Object);
+
+ // Act
+ ChatClient chatClient = factory.GetChatClient("TestConnection", "gpt-35-turbo");
+
+ // Assert
+ Assert.NotNull(chatClient);
+ }
+ finally
+ {
+ // Restore original environment values
+ Environment.SetEnvironmentVariable("AZURE_OPENAI_ENDPOINT", originalEndpoint);
+ Environment.SetEnvironmentVariable("AZURE_OPENAI_KEY", originalKey);
+ }
+ }
+
+ static Mock CreateMockSection(bool exists, string? endpoint = null, string? key = null)
+ {
+ var mockSection = new Mock();
+
+ if (!exists)
+ {
+ mockSection.Setup(s => s.Value).Returns("");
+ mockSection.Setup(s => s.GetChildren()).Returns([]);
+ }
+ else
+ {
+ mockSection.Setup(s => s.Value).Returns("some-value");
+ }
+
+#nullable disable
+ var endpointSection = new Mock();
+ endpointSection.Setup(s => s.Value).Returns(endpoint);
+ mockSection.Setup(s => s.GetSection("Endpoint")).Returns(endpointSection.Object);
+ var keySection = new Mock();
+ keySection.Setup(s => s.Value).Returns(key);
+#nullable restore
+ mockSection.Setup(s => s.GetSection("Key")).Returns(keySection.Object);
+
+ return mockSection;
+ }
+}
\ No newline at end of file
diff --git a/tests/UnitTests/WebJobsOpenAIUnitTests.csproj b/tests/UnitTests/WebJobsOpenAIUnitTests.csproj
new file mode 100644
index 00000000..d3c1b04b
--- /dev/null
+++ b/tests/UnitTests/WebJobsOpenAIUnitTests.csproj
@@ -0,0 +1,29 @@
+
+
+
+ net8.0
+ enable
+ enable
+
+ false
+
+
+
+
+
+
+
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+ all
+
+
+ runtime; build; native; contentfiles; analyzers; buildtransitive
+ all
+
+
+
+
+
+
+
+