Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ serde_json = "1.0"
sysinfo = "0.32"
rusqlite = { version = "0.32", features = ["bundled"] }
rusqlite_migration = "1.3"
regex = "1.11"

[dev-dependencies]
tempfile = "3.12"
171 changes: 42 additions & 129 deletions src/cli/commands.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use clap::{Parser, Subcommand};
use prettytable::{format, row, Table};

use crate::cli::{inspect, ls, rm};
use crate::downloader::downloader::Downloader;
use crate::downloader::huggingface::HuggingFaceDownloader;
use crate::registry::model_registry::ModelRegistry;
Expand All @@ -21,7 +22,7 @@ enum Commands {
/// List running models
PS,
/// List local models
LS,
LS(LsArgs),
/// Download a model from a model provider
PULL(PullArgs),
/// Create and run a new model
Expand All @@ -38,9 +39,19 @@ enum Commands {
VERSION,
}

#[derive(Parser)]
struct LsArgs {
/// Optional model name pattern to filter (e.g., qwen, openai/*)
Copy link

Copilot AI Apr 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ls help text suggests patterns like openai/*, but ls treats this as a regex and openai/* is not a valid regex (it will error because * repeats the preceding token /). Either update the examples to valid regex (e.g., ^openai/.*) or implement glob-style matching if that's the intended UX.

Suggested change
/// Optional model name pattern to filter (e.g., qwen, openai/*)
/// Optional model name regex to filter (e.g., qwen, ^openai/.*)

Copilot uses AI. Check for mistakes.
pattern: Option<String>,

/// Advanced filter using SQL WHERE conditions (e.g., author=inftyai,license=mit)
#[arg(short = 'l', long, value_name = "KEY=VALUE,...")]
query: Option<String>,
}

#[derive(Parser)]
struct PullArgs {
/// Model name to download (e.g., InftyAI/tiny-random-gpt2)
/// Model name to download (e.g., inftyai/tiny-random-gpt2)
model: String,
#[arg(
short = 'p',
Expand All @@ -54,13 +65,13 @@ struct PullArgs {

#[derive(Parser)]
struct RmArgs {
/// Model name to remove (e.g., InftyAI/tiny-random-gpt2)
/// Model name to remove (e.g., inftyai/tiny-random-gpt2)
model: String,
}

#[derive(Parser)]
struct InspectArgs {
/// Model name to inspect (e.g., InftyAI/tiny-random-gpt2)
/// Model name to inspect (e.g., inftyai/tiny-random-gpt2)
model: String,
}

Expand Down Expand Up @@ -94,9 +105,17 @@ pub async fn run(cli: Cli) {
table.printstd();
}

Commands::LS => {
Commands::LS(args) => {
let registry = ModelRegistry::new(None);
let models = registry.load_models().unwrap_or_default();

let models =
match ls::execute(&registry, args.pattern.as_deref(), args.query.as_deref()) {
Ok(models) => models,
Err(e) => {
eprintln!("{}", e);
std::process::exit(1);
}
};

let mut table = Table::new();
table.set_format(
Expand All @@ -105,7 +124,9 @@ pub async fn run(cli: Cli) {
.padding(0, 1)
.build(),
);
table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "AGE"]);
table.add_row(row![
"MODEL", "TASK", "PROVIDER", "REVISION", "SIZE", "CREATED"
]);
for model in models {
let size_str = format_size_decimal(model.metadata.artifact.size);

Expand All @@ -117,8 +138,11 @@ pub async fn run(cli: Cli) {

let created_str = format_time_ago(&model.created_at);

let model_task = model.task.as_deref().unwrap_or("N/A");

table.add_row(row![
model.name,
model_task,
model.provider,
revision_short,
size_str,
Expand Down Expand Up @@ -153,23 +177,9 @@ pub async fn run(cli: Cli) {
Commands::RM(args) => {
let registry = ModelRegistry::new(None);

// Check if model exists first
match registry.get_model(&args.model) {
Ok(Some(_)) => {
// Delete model (cache + registry)
if let Err(e) = registry.remove_model(&args.model) {
eprintln!("Failed to remove model: {}", e);
std::process::exit(1);
}
}
Ok(None) => {
eprintln!("Model not found: {}", args.model);
std::process::exit(1);
}
Err(e) => {
eprintln!("Failed to load registry: {}", e);
std::process::exit(1);
}
if let Err(e) = rm::execute(&registry, &args.model) {
eprintln!("{}", e);
std::process::exit(1);
}
}

Expand All @@ -181,81 +191,10 @@ pub async fn run(cli: Cli) {
Commands::INSPECT(args) => {
let registry = ModelRegistry::new(None);

match registry.get_model(&args.model) {
Ok(Some(model)) => {
println!("Name: {}", model.name);
println!("Kind: Model");
println!("Spec:");
println!(
" Author: {}",
model.author.as_deref().unwrap_or("N/A")
);
println!(
" Type: {}",
model.r#type.as_deref().unwrap_or("N/A")
);
println!(
" License: {}",
model
.license
.as_ref()
.map(|s| s.to_uppercase())
.unwrap_or_else(|| "N/A".to_string())
);
println!(
" Model Series: {}",
model.model_series.as_deref().unwrap_or("N/A")
);
println!(
" Context Window: {}",
model
.metadata
.context_window
.map(|w| crate::utils::format::format_parameters(w as u64))
.unwrap_or_else(|| "N/A".to_string())
);
if let Some(st) = &model.metadata.safetensors {
println!(" Safetensors:");
if let Some(total) = st.get("total").and_then(|v| v.as_u64()) {
println!(
" Total: {}",
crate::utils::format::format_parameters(total)
);
}
if let Some(params) = st.get("parameters").and_then(|v| v.as_object()) {
println!(" Parameters:");
for (dtype, count) in params {
if let Some(num) = count.as_u64() {
println!(
" {:<12} {}",
format!("{}:", dtype),
crate::utils::format::format_parameters(num)
);
}
}
}
} else {
println!(" Safetensors: N/A");
}
// Artifact section
println!(" Artifact:");
println!(" Provider: {}", model.provider);
println!(" Revision: {}", model.metadata.artifact.revision);
println!(
" Size: {}",
format_size_decimal(model.metadata.artifact.size)
);
println!(" Cache Path: {}", model.metadata.artifact.path);
println!("Status:");
println!(" Created: {}", format_time_ago(&model.created_at));
println!(" Updated: {}", format_time_ago(&model.updated_at));
}
Ok(None) => {
eprintln!("Model not found: {}", args.model);
std::process::exit(1);
}
match inspect::execute(&registry, &args.model) {
Ok(model) => inspect::display(&model),
Err(e) => {
eprintln!("Failed to load registry: {}", e);
eprintln!("{}", e);
std::process::exit(1);
}
}
Expand Down Expand Up @@ -286,7 +225,7 @@ mod tests {
uuid: revision.to_string(),
name: name.to_string(),
author: Some("test-author".to_string()),
r#type: Some("text-generation".to_string()),
task: Some("text-generation".to_string()),
model_series: Some("gpt2".to_string()),
provider: "huggingface".to_string(),
license: Some("mit".to_string()),
Expand All @@ -309,7 +248,7 @@ mod tests {
let temp_dir = TempDir::new().unwrap();
let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf()));

let models = registry.load_models().unwrap_or_default();
let models = registry.load_models(None).unwrap_or_default();
assert_eq!(models.len(), 0);
}

Expand All @@ -322,7 +261,7 @@ mod tests {

registry.register_model(model).unwrap();

let models = registry.load_models().unwrap();
let models = registry.load_models(None).unwrap();
assert_eq!(models.len(), 1);
assert_eq!(models[0].name, "test/model");
assert_eq!(models[0].provider, "huggingface");
Expand All @@ -335,7 +274,7 @@ mod tests {

let mut model = create_test_model("test/gpt-model", "abc123def456");
model.author = Some("test-org".to_string());
model.r#type = Some("text-generation".to_string());
model.task = Some("text-generation".to_string());
model.license = Some("mit".to_string());
model.updated_at = "2025-01-02T00:00:00Z".to_string();

Expand All @@ -349,7 +288,7 @@ mod tests {
assert_eq!(model_info.created_at, "2025-01-01T00:00:00Z");
assert_eq!(model_info.updated_at, "2025-01-02T00:00:00Z");
assert_eq!(model_info.author, Some("test-org".to_string()));
assert_eq!(model_info.r#type, Some("text-generation".to_string()));
assert_eq!(model_info.task, Some("text-generation".to_string()));
assert_eq!(model_info.license, Some("mit".to_string()));
assert_eq!(model_info.model_series, Some("gpt2".to_string()));
assert_eq!(model_info.metadata.context_window, Some(2048));
Expand Down Expand Up @@ -386,32 +325,6 @@ mod tests {
assert!(model_info.metadata.safetensors.is_none());
}

#[test]
fn test_rm_command() {
let temp_dir = TempDir::new().unwrap();
let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf()));

let model = create_test_model("test/remove-model", "abc123");

registry.register_model(model).unwrap();
assert!(registry.get_model("test/remove-model").unwrap().is_some());

// Simulate RM command
let result = registry.get_model("test/remove-model");
assert!(result.is_ok());
assert!(result.unwrap().is_some());
}

#[test]
fn test_rm_command_nonexistent() {
let temp_dir = TempDir::new().unwrap();
let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf()));

let result = registry.get_model("nonexistent/model");
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}

#[test]
fn test_revision_truncation() {
let long_revision = "abc123def456ghi789jkl012";
Expand Down
Loading
Loading