From 2fb5bf54a913239cf6c4a280ae61cd90b2791be7 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Sat, 25 Apr 2026 11:17:12 +0100 Subject: [PATCH 1/6] support label filtering Signed-off-by: kerthcet --- Cargo.lock | 1 + Cargo.toml | 1 + src/cli/commands.rs | 72 +++++++++++-- src/registry/model_registry.rs | 17 +-- src/storage/sqlite.rs | 188 ++++++++++++++++++++++++++++++--- src/storage/storage_trait.rs | 6 +- src/system/system_info.rs | 2 +- 7 files changed, 254 insertions(+), 33 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5970e74..97cc3c6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1430,6 +1430,7 @@ dependencies = [ "indicatif", "log", "prettytable-rs", + "regex", "reqwest", "rusqlite", "rusqlite_migration", diff --git a/Cargo.toml b/Cargo.toml index 18b7ad5..7a010d4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ serde_json = "1.0" sysinfo = "0.32" rusqlite = { version = "0.32", features = ["bundled"] } rusqlite_migration = "1.3" +regex = "1.11" [dev-dependencies] tempfile = "3.12" diff --git a/src/cli/commands.rs b/src/cli/commands.rs index 18ab8dc..e1458bc 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -21,7 +21,7 @@ enum Commands { /// List running models PS, /// List local models - LS, + LS(LsArgs), /// Download a model from a model provider PULL(PullArgs), /// Create and run a new model @@ -38,9 +38,19 @@ enum Commands { VERSION, } +#[derive(Parser)] +struct LsArgs { + /// Optional model name pattern to filter (e.g., qwen, openai/*) + pattern: Option, + + /// Advanced filter using SQL WHERE conditions (e.g., author=inftyai,license=mit) + #[arg(short = 'l', long, value_name = "KEY=VALUE,...")] + query: Option, +} + #[derive(Parser)] struct PullArgs { - /// Model name to download (e.g., InftyAI/tiny-random-gpt2) + /// Model name to download (e.g., inftyai/tiny-random-gpt2) model: String, #[arg( short = 'p', @@ -54,13 +64,13 @@ struct PullArgs { #[derive(Parser)] struct RmArgs { - /// Model name to remove (e.g., InftyAI/tiny-random-gpt2) + /// Model name to remove (e.g., inftyai/tiny-random-gpt2) model: String, } #[derive(Parser)] struct InspectArgs { - /// Model name to inspect (e.g., InftyAI/tiny-random-gpt2) + /// Model name to inspect (e.g., inftyai/tiny-random-gpt2) model: String, } @@ -94,9 +104,55 @@ pub async fn run(cli: Cli) { table.printstd(); } - Commands::LS => { + Commands::LS(args) => { let registry = ModelRegistry::new(None); - let models = registry.load_models().unwrap_or_default(); + + // Parse query filters if provided (e.g., "author=inftyai,license=mit") + let mut query_filters = std::collections::HashMap::new(); + if let Some(query_str) = &args.query { + for pair in query_str.split(',') { + if let Some((key, value)) = pair.split_once('=') { + query_filters.insert(key.trim().to_string(), value.trim().to_string()); + } else { + eprintln!("Invalid query format: {}. Expected key=value pairs separated by commas.", pair); + std::process::exit(1); + } + } + } + + // Load models with optional SQL filters + let filter_ref = if query_filters.is_empty() { + None + } else { + Some(&query_filters) + }; + + let mut models = registry.load_models(filter_ref).unwrap_or_else(|e| { + eprintln!("Failed to query models: {}", e); + std::process::exit(1); + }); + + // Filter models by name pattern if provided (case-insensitive) + if let Some(pattern) = &args.pattern { + let pattern_lower = pattern.to_lowercase(); + models.retain(|model| { + let name_lower = model.name.to_lowercase(); + if pattern_lower.ends_with("/*") { + // Prefix match: "InftyAI/*" matches "InftyAI/model1", "InftyAI/model2" + let prefix = &pattern_lower[..pattern_lower.len() - 2]; + name_lower.starts_with(prefix) + } else if pattern_lower.contains('*') { + // Wildcard match (simple glob) + let regex_pattern = pattern_lower.replace('*', ".*"); + regex::Regex::new(®ex_pattern) + .map(|re| re.is_match(&name_lower)) + .unwrap_or(false) + } else { + // Exact or substring match + name_lower.contains(&pattern_lower) + } + }); + } let mut table = Table::new(); table.set_format( @@ -309,7 +365,7 @@ mod tests { let temp_dir = TempDir::new().unwrap(); let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); - let models = registry.load_models().unwrap_or_default(); + let models = registry.load_models(None).unwrap_or_default(); assert_eq!(models.len(), 0); } @@ -322,7 +378,7 @@ mod tests { registry.register_model(model).unwrap(); - let models = registry.load_models().unwrap(); + let models = registry.load_models(None).unwrap(); assert_eq!(models.len(), 1); assert_eq!(models[0].name, "test/model"); assert_eq!(models[0].provider, "huggingface"); diff --git a/src/registry/model_registry.rs b/src/registry/model_registry.rs index 632408d..2272302 100644 --- a/src/registry/model_registry.rs +++ b/src/registry/model_registry.rs @@ -1,5 +1,6 @@ use colored::Colorize; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::fs; use std::path::PathBuf; @@ -56,8 +57,8 @@ impl ModelRegistry { } } - pub fn load_models(&self) -> Result, std::io::Error> { - self.storage.load_models() + pub fn load_models(&self, filters: Option<&HashMap>) -> Result, std::io::Error> { + self.storage.load_models(filters) } pub fn register_model(&self, model: ModelInfo) -> Result<(), std::io::Error> { @@ -143,7 +144,7 @@ mod tests { registry.register_model(model.clone()).unwrap(); - let models = registry.load_models().unwrap(); + let models = registry.load_models(None).unwrap(); assert_eq!(models.len(), 1); assert_eq!(models[0].name, "test/model"); } @@ -156,10 +157,10 @@ mod tests { let model = create_test_model("test/model", "abc123"); registry.register_model(model).unwrap(); - assert_eq!(registry.load_models().unwrap().len(), 1); + assert_eq!(registry.load_models(None).unwrap().len(), 1); registry.unregister_model("test/model").unwrap(); - assert_eq!(registry.load_models().unwrap().len(), 0); + assert_eq!(registry.load_models(None).unwrap().len(), 0); } #[test] @@ -205,7 +206,7 @@ mod tests { registry.register_model(model2).unwrap(); - let models = registry.load_models().unwrap(); + let models = registry.load_models(None).unwrap(); assert_eq!(models.len(), 1); assert_eq!(models[0].metadata.artifact.revision, "def456"); assert_eq!(models[0].metadata.artifact.size, 2000); @@ -229,14 +230,14 @@ mod tests { model.metadata.artifact.path = cache_dir.to_string_lossy().to_string(); registry.register_model(model).unwrap(); - assert_eq!(registry.load_models().unwrap().len(), 1); + assert_eq!(registry.load_models(None).unwrap().len(), 1); assert!(cache_dir.exists()); // Delete model registry.remove_model("test/model").unwrap(); // Verify model removed from registry - assert_eq!(registry.load_models().unwrap().len(), 0); + assert_eq!(registry.load_models(None).unwrap().len(), 0); // Verify cache directory deleted assert!(!cache_dir.exists()); diff --git a/src/storage/sqlite.rs b/src/storage/sqlite.rs index 8b77f79..fd11ec0 100644 --- a/src/storage/sqlite.rs +++ b/src/storage/sqlite.rs @@ -2,6 +2,7 @@ use crate::registry::model_registry::{ModelInfo, ModelMetadata}; use crate::storage::ModelStorage; use rusqlite::{params, Connection, Result as SqlResult}; use rusqlite_migration::{Migrations, M}; +use std::collections::HashMap; use std::io; use std::path::PathBuf; @@ -55,19 +56,55 @@ impl SqliteStorage { } impl ModelStorage for SqliteStorage { - fn load_models(&self) -> Result, io::Error> { + fn load_models( + &self, + filters: Option<&HashMap>, + ) -> Result, io::Error> { let conn = self.get_connection()?; - let mut stmt = conn - .prepare( + // Build WHERE clause from filters + let mut where_clauses = Vec::new(); + let mut params: Vec = Vec::new(); + + if let Some(filter_map) = filters { + // Allowed columns for filtering (prevent SQL injection) + let allowed_columns = ["author", "type", "model_series", "provider", "license"]; + + for (key, value) in filter_map { + if allowed_columns.contains(&key.as_str()) { + where_clauses.push(format!("{} = ?", key)); + params.push(value.clone()); + } else { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid filter column: {}", key), + )); + } + } + } + + let query = if where_clauses.is_empty() { + "SELECT uuid, name, author, type, model_series, provider, license, + metadata, created_at, updated_at + FROM models" + .to_string() + } else { + format!( "SELECT uuid, name, author, type, model_series, provider, license, metadata, created_at, updated_at - FROM models", + FROM models + WHERE {}", + where_clauses.join(" AND ") ) - .map_err(io::Error::other)?; + }; + + let mut stmt = conn.prepare(&query).map_err(io::Error::other)?; + + let param_refs: Vec<&dyn rusqlite::ToSql> = + params.iter().map(|p| p as &dyn rusqlite::ToSql).collect(); let models = stmt - .query_map([], |row| { + .query_map(param_refs.as_slice(), |row| { let metadata_json: String = row.get(7)?; let metadata: ModelMetadata = serde_json::from_str(&metadata_json) .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?; @@ -98,6 +135,10 @@ impl ModelStorage for SqliteStorage { let metadata_json = serde_json::to_string(&model.metadata) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + // Normalize name and author to lowercase + let name_lower = model.name.to_lowercase(); + let author_lower = model.author.as_ref().map(|a| a.to_lowercase()); + conn.execute( "INSERT INTO models (uuid, name, author, type, model_series, provider, license, @@ -114,8 +155,8 @@ impl ModelStorage for SqliteStorage { updated_at = excluded.updated_at", params![ &model.uuid, - &model.name, - model.author.as_deref(), + &name_lower, + author_lower.as_deref(), model.r#type.as_deref(), model.model_series.as_deref(), &model.provider, @@ -133,7 +174,10 @@ impl ModelStorage for SqliteStorage { fn unregister_model(&self, name: &str) -> Result<(), io::Error> { let conn = self.get_connection()?; - conn.execute("DELETE FROM models WHERE name = ?1", params![name]) + // Normalize name to lowercase for case-insensitive lookup + let name_lower = name.to_lowercase(); + + conn.execute("DELETE FROM models WHERE name = ?1", params![name_lower]) .map_err(io::Error::other)?; Ok(()) @@ -142,11 +186,14 @@ impl ModelStorage for SqliteStorage { fn get_model(&self, name: &str) -> Result, io::Error> { let conn = self.get_connection()?; + // Normalize name to lowercase for case-insensitive lookup + let name_lower = name.to_lowercase(); + let result = conn.query_row( "SELECT uuid, name, author, type, model_series, provider, license, metadata, created_at, updated_at FROM models WHERE name = ?1", - params![name], + params![name_lower], |row| { let metadata_json: String = row.get(7)?; let metadata: ModelMetadata = serde_json::from_str(&metadata_json) @@ -218,7 +265,7 @@ mod tests { let model = create_test_model("test/model", "uuid123"); storage.register_model(model.clone()).unwrap(); - let models = storage.load_models().unwrap(); + let models = storage.load_models(None).unwrap(); assert_eq!(models.len(), 1); assert_eq!(models[0].name, "test/model"); assert_eq!(models[0].uuid, "uuid123"); @@ -249,10 +296,10 @@ mod tests { let model = create_test_model("test/model", "uuid123"); storage.register_model(model).unwrap(); - assert_eq!(storage.load_models().unwrap().len(), 1); + assert_eq!(storage.load_models(None).unwrap().len(), 1); storage.unregister_model("test/model").unwrap(); - assert_eq!(storage.load_models().unwrap().len(), 0); + assert_eq!(storage.load_models(None).unwrap().len(), 0); } #[test] @@ -269,7 +316,7 @@ mod tests { model2.updated_at = "2025-01-02T00:00:00Z".to_string(); storage.register_model(model2).unwrap(); - let models = storage.load_models().unwrap(); + let models = storage.load_models(None).unwrap(); assert_eq!(models.len(), 1); // created_at should be preserved assert_eq!(models[0].created_at, "2025-01-01T00:00:00Z"); @@ -297,4 +344,117 @@ mod tests { let st = retrieved.metadata.safetensors.unwrap(); assert_eq!(st.get("total").unwrap().as_u64().unwrap(), 1000); } + + #[test] + fn test_load_models_with_single_filter() { + let temp_dir = TempDir::new().unwrap(); + let db_path = temp_dir.path().join("test.db"); + let storage = SqliteStorage::new(db_path).unwrap(); + + let mut model1 = create_test_model("test/model1", "uuid1"); + model1.author = Some("author1".to_string()); + storage.register_model(model1).unwrap(); + + let mut model2 = create_test_model("test/model2", "uuid2"); + model2.author = Some("author2".to_string()); + storage.register_model(model2).unwrap(); + + let mut filters = HashMap::new(); + filters.insert("author".to_string(), "author1".to_string()); + + let models = storage.load_models(Some(&filters)).unwrap(); + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "test/model1"); + } + + #[test] + fn test_load_models_with_multiple_filters() { + let temp_dir = TempDir::new().unwrap(); + let db_path = temp_dir.path().join("test.db"); + let storage = SqliteStorage::new(db_path).unwrap(); + + let mut model1 = create_test_model("test/model1", "uuid1"); + model1.author = Some("InftyAI".to_string()); + model1.license = Some("mit".to_string()); + storage.register_model(model1).unwrap(); + + let mut model2 = create_test_model("test/model2", "uuid2"); + model2.author = Some("InftyAI".to_string()); + model2.license = Some("apache-2.0".to_string()); + storage.register_model(model2).unwrap(); + + let mut model3 = create_test_model("test/model3", "uuid3"); + model3.author = Some("other-author".to_string()); + model3.license = Some("mit".to_string()); + storage.register_model(model3).unwrap(); + + let mut filters = HashMap::new(); + filters.insert("author".to_string(), "inftyai".to_string()); + filters.insert("license".to_string(), "mit".to_string()); + + let models = storage.load_models(Some(&filters)).unwrap(); + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "test/model1"); + } + + #[test] + fn test_load_models_with_invalid_filter_column() { + let temp_dir = TempDir::new().unwrap(); + let db_path = temp_dir.path().join("test.db"); + let storage = SqliteStorage::new(db_path).unwrap(); + + let mut filters = HashMap::new(); + filters.insert("invalid_column".to_string(), "value".to_string()); + + let result = storage.load_models(Some(&filters)); + assert!(result.is_err()); + assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::InvalidInput); + } + + #[test] + fn test_name_and_author_stored_lowercase() { + let temp_dir = TempDir::new().unwrap(); + let db_path = temp_dir.path().join("test.db"); + let storage = SqliteStorage::new(db_path).unwrap(); + + let mut model = create_test_model("InftyAI/TestModel", "uuid123"); + model.author = Some("InftyAI".to_string()); + storage.register_model(model).unwrap(); + + // Query with original case should work + let retrieved = storage.get_model("InftyAI/TestModel").unwrap(); + assert!(retrieved.is_some()); + let model_info = retrieved.unwrap(); + // Verify stored as lowercase + assert_eq!(model_info.name, "inftyai/testmodel"); + assert_eq!(model_info.author, Some("inftyai".to_string())); + + // Query with different case should also work + let retrieved2 = storage.get_model("inftyai/testmodel").unwrap(); + assert!(retrieved2.is_some()); + + let retrieved3 = storage.get_model("INFTYAI/TESTMODEL").unwrap(); + assert!(retrieved3.is_some()); + } + + #[test] + fn test_author_filter_case_sensitive() { + let temp_dir = TempDir::new().unwrap(); + let db_path = temp_dir.path().join("test.db"); + let storage = SqliteStorage::new(db_path).unwrap(); + + let mut model = create_test_model("test/model", "uuid123"); + model.author = Some("InftyAI".to_string()); + storage.register_model(model).unwrap(); + + // Filter must use lowercase since data is stored in lowercase + let mut filters = HashMap::new(); + filters.insert("author".to_string(), "inftyai".to_string()); + assert_eq!(storage.load_models(Some(&filters)).unwrap().len(), 1); + + // Non-lowercase filter won't match + filters.clear(); + filters.insert("author".to_string(), "InftyAI".to_string()); + assert_eq!(storage.load_models(Some(&filters)).unwrap().len(), 0); + } } diff --git a/src/storage/storage_trait.rs b/src/storage/storage_trait.rs index 726cae8..bc39b0a 100644 --- a/src/storage/storage_trait.rs +++ b/src/storage/storage_trait.rs @@ -1,10 +1,12 @@ use crate::registry::model_registry::ModelInfo; use std::io; +use std::collections::HashMap; + /// Trait for model storage backends pub trait ModelStorage { - /// Load all models from storage - fn load_models(&self) -> Result, io::Error>; + /// Load models from storage with optional filtering by column values (e.g., author=InftyAI, license=mit) + fn load_models(&self, filters: Option<&HashMap>) -> Result, io::Error>; /// Register (insert or update) a single model fn register_model(&self, model: ModelInfo) -> Result<(), io::Error>; diff --git a/src/system/system_info.rs b/src/system/system_info.rs index 00b49cd..ff62ebc 100644 --- a/src/system/system_info.rs +++ b/src/system/system_info.rs @@ -39,7 +39,7 @@ impl SystemInfo { let cache_size = Self::calculate_cache_size(&cache_dir); let registry = ModelRegistry::new(None); - let models_count = registry.load_models().unwrap_or_default().len(); + let models_count = registry.load_models(None).unwrap_or_default().len(); let gpu_info = Self::detect_gpus(); From 49dbbf9e8c0f1f924d63072409c971559606adf2 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Sat, 25 Apr 2026 11:31:24 +0100 Subject: [PATCH 2/6] add tests Signed-off-by: kerthcet --- src/cli/commands.rs | 179 ++++++++++++++++++++++++++++++++++++++------ 1 file changed, 155 insertions(+), 24 deletions(-) diff --git a/src/cli/commands.rs b/src/cli/commands.rs index e1458bc..6e60e48 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -132,26 +132,19 @@ pub async fn run(cli: Cli) { std::process::exit(1); }); - // Filter models by name pattern if provided (case-insensitive) + // Filter models by name pattern if provided (supports regex) + // Note: model names are already stored in lowercase in the database if let Some(pattern) = &args.pattern { let pattern_lower = pattern.to_lowercase(); - models.retain(|model| { - let name_lower = model.name.to_lowercase(); - if pattern_lower.ends_with("/*") { - // Prefix match: "InftyAI/*" matches "InftyAI/model1", "InftyAI/model2" - let prefix = &pattern_lower[..pattern_lower.len() - 2]; - name_lower.starts_with(prefix) - } else if pattern_lower.contains('*') { - // Wildcard match (simple glob) - let regex_pattern = pattern_lower.replace('*', ".*"); - regex::Regex::new(®ex_pattern) - .map(|re| re.is_match(&name_lower)) - .unwrap_or(false) - } else { - // Exact or substring match - name_lower.contains(&pattern_lower) + match regex::Regex::new(&pattern_lower) { + Ok(re) => { + models.retain(|model| re.is_match(&model.name)); + } + Err(e) => { + eprintln!("Invalid regex pattern '{}': {}", pattern, e); + std::process::exit(1); } - }); + } } let mut table = Table::new(); @@ -447,15 +440,25 @@ mod tests { let temp_dir = TempDir::new().unwrap(); let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); - let model = create_test_model("test/remove-model", "abc123"); + // Create a fake cache directory + let cache_dir = temp_dir.path().join("cache"); + std::fs::create_dir_all(&cache_dir).unwrap(); + std::fs::write(cache_dir.join("model.safetensors"), "fake data").unwrap(); + + let mut model = create_test_model("test/remove-model", "abc123"); + model.metadata.artifact.path = cache_dir.to_string_lossy().to_string(); registry.register_model(model).unwrap(); assert!(registry.get_model("test/remove-model").unwrap().is_some()); + assert!(cache_dir.exists()); - // Simulate RM command - let result = registry.get_model("test/remove-model"); - assert!(result.is_ok()); - assert!(result.unwrap().is_some()); + // Simulate RM command - actually remove the model + registry.remove_model("test/remove-model").unwrap(); + + // Verify model is removed from registry + assert!(registry.get_model("test/remove-model").unwrap().is_none()); + // Verify cache directory is deleted + assert!(!cache_dir.exists()); } #[test] @@ -463,9 +466,9 @@ mod tests { let temp_dir = TempDir::new().unwrap(); let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); - let result = registry.get_model("nonexistent/model"); + // Simulate RM command on non-existent model - should not error + let result = registry.remove_model("nonexistent/model"); assert!(result.is_ok()); - assert!(result.unwrap().is_none()); } #[test] @@ -512,4 +515,132 @@ mod tests { assert_eq!(result.metadata.artifact.revision, "v2"); assert_eq!(result.metadata.artifact.size, 2000); } + + #[test] + fn test_ls_with_pattern_substring() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + registry.register_model(create_test_model("inftyai/model1", "uuid1")).unwrap(); + registry.register_model(create_test_model("openai/gpt2", "uuid2")).unwrap(); + registry.register_model(create_test_model("inftyai/model2", "uuid3")).unwrap(); + + // Simulate: puma ls inftyai + let mut models = registry.load_models(None).unwrap(); + let pattern = "inftyai".to_lowercase(); + let re = regex::Regex::new(&pattern).unwrap(); + models.retain(|model| re.is_match(&model.name)); + + assert_eq!(models.len(), 2); + assert!(models.iter().all(|m| m.name.contains("inftyai"))); + } + + #[test] + fn test_ls_with_pattern_prefix() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + registry.register_model(create_test_model("inftyai/model1", "uuid1")).unwrap(); + registry.register_model(create_test_model("openai/gpt2", "uuid2")).unwrap(); + registry.register_model(create_test_model("meta/llama", "uuid3")).unwrap(); + + // Simulate: puma ls "^inftyai/" + let mut models = registry.load_models(None).unwrap(); + let pattern = "^inftyai/".to_lowercase(); + let re = regex::Regex::new(&pattern).unwrap(); + models.retain(|model| re.is_match(&model.name)); + + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "inftyai/model1"); + } + + #[test] + fn test_ls_with_pattern_case_insensitive() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + registry.register_model(create_test_model("InftyAI/Model1", "uuid1")).unwrap(); + registry.register_model(create_test_model("OpenAI/GPT2", "uuid2")).unwrap(); + + // Simulate: puma ls InftyAI (user input with mixed case) + let mut models = registry.load_models(None).unwrap(); + let pattern = "InftyAI".to_lowercase(); + let re = regex::Regex::new(&pattern).unwrap(); + models.retain(|model| re.is_match(&model.name)); + + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "inftyai/model1"); + } + + #[test] + fn test_ls_with_pattern_alternation() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + registry.register_model(create_test_model("meta/llama-2", "uuid1")).unwrap(); + registry.register_model(create_test_model("meta/llama-3", "uuid2")).unwrap(); + registry.register_model(create_test_model("meta/llama-4", "uuid3")).unwrap(); + + // Simulate: puma ls "llama-(2|3)" + let mut models = registry.load_models(None).unwrap(); + let pattern = "llama-(2|3)".to_lowercase(); + let re = regex::Regex::new(&pattern).unwrap(); + models.retain(|model| re.is_match(&model.name)); + + assert_eq!(models.len(), 2); + assert!(models.iter().any(|m| m.name == "meta/llama-2")); + assert!(models.iter().any(|m| m.name == "meta/llama-3")); + } + + #[test] + fn test_ls_with_sql_filter() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let mut model1 = create_test_model("inftyai/model1", "uuid1"); + model1.author = Some("inftyai".to_string()); + registry.register_model(model1).unwrap(); + + let mut model2 = create_test_model("openai/gpt2", "uuid2"); + model2.author = Some("openai".to_string()); + registry.register_model(model2).unwrap(); + + // Simulate: puma ls -l author=inftyai + let mut filters = std::collections::HashMap::new(); + filters.insert("author".to_string(), "inftyai".to_string()); + let models = registry.load_models(Some(&filters)).unwrap(); + + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "inftyai/model1"); + } + + #[test] + fn test_ls_with_pattern_and_sql_filter() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let mut model1 = create_test_model("inftyai/llama-2", "uuid1"); + model1.author = Some("inftyai".to_string()); + registry.register_model(model1).unwrap(); + + let mut model2 = create_test_model("inftyai/gpt2", "uuid2"); + model2.author = Some("inftyai".to_string()); + registry.register_model(model2).unwrap(); + + let mut model3 = create_test_model("openai/llama-2", "uuid3"); + model3.author = Some("openai".to_string()); + registry.register_model(model3).unwrap(); + + // Simulate: puma ls llama -l author=inftyai + let mut filters = std::collections::HashMap::new(); + filters.insert("author".to_string(), "inftyai".to_string()); + let mut models = registry.load_models(Some(&filters)).unwrap(); + + let pattern = "llama".to_lowercase(); + let re = regex::Regex::new(&pattern).unwrap(); + models.retain(|model| re.is_match(&model.name)); + + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "inftyai/llama-2"); + } } From 60ece644c62241dcbf575b1a743e866fc44f505e Mon Sep 17 00:00:00 2001 From: kerthcet Date: Sat, 25 Apr 2026 12:03:06 +0100 Subject: [PATCH 3/6] better organize the structure Signed-off-by: kerthcet --- src/cli/commands.rs | 303 ++------------------------------------ src/cli/inspect.rs | 160 ++++++++++++++++++++ src/cli/ls.rs | 173 ++++++++++++++++++++++ src/cli/mod.rs | 3 + src/cli/rm.rs | 80 ++++++++++ src/lib.rs | 3 +- tests/integration_test.rs | 281 +++++++++++++++++++++++++++++++++++ 7 files changed, 711 insertions(+), 292 deletions(-) create mode 100644 src/cli/inspect.rs create mode 100644 src/cli/ls.rs create mode 100644 src/cli/rm.rs create mode 100644 tests/integration_test.rs diff --git a/src/cli/commands.rs b/src/cli/commands.rs index 6e60e48..1c8f909 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -1,6 +1,7 @@ use clap::{Parser, Subcommand}; use prettytable::{format, row, Table}; +use crate::cli::{inspect, ls, rm}; use crate::downloader::downloader::Downloader; use crate::downloader::huggingface::HuggingFaceDownloader; use crate::registry::model_registry::ModelRegistry; @@ -107,46 +108,14 @@ pub async fn run(cli: Cli) { Commands::LS(args) => { let registry = ModelRegistry::new(None); - // Parse query filters if provided (e.g., "author=inftyai,license=mit") - let mut query_filters = std::collections::HashMap::new(); - if let Some(query_str) = &args.query { - for pair in query_str.split(',') { - if let Some((key, value)) = pair.split_once('=') { - query_filters.insert(key.trim().to_string(), value.trim().to_string()); - } else { - eprintln!("Invalid query format: {}. Expected key=value pairs separated by commas.", pair); - std::process::exit(1); - } + let models = match ls::execute(®istry, args.pattern.as_deref(), args.query.as_deref()) { + Ok(models) => models, + Err(e) => { + eprintln!("{}", e); + std::process::exit(1); } - } - - // Load models with optional SQL filters - let filter_ref = if query_filters.is_empty() { - None - } else { - Some(&query_filters) }; - let mut models = registry.load_models(filter_ref).unwrap_or_else(|e| { - eprintln!("Failed to query models: {}", e); - std::process::exit(1); - }); - - // Filter models by name pattern if provided (supports regex) - // Note: model names are already stored in lowercase in the database - if let Some(pattern) = &args.pattern { - let pattern_lower = pattern.to_lowercase(); - match regex::Regex::new(&pattern_lower) { - Ok(re) => { - models.retain(|model| re.is_match(&model.name)); - } - Err(e) => { - eprintln!("Invalid regex pattern '{}': {}", pattern, e); - std::process::exit(1); - } - } - } - let mut table = Table::new(); table.set_format( format::FormatBuilder::new() @@ -202,23 +171,9 @@ pub async fn run(cli: Cli) { Commands::RM(args) => { let registry = ModelRegistry::new(None); - // Check if model exists first - match registry.get_model(&args.model) { - Ok(Some(_)) => { - // Delete model (cache + registry) - if let Err(e) = registry.remove_model(&args.model) { - eprintln!("Failed to remove model: {}", e); - std::process::exit(1); - } - } - Ok(None) => { - eprintln!("Model not found: {}", args.model); - std::process::exit(1); - } - Err(e) => { - eprintln!("Failed to load registry: {}", e); - std::process::exit(1); - } + if let Err(e) = rm::execute(®istry, &args.model) { + eprintln!("{}", e); + std::process::exit(1); } } @@ -230,81 +185,10 @@ pub async fn run(cli: Cli) { Commands::INSPECT(args) => { let registry = ModelRegistry::new(None); - match registry.get_model(&args.model) { - Ok(Some(model)) => { - println!("Name: {}", model.name); - println!("Kind: Model"); - println!("Spec:"); - println!( - " Author: {}", - model.author.as_deref().unwrap_or("N/A") - ); - println!( - " Type: {}", - model.r#type.as_deref().unwrap_or("N/A") - ); - println!( - " License: {}", - model - .license - .as_ref() - .map(|s| s.to_uppercase()) - .unwrap_or_else(|| "N/A".to_string()) - ); - println!( - " Model Series: {}", - model.model_series.as_deref().unwrap_or("N/A") - ); - println!( - " Context Window: {}", - model - .metadata - .context_window - .map(|w| crate::utils::format::format_parameters(w as u64)) - .unwrap_or_else(|| "N/A".to_string()) - ); - if let Some(st) = &model.metadata.safetensors { - println!(" Safetensors:"); - if let Some(total) = st.get("total").and_then(|v| v.as_u64()) { - println!( - " Total: {}", - crate::utils::format::format_parameters(total) - ); - } - if let Some(params) = st.get("parameters").and_then(|v| v.as_object()) { - println!(" Parameters:"); - for (dtype, count) in params { - if let Some(num) = count.as_u64() { - println!( - " {:<12} {}", - format!("{}:", dtype), - crate::utils::format::format_parameters(num) - ); - } - } - } - } else { - println!(" Safetensors: N/A"); - } - // Artifact section - println!(" Artifact:"); - println!(" Provider: {}", model.provider); - println!(" Revision: {}", model.metadata.artifact.revision); - println!( - " Size: {}", - format_size_decimal(model.metadata.artifact.size) - ); - println!(" Cache Path: {}", model.metadata.artifact.path); - println!("Status:"); - println!(" Created: {}", format_time_ago(&model.created_at)); - println!(" Updated: {}", format_time_ago(&model.updated_at)); - } - Ok(None) => { - eprintln!("Model not found: {}", args.model); - std::process::exit(1); - } + match inspect::execute(®istry, &args.model) { + Ok(model) => inspect::display(&model), Err(e) => { - eprintln!("Failed to load registry: {}", e); + eprintln!("{}", e); std::process::exit(1); } } @@ -435,42 +319,6 @@ mod tests { assert!(model_info.metadata.safetensors.is_none()); } - #[test] - fn test_rm_command() { - let temp_dir = TempDir::new().unwrap(); - let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); - - // Create a fake cache directory - let cache_dir = temp_dir.path().join("cache"); - std::fs::create_dir_all(&cache_dir).unwrap(); - std::fs::write(cache_dir.join("model.safetensors"), "fake data").unwrap(); - - let mut model = create_test_model("test/remove-model", "abc123"); - model.metadata.artifact.path = cache_dir.to_string_lossy().to_string(); - - registry.register_model(model).unwrap(); - assert!(registry.get_model("test/remove-model").unwrap().is_some()); - assert!(cache_dir.exists()); - - // Simulate RM command - actually remove the model - registry.remove_model("test/remove-model").unwrap(); - - // Verify model is removed from registry - assert!(registry.get_model("test/remove-model").unwrap().is_none()); - // Verify cache directory is deleted - assert!(!cache_dir.exists()); - } - - #[test] - fn test_rm_command_nonexistent() { - let temp_dir = TempDir::new().unwrap(); - let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); - - // Simulate RM command on non-existent model - should not error - let result = registry.remove_model("nonexistent/model"); - assert!(result.is_ok()); - } - #[test] fn test_revision_truncation() { let long_revision = "abc123def456ghi789jkl012"; @@ -516,131 +364,4 @@ mod tests { assert_eq!(result.metadata.artifact.size, 2000); } - #[test] - fn test_ls_with_pattern_substring() { - let temp_dir = TempDir::new().unwrap(); - let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); - - registry.register_model(create_test_model("inftyai/model1", "uuid1")).unwrap(); - registry.register_model(create_test_model("openai/gpt2", "uuid2")).unwrap(); - registry.register_model(create_test_model("inftyai/model2", "uuid3")).unwrap(); - - // Simulate: puma ls inftyai - let mut models = registry.load_models(None).unwrap(); - let pattern = "inftyai".to_lowercase(); - let re = regex::Regex::new(&pattern).unwrap(); - models.retain(|model| re.is_match(&model.name)); - - assert_eq!(models.len(), 2); - assert!(models.iter().all(|m| m.name.contains("inftyai"))); - } - - #[test] - fn test_ls_with_pattern_prefix() { - let temp_dir = TempDir::new().unwrap(); - let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); - - registry.register_model(create_test_model("inftyai/model1", "uuid1")).unwrap(); - registry.register_model(create_test_model("openai/gpt2", "uuid2")).unwrap(); - registry.register_model(create_test_model("meta/llama", "uuid3")).unwrap(); - - // Simulate: puma ls "^inftyai/" - let mut models = registry.load_models(None).unwrap(); - let pattern = "^inftyai/".to_lowercase(); - let re = regex::Regex::new(&pattern).unwrap(); - models.retain(|model| re.is_match(&model.name)); - - assert_eq!(models.len(), 1); - assert_eq!(models[0].name, "inftyai/model1"); - } - - #[test] - fn test_ls_with_pattern_case_insensitive() { - let temp_dir = TempDir::new().unwrap(); - let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); - - registry.register_model(create_test_model("InftyAI/Model1", "uuid1")).unwrap(); - registry.register_model(create_test_model("OpenAI/GPT2", "uuid2")).unwrap(); - - // Simulate: puma ls InftyAI (user input with mixed case) - let mut models = registry.load_models(None).unwrap(); - let pattern = "InftyAI".to_lowercase(); - let re = regex::Regex::new(&pattern).unwrap(); - models.retain(|model| re.is_match(&model.name)); - - assert_eq!(models.len(), 1); - assert_eq!(models[0].name, "inftyai/model1"); - } - - #[test] - fn test_ls_with_pattern_alternation() { - let temp_dir = TempDir::new().unwrap(); - let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); - - registry.register_model(create_test_model("meta/llama-2", "uuid1")).unwrap(); - registry.register_model(create_test_model("meta/llama-3", "uuid2")).unwrap(); - registry.register_model(create_test_model("meta/llama-4", "uuid3")).unwrap(); - - // Simulate: puma ls "llama-(2|3)" - let mut models = registry.load_models(None).unwrap(); - let pattern = "llama-(2|3)".to_lowercase(); - let re = regex::Regex::new(&pattern).unwrap(); - models.retain(|model| re.is_match(&model.name)); - - assert_eq!(models.len(), 2); - assert!(models.iter().any(|m| m.name == "meta/llama-2")); - assert!(models.iter().any(|m| m.name == "meta/llama-3")); - } - - #[test] - fn test_ls_with_sql_filter() { - let temp_dir = TempDir::new().unwrap(); - let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); - - let mut model1 = create_test_model("inftyai/model1", "uuid1"); - model1.author = Some("inftyai".to_string()); - registry.register_model(model1).unwrap(); - - let mut model2 = create_test_model("openai/gpt2", "uuid2"); - model2.author = Some("openai".to_string()); - registry.register_model(model2).unwrap(); - - // Simulate: puma ls -l author=inftyai - let mut filters = std::collections::HashMap::new(); - filters.insert("author".to_string(), "inftyai".to_string()); - let models = registry.load_models(Some(&filters)).unwrap(); - - assert_eq!(models.len(), 1); - assert_eq!(models[0].name, "inftyai/model1"); - } - - #[test] - fn test_ls_with_pattern_and_sql_filter() { - let temp_dir = TempDir::new().unwrap(); - let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); - - let mut model1 = create_test_model("inftyai/llama-2", "uuid1"); - model1.author = Some("inftyai".to_string()); - registry.register_model(model1).unwrap(); - - let mut model2 = create_test_model("inftyai/gpt2", "uuid2"); - model2.author = Some("inftyai".to_string()); - registry.register_model(model2).unwrap(); - - let mut model3 = create_test_model("openai/llama-2", "uuid3"); - model3.author = Some("openai".to_string()); - registry.register_model(model3).unwrap(); - - // Simulate: puma ls llama -l author=inftyai - let mut filters = std::collections::HashMap::new(); - filters.insert("author".to_string(), "inftyai".to_string()); - let mut models = registry.load_models(Some(&filters)).unwrap(); - - let pattern = "llama".to_lowercase(); - let re = regex::Regex::new(&pattern).unwrap(); - models.retain(|model| re.is_match(&model.name)); - - assert_eq!(models.len(), 1); - assert_eq!(models[0].name, "inftyai/llama-2"); - } } diff --git a/src/cli/inspect.rs b/src/cli/inspect.rs new file mode 100644 index 0000000..fd2593d --- /dev/null +++ b/src/cli/inspect.rs @@ -0,0 +1,160 @@ +use crate::registry::model_registry::{ModelInfo, ModelRegistry}; +use crate::utils::format::{format_parameters, format_size_decimal, format_time_ago}; + +/// Execute the INSPECT command logic +pub fn execute(registry: &ModelRegistry, model_name: &str) -> Result { + match registry.get_model(model_name) { + Ok(Some(model)) => Ok(model), + Ok(None) => Err(format!("Model not found: {}", model_name)), + Err(e) => Err(format!("Failed to load registry: {}", e)), + } +} + +/// Display the model information +pub fn display(model: &ModelInfo) { + println!("Name: {}", model.name); + println!("Kind: Model"); + println!("Spec:"); + println!( + " Author: {}", + model.author.as_deref().unwrap_or("N/A") + ); + println!( + " Type: {}", + model.r#type.as_deref().unwrap_or("N/A") + ); + println!( + " License: {}", + model + .license + .as_ref() + .map(|s| s.to_uppercase()) + .unwrap_or_else(|| "N/A".to_string()) + ); + println!( + " Model Series: {}", + model.model_series.as_deref().unwrap_or("N/A") + ); + println!( + " Context Window: {}", + model + .metadata + .context_window + .map(|w| format_parameters(w as u64)) + .unwrap_or_else(|| "N/A".to_string()) + ); + + if let Some(st) = &model.metadata.safetensors { + println!(" Safetensors:"); + if let Some(total) = st.get("total").and_then(|v| v.as_u64()) { + println!(" Total: {}", format_parameters(total)); + } + if let Some(params) = st.get("parameters").and_then(|v| v.as_object()) { + println!(" Parameters:"); + for (dtype, count) in params { + if let Some(num) = count.as_u64() { + println!( + " {:<12} {}", + format!("{}:", dtype), + format_parameters(num) + ); + } + } + } + } else { + println!(" Safetensors: N/A"); + } + + // Artifact section + println!(" Artifact:"); + println!(" Provider: {}", model.provider); + println!(" Revision: {}", model.metadata.artifact.revision); + println!( + " Size: {}", + format_size_decimal(model.metadata.artifact.size) + ); + println!(" Cache Path: {}", model.metadata.artifact.path); + println!("Status:"); + println!(" Created: {}", format_time_ago(&model.created_at)); + println!(" Updated: {}", format_time_ago(&model.updated_at)); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::registry::model_registry::{ArtifactInfo, ModelInfo, ModelMetadata}; + use tempfile::TempDir; + + fn create_test_model(name: &str, uuid: &str) -> ModelInfo { + let safetensors = serde_json::json!({ + "parameters": {"F32": 7000000000u64}, + "total": 7000000000u64 + }); + + ModelInfo { + uuid: uuid.to_string(), + name: name.to_string(), + author: Some("test-author".to_string()), + r#type: Some("text-generation".to_string()), + model_series: Some("gpt2".to_string()), + provider: "huggingface".to_string(), + license: Some("mit".to_string()), + created_at: "2025-01-01T00:00:00Z".to_string(), + updated_at: "2025-01-01T00:00:00Z".to_string(), + metadata: ModelMetadata { + artifact: ArtifactInfo { + revision: uuid.to_string(), + size: 1000, + path: "/tmp/test".to_string(), + }, + context_window: Some(2048), + safetensors: Some(safetensors), + }, + } + } + + #[test] + fn test_execute_inspect() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = create_test_model("inftyai/test-model", "abc123"); + registry.register_model(model).unwrap(); + + let result = execute(®istry, "inftyai/test-model"); + assert!(result.is_ok()); + + let model_info = result.unwrap(); + assert_eq!(model_info.name, "inftyai/test-model"); + assert_eq!(model_info.provider, "huggingface"); + } + + #[test] + fn test_execute_inspect_nonexistent() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let result = execute(®istry, "nonexistent/model"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Model not found")); + } + + #[test] + fn test_execute_inspect_case_insensitive() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let model = create_test_model("InftyAI/TestModel", "abc123"); + registry.register_model(model).unwrap(); + + // Can query with different cases + let result = execute(®istry, "InftyAI/TestModel"); + assert!(result.is_ok()); + + let result = execute(®istry, "inftyai/testmodel"); + assert!(result.is_ok()); + + let result = execute(®istry, "INFTYAI/TESTMODEL"); + assert!(result.is_ok()); + } +} diff --git a/src/cli/ls.rs b/src/cli/ls.rs new file mode 100644 index 0000000..7d40034 --- /dev/null +++ b/src/cli/ls.rs @@ -0,0 +1,173 @@ +use crate::registry::model_registry::{ModelInfo, ModelRegistry}; +use std::collections::HashMap; + +/// Execute the LS command logic +pub fn execute( + registry: &ModelRegistry, + pattern: Option<&str>, + query: Option<&str>, +) -> Result, String> { + // Parse query filters if provided + let mut query_filters = HashMap::new(); + if let Some(query_str) = query { + for pair in query_str.split(',') { + if let Some((key, value)) = pair.split_once('=') { + query_filters.insert(key.trim().to_string(), value.trim().to_string()); + } else { + return Err(format!( + "Invalid query format: {}. Expected key=value pairs separated by commas.", + pair + )); + } + } + } + + // Load models with optional SQL filters + let filter_ref = if query_filters.is_empty() { + None + } else { + Some(&query_filters) + }; + + let mut models = registry + .load_models(filter_ref) + .map_err(|e| format!("Failed to query models: {}", e))?; + + // Filter models by name pattern if provided (supports regex) + if let Some(pattern_str) = pattern { + let pattern_lower = pattern_str.to_lowercase(); + match regex::Regex::new(&pattern_lower) { + Ok(re) => { + models.retain(|model| re.is_match(&model.name)); + } + Err(e) => { + return Err(format!("Invalid regex pattern '{}': {}", pattern_str, e)); + } + } + } + + Ok(models) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::registry::model_registry::{ArtifactInfo, ModelInfo, ModelMetadata}; + use tempfile::TempDir; + + fn create_test_model(name: &str, uuid: &str, author: &str) -> ModelInfo { + let safetensors = serde_json::json!({ + "parameters": {"F32": 7000000000u64}, + "total": 7000000000u64 + }); + + ModelInfo { + uuid: uuid.to_string(), + name: name.to_string(), + author: Some(author.to_string()), + r#type: Some("text-generation".to_string()), + model_series: Some("gpt2".to_string()), + provider: "huggingface".to_string(), + license: Some("mit".to_string()), + created_at: "2025-01-01T00:00:00Z".to_string(), + updated_at: "2025-01-01T00:00:00Z".to_string(), + metadata: ModelMetadata { + artifact: ArtifactInfo { + revision: uuid.to_string(), + size: 1000, + path: "/tmp/test".to_string(), + }, + context_window: Some(2048), + safetensors: Some(safetensors), + }, + } + } + + #[test] + fn test_execute_ls_substring() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + registry + .register_model(create_test_model("inftyai/model1", "uuid1", "inftyai")) + .unwrap(); + registry + .register_model(create_test_model("openai/gpt2", "uuid2", "openai")) + .unwrap(); + registry + .register_model(create_test_model("inftyai/model2", "uuid3", "inftyai")) + .unwrap(); + + let models = execute(®istry, Some("inftyai"), None).unwrap(); + assert_eq!(models.len(), 2); + assert!(models.iter().all(|m| m.name.contains("inftyai"))); + } + + #[test] + fn test_execute_ls_prefix() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + registry + .register_model(create_test_model("inftyai/model1", "uuid1", "inftyai")) + .unwrap(); + registry + .register_model(create_test_model("openai/gpt2", "uuid2", "openai")) + .unwrap(); + + let models = execute(®istry, Some("^inftyai/"), None).unwrap(); + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "inftyai/model1"); + } + + #[test] + fn test_execute_ls_case_insensitive() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + registry + .register_model(create_test_model("InftyAI/Model1", "uuid1", "InftyAI")) + .unwrap(); + + let models = execute(®istry, Some("InftyAI"), None).unwrap(); + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "inftyai/model1"); + } + + #[test] + fn test_execute_ls_sql_filter() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + registry + .register_model(create_test_model("inftyai/model1", "uuid1", "inftyai")) + .unwrap(); + registry + .register_model(create_test_model("openai/gpt2", "uuid2", "openai")) + .unwrap(); + + let models = execute(®istry, None, Some("author=inftyai")).unwrap(); + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "inftyai/model1"); + } + + #[test] + fn test_execute_ls_pattern_and_filter() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + registry + .register_model(create_test_model("inftyai/llama-2", "uuid1", "inftyai")) + .unwrap(); + registry + .register_model(create_test_model("inftyai/gpt2", "uuid2", "inftyai")) + .unwrap(); + registry + .register_model(create_test_model("openai/llama-2", "uuid3", "openai")) + .unwrap(); + + let models = execute(®istry, Some("llama"), Some("author=inftyai")).unwrap(); + assert_eq!(models.len(), 1); + assert_eq!(models[0].name, "inftyai/llama-2"); + } +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 82b6da3..65f2452 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -1 +1,4 @@ pub mod commands; +pub mod inspect; +pub mod ls; +pub mod rm; diff --git a/src/cli/rm.rs b/src/cli/rm.rs new file mode 100644 index 0000000..e4e44d7 --- /dev/null +++ b/src/cli/rm.rs @@ -0,0 +1,80 @@ +use crate::registry::model_registry::ModelRegistry; + +/// Execute the RM command logic +pub fn execute(registry: &ModelRegistry, model_name: &str) -> Result<(), String> { + match registry.get_model(model_name) { + Ok(Some(_)) => registry + .remove_model(model_name) + .map_err(|e| format!("Failed to remove model: {}", e)), + Ok(None) => Err(format!("Model not found: {}", model_name)), + Err(e) => Err(format!("Failed to load registry: {}", e)), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::registry::model_registry::{ArtifactInfo, ModelInfo, ModelMetadata}; + use tempfile::TempDir; + + fn create_test_model(name: &str, uuid: &str) -> ModelInfo { + let safetensors = serde_json::json!({ + "parameters": {"F32": 7000000000u64}, + "total": 7000000000u64 + }); + + ModelInfo { + uuid: uuid.to_string(), + name: name.to_string(), + author: Some("test-author".to_string()), + r#type: Some("text-generation".to_string()), + model_series: Some("gpt2".to_string()), + provider: "huggingface".to_string(), + license: Some("mit".to_string()), + created_at: "2025-01-01T00:00:00Z".to_string(), + updated_at: "2025-01-01T00:00:00Z".to_string(), + metadata: ModelMetadata { + artifact: ArtifactInfo { + revision: uuid.to_string(), + size: 1000, + path: "/tmp/test".to_string(), + }, + context_window: Some(2048), + safetensors: Some(safetensors), + }, + } + } + + #[test] + fn test_execute_rm() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let cache_dir = temp_dir.path().join("cache"); + std::fs::create_dir_all(&cache_dir).unwrap(); + std::fs::write(cache_dir.join("model.safetensors"), "fake data").unwrap(); + + let mut model = create_test_model("test/remove-model", "abc123"); + model.metadata.artifact.path = cache_dir.to_string_lossy().to_string(); + + registry.register_model(model).unwrap(); + assert!(registry.get_model("test/remove-model").unwrap().is_some()); + assert!(cache_dir.exists()); + + let result = execute(®istry, "test/remove-model"); + assert!(result.is_ok()); + + assert!(registry.get_model("test/remove-model").unwrap().is_none()); + assert!(!cache_dir.exists()); + } + + #[test] + fn test_execute_rm_nonexistent() { + let temp_dir = TempDir::new().unwrap(); + let registry = ModelRegistry::new(Some(temp_dir.path().to_path_buf())); + + let result = execute(®istry, "nonexistent/model"); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("Model not found")); + } +} diff --git a/src/lib.rs b/src/lib.rs index 8b13789..4e70de3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1 +1,2 @@ - +// lib.rs is intentionally minimal - puma is a binary-first application +// Internal modules are not exposed as public API diff --git a/tests/integration_test.rs b/tests/integration_test.rs new file mode 100644 index 0000000..a448967 --- /dev/null +++ b/tests/integration_test.rs @@ -0,0 +1,281 @@ +use std::process::Command; +use tempfile::TempDir; + +/// Helper to run puma command with custom PUMA_HOME +fn run_puma(home_dir: &str, args: &[&str]) -> std::process::Output { + Command::new(env!("CARGO_BIN_EXE_puma")) + .env("PUMA_HOME", home_dir) + .args(args) + .output() + .expect("Failed to execute puma command") +} + +/// Helper to check if output contains a string +fn output_contains(output: &std::process::Output, text: &str) -> bool { + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + stdout.contains(text) || stderr.contains(text) +} + +#[test] +fn test_pull_command_with_provider() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + // Pull a real model + let output = run_puma( + home, + &["pull", "inftyai/tiny-random-gpt2", "-p", "huggingface"], + ); + assert!(output.status.success()); + + // Verify model appears in ls + let output = run_puma(home, &["ls"]); + assert!(output.status.success()); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("inftyai/tiny-random-gpt2")); + + // Verify model can be inspected + let output = run_puma(home, &["inspect", "inftyai/tiny-random-gpt2"]); + assert!(output.status.success()); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("Name: inftyai/tiny-random-gpt2")); + assert!(stdout.contains("Provider")); + assert!(stdout.contains("huggingface")); + + // Verify model can be removed + let output = run_puma(home, &["rm", "inftyai/tiny-random-gpt2"]); + assert!(output.status.success()); + assert!(output_contains(&output, "Successfully removed model")); + + // Verify model is gone + let output = run_puma(home, &["ls"]); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(!stdout.contains("inftyai/tiny-random-gpt2")); +} + +#[test] +fn test_rm_nonexistent() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["rm", "nonexistent/model"]); + assert!(!output.status.success()); + assert!(output_contains(&output, "Model not found")); +} + +#[test] +fn test_inspect_nonexistent() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["inspect", "nonexistent/model"]); + assert!(!output.status.success()); + assert!(output_contains(&output, "Model not found")); +} + +#[test] +fn test_version() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["version"]); + assert!(output.status.success()); + assert!(output_contains(&output, "PUMA")); +} + +#[test] +fn test_info() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["info"]); + assert!(output.status.success()); + + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("Version")); + assert!(stdout.contains("Models")); +} + +#[test] +fn test_ls_with_invalid_regex() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["ls", "[invalid"]); + assert!(!output.status.success()); + assert!(output_contains(&output, "Invalid regex pattern")); +} + +#[test] +fn test_ls_with_invalid_query() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["ls", "-l", "invalid_format"]); + assert!(!output.status.success()); + assert!(output_contains(&output, "Invalid query format")); +} + +#[test] +fn test_ls_with_invalid_filter_column() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["ls", "-l", "invalid_column=value"]); + assert!(!output.status.success()); + assert!(output_contains(&output, "Invalid filter column")); +} + +#[test] +fn test_ps_command() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["ps"]); + assert!(output.status.success()); + + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("NAME")); + assert!(stdout.contains("PROVIDER")); + assert!(stdout.contains("MODEL")); +} + +#[test] +fn test_pull_command_invalid_model() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + // Pull with invalid model name should fail + let output = run_puma(home, &["pull", "invalid/nonexistent-model-12345"]); + assert!(!output.status.success()); +} + +#[test] +fn test_pull_command_modelscope_provider() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + // Test modelscope provider (currently just prints message) + let output = run_puma(home, &["pull", "test/model", "-p", "modelscope"]); + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("Downloading model from Modelscope") || !output.status.success()); +} + +#[test] +fn test_run_command() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["run"]); + assert!(output.status.success()); + assert!(output_contains(&output, "Creating and running a new model")); +} + +#[test] +fn test_stop_command() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["stop"]); + assert!(output.status.success()); + assert!(output_contains(&output, "Stopping one running model")); +} + +#[test] +fn test_ls_with_pattern_no_models() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + // Pattern matching on empty registry should succeed + let output = run_puma(home, &["ls", "test"]); + assert!(output.status.success()); +} + +#[test] +fn test_ls_with_sql_filter_no_models() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + // SQL filter on empty registry should succeed + let output = run_puma(home, &["ls", "-l", "author=test"]); + assert!(output.status.success()); +} + +#[test] +fn test_ls_with_multiple_filters() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + // Multiple filters separated by comma + let output = run_puma(home, &["ls", "-l", "author=test,license=mit"]); + assert!(output.status.success()); +} + +#[test] +fn test_ls_with_pattern_and_filter_combined() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + // Both pattern and filter should work together + let output = run_puma(home, &["ls", "test", "-l", "author=test"]); + assert!(output.status.success()); +} + +#[test] +fn test_invalid_command() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["invalid-command"]); + assert!(!output.status.success()); +} + +#[test] +fn test_help_command() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["--help"]); + assert!(output.status.success()); + + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("PUMA CLI")); + assert!(stdout.contains("Commands:")); +} + +#[test] +fn test_ls_help() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["ls", "--help"]); + assert!(output.status.success()); + + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("List local models")); +} + +#[test] +fn test_rm_help() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["rm", "--help"]); + assert!(output.status.success()); + + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("Remove one model")); +} + +#[test] +fn test_inspect_help() { + let temp_dir = TempDir::new().unwrap(); + let home = temp_dir.path().to_str().unwrap(); + + let output = run_puma(home, &["inspect", "--help"]); + assert!(output.status.success()); + + let stdout = String::from_utf8_lossy(&output.stdout); + assert!(stdout.contains("Return detailed information about a model")); +} From af777b993f34ef9bc48e2a407f48fe118704a7a5 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Sat, 25 Apr 2026 12:25:52 +0100 Subject: [PATCH 4/6] fix task type error Signed-off-by: kerthcet --- src/cli/commands.rs | 29 +++++++++++++++++------------ src/cli/inspect.rs | 6 +++--- src/cli/ls.rs | 2 +- src/cli/rm.rs | 2 +- src/downloader/huggingface.rs | 2 +- src/registry/model_registry.rs | 4 ++-- src/storage/sqlite.rs | 24 ++++++++++++------------ 7 files changed, 37 insertions(+), 32 deletions(-) diff --git a/src/cli/commands.rs b/src/cli/commands.rs index 1c8f909..8f22390 100644 --- a/src/cli/commands.rs +++ b/src/cli/commands.rs @@ -108,13 +108,14 @@ pub async fn run(cli: Cli) { Commands::LS(args) => { let registry = ModelRegistry::new(None); - let models = match ls::execute(®istry, args.pattern.as_deref(), args.query.as_deref()) { - Ok(models) => models, - Err(e) => { - eprintln!("{}", e); - std::process::exit(1); - } - }; + let models = + match ls::execute(®istry, args.pattern.as_deref(), args.query.as_deref()) { + Ok(models) => models, + Err(e) => { + eprintln!("{}", e); + std::process::exit(1); + } + }; let mut table = Table::new(); table.set_format( @@ -123,7 +124,9 @@ pub async fn run(cli: Cli) { .padding(0, 1) .build(), ); - table.add_row(row!["MODEL", "PROVIDER", "REVISION", "SIZE", "AGE"]); + table.add_row(row![ + "MODEL", "TASK", "PROVIDER", "REVISION", "SIZE", "CREATED" + ]); for model in models { let size_str = format_size_decimal(model.metadata.artifact.size); @@ -135,8 +138,11 @@ pub async fn run(cli: Cli) { let created_str = format_time_ago(&model.created_at); + let model_task = model.task.as_deref().unwrap_or("N/A"); + table.add_row(row![ model.name, + model_task, model.provider, revision_short, size_str, @@ -219,7 +225,7 @@ mod tests { uuid: revision.to_string(), name: name.to_string(), author: Some("test-author".to_string()), - r#type: Some("text-generation".to_string()), + task: Some("text-generation".to_string()), model_series: Some("gpt2".to_string()), provider: "huggingface".to_string(), license: Some("mit".to_string()), @@ -268,7 +274,7 @@ mod tests { let mut model = create_test_model("test/gpt-model", "abc123def456"); model.author = Some("test-org".to_string()); - model.r#type = Some("text-generation".to_string()); + model.task = Some("text-generation".to_string()); model.license = Some("mit".to_string()); model.updated_at = "2025-01-02T00:00:00Z".to_string(); @@ -282,7 +288,7 @@ mod tests { assert_eq!(model_info.created_at, "2025-01-01T00:00:00Z"); assert_eq!(model_info.updated_at, "2025-01-02T00:00:00Z"); assert_eq!(model_info.author, Some("test-org".to_string())); - assert_eq!(model_info.r#type, Some("text-generation".to_string())); + assert_eq!(model_info.task, Some("text-generation".to_string())); assert_eq!(model_info.license, Some("mit".to_string())); assert_eq!(model_info.model_series, Some("gpt2".to_string())); assert_eq!(model_info.metadata.context_window, Some(2048)); @@ -363,5 +369,4 @@ mod tests { assert_eq!(result.metadata.artifact.revision, "v2"); assert_eq!(result.metadata.artifact.size, 2000); } - } diff --git a/src/cli/inspect.rs b/src/cli/inspect.rs index fd2593d..e67fbd5 100644 --- a/src/cli/inspect.rs +++ b/src/cli/inspect.rs @@ -20,8 +20,8 @@ pub fn display(model: &ModelInfo) { model.author.as_deref().unwrap_or("N/A") ); println!( - " Type: {}", - model.r#type.as_deref().unwrap_or("N/A") + " Task: {}", + model.task.as_deref().unwrap_or("N/A") ); println!( " License: {}", @@ -95,7 +95,7 @@ mod tests { uuid: uuid.to_string(), name: name.to_string(), author: Some("test-author".to_string()), - r#type: Some("text-generation".to_string()), + task: Some("text-generation".to_string()), model_series: Some("gpt2".to_string()), provider: "huggingface".to_string(), license: Some("mit".to_string()), diff --git a/src/cli/ls.rs b/src/cli/ls.rs index 7d40034..28a19f1 100644 --- a/src/cli/ls.rs +++ b/src/cli/ls.rs @@ -65,7 +65,7 @@ mod tests { uuid: uuid.to_string(), name: name.to_string(), author: Some(author.to_string()), - r#type: Some("text-generation".to_string()), + task: Some("text-generation".to_string()), model_series: Some("gpt2".to_string()), provider: "huggingface".to_string(), license: Some("mit".to_string()), diff --git a/src/cli/rm.rs b/src/cli/rm.rs index e4e44d7..818f490 100644 --- a/src/cli/rm.rs +++ b/src/cli/rm.rs @@ -27,7 +27,7 @@ mod tests { uuid: uuid.to_string(), name: name.to_string(), author: Some("test-author".to_string()), - r#type: Some("text-generation".to_string()), + task: Some("text-generation".to_string()), model_series: Some("gpt2".to_string()), provider: "huggingface".to_string(), license: Some("mit".to_string()), diff --git a/src/downloader/huggingface.rs b/src/downloader/huggingface.rs index 6720e23..72ffae1 100644 --- a/src/downloader/huggingface.rs +++ b/src/downloader/huggingface.rs @@ -301,7 +301,7 @@ impl Downloader for HuggingFaceDownloader { uuid: sha, // Use revision SHA as UUID for now name: name.to_string(), author: author_from_api, - r#type: task_from_api, + task: task_from_api, model_series: model_series_from_api, provider: "huggingface".to_string(), license: license_from_api, diff --git a/src/registry/model_registry.rs b/src/registry/model_registry.rs index 2272302..68de7e8 100644 --- a/src/registry/model_registry.rs +++ b/src/registry/model_registry.rs @@ -29,7 +29,7 @@ pub struct ModelInfo { pub name: String, pub author: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub r#type: Option, // Task type (image-text-to-text, text-generation) + pub task: Option, // Task type (image-text-to-text, text-generation) #[serde(skip_serializing_if = "Option::is_none")] pub model_series: Option, // Architecture series (qwen3_5, gpt2, llama3) pub provider: String, @@ -117,7 +117,7 @@ mod tests { uuid: revision.to_string(), name: name.to_string(), author: Some("test-author".to_string()), - r#type: Some("text-generation".to_string()), + task: Some("text-generation".to_string()), model_series: Some("gpt2".to_string()), provider: "huggingface".to_string(), license: Some("mit".to_string()), diff --git a/src/storage/sqlite.rs b/src/storage/sqlite.rs index fd11ec0..23a6923 100644 --- a/src/storage/sqlite.rs +++ b/src/storage/sqlite.rs @@ -26,7 +26,7 @@ impl SqliteStorage { uuid TEXT PRIMARY KEY, name TEXT NOT NULL UNIQUE, author TEXT, - type TEXT, + task TEXT, model_series TEXT, provider TEXT NOT NULL, license TEXT, @@ -36,7 +36,7 @@ impl SqliteStorage { CHECK(json_valid(metadata)) ); CREATE INDEX idx_author ON models(author); - CREATE INDEX idx_type ON models(type); + CREATE INDEX idx_task ON models(task); CREATE INDEX idx_model_series ON models(model_series); CREATE INDEX idx_provider ON models(provider); CREATE INDEX idx_license ON models(license); @@ -68,7 +68,7 @@ impl ModelStorage for SqliteStorage { if let Some(filter_map) = filters { // Allowed columns for filtering (prevent SQL injection) - let allowed_columns = ["author", "type", "model_series", "provider", "license"]; + let allowed_columns = ["author", "task", "model_series", "provider", "license"]; for (key, value) in filter_map { if allowed_columns.contains(&key.as_str()) { @@ -84,13 +84,13 @@ impl ModelStorage for SqliteStorage { } let query = if where_clauses.is_empty() { - "SELECT uuid, name, author, type, model_series, provider, license, + "SELECT uuid, name, author, task, model_series, provider, license, metadata, created_at, updated_at FROM models" .to_string() } else { format!( - "SELECT uuid, name, author, type, model_series, provider, license, + "SELECT uuid, name, author, task, model_series, provider, license, metadata, created_at, updated_at FROM models WHERE {}", @@ -113,7 +113,7 @@ impl ModelStorage for SqliteStorage { uuid: row.get(0)?, name: row.get(1)?, author: row.get(2)?, - r#type: row.get(3)?, + task: row.get(3)?, model_series: row.get(4)?, provider: row.get(5)?, license: row.get(6)?, @@ -141,13 +141,13 @@ impl ModelStorage for SqliteStorage { conn.execute( "INSERT INTO models - (uuid, name, author, type, model_series, provider, license, + (uuid, name, author, task, model_series, provider, license, metadata, created_at, updated_at) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10) ON CONFLICT(name) DO UPDATE SET uuid = excluded.uuid, author = excluded.author, - type = excluded.type, + task = excluded.task, model_series = excluded.model_series, provider = excluded.provider, license = excluded.license, @@ -157,7 +157,7 @@ impl ModelStorage for SqliteStorage { &model.uuid, &name_lower, author_lower.as_deref(), - model.r#type.as_deref(), + model.task.as_deref(), model.model_series.as_deref(), &model.provider, model.license.as_deref(), @@ -190,7 +190,7 @@ impl ModelStorage for SqliteStorage { let name_lower = name.to_lowercase(); let result = conn.query_row( - "SELECT uuid, name, author, type, model_series, provider, license, + "SELECT uuid, name, author, task, model_series, provider, license, metadata, created_at, updated_at FROM models WHERE name = ?1", params![name_lower], @@ -203,7 +203,7 @@ impl ModelStorage for SqliteStorage { uuid: row.get(0)?, name: row.get(1)?, author: row.get(2)?, - r#type: row.get(3)?, + task: row.get(3)?, model_series: row.get(4)?, provider: row.get(5)?, license: row.get(6)?, @@ -238,7 +238,7 @@ mod tests { uuid: uuid.to_string(), name: name.to_string(), author: Some("test-author".to_string()), - r#type: Some("text-generation".to_string()), + task: Some("text-generation".to_string()), model_series: Some("gpt2".to_string()), provider: "huggingface".to_string(), license: Some("mit".to_string()), From 26eabc3b0ea476cd30f9930e2e90e4b15cd34b49 Mon Sep 17 00:00:00 2001 From: kerthcet Date: Sat, 25 Apr 2026 12:44:53 +0100 Subject: [PATCH 5/6] support all Signed-off-by: kerthcet --- src/cli/inspect.rs | 42 +++++++++++++++++++-------------------- tests/integration_test.rs | 4 ++-- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/src/cli/inspect.rs b/src/cli/inspect.rs index e67fbd5..e8253a6 100644 --- a/src/cli/inspect.rs +++ b/src/cli/inspect.rs @@ -12,19 +12,19 @@ pub fn execute(registry: &ModelRegistry, model_name: &str) -> Result Date: Sat, 25 Apr 2026 12:46:17 +0100 Subject: [PATCH 6/6] fix lint Signed-off-by: kerthcet --- src/registry/model_registry.rs | 5 ++++- src/storage/storage_trait.rs | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/registry/model_registry.rs b/src/registry/model_registry.rs index 68de7e8..773fe68 100644 --- a/src/registry/model_registry.rs +++ b/src/registry/model_registry.rs @@ -57,7 +57,10 @@ impl ModelRegistry { } } - pub fn load_models(&self, filters: Option<&HashMap>) -> Result, std::io::Error> { + pub fn load_models( + &self, + filters: Option<&HashMap>, + ) -> Result, std::io::Error> { self.storage.load_models(filters) } diff --git a/src/storage/storage_trait.rs b/src/storage/storage_trait.rs index bc39b0a..3968704 100644 --- a/src/storage/storage_trait.rs +++ b/src/storage/storage_trait.rs @@ -6,7 +6,10 @@ use std::collections::HashMap; /// Trait for model storage backends pub trait ModelStorage { /// Load models from storage with optional filtering by column values (e.g., author=InftyAI, license=mit) - fn load_models(&self, filters: Option<&HashMap>) -> Result, io::Error>; + fn load_models( + &self, + filters: Option<&HashMap>, + ) -> Result, io::Error>; /// Register (insert or update) a single model fn register_model(&self, model: ModelInfo) -> Result<(), io::Error>;