Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 68 additions & 48 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use std::pin::Pin;

use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use reqwest::multipart::Form;
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
use reqwest::{multipart::Form, Response};
use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
use serde::{de::DeserializeOwned, Serialize};

use crate::{
config::{Config, OpenAIConfig},
error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
file::Files,
image::Images,
moderation::Moderations,
Expand Down Expand Up @@ -335,52 +335,34 @@ impl<C: Config> Client<C> {
.map_err(backoff::Error::Permanent)?;

let status = response.status();
let bytes = response
.bytes()
.await
.map_err(OpenAIError::Reqwest)
.map_err(backoff::Error::Permanent)?;

if status.is_server_error() {
// OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
let message: String = String::from_utf8_lossy(&bytes).into_owned();
tracing::warn!("Server error: {status} - {message}");
return Err(backoff::Error::Transient {
err: OpenAIError::ApiError(ApiError {
message,
r#type: None,
param: None,
code: None,
}),
retry_after: None,
});
}

// Deserialize response body from either error object or actual response object
if !status.is_success() {
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;

if status.as_u16() == 429
// API returns 429 also when:
// "You exceeded your current quota, please check your plan and billing details."
&& wrapped_error.error.r#type != Some("insufficient_quota".to_string())
{
// Rate limited retry...
tracing::warn!("Rate limited: {}", wrapped_error.error.message);
return Err(backoff::Error::Transient {
err: OpenAIError::ApiError(wrapped_error.error),
retry_after: None,
});
} else {
return Err(backoff::Error::Permanent(OpenAIError::ApiError(
wrapped_error.error,
)));
match read_response(response).await {
Ok(bytes) => Ok(bytes),
Err(e) => {
match e {
OpenAIError::ApiError(api_error) => {
if status.is_server_error() {
Err(backoff::Error::Transient {
err: OpenAIError::ApiError(api_error),
retry_after: None,
})
} else if status.as_u16() == 429
&& api_error.r#type != Some("insufficient_quota".to_string())
{
// Rate limited retry...
tracing::warn!("Rate limited: {}", api_error.message);
Err(backoff::Error::Transient {
err: OpenAIError::ApiError(api_error),
retry_after: None,
})
} else {
Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
}
}
_ => Err(backoff::Error::Permanent(e)),
}
}
}

Ok(bytes)
})
.await
}
Expand Down Expand Up @@ -471,6 +453,44 @@ impl<C: Config> Client<C> {
}
}

async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
let status = response.status();
let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;

if status.is_server_error() {
// OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
let message: String = String::from_utf8_lossy(&bytes).into_owned();
tracing::warn!("Server error: {status} - {message}");
return Err(OpenAIError::ApiError(ApiError {
message,
r#type: None,
param: None,
code: None,
}));
}

// Deserialize response body from either error object or actual response object
if !status.is_success() {
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;

return Err(OpenAIError::ApiError(wrapped_error.error));
}

Ok(bytes)
}

async fn map_stream_error(value: EventSourceError) -> OpenAIError {
match value {
EventSourceError::InvalidStatusCode(status_code, response) => {
read_response(response).await.expect_err(&format!(
"Unreachable because read_response returns err when status_code {status_code} is invalid"
))
}
_ => OpenAIError::StreamError(StreamError::ReqwestEventSource(value.into())),
}
}

/// Request which responds with SSE.
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
pub(crate) async fn stream<O>(
Expand All @@ -485,7 +505,7 @@ where
while let Some(ev) = event_source.next().await {
match ev {
Err(e) => {
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
// rx dropped
break;
}
Expand Down Expand Up @@ -530,7 +550,7 @@ where
while let Some(ev) = event_source.next().await {
match ev {
Err(e) => {
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
// rx dropped
break;
}
Expand Down
15 changes: 14 additions & 1 deletion async-openai/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
//! Errors originating from API calls, parsing responses, and reading-or-writing to the file system.
use std::string::FromUtf8Error;

use reqwest::{header::HeaderValue, Response};
use serde::{Deserialize, Serialize};

#[derive(Debug, thiserror::Error)]
Expand All @@ -20,13 +23,23 @@ pub enum OpenAIError {
FileReadError(String),
/// Error on SSE streaming
#[error("stream failed: {0}")]
StreamError(String),
StreamError(StreamError),
/// Error from client side validation
/// or when builder fails to build request before making API call
#[error("invalid args: {0}")]
InvalidArgument(String),
}

#[derive(Debug, thiserror::Error)]
pub enum StreamError {
/// Underlying error from reqwest_eventsource library when reading the stream
#[error("{0}")]
ReqwestEventSource(#[from] reqwest_eventsource::Error),
/// Error when a stream event does not match one of the expected values
#[error("Unknown event: {0:#?}")]
UnknownEvent(eventsource_stream::Event),
}

/// OpenAI API returns error object on failure
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ApiError {
Expand Down
6 changes: 2 additions & 4 deletions async-openai/src/types/assistant_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::pin::Pin;
use futures::Stream;
use serde::Deserialize;

use crate::error::{map_deserialization_error, ApiError, OpenAIError};
use crate::error::{map_deserialization_error, ApiError, OpenAIError, StreamError};

use super::{
MessageDeltaObject, MessageObject, RunObject, RunStepDeltaObject, RunStepObject, ThreadObject,
Expand Down Expand Up @@ -207,9 +207,7 @@ impl TryFrom<eventsource_stream::Event> for AssistantStreamEvent {
.map(AssistantStreamEvent::ErrorEvent),
"done" => Ok(AssistantStreamEvent::Done(value.data)),

_ => Err(OpenAIError::StreamError(
"Unrecognized event: {value:?#}".into(),
)),
_ => Err(OpenAIError::StreamError(StreamError::UnknownEvent(value))),
}
}
}
2 changes: 1 addition & 1 deletion examples/chat-stream/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
});
}
Err(err) => {
writeln!(lock, "error: {err}").unwrap();
writeln!(lock, "error: {err:?}").unwrap();
}
}
stdout().flush()?;
Expand Down
2 changes: 1 addition & 1 deletion examples/completions-stream/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
Ok(ccr) => ccr.choices.iter().for_each(|c| {
print!("{}", c.text);
}),
Err(e) => eprintln!("{}", e),
Err(e) => eprintln!("{e:?}"),
}
}

Expand Down
4 changes: 2 additions & 2 deletions examples/function-call-stream/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
}
}
Err(err) => {
writeln!(lock, "error: {err}").unwrap();
writeln!(lock, "error: {err:?}").unwrap();
}
}
stdout().flush()?;
Expand Down Expand Up @@ -132,7 +132,7 @@ async fn call_fn(
});
}
Err(err) => {
writeln!(lock, "error: {err}").unwrap();
writeln!(lock, "error: {err:?}").unwrap();
}
}
stdout().flush()?;
Expand Down
7 changes: 3 additions & 4 deletions examples/responses-stream/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
| ResponseEvent::ResponseFailed(_) => {
break;
}
_ => { println!("{response_event:#?}"); }
_ => {
println!("{response_event:#?}");
}
},
Err(e) => {
eprintln!("{e:#?}");
// When a stream ends, it returns Err(OpenAIError::StreamError("Stream ended"))
// Without this, the stream will never end
break;
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion examples/tool-call-stream/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
}
Err(err) => {
let mut lock = stdout().lock();
writeln!(lock, "error: {err}").unwrap();
writeln!(lock, "error: {err:?}").unwrap();
}
}
stdout()
Expand Down