From a90ef661e4dac31a9f3defd2df3454d70d420504 Mon Sep 17 00:00:00 2001 From: Himanshu Neema Date: Sat, 1 Apr 2023 23:18:21 -0700 Subject: [PATCH 1/3] derive PartialEq for all types in types.rs --- async-openai/src/types/types.rs | 110 ++++++++++++++++---------------- 1 file changed, 55 insertions(+), 55 deletions(-) diff --git a/async-openai/src/types/types.rs b/async-openai/src/types/types.rs index 5492135f..384032cf 100644 --- a/async-openai/src/types/types.rs +++ b/async-openai/src/types/types.rs @@ -6,7 +6,7 @@ use serde::{Deserialize, Serialize}; use crate::error::OpenAIError; -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct Model { pub id: String, pub object: String, @@ -14,13 +14,13 @@ pub struct Model { pub owned_by: String, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct ListModelResponse { pub object: String, pub data: Vec, } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(untagged)] pub enum Prompt { String(String), @@ -30,14 +30,14 @@ pub enum Prompt { ArrayOfIntegerArray(Vec>), } -#[derive(Debug, Serialize, Deserialize, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] #[serde(untagged)] pub enum Stop { String(String), // nullable: true StringArray(Vec), // minItems: 1; maxItems: 4 } -#[derive(Clone, Serialize, Default, Debug, Builder)] +#[derive(Clone, Serialize, Default, Debug, Builder, PartialEq)] #[builder(name = "CreateCompletionRequestArgs")] #[builder(pattern = "mutable")] #[builder(setter(into, strip_option), default)] @@ -134,7 +134,7 @@ pub struct CreateCompletionRequest { pub user: Option, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct Logprobs { pub tokens: Vec, pub token_logprobs: Vec>, // Option is to account for null value in the list @@ -142,7 +142,7 @@ pub struct Logprobs { pub text_offset: Vec, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct Choice { pub text: String, pub index: u32, @@ -150,14 +150,14 @@ pub struct Choice { pub finish_reason: Option, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct CreateCompletionResponse { pub id: String, pub object: String, @@ -171,7 +171,7 @@ pub struct CreateCompletionResponse { pub type CompletionResponseStream = Pin> + Send>>; -#[derive(Debug, Clone, Serialize, Default, Builder)] +#[derive(Debug, Clone, Serialize, Default, Builder, PartialEq)] #[builder(name = "CreateEditRequestArgs")] #[builder(pattern = "mutable")] #[builder(setter(into, strip_option), default)] @@ -205,7 +205,7 @@ pub struct CreateEditRequest { pub top_p: Option, // min: 0, max: 1, default: 1 } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct CreateEditResponse { pub object: String, pub created: u32, @@ -213,7 +213,7 @@ pub struct CreateEditResponse { pub usage: Usage, } -#[derive(Default, Debug, Serialize, Clone)] +#[derive(Default, Debug, Serialize, Clone, PartialEq)] pub enum ImageSize { #[serde(rename = "256x256")] S256x256, @@ -224,7 +224,7 @@ pub enum ImageSize { S1024x1024, } -#[derive(Debug, Serialize, Default, Clone)] +#[derive(Debug, Serialize, Default, Clone, PartialEq)] #[serde(rename_all = "lowercase")] pub enum ResponseFormat { #[default] @@ -233,7 +233,7 @@ pub enum ResponseFormat { B64Json, } -#[derive(Debug, Clone, Serialize, Default, Builder)] +#[derive(Debug, Clone, Serialize, Default, Builder, PartialEq)] #[builder(name = "CreateImageRequestArgs")] #[builder(pattern = "mutable")] #[builder(setter(into, strip_option), default)] @@ -260,7 +260,7 @@ pub struct CreateImageRequest { pub user: Option, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] #[serde(rename_all = "lowercase")] pub enum ImageData { Url(std::sync::Arc), @@ -268,18 +268,18 @@ pub enum ImageData { B64Json(std::sync::Arc), } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct ImageResponse { pub created: u32, pub data: Vec>, } -#[derive(Debug, Default, Clone)] +#[derive(Debug, Default, Clone, PartialEq)] pub struct ImageInput { pub path: PathBuf, } -#[derive(Debug, Clone, Default, Builder)] +#[derive(Debug, Clone, Default, Builder, PartialEq)] #[builder(name = "CreateImageEditRequestArgs")] #[builder(pattern = "mutable")] #[builder(setter(into, strip_option), default)] @@ -308,7 +308,7 @@ pub struct CreateImageEditRequest { pub user: Option, } -#[derive(Debug, Default, Clone, Builder)] +#[derive(Debug, Default, Clone, Builder, PartialEq)] #[builder(name = "CreateImageVariationRequestArgs")] #[builder(pattern = "mutable")] #[builder(setter(into, strip_option), default)] @@ -331,14 +331,14 @@ pub struct CreateImageVariationRequest { pub user: Option, } -#[derive(Debug, Serialize, Clone)] +#[derive(Debug, Serialize, Clone, PartialEq)] #[serde(untagged)] pub enum ModerationInput { String(String), StringArray(Vec), } -#[derive(Debug, Serialize, Default, Clone)] +#[derive(Debug, Serialize, Default, Clone, PartialEq)] pub enum TextModerationModel { #[default] #[serde(rename = "text-moderation-latest")] @@ -347,7 +347,7 @@ pub enum TextModerationModel { Stable, } -#[derive(Debug, Default, Clone, Serialize, Builder)] +#[derive(Debug, Default, Clone, Serialize, Builder, PartialEq)] #[builder(name = "CreateModerationRequestArgs")] #[builder(pattern = "mutable")] #[builder(setter(into, strip_option), default)] @@ -364,7 +364,7 @@ pub struct CreateModerationRequest { pub model: Option, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct Category { pub hate: bool, #[serde(rename = "hate/threatening")] @@ -379,7 +379,7 @@ pub struct Category { pub violence_graphic: bool, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct CategoryScore { pub hate: f32, #[serde(rename = "hate/threatening")] @@ -394,26 +394,26 @@ pub struct CategoryScore { pub violence_graphic: f32, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct ContentModerationResult { pub flagged: bool, pub categories: Category, pub category_scores: CategoryScore, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct CreateModerationResponse { pub id: String, pub model: String, pub results: Vec, } -#[derive(Debug, Default, Clone)] +#[derive(Debug, Default, Clone, PartialEq)] pub struct FileInput { pub path: PathBuf, } -#[derive(Debug, Default, Clone, Builder)] +#[derive(Debug, Default, Clone, Builder, PartialEq)] #[builder(name = "CreateFileRequestArgs")] #[builder(pattern = "mutable")] #[builder(setter(into, strip_option), default)] @@ -431,20 +431,20 @@ pub struct CreateFileRequest { pub purpose: String, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct ListFilesResponse { pub object: String, pub data: Vec, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct DeleteFileResponse { pub id: String, pub object: String, pub deleted: bool, } -#[derive(Debug, Deserialize, PartialEq, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct OpenAIFile { pub id: String, pub object: String, @@ -456,7 +456,7 @@ pub struct OpenAIFile { pub status_details: Option, // nullable: true } -#[derive(Debug, Serialize, Clone, Default, Builder)] +#[derive(Debug, Serialize, Clone, Default, Builder, PartialEq)] #[builder(name = "CreateFineTuneRequestArgs")] #[builder(pattern = "mutable")] #[builder(setter(into, strip_option), default)] @@ -575,13 +575,13 @@ pub struct CreateFineTuneRequest { pub suffix: Option, // default: null, minLength:1, maxLength:40 } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct ListFineTuneResponse { pub object: String, pub data: Vec, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct FineTune { pub id: String, pub object: String, @@ -598,7 +598,7 @@ pub struct FineTune { pub events: Option>, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct FineTuneEvent { pub object: String, pub created_at: u32, @@ -606,7 +606,7 @@ pub struct FineTuneEvent { pub message: String, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct ListFineTuneEventsResponse { pub object: String, pub data: Vec, @@ -616,14 +616,14 @@ pub struct ListFineTuneEventsResponse { pub type FineTuneEventsResponseStream = Pin> + Send>>; -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct DeleteModelResponse { pub id: String, pub object: String, pub deleted: bool, } -#[derive(Debug, Serialize, Clone)] +#[derive(Debug, Serialize, Clone, PartialEq)] #[serde(untagged)] pub enum EmbeddingInput { String(String), @@ -633,7 +633,7 @@ pub enum EmbeddingInput { ArrayOfIntegerArray(Vec>), } -#[derive(Debug, Serialize, Default, Clone, Builder)] +#[derive(Debug, Serialize, Default, Clone, Builder, PartialEq)] #[builder(name = "CreateEmbeddingRequestArgs")] #[builder(pattern = "mutable")] #[builder(setter(into, strip_option), default)] @@ -659,20 +659,20 @@ pub struct CreateEmbeddingRequest { pub user: Option, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct Embedding { pub index: u32, pub object: String, pub embedding: Vec, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct EmbeddingUsage { pub prompt_tokens: u32, pub total_tokens: u32, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct CreateEmbeddingResponse { pub object: String, pub model: String, @@ -689,7 +689,7 @@ pub enum Role { Assistant, } -#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder)] +#[derive(Debug, Serialize, Deserialize, Default, Clone, Builder, PartialEq)] #[builder(name = "ChatCompletionRequestMessageArgs")] #[builder(pattern = "mutable")] #[builder(setter(into, strip_option), default)] @@ -705,13 +705,13 @@ pub struct ChatCompletionRequestMessage { pub name: Option, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct ChatCompletionResponseMessage { pub role: Role, pub content: String, } -#[derive(Clone, Serialize, Default, Debug, Builder, Deserialize)] +#[derive(Clone, Serialize, Default, Debug, Builder, Deserialize, PartialEq)] #[builder(name = "CreateChatCompletionRequestArgs")] #[builder(pattern = "mutable")] #[builder(setter(into, strip_option), default)] @@ -775,14 +775,14 @@ pub struct CreateChatCompletionRequest { pub user: Option, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct ChatChoice { pub index: u32, pub message: ChatCompletionResponseMessage, pub finish_reason: Option, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct CreateChatCompletionResponse { pub id: String, pub object: String, @@ -798,20 +798,20 @@ pub type ChatCompletionResponseStream = // For reason (not documented by OpenAI) the response from stream is different -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct ChatCompletionResponseStreamMessage { pub content: Option, pub role: Option, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct ChatChoiceDelta { pub index: u32, pub delta: ChatCompletionResponseStreamMessage, pub finish_reason: Option, } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct CreateChatCompletionStreamResponse { pub id: Option, pub object: String, @@ -821,12 +821,12 @@ pub struct CreateChatCompletionStreamResponse { pub usage: Option, } -#[derive(Debug, Default, Clone)] +#[derive(Debug, Default, Clone, PartialEq)] pub struct AudioInput { pub path: PathBuf, } -#[derive(Debug, Serialize, Default, Clone)] +#[derive(Debug, Serialize, Default, Clone, PartialEq)] #[serde(rename_all = "snake_case")] pub enum AudioResponseFormat { #[default] @@ -837,7 +837,7 @@ pub enum AudioResponseFormat { Vtt, } -#[derive(Clone, Default, Debug, Builder)] +#[derive(Clone, Default, Debug, Builder, PartialEq)] #[builder(name = "CreateTranscriptionRequestArgs")] #[builder(pattern = "mutable")] #[builder(setter(into, strip_option), default)] @@ -868,7 +868,7 @@ pub struct CreateTranscriptionResponse { pub text: String, } -#[derive(Clone, Default, Debug, Builder)] +#[derive(Clone, Default, Debug, Builder, PartialEq)] #[builder(name = "CreateTranslationRequestArgs")] #[builder(pattern = "mutable")] #[builder(setter(into, strip_option), default)] @@ -891,7 +891,7 @@ pub struct CreateTranslationRequest { pub temperature: Option, // default: 0 } -#[derive(Debug, Deserialize, Clone)] +#[derive(Debug, Deserialize, Clone, PartialEq)] pub struct CreateTranslationResponse { pub text: String, } From 351bf3781785dc0b38f39377886352fcef1b43f8 Mon Sep 17 00:00:00 2001 From: Himanshu Neema Date: Sat, 1 Apr 2023 23:19:44 -0700 Subject: [PATCH 2/3] log deserialization error --- async-openai/src/client.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index fd167ed7..85688393 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -232,8 +232,8 @@ impl Client { return Err(OpenAIError::ApiError(wrapped_error.error)); } - let response: O = - serde_json::from_slice(bytes.as_ref()).map_err(OpenAIError::JSONDeserialize)?; + let response: O = serde_json::from_slice(bytes.as_ref()) + .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; Ok(response) } From 9d16125fc3f4e3347a735cf9b2122185b1207985 Mon Sep 17 00:00:00 2001 From: Himanshu Neema Date: Sat, 1 Apr 2023 23:22:32 -0700 Subject: [PATCH 3/3] log deserialization error --- async-openai/src/client.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index 85688393..3006a416 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -369,7 +369,9 @@ impl Client { } let response = match serde_json::from_str::(&message.data) { - Err(e) => Err(OpenAIError::JSONDeserialize(e)), + Err(e) => { + Err(map_deserialization_error(e, &message.data.as_bytes())) + } Ok(output) => Ok(output), };