Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion async-openai/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Download and save images to ./data directory.
// Each url is downloaded and saved in dedicated Tokio task.
// Directory is created if it doesn't exist.
response.save("./data").await?;
let paths = response.save("./data").await?;

paths
.iter()
.for_each(|path| println!("Image file path: {}", path.display()));

Ok(())
}
Expand Down
15 changes: 9 additions & 6 deletions async-openai/src/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ fn create_paths<P: AsRef<Path>>(url: &Url, base_dir: P) -> (PathBuf, PathBuf) {
(dir, path)
}

pub(crate) async fn download_url<P: AsRef<Path>>(url: &str, dir: P) -> Result<(), OpenAIError> {
pub(crate) async fn download_url<P: AsRef<Path>>(
url: &str,
dir: P,
) -> Result<PathBuf, OpenAIError> {
let parsed_url = Url::parse(url).map_err(|e| OpenAIError::FileSaveError(e.to_string()))?;
let response = reqwest::get(url)
.await
Expand All @@ -41,7 +44,7 @@ pub(crate) async fn download_url<P: AsRef<Path>>(url: &str, dir: P) -> Result<()
.map_err(|e| OpenAIError::FileSaveError(e.to_string()))?;

tokio::fs::write(
file_path,
file_path.as_path(),
response
.bytes()
.await
Expand All @@ -50,10 +53,10 @@ pub(crate) async fn download_url<P: AsRef<Path>>(url: &str, dir: P) -> Result<()
.await
.map_err(|e| OpenAIError::FileSaveError(e.to_string()))?;

Ok(())
Ok(file_path)
}

pub(crate) async fn save_b64<P: AsRef<Path>>(b64: &str, dir: P) -> Result<(), OpenAIError> {
pub(crate) async fn save_b64<P: AsRef<Path>>(b64: &str, dir: P) -> Result<PathBuf, OpenAIError> {
let filename: String = rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(10)
Expand All @@ -65,11 +68,11 @@ pub(crate) async fn save_b64<P: AsRef<Path>>(b64: &str, dir: P) -> Result<(), Op
let path = PathBuf::from(dir.as_ref()).join(filename);

tokio::fs::write(
path,
path.as_path(),
base64::decode(b64).map_err(|e| OpenAIError::FileSaveError(e.to_string()))?,
)
.await
.map_err(|e| OpenAIError::FileSaveError(e.to_string()))?;

Ok(())
Ok(path)
}
40 changes: 22 additions & 18 deletions async-openai/src/types/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ impl Display for ResponseFormat {
}

impl ImageResponse {
pub async fn save<P: AsRef<Path>>(&self, dir: P) -> Result<(), OpenAIError> {
/// Save each image in a dedicated Tokio task and return paths to saved files.
/// For [ResponseFormat::Url] each file is downloaded in dedicated Tokio task.
pub async fn save<P: AsRef<Path>>(&self, dir: P) -> Result<Vec<PathBuf>, OpenAIError> {
let exists = match Path::try_exists(dir.as_ref()) {
Ok(exists) => exists,
Err(e) => return Err(OpenAIError::FileSaveError(e.to_string())),
Expand All @@ -135,38 +137,40 @@ impl ImageResponse {
handles.push(tokio::spawn(async move { id.save(dir_buf).await }));
}

let result = futures::future::join_all(handles).await;

let errors: Vec<OpenAIError> = result
.into_iter()
.filter(|r| r.is_err() || r.as_ref().ok().unwrap().is_err())
.map(|r| match r {
Err(e) => OpenAIError::FileSaveError(e.to_string()),
Ok(inner) => inner.err().unwrap(),
})
.collect();
let results = futures::future::join_all(handles).await;
let mut errors = vec![];
let mut paths = vec![];

for result in results {
match result {
Ok(inner) => match inner {
Ok(path) => paths.push(path),
Err(e) => errors.push(e),
},
Err(e) => errors.push(OpenAIError::FileSaveError(e.to_string())),
}
}

if errors.len() > 0 {
if errors.is_empty() {
Ok(paths)
} else {
Err(OpenAIError::FileSaveError(
errors
.into_iter()
.map(|e| e.to_string())
.collect::<Vec<String>>()
.join("; "),
))
} else {
Ok(())
}
}
}

impl ImageData {
async fn save<P: AsRef<Path>>(&self, dir: P) -> Result<(), OpenAIError> {
async fn save<P: AsRef<Path>>(&self, dir: P) -> Result<PathBuf, OpenAIError> {
match self {
ImageData::Url(url) => download_url(url, dir).await?,
ImageData::B64Json(b64_json) => save_b64(b64_json, dir).await?,
ImageData::Url(url) => download_url(url, dir).await,
ImageData::B64Json(b64_json) => save_b64(b64_json, dir).await,
}
Ok(())
}
}

Expand Down
6 changes: 5 additions & 1 deletion examples/create-image-b64-json/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Response already contains image data in base64 format.
// Save each image to ./data directory in dedicated Tokio task.
// Directory is created if it doesn't exist.
response.save("./data").await?;
let paths = response.save("./data").await?;

paths
.iter()
.for_each(|path| println!("Image file path: {}", path.display()));

Ok(())
}
6 changes: 5 additions & 1 deletion examples/create-image/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Download and save images to ./data directory.
// Each url is downloaded and saved in dedicated Tokio task.
// Directory is created if it doesn't exist.
response.save("./data").await?;
let paths = response.save("./data").await?;

paths
.iter()
.for_each(|path| println!("Image file path: {}", path.display()));

Ok(())
}