Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 3 additions & 35 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ serde = { version = "1.0", features = ["derive"] }
serde_derive = "1.0"
env_logger = "0.11.6"
log = "0.4.26"
indicatif = "0.17.11"
indicatif = "0.18"
dirs = "6.0.0"
hf-hub = { version = "0.5.0", features = ["tokio"] }
colored = "2.1"
Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ test:
lint:
cargo fmt --all -- --check
cargo clippy --all-targets --all-features -- -D warnings

format:
cargo fmt --all
cargo clippy --fix --allow-dirty
6 changes: 3 additions & 3 deletions src/cli/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::downloader::downloader::Downloader;
use crate::downloader::huggingface::HuggingFaceDownloader;
use crate::registry::model_registry::ModelRegistry;
use crate::system::system_info::SystemInfo;
use crate::util::format::{format_size, format_time_ago};
use crate::utils::format::{format_size_decimal, format_time_ago};

#[derive(Parser)]
#[command(name = "PUMA")]
Expand Down Expand Up @@ -92,7 +92,7 @@ pub async fn run(cli: Cli) {
table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "CREATED"]);

for model in models {
let size_str = format_size(model.size);
let size_str = format_size_decimal(model.size);

let revision_short = if model.revision.len() > 8 {
&model.revision[..8]
Expand All @@ -118,7 +118,7 @@ pub async fn run(cli: Cli) {
Provider::Huggingface => {
let downloader = HuggingFaceDownloader::new();
if let Err(e) = downloader.download_model(&args.model).await {
eprintln!("Error downloading model: {}", e);
eprintln!("Error downloading model: {}", e);
std::process::exit(1);
}
}
Expand Down
111 changes: 56 additions & 55 deletions src/downloader/huggingface.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,31 @@
use colored::Colorize;
use log::{debug, info};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use log::debug;

use hf_hub::api::tokio::{ApiBuilder, Progress};
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};

use crate::downloader::downloader::{DownloadError, Downloader};
use crate::downloader::progress::{DownloadProgressManager, FileProgress};
use crate::registry::model_registry::{ModelInfo, ModelRegistry};
use crate::util::file;
use crate::utils::file::{self, format_model_name};

/// Adapter to bridge HuggingFace's Progress trait with our FileProgress
#[derive(Clone)]
struct FileProgressBar {
pb: ProgressBar,
total_size: Arc<AtomicU64>,
struct HfProgressAdapter {
progress: FileProgress,
}

impl Progress for FileProgressBar {
impl Progress for HfProgressAdapter {
async fn init(&mut self, size: usize, _filename: &str) {
self.pb.set_length(size as u64);
self.pb.reset();
self.pb.tick(); // Force render with correct size
self.total_size.fetch_add(size as u64, Ordering::Relaxed);
self.progress.init(size as u64);
}

async fn update(&mut self, size: usize) {
self.pb.inc(size as u64);
self.progress.update(size as u64);
}

async fn finish(&mut self) {}
async fn finish(&mut self) {
self.progress.finish();
}
}

pub struct HuggingFaceDownloader;
Expand All @@ -49,7 +46,7 @@ impl Downloader for HuggingFaceDownloader {
async fn download_model(&self, name: &str) -> Result<(), DownloadError> {
let start_time = std::time::Instant::now();

info!("Downloading model {} from Hugging Face...", name);
debug!("Downloading model {} from Hugging Face...", name);

// Use unified PUMA cache directory
let cache_dir = file::huggingface_cache_dir();
Expand All @@ -65,6 +62,8 @@ impl Downloader for HuggingFaceDownloader {
DownloadError::ApiError(format!("Failed to initialize Hugging Face API: {}", e))
})?;

println!("🐆 pulling manifest");

// Download the entire model repository using snapshot download
let repo = api.model(name.to_string());

Expand All @@ -84,9 +83,6 @@ impl Downloader for HuggingFaceDownloader {

debug!("Model info for {}: {:?}", name, model_info);

// Create multi-progress for parallel downloads
let multi_progress = Arc::new(MultiProgress::new());

// Calculate the longest filename for proper alignment
let max_filename_len = model_info
.siblings
Expand All @@ -95,54 +91,59 @@ impl Downloader for HuggingFaceDownloader {
.max()
.unwrap_or(30);

// Progress bar style with block characters (chart-like, not #)
let template = format!(
"{{msg:<{width}}} [{{elapsed_precise}}] {{bar:60.white}} {{bytes}}/{{total_bytes}}",
width = max_filename_len
);
let style = ProgressStyle::default_bar()
.template(&template)
.unwrap()
.progress_chars("▇▆▅▄▃▂▁ ");
// Create progress manager
let progress_manager = DownloadProgressManager::new(max_filename_len);

// Download all files in parallel
let mut tasks = Vec::new();
// Calculate cache paths
let model_cache_path = cache_dir.join(format_model_name(name));
let sha = model_info.sha.clone();
let total_size = Arc::new(AtomicU64::new(0));
let snapshot_path = model_cache_path.join("snapshots").join(&sha);

// Process all files in manifest order (cached files show as instantly complete)
let mut tasks = Vec::new();

for sibling in model_info.siblings {
let api_clone = api.clone();
let model_name = name.to_string();
let filename = sibling.rfilename.clone();
let total_size_clone = Arc::clone(&total_size);

let pb = multi_progress.add(ProgressBar::hidden());
pb.set_style(style.clone());
pb.set_message(filename.clone());
let progress_manager_clone = progress_manager.clone();
let snapshot_path_clone = snapshot_path.clone();

let task = tokio::spawn(async move {
debug!("Downloading: {}", filename);

let repo = api_clone.model(model_name);
let progress = FileProgressBar {
pb: pb.clone(),
total_size: total_size_clone,
};

let result = repo.download_with_progress(&filename, progress).await;
// Check if file exists in cache
let cached_file_path = snapshot_path_clone.join(&filename);
if cached_file_path.exists() {
debug!("File {} found in cache, showing as complete", filename);

// Create progress bar and mark as instantly complete
let mut file_progress = progress_manager_clone.create_file_progress(&filename);
let file_size = cached_file_path.metadata().map(|m| m.len()).unwrap_or(0);
file_progress.init(file_size);
file_progress.update(file_size);
file_progress.finish();

match &result {
Ok(_) => {
pb.finish();
}
Err(_) => {
pb.abandon();
}
return Ok(());
}

result.map_err(|e| {
DownloadError::NetworkError(format!("Failed to download {}: {}", filename, e))
})
// File not in cache, download with progress
debug!("Downloading: {}", filename);
let file_progress = progress_manager_clone.create_file_progress(&filename);
let progress = HfProgressAdapter {
progress: file_progress,
};

repo.download_with_progress(&filename, progress)
.await
.map_err(|e| {
DownloadError::NetworkError(format!(
"Failed to download {}: {}",
filename, e
))
})?;

Ok(())
});

tasks.push(task);
Expand All @@ -157,8 +158,8 @@ impl Downloader for HuggingFaceDownloader {
let elapsed_time = start_time.elapsed();

// Get accumulated size from downloads
let downloaded_size = total_size.load(Ordering::Relaxed);
let model_cache_path = cache_dir.join(format!("models--{}", name.replace("/", "--")));
let downloaded_size = progress_manager.total_downloaded_bytes();
let model_cache_path = cache_dir.join(format_model_name(name));

// Register the model
let model_info_record = ModelInfo {
Expand Down
1 change: 1 addition & 0 deletions src/downloader/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[allow(clippy::module_inception)]
pub mod downloader;
pub mod huggingface;
pub mod progress;
Loading
Loading