Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for ChatGPT API #43

Merged
merged 3 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions async-openai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

- It's based on [OpenAI OpenAPI spec](https://github.com/openai/openai-openapi)
- Current features:
- [x] Completions (including SSE streaming)
- [x] Completions (including SSE streaming & Chat)
- [x] Edits
- [x] Embeddings
- [x] Files
Expand All @@ -35,7 +35,7 @@
- Non-streaming requests are retried with exponential backoff when [rate limited](https://platform.openai.com/docs/guides/rate-limits) by the API server.
- Ergonomic Rust library with builder pattern for all request objects.

*Being a young project there could be rough edges.*
_Being a young project there could be rough edges._

## Usage

Expand Down
52 changes: 52 additions & 0 deletions async-openai/src/chat.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use crate::{
client::Client,
error::OpenAIError,
types::{ChatResponseStream, CreateChatRequest, CreateChatResponse},
};

/// Given a series of messages, the model will return one or more predicted
/// completion messages.
pub struct Chat<'c> {
client: &'c Client,
}

impl<'c> Chat<'c> {
pub fn new(client: &'c Client) -> Self {
Self { client }
}

/// Creates a completion for the provided messages and parameters
pub async fn create(
&self,
request: CreateChatRequest,
) -> Result<CreateChatResponse, OpenAIError> {
if request.stream.is_some() && request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is true, use Chat::create_stream".into(),
));
}
self.client.post("/chat/completions", request).await
}

/// Creates a completion request for the provided messages and parameters
///
/// Stream back partial progress. Tokens will be sent as data-only
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
/// as they become available, with the stream terminated by a data: \[DONE\] message.
///
/// [ChatResponseStream] is a parsed SSE stream until a \[DONE\] is received from server.
pub async fn create_stream(
&self,
mut request: CreateChatRequest,
) -> Result<ChatResponseStream, OpenAIError> {
if request.stream.is_some() && !request.stream.unwrap() {
return Err(OpenAIError::InvalidArgument(
"When stream is false, use Chat::create".into(),
));
}

request.stream = Some(true);

Ok(self.client.post_stream("/chat/completions", request).await)
}
}
6 changes: 6 additions & 0 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
use serde::{de::DeserializeOwned, Serialize};

