diff --git a/src/lib.rs b/src/lib.rs index aab60cbc..679acbd2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,7 @@ mod edge_buffer; mod error; pub mod nested_forward_list; +mod samples_info; mod segment; mod simplification_buffers; mod simplification_common; @@ -54,6 +55,7 @@ mod tsdef; pub use edge_buffer::EdgeBuffer; pub use error::ForrusttsError; +pub use samples_info::SamplesInfo; pub use segment::Segment; pub use simplification_buffers::SimplificationBuffers; pub use simplification_flags::SimplificationFlags; diff --git a/src/samples_info.rs b/src/samples_info.rs new file mode 100644 index 00000000..2f7f1eb0 --- /dev/null +++ b/src/samples_info.rs @@ -0,0 +1,41 @@ +use crate::IdType; + +/// Information about samples used for +/// table simpilfication. +#[derive(Default)] +pub struct SamplesInfo { + /// A list of sample IDs. + /// Can include both "alive" and + /// "ancient/remembered/preserved" sample + /// nodes. + pub samples: Vec, + /// When using [``EdgeBuffer``] to record transmission + /// events, this list must contain a list of all node IDs + /// alive the last time simplification happened. Here, + /// "alive" means "could leave more descendants". + /// At the *start* of a simulation, this should be filled + /// with a list of "founder" node IDs. + pub edge_buffer_founder_nodes: Vec, +} + +impl SamplesInfo { + /// Generate a new instance. + pub fn new() -> Self { + SamplesInfo { + samples: vec![], + edge_buffer_founder_nodes: vec![], + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_default() { + let s: SamplesInfo = Default::default(); + assert!(s.samples.is_empty()); + assert!(s.edge_buffer_founder_nodes.is_empty()); + } +} diff --git a/src/simplification_common.rs b/src/simplification_common.rs index 123d77fe..0ccc65ce 100644 --- a/src/simplification_common.rs +++ b/src/simplification_common.rs @@ -2,6 +2,7 @@ /// functions use crate::simplification_logic; use crate::ForrusttsError; +use crate::SamplesInfo; use crate::SimplificationBuffers; use crate::SimplificationFlags; use crate::SimplificationOutput; @@ -14,7 +15,7 @@ fn setup_idmap(nodes: &[Node], idmap: &mut Vec) { } pub fn setup_simplification( - samples: &[IdType], + samples: &SamplesInfo, tables: &TableCollection, flags: SimplificationFlags, state: &mut SimplificationBuffers, @@ -38,7 +39,7 @@ pub fn setup_simplification( state.ancestry.reset(tables.num_nodes()); simplification_logic::record_sample_nodes( - &samples, + &samples.samples, &tables, &mut state.new_nodes, &mut state.ancestry, diff --git a/src/simplify_from_edge_buffer.rs b/src/simplify_from_edge_buffer.rs index 2db3fbc2..97926c34 100644 --- a/src/simplify_from_edge_buffer.rs +++ b/src/simplify_from_edge_buffer.rs @@ -3,6 +3,7 @@ use crate::simplification_logic; use crate::tables::*; use crate::EdgeBuffer; use crate::ForrusttsError; +use crate::SamplesInfo; use crate::Segment; use crate::SimplificationBuffers; use crate::SimplificationFlags; @@ -28,12 +29,12 @@ impl ParentLocation { fn find_pre_existing_edges( tables: &TableCollection, - alive_at_last_simplification: &[IdType], + edge_buffer_founder_nodes: &[IdType], edge_buffer: &EdgeBuffer, ) -> Result, ForrusttsError> { let mut alive_with_new_edges: Vec = vec![]; - for a in alive_at_last_simplification { + for a in edge_buffer_founder_nodes { if edge_buffer.head(*a)? != EdgeBuffer::null() { alive_with_new_edges.push(*a); } @@ -130,8 +131,9 @@ fn process_births_from_buffer( /// /// # Parameters /// -/// * `samples`: -/// * `alive_at_last_simplification`: +/// * `samples`: Instance of [``SamplesInfo``]. The field +/// [``SamplesInfo::edge_buffer_founder_nodes``] +/// must be populated. See [``EdgeBuffer``] for details. /// * `flags`: modify the behavior of the simplification algorithm. /// * `state`: These are the internal data structures used /// by the simpilfication algorithm. @@ -146,8 +148,7 @@ fn process_births_from_buffer( /// The input tables must be sorted. /// See [``TableCollection::sort_tables_for_simplification``]. pub fn simplify_from_edge_buffer( - samples: &[IdType], - alive_at_last_simplification: &[IdType], + samples: &SamplesInfo, flags: SimplificationFlags, state: &mut SimplificationBuffers, edge_buffer: &mut EdgeBuffer, @@ -158,7 +159,7 @@ pub fn simplify_from_edge_buffer( // Process all edges since the last simplification. let mut max_time = Time::MIN; - for n in alive_at_last_simplification { + for n in samples.edge_buffer_founder_nodes.iter() { max_time = std::cmp::max(max_time, tables.node(*n).time); } for (i, _) in edge_buffer.head_itr().rev().enumerate() { @@ -185,7 +186,7 @@ pub fn simplify_from_edge_buffer( } let existing_edges = - find_pre_existing_edges(&tables, &alive_at_last_simplification, &edge_buffer)?; + find_pre_existing_edges(&tables, &samples.edge_buffer_founder_nodes, &edge_buffer)?; let mut edge_i = 0; let num_edges = tables.num_edges(); diff --git a/src/simplify_tables.rs b/src/simplify_tables.rs index fb8b2497..1120905f 100644 --- a/src/simplify_tables.rs +++ b/src/simplify_tables.rs @@ -1,7 +1,7 @@ use crate::simplification_common::*; use crate::tables::*; use crate::ForrusttsError; -use crate::IdType; +use crate::SamplesInfo; use crate::SimplificationBuffers; use crate::SimplificationFlags; use crate::SimplificationOutput; @@ -26,7 +26,7 @@ use crate::SimplificationOutput; /// [``simplify_tables``] to keep memory allocations /// persistent between simplifications. pub fn simplify_tables_without_state( - samples: &[IdType], + samples: &SamplesInfo, flags: SimplificationFlags, tables: &mut TableCollection, output: &mut SimplificationOutput, @@ -56,7 +56,7 @@ pub fn simplify_tables_without_state( /// The input tables must be sorted. /// See [``TableCollection::sort_tables_for_simplification``]. pub fn simplify_tables( - samples: &[IdType], + samples: &SamplesInfo, flags: SimplificationFlags, state: &mut SimplificationBuffers, tables: &mut TableCollection, diff --git a/src/test_simplify_tables.rs b/src/test_simplify_tables.rs index ef954199..ce92e044 100644 --- a/src/test_simplify_tables.rs +++ b/src/test_simplify_tables.rs @@ -8,6 +8,7 @@ mod test { use crate::tsdef::{IdType, Position, Time}; use crate::wright_fisher::*; use crate::ForrusttsError; + use crate::SamplesInfo; use crate::SimplificationFlags; use crate::SimplificationOutput; use crate::TableCollection; @@ -89,10 +90,10 @@ mod test { // Now, sort and simplify the tables we got from the sim: tables.sort_tables_for_simplification(); - let mut samples: Vec = vec![]; + let mut samples = SamplesInfo::new(); for (i, n) in tables.nodes().iter().enumerate() { if n.time == num_generations { - samples.push(i as IdType); + samples.samples.push(i as IdType); } } @@ -122,8 +123,8 @@ mod test { assert!(rv == 0); let rv = tskr::tsk_table_collection_simplify( tsk_tables.as_mut_ptr(), - samples.as_ptr(), - samples.len() as u32, + samples.samples.as_ptr(), + samples.samples.len() as u32, 0, std::ptr::null_mut(), ); diff --git a/src/wright_fisher.rs b/src/wright_fisher.rs index 3da443dd..2b4bab19 100644 --- a/src/wright_fisher.rs +++ b/src/wright_fisher.rs @@ -11,6 +11,7 @@ use crate::tables::{validate_edge_table, TableCollection}; use crate::tsdef::*; use crate::EdgeBuffer; use crate::ForrusttsError; +use crate::SamplesInfo; use crate::Segment; use crate::SimplificationBuffers; use crate::SimplificationFlags; @@ -65,7 +66,6 @@ type VecBirth = Vec; struct PopulationState { pub parents: VecParent, pub births: VecBirth, - pub alive_at_last_simplification: Vec, pub edge_buffer: EdgeBuffer, pub tables: TableCollection, } @@ -75,7 +75,6 @@ impl PopulationState { PopulationState { parents: vec![], births: vec![], - alive_at_last_simplification: vec![], edge_buffer: EdgeBuffer::new(), tables: TableCollection::new(genome_length).unwrap(), } @@ -291,17 +290,17 @@ fn recombination_breakpoints( } } -fn fill_samples(parents: &[Parent], samples: &mut Vec) { - samples.clear(); +fn fill_samples(parents: &[Parent], samples: &mut SamplesInfo) { + samples.samples.clear(); for p in parents { - samples.push(p.node0); - samples.push(p.node1); + samples.samples.push(p.node0); + samples.samples.push(p.node1); } } fn sort_and_simplify( flags: SimulationFlags, - samples: &[IdType], + samples: &SamplesInfo, state: &mut SimplificationBuffers, pop: &mut PopulationState, output: &mut SimplificationOutput, @@ -341,7 +340,6 @@ fn sort_and_simplify( } else { simplify_from_edge_buffer( samples, - &pop.alive_at_last_simplification, SimplificationFlags::empty(), state, &mut pop.edge_buffer, @@ -354,7 +352,7 @@ fn sort_and_simplify( fn simplify_and_remap_nodes( flags: SimulationFlags, - samples: &mut Vec, + samples: &mut SamplesInfo, state: &mut SimplificationBuffers, pop: &mut PopulationState, output: &mut SimplificationOutput, @@ -368,10 +366,10 @@ fn simplify_and_remap_nodes( } if flags.contains(SimulationFlags::BUFFER_EDGES) { - pop.alive_at_last_simplification.clear(); + samples.edge_buffer_founder_nodes.clear(); for p in &pop.parents { - pop.alive_at_last_simplification.push(p.node0); - pop.alive_at_last_simplification.push(p.node1); + samples.edge_buffer_founder_nodes.push(p.node0); + samples.edge_buffer_founder_nodes.push(p.node1); } } } @@ -523,7 +521,7 @@ pub fn neutral_wf( rng.set(params.seed); let mut pop = PopulationState::new(pop_params.genome_length); - let mut samples: Vec = vec![]; + let mut samples: SamplesInfo = Default::default(); let mut breakpoints = vec![]; // Record nodes for the first generation @@ -535,7 +533,7 @@ pub fn neutral_wf( } for i in 0..pop.tables.num_nodes() { - pop.alive_at_last_simplification.push(i as IdType); + samples.edge_buffer_founder_nodes.push(i as IdType); } let mut simplified = false;