Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 38 additions & 21 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::collections::HashMap;
use futures_util::{Stream, StreamExt as _};
use reqwest::Client;
use reqwest_eventsource::{Event, RequestBuilderExt as _};
use serde_json::Value;
use types::{
Content, ContentData, FunctionResponse, FunctionResponsePayload, GenerateContentRequest,
GenerateContentResponse, Role,
Expand All @@ -19,13 +20,36 @@ pub enum GeminiError {
#[error("Streaming Event Error: {0}")]
EventSource(#[from] reqwest_eventsource::Error),
#[error("API Error: {0}")]
Api(String),
Api(Value),
#[error("JSON Error: {0}")]
Json(#[from] serde_json::Error),
#[error("Function execution error: {0}")]
FunctionExecution(String),
}

impl GeminiError {
async fn from_response(
response: reqwest::Response,
context: Option<serde_json::Value>,
) -> Self {
let status = response.status();
let text = match response.text().await {
Ok(text) => text,
Err(error) => return Self::Http(error),
};
let message = match serde_json::from_str::<Value>(&text) {
Ok(error) => error,
Err(_) => serde_json::Value::String(text),
};

Self::Api(serde_json::json!({
"status": status.as_u16(),
"message": message,
"context": context.unwrap_or_default(),
}))
}
}

#[derive(Debug, Clone)]
pub struct GeminiClient {
api_key: String,
Expand Down Expand Up @@ -85,7 +109,7 @@ impl GeminiClient {

let response = self.http_client.get(&url).send().await?;
if !response.status().is_success() {
return handle_error::<Vec<types::Model>>(response).await;
return Err(GeminiError::from_response(response, None).await);
}

let mut models = vec![];
Expand All @@ -98,8 +122,7 @@ impl GeminiClient {

let response = self.http_client.get(&url).send().await?;
if !response.status().is_success() {
let error_text = response.text().await?;
return Err(GeminiError::Api(error_text));
return Err(GeminiError::from_response(response, None).await);
}

let response: Response = response.json().await?;
Expand Down Expand Up @@ -132,7 +155,7 @@ impl GeminiClient {

let response = self.http_client.post(&url).json(request).send().await?;
if !response.status().is_success() {
return handle_error::<GenerateContentResponse>(response).await;
return Err(GeminiError::from_response(response, None).await);
}

Ok(response.json().await?)
Expand Down Expand Up @@ -171,16 +194,17 @@ impl GeminiClient {
reqwest_eventsource::Error::StreamEnded => stream.close(),
reqwest_eventsource::Error::InvalidContentType(content_type, response) => {
let header = content_type.to_str().unwrap_or_default();
let body = response.text().await?;
yield Err(GeminiError::Api(format!(
"Invalid content type {header}: {body}"
)))
yield Err(GeminiError::from_response(
response,
Some(serde_json::json!({
"cause": "Invalid content type",
"header": header
}))).await)
}
reqwest_eventsource::Error::InvalidStatusCode(code, response) => {
let body = response.text().await?;
yield Err(GeminiError::Api(format!(
"Invalid status code {code}: {body}"
)))
reqwest_eventsource::Error::InvalidStatusCode(_, response) => {
yield Err(GeminiError::from_response(
response,
Some(serde_json::json!({"cause": "Invalid status code"}))).await)
}
_ => yield Err(e.into()),
}
Expand Down Expand Up @@ -247,10 +271,3 @@ impl GeminiClient {
}
}
}

async fn handle_error<T>(response: reqwest::Response) -> Result<T, GeminiError> {
let status = response.status();
let error_text = response.text().await?;

Err(GeminiError::Api(format!("status {status}: {error_text}")))
}