Skip to content

Commit

Permalink
feat: Fine tune events stream; Most complete request-response Rust ty…
Browse files Browse the repository at this point in the history
…pes (#28)

* Client with SSE stream functionality; use that in fine-tunes list events stream and in create completions stream

* update readme

* updated types: the most complete and correct-bug-free types yet

* Updated examples with updated types

* updated readme and src/lib.rs doc comment

* updated comment

* remove 'use async_openai as openai' from everywhere
  • Loading branch information
64bit committed Dec 29, 2022
1 parent dd2e624 commit 3353700
Show file tree
Hide file tree
Showing 18 changed files with 240 additions and 127 deletions.
8 changes: 5 additions & 3 deletions async-openai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
<img src="https://docs.rs/async-openai/badge.svg" />
</a>
</div>
<div align="center">
<sub>Logo created by this <a href="https://github.com/64bit/async-openai/tree/main/examples/create-image-b64-json">repo itself</a></sub>
</div>

## Overview

Expand All @@ -24,7 +27,7 @@
- [x] Edits
- [x] Embeddings
- [x] Files (List, Upload, Delete, Retrieve, Retrieve Content)
- [x] Fine-Tuning
- [x] Fine-Tuning (including SSE streaming Fine-tuning events)
- [x] Images (Generation, Edit, Variation)
- [ ] Microsoft Azure Endpoints / AD Authentication
- [x] Models
Expand Down Expand Up @@ -55,8 +58,7 @@ $Env:OPENAI_API_KEY='sk-...'
```rust
use std::error::Error;

use async_openai as openai;
use openai::{
use async_openai::{
types::{CreateImageRequest, ImageSize, ResponseFormat},
Client, Image,
};
Expand Down
92 changes: 92 additions & 0 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
use std::pin::Pin;

use futures::{stream::StreamExt, Stream};
use reqwest::header::HeaderMap;
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
use serde::{de::DeserializeOwned, Serialize};

use crate::error::{OpenAIError, WrappedError};
Expand Down Expand Up @@ -212,4 +216,92 @@ impl Client {
}
}
}

/// Make HTTP POST request to receive SSE
pub(crate) async fn post_stream<I, O>(
&self,
path: &str,
request: I,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>>>>
where
I: Serialize,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = reqwest::Client::new()
.post(format!("{}{path}", self.api_base()))
.headers(self.headers())
.bearer_auth(self.api_key())
.json(&request)
.eventsource()
.unwrap();

Client::stream(event_source).await
}

/// Make HTTP GET request to receive SSE
pub(crate) async fn get_stream<Q, O>(
&self,
path: &str,
query: &Q,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>>>>
where
Q: Serialize + ?Sized,
O: DeserializeOwned + std::marker::Send + 'static,
{
let event_source = reqwest::Client::new()
.get(format!("{}{path}", self.api_base()))
.query(query)
.headers(self.headers())
.bearer_auth(self.api_key())
.eventsource()
.unwrap();

Client::stream(event_source).await
}

/// Request which responds with SSE.
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
pub(crate) async fn stream<O>(
mut event_source: EventSource,
) -> Pin<Box<dyn Stream<Item = Result<O, OpenAIError>>>>
where
O: DeserializeOwned + std::marker::Send + 'static,
{
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

tokio::spawn(async move {
while let Some(ev) = event_source.next().await {
match ev {
Err(e) => {
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
// rx dropped
break;
}
}
Ok(event) => match event {
Event::Message(message) => {
if message.data == "[DONE]" {
break;
}

let response = match serde_json::from_str::<O>(&message.data) {
Err(e) => Err(OpenAIError::JSONDeserialize(e)),
Ok(output) => Ok(output),
};

if let Err(_e) = tx.send(response) {
// rx dropped
break;
}
}
Event::Open => continue,
},
}
}

event_source.close();
});

Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
}
}
52 changes: 1 addition & 51 deletions async-openai/src/completion.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
use futures::stream::StreamExt;
use reqwest_eventsource::{Event, RequestBuilderExt};

use crate::{
client::Client,
error::OpenAIError,
Expand Down Expand Up @@ -45,53 +42,6 @@ impl Completion {

request.stream = Some(true);

let mut event_source = reqwest::Client::new()
.post(format!("{}/completions", client.api_base()))
.bearer_auth(client.api_key())
.json(&request)
.eventsource()
.unwrap();

let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

tokio::spawn(async move {
while let Some(ev) = event_source.next().await {
match ev {
Err(e) => {
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
// rx dropped
break;
}
}
Ok(event) => match event {
Event::Message(message) => {
if message.data == "[DONE]" {
break;
}

let response = match serde_json::from_str::<CreateCompletionResponse>(
&message.data,
) {
Err(e) => Err(OpenAIError::JSONDeserialize(e)),
Ok(ccr) => Ok(ccr),
};

if let Err(_e) = tx.send(response) {
// rx dropped
break;
}
}
Event::Open => continue,
},
}
}

event_source.close();
});

Ok(
Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
as CompletionResponseStream,
)
Ok(client.post_stream("/completions", request).await)
}
}
2 changes: 1 addition & 1 deletion async-openai/src/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ mod tests {
input: crate::types::EmbeddingInput::ArrayOfIntegerArray(vec![
vec![1, 2, 3],
vec![4, 5, 6],
vec![7, 8, 9],
vec![7, 8, 100257],
]),
..Default::default()
};
Expand Down
21 changes: 20 additions & 1 deletion async-openai/src/fine_tune.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::{
error::OpenAIError,
types::{CreateFineTuneRequest, FineTune, ListFineTuneEventsResponse, ListFineTuneResponse},
types::{
CreateFineTuneRequest, FineTune, FineTuneEventsResponseStream, ListFineTuneEventsResponse,
ListFineTuneResponse,
},
Client,
};

Expand Down Expand Up @@ -52,4 +55,20 @@ impl FineTunes {
.get(format!("/fine-tunes/{fine_tune_id}/events").as_str())
.await
}

/// Get fine-grained status updates for a fine-tune job.
///
/// Stream fine tuning events. [FineTuneEventsResponseStream] is a parsed SSE
/// stream until a \[DONE\] is received from server.
pub async fn list_events_stream(
client: &Client,
fine_tune_id: &str,
) -> Result<FineTuneEventsResponseStream, OpenAIError> {
Ok(client
.get_stream(
format!("/fine-tunes/{fine_tune_id}/events").as_str(),
&[("stream", true)],
)
.await)
}
}
16 changes: 8 additions & 8 deletions async-openai/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,32 +3,32 @@
//! ## Creating client
//!
//! ```
//! use async_openai as openai;
//! use async_openai::Client;
//!
//! // Create a client with api key from env var OPENAI_API_KEY and default base url.
//! let client = openai::Client::new();
//! let client = Client::new();
//!
//! // OR use API key from different source
//! let api_key = "sk-..."; // This secret could be from a file, or environment variable.
//! let client = openai::Client::new().with_api_key(api_key);
//! let client = Client::new().with_api_key(api_key);
//!
//! // Use organization different then default when making requests
//! let client = openai::Client::new().with_org_id("the-org");
//! // Use organization other than default when making requests
//! let client = Client::new().with_org_id("the-org");
//! ```
//!
//! ## Making requests
//!
//!```
//!# tokio_test::block_on(async {
//! use async_openai as openai;
//! use openai::{Client, Completion, types::{CreateCompletionRequest}};
//!
//! use async_openai::{Client, Completion, types::{CreateCompletionRequest, Prompt}};
//!
//! // Create client
//! let client = Client::new();
//! // Create request
//! let request = CreateCompletionRequest {
//! model: "text-davinci-003".to_owned(),
//! prompt: Some("Tell me a joke about the universe".to_owned()),
//! prompt: Some(Prompt::String("Tell me the recipe of alfredo pasta".to_owned())),
//! ..Default::default()
//! };
//! // Call API
Expand Down
Loading

0 comments on commit 3353700

Please sign in to comment.