diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index 2c30cbc0..fe558259 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -619,9 +619,18 @@ where while let Some(ev) = event_source.next().await { match ev { Err(e) => { - if let Err(_e) = tx.send(Err(map_stream_error(e).await)) { - // rx dropped - break; + // Handle StreamEnded gracefully - it's a normal end of stream, not an error + // https://github.com/64bit/async-openai/issues/456 + match &e { + EventSourceError::StreamEnded => { + break; + } + _ => { + if let Err(_e) = tx.send(Err(map_stream_error(e).await)) { + // rx dropped + break; + } + } } } Ok(event) => match event { @@ -664,9 +673,18 @@ where while let Some(ev) = event_source.next().await { match ev { Err(e) => { - if let Err(_e) = tx.send(Err(map_stream_error(e).await)) { - // rx dropped - break; + // Handle StreamEnded gracefully - it's a normal end of stream, not an error + // https://github.com/64bit/async-openai/issues/456 + match &e { + EventSourceError::StreamEnded => { + break; + } + _ => { + if let Err(_e) = tx.send(Err(map_stream_error(e).await)) { + // rx dropped + break; + } + } } } Ok(event) => match event { diff --git a/async-openai/src/impls.rs b/async-openai/src/impls.rs new file mode 100644 index 00000000..2875411a --- /dev/null +++ b/async-openai/src/impls.rs @@ -0,0 +1,133 @@ +use crate::{ + admin_api_keys::AdminAPIKeys, + assistants::Assistants, + audio::Audio, + audit_logs::AuditLogs, + batches::Batches, + certificates::Certificates, + chat::Chat, + chatkit::{Chatkit, ChatkitSessions, ChatkitThreads}, + completion::Completions, + container_files::ContainerFiles, + containers::Containers, + conversation_items::ConversationItems, + conversations::Conversations, + embedding::Embeddings, + eval_run_output_items::EvalRunOutputItems, + eval_runs::EvalRuns, + evals::Evals, + file::Files, + fine_tuning::FineTuning, + group_roles::GroupRoles, + group_users::GroupUsers, + groups::Groups, + image::Images, + invites::Invites, + messages::Messages, + model::Models, + moderation::Moderations, + project_api_keys::ProjectAPIKeys, + project_certificates::ProjectCertificates, + project_group_roles::ProjectGroupRoles, + project_groups::ProjectGroups, + project_rate_limits::ProjectRateLimits, + project_roles::ProjectRoles, + project_service_accounts::ProjectServiceAccounts, + project_user_roles::ProjectUserRoles, + project_users::ProjectUsers, + projects::Projects, + responses::Responses, + roles::Roles, + runs::Runs, + speech::Speech, + steps::Steps, + threads::Threads, + transcriptions::Transcriptions, + translations::Translations, + uploads::Uploads, + usage::Usage, + user_roles::UserRoles, + users::Users, + vector_store_file_batches::VectorStoreFileBatches, + vector_store_files::VectorStoreFiles, + vector_stores::VectorStores, + video::Videos, +}; + +// request builder impls macro + +/// Macro to implement `RequestOptionsBuilder` for wrapper types containing `RequestOptions` +macro_rules! impl_request_options_builder { + ($type:ident) => { + impl<'c, C: crate::config::Config> crate::traits::RequestOptionsBuilder for $type<'c, C> { + fn options_mut(&mut self) -> &mut crate::RequestOptions { + &mut self.request_options + } + + fn options(&self) -> &crate::RequestOptions { + &self.request_options + } + } + }; +} + +#[cfg(feature = "realtime")] +use crate::Realtime; + +impl_request_options_builder!(AdminAPIKeys); +impl_request_options_builder!(Assistants); +impl_request_options_builder!(Audio); +impl_request_options_builder!(AuditLogs); +impl_request_options_builder!(Batches); +impl_request_options_builder!(Certificates); +impl_request_options_builder!(Chat); +impl_request_options_builder!(Chatkit); +impl_request_options_builder!(ChatkitSessions); +impl_request_options_builder!(ChatkitThreads); +impl_request_options_builder!(Completions); +impl_request_options_builder!(ContainerFiles); +impl_request_options_builder!(Containers); +impl_request_options_builder!(ConversationItems); +impl_request_options_builder!(Conversations); +impl_request_options_builder!(Embeddings); +impl_request_options_builder!(Evals); +impl_request_options_builder!(EvalRunOutputItems); +impl_request_options_builder!(EvalRuns); +impl_request_options_builder!(Files); +impl_request_options_builder!(FineTuning); +impl_request_options_builder!(GroupRoles); +impl_request_options_builder!(GroupUsers); +impl_request_options_builder!(Groups); +impl_request_options_builder!(Images); +impl_request_options_builder!(Invites); +impl_request_options_builder!(Messages); +impl_request_options_builder!(Models); +impl_request_options_builder!(Moderations); +impl_request_options_builder!(Projects); +impl_request_options_builder!(ProjectGroupRoles); +impl_request_options_builder!(ProjectGroups); +impl_request_options_builder!(ProjectRoles); +impl_request_options_builder!(ProjectUserRoles); +impl_request_options_builder!(ProjectUsers); +impl_request_options_builder!(ProjectServiceAccounts); +impl_request_options_builder!(ProjectAPIKeys); +impl_request_options_builder!(ProjectRateLimits); +impl_request_options_builder!(ProjectCertificates); +impl_request_options_builder!(Roles); +#[cfg(feature = "realtime")] +impl_request_options_builder!(Realtime); +impl_request_options_builder!(Responses); +impl_request_options_builder!(Runs); +impl_request_options_builder!(Speech); +impl_request_options_builder!(Steps); +impl_request_options_builder!(Threads); +impl_request_options_builder!(Transcriptions); +impl_request_options_builder!(Translations); +impl_request_options_builder!(Uploads); +impl_request_options_builder!(Usage); +impl_request_options_builder!(UserRoles); +impl_request_options_builder!(Users); +impl_request_options_builder!(VectorStoreFileBatches); +impl_request_options_builder!(VectorStoreFiles); +impl_request_options_builder!(VectorStores); +impl_request_options_builder!(Videos); diff --git a/async-openai/src/lib.rs b/async-openai/src/lib.rs index 200ac49f..fd17d1cb 100644 --- a/async-openai/src/lib.rs +++ b/async-openai/src/lib.rs @@ -168,6 +168,7 @@ mod group_roles; mod group_users; mod groups; mod image; +mod impls; mod invites; mod messages; mod model; diff --git a/async-openai/src/types/assistants/assistant_stream.rs b/async-openai/src/types/assistants/assistant_stream.rs index 1037b8ba..cbf39a86 100644 --- a/async-openai/src/types/assistants/assistant_stream.rs +++ b/async-openai/src/types/assistants/assistant_stream.rs @@ -4,7 +4,7 @@ use futures::Stream; use serde::Deserialize; use crate::error::{map_deserialization_error, ApiError, OpenAIError, StreamError}; - +use crate::traits::EventType; use crate::types::assistants::{ MessageDeltaObject, MessageObject, RunObject, RunStepDeltaObject, RunStepObject, ThreadObject, }; @@ -213,3 +213,35 @@ impl TryFrom for AssistantStreamEvent { } } } + +impl EventType for AssistantStreamEvent { + fn event_type(&self) -> &'static str { + match self { + AssistantStreamEvent::ThreadCreated(_) => "thread.created", + AssistantStreamEvent::ThreadRunCreated(_) => "thread.run.created", + AssistantStreamEvent::ThreadRunQueued(_) => "thread.run.queued", + AssistantStreamEvent::ThreadRunInProgress(_) => "thread.run.in_progress", + AssistantStreamEvent::ThreadRunRequiresAction(_) => "thread.run.requires_action", + AssistantStreamEvent::ThreadRunCompleted(_) => "thread.run.completed", + AssistantStreamEvent::ThreadRunIncomplete(_) => "thread.run.incomplete", + AssistantStreamEvent::ThreadRunFailed(_) => "thread.run.failed", + AssistantStreamEvent::ThreadRunCancelling(_) => "thread.run.cancelling", + AssistantStreamEvent::ThreadRunCancelled(_) => "thread.run.cancelled", + AssistantStreamEvent::ThreadRunExpired(_) => "thread.run.expired", + AssistantStreamEvent::ThreadRunStepCreated(_) => "thread.run.step.created", + AssistantStreamEvent::ThreadRunStepInProgress(_) => "thread.run.step.in_progress", + AssistantStreamEvent::ThreadRunStepDelta(_) => "thread.run.step.delta", + AssistantStreamEvent::ThreadRunStepCompleted(_) => "thread.run.step.completed", + AssistantStreamEvent::ThreadRunStepFailed(_) => "thread.run.step.failed", + AssistantStreamEvent::ThreadRunStepCancelled(_) => "thread.run.step.cancelled", + AssistantStreamEvent::ThreadRunStepExpired(_) => "thread.run.step.expired", + AssistantStreamEvent::ThreadMessageCreated(_) => "thread.message.created", + AssistantStreamEvent::ThreadMessageInProgress(_) => "thread.message.in_progress", + AssistantStreamEvent::ThreadMessageDelta(_) => "thread.message.delta", + AssistantStreamEvent::ThreadMessageCompleted(_) => "thread.message.completed", + AssistantStreamEvent::ThreadMessageIncomplete(_) => "thread.message.incomplete", + AssistantStreamEvent::ErrorEvent(_) => "error", + AssistantStreamEvent::Done(_) => "done", + } + } +} diff --git a/async-openai/src/types/assistants/impls.rs b/async-openai/src/types/assistants/impls.rs new file mode 100644 index 00000000..6daa4e1a --- /dev/null +++ b/async-openai/src/types/assistants/impls.rs @@ -0,0 +1,19 @@ +use crate::types::assistants::CreateMessageRequestContent; + +impl From for CreateMessageRequestContent { + fn from(value: String) -> Self { + Self::Content(value) + } +} + +impl From<&str> for CreateMessageRequestContent { + fn from(value: &str) -> Self { + Self::Content(value.to_string()) + } +} + +impl Default for CreateMessageRequestContent { + fn default() -> Self { + Self::Content("".into()) + } +} diff --git a/async-openai/src/types/assistants/mod.rs b/async-openai/src/types/assistants/mod.rs index 3f645f7a..303b7c05 100644 --- a/async-openai/src/types/assistants/mod.rs +++ b/async-openai/src/types/assistants/mod.rs @@ -2,6 +2,7 @@ mod api; mod assistant; mod assistant_impls; mod assistant_stream; +mod impls; mod message; mod run; mod step; diff --git a/async-openai/src/types/audio/form.rs b/async-openai/src/types/audio/form.rs new file mode 100644 index 00000000..35c17163 --- /dev/null +++ b/async-openai/src/types/audio/form.rs @@ -0,0 +1,105 @@ +use crate::{ + error::OpenAIError, + traits::AsyncTryFrom, + types::audio::{ + CreateTranscriptionRequest, CreateTranslationRequest, TranscriptionChunkingStrategy, + }, + util::create_file_part, +}; + +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from(request: CreateTranscriptionRequest) -> Result { + let audio_part = create_file_part(request.file.source).await?; + + let mut form = reqwest::multipart::Form::new() + .part("file", audio_part) + .text("model", request.model); + + if let Some(language) = request.language { + form = form.text("language", language); + } + + if let Some(prompt) = request.prompt { + form = form.text("prompt", prompt); + } + + if let Some(response_format) = request.response_format { + form = form.text("response_format", response_format.to_string()) + } + + if let Some(temperature) = request.temperature { + form = form.text("temperature", temperature.to_string()) + } + + if let Some(include) = request.include { + for inc in include { + form = form.text("include[]", inc.to_string()); + } + } + + if let Some(timestamp_granularities) = request.timestamp_granularities { + for tg in timestamp_granularities { + form = form.text("timestamp_granularities[]", tg.to_string()); + } + } + + if let Some(stream) = request.stream { + form = form.text("stream", stream.to_string()); + } + + if let Some(chunking_strategy) = request.chunking_strategy { + match chunking_strategy { + TranscriptionChunkingStrategy::Auto => { + form = form.text("chunking_strategy", "auto"); + } + TranscriptionChunkingStrategy::ServerVad(vad_config) => { + form = form.text( + "chunking_strategy", + serde_json::to_string(&vad_config).unwrap().to_string(), + ); + } + } + } + + if let Some(known_speaker_names) = request.known_speaker_names { + for kn in known_speaker_names { + form = form.text("known_speaker_names[]", kn.to_string()); + } + } + + if let Some(known_speaker_references) = request.known_speaker_references { + for kn in known_speaker_references { + form = form.text("known_speaker_references[]", kn.to_string()); + } + } + + Ok(form) + } +} + +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from(request: CreateTranslationRequest) -> Result { + let audio_part = create_file_part(request.file.source).await?; + + let mut form = reqwest::multipart::Form::new() + .part("file", audio_part) + .text("model", request.model); + + if let Some(prompt) = request.prompt { + form = form.text("prompt", prompt); + } + + if let Some(response_format) = request.response_format { + form = form.text("response_format", response_format.to_string()) + } + + if let Some(temperature) = request.temperature { + form = form.text("temperature", temperature.to_string()) + } + Ok(form) + } +} diff --git a/async-openai/src/types/audio/impls.rs b/async-openai/src/types/audio/impls.rs new file mode 100644 index 00000000..4c190354 --- /dev/null +++ b/async-openai/src/types/audio/impls.rs @@ -0,0 +1,63 @@ +use std::fmt::Display; + +use crate::types::audio::{ + AudioResponseFormat, TimestampGranularity, TranscriptionInclude, TranslationResponseFormat, +}; + +impl Display for AudioResponseFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + AudioResponseFormat::Json => "json", + AudioResponseFormat::Srt => "srt", + AudioResponseFormat::Text => "text", + AudioResponseFormat::VerboseJson => "verbose_json", + AudioResponseFormat::Vtt => "vtt", + AudioResponseFormat::DiarizedJson => "diarized_json", + } + ) + } +} + +impl Display for TranslationResponseFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + TranslationResponseFormat::Json => "json", + TranslationResponseFormat::Srt => "srt", + TranslationResponseFormat::Text => "text", + TranslationResponseFormat::VerboseJson => "verbose_json", + TranslationResponseFormat::Vtt => "vtt", + } + ) + } +} + +impl Display for TimestampGranularity { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + TimestampGranularity::Word => "word", + TimestampGranularity::Segment => "segment", + } + ) + } +} + +impl Display for TranscriptionInclude { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + TranscriptionInclude::Logprobs => "logprobs", + } + ) + } +} diff --git a/async-openai/src/types/audio/mod.rs b/async-openai/src/types/audio/mod.rs index ac1eb734..368627c7 100644 --- a/async-openai/src/types/audio/mod.rs +++ b/async-openai/src/types/audio/mod.rs @@ -1,4 +1,7 @@ mod audio_types; +mod form; +mod impls; +mod sdk; mod stream; pub use audio_types::*; diff --git a/async-openai/src/types/audio/sdk.rs b/async-openai/src/types/audio/sdk.rs new file mode 100644 index 00000000..165cb982 --- /dev/null +++ b/async-openai/src/types/audio/sdk.rs @@ -0,0 +1,18 @@ +use crate::{error::OpenAIError, types::audio::CreateSpeechResponse, util::create_all_dir}; +use std::path::Path; + +impl CreateSpeechResponse { + pub async fn save>(&self, file_path: P) -> Result<(), OpenAIError> { + let dir = file_path.as_ref().parent(); + + if let Some(dir) = dir { + create_all_dir(dir)?; + } + + tokio::fs::write(file_path, &self.bytes) + .await + .map_err(|e| OpenAIError::FileSaveError(e.to_string()))?; + + Ok(()) + } +} diff --git a/async-openai/src/types/chat/chat_types.rs b/async-openai/src/types/chat/chat_types.rs index ee4212f3..a77dfbba 100644 --- a/async-openai/src/types/chat/chat_types.rs +++ b/async-openai/src/types/chat/chat_types.rs @@ -521,7 +521,6 @@ pub struct ChatCompletionResponseMessage { #[builder(setter(into, strip_option), default)] #[builder(derive(Debug))] #[builder(build_fn(error = "OpenAIError"))] -#[deprecated] pub struct ChatCompletionFunctions { /// The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. pub name: String, diff --git a/async-openai/src/types/chat/impls.rs b/async-openai/src/types/chat/impls.rs new file mode 100644 index 00000000..8869f41a --- /dev/null +++ b/async-openai/src/types/chat/impls.rs @@ -0,0 +1,334 @@ +use std::fmt::Display; + +use crate::types::chat::{ + ChatCompletionFunctionCall, ChatCompletionNamedToolChoice, + ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent, + ChatCompletionRequestDeveloperMessage, ChatCompletionRequestDeveloperMessageContent, + ChatCompletionRequestFunctionMessage, ChatCompletionRequestMessage, + ChatCompletionRequestMessageContentPartAudio, ChatCompletionRequestMessageContentPartImage, + ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessage, + ChatCompletionRequestSystemMessageContent, ChatCompletionRequestToolMessage, + ChatCompletionRequestToolMessageContent, ChatCompletionRequestUserMessage, + ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart, + FunctionName, ImageUrl, Role, +}; + +impl From for ChatCompletionRequestMessage { + fn from(value: ChatCompletionRequestUserMessage) -> Self { + Self::User(value) + } +} + +impl From for ChatCompletionRequestMessage { + fn from(value: ChatCompletionRequestSystemMessage) -> Self { + Self::System(value) + } +} + +impl From for ChatCompletionRequestMessage { + fn from(value: ChatCompletionRequestDeveloperMessage) -> Self { + Self::Developer(value) + } +} + +impl From for ChatCompletionRequestMessage { + fn from(value: ChatCompletionRequestAssistantMessage) -> Self { + Self::Assistant(value) + } +} + +impl From for ChatCompletionRequestMessage { + fn from(value: ChatCompletionRequestFunctionMessage) -> Self { + Self::Function(value) + } +} + +impl From for ChatCompletionRequestMessage { + fn from(value: ChatCompletionRequestToolMessage) -> Self { + Self::Tool(value) + } +} + +impl From for ChatCompletionRequestUserMessage { + fn from(value: ChatCompletionRequestUserMessageContent) -> Self { + Self { + content: value, + name: None, + } + } +} + +impl From for ChatCompletionRequestSystemMessage { + fn from(value: ChatCompletionRequestSystemMessageContent) -> Self { + Self { + content: value, + name: None, + } + } +} + +impl From for ChatCompletionRequestDeveloperMessage { + fn from(value: ChatCompletionRequestDeveloperMessageContent) -> Self { + Self { + content: value, + name: None, + } + } +} + +impl From for ChatCompletionRequestAssistantMessage { + fn from(value: ChatCompletionRequestAssistantMessageContent) -> Self { + Self { + content: Some(value), + ..Default::default() + } + } +} + +impl From<&str> for ChatCompletionRequestUserMessageContent { + fn from(value: &str) -> Self { + ChatCompletionRequestUserMessageContent::Text(value.into()) + } +} + +impl From for ChatCompletionRequestUserMessageContent { + fn from(value: String) -> Self { + ChatCompletionRequestUserMessageContent::Text(value) + } +} + +impl From<&str> for ChatCompletionRequestSystemMessageContent { + fn from(value: &str) -> Self { + ChatCompletionRequestSystemMessageContent::Text(value.into()) + } +} + +impl From for ChatCompletionRequestSystemMessageContent { + fn from(value: String) -> Self { + ChatCompletionRequestSystemMessageContent::Text(value) + } +} + +impl From<&str> for ChatCompletionRequestDeveloperMessageContent { + fn from(value: &str) -> Self { + ChatCompletionRequestDeveloperMessageContent::Text(value.into()) + } +} + +impl From for ChatCompletionRequestDeveloperMessageContent { + fn from(value: String) -> Self { + ChatCompletionRequestDeveloperMessageContent::Text(value) + } +} + +impl From<&str> for ChatCompletionRequestAssistantMessageContent { + fn from(value: &str) -> Self { + ChatCompletionRequestAssistantMessageContent::Text(value.into()) + } +} + +impl From for ChatCompletionRequestAssistantMessageContent { + fn from(value: String) -> Self { + ChatCompletionRequestAssistantMessageContent::Text(value) + } +} + +impl From<&str> for ChatCompletionRequestToolMessageContent { + fn from(value: &str) -> Self { + ChatCompletionRequestToolMessageContent::Text(value.into()) + } +} + +impl From for ChatCompletionRequestToolMessageContent { + fn from(value: String) -> Self { + ChatCompletionRequestToolMessageContent::Text(value) + } +} + +impl From<&str> for ChatCompletionRequestUserMessage { + fn from(value: &str) -> Self { + ChatCompletionRequestUserMessageContent::Text(value.into()).into() + } +} + +impl From for ChatCompletionRequestUserMessage { + fn from(value: String) -> Self { + value.as_str().into() + } +} + +impl From<&str> for ChatCompletionRequestSystemMessage { + fn from(value: &str) -> Self { + ChatCompletionRequestSystemMessageContent::Text(value.into()).into() + } +} + +impl From<&str> for ChatCompletionRequestDeveloperMessage { + fn from(value: &str) -> Self { + ChatCompletionRequestDeveloperMessageContent::Text(value.into()).into() + } +} + +impl From for ChatCompletionRequestSystemMessage { + fn from(value: String) -> Self { + value.as_str().into() + } +} + +impl From for ChatCompletionRequestDeveloperMessage { + fn from(value: String) -> Self { + value.as_str().into() + } +} + +impl From<&str> for ChatCompletionRequestAssistantMessage { + fn from(value: &str) -> Self { + ChatCompletionRequestAssistantMessageContent::Text(value.into()).into() + } +} + +impl From for ChatCompletionRequestAssistantMessage { + fn from(value: String) -> Self { + value.as_str().into() + } +} + +impl From> + for ChatCompletionRequestUserMessageContent +{ + fn from(value: Vec) -> Self { + ChatCompletionRequestUserMessageContent::Array(value) + } +} + +impl From + for ChatCompletionRequestUserMessageContentPart +{ + fn from(value: ChatCompletionRequestMessageContentPartText) -> Self { + ChatCompletionRequestUserMessageContentPart::Text(value) + } +} + +impl From + for ChatCompletionRequestUserMessageContentPart +{ + fn from(value: ChatCompletionRequestMessageContentPartImage) -> Self { + ChatCompletionRequestUserMessageContentPart::ImageUrl(value) + } +} + +impl From + for ChatCompletionRequestUserMessageContentPart +{ + fn from(value: ChatCompletionRequestMessageContentPartAudio) -> Self { + ChatCompletionRequestUserMessageContentPart::InputAudio(value) + } +} + +impl From<&str> for ChatCompletionRequestMessageContentPartText { + fn from(value: &str) -> Self { + ChatCompletionRequestMessageContentPartText { text: value.into() } + } +} + +impl From for ChatCompletionRequestMessageContentPartText { + fn from(value: String) -> Self { + ChatCompletionRequestMessageContentPartText { text: value } + } +} + +impl From<&str> for ChatCompletionFunctionCall { + fn from(value: &str) -> Self { + match value { + "auto" => Self::Auto, + "none" => Self::None, + _ => Self::Function { name: value.into() }, + } + } +} + +impl From<&str> for FunctionName { + fn from(value: &str) -> Self { + Self { name: value.into() } + } +} + +impl From for FunctionName { + fn from(value: String) -> Self { + Self { name: value } + } +} + +impl From<&str> for ChatCompletionNamedToolChoice { + fn from(value: &str) -> Self { + Self { + function: value.into(), + } + } +} + +impl From for ChatCompletionNamedToolChoice { + fn from(value: String) -> Self { + Self { + function: value.into(), + } + } +} + +impl Default for ChatCompletionRequestDeveloperMessageContent { + fn default() -> Self { + ChatCompletionRequestDeveloperMessageContent::Text("".into()) + } +} + +impl Default for ChatCompletionRequestSystemMessageContent { + fn default() -> Self { + ChatCompletionRequestSystemMessageContent::Text("".into()) + } +} + +impl Default for ChatCompletionRequestToolMessageContent { + fn default() -> Self { + ChatCompletionRequestToolMessageContent::Text("".into()) + } +} + +impl Default for ChatCompletionRequestUserMessageContent { + fn default() -> Self { + ChatCompletionRequestUserMessageContent::Text("".into()) + } +} + +impl From<&str> for ImageUrl { + fn from(value: &str) -> Self { + Self { + url: value.into(), + detail: Default::default(), + } + } +} + +impl From for ImageUrl { + fn from(value: String) -> Self { + Self { + url: value, + detail: Default::default(), + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Role::User => "user", + Role::System => "system", + Role::Assistant => "assistant", + Role::Function => "function", + Role::Tool => "tool", + } + ) + } +} diff --git a/async-openai/src/types/chat/mod.rs b/async-openai/src/types/chat/mod.rs index f3f50df4..74a5e910 100644 --- a/async-openai/src/types/chat/mod.rs +++ b/async-openai/src/types/chat/mod.rs @@ -1,5 +1,6 @@ mod api; mod chat_types; +mod impls; pub use api::*; pub use chat_types::*; diff --git a/async-openai/src/types/containers/form.rs b/async-openai/src/types/containers/form.rs new file mode 100644 index 00000000..1a0c99f6 --- /dev/null +++ b/async-openai/src/types/containers/form.rs @@ -0,0 +1,22 @@ +use crate::{ + error::OpenAIError, traits::AsyncTryFrom, types::containers::CreateContainerFileRequest, + util::create_file_part, +}; + +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from(request: CreateContainerFileRequest) -> Result { + let mut form = reqwest::multipart::Form::new(); + + // Either file or file_id should be provided + if let Some(file_source) = request.file { + let file_part = create_file_part(file_source).await?; + form = form.part("file", file_part); + } else if let Some(file_id) = request.file_id { + form = form.text("file_id", file_id); + } + + Ok(form) + } +} diff --git a/async-openai/src/types/containers/mod.rs b/async-openai/src/types/containers/mod.rs index b7ab5455..50ff2afb 100644 --- a/async-openai/src/types/containers/mod.rs +++ b/async-openai/src/types/containers/mod.rs @@ -1,5 +1,6 @@ mod api; mod container; +mod form; pub use api::*; pub use container::*; diff --git a/async-openai/src/types/files/form.rs b/async-openai/src/types/files/form.rs new file mode 100644 index 00000000..b80428d7 --- /dev/null +++ b/async-openai/src/types/files/form.rs @@ -0,0 +1,22 @@ +use crate::{ + error::OpenAIError, traits::AsyncTryFrom, types::files::CreateFileRequest, + util::create_file_part, +}; + +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from(request: CreateFileRequest) -> Result { + let file_part = create_file_part(request.file.source).await?; + let mut form = reqwest::multipart::Form::new() + .part("file", file_part) + .text("purpose", request.purpose.to_string()); + + if let Some(expires_after) = request.expires_after { + form = form + .text("expires_after[anchor]", expires_after.anchor.to_string()) + .text("expires_after[seconds]", expires_after.seconds.to_string()); + } + Ok(form) + } +} diff --git a/async-openai/src/types/files/impls.rs b/async-openai/src/types/files/impls.rs new file mode 100644 index 00000000..06eb774b --- /dev/null +++ b/async-openai/src/types/files/impls.rs @@ -0,0 +1,32 @@ +use std::fmt::Display; + +use crate::types::files::{FileExpirationAfterAnchor, FilePurpose}; + +impl Display for FilePurpose { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::Assistants => "assistants", + Self::Batch => "batch", + Self::FineTune => "fine-tune", + Self::Vision => "vision", + Self::UserData => "user_data", + Self::Evals => "evals", + } + ) + } +} + +impl Display for FileExpirationAfterAnchor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::CreatedAt => "created_at", + } + ) + } +} diff --git a/async-openai/src/types/files/mod.rs b/async-openai/src/types/files/mod.rs index add19077..f2e99112 100644 --- a/async-openai/src/types/files/mod.rs +++ b/async-openai/src/types/files/mod.rs @@ -1,5 +1,7 @@ mod api; mod file; +mod form; +mod impls; pub use api::*; pub use file::*; diff --git a/async-openai/src/types/images/form.rs b/async-openai/src/types/images/form.rs new file mode 100644 index 00000000..5c885d71 --- /dev/null +++ b/async-openai/src/types/images/form.rs @@ -0,0 +1,120 @@ +use crate::{ + error::OpenAIError, + traits::AsyncTryFrom, + types::images::{CreateImageEditRequest, CreateImageVariationRequest, ImageEditInput}, + util::create_file_part, +}; + +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from(request: CreateImageEditRequest) -> Result { + let mut form = reqwest::multipart::Form::new().text("prompt", request.prompt); + + match request.image { + ImageEditInput::Image(image) => { + let image_part = create_file_part(image.source).await?; + form = form.part("image", image_part); + } + ImageEditInput::Images(images) => { + for image in images { + let image_part = create_file_part(image.source).await?; + form = form.part("image[]", image_part); + } + } + } + + if let Some(mask) = request.mask { + let mask_part = create_file_part(mask.source).await?; + form = form.part("mask", mask_part); + } + + if let Some(background) = request.background { + form = form.text("background", background.to_string()) + } + + if let Some(model) = request.model { + form = form.text("model", model.to_string()) + } + + if let Some(n) = request.n { + form = form.text("n", n.to_string()) + } + + if let Some(size) = request.size { + form = form.text("size", size.to_string()) + } + + if let Some(response_format) = request.response_format { + form = form.text("response_format", response_format.to_string()) + } + + if let Some(output_format) = request.output_format { + form = form.text("output_format", output_format.to_string()) + } + + if let Some(output_compression) = request.output_compression { + form = form.text("output_compression", output_compression.to_string()) + } + + if let Some(output_compression) = request.output_compression { + form = form.text("output_compression", output_compression.to_string()) + } + + if let Some(user) = request.user { + form = form.text("user", user) + } + + if let Some(input_fidelity) = request.input_fidelity { + form = form.text("input_fidelity", input_fidelity.to_string()) + } + + if let Some(stream) = request.stream { + form = form.text("stream", stream.to_string()) + } + + if let Some(partial_images) = request.partial_images { + form = form.text("partial_images", partial_images.to_string()) + } + + if let Some(quality) = request.quality { + form = form.text("quality", quality.to_string()) + } + + Ok(form) + } +} + +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from(request: CreateImageVariationRequest) -> Result { + let image_part = create_file_part(request.image.source).await?; + + let mut form = reqwest::multipart::Form::new().part("image", image_part); + + if let Some(model) = request.model { + form = form.text("model", model.to_string()) + } + + if request.n.is_some() { + form = form.text("n", request.n.unwrap().to_string()) + } + + if request.size.is_some() { + form = form.text("size", request.size.unwrap().to_string()) + } + + if request.response_format.is_some() { + form = form.text( + "response_format", + request.response_format.unwrap().to_string(), + ) + } + + if request.user.is_some() { + form = form.text("user", request.user.unwrap()) + } + Ok(form) + } +} diff --git a/async-openai/src/types/images/impls.rs b/async-openai/src/types/images/impls.rs new file mode 100644 index 00000000..e7ff0014 --- /dev/null +++ b/async-openai/src/types/images/impls.rs @@ -0,0 +1,222 @@ +use std::{ + fmt::Display, + path::{Path, PathBuf}, +}; + +use crate::types::images::{ + DallE2ImageSize, ImageBackground, ImageEditInput, ImageInput, ImageModel, ImageOutputFormat, + ImageQuality, ImageResponseFormat, ImageSize, InputFidelity, +}; + +impl Display for ImageSize { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::S256x256 => "256x256", + Self::S512x512 => "512x512", + Self::S1024x1024 => "1024x1024", + Self::S1792x1024 => "1792x1024", + Self::S1024x1792 => "1024x1792", + Self::S1536x1024 => "1536x1024", + Self::S1024x1536 => "1024x1536", + Self::Auto => "auto", + } + ) + } +} + +impl Display for DallE2ImageSize { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::S256x256 => "256x256", + Self::S512x512 => "512x512", + Self::S1024x1024 => "1024x1024", + } + ) + } +} + +impl Display for ImageModel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::DallE2 => "dall-e-2", + Self::DallE3 => "dall-e-3", + Self::GptImage1 => "gpt-image-1", + Self::GptImage1Mini => "gpt-image-1-mini", + Self::Other(other) => other, + } + ) + } +} + +impl Display for ImageBackground { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::Transparent => "transparent", + Self::Opaque => "opaque", + Self::Auto => "auto", + } + ) + } +} + +impl Display for ImageOutputFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::Png => "png", + Self::Jpeg => "jpeg", + Self::Webp => "webp", + } + ) + } +} + +impl Display for InputFidelity { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::High => "high", + Self::Low => "low", + } + ) + } +} + +impl Display for ImageQuality { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::Low => "low", + Self::Medium => "medium", + Self::High => "high", + Self::Auto => "auto", + Self::Standard => "standard", + Self::HD => "hd", + } + ) + } +} + +impl Display for ImageResponseFormat { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::Url => "url", + Self::B64Json => "b64_json", + } + ) + } +} + +impl Default for ImageEditInput { + fn default() -> Self { + Self::Image(ImageInput::default()) + } +} + +impl From for ImageEditInput { + fn from(value: ImageInput) -> Self { + Self::Image(value) + } +} + +impl From> for ImageEditInput { + fn from(value: Vec) -> Self { + Self::Images(value) + } +} + +// Single path-like values +impl From<&str> for ImageEditInput { + fn from(value: &str) -> Self { + Self::Image(value.into()) + } +} + +impl From for ImageEditInput { + fn from(value: String) -> Self { + Self::Image(value.into()) + } +} + +impl From<&Path> for ImageEditInput { + fn from(value: &Path) -> Self { + Self::Image(value.into()) + } +} + +impl From for ImageEditInput { + fn from(value: PathBuf) -> Self { + Self::Image(value.into()) + } +} + +// Arrays of path-like values +impl From<[&str; N]> for ImageEditInput { + fn from(value: [&str; N]) -> Self { + Self::Images(value.into_iter().map(ImageInput::from).collect()) + } +} + +impl From<[String; N]> for ImageEditInput { + fn from(value: [String; N]) -> Self { + Self::Images(value.into_iter().map(ImageInput::from).collect()) + } +} + +impl From<[&Path; N]> for ImageEditInput { + fn from(value: [&Path; N]) -> Self { + Self::Images(value.into_iter().map(ImageInput::from).collect()) + } +} + +impl From<[PathBuf; N]> for ImageEditInput { + fn from(value: [PathBuf; N]) -> Self { + Self::Images(value.into_iter().map(ImageInput::from).collect()) + } +} + +// Vectors of path-like values +impl<'a> From> for ImageEditInput { + fn from(value: Vec<&'a str>) -> Self { + Self::Images(value.into_iter().map(ImageInput::from).collect()) + } +} + +impl From> for ImageEditInput { + fn from(value: Vec) -> Self { + Self::Images(value.into_iter().map(ImageInput::from).collect()) + } +} + +impl From> for ImageEditInput { + fn from(value: Vec<&Path>) -> Self { + Self::Images(value.into_iter().map(ImageInput::from).collect()) + } +} + +impl From> for ImageEditInput { + fn from(value: Vec) -> Self { + Self::Images(value.into_iter().map(ImageInput::from).collect()) + } +} diff --git a/async-openai/src/types/images/mod.rs b/async-openai/src/types/images/mod.rs index 13a6e0e1..1610ce8f 100644 --- a/async-openai/src/types/images/mod.rs +++ b/async-openai/src/types/images/mod.rs @@ -1,4 +1,7 @@ +mod form; mod image; +mod impls; +mod sdk; mod stream; pub use image::*; diff --git a/async-openai/src/types/images/sdk.rs b/async-openai/src/types/images/sdk.rs new file mode 100644 index 00000000..c87c53b9 --- /dev/null +++ b/async-openai/src/types/images/sdk.rs @@ -0,0 +1,56 @@ +use crate::{ + download::{download_url, save_b64}, + error::OpenAIError, + types::images::{Image, ImagesResponse}, + util::create_all_dir, +}; +use std::path::{Path, PathBuf}; + +impl ImagesResponse { + /// Save each image in a dedicated Tokio task and return paths to saved files. + /// For [ResponseFormat::Url] each file is downloaded in dedicated Tokio task. + pub async fn save>(&self, dir: P) -> Result, OpenAIError> { + create_all_dir(dir.as_ref())?; + + let mut handles = vec![]; + for id in self.data.clone() { + let dir_buf = PathBuf::from(dir.as_ref()); + handles.push(tokio::spawn(async move { id.save(dir_buf).await })); + } + + let results = futures::future::join_all(handles).await; + let mut errors = vec![]; + let mut paths = vec![]; + + for result in results { + match result { + Ok(inner) => match inner { + Ok(path) => paths.push(path), + Err(e) => errors.push(e), + }, + Err(e) => errors.push(OpenAIError::FileSaveError(e.to_string())), + } + } + + if errors.is_empty() { + Ok(paths) + } else { + Err(OpenAIError::FileSaveError( + errors + .into_iter() + .map(|e| e.to_string()) + .collect::>() + .join("; "), + )) + } + } +} + +impl Image { + async fn save>(&self, dir: P) -> Result { + match self { + Image::Url { url, .. } => download_url(url, dir).await, + Image::B64Json { b64_json, .. } => save_b64(b64_json, dir).await, + } + } +} diff --git a/async-openai/src/types/impls.rs b/async-openai/src/types/impls.rs index 868e227a..7de6e902 100644 --- a/async-openai/src/types/impls.rs +++ b/async-openai/src/types/impls.rs @@ -1,49 +1,13 @@ -use std::{ - fmt::Display, - path::{Path, PathBuf}, -}; - -use crate::{ - download::{download_url, save_b64}, - error::OpenAIError, - traits::AsyncTryFrom, - types::{ - assistants::CreateMessageRequestContent, - audio::{ - AudioInput, AudioResponseFormat, CreateSpeechResponse, CreateTranscriptionRequest, - CreateTranslationRequest, TimestampGranularity, TranscriptionInclude, - }, - audio::{TranscriptionChunkingStrategy, TranslationResponseFormat}, - chat::{ - ChatCompletionFunctionCall, ChatCompletionFunctions, ChatCompletionNamedToolChoice, - }, - chat::{ - ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageContent, - ChatCompletionRequestDeveloperMessage, ChatCompletionRequestDeveloperMessageContent, - ChatCompletionRequestFunctionMessage, ChatCompletionRequestMessage, - ChatCompletionRequestMessageContentPartAudio, - ChatCompletionRequestMessageContentPartImage, - ChatCompletionRequestMessageContentPartText, ChatCompletionRequestSystemMessage, - ChatCompletionRequestSystemMessageContent, ChatCompletionRequestToolMessage, - ChatCompletionRequestToolMessageContent, ChatCompletionRequestUserMessage, - ChatCompletionRequestUserMessageContent, ChatCompletionRequestUserMessageContentPart, - FunctionName, ImageUrl, Prompt, Role, StopConfiguration, - }, - containers::CreateContainerFileRequest, - embeddings::EmbeddingInput, - files::{CreateFileRequest, FileExpirationAfterAnchor, FileInput, FilePurpose}, - images::{ - CreateImageEditRequest, CreateImageVariationRequest, DallE2ImageSize, Image, - ImageInput, ImageModel, ImageResponseFormat, ImageSize, ImagesResponse, - }, - images::{ImageBackground, ImageEditInput, ImageOutputFormat, ImageQuality, InputFidelity}, - moderations::ModerationInput, - responses::EasyInputContent, - uploads::AddUploadPartRequest, - videos::{CreateVideoRequest, VideoSize}, - InputSource, - }, - util::{create_all_dir, create_file_part}, +use std::path::{Path, PathBuf}; + +use crate::types::{ + audio::AudioInput, + chat::{Prompt, StopConfiguration}, + embeddings::EmbeddingInput, + files::FileInput, + images::ImageInput, + moderations::ModerationInput, + InputSource, }; use bytes::Bytes; @@ -177,402 +141,6 @@ impl_input!(AudioInput); impl_input!(FileInput); impl_input!(ImageInput); -impl Default for ImageEditInput { - fn default() -> Self { - Self::Image(ImageInput::default()) - } -} - -impl From for ImageEditInput { - fn from(value: ImageInput) -> Self { - Self::Image(value) - } -} - -impl From> for ImageEditInput { - fn from(value: Vec) -> Self { - Self::Images(value) - } -} - -// Single path-like values -impl From<&str> for ImageEditInput { - fn from(value: &str) -> Self { - Self::Image(value.into()) - } -} - -impl From for ImageEditInput { - fn from(value: String) -> Self { - Self::Image(value.into()) - } -} - -impl From<&Path> for ImageEditInput { - fn from(value: &Path) -> Self { - Self::Image(value.into()) - } -} - -impl From for ImageEditInput { - fn from(value: PathBuf) -> Self { - Self::Image(value.into()) - } -} - -// Arrays of path-like values -impl From<[&str; N]> for ImageEditInput { - fn from(value: [&str; N]) -> Self { - Self::Images(value.into_iter().map(ImageInput::from).collect()) - } -} - -impl From<[String; N]> for ImageEditInput { - fn from(value: [String; N]) -> Self { - Self::Images(value.into_iter().map(ImageInput::from).collect()) - } -} - -impl From<[&Path; N]> for ImageEditInput { - fn from(value: [&Path; N]) -> Self { - Self::Images(value.into_iter().map(ImageInput::from).collect()) - } -} - -impl From<[PathBuf; N]> for ImageEditInput { - fn from(value: [PathBuf; N]) -> Self { - Self::Images(value.into_iter().map(ImageInput::from).collect()) - } -} - -// Vectors of path-like values -impl<'a> From> for ImageEditInput { - fn from(value: Vec<&'a str>) -> Self { - Self::Images(value.into_iter().map(ImageInput::from).collect()) - } -} - -impl From> for ImageEditInput { - fn from(value: Vec) -> Self { - Self::Images(value.into_iter().map(ImageInput::from).collect()) - } -} - -impl From> for ImageEditInput { - fn from(value: Vec<&Path>) -> Self { - Self::Images(value.into_iter().map(ImageInput::from).collect()) - } -} - -impl From> for ImageEditInput { - fn from(value: Vec) -> Self { - Self::Images(value.into_iter().map(ImageInput::from).collect()) - } -} - -impl Display for VideoSize { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::S720x1280 => "720x1280", - Self::S1280x720 => "1280x720", - Self::S1024x1792 => "1024x1792", - Self::S1792x1024 => "1792x1024", - } - ) - } -} - -impl Display for ImageSize { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::S256x256 => "256x256", - Self::S512x512 => "512x512", - Self::S1024x1024 => "1024x1024", - Self::S1792x1024 => "1792x1024", - Self::S1024x1792 => "1024x1792", - Self::S1536x1024 => "1536x1024", - Self::S1024x1536 => "1024x1536", - Self::Auto => "auto", - } - ) - } -} - -impl Display for DallE2ImageSize { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::S256x256 => "256x256", - Self::S512x512 => "512x512", - Self::S1024x1024 => "1024x1024", - } - ) - } -} - -impl Display for ImageModel { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::DallE2 => "dall-e-2", - Self::DallE3 => "dall-e-3", - Self::GptImage1 => "gpt-image-1", - Self::GptImage1Mini => "gpt-image-1-mini", - Self::Other(other) => other, - } - ) - } -} - -impl Display for ImageBackground { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::Transparent => "transparent", - Self::Opaque => "opaque", - Self::Auto => "auto", - } - ) - } -} - -impl Display for ImageOutputFormat { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::Png => "png", - Self::Jpeg => "jpeg", - Self::Webp => "webp", - } - ) - } -} - -impl Display for InputFidelity { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::High => "high", - Self::Low => "low", - } - ) - } -} - -impl Display for ImageQuality { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::Low => "low", - Self::Medium => "medium", - Self::High => "high", - Self::Auto => "auto", - Self::Standard => "standard", - Self::HD => "hd", - } - ) - } -} - -impl Display for ImageResponseFormat { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::Url => "url", - Self::B64Json => "b64_json", - } - ) - } -} - -impl Display for AudioResponseFormat { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - AudioResponseFormat::Json => "json", - AudioResponseFormat::Srt => "srt", - AudioResponseFormat::Text => "text", - AudioResponseFormat::VerboseJson => "verbose_json", - AudioResponseFormat::Vtt => "vtt", - AudioResponseFormat::DiarizedJson => "diarized_json", - } - ) - } -} - -impl Display for TranslationResponseFormat { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - TranslationResponseFormat::Json => "json", - TranslationResponseFormat::Srt => "srt", - TranslationResponseFormat::Text => "text", - TranslationResponseFormat::VerboseJson => "verbose_json", - TranslationResponseFormat::Vtt => "vtt", - } - ) - } -} - -impl Display for TimestampGranularity { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - TimestampGranularity::Word => "word", - TimestampGranularity::Segment => "segment", - } - ) - } -} - -impl Display for TranscriptionInclude { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - TranscriptionInclude::Logprobs => "logprobs", - } - ) - } -} - -impl Display for Role { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Role::User => "user", - Role::System => "system", - Role::Assistant => "assistant", - Role::Function => "function", - Role::Tool => "tool", - } - ) - } -} - -impl Display for FilePurpose { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::Assistants => "assistants", - Self::Batch => "batch", - Self::FineTune => "fine-tune", - Self::Vision => "vision", - Self::UserData => "user_data", - Self::Evals => "evals", - } - ) - } -} - -impl Display for FileExpirationAfterAnchor { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}", - match self { - Self::CreatedAt => "created_at", - } - ) - } -} - -impl ImagesResponse { - /// Save each image in a dedicated Tokio task and return paths to saved files. - /// For [ResponseFormat::Url] each file is downloaded in dedicated Tokio task. - pub async fn save>(&self, dir: P) -> Result, OpenAIError> { - create_all_dir(dir.as_ref())?; - - let mut handles = vec![]; - for id in self.data.clone() { - let dir_buf = PathBuf::from(dir.as_ref()); - handles.push(tokio::spawn(async move { id.save(dir_buf).await })); - } - - let results = futures::future::join_all(handles).await; - let mut errors = vec![]; - let mut paths = vec![]; - - for result in results { - match result { - Ok(inner) => match inner { - Ok(path) => paths.push(path), - Err(e) => errors.push(e), - }, - Err(e) => errors.push(OpenAIError::FileSaveError(e.to_string())), - } - } - - if errors.is_empty() { - Ok(paths) - } else { - Err(OpenAIError::FileSaveError( - errors - .into_iter() - .map(|e| e.to_string()) - .collect::>() - .join("; "), - )) - } - } -} - -impl CreateSpeechResponse { - pub async fn save>(&self, file_path: P) -> Result<(), OpenAIError> { - let dir = file_path.as_ref().parent(); - - if let Some(dir) = dir { - create_all_dir(dir)?; - } - - tokio::fs::write(file_path, &self.bytes) - .await - .map_err(|e| OpenAIError::FileSaveError(e.to_string()))?; - - Ok(()) - } -} - -impl Image { - async fn save>(&self, dir: P) -> Result { - match self { - Image::Url { url, .. } => download_url(url, dir).await, - Image::B64Json { b64_json, .. } => save_b64(b64_json, dir).await, - } - } -} - macro_rules! impl_from_for_integer_array { ($from_typ:ty, $to_typ:ty) => { impl From<[$from_typ; N]> for $to_typ { @@ -700,790 +268,3 @@ macro_rules! impl_from_for_array_of_integer_array { impl_from_for_array_of_integer_array!(u32, EmbeddingInput); impl_from_for_array_of_integer_array!(u32, Prompt); - -impl From<&str> for ChatCompletionFunctionCall { - fn from(value: &str) -> Self { - match value { - "auto" => Self::Auto, - "none" => Self::None, - _ => Self::Function { name: value.into() }, - } - } -} - -impl From<&str> for FunctionName { - fn from(value: &str) -> Self { - Self { name: value.into() } - } -} - -impl From for FunctionName { - fn from(value: String) -> Self { - Self { name: value } - } -} - -impl From<&str> for ChatCompletionNamedToolChoice { - fn from(value: &str) -> Self { - Self { - function: value.into(), - } - } -} - -impl From for ChatCompletionNamedToolChoice { - fn from(value: String) -> Self { - Self { - function: value.into(), - } - } -} - -impl From<(String, serde_json::Value)> for ChatCompletionFunctions { - fn from(value: (String, serde_json::Value)) -> Self { - Self { - name: value.0, - description: None, - parameters: value.1, - } - } -} - -// todo: write macro for bunch of same looking From trait implementations below - -impl From for ChatCompletionRequestMessage { - fn from(value: ChatCompletionRequestUserMessage) -> Self { - Self::User(value) - } -} - -impl From for ChatCompletionRequestMessage { - fn from(value: ChatCompletionRequestSystemMessage) -> Self { - Self::System(value) - } -} - -impl From for ChatCompletionRequestMessage { - fn from(value: ChatCompletionRequestDeveloperMessage) -> Self { - Self::Developer(value) - } -} - -impl From for ChatCompletionRequestMessage { - fn from(value: ChatCompletionRequestAssistantMessage) -> Self { - Self::Assistant(value) - } -} - -impl From for ChatCompletionRequestMessage { - fn from(value: ChatCompletionRequestFunctionMessage) -> Self { - Self::Function(value) - } -} - -impl From for ChatCompletionRequestMessage { - fn from(value: ChatCompletionRequestToolMessage) -> Self { - Self::Tool(value) - } -} - -impl From for ChatCompletionRequestUserMessage { - fn from(value: ChatCompletionRequestUserMessageContent) -> Self { - Self { - content: value, - name: None, - } - } -} - -impl From for ChatCompletionRequestSystemMessage { - fn from(value: ChatCompletionRequestSystemMessageContent) -> Self { - Self { - content: value, - name: None, - } - } -} - -impl From for ChatCompletionRequestDeveloperMessage { - fn from(value: ChatCompletionRequestDeveloperMessageContent) -> Self { - Self { - content: value, - name: None, - } - } -} - -impl From for ChatCompletionRequestAssistantMessage { - fn from(value: ChatCompletionRequestAssistantMessageContent) -> Self { - Self { - content: Some(value), - ..Default::default() - } - } -} - -impl From<&str> for ChatCompletionRequestUserMessageContent { - fn from(value: &str) -> Self { - ChatCompletionRequestUserMessageContent::Text(value.into()) - } -} - -impl From for ChatCompletionRequestUserMessageContent { - fn from(value: String) -> Self { - ChatCompletionRequestUserMessageContent::Text(value) - } -} - -impl From<&str> for ChatCompletionRequestSystemMessageContent { - fn from(value: &str) -> Self { - ChatCompletionRequestSystemMessageContent::Text(value.into()) - } -} - -impl From for ChatCompletionRequestSystemMessageContent { - fn from(value: String) -> Self { - ChatCompletionRequestSystemMessageContent::Text(value) - } -} - -impl From<&str> for ChatCompletionRequestDeveloperMessageContent { - fn from(value: &str) -> Self { - ChatCompletionRequestDeveloperMessageContent::Text(value.into()) - } -} - -impl From for ChatCompletionRequestDeveloperMessageContent { - fn from(value: String) -> Self { - ChatCompletionRequestDeveloperMessageContent::Text(value) - } -} - -impl From<&str> for ChatCompletionRequestAssistantMessageContent { - fn from(value: &str) -> Self { - ChatCompletionRequestAssistantMessageContent::Text(value.into()) - } -} - -impl From for ChatCompletionRequestAssistantMessageContent { - fn from(value: String) -> Self { - ChatCompletionRequestAssistantMessageContent::Text(value) - } -} - -impl From<&str> for ChatCompletionRequestToolMessageContent { - fn from(value: &str) -> Self { - ChatCompletionRequestToolMessageContent::Text(value.into()) - } -} - -impl From for ChatCompletionRequestToolMessageContent { - fn from(value: String) -> Self { - ChatCompletionRequestToolMessageContent::Text(value) - } -} - -impl From<&str> for ChatCompletionRequestUserMessage { - fn from(value: &str) -> Self { - ChatCompletionRequestUserMessageContent::Text(value.into()).into() - } -} - -impl From for ChatCompletionRequestUserMessage { - fn from(value: String) -> Self { - value.as_str().into() - } -} - -impl From<&str> for ChatCompletionRequestSystemMessage { - fn from(value: &str) -> Self { - ChatCompletionRequestSystemMessageContent::Text(value.into()).into() - } -} - -impl From<&str> for ChatCompletionRequestDeveloperMessage { - fn from(value: &str) -> Self { - ChatCompletionRequestDeveloperMessageContent::Text(value.into()).into() - } -} - -impl From for ChatCompletionRequestSystemMessage { - fn from(value: String) -> Self { - value.as_str().into() - } -} - -impl From for ChatCompletionRequestDeveloperMessage { - fn from(value: String) -> Self { - value.as_str().into() - } -} - -impl From<&str> for ChatCompletionRequestAssistantMessage { - fn from(value: &str) -> Self { - ChatCompletionRequestAssistantMessageContent::Text(value.into()).into() - } -} - -impl From for ChatCompletionRequestAssistantMessage { - fn from(value: String) -> Self { - value.as_str().into() - } -} - -impl From> - for ChatCompletionRequestUserMessageContent -{ - fn from(value: Vec) -> Self { - ChatCompletionRequestUserMessageContent::Array(value) - } -} - -impl From - for ChatCompletionRequestUserMessageContentPart -{ - fn from(value: ChatCompletionRequestMessageContentPartText) -> Self { - ChatCompletionRequestUserMessageContentPart::Text(value) - } -} - -impl From - for ChatCompletionRequestUserMessageContentPart -{ - fn from(value: ChatCompletionRequestMessageContentPartImage) -> Self { - ChatCompletionRequestUserMessageContentPart::ImageUrl(value) - } -} - -impl From - for ChatCompletionRequestUserMessageContentPart -{ - fn from(value: ChatCompletionRequestMessageContentPartAudio) -> Self { - ChatCompletionRequestUserMessageContentPart::InputAudio(value) - } -} - -impl From<&str> for ChatCompletionRequestMessageContentPartText { - fn from(value: &str) -> Self { - ChatCompletionRequestMessageContentPartText { text: value.into() } - } -} - -impl From for ChatCompletionRequestMessageContentPartText { - fn from(value: String) -> Self { - ChatCompletionRequestMessageContentPartText { text: value } - } -} - -impl From<&str> for ImageUrl { - fn from(value: &str) -> Self { - Self { - url: value.into(), - detail: Default::default(), - } - } -} - -impl From for ImageUrl { - fn from(value: String) -> Self { - Self { - url: value, - detail: Default::default(), - } - } -} - -impl From for CreateMessageRequestContent { - fn from(value: String) -> Self { - Self::Content(value) - } -} - -impl From<&str> for CreateMessageRequestContent { - fn from(value: &str) -> Self { - Self::Content(value.to_string()) - } -} - -impl Default for ChatCompletionRequestUserMessageContent { - fn default() -> Self { - ChatCompletionRequestUserMessageContent::Text("".into()) - } -} - -impl Default for CreateMessageRequestContent { - fn default() -> Self { - Self::Content("".into()) - } -} - -impl Default for ChatCompletionRequestDeveloperMessageContent { - fn default() -> Self { - ChatCompletionRequestDeveloperMessageContent::Text("".into()) - } -} - -impl Default for ChatCompletionRequestSystemMessageContent { - fn default() -> Self { - ChatCompletionRequestSystemMessageContent::Text("".into()) - } -} - -impl Default for ChatCompletionRequestToolMessageContent { - fn default() -> Self { - ChatCompletionRequestToolMessageContent::Text("".into()) - } -} - -// start: types to multipart from - -impl AsyncTryFrom for reqwest::multipart::Form { - type Error = OpenAIError; - - async fn try_from(request: CreateTranscriptionRequest) -> Result { - let audio_part = create_file_part(request.file.source).await?; - - let mut form = reqwest::multipart::Form::new() - .part("file", audio_part) - .text("model", request.model); - - if let Some(language) = request.language { - form = form.text("language", language); - } - - if let Some(prompt) = request.prompt { - form = form.text("prompt", prompt); - } - - if let Some(response_format) = request.response_format { - form = form.text("response_format", response_format.to_string()) - } - - if let Some(temperature) = request.temperature { - form = form.text("temperature", temperature.to_string()) - } - - if let Some(include) = request.include { - for inc in include { - form = form.text("include[]", inc.to_string()); - } - } - - if let Some(timestamp_granularities) = request.timestamp_granularities { - for tg in timestamp_granularities { - form = form.text("timestamp_granularities[]", tg.to_string()); - } - } - - if let Some(stream) = request.stream { - form = form.text("stream", stream.to_string()); - } - - if let Some(chunking_strategy) = request.chunking_strategy { - match chunking_strategy { - TranscriptionChunkingStrategy::Auto => { - form = form.text("chunking_strategy", "auto"); - } - TranscriptionChunkingStrategy::ServerVad(vad_config) => { - form = form.text( - "chunking_strategy", - serde_json::to_string(&vad_config).unwrap().to_string(), - ); - } - } - } - - if let Some(known_speaker_names) = request.known_speaker_names { - for kn in known_speaker_names { - form = form.text("known_speaker_names[]", kn.to_string()); - } - } - - if let Some(known_speaker_references) = request.known_speaker_references { - for kn in known_speaker_references { - form = form.text("known_speaker_references[]", kn.to_string()); - } - } - - Ok(form) - } -} - -impl AsyncTryFrom for reqwest::multipart::Form { - type Error = OpenAIError; - - async fn try_from(request: CreateTranslationRequest) -> Result { - let audio_part = create_file_part(request.file.source).await?; - - let mut form = reqwest::multipart::Form::new() - .part("file", audio_part) - .text("model", request.model); - - if let Some(prompt) = request.prompt { - form = form.text("prompt", prompt); - } - - if let Some(response_format) = request.response_format { - form = form.text("response_format", response_format.to_string()) - } - - if let Some(temperature) = request.temperature { - form = form.text("temperature", temperature.to_string()) - } - Ok(form) - } -} - -impl AsyncTryFrom for reqwest::multipart::Form { - type Error = OpenAIError; - - async fn try_from(request: CreateImageEditRequest) -> Result { - let mut form = reqwest::multipart::Form::new().text("prompt", request.prompt); - - match request.image { - ImageEditInput::Image(image) => { - let image_part = create_file_part(image.source).await?; - form = form.part("image", image_part); - } - ImageEditInput::Images(images) => { - for image in images { - let image_part = create_file_part(image.source).await?; - form = form.part("image[]", image_part); - } - } - } - - if let Some(mask) = request.mask { - let mask_part = create_file_part(mask.source).await?; - form = form.part("mask", mask_part); - } - - if let Some(background) = request.background { - form = form.text("background", background.to_string()) - } - - if let Some(model) = request.model { - form = form.text("model", model.to_string()) - } - - if let Some(n) = request.n { - form = form.text("n", n.to_string()) - } - - if let Some(size) = request.size { - form = form.text("size", size.to_string()) - } - - if let Some(response_format) = request.response_format { - form = form.text("response_format", response_format.to_string()) - } - - if let Some(output_format) = request.output_format { - form = form.text("output_format", output_format.to_string()) - } - - if let Some(output_compression) = request.output_compression { - form = form.text("output_compression", output_compression.to_string()) - } - - if let Some(output_compression) = request.output_compression { - form = form.text("output_compression", output_compression.to_string()) - } - - if let Some(user) = request.user { - form = form.text("user", user) - } - - if let Some(input_fidelity) = request.input_fidelity { - form = form.text("input_fidelity", input_fidelity.to_string()) - } - - if let Some(stream) = request.stream { - form = form.text("stream", stream.to_string()) - } - - if let Some(partial_images) = request.partial_images { - form = form.text("partial_images", partial_images.to_string()) - } - - if let Some(quality) = request.quality { - form = form.text("quality", quality.to_string()) - } - - Ok(form) - } -} - -impl AsyncTryFrom for reqwest::multipart::Form { - type Error = OpenAIError; - - async fn try_from(request: CreateImageVariationRequest) -> Result { - let image_part = create_file_part(request.image.source).await?; - - let mut form = reqwest::multipart::Form::new().part("image", image_part); - - if let Some(model) = request.model { - form = form.text("model", model.to_string()) - } - - if request.n.is_some() { - form = form.text("n", request.n.unwrap().to_string()) - } - - if request.size.is_some() { - form = form.text("size", request.size.unwrap().to_string()) - } - - if request.response_format.is_some() { - form = form.text( - "response_format", - request.response_format.unwrap().to_string(), - ) - } - - if request.user.is_some() { - form = form.text("user", request.user.unwrap()) - } - Ok(form) - } -} - -impl AsyncTryFrom for reqwest::multipart::Form { - type Error = OpenAIError; - - async fn try_from(request: CreateFileRequest) -> Result { - let file_part = create_file_part(request.file.source).await?; - let mut form = reqwest::multipart::Form::new() - .part("file", file_part) - .text("purpose", request.purpose.to_string()); - - if let Some(expires_after) = request.expires_after { - form = form - .text("expires_after[anchor]", expires_after.anchor.to_string()) - .text("expires_after[seconds]", expires_after.seconds.to_string()); - } - Ok(form) - } -} - -impl AsyncTryFrom for reqwest::multipart::Form { - type Error = OpenAIError; - - async fn try_from(request: AddUploadPartRequest) -> Result { - let file_part = create_file_part(request.data).await?; - let form = reqwest::multipart::Form::new().part("data", file_part); - Ok(form) - } -} - -impl AsyncTryFrom for reqwest::multipart::Form { - type Error = OpenAIError; - - async fn try_from(request: CreateContainerFileRequest) -> Result { - let mut form = reqwest::multipart::Form::new(); - - // Either file or file_id should be provided - if let Some(file_source) = request.file { - let file_part = create_file_part(file_source).await?; - form = form.part("file", file_part); - } else if let Some(file_id) = request.file_id { - form = form.text("file_id", file_id); - } - - Ok(form) - } -} - -impl AsyncTryFrom for reqwest::multipart::Form { - type Error = OpenAIError; - - async fn try_from(request: CreateVideoRequest) -> Result { - let mut form = reqwest::multipart::Form::new().text("model", request.model); - - form = form.text("prompt", request.prompt); - - if request.size.is_some() { - form = form.text("size", request.size.unwrap().to_string()); - } - - if request.seconds.is_some() { - form = form.text("seconds", request.seconds.unwrap()); - } - - if request.input_reference.is_some() { - let image_part = create_file_part(request.input_reference.unwrap().source).await?; - form = form.part("input_reference", image_part); - } - - Ok(form) - } -} - -#[cfg(feature = "realtime")] -impl AsyncTryFrom for reqwest::multipart::Form { - type Error = OpenAIError; - - async fn try_from( - request: crate::types::realtime::RealtimeCallCreateRequest, - ) -> Result { - use reqwest::multipart::Part; - - // Create SDP part with content type application/sdp - let sdp_part = Part::text(request.sdp) - .mime_str("application/sdp") - .map_err(|e| OpenAIError::InvalidArgument(format!("Invalid content type: {}", e)))?; - - let mut form = reqwest::multipart::Form::new().part("sdp", sdp_part); - - // Add session as JSON if present - if let Some(session) = request.session { - let session_json = serde_json::to_string(&session).map_err(|e| { - OpenAIError::InvalidArgument(format!("Failed to serialize session: {}", e)) - })?; - let session_part = Part::text(session_json) - .mime_str("application/json") - .map_err(|e| { - OpenAIError::InvalidArgument(format!("Invalid content type: {}", e)) - })?; - form = form.part("session", session_part); - } - - Ok(form) - } -} - -// end: types to multipart form - -impl Default for EasyInputContent { - fn default() -> Self { - Self::Text("".to_string()) - } -} - -impl From for EasyInputContent { - fn from(value: String) -> Self { - Self::Text(value) - } -} - -impl From<&str> for EasyInputContent { - fn from(value: &str) -> Self { - Self::Text(value.to_owned()) - } -} - -// request builder impls macro - -/// Macro to implement `RequestOptionsBuilder` for wrapper types containing `RequestOptions` -macro_rules! impl_request_options_builder { - ($type:ident) => { - impl<'c, C: crate::config::Config> crate::traits::RequestOptionsBuilder for $type<'c, C> { - fn options_mut(&mut self) -> &mut crate::RequestOptions { - &mut self.request_options - } - - fn options(&self) -> &crate::RequestOptions { - &self.request_options - } - } - }; -} - -use crate::{ - admin_api_keys::AdminAPIKeys, - assistants::Assistants, - audio::Audio, - audit_logs::AuditLogs, - batches::Batches, - certificates::Certificates, - chat::Chat, - chatkit::{Chatkit, ChatkitSessions, ChatkitThreads}, - completion::Completions, - container_files::ContainerFiles, - containers::Containers, - conversation_items::ConversationItems, - conversations::Conversations, - embedding::Embeddings, - eval_run_output_items::EvalRunOutputItems, - eval_runs::EvalRuns, - evals::Evals, - file::Files, - fine_tuning::FineTuning, - image::Images, - invites::Invites, - messages::Messages, - model::Models, - moderation::Moderations, - project_api_keys::ProjectAPIKeys, - project_certificates::ProjectCertificates, - project_rate_limits::ProjectRateLimits, - project_service_accounts::ProjectServiceAccounts, - project_users::ProjectUsers, - projects::Projects, - responses::Responses, - runs::Runs, - speech::Speech, - steps::Steps, - threads::Threads, - transcriptions::Transcriptions, - translations::Translations, - uploads::Uploads, - usage::Usage, - users::Users, - vector_store_file_batches::VectorStoreFileBatches, - vector_store_files::VectorStoreFiles, - vector_stores::VectorStores, - video::Videos, -}; - -#[cfg(feature = "realtime")] -use crate::Realtime; - -impl_request_options_builder!(AdminAPIKeys); -impl_request_options_builder!(Assistants); -impl_request_options_builder!(Audio); -impl_request_options_builder!(AuditLogs); -impl_request_options_builder!(Batches); -impl_request_options_builder!(Certificates); -impl_request_options_builder!(Chat); -impl_request_options_builder!(Chatkit); -impl_request_options_builder!(ChatkitSessions); -impl_request_options_builder!(ChatkitThreads); -impl_request_options_builder!(Completions); -impl_request_options_builder!(ContainerFiles); -impl_request_options_builder!(Containers); -impl_request_options_builder!(ConversationItems); -impl_request_options_builder!(Conversations); -impl_request_options_builder!(Embeddings); -impl_request_options_builder!(Evals); -impl_request_options_builder!(EvalRunOutputItems); -impl_request_options_builder!(EvalRuns); -impl_request_options_builder!(Files); -impl_request_options_builder!(FineTuning); -impl_request_options_builder!(Images); -impl_request_options_builder!(Invites); -impl_request_options_builder!(Messages); -impl_request_options_builder!(Models); -impl_request_options_builder!(Moderations); -impl_request_options_builder!(Projects); -impl_request_options_builder!(ProjectUsers); -impl_request_options_builder!(ProjectServiceAccounts); -impl_request_options_builder!(ProjectAPIKeys); -impl_request_options_builder!(ProjectRateLimits); -impl_request_options_builder!(ProjectCertificates); -#[cfg(feature = "realtime")] -impl_request_options_builder!(Realtime); -impl_request_options_builder!(Responses); -impl_request_options_builder!(Runs); -impl_request_options_builder!(Speech); -impl_request_options_builder!(Steps); -impl_request_options_builder!(Threads); -impl_request_options_builder!(Transcriptions); -impl_request_options_builder!(Translations); -impl_request_options_builder!(Uploads); -impl_request_options_builder!(Usage); -impl_request_options_builder!(Users); -impl_request_options_builder!(VectorStoreFileBatches); -impl_request_options_builder!(VectorStoreFiles); -impl_request_options_builder!(VectorStores); -impl_request_options_builder!(Videos); diff --git a/async-openai/src/types/mcp.rs b/async-openai/src/types/mcp.rs index fae078a1..deaad094 100644 --- a/async-openai/src/types/mcp.rs +++ b/async-openai/src/types/mcp.rs @@ -135,3 +135,73 @@ pub struct MCPListToolsTool { #[serde(skip_serializing_if = "Option::is_none")] pub description: Option, } + +// MCPToolRequireApproval ergonomics + +impl From for MCPToolRequireApproval { + fn from(setting: MCPToolApprovalSetting) -> Self { + MCPToolRequireApproval::ApprovalSetting(setting) + } +} + +impl From for MCPToolRequireApproval { + fn from(filter: MCPToolApprovalFilter) -> Self { + MCPToolRequireApproval::Filter(filter) + } +} + +// MCPToolAllowedTools ergonomics + +impl From for MCPToolAllowedTools { + fn from(filter: MCPToolFilter) -> Self { + MCPToolAllowedTools::Filter(filter) + } +} + +impl From> for MCPToolAllowedTools { + fn from(tools: Vec) -> Self { + MCPToolAllowedTools::List(tools) + } +} + +impl From> for MCPToolAllowedTools { + fn from(tools: Vec<&str>) -> Self { + MCPToolAllowedTools::List(tools.into_iter().map(|s| s.to_string()).collect()) + } +} + +impl From<&[&str]> for MCPToolAllowedTools { + fn from(tools: &[&str]) -> Self { + MCPToolAllowedTools::List(tools.iter().map(|s| s.to_string()).collect()) + } +} + +impl From<[&str; N]> for MCPToolAllowedTools { + fn from(tools: [&str; N]) -> Self { + MCPToolAllowedTools::List(tools.iter().map(|s| s.to_string()).collect()) + } +} + +impl From<&Vec> for MCPToolAllowedTools { + fn from(tools: &Vec) -> Self { + MCPToolAllowedTools::List(tools.clone()) + } +} + +impl From<&Vec<&str>> for MCPToolAllowedTools { + fn from(tools: &Vec<&str>) -> Self { + MCPToolAllowedTools::List(tools.iter().map(|s| s.to_string()).collect()) + } +} + +impl From<&str> for MCPToolAllowedTools { + fn from(tool: &str) -> Self { + MCPToolAllowedTools::List(vec![tool.to_string()]) + } +} + +impl From for MCPToolAllowedTools { + fn from(tool: String) -> Self { + MCPToolAllowedTools::List(vec![tool]) + } +} diff --git a/async-openai/src/types/realtime/client_event.rs b/async-openai/src/types/realtime/client_event.rs index 0881b9b9..b9089f7b 100644 --- a/async-openai/src/types/realtime/client_event.rs +++ b/async-openai/src/types/realtime/client_event.rs @@ -1,6 +1,7 @@ use serde::{Deserialize, Serialize}; use tokio_tungstenite::tungstenite::Message; +use crate::traits::EventType; use crate::types::realtime::{RealtimeConversationItem, RealtimeResponseCreateParams, Session}; #[derive(Debug, Serialize, Deserialize, Clone)] @@ -365,3 +366,49 @@ impl From for RealtimeClientEventConversationItemCreat } } } + +// Implement EventType trait for all event types in this file + +macro_rules! impl_event_type { + ($($ty:ty => $event_type:expr),* $(,)?) => { + $( + impl EventType for $ty { + fn event_type(&self) -> &'static str { + $event_type + } + } + )* + }; +} + +impl_event_type! { + RealtimeClientEventSessionUpdate => "session.update", + RealtimeClientEventInputAudioBufferAppend => "input_audio_buffer.append", + RealtimeClientEventInputAudioBufferCommit => "input_audio_buffer.commit", + RealtimeClientEventInputAudioBufferClear => "input_audio_buffer.clear", + RealtimeClientEventConversationItemCreate => "conversation.item.create", + RealtimeClientEventConversationItemRetrieve => "conversation.item.retrieve", + RealtimeClientEventConversationItemTruncate => "conversation.item.truncate", + RealtimeClientEventConversationItemDelete => "conversation.item.delete", + RealtimeClientEventResponseCreate => "response.create", + RealtimeClientEventResponseCancel => "response.cancel", + RealtimeClientEventOutputAudioBufferClear => "output_audio_buffer.clear", +} + +impl EventType for RealtimeClientEvent { + fn event_type(&self) -> &'static str { + match self { + RealtimeClientEvent::SessionUpdate(e) => e.event_type(), + RealtimeClientEvent::InputAudioBufferAppend(e) => e.event_type(), + RealtimeClientEvent::InputAudioBufferCommit(e) => e.event_type(), + RealtimeClientEvent::InputAudioBufferClear(e) => e.event_type(), + RealtimeClientEvent::ConversationItemCreate(e) => e.event_type(), + RealtimeClientEvent::ConversationItemRetrieve(e) => e.event_type(), + RealtimeClientEvent::ConversationItemTruncate(e) => e.event_type(), + RealtimeClientEvent::ConversationItemDelete(e) => e.event_type(), + RealtimeClientEvent::ResponseCreate(e) => e.event_type(), + RealtimeClientEvent::ResponseCancel(e) => e.event_type(), + RealtimeClientEvent::OutputAudioBufferClear(e) => e.event_type(), + } + } +} diff --git a/async-openai/src/types/realtime/form.rs b/async-openai/src/types/realtime/form.rs new file mode 100644 index 00000000..96720736 --- /dev/null +++ b/async-openai/src/types/realtime/form.rs @@ -0,0 +1,35 @@ +#[cfg(feature = "realtime")] +use crate::{error::OpenAIError, traits::AsyncTryFrom}; + +#[cfg(feature = "realtime")] +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from( + request: crate::types::realtime::RealtimeCallCreateRequest, + ) -> Result { + use reqwest::multipart::Part; + + // Create SDP part with content type application/sdp + let sdp_part = Part::text(request.sdp) + .mime_str("application/sdp") + .map_err(|e| OpenAIError::InvalidArgument(format!("Invalid content type: {}", e)))?; + + let mut form = reqwest::multipart::Form::new().part("sdp", sdp_part); + + // Add session as JSON if present + if let Some(session) = request.session { + let session_json = serde_json::to_string(&session).map_err(|e| { + OpenAIError::InvalidArgument(format!("Failed to serialize session: {}", e)) + })?; + let session_part = Part::text(session_json) + .mime_str("application/json") + .map_err(|e| { + OpenAIError::InvalidArgument(format!("Invalid content type: {}", e)) + })?; + form = form.part("session", session_part); + } + + Ok(form) + } +} diff --git a/async-openai/src/types/realtime/mod.rs b/async-openai/src/types/realtime/mod.rs index d3cc8d27..3aacf4f7 100644 --- a/async-openai/src/types/realtime/mod.rs +++ b/async-openai/src/types/realtime/mod.rs @@ -2,6 +2,7 @@ mod api; mod client_event; mod conversation_item; mod error; +mod form; mod response; mod server_event; mod session; diff --git a/async-openai/src/types/realtime/server_event.rs b/async-openai/src/types/realtime/server_event.rs index 8be5c3de..bc328304 100644 --- a/async-openai/src/types/realtime/server_event.rs +++ b/async-openai/src/types/realtime/server_event.rs @@ -1,5 +1,6 @@ use serde::{Deserialize, Serialize}; +use crate::traits::EventType; use crate::types::{audio::TranscriptionUsage, LogProbProperties}; use super::{ @@ -804,3 +805,117 @@ pub enum RealtimeServerEvent { #[serde(rename = "rate_limits.updated")] RateLimitsUpdated(RealtimeServerEventRateLimitsUpdated), } + +// Implement EventType trait for all event types in this file + +macro_rules! impl_event_type { + ($($ty:ty => $event_type:expr),* $(,)?) => { + $( + impl EventType for $ty { + fn event_type(&self) -> &'static str { + $event_type + } + } + )* + }; +} + +impl_event_type! { + RealtimeServerEventError => "error", + RealtimeServerEventSessionCreated => "session.created", + RealtimeServerEventSessionUpdated => "session.updated", + RealtimeServerEventConversationItemAdded => "conversation.item.added", + RealtimeServerEventConversationItemDone => "conversation.item.done", + RealtimeServerEventInputAudioBufferCommitted => "input_audio_buffer.committed", + RealtimeServerEventInputAudioBufferCleared => "input_audio_buffer.cleared", + RealtimeServerEventInputAudioBufferSpeechStarted => "input_audio_buffer.speech_started", + RealtimeServerEventInputAudioBufferSpeechStopped => "input_audio_buffer.speech_stopped", + RealtimeServerEventInputAudioBufferTimeoutTriggered => "input_audio_buffer.timeout_triggered", + RealtimeServerEventOutputAudioBufferStarted => "output_audio_buffer.started", + RealtimeServerEventOutputAudioBufferStopped => "output_audio_buffer.stopped", + RealtimeServerEventOutputAudioBufferCleared => "output_audio_buffer.cleared", + RealtimeServerEventConversationItemInputAudioTranscriptionCompleted => "conversation.item.input_audio_transcription.completed", + RealtimeServerEventConversationItemInputAudioTranscriptionDelta => "conversation.item.input_audio_transcription.delta", + RealtimeServerEventConversationItemInputAudioTranscriptionFailed => "conversation.item.input_audio_transcription.failed", + RealtimeServerEventConversationItemTruncated => "conversation.item.truncated", + RealtimeServerEventConversationItemDeleted => "conversation.item.deleted", + RealtimeServerEventConversationItemRetrieved => "conversation.item.retrieved", + RealtimeServerEventConversationItemInputAudioTranscriptionSegment => "conversation.item.input_audio_transcription.segment", + RealtimeServerEventResponseCreated => "response.created", + RealtimeServerEventResponseDone => "response.done", + RealtimeServerEventResponseOutputItemAdded => "response.output_item.added", + RealtimeServerEventResponseOutputItemDone => "response.output_item.done", + RealtimeServerEventResponseContentPartAdded => "response.content_part.added", + RealtimeServerEventResponseContentPartDone => "response.content_part.done", + RealtimeServerEventResponseTextDelta => "response.output_text.delta", + RealtimeServerEventResponseTextDone => "response.output_text.done", + RealtimeServerEventResponseAudioTranscriptDelta => "response.output_audio_transcript.delta", + RealtimeServerEventResponseAudioTranscriptDone => "response.output_audio_transcript.done", + RealtimeServerEventResponseAudioDelta => "response.output_audio.delta", + RealtimeServerEventResponseAudioDone => "response.output_audio.done", + RealtimeServerEventResponseFunctionCallArgumentsDelta => "response.function_call_arguments.delta", + RealtimeServerEventResponseFunctionCallArgumentsDone => "response.function_call_arguments.done", + RealtimeServerEventResponseMCPCallArgumentsDelta => "response.mcp_call_arguments.delta", + RealtimeServerEventResponseMCPCallArgumentsDone => "response.mcp_call_arguments.done", + RealtimeServerEventResponseMCPCallInProgress => "response.mcp_call.in_progress", + RealtimeServerEventResponseMCPCallCompleted => "response.mcp_call.completed", + RealtimeServerEventResponseMCPCallFailed => "response.mcp_call.failed", + RealtimeServerEventMCPListToolsInProgress => "mcp_list_tools.in_progress", + RealtimeServerEventMCPListToolsCompleted => "mcp_list_tools.completed", + RealtimeServerEventMCPListToolsFailed => "mcp_list_tools.failed", + RealtimeServerEventRateLimitsUpdated => "rate_limits.updated", +} + +impl EventType for RealtimeServerEvent { + fn event_type(&self) -> &'static str { + match self { + RealtimeServerEvent::Error(e) => e.event_type(), + RealtimeServerEvent::SessionCreated(e) => e.event_type(), + RealtimeServerEvent::SessionUpdated(e) => e.event_type(), + RealtimeServerEvent::ConversationItemAdded(e) => e.event_type(), + RealtimeServerEvent::ConversationItemDone(e) => e.event_type(), + RealtimeServerEvent::InputAudioBufferCommitted(e) => e.event_type(), + RealtimeServerEvent::InputAudioBufferCleared(e) => e.event_type(), + RealtimeServerEvent::InputAudioBufferSpeechStarted(e) => e.event_type(), + RealtimeServerEvent::InputAudioBufferSpeechStopped(e) => e.event_type(), + RealtimeServerEvent::InputAudioBufferTimeoutTriggered(e) => e.event_type(), + RealtimeServerEvent::OutputAudioBufferStarted(e) => e.event_type(), + RealtimeServerEvent::OutputAudioBufferStopped(e) => e.event_type(), + RealtimeServerEvent::OutputAudioBufferCleared(e) => e.event_type(), + RealtimeServerEvent::ConversationItemInputAudioTranscriptionCompleted(e) => { + e.event_type() + } + RealtimeServerEvent::ConversationItemInputAudioTranscriptionDelta(e) => e.event_type(), + RealtimeServerEvent::ConversationItemInputAudioTranscriptionFailed(e) => e.event_type(), + RealtimeServerEvent::ConversationItemTruncated(e) => e.event_type(), + RealtimeServerEvent::ConversationItemDeleted(e) => e.event_type(), + RealtimeServerEvent::ConversationItemRetrieved(e) => e.event_type(), + RealtimeServerEvent::ConversationItemInputAudioTranscriptionSegment(e) => { + e.event_type() + } + RealtimeServerEvent::ResponseCreated(e) => e.event_type(), + RealtimeServerEvent::ResponseDone(e) => e.event_type(), + RealtimeServerEvent::ResponseOutputItemAdded(e) => e.event_type(), + RealtimeServerEvent::ResponseOutputItemDone(e) => e.event_type(), + RealtimeServerEvent::ResponseContentPartAdded(e) => e.event_type(), + RealtimeServerEvent::ResponseContentPartDone(e) => e.event_type(), + RealtimeServerEvent::ResponseOutputTextDelta(e) => e.event_type(), + RealtimeServerEvent::ResponseOutputTextDone(e) => e.event_type(), + RealtimeServerEvent::ResponseOutputAudioTranscriptDelta(e) => e.event_type(), + RealtimeServerEvent::ResponseOutputAudioTranscriptDone(e) => e.event_type(), + RealtimeServerEvent::ResponseOutputAudioDelta(e) => e.event_type(), + RealtimeServerEvent::ResponseOutputAudioDone(e) => e.event_type(), + RealtimeServerEvent::ResponseFunctionCallArgumentsDelta(e) => e.event_type(), + RealtimeServerEvent::ResponseFunctionCallArgumentsDone(e) => e.event_type(), + RealtimeServerEvent::ResponseMCPCallArgumentsDelta(e) => e.event_type(), + RealtimeServerEvent::ResponseMCPCallArgumentsDone(e) => e.event_type(), + RealtimeServerEvent::ResponseMCPCallInProgress(e) => e.event_type(), + RealtimeServerEvent::ResponseMCPCallCompleted(e) => e.event_type(), + RealtimeServerEvent::ResponseMCPCallFailed(e) => e.event_type(), + RealtimeServerEvent::MCPListToolsInProgress(e) => e.event_type(), + RealtimeServerEvent::MCPListToolsCompleted(e) => e.event_type(), + RealtimeServerEvent::MCPListToolsFailed(e) => e.event_type(), + RealtimeServerEvent::RateLimitsUpdated(e) => e.event_type(), + } + } +} diff --git a/async-openai/src/types/responses/impls.rs b/async-openai/src/types/responses/impls.rs new file mode 100644 index 00000000..9e7c755e --- /dev/null +++ b/async-openai/src/types/responses/impls.rs @@ -0,0 +1,676 @@ +use crate::types::responses::{ + ApplyPatchToolCallItemParam, ApplyPatchToolCallOutputItemParam, CodeInterpreterContainerAuto, + CodeInterpreterTool, CodeInterpreterToolCall, CodeInterpreterToolContainer, + ComputerCallOutputItemParam, ComputerToolCall, ComputerUsePreviewTool, ConversationParam, + CustomToolCall, CustomToolCallOutput, CustomToolParam, EasyInputContent, EasyInputMessage, + FileSearchTool, FileSearchToolCall, FunctionCallOutput, FunctionCallOutputItemParam, + FunctionShellCallItemParam, FunctionShellCallOutputItemParam, FunctionTool, FunctionToolCall, + ImageGenTool, ImageGenToolCall, InputContent, InputFileContent, InputImageContent, InputItem, + InputMessage, InputParam, InputTextContent, Item, ItemReference, ItemReferenceType, + LocalShellToolCall, LocalShellToolCallOutput, MCPApprovalRequest, MCPApprovalResponse, + MCPListTools, MCPToolCall, MessageItem, MessageType, OutputMessage, OutputMessageContent, + OutputTextContent, Prompt, Reasoning, ReasoningEffort, ReasoningItem, ReasoningSummary, + RefusalContent, ResponsePromptVariables, ResponseStreamOptions, ResponseTextParam, Role, + TextResponseFormatConfiguration, Tool, ToolChoiceCustom, ToolChoiceFunction, ToolChoiceMCP, + ToolChoiceOptions, ToolChoiceParam, ToolChoiceTypes, WebSearchTool, WebSearchToolCall, +}; +use crate::types::{chat::ResponseFormatJsonSchema, MCPTool}; + +impl> From for EasyInputMessage { + fn from(value: S) -> Self { + EasyInputMessage { + r#type: MessageType::Message, + role: Role::User, + content: EasyInputContent::Text(value.into()), + } + } +} + +impl From for InputItem { + fn from(msg: EasyInputMessage) -> Self { + InputItem::EasyMessage(msg) + } +} + +// InputItem ergonomics + +impl From for InputItem { + fn from(msg: InputMessage) -> Self { + InputItem::Item(Item::Message(MessageItem::Input(msg))) + } +} + +impl From for InputItem { + fn from(item: Item) -> Self { + InputItem::Item(item) + } +} + +impl From for InputItem { + fn from(item: ItemReference) -> Self { + InputItem::ItemReference(item) + } +} + +// InputParam ergonomics: from InputItem + +impl From for InputParam { + fn from(item: InputItem) -> Self { + InputParam::Items(vec![item]) + } +} + +impl From for InputParam { + fn from(item: Item) -> Self { + InputParam::Items(vec![InputItem::Item(item)]) + } +} + +impl From for InputParam { + fn from(item: MessageItem) -> Self { + InputParam::Items(vec![InputItem::Item(Item::Message(item))]) + } +} + +impl From for InputParam { + fn from(msg: InputMessage) -> Self { + InputParam::Items(vec![InputItem::Item(Item::Message(MessageItem::Input( + msg, + )))]) + } +} + +impl> From> for InputParam { + fn from(items: Vec) -> Self { + InputParam::Items(items.into_iter().map(|item| item.into()).collect()) + } +} + +impl, const N: usize> From<[I; N]> for InputParam { + fn from(items: [I; N]) -> Self { + InputParam::Items(items.into_iter().map(|item| item.into()).collect()) + } +} + +// InputParam ergonomics: from string "family" + +impl From<&str> for InputParam { + fn from(value: &str) -> Self { + InputParam::Text(value.into()) + } +} + +impl From for InputParam { + fn from(value: String) -> Self { + InputParam::Text(value) + } +} + +impl From<&String> for InputParam { + fn from(value: &String) -> Self { + InputParam::Text(value.clone()) + } +} + +// InputParam ergonomics: from vector family + +macro_rules! impl_inputparam_easy_from_collection { + // Vec + ($t:ty, $map:expr, $clone:expr) => { + impl From> for InputParam { + fn from(values: Vec<$t>) -> Self { + InputParam::Items( + values + .into_iter() + .map(|value| { + InputItem::EasyMessage(EasyInputMessage { + r#type: MessageType::Message, + role: Role::User, + content: EasyInputContent::Text($map(value)), + }) + }) + .collect(), + ) + } + } + // &[T; N] + impl From<[$t; N]> for InputParam { + fn from(values: [$t; N]) -> Self { + InputParam::Items( + values + .into_iter() + .map(|value| { + InputItem::EasyMessage(EasyInputMessage { + r#type: MessageType::Message, + role: Role::User, + content: EasyInputContent::Text($map(value)), + }) + }) + .collect(), + ) + } + } + // &Vec + impl From<&Vec<$t>> for InputParam { + fn from(values: &Vec<$t>) -> Self { + InputParam::Items( + values + .iter() + .map(|value| { + InputItem::EasyMessage(EasyInputMessage { + r#type: MessageType::Message, + role: Role::User, + content: EasyInputContent::Text($clone(value)), + }) + }) + .collect(), + ) + } + } + }; +} + +// Apply for &str +impl_inputparam_easy_from_collection!(&str, |v: &str| v.to_string(), |v: &str| v.to_string()); +// Apply for String +impl_inputparam_easy_from_collection!(String, |v: String| v, |v: &String| v.clone()); +// Apply for &String +impl_inputparam_easy_from_collection!(&String, |v: &String| v.clone(), |v: &String| v.clone()); + +// ConversationParam ergonomics + +impl> From for ConversationParam { + fn from(id: S) -> Self { + ConversationParam::ConversationID(id.into()) + } +} + +// ToolChoiceParam ergonomics + +impl From for ToolChoiceParam { + fn from(mode: ToolChoiceOptions) -> Self { + ToolChoiceParam::Mode(mode) + } +} + +impl From for ToolChoiceParam { + fn from(tool_type: ToolChoiceTypes) -> Self { + ToolChoiceParam::Hosted(tool_type) + } +} + +impl> From for ToolChoiceParam { + fn from(name: S) -> Self { + ToolChoiceParam::Function(ToolChoiceFunction { name: name.into() }) + } +} + +impl From for ToolChoiceParam { + fn from(function: ToolChoiceFunction) -> Self { + ToolChoiceParam::Function(function) + } +} + +impl From for ToolChoiceParam { + fn from(mcp: ToolChoiceMCP) -> Self { + ToolChoiceParam::Mcp(mcp) + } +} + +impl From for ToolChoiceParam { + fn from(custom: ToolChoiceCustom) -> Self { + ToolChoiceParam::Custom(custom) + } +} + +// ResponseTextParam ergonomics + +impl From for ResponseTextParam { + fn from(format: TextResponseFormatConfiguration) -> Self { + ResponseTextParam { + format, + verbosity: None, + } + } +} + +impl From for ResponseTextParam { + fn from(schema: ResponseFormatJsonSchema) -> Self { + ResponseTextParam { + format: TextResponseFormatConfiguration::JsonSchema(schema), + verbosity: None, + } + } +} + +// ResponseStreamOptions ergonomics + +impl From for ResponseStreamOptions { + fn from(include_obfuscation: bool) -> Self { + ResponseStreamOptions { + include_obfuscation: Some(include_obfuscation), + } + } +} + +// Reasoning ergonomics + +impl From for Reasoning { + fn from(effort: ReasoningEffort) -> Self { + Reasoning { + effort: Some(effort), + summary: None, + } + } +} + +impl From for Reasoning { + fn from(summary: ReasoningSummary) -> Self { + Reasoning { + effort: None, + summary: Some(summary), + } + } +} + +// Prompt ergonomics + +impl> From for Prompt { + fn from(id: S) -> Self { + Prompt { + id: id.into(), + version: None, + variables: None, + } + } +} + +// InputTextContent ergonomics + +impl> From for InputTextContent { + fn from(text: S) -> Self { + InputTextContent { text: text.into() } + } +} + +// InputContent ergonomics + +impl From for InputContent { + fn from(content: InputTextContent) -> Self { + InputContent::InputText(content) + } +} + +impl From for InputContent { + fn from(content: InputImageContent) -> Self { + InputContent::InputImage(content) + } +} + +impl From for InputContent { + fn from(content: InputFileContent) -> Self { + InputContent::InputFile(content) + } +} + +impl> From for InputContent { + fn from(text: S) -> Self { + InputContent::InputText(InputTextContent { text: text.into() }) + } +} + +// ResponsePromptVariables ergonomics + +impl From for ResponsePromptVariables { + fn from(content: InputContent) -> Self { + ResponsePromptVariables::Content(content) + } +} + +impl> From for ResponsePromptVariables { + fn from(text: S) -> Self { + ResponsePromptVariables::String(text.into()) + } +} + +// MessageItem ergonomics + +impl From for MessageItem { + fn from(msg: InputMessage) -> Self { + MessageItem::Input(msg) + } +} + +impl From for MessageItem { + fn from(msg: OutputMessage) -> Self { + MessageItem::Output(msg) + } +} + +// FunctionCallOutput ergonomics + +impl From<&str> for FunctionCallOutput { + fn from(text: &str) -> Self { + FunctionCallOutput::Text(text.to_string()) + } +} + +impl From for FunctionCallOutput { + fn from(text: String) -> Self { + FunctionCallOutput::Text(text) + } +} + +impl From> for FunctionCallOutput { + fn from(content: Vec) -> Self { + FunctionCallOutput::Content(content) + } +} + +// RefusalContent ergonomics + +impl> From for RefusalContent { + fn from(refusal: S) -> Self { + RefusalContent { + refusal: refusal.into(), + } + } +} + +// OutputMessageContent ergonomics + +impl From for OutputMessageContent { + fn from(content: OutputTextContent) -> Self { + OutputMessageContent::OutputText(content) + } +} + +impl From for OutputMessageContent { + fn from(content: RefusalContent) -> Self { + OutputMessageContent::Refusal(content) + } +} + +// Item ergonomics + +impl From for Item { + fn from(item: MessageItem) -> Self { + Item::Message(item) + } +} + +impl From for Item { + fn from(call: FileSearchToolCall) -> Self { + Item::FileSearchCall(call) + } +} + +impl From for Item { + fn from(call: ComputerToolCall) -> Self { + Item::ComputerCall(call) + } +} + +impl From for Item { + fn from(output: ComputerCallOutputItemParam) -> Self { + Item::ComputerCallOutput(output) + } +} + +impl From for Item { + fn from(call: WebSearchToolCall) -> Self { + Item::WebSearchCall(call) + } +} + +impl From for Item { + fn from(call: FunctionToolCall) -> Self { + Item::FunctionCall(call) + } +} + +impl From for Item { + fn from(output: FunctionCallOutputItemParam) -> Self { + Item::FunctionCallOutput(output) + } +} + +impl From for Item { + fn from(item: ReasoningItem) -> Self { + Item::Reasoning(item) + } +} + +impl From for Item { + fn from(call: ImageGenToolCall) -> Self { + Item::ImageGenerationCall(call) + } +} + +impl From for Item { + fn from(call: CodeInterpreterToolCall) -> Self { + Item::CodeInterpreterCall(call) + } +} + +impl From for Item { + fn from(call: LocalShellToolCall) -> Self { + Item::LocalShellCall(call) + } +} + +impl From for Item { + fn from(output: LocalShellToolCallOutput) -> Self { + Item::LocalShellCallOutput(output) + } +} + +impl From for Item { + fn from(call: FunctionShellCallItemParam) -> Self { + Item::FunctionShellCall(call) + } +} + +impl From for Item { + fn from(output: FunctionShellCallOutputItemParam) -> Self { + Item::FunctionShellCallOutput(output) + } +} + +impl From for Item { + fn from(call: ApplyPatchToolCallItemParam) -> Self { + Item::ApplyPatchCall(call) + } +} + +impl From for Item { + fn from(output: ApplyPatchToolCallOutputItemParam) -> Self { + Item::ApplyPatchCallOutput(output) + } +} + +impl From for Item { + fn from(tools: MCPListTools) -> Self { + Item::McpListTools(tools) + } +} + +impl From for Item { + fn from(request: MCPApprovalRequest) -> Self { + Item::McpApprovalRequest(request) + } +} + +impl From for Item { + fn from(response: MCPApprovalResponse) -> Self { + Item::McpApprovalResponse(response) + } +} + +impl From for Item { + fn from(call: MCPToolCall) -> Self { + Item::McpCall(call) + } +} + +impl From for Item { + fn from(output: CustomToolCallOutput) -> Self { + Item::CustomToolCallOutput(output) + } +} + +impl From for Item { + fn from(call: CustomToolCall) -> Self { + Item::CustomToolCall(call) + } +} + +// Tool ergonomics + +impl From for Tool { + fn from(tool: FunctionTool) -> Self { + Tool::Function(tool) + } +} + +impl From for Tool { + fn from(tool: FileSearchTool) -> Self { + Tool::FileSearch(tool) + } +} + +impl From for Tool { + fn from(tool: ComputerUsePreviewTool) -> Self { + Tool::ComputerUsePreview(tool) + } +} + +impl From for Tool { + fn from(tool: WebSearchTool) -> Self { + Tool::WebSearch(tool) + } +} + +impl From for Tool { + fn from(tool: MCPTool) -> Self { + Tool::Mcp(tool) + } +} + +impl From for Tool { + fn from(tool: CodeInterpreterTool) -> Self { + Tool::CodeInterpreter(tool) + } +} + +impl From for Tool { + fn from(tool: ImageGenTool) -> Self { + Tool::ImageGeneration(tool) + } +} + +impl From for Tool { + fn from(tool: CustomToolParam) -> Self { + Tool::Custom(tool) + } +} + +// Vec ergonomics + +impl From for Vec { + fn from(tool: Tool) -> Self { + vec![tool] + } +} + +impl From for Vec { + fn from(tool: FunctionTool) -> Self { + vec![Tool::Function(tool)] + } +} + +impl From for Vec { + fn from(tool: FileSearchTool) -> Self { + vec![Tool::FileSearch(tool)] + } +} + +impl From for Vec { + fn from(tool: ComputerUsePreviewTool) -> Self { + vec![Tool::ComputerUsePreview(tool)] + } +} + +impl From for Vec { + fn from(tool: WebSearchTool) -> Self { + vec![Tool::WebSearch(tool)] + } +} + +impl From for Vec { + fn from(tool: MCPTool) -> Self { + vec![Tool::Mcp(tool)] + } +} + +impl From for Vec { + fn from(tool: CodeInterpreterTool) -> Self { + vec![Tool::CodeInterpreter(tool)] + } +} + +impl From for Vec { + fn from(tool: ImageGenTool) -> Self { + vec![Tool::ImageGeneration(tool)] + } +} + +impl From for Vec { + fn from(tool: CustomToolParam) -> Self { + vec![Tool::Custom(tool)] + } +} + +// EasyInputContent ergonomics + +impl Default for EasyInputContent { + fn default() -> Self { + Self::Text("".to_string()) + } +} + +impl From for EasyInputContent { + fn from(value: String) -> Self { + Self::Text(value) + } +} + +impl From<&str> for EasyInputContent { + fn from(value: &str) -> Self { + Self::Text(value.to_owned()) + } +} + +// Defaults + +impl Default for CodeInterpreterToolContainer { + fn default() -> Self { + Self::Auto(CodeInterpreterContainerAuto::default()) + } +} + +impl Default for InputParam { + fn default() -> Self { + Self::Text(String::new()) + } +} + +impl ItemReference { + /// Create a new item reference with the given ID. + pub fn new(id: impl Into) -> Self { + Self { + r#type: Some(ItemReferenceType::ItemReference), + id: id.into(), + } + } +} diff --git a/async-openai/src/types/responses/mod.rs b/async-openai/src/types/responses/mod.rs index 65003d87..938a82f0 100644 --- a/async-openai/src/types/responses/mod.rs +++ b/async-openai/src/types/responses/mod.rs @@ -1,6 +1,8 @@ mod api; mod conversation; +mod impls; mod response; +mod sdk; mod stream; pub use api::*; diff --git a/async-openai/src/types/responses/response.rs b/async-openai/src/types/responses/response.rs index 89387d2a..62b548e2 100644 --- a/async-openai/src/types/responses/response.rs +++ b/async-openai/src/types/responses/response.rs @@ -39,12 +39,6 @@ pub enum InputParam { Items(Vec), } -impl Default for InputParam { - fn default() -> Self { - Self::Text(String::new()) - } -} - /// Content item used to generate a response. /// /// This is a properly discriminated union based on the `type` field, using Rust's @@ -174,32 +168,6 @@ pub enum InputItem { EasyMessage(EasyInputMessage), } -impl InputItem { - /// Creates an InputItem from an item reference ID. - pub fn from_reference(id: impl Into) -> Self { - Self::ItemReference(ItemReference::new(id)) - } - - /// Creates an InputItem from a structured Item. - pub fn from_item(item: Item) -> Self { - Self::Item(item) - } - - /// Creates an InputItem from an EasyInputMessage. - pub fn from_easy_message(message: EasyInputMessage) -> Self { - Self::EasyMessage(message) - } - - /// Creates a simple text message with the given role and content. - pub fn text_message(role: Role, content: impl Into) -> Self { - Self::EasyMessage(EasyInputMessage { - r#type: MessageType::Message, - role, - content: EasyInputContent::Text(content.into()), - }) - } -} - /// A message item used within the `Item` enum. /// /// Both InputMessage and OutputMessage have `type: "message"`, so we use an untagged @@ -235,16 +203,6 @@ pub struct ItemReference { pub id: String, } -impl ItemReference { - /// Create a new item reference with the given ID. - pub fn new(id: impl Into) -> Self { - Self { - r#type: Some(ItemReferenceType::ItemReference), - id: id.into(), - } - } -} - #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "snake_case")] pub enum ItemReferenceType { @@ -411,8 +369,7 @@ pub struct EasyInputMessage { /// A structured message input to the model (InputMessage in the OpenAPI spec). /// /// This variant requires structured content (not a simple string) and does not support -/// the `assistant` role (use OutputMessage for that). Used when items are returned via API -/// with additional metadata. +/// the `assistant` role (use OutputMessage for that). status is populated when items are returned via API. #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default, Builder)] #[builder( name = "InputMessageArgs", @@ -486,14 +443,14 @@ pub struct InputTextContent { pub struct InputImageContent { /// The detail level of the image to be sent to the model. One of `high`, `low`, or `auto`. /// Defaults to `auto`. - detail: ImageDetail, + pub detail: ImageDetail, /// The ID of the file to be sent to the model. #[serde(skip_serializing_if = "Option::is_none")] - file_id: Option, + pub file_id: Option, /// The URL of the image to be sent to the model. A fully qualified URL or base64 encoded image /// in a data URL. #[serde(skip_serializing_if = "Option::is_none")] - image_url: Option, + pub image_url: Option, } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default, Builder)] @@ -1264,12 +1221,6 @@ pub enum CodeInterpreterToolContainer { ContainerID(String), } -impl Default for CodeInterpreterToolContainer { - fn default() -> Self { - Self::Auto(CodeInterpreterContainerAuto::default()) - } -} - /// Auto configuration for code interpreter container. #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Default)] pub struct CodeInterpreterContainerAuto { @@ -1455,21 +1406,21 @@ pub enum ToolChoiceTypes { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct ToolChoiceFunction { /// The name of the function to call. - name: String, + pub name: String, } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct ToolChoiceMCP { /// The name of the tool to call on the server. - name: String, + pub name: String, /// The label of the MCP server to use. - server_label: String, + pub server_label: String, } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct ToolChoiceCustom { /// The name of the custom tool to call. - name: String, + pub name: String, } #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] @@ -2715,36 +2666,6 @@ pub struct Response { pub usage: Option, } -impl Response { - /// SDK-only convenience property that contains the aggregated text output from all - /// `output_text` items in the `output` array, if any are present. - pub fn output_text(&self) -> Option { - let output = self - .output - .iter() - .filter_map(|item| match item { - OutputItem::Message(msg) => Some( - msg.content - .iter() - .filter_map(|content| match content { - OutputMessageContent::OutputText(ot) => Some(ot.text.clone()), - _ => None, - }) - .collect::>(), - ), - _ => None, - }) - .flatten() - .collect::>() - .join(""); - if output.is_empty() { - None - } else { - Some(output) - } - } -} - #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(rename_all = "snake_case")] pub enum Status { diff --git a/async-openai/src/types/responses/sdk.rs b/async-openai/src/types/responses/sdk.rs new file mode 100644 index 00000000..834335c3 --- /dev/null +++ b/async-openai/src/types/responses/sdk.rs @@ -0,0 +1,31 @@ +use crate::types::responses::{OutputItem, OutputMessageContent, Response}; + +impl Response { + /// SDK-only convenience property that contains the aggregated text output from all + /// `output_text` items in the `output` array, if any are present. + pub fn output_text(&self) -> Option { + let output = self + .output + .iter() + .filter_map(|item| match item { + OutputItem::Message(msg) => Some( + msg.content + .iter() + .filter_map(|content| match content { + OutputMessageContent::OutputText(ot) => Some(ot.text.clone()), + _ => None, + }) + .collect::>(), + ), + _ => None, + }) + .flatten() + .collect::>() + .join(""); + if output.is_empty() { + None + } else { + Some(output) + } + } +} diff --git a/async-openai/src/types/responses/stream.rs b/async-openai/src/types/responses/stream.rs index fed7c221..6e58dff2 100644 --- a/async-openai/src/types/responses/stream.rs +++ b/async-openai/src/types/responses/stream.rs @@ -275,7 +275,8 @@ pub struct ResponseFunctionCallArgumentsDeltaEvent { #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub struct ResponseFunctionCallArgumentsDoneEvent { - pub name: String, + /// https://github.com/64bit/async-openai/issues/472 + pub name: Option, pub sequence_number: u64, pub item_id: String, pub output_index: u32, @@ -542,3 +543,132 @@ pub struct ResponseErrorEvent { pub message: String, pub param: Option, } + +use crate::traits::EventType; + +// Implement EventType trait for all event types in this file + +macro_rules! impl_event_type { + ($($ty:ty => $event_type:expr),* $(,)?) => { + $( + impl EventType for $ty { + fn event_type(&self) -> &'static str { + $event_type + } + } + )* + }; +} + +// Apply macro for each event struct type in this file. +impl_event_type! { + ResponseCreatedEvent => "response.created", + ResponseInProgressEvent => "response.in_progress", + ResponseCompletedEvent => "response.completed", + ResponseFailedEvent => "response.failed", + ResponseIncompleteEvent => "response.incomplete", + ResponseOutputItemAddedEvent => "response.output_item.added", + ResponseOutputItemDoneEvent => "response.output_item.done", + ResponseContentPartAddedEvent => "response.content_part.added", + ResponseContentPartDoneEvent => "response.content_part.done", + ResponseTextDeltaEvent => "response.output_text.delta", + ResponseTextDoneEvent => "response.output_text.done", + ResponseRefusalDeltaEvent => "response.refusal.delta", + ResponseRefusalDoneEvent => "response.refusal.done", + ResponseFunctionCallArgumentsDeltaEvent => "response.function_call_arguments.delta", + ResponseFunctionCallArgumentsDoneEvent => "response.function_call_arguments.done", + ResponseFileSearchCallInProgressEvent => "response.file_search_call.in_progress", + ResponseFileSearchCallSearchingEvent => "response.file_search_call.searching", + ResponseFileSearchCallCompletedEvent => "response.file_search_call.completed", + ResponseWebSearchCallInProgressEvent => "response.web_search_call.in_progress", + ResponseWebSearchCallSearchingEvent => "response.web_search_call.searching", + ResponseWebSearchCallCompletedEvent => "response.web_search_call.completed", + ResponseReasoningSummaryPartAddedEvent => "response.reasoning_summary_part.added", + ResponseReasoningSummaryPartDoneEvent => "response.reasoning_summary_part.done", + ResponseReasoningSummaryTextDeltaEvent => "response.reasoning_summary_text.delta", + ResponseReasoningSummaryTextDoneEvent => "response.reasoning_summary_text.done", + ResponseReasoningTextDeltaEvent => "response.reasoning_text.delta", + ResponseReasoningTextDoneEvent => "response.reasoning_text.done", + ResponseImageGenCallCompletedEvent => "response.image_generation_call.completed", + ResponseImageGenCallGeneratingEvent => "response.image_generation_call.generating", + ResponseImageGenCallInProgressEvent => "response.image_generation_call.in_progress", + ResponseImageGenCallPartialImageEvent => "response.image_generation_call.partial_image", + ResponseMCPCallArgumentsDeltaEvent => "response.mcp_call_arguments.delta", + ResponseMCPCallArgumentsDoneEvent => "response.mcp_call_arguments.done", + ResponseMCPCallCompletedEvent => "response.mcp_call.completed", + ResponseMCPCallFailedEvent => "response.mcp_call.failed", + ResponseMCPCallInProgressEvent => "response.mcp_call.in_progress", + ResponseMCPListToolsCompletedEvent => "response.mcp_list_tools.completed", + ResponseMCPListToolsFailedEvent => "response.mcp_list_tools.failed", + ResponseMCPListToolsInProgressEvent => "response.mcp_list_tools.in_progress", + ResponseCodeInterpreterCallInProgressEvent => "response.code_interpreter_call.in_progress", + ResponseCodeInterpreterCallInterpretingEvent => "response.code_interpreter_call.interpreting", + ResponseCodeInterpreterCallCompletedEvent => "response.code_interpreter_call.completed", + ResponseCodeInterpreterCallCodeDeltaEvent => "response.code_interpreter_call_code.delta", + ResponseCodeInterpreterCallCodeDoneEvent => "response.code_interpreter_call_code.done", + ResponseOutputTextAnnotationAddedEvent => "response.output_text.annotation.added", + ResponseQueuedEvent => "response.queued", + ResponseCustomToolCallInputDeltaEvent => "response.custom_tool_call_input.delta", + ResponseCustomToolCallInputDoneEvent => "response.custom_tool_call_input.done", + ResponseErrorEvent => "error", +} + +impl EventType for ResponseStreamEvent { + fn event_type(&self) -> &'static str { + match self { + ResponseStreamEvent::ResponseCreated(event) => event.event_type(), + ResponseStreamEvent::ResponseInProgress(event) => event.event_type(), + ResponseStreamEvent::ResponseCompleted(event) => event.event_type(), + ResponseStreamEvent::ResponseFailed(event) => event.event_type(), + ResponseStreamEvent::ResponseIncomplete(event) => event.event_type(), + ResponseStreamEvent::ResponseOutputItemAdded(event) => event.event_type(), + ResponseStreamEvent::ResponseOutputItemDone(event) => event.event_type(), + ResponseStreamEvent::ResponseContentPartAdded(event) => event.event_type(), + ResponseStreamEvent::ResponseContentPartDone(event) => event.event_type(), + ResponseStreamEvent::ResponseOutputTextDelta(event) => event.event_type(), + ResponseStreamEvent::ResponseOutputTextDone(event) => event.event_type(), + ResponseStreamEvent::ResponseRefusalDelta(event) => event.event_type(), + ResponseStreamEvent::ResponseRefusalDone(event) => event.event_type(), + ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(event) => event.event_type(), + ResponseStreamEvent::ResponseFunctionCallArgumentsDone(event) => event.event_type(), + ResponseStreamEvent::ResponseFileSearchCallInProgress(event) => event.event_type(), + ResponseStreamEvent::ResponseFileSearchCallSearching(event) => event.event_type(), + ResponseStreamEvent::ResponseFileSearchCallCompleted(event) => event.event_type(), + ResponseStreamEvent::ResponseWebSearchCallInProgress(event) => event.event_type(), + ResponseStreamEvent::ResponseWebSearchCallSearching(event) => event.event_type(), + ResponseStreamEvent::ResponseWebSearchCallCompleted(event) => event.event_type(), + ResponseStreamEvent::ResponseReasoningSummaryPartAdded(event) => event.event_type(), + ResponseStreamEvent::ResponseReasoningSummaryPartDone(event) => event.event_type(), + ResponseStreamEvent::ResponseReasoningSummaryTextDelta(event) => event.event_type(), + ResponseStreamEvent::ResponseReasoningSummaryTextDone(event) => event.event_type(), + ResponseStreamEvent::ResponseReasoningTextDelta(event) => event.event_type(), + ResponseStreamEvent::ResponseReasoningTextDone(event) => event.event_type(), + ResponseStreamEvent::ResponseImageGenerationCallCompleted(event) => event.event_type(), + ResponseStreamEvent::ResponseImageGenerationCallGenerating(event) => event.event_type(), + ResponseStreamEvent::ResponseImageGenerationCallInProgress(event) => event.event_type(), + ResponseStreamEvent::ResponseImageGenerationCallPartialImage(event) => { + event.event_type() + } + ResponseStreamEvent::ResponseMCPCallArgumentsDelta(event) => event.event_type(), + ResponseStreamEvent::ResponseMCPCallArgumentsDone(event) => event.event_type(), + ResponseStreamEvent::ResponseMCPCallCompleted(event) => event.event_type(), + ResponseStreamEvent::ResponseMCPCallFailed(event) => event.event_type(), + ResponseStreamEvent::ResponseMCPCallInProgress(event) => event.event_type(), + ResponseStreamEvent::ResponseMCPListToolsCompleted(event) => event.event_type(), + ResponseStreamEvent::ResponseMCPListToolsFailed(event) => event.event_type(), + ResponseStreamEvent::ResponseMCPListToolsInProgress(event) => event.event_type(), + ResponseStreamEvent::ResponseCodeInterpreterCallInProgress(event) => event.event_type(), + ResponseStreamEvent::ResponseCodeInterpreterCallInterpreting(event) => { + event.event_type() + } + ResponseStreamEvent::ResponseCodeInterpreterCallCompleted(event) => event.event_type(), + ResponseStreamEvent::ResponseCodeInterpreterCallCodeDelta(event) => event.event_type(), + ResponseStreamEvent::ResponseCodeInterpreterCallCodeDone(event) => event.event_type(), + ResponseStreamEvent::ResponseOutputTextAnnotationAdded(event) => event.event_type(), + ResponseStreamEvent::ResponseQueued(event) => event.event_type(), + ResponseStreamEvent::ResponseCustomToolCallInputDelta(event) => event.event_type(), + ResponseStreamEvent::ResponseCustomToolCallInputDone(event) => event.event_type(), + ResponseStreamEvent::ResponseError(event) => event.event_type(), + } + } +} diff --git a/async-openai/src/types/uploads/form.rs b/async-openai/src/types/uploads/form.rs new file mode 100644 index 00000000..d2204686 --- /dev/null +++ b/async-openai/src/types/uploads/form.rs @@ -0,0 +1,14 @@ +use crate::{ + error::OpenAIError, traits::AsyncTryFrom, types::uploads::AddUploadPartRequest, + util::create_file_part, +}; + +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from(request: AddUploadPartRequest) -> Result { + let file_part = create_file_part(request.data).await?; + let form = reqwest::multipart::Form::new().part("data", file_part); + Ok(form) + } +} diff --git a/async-openai/src/types/uploads/mod.rs b/async-openai/src/types/uploads/mod.rs index a90e3638..879d91d8 100644 --- a/async-openai/src/types/uploads/mod.rs +++ b/async-openai/src/types/uploads/mod.rs @@ -1,3 +1,4 @@ +mod form; mod upload; pub use upload::*; diff --git a/async-openai/src/types/videos/form.rs b/async-openai/src/types/videos/form.rs new file mode 100644 index 00000000..0c68a028 --- /dev/null +++ b/async-openai/src/types/videos/form.rs @@ -0,0 +1,29 @@ +use crate::{ + error::OpenAIError, traits::AsyncTryFrom, types::videos::CreateVideoRequest, + util::create_file_part, +}; + +impl AsyncTryFrom for reqwest::multipart::Form { + type Error = OpenAIError; + + async fn try_from(request: CreateVideoRequest) -> Result { + let mut form = reqwest::multipart::Form::new().text("model", request.model); + + form = form.text("prompt", request.prompt); + + if request.size.is_some() { + form = form.text("size", request.size.unwrap().to_string()); + } + + if request.seconds.is_some() { + form = form.text("seconds", request.seconds.unwrap()); + } + + if request.input_reference.is_some() { + let image_part = create_file_part(request.input_reference.unwrap().source).await?; + form = form.part("input_reference", image_part); + } + + Ok(form) + } +} diff --git a/async-openai/src/types/videos/impls.rs b/async-openai/src/types/videos/impls.rs new file mode 100644 index 00000000..eb52edb5 --- /dev/null +++ b/async-openai/src/types/videos/impls.rs @@ -0,0 +1,18 @@ +use std::fmt::Display; + +use crate::types::videos::VideoSize; + +impl Display for VideoSize { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "{}", + match self { + Self::S720x1280 => "720x1280", + Self::S1280x720 => "1280x720", + Self::S1024x1792 => "1024x1792", + Self::S1792x1024 => "1792x1024", + } + ) + } +} diff --git a/async-openai/src/types/videos/mod.rs b/async-openai/src/types/videos/mod.rs index 7bcc8a28..847df591 100644 --- a/async-openai/src/types/videos/mod.rs +++ b/async-openai/src/types/videos/mod.rs @@ -1,4 +1,6 @@ mod api; +mod form; +mod impls; mod video; pub use api::*; diff --git a/examples/conversations/src/main.rs b/examples/conversations/src/main.rs index 88306882..cdbcf3bb 100644 --- a/examples/conversations/src/main.rs +++ b/examples/conversations/src/main.rs @@ -2,8 +2,7 @@ use async_openai::{ traits::RequestOptionsBuilder, types::responses::{ ConversationItem, CreateConversationItemsRequestArgs, CreateConversationRequestArgs, - EasyInputContent, EasyInputMessage, InputItem, ListConversationItemsQuery, MessageType, - Role, UpdateConversationRequestArgs, + EasyInputMessage, ListConversationItemsQuery, UpdateConversationRequestArgs, }, Client, }; @@ -25,13 +24,10 @@ async fn main() -> Result<(), Box> { .metadata(json!({ "topic": "demo", })) - .items(vec![InputItem::from_easy_message(EasyInputMessage { - r#type: MessageType::Message, - role: Role::User, - content: EasyInputContent::Text( - "Hello! Can you help me understand conversations?".to_string(), - ), - })]) + .items(vec![EasyInputMessage::from( + "Hello! Can you help me understand conversations?", + ) + .into()]) .build()?, ) .await?; @@ -48,16 +44,8 @@ async fn main() -> Result<(), Box> { .create( CreateConversationItemsRequestArgs::default() .items(vec![ - InputItem::from_easy_message(EasyInputMessage { - r#type: MessageType::Message, - role: Role::User, - content: EasyInputContent::Text("What are the main features?".to_string()), - }), - InputItem::from_easy_message(EasyInputMessage { - r#type: MessageType::Message, - role: Role::User, - content: EasyInputContent::Text("Can you give me an example?".to_string()), - }), + EasyInputMessage::from("What are the main features?").into(), + EasyInputMessage::from("Can you give me an example?").into(), ]) .build()?, ) diff --git a/examples/function-call-stream/Cargo.toml b/examples/function-call-stream/Cargo.toml deleted file mode 100644 index 00153331..00000000 --- a/examples/function-call-stream/Cargo.toml +++ /dev/null @@ -1,13 +0,0 @@ -[package] -name = "function-call-stream" -version = "0.1.0" -edition = "2021" -publish = false - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -async-openai = {path = "../../async-openai"} -serde_json = "1.0.135" -tokio = { version = "1.43.0", features = ["full"] } -futures = "0.3.31" diff --git a/examples/function-call-stream/src/main.rs b/examples/function-call-stream/src/main.rs deleted file mode 100644 index 376d93fc..00000000 --- a/examples/function-call-stream/src/main.rs +++ /dev/null @@ -1,153 +0,0 @@ -use std::collections::HashMap; -use std::error::Error; -use std::io::{stdout, Write}; - -use async_openai::types::chat::{ - ChatCompletionRequestFunctionMessageArgs, ChatCompletionRequestUserMessageArgs, FinishReason, -}; -use async_openai::{ - types::chat::{ChatCompletionFunctionsArgs, CreateChatCompletionRequestArgs}, - Client, -}; - -use async_openai::config::OpenAIConfig; -use futures::StreamExt; -use serde_json::json; - -#[tokio::main] -async fn main() -> Result<(), Box> { - let client = Client::new(); - - let model = "gpt-4o-mini"; - - let request = CreateChatCompletionRequestArgs::default() - .max_tokens(512u32) - .model(model) - .messages([ChatCompletionRequestUserMessageArgs::default() - .content("What's the weather like in Boston?") - .build()? - .into()]) - .functions([ChatCompletionFunctionsArgs::default() - .name("get_current_weather") - .description("Get the current weather in a given location") - .parameters(json!({ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { "type": "string", "enum": ["celsius", "fahrenheit"] }, - }, - "required": ["location"], - })) - .build()?]) - .function_call("auto") - .build()?; - - let mut stream = client.chat().create_stream(request).await?; - - let mut fn_name = String::new(); - let mut fn_args = String::new(); - - let mut lock = stdout().lock(); - while let Some(result) = stream.next().await { - match result { - Ok(response) => { - for chat_choice in response.choices { - if let Some(fn_call) = &chat_choice.delta.function_call { - writeln!(lock, "function_call: {:?}", fn_call).unwrap(); - if let Some(name) = &fn_call.name { - fn_name.clone_from(name); - } - if let Some(args) = &fn_call.arguments { - fn_args.push_str(args); - } - } - if let Some(finish_reason) = &chat_choice.finish_reason { - if matches!(finish_reason, FinishReason::FunctionCall) { - call_fn(&client, &fn_name, &fn_args).await?; - } - } else if let Some(content) = &chat_choice.delta.content { - write!(lock, "{}", content).unwrap(); - } - } - } - Err(err) => { - writeln!(lock, "error: {err:?}").unwrap(); - } - } - stdout().flush()?; - } - - Ok(()) -} - -async fn call_fn( - client: &Client, - name: &str, - args: &str, -) -> Result<(), Box> { - let mut available_functions: HashMap<&str, fn(&str, &str) -> serde_json::Value> = - HashMap::new(); - available_functions.insert("get_current_weather", get_current_weather); - - let function_args: serde_json::Value = args.parse().unwrap(); - - let model = "gpt-4o-mini"; - let location = function_args["location"].as_str().unwrap(); - let unit = function_args["unit"].as_str().unwrap_or("fahrenheit"); - let function = available_functions.get(name).unwrap(); - let function_response = function(location, unit); // call the function - - let message = vec![ - ChatCompletionRequestUserMessageArgs::default() - .content("What's the weather like in Boston?") - .build()? - .into(), - ChatCompletionRequestFunctionMessageArgs::default() - .content(function_response.to_string()) - .name(name) - .build()? - .into(), - ]; - - let request = CreateChatCompletionRequestArgs::default() - .max_tokens(512u32) - .model(model) - .messages(message) - .build()?; - - // Now stream received response from model, which essentially formats the function response - let mut stream = client.chat().create_stream(request).await?; - - let mut lock = stdout().lock(); - while let Some(result) = stream.next().await { - match result { - Ok(response) => { - response.choices.iter().for_each(|chat_choice| { - if let Some(ref content) = chat_choice.delta.content { - write!(lock, "{}", content).unwrap(); - } - }); - } - Err(err) => { - writeln!(lock, "error: {err:?}").unwrap(); - } - } - stdout().flush()?; - } - println!("\n"); - Ok(()) -} - -fn get_current_weather(location: &str, unit: &str) -> serde_json::Value { - let weather_info = json!({ - "location": location, - "temperature": "72", - "unit": unit, - "forecast": ["sunny", "windy"] - }); - - weather_info -} diff --git a/examples/function-call/Cargo.toml b/examples/function-call/Cargo.toml deleted file mode 100644 index c1ff6094..00000000 --- a/examples/function-call/Cargo.toml +++ /dev/null @@ -1,13 +0,0 @@ -[package] -name = "function-call" -version = "0.1.0" -edition = "2021" -publish = false - -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html - -[dependencies] -async-openai = {path = "../../async-openai"} -serde_json = "1.0.135" -tokio = { version = "1.43.0", features = ["full"] } -tracing-subscriber = { version = "0.3.19", features = ["env-filter"]} diff --git a/examples/function-call/README.md b/examples/function-call/README.md deleted file mode 100644 index 9b86dd8d..00000000 --- a/examples/function-call/README.md +++ /dev/null @@ -1,5 +0,0 @@ -### Output - -> Response: -> -> 0: Role: assistant Content: Some("The current weather in Boston is sunny and windy with a temperature of 72 degrees Fahrenheit.") diff --git a/examples/function-call/src/main.rs b/examples/function-call/src/main.rs deleted file mode 100644 index 3ddd8a88..00000000 --- a/examples/function-call/src/main.rs +++ /dev/null @@ -1,118 +0,0 @@ -use async_openai::{ - types::chat::{ - ChatCompletionFunctionsArgs, ChatCompletionRequestFunctionMessageArgs, - ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs, - }, - Client, -}; -use serde_json::json; -use std::collections::HashMap; -use std::error::Error; -use tracing_subscriber::{fmt, prelude::*, EnvFilter}; - -#[tokio::main] -async fn main() -> Result<(), Box> { - // This should come from env var outside the program - std::env::set_var("RUST_LOG", "warn"); - - // Setup tracing subscriber so that library can log the rate limited message - tracing_subscriber::registry() - .with(fmt::layer()) - .with(EnvFilter::from_default_env()) - .init(); - - let client = Client::new(); - - let model = "gpt-4o-mini"; - - let request = CreateChatCompletionRequestArgs::default() - .max_tokens(512u32) - .model(model) - .messages([ChatCompletionRequestUserMessageArgs::default() - .content("What's the weather like in Boston?") - .build()? - .into()]) - .functions([ChatCompletionFunctionsArgs::default() - .name("get_current_weather") - .description("Get the current weather in a given location") - .parameters(json!({ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The city and state, e.g. San Francisco, CA", - }, - "unit": { "type": "string", "enum": ["celsius", "fahrenheit"] }, - }, - "required": ["location"], - })) - .build()?]) - .function_call("auto") - .build()?; - - let response_message = client - .chat() - .create(request) - .await? - .choices - .first() - .unwrap() - .message - .clone(); - - if let Some(function_call) = response_message.function_call { - let mut available_functions: HashMap<&str, fn(&str, &str) -> serde_json::Value> = - HashMap::new(); - available_functions.insert("get_current_weather", get_current_weather); - let function_name = function_call.name; - let function_args: serde_json::Value = function_call.arguments.parse().unwrap(); - - let location = function_args["location"].as_str().unwrap(); - let unit = "fahrenheit"; - let function = available_functions.get(function_name.as_str()).unwrap(); - let function_response = function(location, unit); - - let message = vec![ - ChatCompletionRequestUserMessageArgs::default() - .content("What's the weather like in Boston?") - .build()? - .into(), - ChatCompletionRequestFunctionMessageArgs::default() - .content(function_response.to_string()) - .name(function_name) - .build()? - .into(), - ]; - - println!("{}", serde_json::to_string(&message).unwrap()); - - let request = CreateChatCompletionRequestArgs::default() - .max_tokens(512u32) - .model(model) - .messages(message) - .build()?; - - let response = client.chat().create(request).await?; - - println!("\nResponse:\n"); - for choice in response.choices { - println!( - "{}: Role: {} Content: {:?}", - choice.index, choice.message.role, choice.message.content - ); - } - } - - Ok(()) -} - -fn get_current_weather(location: &str, unit: &str) -> serde_json::Value { - let weather_info = json!({ - "location": location, - "temperature": "72", - "unit": unit, - "forecast": ["sunny", "windy"] - }); - - weather_info -} diff --git a/examples/gemini-openai-compatibility/Cargo.toml b/examples/gemini-openai-compatibility/Cargo.toml index fefe9f7f..f75bbda2 100644 --- a/examples/gemini-openai-compatibility/Cargo.toml +++ b/examples/gemini-openai-compatibility/Cargo.toml @@ -2,7 +2,7 @@ name = "gemini-openai-compatibility" version = "0.1.0" edition = "2021" -rust-version.workspace = true +publish = false [dependencies] async-openai = {path = "../../async-openai", features = ["byot"]} diff --git a/examples/realtime/src/main.rs b/examples/realtime/src/main.rs index 87881734..bc124a0f 100644 --- a/examples/realtime/src/main.rs +++ b/examples/realtime/src/main.rs @@ -6,6 +6,7 @@ use async_openai::types::realtime::{ }; use futures_util::{future, pin_mut, StreamExt}; +use async_openai::traits::EventType; use tokio::io::AsyncReadExt; use tokio_tungstenite::{ connect_async, @@ -48,11 +49,7 @@ async fn main() { serde_json::from_slice(&data); match server_event { Ok(server_event) => { - let value = serde_json::to_value(&server_event).unwrap(); - let event_type = value["type"].clone(); - - eprint!("{:32} | ", event_type.as_str().unwrap()); - + eprint!("{:32} | ", server_event.event_type()); match server_event { RealtimeServerEvent::ResponseOutputItemDone(event) => { eprint!("{event:?}"); diff --git a/examples/responses-function-call/Cargo.toml b/examples/responses-function-call/Cargo.toml index b576a1f2..fae73d83 100644 --- a/examples/responses-function-call/Cargo.toml +++ b/examples/responses-function-call/Cargo.toml @@ -9,3 +9,5 @@ async-openai = {path = "../../async-openai"} serde_json = "1.0.135" tokio = { version = "1.43.0", features = ["full"] } serde = { version = "1.0.219", features = ["derive"] } +clap = { version = "4", features = ["derive"] } +futures = "0.3" diff --git a/examples/responses-function-call/src/main.rs b/examples/responses-function-call/src/main.rs index 0dcfc3e2..87816261 100644 --- a/examples/responses-function-call/src/main.rs +++ b/examples/responses-function-call/src/main.rs @@ -1,13 +1,18 @@ use async_openai::{ + traits::EventType, types::responses::{ - CreateResponseArgs, EasyInputContent, EasyInputMessage, FunctionCallOutput, - FunctionCallOutputItemParam, FunctionTool, FunctionToolCall, InputItem, InputParam, Item, - MessageType, OutputItem, Role, Tool, + CreateResponseArgs, EasyInputMessage, FunctionCallOutput, FunctionCallOutputItemParam, + FunctionTool, FunctionToolCall, InputItem, InputParam, Item, OutputItem, + ResponseStreamEvent, Tool, }, Client, }; +use clap::Parser; +use futures::StreamExt; use serde::Deserialize; +use std::collections::HashMap; use std::error::Error; +use std::io::{stdout, Write}; #[derive(Debug, Deserialize)] struct WeatherFunctionArgs { @@ -19,8 +24,27 @@ fn check_weather(location: String, units: String) -> String { format!("The weather in {location} is 25 {units}") } -#[tokio::main] -async fn main() -> Result<(), Box> { +#[derive(Parser, Debug)] +#[command(name = "responses-function-call")] +#[command(about = "Example demonstrating function calls with the Responses API")] +struct Args { + #[command(subcommand)] + command: Command, +} + +#[derive(clap::Subcommand, Debug)] +enum Command { + /// Run non-streaming function call example + NonStreaming, + /// Run streaming function call example + Streaming, + /// Run both streaming and non-streaming examples + All, +} + +async fn run_non_streaming() -> Result<(), Box> { + println!("=== Non-Streaming Function Call Example ===\n"); + let client = Client::new(); let tools = vec![Tool::Function(FunctionTool { @@ -53,20 +77,18 @@ async fn main() -> Result<(), Box> { strict: None, })]; - let mut input_messages = vec![InputItem::EasyMessage(EasyInputMessage { - r#type: MessageType::Message, - role: Role::User, - content: EasyInputContent::Text("What's the weather like in Paris today?".to_string()), - })]; + let mut input_items: Vec = + vec![EasyInputMessage::from("What's the weather like in Paris today?").into()]; let request = CreateResponseArgs::default() .max_output_tokens(512u32) .model("gpt-4.1") - .input(InputParam::Items(input_messages.clone())) + .input(InputParam::Items(input_items.clone())) .tools(tools.clone()) .build()?; - println!("{}", serde_json::to_string(&request).unwrap()); + println!("Request: {}", serde_json::to_string(&request)?); + println!("\n---\n"); let response = client.responses().create(request).await?; @@ -85,6 +107,11 @@ async fn main() -> Result<(), Box> { return Ok(()); }; + println!( + "Function call requested: {} with arguments: {}", + function_call_request.name, function_call_request.arguments + ); + let function_result = match function_call_request.name.as_str() { "get_weather" => { let args: WeatherFunctionArgs = serde_json::from_str(&function_call_request.arguments)?; @@ -96,13 +123,15 @@ async fn main() -> Result<(), Box> { } }; + println!("Function result: {}\n", function_result); + // Add the function call from the assistant back to the conversation - input_messages.push(InputItem::Item(Item::FunctionCall( + input_items.push(InputItem::Item(Item::FunctionCall( function_call_request.clone(), ))); // Add the function call output back to the conversation - input_messages.push(InputItem::Item(Item::FunctionCallOutput( + input_items.push(InputItem::Item(Item::FunctionCallOutput( FunctionCallOutputItemParam { call_id: function_call_request.call_id.clone(), output: FunctionCallOutput::Text(function_result), @@ -114,15 +143,229 @@ async fn main() -> Result<(), Box> { let request = CreateResponseArgs::default() .max_output_tokens(512u32) .model("gpt-4.1") - .input(InputParam::Items(input_messages)) + .input(InputParam::Items(input_items)) .tools(tools) .build()?; - println!("request 2 {}", serde_json::to_string(&request).unwrap()); + println!("Second request: {}", serde_json::to_string(&request)?); + println!("\n---\n"); let response = client.responses().create(request).await?; - println!("{}", serde_json::to_string(&response).unwrap()); + println!("Final response: {}", serde_json::to_string(&response)?); + + Ok(()) +} + +async fn run_streaming() -> Result<(), Box> { + println!("=== Streaming Function Call Example ===\n"); + + let client = Client::new(); + + let tools = vec![Tool::Function(FunctionTool { + name: "get_weather".to_string(), + description: Some("Retrieves current weather for the given location".to_string()), + parameters: Some(serde_json::json!( + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City and country e.g. Bogotá, Colombia" + }, + "units": { + "type": "string", + "enum": [ + "celsius", + "fahrenheit" + ], + "description": "Units the temperature will be returned in." + } + }, + "required": [ + "location", + "units" + ], + "additionalProperties": false + } + )), + strict: None, + })]; + + let mut input_items: Vec = + vec![EasyInputMessage::from("What's the weather like in Paris today?").into()]; + + let request = CreateResponseArgs::default() + .max_output_tokens(512u32) + .model("gpt-4.1") + .stream(true) + .input(InputParam::Items(input_items.clone())) + .tools(tools.clone()) + .build()?; + + println!("Request: {}", serde_json::to_string(&request)?); + println!("\n---\n"); + + let mut stream = client.responses().create_stream(request).await?; + + // Track function call arguments as they stream in + let mut function_call_args: HashMap = HashMap::new(); + // Track function call metadata (name, call_id) by item_id + let mut function_call_metadata: HashMap = HashMap::new(); + let mut function_call_request: Option = None; + let mut stdout_lock = stdout().lock(); + + while let Some(result) = stream.next().await { + match result { + Ok(event) => { + match &event { + ResponseStreamEvent::ResponseOutputItemAdded(added) => { + // When a function call item is added, extract the call_id + if let OutputItem::FunctionCall(fc) = &added.item { + let item_id = fc.id.clone().unwrap_or_default(); + function_call_metadata + .insert(item_id.clone(), (fc.name.clone(), fc.call_id.clone())); + writeln!(stdout_lock, "{}: {}\n", added.event_type(), fc.name)?; + } + } + ResponseStreamEvent::ResponseFunctionCallArgumentsDelta(delta) => { + // Accumulate function call arguments + let args = function_call_args + .entry(delta.item_id.clone()) + .or_insert_with(String::new); + args.push_str(&delta.delta); + write!(stdout_lock, "{}: {}\n", delta.event_type(), delta.delta)?; + stdout().flush()?; + } + ResponseStreamEvent::ResponseFunctionCallArgumentsDone(done) => { + // Function call arguments are complete + if let Some((name, call_id)) = function_call_metadata.get(&done.item_id) { + let arguments = function_call_args + .remove(&done.item_id) + .unwrap_or_else(|| done.arguments.clone()); + + writeln!( + stdout_lock, + "{}: [Function call complete: {}]", + done.event_type(), + name + )?; + writeln!( + stdout_lock, + "{}: Arguments: {}\n", + done.event_type(), + arguments + )?; + + // Create the function call request + function_call_request = Some(FunctionToolCall { + name: name.clone(), + arguments: arguments, + call_id: call_id.clone(), + id: Some(done.item_id.clone()), + status: None, + }); + } + } + ResponseStreamEvent::ResponseOutputTextDelta(delta) => { + write!(stdout_lock, "{}: {}\n", delta.event_type(), delta.delta)?; + stdout().flush()?; + } + _ => { + writeln!(stdout_lock, "{}: skipping\n", event.event_type())?; + } + } + } + Err(e) => { + writeln!(stdout_lock, "\nError: {:?}", e)?; + return Err(Box::new(e)); + } + } + } + + // Execute the function call if we have one + let Some(function_call_request) = function_call_request else { + println!("\nNo function_call request found"); + return Ok(()); + }; + + println!("\n---\n"); + + let function_result = match function_call_request.name.as_str() { + "get_weather" => { + let args: WeatherFunctionArgs = serde_json::from_str(&function_call_request.arguments)?; + check_weather(args.location, args.units) + } + _ => { + println!("Unknown function {}", function_call_request.name); + return Ok(()); + } + }; + + println!("Function result: {}\n", function_result); + + // Add the function call from the assistant back to the conversation + input_items.push(InputItem::Item(Item::FunctionCall( + function_call_request.clone(), + ))); + + // Add the function call output back to the conversation + input_items.push(InputItem::Item(Item::FunctionCallOutput( + FunctionCallOutputItemParam { + call_id: function_call_request.call_id.clone(), + output: FunctionCallOutput::Text(function_result), + id: None, + status: None, + }, + ))); + + let request = CreateResponseArgs::default() + .max_output_tokens(512u32) + .model("gpt-4.1") + .stream(true) + .input(InputParam::Items(input_items)) + .tools(tools) + .build()?; + + println!("Second request: {}", serde_json::to_string(&request)?); + println!("\n---\n"); + println!("Final response (streaming):\n"); + + let mut stream = client.responses().create_stream(request).await?; + let mut stdout_lock = stdout().lock(); + + while let Some(result) = stream.next().await { + match result { + Ok(event) => match &event { + ResponseStreamEvent::ResponseOutputTextDelta(delta) => { + write!(stdout_lock, "{}: {}\n", delta.event_type(), delta.delta)?; + stdout().flush()?; + } + _ => { + writeln!(stdout_lock, "{}: skipping\n", event.event_type())?; + } + }, + Err(e) => { + writeln!(stdout_lock, "\nError: {:?}", e)?; + return Err(Box::new(e)); + } + } + } Ok(()) } + +#[tokio::main] +async fn main() -> Result<(), Box> { + let args = Args::parse(); + + match args.command { + Command::NonStreaming => run_non_streaming().await, + Command::Streaming => run_streaming().await, + Command::All => { + run_non_streaming().await?; + println!("\n\n"); + run_streaming().await + } + } +} diff --git a/examples/responses-images-and-vision/Cargo.toml b/examples/responses-images-and-vision/Cargo.toml new file mode 100644 index 00000000..5f01cf38 --- /dev/null +++ b/examples/responses-images-and-vision/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "responses-images-and-vision" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +async-openai = { path = "../../async-openai" } +tokio = { version = "1.0", features = ["full"] } +futures = "0.3" +base64 = "0.22.1" +serde_json = "1.0" \ No newline at end of file diff --git a/examples/responses-images-and-vision/README.md b/examples/responses-images-and-vision/README.md new file mode 100644 index 00000000..ee580d8f --- /dev/null +++ b/examples/responses-images-and-vision/README.md @@ -0,0 +1,5 @@ +## Overview + +This example exercises as many Responses API capabilities + +Image Credit: https://unsplash.com/photos/pride-of-lion-on-field-L4-BDd01wmM \ No newline at end of file diff --git a/examples/responses-images-and-vision/src/main.rs b/examples/responses-images-and-vision/src/main.rs new file mode 100644 index 00000000..e6d7043b --- /dev/null +++ b/examples/responses-images-and-vision/src/main.rs @@ -0,0 +1,113 @@ +use std::error::Error; + +use async_openai::{ + config::OpenAIConfig, + types::{ + chat::ImageDetail, + responses::{ + CreateResponseArgs, ImageGenTool, InputContent, InputImageContent, InputMessage, + InputRole, OutputItem, OutputMessageContent, + }, + }, + Client, +}; + +use base64::{engine::general_purpose::STANDARD as BASE64_STANDARD, Engine as _}; +use std::fs::OpenOptions; +use std::io::Write; + +async fn analyze_image_url(client: &Client) -> Result<(), Box> { + let image_url = + "https://images.unsplash.com/photo-1554990772-0bea55d510d5?q=80&w=512&auto=format"; + let request = CreateResponseArgs::default() + .model("gpt-4.1-mini") + .input(InputMessage { + content: vec![ + "what is in this image? Along with count of objects in the image?".into(), + InputContent::InputImage(InputImageContent { + detail: ImageDetail::Auto, + image_url: Some(image_url.to_string()), + file_id: None, + }), + ], + role: InputRole::User, + status: None, + }) + .build()?; + + println!( + "analyze_image_url request:\n{}", + serde_json::to_string(&request)? + ); + + let response = client.responses().create(request).await?; + + for output in response.output { + match output { + OutputItem::Message(message) => { + for content in message.content { + match content { + OutputMessageContent::OutputText(text) => { + println!("Text: {:?}", text.text); + } + OutputMessageContent::Refusal(refusal) => { + println!("Refusal: {:?}", refusal.refusal); + } + } + } + } + _ => println!("Other output: {:?}", output), + } + } + + Ok(()) +} + +async fn generate_image(client: &Client) -> Result<(), Box> { + let request = CreateResponseArgs::default() + .model("gpt-4.1-mini") + .input("Generate an image of gray tabby cat hugging an otter with an orange scarf") + .tools(ImageGenTool::default()) + .build()?; + + println!( + "generate_image request:\n{}", + serde_json::to_string(&request)? + ); + + let response = client.responses().create(request).await?; + + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open("./data/image.png")?; + + for output in response.output { + match output { + OutputItem::ImageGenerationCall(image_gen_call) => { + if let Some(result) = image_gen_call.result { + println!("Image generation call has result"); + let decoded = BASE64_STANDARD.decode(&result)?; + file.write_all(&decoded)?; + } else { + println!("Image generation call has no result"); + } + } + _ => println!("Other output: {:?}", output), + } + } + + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let client = Client::new(); + + std::fs::create_dir_all("./data")?; + + generate_image(&client).await?; + analyze_image_url(&client).await?; + + Ok(()) +} diff --git a/examples/responses-stream/Cargo.toml b/examples/responses-stream/Cargo.toml index 82eb90a7..5a276ba0 100644 --- a/examples/responses-stream/Cargo.toml +++ b/examples/responses-stream/Cargo.toml @@ -2,6 +2,7 @@ name = "responses-stream" version = "0.1.0" edition = "2024" +publish = false [dependencies] async-openai = { path = "../../async-openai" } diff --git a/examples/responses-stream/src/main.rs b/examples/responses-stream/src/main.rs index 37be90c6..c01b0bd6 100644 --- a/examples/responses-stream/src/main.rs +++ b/examples/responses-stream/src/main.rs @@ -1,11 +1,10 @@ use async_openai::{ Client, - types::responses::{ - CreateResponseArgs, EasyInputContent, EasyInputMessage, InputItem, InputParam, MessageType, - ResponseStreamEvent, Role, - }, + traits::EventType, + types::responses::{CreateResponseArgs, ResponseStreamEvent}, }; use futures::StreamExt; +use std::io::{Write, stdout}; #[tokio::main] async fn main() -> Result<(), Box> { @@ -14,36 +13,28 @@ async fn main() -> Result<(), Box> { let request = CreateResponseArgs::default() .model("gpt-4.1") .stream(true) - .input(InputParam::Items(vec![InputItem::EasyMessage( - EasyInputMessage { - r#type: MessageType::Message, - role: Role::User, - content: EasyInputContent::Text("Write a haiku about programming.".to_string()), - }, - )])) + .input("Write a haiku about programming.") .build()?; let mut stream = client.responses().create_stream(request).await?; + let mut lock = stdout().lock(); + while let Some(result) = stream.next().await { match result { Ok(response_event) => match &response_event { ResponseStreamEvent::ResponseOutputTextDelta(delta) => { - print!("{}", delta.delta); - } - ResponseStreamEvent::ResponseCompleted(_) - | ResponseStreamEvent::ResponseIncomplete(_) - | ResponseStreamEvent::ResponseFailed(_) => { - break; + write!(lock, "{}", delta.delta)?; } _ => { - println!("{response_event:#?}"); + writeln!(lock, "\n{}: skipping\n", response_event.event_type())?; } }, Err(e) => { - eprintln!("{e:#?}"); + eprintln!("\n{e:#?}"); } } + stdout().flush()?; } Ok(()) diff --git a/examples/responses-structured-outputs/Cargo.toml b/examples/responses-structured-outputs/Cargo.toml new file mode 100644 index 00000000..869d3436 --- /dev/null +++ b/examples/responses-structured-outputs/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "responses-structured-outputs" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +async-openai = { path = "../../async-openai" } +serde_json = "1.0" +tokio = { version = "1", features = ["full"] } +clap = { version = "4", features = ["derive"] } +futures = "0.3" diff --git a/examples/responses-structured-outputs/src/main.rs b/examples/responses-structured-outputs/src/main.rs new file mode 100644 index 00000000..eb7fe0f7 --- /dev/null +++ b/examples/responses-structured-outputs/src/main.rs @@ -0,0 +1,456 @@ +use std::error::Error; + +use async_openai::{ + config::OpenAIConfig, + traits::EventType, + types::{ + chat::ResponseFormatJsonSchema, + responses::{ + CreateResponseArgs, InputMessage, InputRole, OutputItem, OutputMessageContent, + ResponseStreamEvent, + }, + }, + Client, +}; +use clap::Parser; +use futures::StreamExt; +use serde_json::json; +use std::io::{stdout, Write}; + +/// Chain of thought example: Guides the model through step-by-step reasoning +async fn chain_of_thought(client: &Client) -> Result<(), Box> { + println!("=== Chain of Thought Example ===\n"); + + let schema = json!({ + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "explanation": { "type": "string" }, + "output": { "type": "string" } + }, + "required": ["explanation", "output"], + "additionalProperties": false + } + }, + "final_answer": { "type": "string" } + }, + "required": ["steps", "final_answer"], + "additionalProperties": false + }); + + let request = CreateResponseArgs::default() + .model("gpt-4o-2024-08-06") + .max_output_tokens(512u32) + .text(ResponseFormatJsonSchema { + description: Some( + "A step-by-step reasoning process for solving math problems".to_string(), + ), + name: "math_reasoning".to_string(), + schema: Some(schema), + strict: Some(true), + }) + .input(vec![ + InputMessage { + role: InputRole::System, + content: vec![ + "You are a helpful math tutor. Guide the user through the solution step by step." + .into(), + ], + status: None, + }, + InputMessage { + role: InputRole::User, + content: vec!["How can I solve 8x + 7 = -23?".into()], + status: None, + }, + ]) + .build()?; + + let response = client.responses().create(request).await?; + + for output in response.output { + if let OutputItem::Message(message) = output { + for content in message.content { + if let OutputMessageContent::OutputText(text) = content { + println!("Response:\n{}\n", text.text); + } + } + } + } + + Ok(()) +} + +/// Structured data extraction example: Extracts specific fields from unstructured text +async fn structured_data_extraction(client: &Client) -> Result<(), Box> { + println!("=== Structured Data Extraction Example ===\n"); + + let schema = json!({ + "type": "object", + "properties": { + "name": { "type": "string" }, + "age": { "type": "integer" }, + "occupation": { "type": "string" }, + "location": { "type": "string" }, + "email": { "type": "string" } + }, + "required": ["name", "age", "occupation", "email", "location"], + "additionalProperties": false + }); + + let text = "Hi, I'm Sarah Johnson. I'm 28 years old and I work as a software engineer in San Francisco. You can reach me at sarah.johnson@email.com."; + + let request = CreateResponseArgs::default() + .model("gpt-4o-2024-08-06") + .max_output_tokens(256u32) + .text(ResponseFormatJsonSchema { + description: Some("Extract structured information from text".to_string()), + name: "person_info".to_string(), + schema: Some(schema), + strict: Some(true), + }) + .input(vec![ + InputMessage { + role: InputRole::System, + content: vec!["Extract the following information from the user's text: name, age, occupation, location, and email. If any information is not present, omit that field.".into()], + status: None, + }, + InputMessage { + role: InputRole::User, + content: vec![text.into()], + status: None, + }, + ]) + .build()?; + + let response = client.responses().create(request).await?; + + println!("Input text: {}\n", text); + for output in response.output { + if let OutputItem::Message(message) = output { + for content in message.content { + if let OutputMessageContent::OutputText(text) = content { + println!("Extracted data:\n{}\n", text.text); + } + } + } + } + + Ok(()) +} + +/// UI generation example: Generates UI component code based on description +async fn ui_generation(client: &Client) -> Result<(), Box> { + println!("=== UI Generation Example ===\n"); + + let schema = json!({ + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type of the UI component", + "enum": ["div", "button", "header", "section", "field", "form"] + }, + "label": { + "type": "string", + "description": "The label of the UI component, used for buttons or form fields" + }, + "children": { + "type": "array", + "description": "Nested UI components", + "items": {"$ref": "#"} + }, + "attributes": { + "type": "array", + "description": "Arbitrary attributes for the UI component, suitable for any element", + "items": { + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "The name of the attribute, for example onClick or className" + }, + "value": { + "type": "string", + "description": "The value of the attribute" + } + }, + "required": ["name", "value"], + "additionalProperties": false + } + } + }, + "required": ["type", "label", "children", "attributes"], + "additionalProperties": false + + }); + + let request = CreateResponseArgs::default() + .model("gpt-4o-2024-08-06") + .max_output_tokens(1024u32) + .text(ResponseFormatJsonSchema { + description: Some("Generate HTML and CSS code for UI components".to_string()), + name: "ui_component".to_string(), + schema: Some(schema), + strict: Some(true), + }) + .input(vec![ + InputMessage { + role: InputRole::System, + content: vec!["You are a UI designer. Generate clean, modern HTML and CSS code for the requested UI component. The HTML should be semantic and accessible, and the CSS should be well-organized.".into()], + status: None, + }, + InputMessage { + role: InputRole::User, + content: vec!["Create a login form with email and password fields, a submit button, and a link to sign up. Use a modern, clean design with a light color scheme.".into()], + status: None, + }, + ]) + .build()?; + + let response = client.responses().create(request).await?; + + for output in response.output { + if let OutputItem::Message(message) = output { + for content in message.content { + if let OutputMessageContent::OutputText(text) = content { + println!("Generated UI Component:\n{}\n", text.text); + } + } + } + } + + Ok(()) +} + +/// Moderation example: Analyzes content for policy violations with structured output +async fn moderation(client: &Client) -> Result<(), Box> { + println!("=== Moderation Example ===\n"); + + let schema = json!({ + "type": "object", + "properties": { + "is_violating": { + "type": "boolean", + "description": "Indicates if the content is violating guidelines" + }, + "category": { + "type": ["string", "null"], + "description": "Type of violation, if the content is violating guidelines. Null otherwise.", + "enum": ["violence", "sexual", "self_harm"] + }, + "explanation_if_violating": { + "type": ["string", "null"], + "description": "Explanation of why the content is violating" + } + }, + "required": ["is_violating", "category", "explanation_if_violating"], + "additionalProperties": false + }); + + let content_to_moderate = "How do I prepare for a job interview?"; + + let request = CreateResponseArgs::default() + .model("gpt-4o-2024-08-06") + .max_output_tokens(256u32) + .text(ResponseFormatJsonSchema { + description: Some("Analyze content for policy violations and provide structured moderation results".to_string()), + name: "moderation_result".to_string(), + schema: Some(schema), + strict: Some(true), + }) + .input(vec![ + InputMessage { + role: InputRole::System, + content: vec!["Determine if the user input violates specific guidelines and explain if they do.".into()], + status: None, + }, + InputMessage { + role: InputRole::User, + content: vec![content_to_moderate.into()], + status: None, + }, + ]) + .build()?; + + let response = client.responses().create(request).await?; + + for output in response.output { + if let OutputItem::Message(message) = output { + for content in message.content { + if let OutputMessageContent::OutputText(text) = content { + println!("Content analyzed: {}\n", content_to_moderate); + println!("Moderation result:\n{}\n", text.text); + } + } + } + } + + Ok(()) +} + +/// Streaming structured output example: Extract entities from text with streaming +async fn streaming_structured_output(client: &Client) -> Result<(), Box> { + println!("=== Streaming Structured Output Example ===\n"); + + let schema = json!({ + "type": "object", + "properties": { + "attributes": { + "type": "array", + "items": { "type": "string" } + }, + "colors": { + "type": "array", + "items": { "type": "string" } + }, + "animals": { + "type": "array", + "items": { "type": "string" } + } + }, + "required": ["attributes", "colors", "animals"], + "additionalProperties": false + }); + + let request = CreateResponseArgs::default() + .model("gpt-4.1") + .stream(true) + .text(ResponseFormatJsonSchema { + description: Some("Extract entities from the input text".to_string()), + name: "entities".to_string(), + schema: Some(schema), + strict: Some(true), + }) + .input(vec![ + InputMessage { + role: InputRole::System, + content: vec!["Extract entities from the input text".into()], + status: None, + }, + InputMessage { + role: InputRole::User, + content: vec![ + "The quick brown fox jumps over the lazy dog with piercing blue eyes".into(), + ], + status: None, + }, + ]) + .build()?; + + let mut stream = client.responses().create_stream(request).await?; + let mut lock = stdout().lock(); + let mut final_response = None; + + while let Some(result) = stream.next().await { + match result { + Ok(event) => match event { + ResponseStreamEvent::ResponseRefusalDelta(delta) => { + write!(lock, "{}", delta.delta)?; + lock.flush()?; + } + ResponseStreamEvent::ResponseOutputTextDelta(delta) => { + write!(lock, "{}", delta.delta)?; + lock.flush()?; + } + ResponseStreamEvent::ResponseError(error) => { + writeln!(lock, "\nError: {}", error.message)?; + if let Some(code) = &error.code { + writeln!(lock, "Code: {}", code)?; + } + if let Some(param) = &error.param { + writeln!(lock, "Param: {}", param)?; + } + } + ResponseStreamEvent::ResponseCompleted(completed) => { + writeln!(lock, "\nCompleted")?; + final_response = Some(completed.response); + break; + } + _ => { + writeln!(lock, "\n{}: skipping\n", event.event_type())?; + } + }, + Err(e) => { + writeln!(lock, "\nStream error: {:#?}", e)?; + } + } + } + + if let Some(response) = final_response { + writeln!(lock, "\nFinal response:")?; + for output in response.output { + if let OutputItem::Message(message) = output { + for content in message.content { + if let OutputMessageContent::OutputText(text) = content { + writeln!(lock, "{}", text.text)?; + } + } + } + } + } + + Ok(()) +} + +#[derive(Parser, Debug)] +#[command(name = "responses-structured-outputs")] +#[command(about = "Examples of structured outputs using the Responses API", long_about = None)] +struct Cli { + /// Which example to run + #[arg(value_enum)] + example: Example, +} + +#[derive(clap::ValueEnum, Clone, Debug)] +enum Example { + /// Chain of thought: Step-by-step reasoning for math problems + ChainOfThought, + /// Structured data extraction: Extract fields from unstructured text + DataExtraction, + /// UI generation: Generate HTML and CSS for UI components + UiGeneration, + /// Moderation: Analyze content for policy violations + Moderation, + /// Streaming structured output: Extract entities with streaming + Streaming, + /// Run all examples + All, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let cli = Cli::parse(); + let client = Client::new(); + + match cli.example { + Example::ChainOfThought => { + chain_of_thought(&client).await?; + } + Example::DataExtraction => { + structured_data_extraction(&client).await?; + } + Example::UiGeneration => { + ui_generation(&client).await?; + } + Example::Moderation => { + moderation(&client).await?; + } + Example::Streaming => { + streaming_structured_output(&client).await?; + } + Example::All => { + chain_of_thought(&client).await?; + structured_data_extraction(&client).await?; + ui_generation(&client).await?; + moderation(&client).await?; + streaming_structured_output(&client).await?; + } + } + + Ok(()) +} diff --git a/examples/responses/src/main.rs b/examples/responses/src/main.rs index 0ee2b16c..e96082f4 100644 --- a/examples/responses/src/main.rs +++ b/examples/responses/src/main.rs @@ -3,11 +3,10 @@ use std::error::Error; use async_openai::{ types::{ responses::{ - CreateResponseArgs, EasyInputContent, EasyInputMessage, InputItem, InputParam, - MessageType, ResponseTextParam, Role, TextResponseFormatConfiguration, Tool, Verbosity, - WebSearchToolArgs, + CreateResponseArgs, ResponseTextParam, TextResponseFormatConfiguration, Tool, + Verbosity, WebSearchTool, }, - MCPToolAllowedTools, MCPToolApprovalSetting, MCPToolArgs, MCPToolRequireApproval, + MCPToolApprovalSetting, MCPToolArgs, }, Client, }; @@ -23,30 +22,27 @@ async fn main() -> Result<(), Box> { format: TextResponseFormatConfiguration::Text, verbosity: Some(Verbosity::Medium), // only here to test the config, but gpt-4.1 only supports medium }) - .input(InputParam::Items(vec![InputItem::EasyMessage( - EasyInputMessage { - r#type: MessageType::Message, - role: Role::User, - content: EasyInputContent::Text("What transport protocols does the 2025-03-26 version of the MCP spec (modelcontextprotocol/modelcontextprotocol) support?".to_string()), - } - )])) + .input([ + "What transport protocols does the 2025-03-26 version of the MCP spec (modelcontextprotocol/modelcontextprotocol) support?", + "what is MCP?" + ]) .tools(vec![ - Tool::WebSearchPreview(WebSearchToolArgs::default().build()?), + Tool::WebSearchPreview(WebSearchTool::default()), Tool::Mcp(MCPToolArgs::default() .server_label("deepwiki") .server_url("https://mcp.deepwiki.com/mcp") - .require_approval(MCPToolRequireApproval::ApprovalSetting(MCPToolApprovalSetting::Never)) - .allowed_tools(MCPToolAllowedTools::List(vec!["ask_question".to_string()])) + .require_approval(MCPToolApprovalSetting::Never) + .allowed_tools(["ask_question"]) .build()?), ]) .build()?; - println!("{}", serde_json::to_string(&request).unwrap()); + println!("Request:\n{}", serde_json::to_string(&request).unwrap()); let response = client.responses().create(request).await?; - let output_text = response.output_text().unwrap_or("Empty text output".into()); - println!("\nOutput Text: {output_text:?}\n",); + println!("\n SDK: output_text()\n: {:?}", response.output_text()); + for output in response.output { println!("\nOutput: {:?}\n", output); } diff --git a/examples/webhooks/Cargo.toml b/examples/webhooks/Cargo.toml index 9a44206f..fb004089 100644 --- a/examples/webhooks/Cargo.toml +++ b/examples/webhooks/Cargo.toml @@ -2,6 +2,7 @@ name = "webhooks" version = "0.1.0" edition = "2021" +publish = false [dependencies] async-openai = { path = "../../async-openai", features = ["webhook"] }