Skip to content

Commit

Permalink
refactor: parser improvements to be idiomatic for clap 3.2.x
Browse files Browse the repository at this point in the history
  • Loading branch information
Rob Patro committed Jul 14, 2022
1 parent b9769a6 commit 347c342
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 57 deletions.
17 changes: 9 additions & 8 deletions src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use sprs::TriMatI;
use std::collections::HashSet;
use std::fs;
use std::io::{BufRead, BufReader, BufWriter, Write};
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
Expand All @@ -31,13 +32,13 @@ pub fn infer(
//num_bootstraps,
//init_uniform,
//summary_stat,
count_mat_file: String,
eq_label_file: String,
count_mat_file: &PathBuf,
eq_label_file: &PathBuf,
usa_mode: bool,
_use_mtx: bool,
num_threads: u32,
filter_list: Option<String>,
output_dir: String,
filter_list: Option<&PathBuf>,
output_dir: &PathBuf,
log: &slog::Logger,
) -> anyhow::Result<()> {
info!(
Expand All @@ -46,7 +47,7 @@ pub fn infer(
);

// get the path for the equivalence class count matrix
let count_mat_path = std::path::Path::new(&count_mat_file);
let count_mat_path = std::path::Path::new(count_mat_file);
let count_mat_parent = count_mat_path
.parent()
.unwrap_or_else(|| panic!("cannot get parent path of {:?}", count_mat_path));
Expand All @@ -71,7 +72,7 @@ pub fn infer(
let mut num_cells = count_mat.rows();

// read in the global equivalence class representation
let eq_label_path = std::path::Path::new(&eq_label_file);
let eq_label_path = std::path::Path::new(eq_label_file);
let global_eq_classes = Arc::new(crate::eq_class::IndexedEqList::init_from_eqc_file(
eq_label_path,
));
Expand Down Expand Up @@ -117,7 +118,7 @@ pub fn infer(

if let Some(fname) = filter_list {
// read in the fitler list
match read_filter_list(&fname, bc_len) {
match read_filter_list(fname, bc_len) {
Ok(fset) => {
// the number of cells we expect to
// actually process
Expand Down Expand Up @@ -316,7 +317,7 @@ pub fn infer(
}

// create our output directory
let output_path = std::path::Path::new(&output_dir);
let output_path = std::path::Path::new(output_dir);
fs::create_dir_all(output_path)?;

let in_col_path = count_mat_parent.join("quants_mat_cols.txt");
Expand Down
49 changes: 17 additions & 32 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ use itertools::Itertools;
use mimalloc::MiMalloc;
use rand::Rng;
use slog::{crit, o, warn, Drain};
use std::borrow::ToOwned;
use std::path::{Path, PathBuf};

use alevin_fry::cellfilter::{generate_permit_list, CellFilterMethod};
Expand Down Expand Up @@ -74,19 +73,6 @@ fn pathbuf_file_exists_validator(v: &str) -> Result<PathBuf, String> {
}
}

/// Checks if the path pointed to by v exists and is
/// a valid directory on disk. If there is any issue
/// with permissions or failure to properly
/// resolve symlinks, or if the path is wrong, it returns
/// an Err(String), else Ok(String).
fn directory_exists_validator(v: &str) -> Result<String, String> {
if !Path::new(v).is_dir() {
Err(String::from("No valid directory was found at this path."))
} else {
Ok(v.to_string())
}
}

/// Checks if the path pointed to by v exists and is
/// a valid directory on disk. If there is any issue
/// with permissions or failure to properly
Expand Down Expand Up @@ -196,17 +182,17 @@ fn main() -> anyhow::Result<()> {
.version(version)
.author(crate_authors)
.arg(arg!(-i --"input-dir" <INPUTDIR> "input directory containing collated RAD file")
.value_parser(directory_exists_validator))
.arg(arg!(-m --"tg-map" <TGMAP> "transcript to gene map").value_parser(file_exists_validator))
.arg(arg!(-o --"output-dir" <OUTPUTDIR> "output directory where quantification results will be written"))
.value_parser(pathbuf_directory_exists_validator))
.arg(arg!(-m --"tg-map" <TGMAP> "transcript to gene map").value_parser(pathbuf_file_exists_validator))
.arg(arg!(-o --"output-dir" <OUTPUTDIR> "output directory where quantification results will be written").value_parser(value_parser!(PathBuf)))
.arg(arg!(-t --threads <THREADS> "number of threads to use for processing").value_parser(value_parser!(u32)).default_value(&max_num_threads))
.arg(arg!(-d --"dump-eqclasses" "flag for dumping equivalence classes").takes_value(false).required(false))
.arg(arg!(-b --"num-bootstraps" <NUMBOOTSTRAPS> "number of bootstraps to use").value_parser(value_parser!(u32)).default_value("0"))
.arg(arg!(--"init-uniform" "flag for uniform sampling").requires("num-bootstraps").takes_value(false).required(false))
.arg(arg!(--"summary-stat" "flag for storing only summary statistics").requires("num-bootstraps").takes_value(false).required(false))
.arg(arg!(--"use-mtx" "flag for writing output matrix in matrix market format (default)").takes_value(false).required(false))
.arg(arg!(--"use-eds" "flag for writing output matrix in EDS format").takes_value(false).required(false).conflicts_with("use-mtx"))
.arg(arg!(--"quant-subset" <SFILE> "file containing list of barcodes to quantify, those not in this list will be ignored").required(false).value_parser(file_exists_validator))
.arg(arg!(--"quant-subset" <SFILE> "file containing list of barcodes to quantify, those not in this list will be ignored").required(false).value_parser(pathbuf_file_exists_validator))
.arg(arg!(-r --resolution <RESOLUTION> "the resolution strategy by which molecules will be counted")
.ignore_case(true)
.value_parser(["full", "trivial", "cr-like", "cr-like-em", "parsimony", "parsimony-em", "parsimony-gene", "parsimony-gene-em"]))
Expand Down Expand Up @@ -251,14 +237,14 @@ fn main() -> anyhow::Result<()> {
.version(version)
.author(crate_authors)
.arg(arg!(-c --"count-mat" <EQCMAT> "matrix of cells by equivalence class counts")
.value_parser(file_exists_validator).takes_value(true).required(true))
.value_parser(pathbuf_file_exists_validator).takes_value(true).required(true))
//.arg(arg!(-b --barcodes=<barcodes> "file containing the barcodes labeling the matrix rows").takes_value(true).required(true))
.arg(arg!(-e --"eq-labels" <EQLABELS> "file containing the gene labels of the equivalence classes")
.value_parser(file_exists_validator).takes_value(true).required(true))
.arg(arg!(-o --"output-dir" <OUTPUTDIR> "output directory where quantification results will be written").takes_value(true).required(true))
.value_parser(pathbuf_file_exists_validator).takes_value(true).required(true))
.arg(arg!(-o --"output-dir" <OUTPUTDIR> "output directory where quantification results will be written").value_parser(value_parser!(PathBuf)).takes_value(true).required(true))
.arg(arg!(-t --threads <THREADS> "number of threads to use for processing").value_parser(value_parser!(u32)).default_value(&max_num_threads))
.arg(arg!(--usa "flag specifying that input equivalence classes were computed in USA mode").takes_value(false).required(false))
.arg(arg!(--"quant-subset" <SFILE> "file containing list of barcodes to quantify, those not in this list will be ignored").required(false).value_parser(file_exists_validator))
.arg(arg!(--"quant-subset" <SFILE> "file containing list of barcodes to quantify, those not in this list will be ignored").required(false).value_parser(pathbuf_file_exists_validator))
.arg(arg!(--"use-mtx" "flag for writing output matrix in matrix market format (default)").takes_value(false).required(false))
.arg(arg!(--"use-eds" "flag for writing output matrix in EDS format").takes_value(false).required(false).conflicts_with("use-mtx"));

Expand Down Expand Up @@ -453,13 +439,13 @@ fn main() -> anyhow::Result<()> {
let summary_stat = t.is_present("summary-stat");
let dump_eq = t.is_present("dump-eqclasses");
let use_mtx = !t.is_present("use-eds");
let input_dir: String = t.get_one::<String>("input-dir").unwrap().clone();
let output_dir: String = t.get_one::<String>("output-dir").unwrap().clone();
let tg_map: String = t.get_one::<String>("tg-map").unwrap().clone();
let input_dir: &PathBuf = t.get_one("input-dir").unwrap();
let output_dir: &PathBuf = t.get_one("output-dir").unwrap();
let tg_map: &PathBuf = t.get_one("tg-map").unwrap();
let resolution: ResolutionStrategy = *t.get_one("resolution").unwrap();
let sa_model: SplicedAmbiguityModel = *t.get_one("sa-model").unwrap();
let small_thresh = *t.get_one("small-thresh").unwrap();
let filter_list = t.get_one::<String>("quant-subset").map(ToOwned::to_owned);
let filter_list: Option<&PathBuf> = t.get_one("quant-subset");
let large_graph_thresh: usize = *t.get_one("large-graph-thresh").unwrap();
let umi_edit_dist: u32 = *t.get_one("umi-edit-dist").unwrap();
let mut pug_exact_umi = false;
Expand Down Expand Up @@ -570,7 +556,7 @@ fn main() -> anyhow::Result<()> {
// if the input directory contains the valid json file we want
// then proceed. otherwise print a critical error.
if json_path.exists() {
let velo_mode = alevin_fry::utils::is_velo_mode(quant_opts.input_dir.to_string());
let velo_mode = alevin_fry::utils::is_velo_mode(quant_opts.input_dir);
if velo_mode {
match alevin_fry::quant::velo_quantify(quant_opts) {
// if we're all good; then great!
Expand Down Expand Up @@ -634,11 +620,10 @@ fn main() -> anyhow::Result<()> {
if let Some(t) = opts.subcommand_matches("infer") {
let num_threads = *t.get_one("threads").unwrap();
let use_mtx = !t.is_present("use-eds");
let output_dir = t.get_one::<String>("output-dir").unwrap().clone();
let count_mat = t.get_one::<String>("count-mat").unwrap().clone();
let eq_label_file = t.get_one::<String>("eq-labels").unwrap().clone();
let filter_list: Option<String> =
t.get_one::<String>("quant-subset").map(ToOwned::to_owned);
let output_dir = t.get_one("output-dir").unwrap();
let count_mat = t.get_one("count-mat").unwrap();
let eq_label_file = t.get_one("eq-labels").unwrap();
let filter_list: Option<&PathBuf> = t.get_one("quant-subset");
let usa_mode = t.is_present("usa");

alevin_fry::infer::infer(
Expand Down
16 changes: 8 additions & 8 deletions src/prog_opts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ use std::path::PathBuf;

#[derive(TypedBuilder, Debug)]
//#[builder(name = "QuantOptsBuilder")]
pub struct QuantOpts<'b, 'c, 'd> {
pub input_dir: String,
pub tg_map: String,
pub output_dir: String,
pub struct QuantOpts<'a, 'b, 'c, 'd, 'e, 'f, 'g> {
pub input_dir: &'a PathBuf,
pub tg_map: &'b PathBuf,
pub output_dir: &'c PathBuf,
pub num_threads: u32,
pub num_bootstraps: u32,
pub init_uniform: bool,
Expand All @@ -25,10 +25,10 @@ pub struct QuantOpts<'b, 'c, 'd> {
pub sa_model: SplicedAmbiguityModel,
pub small_thresh: usize,
pub large_graph_thresh: usize,
pub filter_list: Option<String>,
pub cmdline: &'b str,
pub version: &'c str,
pub log: &'d slog::Logger,
pub filter_list: Option<&'d PathBuf>,
pub cmdline: &'e str,
pub version: &'f str,
pub log: &'g slog::Logger,
}

#[derive(TypedBuilder, Debug)]
Expand Down
10 changes: 5 additions & 5 deletions src/quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ fn write_eqc_counts(
// TODO: see if we'd rather pass an structure
// with these options
pub fn quantify(quant_opts: QuantOpts) -> anyhow::Result<()> {
let parent = std::path::Path::new(&quant_opts.input_dir);
let parent = std::path::Path::new(quant_opts.input_dir);
let log = quant_opts.log;

// read the collate metadata
Expand Down Expand Up @@ -352,7 +352,7 @@ pub fn quantify(quant_opts: QuantOpts) -> anyhow::Result<()> {
// TODO: see if we'd rather pass an structure
// with these options
pub fn do_quantify<T: Read>(mut br: T, quant_opts: QuantOpts) -> anyhow::Result<()> {
let parent = std::path::Path::new(&quant_opts.input_dir);
let parent = std::path::Path::new(quant_opts.input_dir);
let hdr = rad_types::RadHeader::from_bytes(&mut br);

let init_uniform = quant_opts.init_uniform;
Expand Down Expand Up @@ -409,7 +409,7 @@ pub fn do_quantify<T: Read>(mut br: T, quant_opts: QuantOpts) -> anyhow::Result<
// both spliced and unspliced. The type will be automatically
// determined.
match afutils::parse_tg_map(
&quant_opts.tg_map,
quant_opts.tg_map,
hdr.ref_count as usize,
&rname_to_id,
&mut gene_names,
Expand Down Expand Up @@ -488,7 +488,7 @@ pub fn do_quantify<T: Read>(mut br: T, quant_opts: QuantOpts) -> anyhow::Result<
// if we have a filter list, extract it here
let mut retained_bc: Option<HashSet<u64, ahash::RandomState>> = None;
if let Some(fname) = filter_list {
match afutils::read_filter_list(&fname, ft_vals.bclen) {
match afutils::read_filter_list(fname, ft_vals.bclen) {
Ok(fset) => {
// the number of cells we expect to
// actually process
Expand Down Expand Up @@ -548,7 +548,7 @@ pub fn do_quantify<T: Read>(mut br: T, quant_opts: QuantOpts) -> anyhow::Result<
let num_genes = gene_name_to_id.len();

// create our output directory
let output_path = std::path::Path::new(&quant_opts.output_dir);
let output_path = std::path::Path::new(quant_opts.output_dir);
fs::create_dir_all(output_path)?;

// create sub-directory for matrix
Expand Down
9 changes: 5 additions & 4 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use std::collections::{HashMap, HashSet};
use std::error::Error;
use std::fs::File;
use std::io::{BufReader, BufWriter, Write};
use std::path::PathBuf;
use std::str::FromStr;
use thiserror::Error;

Expand Down Expand Up @@ -240,7 +241,7 @@ fn parse_tg_spliced(
}

pub fn parse_tg_map(
tg_map: &str,
tg_map: &PathBuf,
ref_count: usize,
rname_to_id: &HashMap<String, u32, ahash::RandomState>,
gene_names: &mut Vec<String>,
Expand Down Expand Up @@ -693,7 +694,7 @@ pub fn generate_permitlist_map(
/// a HashSet containing the k-mer encoding of all barcodes or
/// the Error that was encountered parsing the file.
pub fn read_filter_list(
flist: &str,
flist: &PathBuf,
bclen: u16,
) -> anyhow::Result<HashSet<u64, ahash::RandomState>> {
let s = ahash::RandomState::with_seeds(2u64, 7u64, 1u64, 8u64);
Expand All @@ -715,8 +716,8 @@ pub fn read_filter_list(
Ok(fset)
}

pub fn is_velo_mode(input_dir: String) -> bool {
let parent = std::path::Path::new(&input_dir);
pub fn is_velo_mode(input_dir: &PathBuf) -> bool {
let parent = std::path::Path::new(input_dir);
// open the metadata file and read the json
let meta_data_file = File::open(parent.join("generate_permit_list.json"))
.expect("could not open the generate_permit_list.json file.");
Expand Down

0 comments on commit 347c342

Please sign in to comment.