use crate::{
chat::Chat,
edit::Edits,
error::{OpenAIError, WrappedError},
file::Files,
Expand Down Expand Up @@ -91,6 +92,11 @@ impl Client {
Completions::new(self)
}

/// To call [Chat] group related APIs using this client.
pub fn chat(&self) -> Chat {
Chat::new(self)
}

/// To call [Edits] group related APIs using this client.
pub fn edits(&self) -> Edits {
Edits::new(self)
Expand Down
2 changes: 1 addition & 1 deletion async-openai/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,6 @@ mod tests {
let delete_response = client.files().delete(&openai_file.id).await.unwrap();

assert_eq!(openai_file.id, delete_response.id);
assert_eq!(delete_response.deleted, true);
assert!(delete_response.deleted);
}
}
1 change: 1 addition & 0 deletions async-openai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
//! ## Examples
//! For full working examples for all supported features see [examples](https://github.com/64bit/async-openai/tree/main/examples) directory in the repository.
//!
mod chat;
mod client;
mod completion;
mod download;
Expand Down
122 changes: 121 additions & 1 deletion async-openai/src/types/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::{collections::HashMap, path::PathBuf, pin::Pin};
use std::{
collections::HashMap,
fmt::{Display, Formatter},
path::PathBuf,
pin::Pin,
};

use derive_builder::Builder;
use futures::Stream;
Expand Down Expand Up @@ -134,6 +139,101 @@ pub struct CreateCompletionRequest {
pub user: Option<String>,
}

#[derive(Clone, Serialize, Debug, Deserialize)]
pub enum MessageRole {
#[serde(rename = "assistant")]
Assistant,
#[serde(rename = "system")]
System,
#[serde(rename = "user")]
User,
}

#[derive(Clone, Serialize, Deserialize, Debug, Builder)]
#[builder(name = "MessageArgs")]
#[builder(pattern = "mutable")]
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct Message {
pub role: MessageRole,
pub content: String,
}

#[derive(Clone, Serialize, Default, Debug, Builder)]
#[builder(name = "CreateChatRequestArgs")]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option), default)]
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct CreateChatRequest {
/// ID of the model to use. You can use the [List models](https://platform.openai.com/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](https://platform.openai.com/docs/models/overview) for descriptions of them.
pub model: String,

/// The message(s) to generate a response to, encoded as an array of the message type.
///
/// Note that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.
#[serde(skip_serializing_if = "Option::is_none")]
pub messages: Option<Vec<Message>>,

/// What [sampling temperature](https://towardsdatascience.com/how-to-sample-from-language-models-682bceb97277) to use. Higher values means the model will take more risks. Try 0.9 for more creative applications, and 0 (argmax sampling) for ones with a well-defined answer.
///
/// We generally recommend altering this or `top_p` but not both.
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>, // min: 0, max: 2, default: 1,

/// An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.
///
/// We generally recommend altering this or `temperature` but not both.
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>, // min: 0, max: 1, default: 1

/// How many completions to generate for each prompt.

/// **Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.
///
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u8>, // min:1 max: 128, default: 1

/// Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format)
/// as they become available, with the stream terminated by a `data: [DONE]` message.
#[serde(skip_serializing_if = "Option::is_none")]
pub stream: Option<bool>, // nullable: true

/// Include the log probabilities on the `logprobs` most likely tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response.

/// The maximum value for `logprobs` is 5. If you need more than this, please contact us through our [Help center](https://help.openai.com) and describe your use case.
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<u8>, // min:0 , max: 5, default: null, nullable: true

/// Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Stop>,

/// Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
///
/// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>, // min: -2.0, max: 2.0, default 0

/// Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.
///
/// [See more information about frequency and presence penalties.](https://platform.openai.com/docs/api-reference/parameter-details)
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>, // min: -2.0, max: 2.0, default: 0

/// Modify the likelihood of specified tokens appearing in the completion.
///
/// Accepts a json object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) (which works for both GPT-2 and GPT-3) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.
///
/// As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated.
#[serde(skip_serializing_if = "Option::is_none")]
pub logit_bias: Option<HashMap<String, serde_json::Value>>, // default: null

/// A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse. [Learn more](https://platform.openai.com/docs/usage-policies/end-user-ids).
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
}

#[derive(Debug, Deserialize)]
pub struct Logprobs {
pub tokens: Vec<String>,
Expand All @@ -150,6 +250,13 @@ pub struct Choice {
pub finish_reason: Option<String>,
}

#[derive(Debug, Deserialize)]
pub struct ChatChoice {
pub message: Message,
pub index: u32,
pub finish_reason: Option<String>,
}

#[derive(Debug, Deserialize)]
pub struct Usage {
pub prompt_tokens: u32,
Expand All @@ -167,10 +274,23 @@ pub struct CreateCompletionResponse {
pub usage: Option<Usage>,
}

#[derive(Debug, Deserialize)]
pub struct CreateChatResponse {
pub id: String,
pub object: String,
pub created: u32,
pub choices: Vec<ChatChoice>,
pub usage: Option<Usage>,
}

/// Parsed server side events stream until an \[DONE\] is received from server.
pub type CompletionResponseStream =
Pin<Box<dyn Stream<Item = Result<CreateCompletionResponse, OpenAIError>> + Send>>;

/// Parsed server side events stream until an \[DONE\] is received from server.
pub type ChatResponseStream =
Pin<Box<dyn Stream<Item = Result<CreateChatResponse, OpenAIError>> + Send>>;

#[derive(Debug, Clone, Serialize, Default, Builder)]
#[builder(name = "CreateEditRequestArgs")]
#[builder(pattern = "mutable")]
Expand Down
19 changes: 9 additions & 10 deletions async-openai/tests/boxed_future.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@

use futures::StreamExt;
use futures::future::{BoxFuture, FutureExt};
use futures::StreamExt;

use async_openai::Client;
use async_openai::types::{CompletionResponseStream, CreateCompletionRequestArgs};
use async_openai::Client;

#[tokio::test]
async fn boxed_future_test() {

fn interpret_bool(token_stream: &mut CompletionResponseStream) -> BoxFuture<'_, bool> {
async move {
while let Some(response) = token_stream.next().await {
Expand All @@ -17,12 +15,13 @@ async fn boxed_future_test() {
if !token_str.is_empty() {
return token_str.contains("yes") || token_str.contains("Yes");
}
},
}
Err(e) => eprintln!("Error: {e}"),
}
}
false
}.boxed()
}
.boxed()
}

let client = Client::new();
Expand All @@ -34,11 +33,11 @@ async fn boxed_future_test() {
.stream(true)
.logprobs(3)
.max_tokens(64_u16)
.build().unwrap();
.build()
.unwrap();

let mut stream = client.completions().create_stream(request).await.unwrap();

let result = interpret_bool(&mut stream).await;
assert_eq!(result, true);

}
assert!(result);
}