diff --git a/Cargo.lock b/Cargo.lock index 6554c89..7974fe6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -232,19 +232,6 @@ dependencies = [ "windows-sys 0.59.0", ] -[[package]] -name = "console" -version = "0.15.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b" -dependencies = [ - "encode_unicode", - "libc", - "once_cell", - "unicode-width 0.2.0", - "windows-sys 0.59.0", -] - [[package]] name = "console" version = "0.16.3" @@ -714,7 +701,7 @@ dependencies = [ "dirs", "futures", "http", - "indicatif 0.18.4", + "indicatif", "libc", "log", "native-tls", @@ -1020,26 +1007,13 @@ dependencies = [ "hashbrown", ] -[[package]] -name = "indicatif" -version = "0.17.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "183b3088984b400f4cfac3620d5e076c84da5364016b4f49473de574b2586235" -dependencies = [ - "console 0.15.10", - "number_prefix", - "portable-atomic", - "unicode-width 0.2.0", - "web-time", -] - [[package]] name = "indicatif" version = "0.18.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "25470f23803092da7d239834776d653104d551bc4d7eacaf31e6837854b8e9eb" dependencies = [ - "console 0.16.3", + "console", "portable-atomic", "unicode-width 0.2.0", "unit-prefix", @@ -1225,12 +1199,6 @@ dependencies = [ "libc", ] -[[package]] -name = "number_prefix" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" - [[package]] name = "object" version = "0.36.7" @@ -1406,7 +1374,7 @@ dependencies = [ "dirs", "env_logger", "hf-hub", - "indicatif 0.17.11", + "indicatif", "log", "prettytable-rs", "reqwest", diff --git a/Cargo.toml b/Cargo.toml index eab4ff7..e05dce3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] } serde_derive = "1.0" env_logger = "0.11.6" log = "0.4.26" -indicatif = "0.17.11" +indicatif = "0.18" dirs = "6.0.0" hf-hub = { version = "0.5.0", features = ["tokio"] } colored = "2.1" diff --git a/Makefile b/Makefile index 61ed3f8..0e72459 100644 --- a/Makefile +++ b/Makefile @@ -7,3 +7,7 @@ test: lint: cargo fmt --all -- --check cargo clippy --all-targets --all-features -- -D warnings + +format: + cargo fmt --all + cargo clippy --fix --allow-dirty \ No newline at end of file diff --git a/src/cli/commands.rs b/src/cli/commands.rs index 8a04b59..128d137 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -5,7 +5,7 @@ use crate::downloader::downloader::Downloader; use crate::downloader::huggingface::HuggingFaceDownloader; use crate::registry::model_registry::ModelRegistry; use crate::system::system_info::SystemInfo; -use crate::util::format::{format_size, format_time_ago}; +use crate::utils::format::{format_size_decimal, format_time_ago}; #[derive(Parser)] #[command(name = "PUMA")] @@ -92,7 +92,7 @@ pub async fn run(cli: Cli) { table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "CREATED"]); for model in models { - let size_str = format_size(model.size); + let size_str = format_size_decimal(model.size); let revision_short = if model.revision.len() > 8 { &model.revision[..8] @@ -118,7 +118,7 @@ pub async fn run(cli: Cli) { Provider::Huggingface => { let downloader = HuggingFaceDownloader::new(); if let Err(e) = downloader.download_model(&args.model).await { - eprintln!("Error downloading model: {}", e); + eprintln!("❌ Error downloading model: {}", e); std::process::exit(1); } } diff --git a/src/downloader/huggingface.rs b/src/downloader/huggingface.rs index baaa7cd..0e16a3b 100644 --- a/src/downloader/huggingface.rs +++ b/src/downloader/huggingface.rs @@ -1,34 +1,31 @@ use colored::Colorize; -use log::{debug, info}; -use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::Arc; +use log::debug; use hf_hub::api::tokio::{ApiBuilder, Progress}; -use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; use crate::downloader::downloader::{DownloadError, Downloader}; +use crate::downloader::progress::{DownloadProgressManager, FileProgress}; use crate::registry::model_registry::{ModelInfo, ModelRegistry}; -use crate::util::file; +use crate::utils::file::{self, format_model_name}; +/// Adapter to bridge HuggingFace's Progress trait with our FileProgress #[derive(Clone)] -struct FileProgressBar { - pb: ProgressBar, - total_size: Arc, +struct HfProgressAdapter { + progress: FileProgress, } -impl Progress for FileProgressBar { +impl Progress for HfProgressAdapter { async fn init(&mut self, size: usize, _filename: &str) { - self.pb.set_length(size as u64); - self.pb.reset(); - self.pb.tick(); // Force render with correct size - self.total_size.fetch_add(size as u64, Ordering::Relaxed); + self.progress.init(size as u64); } async fn update(&mut self, size: usize) { - self.pb.inc(size as u64); + self.progress.update(size as u64); } - async fn finish(&mut self) {} + async fn finish(&mut self) { + self.progress.finish(); + } } pub struct HuggingFaceDownloader; @@ -49,7 +46,7 @@ impl Downloader for HuggingFaceDownloader { async fn download_model(&self, name: &str) -> Result<(), DownloadError> { let start_time = std::time::Instant::now(); - info!("Downloading model {} from Hugging Face...", name); + debug!("Downloading model {} from Hugging Face...", name); // Use unified PUMA cache directory let cache_dir = file::huggingface_cache_dir(); @@ -65,6 +62,8 @@ impl Downloader for HuggingFaceDownloader { DownloadError::ApiError(format!("Failed to initialize Hugging Face API: {}", e)) })?; + println!("🐆 pulling manifest"); + // Download the entire model repository using snapshot download let repo = api.model(name.to_string()); @@ -84,9 +83,6 @@ impl Downloader for HuggingFaceDownloader { debug!("Model info for {}: {:?}", name, model_info); - // Create multi-progress for parallel downloads - let multi_progress = Arc::new(MultiProgress::new()); - // Calculate the longest filename for proper alignment let max_filename_len = model_info .siblings @@ -95,54 +91,59 @@ impl Downloader for HuggingFaceDownloader { .max() .unwrap_or(30); - // Progress bar style with block characters (chart-like, not #) - let template = format!( - "{{msg:<{width}}} [{{elapsed_precise}}] {{bar:60.white}} {{bytes}}/{{total_bytes}}", - width = max_filename_len - ); - let style = ProgressStyle::default_bar() - .template(&template) - .unwrap() - .progress_chars("▇▆▅▄▃▂▁ "); + // Create progress manager + let progress_manager = DownloadProgressManager::new(max_filename_len); - // Download all files in parallel - let mut tasks = Vec::new(); + // Calculate cache paths + let model_cache_path = cache_dir.join(format_model_name(name)); let sha = model_info.sha.clone(); - let total_size = Arc::new(AtomicU64::new(0)); + let snapshot_path = model_cache_path.join("snapshots").join(&sha); + + // Process all files in manifest order (cached files show as instantly complete) + let mut tasks = Vec::new(); for sibling in model_info.siblings { let api_clone = api.clone(); let model_name = name.to_string(); let filename = sibling.rfilename.clone(); - let total_size_clone = Arc::clone(&total_size); - - let pb = multi_progress.add(ProgressBar::hidden()); - pb.set_style(style.clone()); - pb.set_message(filename.clone()); + let progress_manager_clone = progress_manager.clone(); + let snapshot_path_clone = snapshot_path.clone(); let task = tokio::spawn(async move { - debug!("Downloading: {}", filename); - let repo = api_clone.model(model_name); - let progress = FileProgressBar { - pb: pb.clone(), - total_size: total_size_clone, - }; - let result = repo.download_with_progress(&filename, progress).await; + // Check if file exists in cache + let cached_file_path = snapshot_path_clone.join(&filename); + if cached_file_path.exists() { + debug!("File {} found in cache, showing as complete", filename); + + // Create progress bar and mark as instantly complete + let mut file_progress = progress_manager_clone.create_file_progress(&filename); + let file_size = cached_file_path.metadata().map(|m| m.len()).unwrap_or(0); + file_progress.init(file_size); + file_progress.update(file_size); + file_progress.finish(); - match &result { - Ok(_) => { - pb.finish(); - } - Err(_) => { - pb.abandon(); - } + return Ok(()); } - result.map_err(|e| { - DownloadError::NetworkError(format!("Failed to download {}: {}", filename, e)) - }) + // File not in cache, download with progress + debug!("Downloading: {}", filename); + let file_progress = progress_manager_clone.create_file_progress(&filename); + let progress = HfProgressAdapter { + progress: file_progress, + }; + + repo.download_with_progress(&filename, progress) + .await + .map_err(|e| { + DownloadError::NetworkError(format!( + "Failed to download {}: {}", + filename, e + )) + })?; + + Ok(()) }); tasks.push(task); @@ -157,8 +158,8 @@ impl Downloader for HuggingFaceDownloader { let elapsed_time = start_time.elapsed(); // Get accumulated size from downloads - let downloaded_size = total_size.load(Ordering::Relaxed); - let model_cache_path = cache_dir.join(format!("models--{}", name.replace("/", "--"))); + let downloaded_size = progress_manager.total_downloaded_bytes(); + let model_cache_path = cache_dir.join(format_model_name(name)); // Register the model let model_info_record = ModelInfo { diff --git a/src/downloader/mod.rs b/src/downloader/mod.rs index 2bdb3f5..39ef068 100644 --- a/src/downloader/mod.rs +++ b/src/downloader/mod.rs @@ -1,3 +1,4 @@ #[allow(clippy::module_inception)] pub mod downloader; pub mod huggingface; +pub mod progress; diff --git a/src/downloader/progress.rs b/src/downloader/progress.rs new file mode 100644 index 0000000..69f3f8d --- /dev/null +++ b/src/downloader/progress.rs @@ -0,0 +1,109 @@ +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; +use std::sync::atomic::{AtomicU64, Ordering}; +use std::sync::Arc; + +/// Manages multi-file download progress tracking +/// +/// # Example +/// ```rust +/// use puma::downloader::progress::DownloadProgressManager; +/// +/// let progress_manager = DownloadProgressManager::new(30); +/// let mut file_progress = progress_manager.create_file_progress("model.bin"); +/// +/// file_progress.init(1024 * 1024); // 1 MB +/// file_progress.update(512 * 1024); // Downloaded 512 KB +/// file_progress.finish(); +/// +/// let total = progress_manager.total_downloaded_bytes(); +/// ``` +#[derive(Clone)] +pub struct DownloadProgressManager { + multi_progress: Arc, + total_size: Arc, + style: ProgressStyle, +} + +impl DownloadProgressManager { + /// Create a new progress manager with aligned file names + pub fn new(max_filename_len: usize) -> Self { + let multi_progress = Arc::new(MultiProgress::new()); + + let template = format!( + "{{msg:<{width}}} [{{elapsed_precise}}] {{bar:60.white}} {{bytes}}/{{total_bytes}}", + width = max_filename_len + ); + let style = ProgressStyle::default_bar() + .template(&template) + .unwrap() + .progress_chars("▇▆▅▄▃▂▁ "); + + Self { + multi_progress, + total_size: Arc::new(AtomicU64::new(0)), + style, + } + } + + /// Create a new progress bar for a file download + pub fn create_file_progress(&self, filename: &str) -> FileProgress { + let pb = self.multi_progress.add(ProgressBar::hidden()); + pb.set_style(self.style.clone()); + pb.set_message(filename.to_string()); + + FileProgress { + pb, + total_size: Arc::clone(&self.total_size), + } + } + + /// Get the total accumulated download size + pub fn total_downloaded_bytes(&self) -> u64 { + self.total_size.load(Ordering::Relaxed) + } +} + +/// Tracks progress for a single file download +#[derive(Clone)] +pub struct FileProgress { + pb: ProgressBar, + total_size: Arc, +} + +impl FileProgress { + /// Initialize progress bar with file size + pub fn init(&mut self, size: u64) { + self.pb.set_length(size); + self.pb.reset(); + self.pb.tick(); + self.total_size.fetch_add(size, Ordering::Relaxed); + } + + /// Update progress with downloaded bytes + pub fn update(&mut self, bytes: u64) { + self.pb.inc(bytes); + } + + /// Mark download as complete + pub fn finish(&mut self) { + self.pb.finish(); + } + + /// Mark download as failed + #[allow(dead_code)] + pub fn abandon(&mut self) { + self.pb.abandon(); + } + + /// Get the inner progress bar (for provider-specific adapters) + #[allow(dead_code)] + pub fn progress_bar(&self) -> &ProgressBar { + &self.pb + } + + /// Get the total size tracker (for provider-specific adapters) + #[allow(dead_code)] + pub fn total_size_tracker(&self) -> Arc { + Arc::clone(&self.total_size) + } +} diff --git a/src/main.rs b/src/main.rs index 02d7f23..57e4eb5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,13 +2,13 @@ mod cli; mod downloader; mod registry; mod system; -mod util; +mod utils; use clap::Parser; use tokio::runtime::Builder; use crate::cli::commands::{run, Cli}; -use crate::util::file; +use crate::utils::file; fn main() { // Initialize logger. diff --git a/src/registry/model_registry.rs b/src/registry/model_registry.rs index ccc1eab..2154d6d 100644 --- a/src/registry/model_registry.rs +++ b/src/registry/model_registry.rs @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize}; use std::fs; use std::path::PathBuf; -use crate::util::file; +use crate::utils::file; #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ModelInfo { @@ -96,7 +96,7 @@ impl ModelRegistry { self.unregister_model(name)?; println!( - "\n{} {} {}", + "{} {} {}", "✓".green().bold(), "Successfully removed model".bright_white(), name.cyan().bold() diff --git a/src/system/system_info.rs b/src/system/system_info.rs index 8abb7df..b716949 100644 --- a/src/system/system_info.rs +++ b/src/system/system_info.rs @@ -5,8 +5,8 @@ use std::path::PathBuf; use sysinfo::System; use crate::registry::model_registry::ModelRegistry; -use crate::util::file; -use crate::util::format::format_size; +use crate::utils::file; +use crate::utils::format::format_size; #[derive(Debug, Serialize, Deserialize)] pub struct SystemInfo { diff --git a/src/util/file.rs b/src/util/file.rs deleted file mode 100644 index 602a1ef..0000000 --- a/src/util/file.rs +++ /dev/null @@ -1,32 +0,0 @@ -use std::fs; -use std::path::PathBuf; - -use dirs::home_dir; - -pub fn create_folder_if_not_exists(folder_path: &PathBuf) -> std::io::Result<()> { - fs::create_dir_all(folder_path)?; - Ok(()) -} - -pub fn root_home() -> PathBuf { - // Allow tests to override PUMA home directory - if let Ok(test_home) = std::env::var("PUMA_HOME") { - PathBuf::from(test_home) - } else { - let home = home_dir().expect("Failed to get home directory"); - home.join(".puma") - } -} - -pub fn cache_dir() -> PathBuf { - root_home().join("cache") -} - -pub fn huggingface_cache_dir() -> PathBuf { - cache_dir().join("huggingface") -} - -#[allow(dead_code)] -pub fn modelscope_cache_dir() -> PathBuf { - cache_dir().join("modelscope") -} diff --git a/src/utils/file.rs b/src/utils/file.rs new file mode 100644 index 0000000..53ade50 --- /dev/null +++ b/src/utils/file.rs @@ -0,0 +1,95 @@ +use std::fs; +use std::path::PathBuf; + +use dirs::home_dir; + +pub fn create_folder_if_not_exists(folder_path: &PathBuf) -> std::io::Result<()> { + fs::create_dir_all(folder_path)?; + Ok(()) +} + +pub fn root_home() -> PathBuf { + // Allow tests to override PUMA home directory + if let Ok(test_home) = std::env::var("PUMA_HOME") { + PathBuf::from(test_home) + } else { + let home = home_dir().expect("Failed to get home directory"); + home.join(".puma") + } +} + +pub fn cache_dir() -> PathBuf { + root_home().join("cache") +} + +pub fn huggingface_cache_dir() -> PathBuf { + cache_dir().join("huggingface") +} + +#[allow(dead_code)] +pub fn modelscope_cache_dir() -> PathBuf { + cache_dir().join("modelscope") +} + +/// Format model name for HuggingFace cache directory +/// Converts "owner/model" to "models--owner--model" +pub fn format_model_name(name: &str) -> String { + format!("models--{}", name.replace("/", "--")) +} + +/// List all files recursively in a directory +#[allow(dead_code)] +pub fn list_files_recursive(dir: &std::path::Path) -> std::io::Result> { + let mut files = Vec::new(); + if dir.is_dir() { + for entry in fs::read_dir(dir)? { + let entry = entry?; + let path = entry.path(); + if path.is_dir() { + files.extend(list_files_recursive(&path)?); + } else { + files.push(path); + } + } + } + Ok(files) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_format_model_name_basic() { + assert_eq!(format_model_name("owner/model"), "models--owner--model"); + } + + #[test] + fn test_format_model_name_complex() { + assert_eq!( + format_model_name("Qwen/Qwen3.5-2B"), + "models--Qwen--Qwen3.5-2B" + ); + } + + #[test] + fn test_format_model_name_multiple_slashes() { + assert_eq!( + format_model_name("org/team/model"), + "models--org--team--model" + ); + } + + #[test] + fn test_format_model_name_no_slash() { + assert_eq!(format_model_name("model"), "models--model"); + } + + #[test] + fn test_format_model_name_special_chars() { + assert_eq!( + format_model_name("InftyAI/tiny-random-gpt2"), + "models--InftyAI--tiny-random-gpt2" + ); + } +} diff --git a/src/util/format.rs b/src/utils/format.rs similarity index 76% rename from src/util/format.rs rename to src/utils/format.rs index 70e109b..0e193ab 100644 --- a/src/util/format.rs +++ b/src/utils/format.rs @@ -17,6 +17,23 @@ pub fn format_size(bytes: u64) -> String { } } +/// Format byte size to human-readable format using decimal units (B, KB, MB, GB) +pub fn format_size_decimal(bytes: u64) -> String { + const KB: f64 = 1000.0; + const MB: f64 = 1000.0 * 1000.0; + const GB: f64 = 1000.0 * 1000.0 * 1000.0; + + if bytes as f64 >= GB { + format!("{:.2} GB", bytes as f64 / GB) + } else if bytes as f64 >= MB { + format!("{:.2} MB", bytes as f64 / MB) + } else if bytes as f64 >= KB { + format!("{:.2} KB", bytes as f64 / KB) + } else { + format!("{} B", bytes) + } +} + /// Format RFC3339 timestamp to human-readable relative time (e.g., "2 hours ago") pub fn format_time_ago(timestamp: &str) -> String { // Try to parse as RFC3339 @@ -213,4 +230,47 @@ mod tests { let invalid = "not-a-timestamp"; assert_eq!(format_time_ago(invalid), "not-a-timestamp"); } + + #[test] + fn test_format_size_decimal_bytes() { + assert_eq!(format_size_decimal(0), "0 B"); + assert_eq!(format_size_decimal(1), "1 B"); + assert_eq!(format_size_decimal(999), "999 B"); + } + + #[test] + fn test_format_size_decimal_kilobytes() { + assert_eq!(format_size_decimal(1000), "1.00 KB"); + assert_eq!(format_size_decimal(1500), "1.50 KB"); + assert_eq!(format_size_decimal(10000), "10.00 KB"); + assert_eq!(format_size_decimal(999_999), "1000.00 KB"); + } + + #[test] + fn test_format_size_decimal_megabytes() { + assert_eq!(format_size_decimal(1_000_000), "1.00 MB"); + assert_eq!(format_size_decimal(1_500_000), "1.50 MB"); + assert_eq!(format_size_decimal(10_000_000), "10.00 MB"); + assert_eq!(format_size_decimal(500_000_000), "500.00 MB"); + } + + #[test] + fn test_format_size_decimal_gigabytes() { + assert_eq!(format_size_decimal(1_000_000_000), "1.00 GB"); + assert_eq!(format_size_decimal(1_500_000_000), "1.50 GB"); + assert_eq!(format_size_decimal(10_000_000_000), "10.00 GB"); + assert_eq!(format_size_decimal(100_000_000_000), "100.00 GB"); + } + + #[test] + fn test_format_size_decimal_realistic_model_sizes() { + // Small model (100 MB) + assert_eq!(format_size_decimal(100_000_000), "100.00 MB"); + + // Medium model (7 GB) + assert_eq!(format_size_decimal(7_000_000_000), "7.00 GB"); + + // Large model (65 GB) + assert_eq!(format_size_decimal(65_000_000_000), "65.00 GB"); + } } diff --git a/src/util/mod.rs b/src/utils/mod.rs similarity index 100% rename from src/util/mod.rs rename to src/utils/mod.rs