Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ PUMA Information:
| `stop` | 🚧 | Stop a running model | `puma stop <model-id>` |
| `rm` | ✅ | Remove a model | `puma rm InftyAI/tiny-random-gpt2` |
| `info` | ✅ | Display system-wide information | `puma info` |
| `inspect` | 🚧 | Return detailed information about a model or service | `puma inspect InftyAI/tiny-random-gpt2` |
| `inspect` | | Return detailed information about a model or service | `puma inspect InftyAI/tiny-random-gpt2` |
| `version` | ✅ | Show PUMA version | `puma version` |
| `help` | ✅ | Show help information | `puma help` |

Expand Down
69 changes: 64 additions & 5 deletions src/cli/commands.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ enum Commands {
/// Display system-wide information
INFO,
/// Return detailed information about a model
INSPECT,
INSPECT(InspectArgs),
/// Returns the version of PUMA.
VERSION,
}
Expand All @@ -58,6 +58,12 @@ struct RmArgs {
model: String,
}

#[derive(Parser)]
struct InspectArgs {
/// Model name to inspect (e.g., InftyAI/tiny-random-gpt2)
model: String,
}

#[derive(Debug, Clone, Default, clap::ValueEnum)]
pub enum Provider {
#[default]
Expand All @@ -70,7 +76,12 @@ pub async fn run(cli: Cli) {
match cli.command {
Commands::PS => {
let mut table = Table::new();
table.set_format(*format::consts::FORMAT_CLEAN);
table.set_format(
format::FormatBuilder::new()
.column_separator(' ')
.padding(0, 1)
.build(),
);
table.add_row(row!["NAME", "PROVIDER", "MODEL", "STATUS", "AGE"]);
table.add_row(row![
"deepseek-r1",
Expand All @@ -88,7 +99,12 @@ pub async fn run(cli: Cli) {
let models = registry.load_models().unwrap_or_default();

let mut table = Table::new();
table.set_format(*format::consts::FORMAT_CLEAN);
table.set_format(
format::FormatBuilder::new()
.column_separator(' ')
.padding(0, 1)
.build(),
);
table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "MODIFIED"]);

for model in models {
Expand Down Expand Up @@ -163,8 +179,51 @@ pub async fn run(cli: Cli) {
info.display();
}

Commands::INSPECT => {
println!("Returning detailed information about model...");
Commands::INSPECT(args) => {
let registry = ModelRegistry::new(None);

match registry.get_model(&args.model) {
Ok(Some(model)) => {
println!("Name: {}", model.name);
println!("Kind: Model");

println!("Spec:");
// Architecture section (only if info is available)
if let Some(arch) = &model.arch {
println!(" Architecture:");
if let Some(model_type) = &arch.model_type {
Comment on lines +190 to +194
println!(" Type: {}", model_type);
}
if let Some(classes) = &arch.classes {
println!(" Classes: {}", classes.join(", "));
}
if let Some(parameters) = &arch.parameters {
println!(" Parameters: {}", parameters);
}
if let Some(context_window) = arch.context_window {
println!(" Context Window: {}", context_window);
}
}
// Registry section
println!(" Registry:");
println!(" Provider: {}", model.provider);
println!(" Revision: {}", model.revision);
println!(" Size: {}", format_size_decimal(model.size));
println!(
" Modified: {}",
format_time_ago(&model.modified_at)
);
println!(" Cache Path: {}", model.cache_path);
}
Ok(None) => {
eprintln!("Model not found: {}", args.model);
std::process::exit(1);
}
Err(e) => {
eprintln!("Failed to load registry: {}", e);
std::process::exit(1);
}
}
}

Commands::VERSION => {
Expand Down
75 changes: 60 additions & 15 deletions src/downloader/huggingface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use colored::Colorize;
use log::debug;

use hf_hub::api::tokio::{ApiBuilder, Progress};
use indicatif::{ProgressBar, ProgressStyle};

use crate::downloader::downloader::{DownloadError, Downloader};
use crate::downloader::progress::{DownloadProgressManager, FileProgress};
use crate::registry::model_registry::{ModelInfo, ModelRegistry};
use crate::registry::model_registry::{ModelArchitecture, ModelInfo, ModelRegistry};
use crate::utils::file::{self, format_model_name};

/// Adapter to bridge HuggingFace's Progress trait with our FileProgress
Expand Down Expand Up @@ -62,7 +63,15 @@ impl Downloader for HuggingFaceDownloader {
DownloadError::ApiError(format!("Failed to initialize Hugging Face API: {}", e))
})?;

println!("🐆 pulling manifest");
// Create a simple spinner for manifest pulling
let manifest_spinner = ProgressBar::new_spinner();
manifest_spinner.set_style(
ProgressStyle::default_spinner()
.tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏")
.template("pulling manifest {spinner:.white}")
.unwrap(),
);
manifest_spinner.enable_steady_tick(std::time::Duration::from_millis(80));

// Download the entire model repository using snapshot download
let repo = api.model(name.to_string());
Expand All @@ -81,6 +90,10 @@ impl Downloader for HuggingFaceDownloader {
}
})?;
Comment on lines 66 to 91

// Stop manifest spinner and print clean message
manifest_spinner.finish_and_clear();
println!("pulling manifest");

debug!("Model info for {}: {:?}", name, model_info);

// Calculate the longest filename for proper alignment
Expand All @@ -91,6 +104,8 @@ impl Downloader for HuggingFaceDownloader {
.max()
.unwrap_or(30);

// Add extra space for "pulling " prefix
let max_filename_len = max_filename_len + 8;
// Create progress manager
let progress_manager = DownloadProgressManager::new(max_filename_len);

Expand Down Expand Up @@ -124,8 +139,9 @@ impl Downloader for HuggingFaceDownloader {
debug!("File {} found in cache, showing as complete", filename);

// Create progress bar for cached file (no speed display)
let display_name = format!("pulling {}", filename);
let mut file_progress =
progress_manager_clone.create_cached_file_progress(&filename);
progress_manager_clone.create_cached_file_progress(&display_name);
let file_size = cached_file_path.metadata().map(|m| m.len()).unwrap_or(0);
file_progress.init(file_size);
file_progress.update(file_size);
Expand All @@ -136,7 +152,8 @@ impl Downloader for HuggingFaceDownloader {

// File not in cache, download with progress
debug!("Downloading: {}", filename);
let file_progress = progress_manager_clone.create_file_progress(&filename);
let display_name = format!("pulling {}", filename);
let file_progress = progress_manager_clone.create_file_progress(&display_name);
let progress = HfProgressAdapter {
progress: file_progress,
};
Expand All @@ -156,37 +173,65 @@ impl Downloader for HuggingFaceDownloader {
tasks.push(task);
}

// Give tasks a moment to start and create their progress bars
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
Comment on lines +176 to +177

// Show spinner at the bottom after all progress bars are created (only if not fully cached)
let spinner = if !model_totally_cached {
Some(progress_manager.create_spinner())
} else {
None
};
Comment on lines +176 to +184

// Wait for all downloads to complete
for task in tasks {
task.await
.map_err(|e| DownloadError::ApiError(format!("Task join error: {}", e)))??;
}

// Finish spinner after downloads complete
if let Some(spinner) = &spinner {
spinner.finish_and_clear();
}

Comment on lines 186 to +196
let elapsed_time = start_time.elapsed();

// Get accumulated size from downloads
let downloaded_size = progress_manager.total_downloaded_bytes();
let model_cache_path = cache_dir.join(format_model_name(name));

// Register the model
let model_info_record = ModelInfo {
name: name.to_string(),
provider: "huggingface".to_string(),
revision: sha,
size: downloaded_size,
modified_at: chrono::Local::now().to_rfc3339(),
cache_path: model_cache_path.to_string_lossy().to_string(),
};

// Register the model only if not totally cached
if !model_totally_cached {
// Extract architecture info from config.json
Comment on lines +203 to +205
let config_path = snapshot_path.join("config.json");
let arch = if config_path.exists() {
std::fs::read_to_string(&config_path)
.ok()
.and_then(|content| serde_json::from_str::<serde_json::Value>(&content).ok())
.and_then(|config| ModelArchitecture::from_config(&config))
} else {
None
};

let model_info_record = ModelInfo {
name: name.to_string(),
provider: "huggingface".to_string(),
revision: sha,
size: downloaded_size,
modified_at: chrono::Local::now().to_rfc3339(),
cache_path: model_cache_path.to_string_lossy().to_string(),
arch,
};

let registry = ModelRegistry::new(None);
registry
.register_model(model_info_record)
.map_err(|e| DownloadError::ApiError(format!("Failed to register model: {}", e)))?;
}
Comment on lines +203 to 230

// Print success message
println!(
"\n{} {} {} {} {:.2?}",
"{} {} {} {} {:.2?}",
"✓".green().bold(),
"Successfully downloaded model".bright_white(),
name.cyan().bold(),
Expand Down
13 changes: 13 additions & 0 deletions src/downloader/progress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,19 @@ impl DownloadProgressManager {
pub fn total_downloaded_bytes(&self) -> u64 {
self.total_size.load(Ordering::Relaxed)
}

/// Create a spinner progress bar (for post-download operations)
pub fn create_spinner(&self) -> ProgressBar {
let pb = self.multi_progress.add(ProgressBar::new_spinner());
pb.set_style(
ProgressStyle::default_spinner()
.tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏")
.template("{spinner} ")
.unwrap(),
);
pb.enable_steady_tick(std::time::Duration::from_millis(80));
pb
}
}

/// Tracks progress for a single file download
Expand Down
Loading
Loading