diff --git a/async-openai/README.md b/async-openai/README.md index feebe3f5..b61799e3 100644 --- a/async-openai/README.md +++ b/async-openai/README.md @@ -81,7 +81,11 @@ async fn main() -> Result<(), Box> { // Download and save images to ./data directory. // Each url is downloaded and saved in dedicated Tokio task. // Directory is created if it doesn't exist. - response.save("./data").await?; + let paths = response.save("./data").await?; + + paths + .iter() + .for_each(|path| println!("Image file path: {}", path.display())); Ok(()) } diff --git a/async-openai/src/download.rs b/async-openai/src/download.rs index a2ec8dcb..c0029e9a 100644 --- a/async-openai/src/download.rs +++ b/async-openai/src/download.rs @@ -21,7 +21,10 @@ fn create_paths>(url: &Url, base_dir: P) -> (PathBuf, PathBuf) { (dir, path) } -pub(crate) async fn download_url>(url: &str, dir: P) -> Result<(), OpenAIError> { +pub(crate) async fn download_url>( + url: &str, + dir: P, +) -> Result { let parsed_url = Url::parse(url).map_err(|e| OpenAIError::FileSaveError(e.to_string()))?; let response = reqwest::get(url) .await @@ -41,7 +44,7 @@ pub(crate) async fn download_url>(url: &str, dir: P) -> Result<() .map_err(|e| OpenAIError::FileSaveError(e.to_string()))?; tokio::fs::write( - file_path, + file_path.as_path(), response .bytes() .await @@ -50,10 +53,10 @@ pub(crate) async fn download_url>(url: &str, dir: P) -> Result<() .await .map_err(|e| OpenAIError::FileSaveError(e.to_string()))?; - Ok(()) + Ok(file_path) } -pub(crate) async fn save_b64>(b64: &str, dir: P) -> Result<(), OpenAIError> { +pub(crate) async fn save_b64>(b64: &str, dir: P) -> Result { let filename: String = rand::thread_rng() .sample_iter(&Alphanumeric) .take(10) @@ -65,11 +68,11 @@ pub(crate) async fn save_b64>(b64: &str, dir: P) -> Result<(), Op let path = PathBuf::from(dir.as_ref()).join(filename); tokio::fs::write( - path, + path.as_path(), base64::decode(b64).map_err(|e| OpenAIError::FileSaveError(e.to_string()))?, ) .await .map_err(|e| OpenAIError::FileSaveError(e.to_string()))?; - Ok(()) + Ok(path) } diff --git a/async-openai/src/types/impls.rs b/async-openai/src/types/impls.rs index 39327a93..1b767e55 100644 --- a/async-openai/src/types/impls.rs +++ b/async-openai/src/types/impls.rs @@ -118,7 +118,9 @@ impl Display for ResponseFormat { } impl ImageResponse { - pub async fn save>(&self, dir: P) -> Result<(), OpenAIError> { + /// Save each image in a dedicated Tokio task and return paths to saved files. + /// For [ResponseFormat::Url] each file is downloaded in dedicated Tokio task. + pub async fn save>(&self, dir: P) -> Result, OpenAIError> { let exists = match Path::try_exists(dir.as_ref()) { Ok(exists) => exists, Err(e) => return Err(OpenAIError::FileSaveError(e.to_string())), @@ -135,18 +137,23 @@ impl ImageResponse { handles.push(tokio::spawn(async move { id.save(dir_buf).await })); } - let result = futures::future::join_all(handles).await; - - let errors: Vec = result - .into_iter() - .filter(|r| r.is_err() || r.as_ref().ok().unwrap().is_err()) - .map(|r| match r { - Err(e) => OpenAIError::FileSaveError(e.to_string()), - Ok(inner) => inner.err().unwrap(), - }) - .collect(); + let results = futures::future::join_all(handles).await; + let mut errors = vec![]; + let mut paths = vec![]; + + for result in results { + match result { + Ok(inner) => match inner { + Ok(path) => paths.push(path), + Err(e) => errors.push(e), + }, + Err(e) => errors.push(OpenAIError::FileSaveError(e.to_string())), + } + } - if errors.len() > 0 { + if errors.is_empty() { + Ok(paths) + } else { Err(OpenAIError::FileSaveError( errors .into_iter() @@ -154,19 +161,16 @@ impl ImageResponse { .collect::>() .join("; "), )) - } else { - Ok(()) } } } impl ImageData { - async fn save>(&self, dir: P) -> Result<(), OpenAIError> { + async fn save>(&self, dir: P) -> Result { match self { - ImageData::Url(url) => download_url(url, dir).await?, - ImageData::B64Json(b64_json) => save_b64(b64_json, dir).await?, + ImageData::Url(url) => download_url(url, dir).await, + ImageData::B64Json(b64_json) => save_b64(b64_json, dir).await, } - Ok(()) } } diff --git a/examples/create-image-b64-json/src/main.rs b/examples/create-image-b64-json/src/main.rs index abd18e53..a552126d 100644 --- a/examples/create-image-b64-json/src/main.rs +++ b/examples/create-image-b64-json/src/main.rs @@ -22,7 +22,11 @@ async fn main() -> Result<(), Box> { // Response already contains image data in base64 format. // Save each image to ./data directory in dedicated Tokio task. // Directory is created if it doesn't exist. - response.save("./data").await?; + let paths = response.save("./data").await?; + + paths + .iter() + .for_each(|path| println!("Image file path: {}", path.display())); Ok(()) } diff --git a/examples/create-image/src/main.rs b/examples/create-image/src/main.rs index 45e1de11..7bc6064c 100644 --- a/examples/create-image/src/main.rs +++ b/examples/create-image/src/main.rs @@ -22,7 +22,11 @@ async fn main() -> Result<(), Box> { // Download and save images to ./data directory. // Each url is downloaded and saved in dedicated Tokio task. // Directory is created if it doesn't exist. - response.save("./data").await?; + let paths = response.save("./data").await?; + + paths + .iter() + .for_each(|path| println!("Image file path: {}", path.display())); Ok(()) }