Skip to content

Commit

Permalink
Merge pull request #5 from COMBINE-lab/bootstrapping
Browse files Browse the repository at this point in the history
merging common bootstrapping option to develop branch
  • Loading branch information
hiraksarkar committed Aug 18, 2020
2 parents 8b84a8c + edaa587 commit c921562
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 9 deletions.
73 changes: 69 additions & 4 deletions libradicl/src/em.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,37 @@ const MIN_ITER: u32 = 50;
const MAX_ITER: u32 = 10_000;
const REL_DIFF_TOLERANCE: f32 = 1e-2;

#[allow(dead_code)]
fn mean(data: &[f64]) -> Option<f64> {
let sum = data.iter().sum::<f64>() as f64;
let count = data.len();

match count {
positive if positive > 0 => Some(sum / count as f64),
_ => None,
}
}

#[allow(dead_code)]
fn std_deviation(data: &[f64]) -> Option<f64> {
match (mean(data), data.len()) {
(Some(data_mean), count) if count > 0 => {
let variance = data
.iter()
.map(|value| {
let diff = data_mean - (*value as f64);

diff * diff
})
.sum::<f64>()
/ count as f64;

Some(variance.sqrt())
}
_ => None,
}
}

pub fn em_update(
alphas_in: &[f32],
alphas_out: &mut Vec<f32>,
Expand Down Expand Up @@ -149,11 +180,18 @@ pub fn run_bootstrap(
// num_alphas: usize,
// only_unique: bool,
init_uniform: bool,
summary_stat: bool,
_log: &slog::Logger,
) -> Vec<Vec<f32>> {
let mut total_fragments = 0u64;

// println!("In bootstrapping");

let mut alphas: Vec<f32> = vec![0.0; gene_alpha.len()];
let mut alphas_mean: Vec<f32> = vec![0.0; gene_alpha.len()];
let mut alphas_square: Vec<f32> = vec![0.0; gene_alpha.len()];
let mut sample_mean: Vec<f32> = vec![0.0; gene_alpha.len()];
let mut sample_var: Vec<f32> = vec![0.0; gene_alpha.len()];
let mut alphas_prime: Vec<f32> = vec![0.0; gene_alpha.len()];
// let mut means: Vec<f32> = vec![0.0; gene_alpha.len()];
// let mut square_means: Vec<f32> = vec![0.0; gene_alpha.len()];
Expand All @@ -173,6 +211,8 @@ pub fn run_bootstrap(
.or_insert_with(|| labels.to_vec());
}

// println!("total fragments {:?}", total_fragments);

// a new hashmap to be updated in each bootstrap
let s = fasthash::RandomState::<Hash64>::new();
let mut eqclass_bootstrap: HashMap<Vec<u32>, u32, fasthash::RandomState<Hash64>> =
Expand All @@ -184,13 +224,16 @@ pub fn run_bootstrap(
let mut bootstraps = Vec::new();

// bootstrap loop starts
// let mut old_resampled_counts = Vec::new();
for _bs_num in 0..num_bootstraps {
// resample from multinomial
let resampled_counts = thread_rng().sample(dist.clone());

for (eq_id, labels) in &eqclasses_serialize {
eqclass_bootstrap
.entry(labels.to_vec())
.or_insert(resampled_counts[*eq_id].round() as u32);
eqclass_bootstrap.insert(labels.to_vec(), resampled_counts[*eq_id].round() as u32);
// eqclass_bootstrap
// .entry(labels.to_vec())
// .or_insert(resampled_counts[*eq_id].round() as u32);
}

// fill new alpha
Expand All @@ -202,6 +245,9 @@ pub fn run_bootstrap(
}
}

// let alpha_sum : f32 = alphas.iter().sum();
// println!("Bootstrap num ... {:?}, alpha_sum ... {:?}", _bs_num, alpha_sum);

let mut it_num: u32 = 0;
let mut converged: bool = false;
while it_num < MIN_ITER || (it_num < MAX_ITER && !converged) {
Expand Down Expand Up @@ -242,7 +288,26 @@ pub fn run_bootstrap(

let alphas_sum: f32 = alphas.iter().sum();
assert!(alphas_sum > 0.0, "Alpha Sum too small");
bootstraps.push(alphas.clone());
if summary_stat {
for i in 0..gene_alpha.len() {
alphas_mean[i] += alphas[i];
alphas_square[i] += alphas[i] * alphas[i];
}
} else {
bootstraps.push(alphas.clone());
}
// println!("After alpha sum: {:?}, it_num: {:?}", alphas_sum, it_num);
// old_resampled_counts = resampled_counts.clone();
}
if summary_stat {
for i in 0..gene_alpha.len() {
let mean_alpha = alphas_mean[i] / num_bootstraps as f32;
sample_mean[i] = mean_alpha;
sample_var[i] = (alphas_square[i] / num_bootstraps as f32) - (mean_alpha * mean_alpha);
}

bootstraps.push(sample_mean.clone());
bootstraps.push(sample_var.clone());
}

bootstraps
Expand Down
139 changes: 136 additions & 3 deletions libradicl/src/quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ use std::string::ToString;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
//use std::ptr;

use flate2::write::GzEncoder;
use flate2::Compression;

use self::libradicl::em::em_optimize;
use self::libradicl::em::{em_optimize, run_bootstrap};
use self::libradicl::pugutils;
use self::libradicl::schema::{EqMap, PUGEdgeType, ResolutionStrategy};
use self::libradicl::utils::*;
Expand Down Expand Up @@ -254,6 +255,9 @@ pub fn quantify(
tg_map: String,
output_dir: String,
num_threads: u32,
num_bootstraps: u32,
init_uniform: bool,
summary_stat: bool,
resolution: ResolutionStrategy,
//no_em: bool,
//naive: bool,
Expand Down Expand Up @@ -399,13 +403,53 @@ pub fn quantify(
let bc_file = fs::File::create(bc_path)?;

let mat_path = output_path.join("counts.eds.gz");
//let bootstrap_path_1 = output_path.join("bootstraps_1.eds.gz");

let bootstrap_path = output_path.join("bootstraps.eds.gz");
let bootstrap_mean_path = output_path.join("bootstraps_mean.eds.gz");
let bootstrap_var_path = output_path.join("bootstraps_var.eds.gz");
let buffered = GzEncoder::new(fs::File::create(mat_path)?, Compression::default());

let ff_path = output_path.join("features.txt");
let mut ff_file = fs::File::create(ff_path)?;
writeln!(ff_file, "cell_num\tnum_mapped\ttot_umi\tdedup_rate\tmean_by_max\ttotal_expressed_genes\tnum_genes_over_mean")?;

let alt_res_cells = Arc::new(Mutex::new(Vec::<u64>::new()));
let mut bt_writer_optional: Arc<Mutex<Option<BufWriter<GzEncoder<fs::File>>>>> =
Arc::new(Mutex::new(None));
let mut bt_summary_writer_optional: Arc<
Mutex<
Option<(
BufWriter<GzEncoder<fs::File>>,
BufWriter<GzEncoder<fs::File>>,
)>,
>,
> = Arc::new(Mutex::new(None));
if num_bootstraps > 0 {
if summary_stat {
let bt_mean_buffered = GzEncoder::new(
fs::File::create(bootstrap_mean_path)?,
Compression::default(),
);
let bt_var_buffered = GzEncoder::new(
fs::File::create(bootstrap_var_path)?,
Compression::default(),
);
bt_summary_writer_optional = Arc::new(Mutex::new(Some((
BufWriter::new(bt_mean_buffered),
BufWriter::new(bt_var_buffered),
))));
} else {
let bt_buffered =
GzEncoder::new(fs::File::create(bootstrap_path)?, Compression::default());
bt_writer_optional = Arc::new(Mutex::new(Some(BufWriter::new(bt_buffered))));
}
}

// let bt_buffered = GzEncoder::new(fs::File::create(bootstrap_path)?, Compression::default());
// let bt_writer = Arc::new(Mutex::new(
// BufWriter::new(bt_buffered),
// ));

let bc_writer = Arc::new(Mutex::new((
BufWriter::new(bc_file),
Expand All @@ -429,6 +473,26 @@ pub fn quantify(
let umi_type = umi_type;
// and the file writer
let bcout = bc_writer.clone();
// and the bootstrap file writer
let mut btcout_optional: Arc<Mutex<Option<BufWriter<GzEncoder<fs::File>>>>> =
Arc::new(Mutex::new(None));
let mut btcout_summary_optional: Arc<
Mutex<
Option<(
BufWriter<GzEncoder<fs::File>>,
BufWriter<GzEncoder<fs::File>>,
)>,
>,
> = Arc::new(Mutex::new(None));
if num_bootstraps > 0 {
if summary_stat {
btcout_summary_optional = bt_summary_writer_optional.clone();
} else {
btcout_optional = bt_writer_optional.clone();
}
}

//let btcout = bt_writer.clone();
// and will need to know the barcode length
let bclen = ft_vals.bclen;
let alt_res_cells = alt_res_cells.clone();
Expand All @@ -455,6 +519,7 @@ pub fn quantify(
let counts: Vec<f32>;
let mut alt_resolution = false;

let mut bootstraps: Vec<Vec<f32>> = Vec::new();
match resolution {
ResolutionStrategy::CellRangerLike => {
let gene_eqc = pugutils::get_num_molecules_cell_ranger_like(
Expand Down Expand Up @@ -514,6 +579,16 @@ pub fn quantify(
true,
&log,
);
if num_bootstraps > 0 {
bootstraps = run_bootstrap(
&gene_eqc,
num_bootstraps,
&counts,
init_uniform,
summary_stat,
&log,
);
}
}
ResolutionStrategy::Full => {
let g = extract_graph(&eq_map, &log);
Expand All @@ -530,9 +605,20 @@ pub fn quantify(
&mut unique_evidence,
&mut no_ambiguity,
num_genes,
false,
init_uniform,
&log,
);

if num_bootstraps > 0 {
bootstraps = run_bootstrap(
&gene_eqc,
num_bootstraps,
&counts,
init_uniform,
summary_stat,
&log,
);
}
}
}
// clear our local variables
Expand Down Expand Up @@ -582,7 +668,8 @@ pub fn quantify(
let eds_bytes: Vec<u8> = sce::eds::as_bytes(&counts, num_genes)
.expect("can't conver vector to eds");

let writer = &mut *bcout.lock().unwrap();
let writer_deref = bcout.lock();
let writer = &mut *writer_deref.unwrap();
// write to barcode file
writeln!(&mut writer.0, "{}\t{}", cell_num, unsafe {
std::str::from_utf8_unchecked(&bitmer_to_bytes(bc_mer)[..])
Expand All @@ -608,6 +695,52 @@ pub fn quantify(
)
.expect("can't write to feature file");
}

// write bootstraps
if num_bootstraps > 0 {
// flatten the bootstraps
if summary_stat {
let eds_mean_bytes: Vec<u8> =
sce::eds::as_bytes(&bootstraps[0], num_genes)
.expect("can't convert vector to eds");
let eds_var_bytes: Vec<u8> =
sce::eds::as_bytes(&bootstraps[1], num_genes)
.expect("can't convert vector to eds");

let mut writer_deref = btcout_summary_optional.lock().unwrap();
match &mut *writer_deref {
Some(writer) => {
writer
.0
.write_all(&eds_mean_bytes)
.expect("can't write to matrix file.");
writer
.1
.write_all(&eds_var_bytes)
.expect("can't write to matrix file.");
}
None => {}
}
} else {
let mut bt_eds_bytes: Vec<u8> = Vec::new();
for i in 0..num_bootstraps {
let bt_eds_bytes_slice =
sce::eds::as_bytes(&bootstraps[i as usize], num_genes)
.expect("can't convert vector to eds");
bt_eds_bytes.append(&mut bt_eds_bytes_slice.clone());
}

let mut writer_deref = btcout_optional.lock().unwrap();
match &mut *writer_deref {
Some(writer) => {
writer
.write_all(&bt_eds_bytes)
.expect("can't write to matrix file.");
}
None => {}
}
}
}
}
}
});
Expand Down
20 changes: 18 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ fn main() {
.arg(Arg::from("-m, --tg-map=<tg-map> 'transcript to gene map'"))
.arg(Arg::from("-o, --output-dir=<output-dir> 'output directory where quantification results will be written'"))
.arg(Arg::from("-t, --threads 'number of threads to use for processing'").default_value(&max_num_threads))
.arg(Arg::from("-b, --num-bootstraps 'number of bootstraps to use'").default_value("0"))
.arg(Arg::from("--init-uniform 'flag for uniform sampling'").requires("num-bootstraps").takes_value(false).required(false))
.arg(Arg::from("--summary-stat 'flag for storing only summary statistics'").requires("num-bootstraps").takes_value(false).required(false))
.arg(Arg::from("-r, --resolution 'the resolution strategy by which molecules will be counted'")
.possible_values(&["full", "trivial", "cr-like", "cr-like-em", "parsimony"])
.default_value("full")
Expand Down Expand Up @@ -207,11 +210,24 @@ fn main() {

if let Some(ref t) = opts.subcommand_matches("quant") {
let num_threads = t.value_of_t("threads").unwrap();
let num_bootstraps = t.value_of_t("num-bootstraps").unwrap();
let init_uniform = t.is_present("init-uniform");
let summary_stat = t.is_present("summary-stat");
let input_dir = t.value_of_t("input-dir").unwrap();
let output_dir = t.value_of_t("output-dir").unwrap();
let tg_map = t.value_of_t("tg-map").unwrap();
let resolution: ResolutionStrategy = t.value_of_t("resolution").unwrap();
libradicl::quant::quantify(input_dir, tg_map, output_dir, num_threads, resolution, &log)
.expect("could not quantify rad file.");
libradicl::quant::quantify(
input_dir,
tg_map,
output_dir,
num_threads,
num_bootstraps,
init_uniform,
summary_stat,
resolution,
&log,
)
.expect("could not quantify rad file.");
}
}

0 comments on commit c921562

Please sign in to comment.