diff --git a/src/lib.rs b/src/lib.rs index 2a3ab6b..e56e1db 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ use std::collections::HashMap; use futures_util::{Stream, StreamExt as _}; use reqwest::Client; use reqwest_eventsource::{Event, RequestBuilderExt as _}; +use serde_json::Value; use types::{ Content, ContentData, FunctionResponse, FunctionResponsePayload, GenerateContentRequest, GenerateContentResponse, Role, @@ -19,13 +20,36 @@ pub enum GeminiError { #[error("Streaming Event Error: {0}")] EventSource(#[from] reqwest_eventsource::Error), #[error("API Error: {0}")] - Api(String), + Api(Value), #[error("JSON Error: {0}")] Json(#[from] serde_json::Error), #[error("Function execution error: {0}")] FunctionExecution(String), } +impl GeminiError { + async fn from_response( + response: reqwest::Response, + context: Option, + ) -> Self { + let status = response.status(); + let text = match response.text().await { + Ok(text) => text, + Err(error) => return Self::Http(error), + }; + let message = match serde_json::from_str::(&text) { + Ok(error) => error, + Err(_) => serde_json::Value::String(text), + }; + + Self::Api(serde_json::json!({ + "status": status.as_u16(), + "message": message, + "context": context.unwrap_or_default(), + })) + } +} + #[derive(Debug, Clone)] pub struct GeminiClient { api_key: String, @@ -85,7 +109,7 @@ impl GeminiClient { let response = self.http_client.get(&url).send().await?; if !response.status().is_success() { - return handle_error::>(response).await; + return Err(GeminiError::from_response(response, None).await); } let mut models = vec![]; @@ -98,8 +122,7 @@ impl GeminiClient { let response = self.http_client.get(&url).send().await?; if !response.status().is_success() { - let error_text = response.text().await?; - return Err(GeminiError::Api(error_text)); + return Err(GeminiError::from_response(response, None).await); } let response: Response = response.json().await?; @@ -132,7 +155,7 @@ impl GeminiClient { let response = self.http_client.post(&url).json(request).send().await?; if !response.status().is_success() { - return handle_error::(response).await; + return Err(GeminiError::from_response(response, None).await); } Ok(response.json().await?) @@ -171,16 +194,17 @@ impl GeminiClient { reqwest_eventsource::Error::StreamEnded => stream.close(), reqwest_eventsource::Error::InvalidContentType(content_type, response) => { let header = content_type.to_str().unwrap_or_default(); - let body = response.text().await?; - yield Err(GeminiError::Api(format!( - "Invalid content type {header}: {body}" - ))) + yield Err(GeminiError::from_response( + response, + Some(serde_json::json!({ + "cause": "Invalid content type", + "header": header + }))).await) } - reqwest_eventsource::Error::InvalidStatusCode(code, response) => { - let body = response.text().await?; - yield Err(GeminiError::Api(format!( - "Invalid status code {code}: {body}" - ))) + reqwest_eventsource::Error::InvalidStatusCode(_, response) => { + yield Err(GeminiError::from_response( + response, + Some(serde_json::json!({"cause": "Invalid status code"}))).await) } _ => yield Err(e.into()), } @@ -247,10 +271,3 @@ impl GeminiClient { } } } - -async fn handle_error(response: reqwest::Response) -> Result { - let status = response.status(); - let error_text = response.text().await?; - - Err(GeminiError::Api(format!("status {status}: {error_text}"))) -}