diff --git a/async-openai/Cargo.toml b/async-openai/Cargo.toml index 8fa4e438..31e24e3e 100644 --- a/async-openai/Cargo.toml +++ b/async-openai/Cargo.toml @@ -40,6 +40,7 @@ tracing = "0.1.37" derive_builder = "0.12.0" async-convert = "1.0.0" secrecy = { version = "0.8.0", features=["serde"] } +bytes = "1.5.0" [dev-dependencies] tokio-test = "0.4.2" diff --git a/async-openai/README.md b/async-openai/README.md index 1c74574b..95a2461e 100644 --- a/async-openai/README.md +++ b/async-openai/README.md @@ -23,7 +23,7 @@ - It's based on [OpenAI OpenAPI spec](https://github.com/openai/openai-openapi) - Current features: - - [x] Audio + - [x] Audio (Whisper/TTS) - [x] Chat - [x] Completions (Legacy) - [x] Edits (Deprecated) diff --git a/async-openai/src/audio.rs b/async-openai/src/audio.rs index 137d5cf6..79d40201 100644 --- a/async-openai/src/audio.rs +++ b/async-openai/src/audio.rs @@ -1,9 +1,11 @@ +use bytes::Bytes; + use crate::{ config::Config, error::OpenAIError, types::{ - CreateTranscriptionRequest, CreateTranscriptionResponse, CreateTranslationRequest, - CreateTranslationResponse, + CreateSpeechRequest, CreateTranscriptionRequest, CreateTranscriptionResponse, + CreateTranslationRequest, CreateTranslationResponse, }, Client, }; @@ -36,4 +38,8 @@ impl<'c, C: Config> Audio<'c, C> { ) -> Result { self.client.post_form("/audio/translations", request).await } + + pub async fn speech(&self, request: CreateSpeechRequest) -> Result { + self.client.post_raw("/audio/speech", request).await + } } diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index 8ec14f4d..86436185 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -1,5 +1,6 @@ use std::pin::Pin; +use bytes::Bytes; use futures::{stream::StreamExt, Stream}; use reqwest_eventsource::{Event, EventSource, RequestBuilderExt}; use serde::{de::DeserializeOwned, Serialize}; @@ -174,6 +175,24 @@ impl Client { self.execute(request_maker).await } + /// Make a POST request to {path} and return the response body + pub(crate) async fn post_raw(&self, path: &str, request: I) -> Result + where + I: Serialize, + { + let request_maker = || async { + Ok(self + .http_client + .post(self.config.url(path)) + .query(&self.config.query()) + .headers(self.config.headers()) + .json(&request) + .build()?) + }; + + self.execute_raw(request_maker).await + } + /// Make a POST request to {path} and deserialize the response body pub(crate) async fn post(&self, path: &str, request: I) -> Result where @@ -218,9 +237,8 @@ impl Client { /// request_maker serves one purpose: to be able to create request again /// to retry API call after getting rate limited. request_maker is async because /// reqwest::multipart::Form is created by async calls to read files for uploads. - async fn execute(&self, request_maker: M) -> Result + async fn execute_raw(&self, request_maker: M) -> Result where - O: DeserializeOwned, M: Fn() -> Fut, Fut: core::future::Future>, { @@ -265,14 +283,30 @@ impl Client { } } - let response: O = serde_json::from_slice(bytes.as_ref()) - .map_err(|e| map_deserialization_error(e, bytes.as_ref())) - .map_err(backoff::Error::Permanent)?; - Ok(response) + Ok(bytes) }) .await } + /// Execute a HTTP request and retry on rate limit + /// + /// request_maker serves one purpose: to be able to create request again + /// to retry API call after getting rate limited. request_maker is async because + /// reqwest::multipart::Form is created by async calls to read files for uploads. + async fn execute(&self, request_maker: M) -> Result + where + O: DeserializeOwned, + M: Fn() -> Fut, + Fut: core::future::Future>, + { + let bytes = self.execute_raw(request_maker).await?; + + let response: O = serde_json::from_slice(bytes.as_ref()) + .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; + + Ok(response) + } + /// Make HTTP POST request to receive SSE pub(crate) async fn post_stream( &self, diff --git a/async-openai/src/types/types.rs b/async-openai/src/types/types.rs index a72484c1..53689cf0 100644 --- a/async-openai/src/types/types.rs +++ b/async-openai/src/types/types.rs @@ -1225,6 +1225,30 @@ pub enum AudioResponseFormat { Vtt, } +#[derive(Debug, Serialize, Default, Clone, Copy, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum SpeechResponseFormat { + #[default] + Mp3, + Opus, + Aac, + Flac, +} + +#[derive(Debug, Serialize, Clone, PartialEq)] +#[serde(rename_all = "snake_case")] +#[serde(untagged)] +#[non_exhaustive] +pub enum Voice { + Alloy, + Echo, + Fable, + Onyx, + Nova, + Shimmer, + Other(String), +} + #[derive(Clone, Default, Debug, Builder, PartialEq)] #[builder(name = "CreateTranscriptionRequestArgs")] #[builder(pattern = "mutable")] @@ -1256,6 +1280,29 @@ pub struct CreateTranscriptionResponse { pub text: String, } +#[derive(Clone, Debug, Builder, PartialEq, Serialize)] +#[builder(name = "CreateSpeechRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option))] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateSpeechRequest { + /// The text to generate audio for. The maximum length is 4096 characters. + pub input: String, + + /// ID of the model to use. Only `tts-1` and `tts-1-hd` are currently available. + pub model: String, + + /// The voice to use when generating the audio. Supported voices are `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`. + pub voice: Voice, + + /// The format to audio in. Supported formats are mp3, opus, aac, and flac. + pub response_format: Option, + + /// The speed of the generated audio. Select a value from 0.25 to 4.0. 1.0 is the default. + pub speed: Option, // default: 1.0 +} + #[derive(Clone, Default, Debug, Builder, PartialEq)] #[builder(name = "CreateTranslationRequestArgs")] #[builder(pattern = "mutable")] @@ -1283,3 +1330,8 @@ pub struct CreateTranslationRequest { pub struct CreateTranslationResponse { pub text: String, } + +#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)] +pub struct CreateSpeechResponse { + pub text: String, +} diff --git a/examples/audio-speech/.gitignore b/examples/audio-speech/.gitignore new file mode 100644 index 00000000..cbf36313 --- /dev/null +++ b/examples/audio-speech/.gitignore @@ -0,0 +1 @@ +audio.mp3 diff --git a/examples/audio-speech/Cargo.toml b/examples/audio-speech/Cargo.toml new file mode 100644 index 00000000..1ce87cb7 --- /dev/null +++ b/examples/audio-speech/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "audio-speech" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +async-openai = {path = "../../async-openai"} +tokio = { version = "1.25.0", features = ["full"] } diff --git a/examples/audio-speech/README.md b/examples/audio-speech/README.md new file mode 100644 index 00000000..ef30a3c7 --- /dev/null +++ b/examples/audio-speech/README.md @@ -0,0 +1,3 @@ +### Output (as an mp3 file) + +> Today is a wonderful day to build something people love! diff --git a/examples/audio-speech/src/main.rs b/examples/audio-speech/src/main.rs new file mode 100644 index 00000000..b96581d3 --- /dev/null +++ b/examples/audio-speech/src/main.rs @@ -0,0 +1,22 @@ +use async_openai::{ + types::{CreateSpeechRequestArgs, Voice}, + Client, +}; +use std::error::Error; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::new(); + + let request = CreateSpeechRequestArgs::default() + .input("Today is a wonderful day to build something people love!".to_string()) + .voice(Voice::Alloy) + .model("tts-1") + .build()?; + + let response = client.audio().speech(request).await?; + + std::fs::write("audio.mp3", response)?; + + Ok(()) +}