From 3633e68cae93b44ee79c8d01147d193f9b599fa2 Mon Sep 17 00:00:00 2001 From: Bas Brekelmans Date: Thu, 30 Mar 2023 22:55:23 -0700 Subject: [PATCH] Allow an HttpClientFactory to be assigned to OpenAIAPI --- OpenAI_API/EndpointBase.cs | 16 +++--- OpenAI_API/OpenAIAPI.cs | 8 ++- OpenAI_API/OpenAI_API.csproj | 5 ++ OpenAI_Tests/HttpClientResolutionTests.cs | 66 +++++++++++++++++++++++ OpenAI_Tests/OpenAI_Tests.csproj | 1 + 5 files changed, 87 insertions(+), 9 deletions(-) create mode 100644 OpenAI_Tests/HttpClientResolutionTests.cs diff --git a/OpenAI_API/EndpointBase.cs b/OpenAI_API/EndpointBase.cs index 66a4ee0..792727a 100644 --- a/OpenAI_API/EndpointBase.cs +++ b/OpenAI_API/EndpointBase.cs @@ -59,16 +59,18 @@ protected HttpClient GetClient() { throw new AuthenticationException("You must provide API authentication. Please refer to https://github.com/OkGoDoIt/OpenAI-API-dotnet#authentication for details."); } - - /* - if (_Api.SharedHttpClient==null) + + HttpClient client; + var clientFactory = _Api.HttpClientFactory; + if (clientFactory != null) + { + client = clientFactory.CreateClient(); + } + else { - _Api.SharedHttpClient = new HttpClient(); - _Api.SharedHttpClient. + client = new HttpClient(); } - */ - HttpClient client = new HttpClient(); client.DefaultRequestHeaders.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", _Api.Auth.ApiKey); // Further authentication-header used for Azure openAI service client.DefaultRequestHeaders.Add("api-key", _Api.Auth.ApiKey); diff --git a/OpenAI_API/OpenAIAPI.cs b/OpenAI_API/OpenAIAPI.cs index f415410..ac503dc 100644 --- a/OpenAI_API/OpenAIAPI.cs +++ b/OpenAI_API/OpenAIAPI.cs @@ -5,7 +5,7 @@ using OpenAI_API.Images; using OpenAI_API.Models; using OpenAI_API.Moderation; -using System.Xml.Linq; +using System.Net.Http; namespace OpenAI_API { @@ -31,6 +31,11 @@ public class OpenAIAPI : IOpenAIAPI /// public APIAuthentication Auth { get; set; } + /// + /// Optionally provide an IHttpClientFactory to create the client to send requests. + /// + public IHttpClientFactory HttpClientFactory { get; set; } + /// /// Creates a new entry point to the OpenAPI API, handling auth and allowing access to the various API endpoints /// @@ -96,6 +101,5 @@ public static OpenAIAPI ForAzure(string YourResourceName, string deploymentId, A /// The API lets you do operations with images. You can Given a prompt and/or an input image, the model will generate a new image. /// public ImageGenerationEndpoint ImageGenerations { get; } - } } diff --git a/OpenAI_API/OpenAI_API.csproj b/OpenAI_API/OpenAI_API.csproj index 4846444..cf95141 100644 --- a/OpenAI_API/OpenAI_API.csproj +++ b/OpenAI_API/OpenAI_API.csproj @@ -42,10 +42,15 @@ + + + + + diff --git a/OpenAI_Tests/HttpClientResolutionTests.cs b/OpenAI_Tests/HttpClientResolutionTests.cs new file mode 100644 index 0000000..3cea930 --- /dev/null +++ b/OpenAI_Tests/HttpClientResolutionTests.cs @@ -0,0 +1,66 @@ +using Microsoft.Extensions.Options; +using Moq; +using NUnit.Framework; +using OpenAI_API; +using System; +using System.Linq; +using System.Net.Http; + +namespace OpenAI_Tests +{ + public class HttpClientResolutionTests + { + [Test] + public void GetHttpClient_NoFactory() + { + var api = new OpenAIAPI(new APIAuthentication("fake-key")); + var endpoint = new TestEndpoint(api); + + var client = endpoint.GetHttpClient(); + Assert.IsNotNull(client); + } + + [Test] + public void GetHttpClient_WithFactory() + { + var expectedClient1 = new HttpClient(); + var mockedFactory1 = Mock.Of(f => f.CreateClient(Options.DefaultName) == expectedClient1); + + var expectedClient2 = new HttpClient(); + var mockedFactory2 = Mock.Of(f => f.CreateClient(Options.DefaultName) == expectedClient2); + + var api = new OpenAIAPI(new APIAuthentication("fake-key")); + var endpoint = new TestEndpoint(api); + + api.HttpClientFactory = mockedFactory1; + var actualClient1 = endpoint.GetHttpClient(); + + api.HttpClientFactory = mockedFactory2; + var actualClient2 = endpoint.GetHttpClient(); + + Assert.AreSame(expectedClient1, actualClient1); + Assert.AreSame(expectedClient2, actualClient2); + + api.HttpClientFactory = null; + var actualClient3 = endpoint.GetHttpClient(); + + Assert.NotNull(actualClient3); + Assert.AreNotSame(expectedClient1, actualClient3); + Assert.AreNotSame(expectedClient2, actualClient3); + } + + private class TestEndpoint : EndpointBase + { + public TestEndpoint(OpenAIAPI api) : base(api) + { + } + + protected override string Endpoint => throw new System.NotSupportedException(); + + public HttpClient GetHttpClient() + { + return base.GetClient(); + } + } + } +} diff --git a/OpenAI_Tests/OpenAI_Tests.csproj b/OpenAI_Tests/OpenAI_Tests.csproj index 805cc44..f26766e 100644 --- a/OpenAI_Tests/OpenAI_Tests.csproj +++ b/OpenAI_Tests/OpenAI_Tests.csproj @@ -8,6 +8,7 @@ +