diff --git a/async-openai/README.md b/async-openai/README.md index 471fc1b9..e723e3aa 100644 --- a/async-openai/README.md +++ b/async-openai/README.md @@ -23,24 +23,28 @@ - It's based on [OpenAI OpenAPI spec](https://github.com/openai/openai-openapi) - Current features: - - [x] Assistants (v2) + - [x] Administration (partially implemented) + - [x] Assistants (beta) - [x] Audio - [x] Batch - [x] Chat - - [x] Completions (Legacy) + - [x] ChatKit (beta) + - [x] Completions (legacy) - [x] Conversations - - [x] Containers | Container Files + - [x] Containers - [x] Embeddings + - [x] Evals - [x] Files - [x] Fine-Tuning - [x] Images - [x] Models - [x] Moderations - - [x] Organizations | Administration (partially implemented) - - [x] Realtime GA (partially implemented) + - [x] Realtime (partially implemented) - [x] Responses - [x] Uploads + - [x] Vector Stores - [x] Videos + - [x] Webhooks - Bring your own custom types for Request or Response objects. - SSE streaming on available APIs - Requests (except SSE streaming) including form submissions are retried with exponential backoff when [rate limited](https://platform.openai.com/docs/guides/rate-limits). diff --git a/async-openai/src/chatkit.rs b/async-openai/src/chatkit.rs new file mode 100644 index 00000000..3eb795fb --- /dev/null +++ b/async-openai/src/chatkit.rs @@ -0,0 +1,116 @@ +use serde::Serialize; + +use crate::{ + config::Config, + error::OpenAIError, + types::chatkit::{ + ChatSessionResource, CreateChatSessionBody, DeletedThreadResource, ThreadItemListResource, + ThreadListResource, ThreadResource, + }, + Client, +}; + +/// ChatKit API for managing sessions and threads. +/// +/// Related guide: [ChatKit](https://platform.openai.com/docs/api-reference/chatkit) +pub struct Chatkit<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> Chatkit<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Access sessions API. + pub fn sessions(&self) -> ChatkitSessions<'_, C> { + ChatkitSessions::new(self.client) + } + + /// Access threads API. + pub fn threads(&self) -> ChatkitThreads<'_, C> { + ChatkitThreads::new(self.client) + } +} + +/// ChatKit sessions API. +pub struct ChatkitSessions<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> ChatkitSessions<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// Create a ChatKit session. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn create( + &self, + request: CreateChatSessionBody, + ) -> Result { + self.client.post("/chatkit/sessions", request).await + } + + /// Cancel a ChatKit session. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn cancel(&self, session_id: &str) -> Result { + self.client + .post( + &format!("/chatkit/sessions/{session_id}/cancel"), + serde_json::json!({}), + ) + .await + } +} + +/// ChatKit threads API. +pub struct ChatkitThreads<'c, C: Config> { + client: &'c Client, +} + +impl<'c, C: Config> ChatkitThreads<'c, C> { + pub fn new(client: &'c Client) -> Self { + Self { client } + } + + /// List ChatKit threads. + #[crate::byot(T0 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list(&self, query: &Q) -> Result + where + Q: Serialize + ?Sized, + { + self.client.get_with_query("/chatkit/threads", &query).await + } + + /// Retrieve a ChatKit thread. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn retrieve(&self, thread_id: &str) -> Result { + self.client + .get(&format!("/chatkit/threads/{thread_id}")) + .await + } + + /// Delete a ChatKit thread. + #[crate::byot(T0 = std::fmt::Display, R = serde::de::DeserializeOwned)] + pub async fn delete(&self, thread_id: &str) -> Result { + self.client + .delete(&format!("/chatkit/threads/{thread_id}")) + .await + } + + /// List ChatKit thread items. + #[crate::byot(T0 = std::fmt::Display, T1 = serde::Serialize, R = serde::de::DeserializeOwned)] + pub async fn list_items( + &self, + thread_id: &str, + query: &Q, + ) -> Result + where + Q: Serialize + ?Sized, + { + self.client + .get_with_query(&format!("/chatkit/threads/{thread_id}/items"), &query) + .await + } +} diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index dadef6a5..b78c72fb 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -7,6 +7,7 @@ use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, Request use serde::{de::DeserializeOwned, Serialize}; use crate::{ + chatkit::Chatkit, config::{Config, OpenAIConfig}, error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError}, file::Files, @@ -188,6 +189,10 @@ impl Client { Evals::new(self) } + pub fn chatkit(&self) -> Chatkit<'_, C> { + Chatkit::new(self) + } + pub fn config(&self) -> &C { &self.config } diff --git a/async-openai/src/config.rs b/async-openai/src/config.rs index 294c8780..d694c891 100644 --- a/async-openai/src/config.rs +++ b/async-openai/src/config.rs @@ -3,6 +3,8 @@ use reqwest::header::{HeaderMap, AUTHORIZATION}; use secrecy::{ExposeSecret, SecretString}; use serde::Deserialize; +use crate::error::OpenAIError; + /// Default v1 API base url pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1"; /// Organization header @@ -59,6 +61,8 @@ pub struct OpenAIConfig { api_key: SecretString, org_id: String, project_id: String, + #[serde(skip)] + custom_headers: HeaderMap, } impl Default for OpenAIConfig { @@ -70,6 +74,7 @@ impl Default for OpenAIConfig { .into(), org_id: Default::default(), project_id: Default::default(), + custom_headers: HeaderMap::new(), } } } @@ -104,6 +109,21 @@ impl OpenAIConfig { self } + /// Add a custom header that will be included in all requests. + /// Headers are merged with existing headers, with custom headers taking precedence. + pub fn with_header(mut self, key: K, value: V) -> Result + where + K: reqwest::header::IntoHeaderName, + V: TryInto, + V::Error: Into, + { + let header_value = value.try_into().map_err(|e| { + OpenAIError::InvalidArgument(format!("Invalid header value: {}", e.into())) + })?; + self.custom_headers.insert(key, header_value); + Ok(self) + } + pub fn org_id(&self) -> &str { &self.org_id } @@ -134,9 +154,10 @@ impl Config for OpenAIConfig { .unwrap(), ); - // hack for Assistants APIs - // Calls to the Assistants API require that you pass a Beta header - // headers.insert(OPENAI_BETA_HEADER, "assistants=v2".parse().unwrap()); + // Merge custom headers, with custom headers taking precedence + for (key, value) in self.custom_headers.iter() { + headers.insert(key, value.clone()); + } headers } diff --git a/async-openai/src/lib.rs b/async-openai/src/lib.rs index b58904ae..f570ba63 100644 --- a/async-openai/src/lib.rs +++ b/async-openai/src/lib.rs @@ -145,6 +145,7 @@ mod audio; mod audit_logs; mod batches; mod chat; +mod chatkit; mod client; mod completion; pub mod config; @@ -193,6 +194,7 @@ pub use audio::Audio; pub use audit_logs::AuditLogs; pub use batches::Batches; pub use chat::Chat; +pub use chatkit::Chatkit; pub use client::Client; pub use completion::Completions; pub use container_files::ContainerFiles; diff --git a/async-openai/src/types/chatkit/mod.rs b/async-openai/src/types/chatkit/mod.rs new file mode 100644 index 00000000..ad660ea1 --- /dev/null +++ b/async-openai/src/types/chatkit/mod.rs @@ -0,0 +1,5 @@ +mod session; +mod thread; + +pub use session::*; +pub use thread::*; diff --git a/async-openai/src/types/chatkit/session.rs b/async-openai/src/types/chatkit/session.rs new file mode 100644 index 00000000..4eed1dcf --- /dev/null +++ b/async-openai/src/types/chatkit/session.rs @@ -0,0 +1,272 @@ +use std::collections::HashMap; + +use derive_builder::Builder; +use serde::{Deserialize, Serialize}; + +use crate::error::OpenAIError; + +/// Represents a ChatKit session and its resolved configuration. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ChatSessionResource { + /// Identifier for the ChatKit session. + pub id: String, + /// Type discriminator that is always `chatkit.session`. + #[serde(default = "default_session_object")] + pub object: String, + /// Unix timestamp (in seconds) for when the session expires. + pub expires_at: i64, + /// Ephemeral client secret that authenticates session requests. + pub client_secret: String, + /// Workflow metadata for the session. + pub workflow: ChatkitWorkflow, + /// User identifier associated with the session. + pub user: String, + /// Resolved rate limit values. + pub rate_limits: ChatSessionRateLimits, + /// Convenience copy of the per-minute request limit. + pub max_requests_per_1_minute: i32, + /// Current lifecycle state of the session. + pub status: ChatSessionStatus, + /// Resolved ChatKit feature configuration for the session. + pub chatkit_configuration: ChatSessionChatkitConfiguration, +} + +fn default_session_object() -> String { + "chatkit.session".to_string() +} + +/// Workflow metadata and state returned for the session. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ChatkitWorkflow { + /// Identifier of the workflow backing the session. + pub id: String, + /// Specific workflow version used for the session. Defaults to null when using the latest deployment. + #[serde(skip_serializing_if = "Option::is_none")] + pub version: Option, + /// State variable key-value pairs applied when invoking the workflow. Defaults to null when no overrides were provided. + #[serde(skip_serializing_if = "Option::is_none")] + pub state_variables: Option>, + /// Tracing settings applied to the workflow. + pub tracing: ChatkitWorkflowTracing, +} + +/// Controls diagnostic tracing during the session. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ChatkitWorkflowTracing { + /// Indicates whether tracing is enabled. + pub enabled: bool, +} + +/// Active per-minute request limit for the session. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ChatSessionRateLimits { + /// Maximum allowed requests per one-minute window. + pub max_requests_per_1_minute: i32, +} + +/// Current lifecycle state of the session. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum ChatSessionStatus { + Active, + Expired, + Cancelled, +} + +/// ChatKit configuration for the session. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ChatSessionChatkitConfiguration { + /// Automatic thread titling preferences. + pub automatic_thread_titling: ChatSessionAutomaticThreadTitling, + /// Upload settings for the session. + pub file_upload: ChatSessionFileUpload, + /// History retention configuration. + pub history: ChatSessionHistory, +} + +/// Automatic thread title preferences for the session. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ChatSessionAutomaticThreadTitling { + /// Whether automatic thread titling is enabled. + pub enabled: bool, +} + +/// Upload permissions and limits applied to the session. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ChatSessionFileUpload { + /// Indicates if uploads are enabled for the session. + pub enabled: bool, + /// Maximum upload size in megabytes. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_file_size: Option, + /// Maximum number of uploads allowed during the session. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_files: Option, +} + +/// History retention preferences returned for the session. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ChatSessionHistory { + /// Indicates if chat history is persisted for the session. + pub enabled: bool, + /// Number of prior threads surfaced in history views. Defaults to null when all history is retained. + #[serde(skip_serializing_if = "Option::is_none")] + pub recent_threads: Option, +} + +/// Parameters for provisioning a new ChatKit session. +#[derive(Clone, Serialize, Debug, Deserialize, Builder, PartialEq, Default)] +#[builder(name = "CreateChatSessionRequestArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct CreateChatSessionBody { + /// Workflow that powers the session. + pub workflow: WorkflowParam, + /// A free-form string that identifies your end user; ensures this Session can access other objects that have the same `user` scope. + pub user: String, + /// Optional override for session expiration timing in seconds from creation. Defaults to 10 minutes. + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_after: Option, + /// Optional override for per-minute request limits. When omitted, defaults to 10. + #[serde(skip_serializing_if = "Option::is_none")] + pub rate_limits: Option, + /// Optional overrides for ChatKit runtime configuration features + #[serde(skip_serializing_if = "Option::is_none")] + pub chatkit_configuration: Option, +} + +/// Workflow reference and overrides applied to the chat session. +#[derive(Clone, Serialize, Debug, Deserialize, Builder, PartialEq, Default)] +#[builder(name = "WorkflowParamArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct WorkflowParam { + /// Identifier for the workflow invoked by the session. + pub id: String, + /// Specific workflow version to run. Defaults to the latest deployed version. + #[serde(skip_serializing_if = "Option::is_none")] + pub version: Option, + /// State variables forwarded to the workflow. Keys may be up to 64 characters, values must be primitive types, and the map defaults to an empty object. + #[serde(skip_serializing_if = "Option::is_none")] + pub state_variables: Option>, + /// Optional tracing overrides for the workflow invocation. When omitted, tracing is enabled by default. + #[serde(skip_serializing_if = "Option::is_none")] + pub tracing: Option, +} + +/// Controls diagnostic tracing during the session. +#[derive(Clone, Serialize, Default, Debug, Deserialize, Builder, PartialEq)] +#[builder(name = "WorkflowTracingParamArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct WorkflowTracingParam { + /// Whether tracing is enabled during the session. Defaults to true. + #[serde(skip_serializing_if = "Option::is_none")] + pub enabled: Option, +} + +/// Controls when the session expires relative to an anchor timestamp. +#[derive(Clone, Serialize, Debug, Deserialize, Builder, PartialEq, Default)] +#[builder(name = "ExpiresAfterParamArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ExpiresAfterParam { + /// Base timestamp used to calculate expiration. Currently fixed to `created_at`. + #[serde(default = "default_anchor")] + #[builder(default = "default_anchor()")] + pub anchor: String, + /// Number of seconds after the anchor when the session expires. + pub seconds: i32, +} + +fn default_anchor() -> String { + "created_at".to_string() +} + +/// Controls request rate limits for the session. +#[derive(Clone, Serialize, Debug, Deserialize, Builder, PartialEq, Default)] +#[builder(name = "RateLimitsParamArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct RateLimitsParam { + /// Maximum number of requests allowed per minute for the session. Defaults to 10. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_requests_per_1_minute: Option, +} + +/// Optional per-session configuration settings for ChatKit behavior. +#[derive(Clone, Serialize, Debug, Deserialize, Builder, PartialEq, Default)] +#[builder(name = "ChatkitConfigurationParamArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct ChatkitConfigurationParam { + /// Configuration for automatic thread titling. When omitted, automatic thread titling is enabled by default. + #[serde(skip_serializing_if = "Option::is_none")] + pub automatic_thread_titling: Option, + /// Configuration for upload enablement and limits. When omitted, uploads are disabled by default (max_files 10, max_file_size 512 MB). + #[serde(skip_serializing_if = "Option::is_none")] + pub file_upload: Option, + /// Configuration for chat history retention. When omitted, history is enabled by default with no limit on recent_threads (null). + #[serde(skip_serializing_if = "Option::is_none")] + pub history: Option, +} + +/// Controls whether ChatKit automatically generates thread titles. +#[derive(Clone, Serialize, Debug, Deserialize, Builder, PartialEq, Default)] +#[builder(name = "AutomaticThreadTitlingParamArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct AutomaticThreadTitlingParam { + /// Enable automatic thread title generation. Defaults to true. + #[serde(skip_serializing_if = "Option::is_none")] + pub enabled: Option, +} + +/// Controls whether users can upload files. +#[derive(Clone, Serialize, Debug, Deserialize, Builder, PartialEq, Default)] +#[builder(name = "FileUploadParamArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct FileUploadParam { + /// Enable uploads for this session. Defaults to false. + #[serde(skip_serializing_if = "Option::is_none")] + pub enabled: Option, + /// Maximum size in megabytes for each uploaded file. Defaults to 512 MB, which is the maximum allowable size. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_file_size: Option, + /// Maximum number of files that can be uploaded to the session. Defaults to 10. + #[serde(skip_serializing_if = "Option::is_none")] + pub max_files: Option, +} + +/// Controls how much historical context is retained for the session. +#[derive(Clone, Serialize, Debug, Deserialize, Builder, PartialEq, Default)] +#[builder(name = "HistoryParamArgs")] +#[builder(pattern = "mutable")] +#[builder(setter(into, strip_option), default)] +#[builder(derive(Debug))] +#[builder(build_fn(error = "OpenAIError"))] +pub struct HistoryParam { + /// Enables chat users to access previous ChatKit threads. Defaults to true. + #[serde(skip_serializing_if = "Option::is_none")] + pub enabled: Option, + /// Number of recent ChatKit threads users have access to. Defaults to unlimited when unset. + #[serde(skip_serializing_if = "Option::is_none")] + pub recent_threads: Option, +} diff --git a/async-openai/src/types/chatkit/thread.rs b/async-openai/src/types/chatkit/thread.rs new file mode 100644 index 00000000..cf2671da --- /dev/null +++ b/async-openai/src/types/chatkit/thread.rs @@ -0,0 +1,399 @@ +use serde::{Deserialize, Serialize}; + +/// Represents a ChatKit thread and its current status. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ThreadResource { + /// Identifier of the thread. + pub id: String, + /// Type discriminator that is always `chatkit.thread`. + #[serde(default = "default_thread_object")] + pub object: String, + /// Unix timestamp (in seconds) for when the thread was created. + pub created_at: i64, + /// Optional human-readable title for the thread. Defaults to null when no title has been generated. + pub title: Option, + /// Current status for the thread. Defaults to `active` for newly created threads. + #[serde(flatten)] + pub status: ThreadStatus, + /// Free-form string that identifies your end user who owns the thread. + pub user: String, + /// Thread items (only present when retrieving a thread) + #[serde(skip_serializing_if = "Option::is_none")] + pub items: Option, +} + +fn default_thread_object() -> String { + "chatkit.thread".to_string() +} + +/// Current status for the thread. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ThreadStatus { + /// Indicates that a thread is active. + Active, + /// Indicates that a thread is locked and cannot accept new input. + Locked { reason: Option }, + /// Indicates that a thread has been closed. + Closed { reason: Option }, +} + +/// A paginated list of ChatKit threads. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +pub struct ThreadListResource { + /// The type of object returned, must be `list`. + #[serde(default = "default_list_object")] + pub object: String, + /// A list of items + pub data: Vec, + /// The ID of the first item in the list. + pub first_id: Option, + /// The ID of the last item in the list. + pub last_id: Option, + /// Whether there are more items available. + pub has_more: bool, +} + +fn default_list_object() -> String { + "list".to_string() +} + +/// Confirmation payload returned after deleting a thread. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +pub struct DeletedThreadResource { + /// Identifier of the deleted thread. + pub id: String, + /// Type discriminator that is always `chatkit.thread.deleted`. + #[serde(default = "default_deleted_object")] + pub object: String, + /// Indicates that the thread has been deleted. + pub deleted: bool, +} + +fn default_deleted_object() -> String { + "chatkit.thread.deleted".to_string() +} + +/// A paginated list of thread items rendered for the ChatKit API. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +pub struct ThreadItemListResource { + /// The type of object returned, must be `list`. + #[serde(default = "default_list_object")] + pub object: String, + /// A list of items + pub data: Vec, + /// The ID of the first item in the list. + pub first_id: Option, + /// The ID of the last item in the list. + pub last_id: Option, + /// Whether there are more items available. + pub has_more: bool, +} + +/// The thread item - discriminated union based on type field. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ThreadItem { + /// User-authored messages within a thread. + #[serde(rename = "chatkit.user_message")] + UserMessage(UserMessageItem), + /// Assistant-authored message within a thread. + #[serde(rename = "chatkit.assistant_message")] + AssistantMessage(AssistantMessageItem), + /// Thread item that renders a widget payload. + #[serde(rename = "chatkit.widget")] + Widget(WidgetMessageItem), + /// Record of a client side tool invocation initiated by the assistant. + #[serde(rename = "chatkit.client_tool_call")] + ClientToolCall(ClientToolCallItem), + /// Task emitted by the workflow to show progress and status updates. + #[serde(rename = "chatkit.task")] + Task(TaskItem), + /// Collection of workflow tasks grouped together in the thread. + #[serde(rename = "chatkit.task_group")] + TaskGroup(TaskGroupItem), +} + +/// User-authored messages within a thread. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +pub struct UserMessageItem { + /// Identifier of the thread item. + pub id: String, + /// Type discriminator that is always `chatkit.thread_item`. + #[serde(default = "default_thread_item_object")] + pub object: String, + /// Unix timestamp (in seconds) for when the item was created. + pub created_at: i64, + /// Identifier of the parent thread. + pub thread_id: String, + /// Ordered content elements supplied by the user. + pub content: Vec, + /// Attachments associated with the user message. Defaults to an empty list. + #[serde(default)] + pub attachments: Vec, + /// Inference overrides applied to the message. Defaults to null when unset. + #[serde(skip_serializing_if = "Option::is_none")] + pub inference_options: Option, +} + +fn default_thread_item_object() -> String { + "chatkit.thread_item".to_string() +} + +/// Content blocks that comprise a user message. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum UserMessageContent { + /// Text block that a user contributed to the thread. + #[serde(rename = "input_text")] + InputText { text: String }, + /// Quoted snippet that the user referenced in their message. + #[serde(rename = "quoted_text")] + QuotedText { text: String }, +} + +/// Attachment metadata included on thread items. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct Attachment { + /// Attachment discriminator. + #[serde(rename = "type")] + pub attachment_type: AttachmentType, + /// Identifier for the attachment. + pub id: String, + /// Original display name for the attachment. + pub name: String, + /// MIME type of the attachment. + pub mime_type: String, + /// Preview URL for rendering the attachment inline. + pub preview_url: Option, +} + +/// Attachment discriminator. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum AttachmentType { + Image, + File, +} + +/// Model and tool overrides applied when generating the assistant response. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +pub struct InferenceOptions { + /// Preferred tool to invoke. Defaults to null when ChatKit should auto-select. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + /// Model name that generated the response. Defaults to null when using the session default. + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, +} + +/// Tool selection that the assistant should honor when executing the item. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +pub struct ToolChoice { + /// Identifier of the requested tool. + pub id: String, +} + +/// Assistant-authored message within a thread. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq, Default)] +pub struct AssistantMessageItem { + /// Identifier of the thread item. + pub id: String, + /// Type discriminator that is always `chatkit.thread_item`. + #[serde(default = "default_thread_item_object")] + pub object: String, + /// Unix timestamp (in seconds) for when the item was created. + pub created_at: i64, + /// Identifier of the parent thread. + pub thread_id: String, + /// Ordered assistant response segments. + pub content: Vec, +} + +/// Assistant response text accompanied by optional annotations. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ResponseOutputText { + /// Type discriminator that is always `output_text`. + #[serde(default = "default_output_text_type")] + pub r#type: String, + /// Assistant generated text. + pub text: String, + /// Ordered list of annotations attached to the response text. + #[serde(default)] + pub annotations: Vec, +} + +fn default_output_text_type() -> String { + "output_text".to_string() +} + +/// Annotation object describing a cited source. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum Annotation { + /// Annotation that references an uploaded file. + #[serde(rename = "file")] + File(FileAnnotation), + /// Annotation that references a URL. + #[serde(rename = "url")] + Url(UrlAnnotation), +} + +/// Annotation that references an uploaded file. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct FileAnnotation { + /// Type discriminator that is always `file` for this annotation. + #[serde(default = "default_file_annotation_type")] + pub r#type: String, + /// File attachment referenced by the annotation. + pub source: FileAnnotationSource, +} + +fn default_file_annotation_type() -> String { + "file".to_string() +} + +/// Attachment source referenced by an annotation. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct FileAnnotationSource { + /// Type discriminator that is always `file`. + #[serde(default = "default_file_source_type")] + pub r#type: String, + /// Filename referenced by the annotation. + pub filename: String, +} + +fn default_file_source_type() -> String { + "file".to_string() +} + +/// Annotation that references a URL. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct UrlAnnotation { + /// Type discriminator that is always `url` for this annotation. + #[serde(default = "default_url_annotation_type")] + pub r#type: String, + /// URL referenced by the annotation. + pub source: UrlAnnotationSource, +} + +fn default_url_annotation_type() -> String { + "url".to_string() +} + +/// URL backing an annotation entry. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct UrlAnnotationSource { + /// Type discriminator that is always `url`. + #[serde(default = "default_url_source_type")] + pub r#type: String, + /// URL referenced by the annotation. + pub url: String, +} + +fn default_url_source_type() -> String { + "url".to_string() +} + +/// Thread item that renders a widget payload. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct WidgetMessageItem { + /// Identifier of the thread item. + pub id: String, + /// Type discriminator that is always `chatkit.thread_item`. + #[serde(default = "default_thread_item_object")] + pub object: String, + /// Unix timestamp (in seconds) for when the item was created. + pub created_at: i64, + /// Identifier of the parent thread. + pub thread_id: String, + /// Serialized widget payload rendered in the UI. + pub widget: String, +} + +/// Record of a client side tool invocation initiated by the assistant. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct ClientToolCallItem { + /// Identifier of the thread item. + pub id: String, + /// Type discriminator that is always `chatkit.thread_item`. + #[serde(default = "default_thread_item_object")] + pub object: String, + /// Unix timestamp (in seconds) for when the item was created. + pub created_at: i64, + /// Identifier of the parent thread. + pub thread_id: String, + /// Execution status for the tool call. + pub status: ClientToolCallStatus, + /// Identifier for the client tool call. + pub call_id: String, + /// Tool name that was invoked. + pub name: String, + /// JSON-encoded arguments that were sent to the tool. + pub arguments: String, + /// JSON-encoded output captured from the tool. Defaults to null while execution is in progress. + pub output: Option, +} + +/// Execution status for the tool call. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum ClientToolCallStatus { + InProgress, + Completed, +} + +/// Task emitted by the workflow to show progress and status updates. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct TaskItem { + /// Identifier of the thread item. + pub id: String, + /// Type discriminator that is always `chatkit.thread_item`. + #[serde(default = "default_thread_item_object")] + pub object: String, + /// Unix timestamp (in seconds) for when the item was created. + pub created_at: i64, + /// Identifier of the parent thread. + pub thread_id: String, + /// Subtype for the task. + pub task_type: TaskType, + /// Optional heading for the task. Defaults to null when not provided. + pub heading: Option, + /// Optional summary that describes the task. Defaults to null when omitted. + pub summary: Option, +} + +/// Subtype for the task. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +#[serde(rename_all = "snake_case")] +pub enum TaskType { + Custom, + Thought, +} + +/// Collection of workflow tasks grouped together in the thread. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct TaskGroupItem { + /// Identifier of the thread item. + pub id: String, + /// Type discriminator that is always `chatkit.thread_item`. + #[serde(default = "default_thread_item_object")] + pub object: String, + /// Unix timestamp (in seconds) for when the item was created. + pub created_at: i64, + /// Identifier of the parent thread. + pub thread_id: String, + /// Tasks included in the group. + pub tasks: Vec, +} + +/// Task entry that appears within a TaskGroup. +#[derive(Clone, Serialize, Debug, Deserialize, PartialEq)] +pub struct TaskGroupTask { + /// Subtype for the grouped task. + pub task_type: TaskType, + /// Optional heading for the grouped task. Defaults to null when not provided. + pub heading: Option, + /// Optional summary that describes the grouped task. Defaults to null when omitted. + pub summary: Option, +} diff --git a/async-openai/src/types/mod.rs b/async-openai/src/types/mod.rs index f28556e5..f5fe4b39 100644 --- a/async-openai/src/types/mod.rs +++ b/async-openai/src/types/mod.rs @@ -7,6 +7,7 @@ pub mod audio; mod audit_log; pub mod batches; pub mod chat; +pub mod chatkit; mod common; mod completion; pub mod containers; diff --git a/examples/chatkit/Cargo.toml b/examples/chatkit/Cargo.toml new file mode 100644 index 00000000..7ff906d9 --- /dev/null +++ b/examples/chatkit/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "chatkit" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +async-openai = {path = "../../async-openai"} +serde_json = "1.0.135" +tokio = { version = "1.43.0", features = ["full"] } + diff --git a/examples/chatkit/src/main.rs b/examples/chatkit/src/main.rs new file mode 100644 index 00000000..16b7de9e --- /dev/null +++ b/examples/chatkit/src/main.rs @@ -0,0 +1,53 @@ +use std::error::Error; + +use async_openai::{ + config::OpenAIConfig, + types::chatkit::{CreateChatSessionRequestArgs, WorkflowParamArgs}, + Client, +}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Get workflow_id from command line arguments + let args: Vec = std::env::args().collect(); + if args.len() < 2 { + eprintln!("Usage: {} ", args[0]); + std::process::exit(1); + } + let workflow_id = &args[1]; + + println!("Using workflow_id: {}", workflow_id); + + let config = OpenAIConfig::default() + .with_header("OpenAI-Beta", "chatkit_beta=v1") + .unwrap(); + let client = Client::with_config(config); + + // 1. Create a ChatKit session + println!("\n=== Creating ChatKit Session ==="); + let workflow = WorkflowParamArgs::default() + .id(workflow_id.clone()) + .build()?; + + let session_request = CreateChatSessionRequestArgs::default() + .workflow(workflow) + .user("example_user".to_string()) + .build()?; + + let session = client.chatkit().sessions().create(session_request).await?; + println!("Created session:"); + println!(" ID: {}", session.id); + println!(" Status: {:?}", session.status); + println!(" Expires at: {}", session.expires_at); + println!(" Client secret: {}", session.client_secret); + println!(" Workflow ID: {}", session.workflow.id); + println!(" User: {}", session.user); + + // 2. Cancel the session (cleanup) + println!("\n=== Cancelling Session ==="); + let cancelled_session = client.chatkit().sessions().cancel(&session.id).await?; + println!("Cancelled session: {}", cancelled_session.id); + println!(" Status: {:?}", cancelled_session.status); + + Ok(()) +}