Skip to content

Commit

Permalink
better support for azure openai and chat models
Browse files Browse the repository at this point in the history
  • Loading branch information
asklar committed Aug 18, 2023
1 parent dd30471 commit 6f9b808
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 8 deletions.
2 changes: 1 addition & 1 deletion OpenAI.WinRT.nuspec
Expand Up @@ -2,7 +2,7 @@
<package >
<metadata>
<id>OpenAI.WinRT</id>
<version>0.0.15</version>
<version>0.0.16</version>
<title>OpenAI.WinRT</title>
<authors>Alexander Sklar</authors>
<requireLicenseAcceptance>false</requireLicenseAcceptance>
Expand Down
24 changes: 19 additions & 5 deletions OpenAIClient.cpp
Expand Up @@ -28,6 +28,20 @@ namespace winrt::OpenAI::implementation
}
}

winrt::OpenAI::OpenAIClient OpenAIClient::CreateAzureOpenAIClient(winrt::Windows::Foundation::Uri const& endpoint, winrt::hstring deployment, winrt::hstring const& apiKey)
{
auto client = winrt::OpenAI::OpenAIClient();
client.UseBearerTokenAuthorization(false);
client.ApiKey(apiKey);
constexpr std::wstring_view uriTemplate = L"{}/openai/deployments/{}/chat/completions?api-version=2023-03-15-preview";

auto uri = winrt::Windows::Foundation::Uri{ std::vformat(uriTemplate, std::make_wformat_args(endpoint.AbsoluteUri(), deployment))};
client.CompletionUri(uri);
client.IsChatModel(true);

return client;
}

void OpenAIClient::ApiKey(winrt::hstring v) noexcept
{
m_apiKey = v;
Expand Down Expand Up @@ -66,8 +80,8 @@ namespace winrt::OpenAI::implementation
std::wstring requestJson;
auto modelString = EscapeStringForJson(request.Model());
const std::wstring_view model{ modelString };
const auto isGpt35Turbo = request.Model() == L"gpt-3.5-turbo";
if (isGpt35Turbo) {
const auto isChatModel = request.Model() == L"gpt-3.5-turbo" || m_isChatModel;
if (isChatModel) {
constexpr std::wstring_view requestTemplate{ LR"({{ {}
"messages": [{{ "role": "user", "content": {} }}],
"temperature": {},
Expand Down Expand Up @@ -101,7 +115,7 @@ namespace winrt::OpenAI::implementation
));
}
auto content = winrt::HttpStringContent(requestJson, winrt::UnicodeEncoding::Utf8, L"application/json");
auto uri = isGpt35Turbo && UseBearerTokenAuthorization() ? Windows::Foundation::Uri{gpt35turboEndpoint} : CompletionUri();
auto uri = isChatModel && UseBearerTokenAuthorization() ? Windows::Foundation::Uri{gpt35turboEndpoint} : CompletionUri();
auto response = co_await m_client.PostAsync(uri, content);
auto responseJsonStr = co_await response.Content().ReadAsStringAsync();
statusCode = response.StatusCode();
Expand All @@ -114,7 +128,7 @@ namespace winrt::OpenAI::implementation
const auto& choice = c.GetObject();
auto retChoice = winrt::make<Choice>();
auto retChoiceImpl = winrt::get_self<Choice>(retChoice);
if (!isGpt35Turbo) {
if (!isChatModel) {
retChoiceImpl->m_text = choice.GetNamedString(L"text");
} else {
auto msg = choice.GetNamedObject(L"message");
Expand Down Expand Up @@ -152,7 +166,7 @@ namespace winrt::OpenAI::implementation
auto choices = json.GetNamedArray(L"choices");
auto choice = choices.GetObjectAt(0);
winrt::hstring text;
if (!isGpt35Turbo) {
if (!isChatModel) {
text = choice.GetNamedString(L"text");
} else {
auto msg = choice.GetNamedObject(L"message");
Expand Down
6 changes: 4 additions & 2 deletions OpenAIClient.h
Expand Up @@ -18,9 +18,11 @@ namespace winrt::OpenAI::implementation
winrt::hstring ApiKey() const noexcept { return m_apiKey; }
void ApiKey(winrt::hstring v) noexcept;


static winrt::OpenAI::OpenAIClient CreateAzureOpenAIClient(winrt::Windows::Foundation::Uri const& uri, winrt::hstring deployment, winrt::hstring const& apiKey);
static constexpr std::wstring_view gpt35turboEndpoint = L"https://api.openai.com/v1/chat/completions";

bool m_isChatModel = false;
bool IsChatModel() const noexcept { return m_isChatModel; }
void IsChatModel(bool v) noexcept { m_isChatModel = v; }
winrt::Windows::Foundation::Uri CompletionUri() const noexcept { return m_completionUri; }
void CompletionUri(winrt::Windows::Foundation::Uri v) noexcept { m_completionUri = v; }
Windows::Foundation::IAsyncOperation<winrt::Windows::Foundation::Collections::IVector<winrt::OpenAI::Choice>> GetCompletionAsync(winrt::hstring prompt, winrt::hstring model);
Expand Down
2 changes: 2 additions & 0 deletions OpenAIClient.idl
Expand Up @@ -35,6 +35,7 @@ namespace OpenAI
runtimeclass OpenAIClient
{
OpenAIClient();
static OpenAIClient CreateAzureOpenAIClient(Windows.Foundation.Uri endpoint, String deployment, String apiKey);
String ApiKey;
Windows.Foundation.Uri CompletionUri;

Expand All @@ -50,6 +51,7 @@ namespace OpenAI
FewShotTemplate CreateFewShotTemplate(Windows.Foundation.Collections.IVectorView<String> parameters);

Windows.Foundation.IAsyncOperation< Windows.Foundation.Collections.IVector<Choice> > GetChatResponseAsync(ChatRequest request);
Boolean IsChatModel;
}

enum Similarity
Expand Down

0 comments on commit 6f9b808

Please sign in to comment.