Skip to content

Commit

Permalink
Add speech endpoint (#130)
Browse files Browse the repository at this point in the history
* Add speech endpoint

* Add voice parameter and example

* Add voice enum
  • Loading branch information
m1guelpf committed Nov 6, 2023
1 parent 50019ca commit e085d30
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 9 deletions.
1 change: 1 addition & 0 deletions async-openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ tracing = "0.1.37"
derive_builder = "0.12.0"
async-convert = "1.0.0"
secrecy = { version = "0.8.0", features=["serde"] }
bytes = "1.5.0"

[dev-dependencies]
tokio-test = "0.4.2"
2 changes: 1 addition & 1 deletion async-openai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

- It's based on [OpenAI OpenAPI spec](https://github.com/openai/openai-openapi)
- Current features:
- [x] Audio
- [x] Audio (Whisper/TTS)
- [x] Chat
- [x] Completions (Legacy)
- [x] Edits (Deprecated)
Expand Down
10 changes: 8 additions & 2 deletions async-openai/src/audio.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use bytes::Bytes;

use crate::{
config::Config,
error::OpenAIError,
types::{
CreateTranscriptionRequest, CreateTranscriptionResponse, CreateTranslationRequest,
CreateTranslationResponse,
CreateSpeechRequest, CreateTranscriptionRequest, CreateTranscriptionResponse,
CreateTranslationRequest, CreateTranslationResponse,
},
Client,
};
Expand Down Expand Up @@ -36,4 +38,8 @@ impl<'c, C: Config> Audio<'c, C> {
) -> Result<CreateTranslationResponse, OpenAIError> {
self.client.post_form("/audio/translations", request).await
}

pub async fn speech(&self, request: CreateSpeechRequest) -> Result<Bytes, OpenAIError> {
self.client.post_raw("/audio/speech", request).await
}
}
46 changes: 40 additions & 6 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::pin::Pin;

use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
use serde::{de::DeserializeOwned, Serialize};
Expand Down Expand Up @@ -174,6 +175,24 @@ impl<C: Config> Client<C> {
self.execute(request_maker).await
}

/// Make a POST request to {path} and return the response body
pub(crate) async fn post_raw<I>(&self, path: &str, request: I) -> Result<Bytes, OpenAIError>
where
I: Serialize,
{
let request_maker = || async {
Ok(self
.http_client
.post(self.config.url(path))
.query(&self.config.query())
.headers(self.config.headers())
.json(&request)
.build()?)
};

self.execute_raw(request_maker).await
}

/// Make a POST request to {path} and deserialize the response body
pub(crate) async fn post<I, O>(&self, path: &str, request: I) -> Result<O, OpenAIError>
where
Expand Down Expand Up @@ -218,9 +237,8 @@ impl<C: Config> Client<C> {
/// request_maker serves one purpose: to be able to create request again
/// to retry API call after getting rate limited. request_maker is async because
/// reqwest::multipart::Form is created by async calls to read files for uploads.
async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
async fn execute_raw<M, Fut>(&self, request_maker: M) -> Result<Bytes, OpenAIError>
where
O: DeserializeOwned,
M: Fn() -> Fut,
Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
{
Expand Down Expand Up @@ -265,14 +283,30 @@ impl<C: Config> Client<C> {
}
}

let response: O = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;
Ok(response)
Ok(bytes)
})
.await
}

/// Execute a HTTP request and retry on rate limit
///
/// request_maker serves one purpose: to be able to create request again
/// to retry API call after getting rate limited. request_maker is async because
/// reqwest::multipart::Form is created by async calls to read files for uploads.
async fn execute<O, M, Fut>(&self, request_maker: M) -> Result<O, OpenAIError>
where
O: DeserializeOwned,
M: Fn() -> Fut,
Fut: core::future::Future<Output = Result<reqwest::Request, OpenAIError>>,
{
let bytes = self.execute_raw(request_maker).await?;

let response: O = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;

Ok(response)
}

/// Make HTTP POST request to receive SSE
pub(crate) async fn post_stream<I, O>(
&self,
Expand Down
52 changes: 52 additions & 0 deletions async-openai/src/types/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1225,6 +1225,30 @@ pub enum AudioResponseFormat {
Vtt,
}

#[derive(Debug, Serialize, Default, Clone, Copy, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum SpeechResponseFormat {
#[default]
Mp3,
Opus,
Aac,
Flac,
}

#[derive(Debug, Serialize, Clone, PartialEq)]
#[serde(rename_all = "snake_case")]
#[serde(untagged)]
#[non_exhaustive]
pub enum Voice {
Alloy,
Echo,
Fable,
Onyx,
Nova,
Shimmer,
Other(String),
}

#[derive(Clone, Default, Debug, Builder, PartialEq)]
#[builder(name = "CreateTranscriptionRequestArgs")]
#[builder(pattern = "mutable")]
Expand Down Expand Up @@ -1256,6 +1280,29 @@ pub struct CreateTranscriptionResponse {
pub text: String,
}

#[derive(Clone, Debug, Builder, PartialEq, Serialize)]
#[builder(name = "CreateSpeechRequestArgs")]
#[builder(pattern = "mutable")]
#[builder(setter(into, strip_option))]
#[builder(derive(Debug))]
#[builder(build_fn(error = "OpenAIError"))]
pub struct CreateSpeechRequest {
/// The text to generate audio for. The maximum length is 4096 characters.
pub input: String,

/// ID of the model to use. Only `tts-1` and `tts-1-hd` are currently available.
pub model: String,

/// The voice to use when generating the audio. Supported voices are `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`.
pub voice: Voice,

/// The format to audio in. Supported formats are mp3, opus, aac, and flac.
pub response_format: Option<SpeechResponseFormat>,

/// The speed of the generated audio. Select a value from 0.25 to 4.0. 1.0 is the default.
pub speed: Option<f32>, // default: 1.0
}

#[derive(Clone, Default, Debug, Builder, PartialEq)]
#[builder(name = "CreateTranslationRequestArgs")]
#[builder(pattern = "mutable")]
Expand Down Expand Up @@ -1283,3 +1330,8 @@ pub struct CreateTranslationRequest {
pub struct CreateTranslationResponse {
pub text: String,
}

#[derive(Debug, Deserialize, Clone, PartialEq, Serialize)]
pub struct CreateSpeechResponse {
pub text: String,
}
1 change: 1 addition & 0 deletions examples/audio-speech/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
audio.mp3
9 changes: 9 additions & 0 deletions examples/audio-speech/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[package]
name = "audio-speech"
version = "0.1.0"
edition = "2021"
publish = false

[dependencies]
async-openai = {path = "../../async-openai"}
tokio = { version = "1.25.0", features = ["full"] }
3 changes: 3 additions & 0 deletions examples/audio-speech/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
### Output (as an mp3 file)

> Today is a wonderful day to build something people love!
22 changes: 22 additions & 0 deletions examples/audio-speech/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use async_openai::{
types::{CreateSpeechRequestArgs, Voice},
Client,
};
use std::error::Error;

#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let client = Client::new();

let request = CreateSpeechRequestArgs::default()
.input("Today is a wonderful day to build something people love!".to_string())
.voice(Voice::Alloy)
.model("tts-1")
.build()?;

let response = client.audio().speech(request).await?;

std::fs::write("audio.mp3", response)?;

Ok(())
}

0 comments on commit e085d30

Please sign in to comment.