From e441f9202a49f41ad37613461df8c970ee06c131 Mon Sep 17 00:00:00 2001 From: molpopgen Date: Mon, 29 Mar 2021 10:50:40 -0700 Subject: [PATCH] Move all simplification code into one source file --- src/edge_buffer.rs | 52 -- src/lib.rs | 24 +- src/simplification.rs | 1048 ++++++++++++++++++++++++++++++ src/simplification_buffers.rs | 40 -- src/simplification_common.rs | 83 --- src/simplification_flags.rs | 33 - src/simplification_logic.rs | 372 ----------- src/simplification_output.rs | 33 - src/simplify_from_edge_buffer.rs | 296 --------- src/simplify_tables.rs | 133 ---- src/wright_fisher.rs | 5 +- 11 files changed, 1059 insertions(+), 1060 deletions(-) delete mode 100644 src/edge_buffer.rs create mode 100644 src/simplification.rs delete mode 100644 src/simplification_buffers.rs delete mode 100644 src/simplification_common.rs delete mode 100644 src/simplification_flags.rs delete mode 100644 src/simplification_logic.rs delete mode 100644 src/simplification_output.rs delete mode 100644 src/simplify_from_edge_buffer.rs delete mode 100644 src/simplify_tables.rs diff --git a/src/edge_buffer.rs b/src/edge_buffer.rs deleted file mode 100644 index a8090fc4..00000000 --- a/src/edge_buffer.rs +++ /dev/null @@ -1,52 +0,0 @@ -use crate::nested_forward_list::NestedForwardList; -use crate::segment::Segment; - -/// Data type used for edge buffering. -/// Simplification of simulated data happens -/// via [``crate::simplify_from_edge_buffer()``]. -/// -/// # Overview -/// -/// The typical tree sequence recording workflow -/// goes like this. When a new node is "born", -/// we: -/// -/// 1. Add a new node to the node table. This -/// new node is a "child". -/// 2. Add edges to the edge table representing -/// the genomic intervals passed on from -/// various parents to this child node. -/// -/// We repeat `1` and `2` for a while, then -/// we [``sort the tables``](crate::TableCollection::sort_tables_for_simplification). -/// After sorting, we [``simplify``](crate::simplify_tables()) the tables. -/// -/// We can avoid the sorting step using this type. -/// To start, we record the list of currently-alive nodes -/// [``here``](crate::SamplesInfo::edge_buffer_founder_nodes). -/// -/// Then, we use `parent` ID values as the -/// `head` values for linked lists stored in a -/// [``NestedForwardList``](crate::nested_forward_list::NestedForwardList). -/// -/// By its very nature, offspring are generated by birth order. -/// Further, a well-behaved forward simulation is capable of calculating -/// edges from left-to-right along a genome. Thus, we can -/// [``extend``](crate::nested_forward_list::NestedForwardList::extend) -/// the data for each chain with [``Segment``] instances representing -/// transmission events. The segment's [``node``](Segment::node) field represents -/// the child. -/// -/// After recording for a while, we call -/// [``simplify_from_edge_buffer``](crate::simplify_from_edge_buffer()) to simplify -/// the tables. After simplification, the client code must re-populate -/// [``the list``](crate::SamplesInfo::edge_buffer_founder_nodes) of alive nodes. -/// Once that is done, we can keep recording, etc.. -/// -/// # Example -/// -/// For a full example of use in simulation, -/// see the source code for -/// [``wright_fisher::neutral_wf``](crate::wright_fisher::neutral_wf) -/// -pub type EdgeBuffer = NestedForwardList; diff --git a/src/lib.rs b/src/lib.rs index 942c9cee..93ca9bb1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,30 +41,22 @@ // stuff that needs documenting: // #![warn(missing_docs)] -mod edge_buffer; mod error; pub mod nested_forward_list; -mod samples_info; mod segment; -mod simplification_buffers; -mod simplification_common; -mod simplification_flags; -mod simplification_logic; -mod simplification_output; -mod simplify_from_edge_buffer; -mod simplify_tables; +mod simplification; mod tables; mod tsdef; -pub use edge_buffer::EdgeBuffer; pub use error::ForrusttsError; -pub use samples_info::SamplesInfo; pub use segment::Segment; -pub use simplification_buffers::SimplificationBuffers; -pub use simplification_flags::SimplificationFlags; -pub use simplification_output::SimplificationOutput; -pub use simplify_from_edge_buffer::simplify_from_edge_buffer; -pub use simplify_tables::{simplify_tables, simplify_tables_without_state}; +pub use simplification::simplify_from_edge_buffer; +pub use simplification::EdgeBuffer; +pub use simplification::SamplesInfo; +pub use simplification::SimplificationBuffers; +pub use simplification::SimplificationFlags; +pub use simplification::SimplificationOutput; +pub use simplification::{simplify_tables, simplify_tables_without_state}; pub use tables::*; pub use tsdef::*; diff --git a/src/simplification.rs b/src/simplification.rs new file mode 100644 index 00000000..63ff0422 --- /dev/null +++ b/src/simplification.rs @@ -0,0 +1,1048 @@ +use crate::nested_forward_list::NestedForwardList; +use crate::tables::*; +use crate::tsdef::{IdType, Position, Time, NULL_ID}; +use crate::ForrusttsError; +use crate::Segment; +use bitflags::bitflags; + +struct SegmentOverlapper { + segment_queue: Vec, + overlapping: Vec, + left: Position, + right: Position, + qbeg: usize, + qend: usize, + obeg: usize, + oend: usize, +} + +impl SegmentOverlapper { + fn set_partition(&mut self) -> Position { + let mut tright = Position::MAX; + let mut b: usize = 0; + + for i in 0..self.oend { + if self.overlapping[i].right > self.left { + self.overlapping[b] = self.overlapping[i]; + tright = std::cmp::min(tright, self.overlapping[b].right); + b += 1; + } + } + + self.oend = b; + + tright + } + + fn num_overlaps(&self) -> usize { + assert!( + self.oend - self.obeg <= self.overlapping.len(), + "overlap details = {} {} {}", + self.oend, + self.obeg, + self.overlapping.len() + ); + self.oend - self.obeg + } + + // Public interface below + + const fn new() -> SegmentOverlapper { + SegmentOverlapper { + segment_queue: vec![], + overlapping: vec![], + left: 0, + right: Position::MAX, + qbeg: std::usize::MAX, + qend: std::usize::MAX, + obeg: std::usize::MAX, + oend: std::usize::MAX, + } + } + + fn init(&mut self) { + self.qbeg = 0; + self.qend = self.segment_queue.len() - 1; + assert!(self.qend < self.segment_queue.len()); + self.obeg = 0; + self.oend = 0; + self.overlapping.clear(); + } + + fn enqueue(&mut self, left: Position, right: Position, node: IdType) { + self.segment_queue.push(Segment { left, right, node }); + } + + fn finalize_queue(&mut self, maxlen: Position) { + self.segment_queue.sort_by(|a, b| a.left.cmp(&b.left)); + self.segment_queue.push(Segment { + left: maxlen, + right: maxlen + 1, + node: NULL_ID, + }); + } + + fn advance(&mut self) -> bool { + let mut rv = false; + + if self.qbeg < self.qend { + self.left = self.right; + let mut tright = self.set_partition(); + if self.num_overlaps() == 0 { + self.left = self.segment_queue[self.qbeg].left; + } + while self.qbeg < self.qend && self.segment_queue[self.qbeg].left == self.left { + tright = std::cmp::min(tright, self.segment_queue[self.qbeg].right); + // NOTE: I wonder how efficient this is vs C++? + self.overlapping + .insert(self.oend, self.segment_queue[self.qbeg]); + self.oend += 1; + self.qbeg += 1; + } + self.right = std::cmp::min(self.segment_queue[self.qbeg].left, tright); + rv = true; + } else { + self.left = self.right; + self.right = Position::MAX; + let tright = self.set_partition(); + if self.num_overlaps() > 0 { + self.right = tright; + rv = true + } + } + + rv + } + + fn get_left(&self) -> Position { + self.left + } + + fn get_right(&self) -> Position { + self.right + } + + fn clear_queue(&mut self) { + self.segment_queue.clear(); + } + + fn overlap(&self, i: usize) -> &Segment { + &self.overlapping[i] + } +} + +type AncestryList = NestedForwardList; + +fn find_parent_child_segment_overlap( + edges: &[Edge], + edge_index: usize, + num_edges: usize, + maxlen: Position, + u: IdType, + ancestry: &mut AncestryList, + overlapper: &mut SegmentOverlapper, +) -> Result { + overlapper.clear_queue(); + + let mut i = edge_index; + + while i < num_edges && edges[i].parent == u { + let edge = &edges[i]; + + ancestry.for_each(edges[i].child, |seg: &Segment| { + if seg.right > edge.left && edge.right > seg.left { + overlapper.enqueue( + std::cmp::max(seg.left, edge.left), + std::cmp::min(seg.right, edge.right), + seg.node, + ); + } + true + })?; + + i += 1; + } + overlapper.finalize_queue(maxlen); + Ok(i) +} + +fn add_ancestry( + input_id: IdType, + left: Position, + right: Position, + node: IdType, + ancestry: &mut AncestryList, +) -> Result<(), ForrusttsError> { + let head = ancestry.head(input_id)?; + if head == AncestryList::null() { + let seg = Segment { left, right, node }; + ancestry.extend(input_id, seg)?; + } else { + let last_idx = ancestry.tail(input_id)?; + if last_idx == AncestryList::null() { + return Err(ForrusttsError::SimplificationError { + value: "last_idx is NULL_ID".to_string(), + }); + } + let last = ancestry.fetch_mut(last_idx)?; + if last.right == left && last.node == node { + last.right = right; + } else { + let seg = Segment { left, right, node }; + ancestry.extend(input_id, seg)?; + } + } + Ok(()) +} + +fn buffer_edge( + left: Position, + right: Position, + parent: IdType, + child: IdType, + temp_edge_buffer: &mut EdgeTable, +) { + let i = temp_edge_buffer + .iter() + .rposition(|e: &Edge| e.child == child); + + match i { + None => temp_edge_buffer.push(Edge { + left, + right, + parent, + child, + }), + Some(x) => { + if temp_edge_buffer[x].right == left { + temp_edge_buffer[x].right = right; + } else { + temp_edge_buffer.push(Edge { + left, + right, + parent, + child, + }); + } + } + } +} + +fn output_buffered_edges(temp_edge_buffer: &mut EdgeTable, new_edges: &mut EdgeTable) -> usize { + temp_edge_buffer.sort_by(|a, b| a.child.cmp(&b.child)); + + // Need to store size here b/c + // append drains contents of input!!! + let rv = temp_edge_buffer.len(); + new_edges.append(temp_edge_buffer); + + rv +} + +fn merge_ancestors( + input_nodes: &[Node], + maxlen: Position, + parent_input_id: IdType, + state: &mut SimplificationBuffers, + idmap: &mut [IdType], +) -> Result<(), ForrusttsError> { + let mut output_id = idmap[parent_input_id as usize]; + let is_sample = output_id != NULL_ID; + + if is_sample { + state.ancestry.nullify_list(parent_input_id)?; + } + + let mut previous_right: Position = 0; + let mut ancestry_node: IdType; + state.overlapper.init(); + state.temp_edge_buffer.clear(); + + while state.overlapper.advance() { + if state.overlapper.num_overlaps() == 1 { + ancestry_node = state.overlapper.overlap(0).node; + if is_sample { + buffer_edge( + state.overlapper.get_left(), + state.overlapper.get_right(), + output_id, + ancestry_node, + &mut state.temp_edge_buffer, + ); + ancestry_node = output_id; + } + } else { + if output_id == NULL_ID { + state.new_nodes.push(Node { + time: input_nodes[parent_input_id as usize].time, + deme: input_nodes[parent_input_id as usize].deme, + }); + output_id = (state.new_nodes.len() - 1) as IdType; + idmap[parent_input_id as usize] = output_id; + } + ancestry_node = output_id; + for i in 0..state.overlapper.num_overlaps() as usize { + let o = &state.overlapper.overlap(i); + buffer_edge( + state.overlapper.get_left(), + state.overlapper.get_right(), + output_id, + o.node, + &mut state.temp_edge_buffer, + ); + } + } + if is_sample && state.overlapper.get_left() != previous_right { + add_ancestry( + parent_input_id, + previous_right, + state.overlapper.get_left(), + output_id, + &mut state.ancestry, + )?; + } + add_ancestry( + parent_input_id, + state.overlapper.get_left(), + state.overlapper.get_right(), + ancestry_node, + &mut state.ancestry, + )?; + previous_right = state.overlapper.get_right(); + } + if is_sample && previous_right != maxlen { + add_ancestry( + parent_input_id, + previous_right, + maxlen, + output_id, + &mut state.ancestry, + )?; + } + + if output_id != NULL_ID { + let n = output_buffered_edges(&mut state.temp_edge_buffer, &mut state.new_edges); + + if n == 0 && !is_sample { + assert!(output_id < state.new_nodes.len() as IdType); + state.new_nodes.truncate(output_id as usize); + idmap[parent_input_id as usize] = NULL_ID; + } + } + Ok(()) +} + +fn record_sample_nodes( + samples: &[IdType], + tables: &TableCollection, + new_nodes: &mut NodeTable, + ancestry: &mut AncestryList, + idmap: &mut [IdType], +) -> Result<(), ForrusttsError> { + for sample in samples.iter() { + assert!(*sample >= 0); + // NOTE: the following can be debug_assert? + if *sample == NULL_ID { + return Err(ForrusttsError::SimplificationError { + value: "sample node is NULL_ID".to_string(), + }); + } + if idmap[*sample as usize] != NULL_ID { + return Err(ForrusttsError::SimplificationError { + value: "invalid sample list!".to_string(), + }); + } + let n = tables.node(*sample); + new_nodes.push(Node { + time: n.time, + deme: n.deme, + }); + + add_ancestry( + *sample, + 0, + tables.genome_length(), + (new_nodes.len() - 1) as IdType, + ancestry, + )?; + + idmap[*sample as usize] = (new_nodes.len() - 1) as IdType; + } + Ok(()) +} + +fn validate_tables( + tables: &TableCollection, + flags: &SimplificationFlags, +) -> Result<(), ForrusttsError> { + if flags.contains(SimplificationFlags::VALIDATE_EDGES) { + validate_edge_table(tables.genome_length(), tables.edges(), tables.nodes())?; + } + Ok(()) +} + +fn setup_idmap(nodes: &[Node], idmap: &mut Vec) { + idmap.resize(nodes.len(), NULL_ID); + idmap.iter_mut().for_each(|x| *x = NULL_ID); +} + +fn setup_simplification( + samples: &SamplesInfo, + tables: &TableCollection, + flags: SimplificationFlags, + state: &mut SimplificationBuffers, + output: &mut SimplificationOutput, +) -> Result<(), ForrusttsError> { + if !tables.sites_.is_empty() || !tables.mutations_.is_empty() { + return Err(ForrusttsError::SimplificationError { + value: "mutation simplification not yet implemented".to_string(), + }); + } + + validate_tables(tables, &flags)?; + setup_idmap(&tables.nodes_, &mut output.idmap); + + state.clear(); + state.ancestry.reset(tables.num_nodes()); + + record_sample_nodes( + &samples.samples, + &tables, + &mut state.new_nodes, + &mut state.ancestry, + &mut output.idmap, + )?; + + Ok(()) +} + +fn process_parent( + u: IdType, + (edge_index, num_edges): (usize, usize), + tables: &TableCollection, + state: &mut SimplificationBuffers, + output: &mut SimplificationOutput, +) -> Result { + let edge_i = find_parent_child_segment_overlap( + &tables.edges_, + edge_index, + num_edges, + tables.genome_length(), + u, + &mut state.ancestry, + &mut state.overlapper, + )?; + + merge_ancestors( + &tables.nodes_, + tables.genome_length(), + u, + state, + &mut output.idmap, + )?; + Ok(edge_i) +} + +struct ParentLocation { + parent: IdType, + start: usize, + stop: usize, +} + +// TODO: validate input and return errors. +impl ParentLocation { + fn new(parent: IdType, start: usize, stop: usize) -> Self { + ParentLocation { + parent, + start, + stop, + } + } +} + +fn find_pre_existing_edges( + tables: &TableCollection, + edge_buffer_founder_nodes: &[IdType], + edge_buffer: &EdgeBuffer, +) -> Result, ForrusttsError> { + let mut alive_with_new_edges: Vec = vec![]; + + for a in edge_buffer_founder_nodes { + if edge_buffer.head(*a)? != EdgeBuffer::null() { + alive_with_new_edges.push(*a); + } + } + if alive_with_new_edges.is_empty() { + return Ok(vec![]); + } + + let mut starts = vec![usize::MAX; tables.num_nodes()]; + let mut stops = vec![usize::MAX; tables.num_nodes()]; + + for (i, e) in tables.enumerate_edges() { + if starts[e.parent as usize] == usize::MAX { + starts[e.parent as usize] = i; + stops[e.parent as usize] = i + 1; + } else { + stops[e.parent as usize] = i + 1; + } + } + + let mut rv = vec![]; + for a in alive_with_new_edges { + rv.push(ParentLocation::new( + a, + starts[a as usize], + stops[a as usize], + )); + } + + rv.sort_by(|a, b| { + let ta = tables.nodes_[a.parent as usize].time; + let tb = tables.nodes_[b.parent as usize].time; + if ta == tb { + if a.start == b.start { + return a.parent.cmp(&b.parent); + } + return a.start.cmp(&b.start); + } + ta.cmp(&tb).reverse() + }); + + // TODO: this could eventually be called in a debug_assert + if !rv.is_empty() { + for i in 1..rv.len() { + let t0 = tables.nodes_[rv[i - 1].parent as usize].time; + let t1 = tables.nodes_[rv[i].parent as usize].time; + if t0 < t1 { + return Err(ForrusttsError::SimplificationError { + value: "existing edges not properly sorted by time".to_string(), + }); + } + } + } + Ok(rv) +} + +fn queue_children( + child: IdType, + left: Position, + right: Position, + ancestry: &mut AncestryList, + overlapper: &mut SegmentOverlapper, +) -> Result<(), ForrusttsError> { + Ok(ancestry.for_each(child, |seg: &Segment| { + if seg.right > left && right > seg.left { + overlapper.enqueue( + std::cmp::max(seg.left, left), + std::cmp::min(seg.right, right), + seg.node, + ); + } + true + })?) +} + +fn process_births_from_buffer( + head: IdType, + edge_buffer: &EdgeBuffer, + state: &mut SimplificationBuffers, +) -> Result<(), ForrusttsError> { + // Have to take references here to + // make the borrow checker happy. + let a = &mut state.ancestry; + let o = &mut state.overlapper; + Ok(edge_buffer.for_each(head, |seg: &Segment| { + queue_children(seg.node, seg.left, seg.right, a, o).unwrap(); + true + })?) +} + +bitflags! { + /// Boolean flags affecting simplification + /// behavior. + /// + /// # Example + /// + /// ``` + /// let e = forrustts::SimplificationFlags::empty(); + /// assert_eq!(e.bits(), 0); + /// ``` + #[derive(Default)] + pub struct SimplificationFlags: u32 { + /// Validate that input edges are sorted + const VALIDATE_EDGES = 1 << 0; + /// Validate that input mutations are sorted + const VALIDATE_MUTATIONS = 1 << 1; + /// Validate all tables. + const VALIDATE_ALL = Self::VALIDATE_EDGES.bits | Self::VALIDATE_MUTATIONS.bits; + } +} + +/// Information about samples used for +/// table simpilfication. +#[derive(Default)] +pub struct SamplesInfo { + /// A list of sample IDs. + /// Can include both "alive" and + /// "ancient/remembered/preserved" sample + /// nodes. + pub samples: Vec, + /// When using [``EdgeBuffer``](type.EdgeBuffer.html) + /// to record transmission + /// events, this list must contain a list of all node IDs + /// alive the last time simplification happened. Here, + /// "alive" means "could leave more descendants". + /// At the *start* of a simulation, this should be filled + /// with a list of "founder" node IDs. + pub edge_buffer_founder_nodes: Vec, +} + +impl SamplesInfo { + /// Generate a new instance. + pub fn new() -> Self { + SamplesInfo { + samples: vec![], + edge_buffer_founder_nodes: vec![], + } + } +} + +/// Useful information output by table +/// simplification. +pub struct SimplificationOutput { + /// Maps input node ID to output ID. + /// Values are set to [``NULL_ID``](crate::NULL_ID) + /// for input nodes that "simplify out". + pub idmap: Vec, +} + +impl SimplificationOutput { + /// Create a new instance. + pub fn new() -> Self { + SimplificationOutput { idmap: vec![] } + } +} + +impl Default for SimplificationOutput { + fn default() -> Self { + SimplificationOutput::new() + } +} + +/// Holds internal memory used by +/// simplification machinery. +/// +/// During simplification, several large +/// memory blocks are required. This type +/// allows those allocations to be re-used +/// in subsequent calls to +/// [simplify_tables_with_state](fn.simplify_tables_with_state.html). +/// Doing so typically improves run times at +/// the cost of higher peak memory consumption. +pub struct SimplificationBuffers { + new_edges: EdgeTable, + temp_edge_buffer: EdgeTable, + new_nodes: NodeTable, + overlapper: SegmentOverlapper, + ancestry: AncestryList, +} + +impl SimplificationBuffers { + /// Create a new instance. + pub const fn new() -> SimplificationBuffers { + SimplificationBuffers { + new_edges: EdgeTable::new(), + temp_edge_buffer: EdgeTable::new(), + new_nodes: NodeTable::new(), + overlapper: SegmentOverlapper::new(), + ancestry: AncestryList::new(), + } + } + + // NOTE: should this be fully pub? + fn clear(&mut self) { + self.new_edges.clear(); + self.temp_edge_buffer.clear(); + self.new_nodes.clear(); + } +} + +/// Simplify a [``TableCollection``]. +/// +/// # Parameters +/// +/// * `samples`: +/// * `flags`: modify the behavior of the simplification algorithm. +/// * `tables`: a [``TableCollection``] to simplify. +/// * `output`: Where simplification output gets written. +/// See [``SimplificationOutput``]. +/// +/// # Notes +/// +/// The input tables must be sorted. +/// See [``TableCollection::sort_tables_for_simplification``]. +/// +/// It is common to simplify many times during a simulation. +/// To avoid making big allocations each time, see +/// [``simplify_tables``] to keep memory allocations +/// persistent between simplifications. +pub fn simplify_tables_without_state( + samples: &SamplesInfo, + flags: SimplificationFlags, + tables: &mut TableCollection, + output: &mut SimplificationOutput, +) -> Result<(), ForrusttsError> { + let mut state = SimplificationBuffers::new(); + simplify_tables(samples, flags, &mut state, tables, output) +} + +/// Simplify a [``TableCollection``]. +/// +/// This differs from [``simplify_tables_without_state``] in that the big memory +/// allocations made during simplification are preserved in +/// an instance of [``SimplificationBuffers``]. +/// +/// # Parameters +/// +/// * `samples`: +/// * `flags`: modify the behavior of the simplification algorithm. +/// * `state`: These are the internal data structures used +/// by the simpilfication algorithm. +/// * `tables`: a [``TableCollection``] to simplify. +/// * `output`: Where simplification output gets written. +/// See [``SimplificationOutput``]. +/// +/// # Notes +/// +/// The input tables must be sorted. +/// See [``TableCollection::sort_tables_for_simplification``]. +pub fn simplify_tables( + samples: &SamplesInfo, + flags: SimplificationFlags, + state: &mut SimplificationBuffers, + tables: &mut TableCollection, + output: &mut SimplificationOutput, +) -> Result<(), ForrusttsError> { + setup_simplification(samples, tables, flags, state, output)?; + + let mut edge_i = 0; + let num_edges = tables.num_edges(); + let mut new_edges_inserted: usize = 0; + while edge_i < num_edges { + edge_i = process_parent( + tables.edges_[edge_i].parent, + (edge_i, num_edges), + &tables, + state, + output, + )?; + + if state.new_edges.len() >= 1024 && new_edges_inserted + state.new_edges.len() < edge_i { + for i in state.new_edges.drain(..) { + tables.edges_[new_edges_inserted] = i; + new_edges_inserted += 1; + } + assert_eq!(state.new_edges.len(), 0); + } + } + + tables.edges_.truncate(new_edges_inserted); + tables.edges_.append(&mut state.new_edges); + std::mem::swap(&mut tables.nodes_, &mut state.new_nodes); + + Ok(()) +} + +/// Data type used for edge buffering. +/// Simplification of simulated data happens +/// via [``crate::simplify_from_edge_buffer()``]. +/// +/// # Overview +/// +/// The typical tree sequence recording workflow +/// goes like this. When a new node is "born", +/// we: +/// +/// 1. Add a new node to the node table. This +/// new node is a "child". +/// 2. Add edges to the edge table representing +/// the genomic intervals passed on from +/// various parents to this child node. +/// +/// We repeat `1` and `2` for a while, then +/// we [``sort the tables``](crate::TableCollection::sort_tables_for_simplification). +/// After sorting, we [``simplify``](crate::simplify_tables()) the tables. +/// +/// We can avoid the sorting step using this type. +/// To start, we record the list of currently-alive nodes +/// [``here``](crate::SamplesInfo::edge_buffer_founder_nodes). +/// +/// Then, we use `parent` ID values as the +/// `head` values for linked lists stored in a +/// [``NestedForwardList``](crate::nested_forward_list::NestedForwardList). +/// +/// By its very nature, offspring are generated by birth order. +/// Further, a well-behaved forward simulation is capable of calculating +/// edges from left-to-right along a genome. Thus, we can +/// [``extend``](crate::nested_forward_list::NestedForwardList::extend) +/// the data for each chain with [``Segment``] instances representing +/// transmission events. The segment's [``node``](Segment::node) field represents +/// the child. +/// +/// After recording for a while, we call +/// [``simplify_from_edge_buffer``](crate::simplify_from_edge_buffer()) to simplify +/// the tables. After simplification, the client code must re-populate +/// [``the list``](crate::SamplesInfo::edge_buffer_founder_nodes) of alive nodes. +/// Once that is done, we can keep recording, etc.. +/// +/// # Example +/// +/// For a full example of use in simulation, +/// see the source code for +/// [``wright_fisher::neutral_wf``](crate::wright_fisher::neutral_wf) +/// +pub type EdgeBuffer = NestedForwardList; + +/// Simplify a [``TableCollection``] from an [``EdgeBuffer``]. +/// +/// See [``EdgeBuffer``] for discussion. +/// +/// # Parameters +/// +/// * `samples`: Instance of [``SamplesInfo``]. The field +/// [``SamplesInfo::edge_buffer_founder_nodes``] +/// must be populated. See [``EdgeBuffer``] for details. +/// * `flags`: modify the behavior of the simplification algorithm. +/// * `state`: These are the internal data structures used +/// by the simpilfication algorithm. +/// * `edge_buffer`: An [``EdgeBuffer``] recording births since the last +/// simplification. +/// * `tables`: a [``TableCollection``] to simplify. +/// * `output`: Where simplification output gets written. +/// See [``SimplificationOutput``]. +/// +/// # Notes +/// +/// The input tables must be sorted. +/// See [``TableCollection::sort_tables_for_simplification``]. +/// +/// # Limitations +/// +/// The simplification code does not currently validate +/// that "buffered" edges do indeed represent a valid sort order. +pub fn simplify_from_edge_buffer( + samples: &SamplesInfo, + flags: SimplificationFlags, + state: &mut SimplificationBuffers, + edge_buffer: &mut EdgeBuffer, + tables: &mut TableCollection, + output: &mut SimplificationOutput, +) -> Result<(), ForrusttsError> { + setup_simplification(samples, tables, flags, state, output)?; + + // Process all edges since the last simplification. + let mut max_time = Time::MIN; + for n in samples.edge_buffer_founder_nodes.iter() { + max_time = std::cmp::max(max_time, tables.node(*n).time); + } + for (i, _) in edge_buffer.head_itr().rev().enumerate() { + let head = (edge_buffer.len() - i - 1) as i32; + let ptime = tables.node(head).time; + if ptime > max_time + // Then this is a parent who is: + // 1. Born since the last simplification. + // 2. Left offspring + { + state.overlapper.clear_queue(); + process_births_from_buffer(head, edge_buffer, state)?; + state.overlapper.finalize_queue(tables.genome_length()); + merge_ancestors( + &tables.nodes(), + tables.genome_length(), + head, + state, + &mut output.idmap, + )?; + } else if ptime <= max_time { + break; + } + } + + let existing_edges = + find_pre_existing_edges(&tables, &samples.edge_buffer_founder_nodes, &edge_buffer)?; + + let mut edge_i = 0; + let num_edges = tables.num_edges(); + + for ex in existing_edges { + while edge_i < num_edges + && tables.nodes_[tables.edges_[edge_i].parent as usize].time + > tables.nodes_[ex.parent as usize].time + { + edge_i = process_parent( + tables.edges_[edge_i].parent, + (edge_i, num_edges), + &tables, + state, + output, + )?; + } + if ex.start != usize::MAX { + while (edge_i as usize) < ex.start + && tables.nodes_[tables.edges_[edge_i].parent as usize].time + >= tables.nodes_[ex.parent as usize].time + { + edge_i = process_parent( + tables.edges_[edge_i].parent, + (edge_i, num_edges), + &tables, + state, + output, + )?; + } + } + // now, handle ex.parent + state.overlapper.clear_queue(); + if ex.start != usize::MAX { + while edge_i < ex.stop { + // TODO: a debug assert or regular assert? + if tables.edges_[edge_i].parent != ex.parent { + return Err(ForrusttsError::SimplificationError { + value: "Unexpected parent node".to_string(), + }); + } + let a = &mut state.ancestry; + let o = &mut state.overlapper; + queue_children( + tables.edges_[edge_i].child, + tables.edges_[edge_i].left, + tables.edges_[edge_i].right, + a, + o, + )?; + edge_i += 1; + } + if edge_i < num_edges && tables.edges_[edge_i].parent == ex.parent { + return Err(ForrusttsError::SimplificationError { + value: "error traversing pre-existing edges for parent".to_string(), + }); + } + } + process_births_from_buffer(ex.parent, edge_buffer, state)?; + state.overlapper.finalize_queue(tables.genome_length()); + merge_ancestors( + &tables.nodes_, + tables.genome_length(), + ex.parent, + state, + &mut output.idmap, + )?; + } + + // Handle remaining edges. + while edge_i < num_edges { + edge_i = process_parent( + tables.edges_[edge_i].parent, + (edge_i, num_edges), + &tables, + state, + output, + )?; + } + + std::mem::swap(&mut tables.edges_, &mut state.new_edges); + std::mem::swap(&mut tables.nodes_, &mut state.new_nodes); + edge_buffer.reset(tables.num_nodes()); + + Ok(()) +} + +#[cfg(test)] +mod test_samples_info { + use super::SamplesInfo; + + #[test] + fn test_default() { + let s: SamplesInfo = Default::default(); + assert!(s.samples.is_empty()); + assert!(s.edge_buffer_founder_nodes.is_empty()); + } +} + +#[cfg(test)] +mod test_simplification_output { + use super::SimplificationOutput; + + #[test] + fn test_defaul() { + let x: SimplificationOutput = Default::default(); + assert_eq!(x.idmap.is_empty(), true); + } +} + +#[cfg(test)] +mod test_simplification_flags { + use super::SimplificationFlags; + + #[test] + fn test_empty_simplification_flags() { + let e = SimplificationFlags::empty(); + assert_eq!(e.bits(), 0); + } +} + +#[cfg(test)] +mod test_simpify_tables { + use super::*; + + // TODO: we need lots more tests of these validations! + + #[test] + fn test_simplify_tables_unsorted_edges() { + let mut tables = TableCollection::new(1000).unwrap(); + + tables.add_node(0, 0).unwrap(); // parent + tables.add_node(1, 0).unwrap(); // child + tables.add_edge(100, tables.genome_length(), 0, 1).unwrap(); + tables.add_edge(0, 100, 0, 1).unwrap(); + + let mut output = SimplificationOutput::new(); + + let mut samples = SamplesInfo::new(); + samples.samples.push(1); + + let _ = simplify_tables_without_state( + &samples, + SimplificationFlags::VALIDATE_ALL, + &mut tables, + &mut output, + ) + .map_or_else( + |x: ForrusttsError| { + assert_eq!( + x, + ForrusttsError::TablesError { + value: TablesError::EdgesNotSortedByLeft + } + ) + }, + |_| panic!(), + ); + } +} + +#[cfg(test)] +mod test_simpify_table_from_edge_buffer { + use super::{process_births_from_buffer, EdgeBuffer, ForrusttsError, SimplificationBuffers}; + + // This shows that the closure error gets propagated + // as the result type. + #[test] + fn test_process_births_from_buffer_closure_error() { + let b = EdgeBuffer::new(); + let mut s = SimplificationBuffers::new(); + assert!(process_births_from_buffer(-1, &b, &mut s) + .map_or_else(|_: ForrusttsError| true, |_| false)); + } +} diff --git a/src/simplification_buffers.rs b/src/simplification_buffers.rs deleted file mode 100644 index 13067f74..00000000 --- a/src/simplification_buffers.rs +++ /dev/null @@ -1,40 +0,0 @@ -use crate::simplification_logic::{AncestryList, SegmentOverlapper}; -use crate::tables::{EdgeTable, NodeTable}; - -/// Holds internal memory used by -/// simplification machinery. -/// -/// During simplification, several large -/// memory blocks are required. This type -/// allows those allocations to be re-used -/// in subsequent calls to -/// [simplify_tables_with_state](fn.simplify_tables_with_state.html). -/// Doing so typically improves run times at -/// the cost of higher peak memory consumption. -pub struct SimplificationBuffers { - pub(crate) new_edges: EdgeTable, - pub(crate) temp_edge_buffer: EdgeTable, - pub(crate) new_nodes: NodeTable, - pub(crate) overlapper: SegmentOverlapper, - pub(crate) ancestry: AncestryList, -} - -impl SimplificationBuffers { - /// Create a new instance. - pub const fn new() -> SimplificationBuffers { - SimplificationBuffers { - new_edges: EdgeTable::new(), - temp_edge_buffer: EdgeTable::new(), - new_nodes: NodeTable::new(), - overlapper: SegmentOverlapper::new(), - ancestry: AncestryList::new(), - } - } - - // NOTE: should this be fully pub? - pub(crate) fn clear(&mut self) { - self.new_edges.clear(); - self.temp_edge_buffer.clear(); - self.new_nodes.clear(); - } -} diff --git a/src/simplification_common.rs b/src/simplification_common.rs deleted file mode 100644 index 2f694071..00000000 --- a/src/simplification_common.rs +++ /dev/null @@ -1,83 +0,0 @@ -/// Common functions to reuse in various "simplify tables" -/// functions -use crate::simplification_logic; -use crate::validate_edge_table; -use crate::ForrusttsError; -use crate::SamplesInfo; -use crate::SimplificationBuffers; -use crate::SimplificationFlags; -use crate::SimplificationOutput; -use crate::{IdType, NULL_ID}; -use crate::{Node, TableCollection}; - -pub fn validate_tables( - tables: &TableCollection, - flags: &SimplificationFlags, -) -> Result<(), ForrusttsError> { - if flags.contains(SimplificationFlags::VALIDATE_EDGES) { - validate_edge_table(tables.genome_length(), tables.edges(), tables.nodes())?; - } - Ok(()) -} - -fn setup_idmap(nodes: &[Node], idmap: &mut Vec) { - idmap.resize(nodes.len(), NULL_ID); - idmap.iter_mut().for_each(|x| *x = NULL_ID); -} - -pub fn setup_simplification( - samples: &SamplesInfo, - tables: &TableCollection, - flags: SimplificationFlags, - state: &mut SimplificationBuffers, - output: &mut SimplificationOutput, -) -> Result<(), ForrusttsError> { - if !tables.sites_.is_empty() || !tables.mutations_.is_empty() { - return Err(ForrusttsError::SimplificationError { - value: "mutation simplification not yet implemented".to_string(), - }); - } - - validate_tables(tables, &flags)?; - setup_idmap(&tables.nodes_, &mut output.idmap); - - state.clear(); - state.ancestry.reset(tables.num_nodes()); - - simplification_logic::record_sample_nodes( - &samples.samples, - &tables, - &mut state.new_nodes, - &mut state.ancestry, - &mut output.idmap, - )?; - - Ok(()) -} - -pub fn process_parent( - u: IdType, - (edge_index, num_edges): (usize, usize), - tables: &TableCollection, - state: &mut SimplificationBuffers, - output: &mut SimplificationOutput, -) -> Result { - let edge_i = simplification_logic::find_parent_child_segment_overlap( - &tables.edges_, - edge_index, - num_edges, - tables.genome_length(), - u, - &mut state.ancestry, - &mut state.overlapper, - )?; - - simplification_logic::merge_ancestors( - &tables.nodes_, - tables.genome_length(), - u, - state, - &mut output.idmap, - )?; - Ok(edge_i) -} diff --git a/src/simplification_flags.rs b/src/simplification_flags.rs deleted file mode 100644 index c29e1721..00000000 --- a/src/simplification_flags.rs +++ /dev/null @@ -1,33 +0,0 @@ -use bitflags::bitflags; - -bitflags! { - /// Boolean flags affecting simplification - /// behavior. - /// - /// # Example - /// - /// ``` - /// let e = forrustts::SimplificationFlags::empty(); - /// assert_eq!(e.bits(), 0); - /// ``` - #[derive(Default)] - pub struct SimplificationFlags: u32 { - /// Validate that input edges are sorted - const VALIDATE_EDGES = 1 << 0; - /// Validate that input mutations are sorted - const VALIDATE_MUTATIONS = 1 << 1; - /// Validate all tables. - const VALIDATE_ALL = Self::VALIDATE_EDGES.bits | Self::VALIDATE_MUTATIONS.bits; - } -} - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_empty() { - let e = SimplificationFlags::empty(); - assert_eq!(e.bits(), 0); - } -} diff --git a/src/simplification_logic.rs b/src/simplification_logic.rs deleted file mode 100644 index 060590c5..00000000 --- a/src/simplification_logic.rs +++ /dev/null @@ -1,372 +0,0 @@ -use crate::nested_forward_list::NestedForwardList; -use crate::segment::Segment; -use crate::simplification_buffers::SimplificationBuffers; -use crate::tables::*; -use crate::tsdef::{IdType, Position, NULL_ID}; -use crate::ForrusttsError; - -pub struct SegmentOverlapper { - segment_queue: Vec, - overlapping: Vec, - left: Position, - right: Position, - qbeg: usize, - qend: usize, - obeg: usize, - oend: usize, -} - -impl SegmentOverlapper { - fn set_partition(&mut self) -> Position { - let mut tright = Position::MAX; - let mut b: usize = 0; - - for i in 0..self.oend { - if self.overlapping[i].right > self.left { - self.overlapping[b] = self.overlapping[i]; - tright = std::cmp::min(tright, self.overlapping[b].right); - b += 1; - } - } - - self.oend = b; - - tright - } - - fn num_overlaps(&self) -> usize { - assert!( - self.oend - self.obeg <= self.overlapping.len(), - "overlap details = {} {} {}", - self.oend, - self.obeg, - self.overlapping.len() - ); - self.oend - self.obeg - } - - // Public interface below - - pub const fn new() -> SegmentOverlapper { - SegmentOverlapper { - segment_queue: vec![], - overlapping: vec![], - left: 0, - right: Position::MAX, - qbeg: std::usize::MAX, - qend: std::usize::MAX, - obeg: std::usize::MAX, - oend: std::usize::MAX, - } - } - - pub fn init(&mut self) { - self.qbeg = 0; - self.qend = self.segment_queue.len() - 1; - assert!(self.qend < self.segment_queue.len()); - self.obeg = 0; - self.oend = 0; - self.overlapping.clear(); - } - - pub fn enqueue(&mut self, left: Position, right: Position, node: IdType) { - self.segment_queue.push(Segment { left, right, node }); - } - - pub fn finalize_queue(&mut self, maxlen: Position) { - self.segment_queue.sort_by(|a, b| a.left.cmp(&b.left)); - self.segment_queue.push(Segment { - left: maxlen, - right: maxlen + 1, - node: NULL_ID, - }); - } - - pub fn advance(&mut self) -> bool { - let mut rv = false; - - if self.qbeg < self.qend { - self.left = self.right; - let mut tright = self.set_partition(); - if self.num_overlaps() == 0 { - self.left = self.segment_queue[self.qbeg].left; - } - while self.qbeg < self.qend && self.segment_queue[self.qbeg].left == self.left { - tright = std::cmp::min(tright, self.segment_queue[self.qbeg].right); - // NOTE: I wonder how efficient this is vs C++? - self.overlapping - .insert(self.oend, self.segment_queue[self.qbeg]); - self.oend += 1; - self.qbeg += 1; - } - self.right = std::cmp::min(self.segment_queue[self.qbeg].left, tright); - rv = true; - } else { - self.left = self.right; - self.right = Position::MAX; - let tright = self.set_partition(); - if self.num_overlaps() > 0 { - self.right = tright; - rv = true - } - } - - rv - } - - pub fn get_left(&self) -> Position { - self.left - } - - pub fn get_right(&self) -> Position { - self.right - } - - pub fn clear_queue(&mut self) { - self.segment_queue.clear(); - } - - pub fn overlap(&self, i: usize) -> &Segment { - &self.overlapping[i] - } -} - -pub type AncestryList = NestedForwardList; - -pub fn find_parent_child_segment_overlap( - edges: &[Edge], - edge_index: usize, - num_edges: usize, - maxlen: Position, - u: IdType, - ancestry: &mut AncestryList, - overlapper: &mut SegmentOverlapper, -) -> Result { - overlapper.clear_queue(); - - let mut i = edge_index; - - while i < num_edges && edges[i].parent == u { - let edge = &edges[i]; - - ancestry.for_each(edges[i].child, |seg: &Segment| { - if seg.right > edge.left && edge.right > seg.left { - overlapper.enqueue( - std::cmp::max(seg.left, edge.left), - std::cmp::min(seg.right, edge.right), - seg.node, - ); - } - true - })?; - - i += 1; - } - overlapper.finalize_queue(maxlen); - Ok(i) -} - -fn add_ancestry( - input_id: IdType, - left: Position, - right: Position, - node: IdType, - ancestry: &mut AncestryList, -) -> Result<(), ForrusttsError> { - let head = ancestry.head(input_id)?; - if head == AncestryList::null() { - let seg = Segment { left, right, node }; - ancestry.extend(input_id, seg)?; - } else { - let last_idx = ancestry.tail(input_id)?; - if last_idx == AncestryList::null() { - return Err(ForrusttsError::SimplificationError { - value: "last_idx is NULL_ID".to_string(), - }); - } - let last = ancestry.fetch_mut(last_idx)?; - if last.right == left && last.node == node { - last.right = right; - } else { - let seg = Segment { left, right, node }; - ancestry.extend(input_id, seg)?; - } - } - Ok(()) -} - -fn buffer_edge( - left: Position, - right: Position, - parent: IdType, - child: IdType, - temp_edge_buffer: &mut EdgeTable, -) { - let i = temp_edge_buffer - .iter() - .rposition(|e: &Edge| e.child == child); - - match i { - None => temp_edge_buffer.push(Edge { - left, - right, - parent, - child, - }), - Some(x) => { - if temp_edge_buffer[x].right == left { - temp_edge_buffer[x].right = right; - } else { - temp_edge_buffer.push(Edge { - left, - right, - parent, - child, - }); - } - } - } -} - -fn output_buffered_edges(temp_edge_buffer: &mut EdgeTable, new_edges: &mut EdgeTable) -> usize { - temp_edge_buffer.sort_by(|a, b| a.child.cmp(&b.child)); - - // Need to store size here b/c - // append drains contents of input!!! - let rv = temp_edge_buffer.len(); - new_edges.append(temp_edge_buffer); - - rv -} - -pub fn merge_ancestors( - input_nodes: &[Node], - maxlen: Position, - parent_input_id: IdType, - state: &mut SimplificationBuffers, - idmap: &mut [IdType], -) -> Result<(), ForrusttsError> { - let mut output_id = idmap[parent_input_id as usize]; - let is_sample = output_id != NULL_ID; - - if is_sample { - state.ancestry.nullify_list(parent_input_id)?; - } - - let mut previous_right: Position = 0; - let mut ancestry_node: IdType; - state.overlapper.init(); - state.temp_edge_buffer.clear(); - - while state.overlapper.advance() { - if state.overlapper.num_overlaps() == 1 { - ancestry_node = state.overlapper.overlap(0).node; - if is_sample { - buffer_edge( - state.overlapper.get_left(), - state.overlapper.get_right(), - output_id, - ancestry_node, - &mut state.temp_edge_buffer, - ); - ancestry_node = output_id; - } - } else { - if output_id == NULL_ID { - state.new_nodes.push(Node { - time: input_nodes[parent_input_id as usize].time, - deme: input_nodes[parent_input_id as usize].deme, - }); - output_id = (state.new_nodes.len() - 1) as IdType; - idmap[parent_input_id as usize] = output_id; - } - ancestry_node = output_id; - for i in 0..state.overlapper.num_overlaps() as usize { - let o = &state.overlapper.overlap(i); - buffer_edge( - state.overlapper.get_left(), - state.overlapper.get_right(), - output_id, - o.node, - &mut state.temp_edge_buffer, - ); - } - } - if is_sample && state.overlapper.get_left() != previous_right { - add_ancestry( - parent_input_id, - previous_right, - state.overlapper.get_left(), - output_id, - &mut state.ancestry, - )?; - } - add_ancestry( - parent_input_id, - state.overlapper.get_left(), - state.overlapper.get_right(), - ancestry_node, - &mut state.ancestry, - )?; - previous_right = state.overlapper.get_right(); - } - if is_sample && previous_right != maxlen { - add_ancestry( - parent_input_id, - previous_right, - maxlen, - output_id, - &mut state.ancestry, - )?; - } - - if output_id != NULL_ID { - let n = output_buffered_edges(&mut state.temp_edge_buffer, &mut state.new_edges); - - if n == 0 && !is_sample { - assert!(output_id < state.new_nodes.len() as IdType); - state.new_nodes.truncate(output_id as usize); - idmap[parent_input_id as usize] = NULL_ID; - } - } - Ok(()) -} - -pub fn record_sample_nodes( - samples: &[IdType], - tables: &TableCollection, - new_nodes: &mut NodeTable, - ancestry: &mut AncestryList, - idmap: &mut [IdType], -) -> Result<(), ForrusttsError> { - for sample in samples.iter() { - assert!(*sample >= 0); - // NOTE: the following can be debug_assert? - if *sample == NULL_ID { - return Err(ForrusttsError::SimplificationError { - value: "sample node is NULL_ID".to_string(), - }); - } - if idmap[*sample as usize] != NULL_ID { - return Err(ForrusttsError::SimplificationError { - value: "invalid sample list!".to_string(), - }); - } - let n = tables.node(*sample); - new_nodes.push(Node { - time: n.time, - deme: n.deme, - }); - - add_ancestry( - *sample, - 0, - tables.genome_length(), - (new_nodes.len() - 1) as IdType, - ancestry, - )?; - - idmap[*sample as usize] = (new_nodes.len() - 1) as IdType; - } - Ok(()) -} diff --git a/src/simplification_output.rs b/src/simplification_output.rs deleted file mode 100644 index 5fb7bc8d..00000000 --- a/src/simplification_output.rs +++ /dev/null @@ -1,33 +0,0 @@ -/// Useful information output by table -/// simplification. -pub struct SimplificationOutput { - /// Maps input node ID to output ID. - /// Values are set to [``NULL_ID``](crate::NULL_ID) - /// for input nodes that "simplify out". - pub idmap: Vec, -} - -impl SimplificationOutput { - /// Create a new instance. - pub fn new() -> Self { - SimplificationOutput { idmap: vec![] } - } -} - -impl Default for SimplificationOutput { - fn default() -> Self { - SimplificationOutput::new() - } -} - -#[cfg(test)] -mod test { - - use super::*; - - #[test] - fn test_default() { - let x: SimplificationOutput = Default::default(); - assert_eq!(x.idmap.is_empty(), true); - } -} diff --git a/src/simplify_from_edge_buffer.rs b/src/simplify_from_edge_buffer.rs deleted file mode 100644 index 6cdf4e0c..00000000 --- a/src/simplify_from_edge_buffer.rs +++ /dev/null @@ -1,296 +0,0 @@ -use crate::simplification_common::*; -use crate::simplification_logic; -use crate::tables::*; -use crate::EdgeBuffer; -use crate::ForrusttsError; -use crate::SamplesInfo; -use crate::Segment; -use crate::SimplificationBuffers; -use crate::SimplificationFlags; -use crate::SimplificationOutput; -use crate::{IdType, Position, Time}; - -struct ParentLocation { - parent: IdType, - start: usize, - stop: usize, -} - -// TODO: validate input and return errors. -impl ParentLocation { - fn new(parent: IdType, start: usize, stop: usize) -> Self { - ParentLocation { - parent, - start, - stop, - } - } -} - -fn find_pre_existing_edges( - tables: &TableCollection, - edge_buffer_founder_nodes: &[IdType], - edge_buffer: &EdgeBuffer, -) -> Result, ForrusttsError> { - let mut alive_with_new_edges: Vec = vec![]; - - for a in edge_buffer_founder_nodes { - if edge_buffer.head(*a)? != EdgeBuffer::null() { - alive_with_new_edges.push(*a); - } - } - if alive_with_new_edges.is_empty() { - return Ok(vec![]); - } - - let mut starts = vec![usize::MAX; tables.num_nodes()]; - let mut stops = vec![usize::MAX; tables.num_nodes()]; - - for (i, e) in tables.enumerate_edges() { - if starts[e.parent as usize] == usize::MAX { - starts[e.parent as usize] = i; - stops[e.parent as usize] = i + 1; - } else { - stops[e.parent as usize] = i + 1; - } - } - - let mut rv = vec![]; - for a in alive_with_new_edges { - rv.push(ParentLocation::new( - a, - starts[a as usize], - stops[a as usize], - )); - } - - rv.sort_by(|a, b| { - let ta = tables.nodes_[a.parent as usize].time; - let tb = tables.nodes_[b.parent as usize].time; - if ta == tb { - if a.start == b.start { - return a.parent.cmp(&b.parent); - } - return a.start.cmp(&b.start); - } - ta.cmp(&tb).reverse() - }); - - // TODO: this could eventually be called in a debug_assert - if !rv.is_empty() { - for i in 1..rv.len() { - let t0 = tables.nodes_[rv[i - 1].parent as usize].time; - let t1 = tables.nodes_[rv[i].parent as usize].time; - if t0 < t1 { - return Err(ForrusttsError::SimplificationError { - value: "existing edges not properly sorted by time".to_string(), - }); - } - } - } - Ok(rv) -} - -fn queue_children( - child: IdType, - left: Position, - right: Position, - ancestry: &mut simplification_logic::AncestryList, - overlapper: &mut simplification_logic::SegmentOverlapper, -) -> Result<(), ForrusttsError> { - Ok(ancestry.for_each(child, |seg: &Segment| { - if seg.right > left && right > seg.left { - overlapper.enqueue( - std::cmp::max(seg.left, left), - std::cmp::min(seg.right, right), - seg.node, - ); - } - true - })?) -} - -fn process_births_from_buffer( - head: IdType, - edge_buffer: &EdgeBuffer, - state: &mut SimplificationBuffers, -) -> Result<(), ForrusttsError> { - // Have to take references here to - // make the borrow checker happy. - let a = &mut state.ancestry; - let o = &mut state.overlapper; - Ok(edge_buffer.for_each(head, |seg: &Segment| { - queue_children(seg.node, seg.left, seg.right, a, o).unwrap(); - true - })?) -} - -/// Simplify a [``TableCollection``] from an [``EdgeBuffer``]. -/// -/// See [``EdgeBuffer``] for discussion. -/// -/// # Parameters -/// -/// * `samples`: Instance of [``SamplesInfo``]. The field -/// [``SamplesInfo::edge_buffer_founder_nodes``] -/// must be populated. See [``EdgeBuffer``] for details. -/// * `flags`: modify the behavior of the simplification algorithm. -/// * `state`: These are the internal data structures used -/// by the simpilfication algorithm. -/// * `edge_buffer`: An [``EdgeBuffer``] recording births since the last -/// simplification. -/// * `tables`: a [``TableCollection``] to simplify. -/// * `output`: Where simplification output gets written. -/// See [``SimplificationOutput``]. -/// -/// # Notes -/// -/// The input tables must be sorted. -/// See [``TableCollection::sort_tables_for_simplification``]. -/// -/// # Limitations -/// -/// The simplification code does not currently validate -/// that "buffered" edges do indeed represent a valid sort order. -pub fn simplify_from_edge_buffer( - samples: &SamplesInfo, - flags: SimplificationFlags, - state: &mut SimplificationBuffers, - edge_buffer: &mut EdgeBuffer, - tables: &mut TableCollection, - output: &mut SimplificationOutput, -) -> Result<(), ForrusttsError> { - setup_simplification(samples, tables, flags, state, output)?; - - // Process all edges since the last simplification. - let mut max_time = Time::MIN; - for n in samples.edge_buffer_founder_nodes.iter() { - max_time = std::cmp::max(max_time, tables.node(*n).time); - } - for (i, _) in edge_buffer.head_itr().rev().enumerate() { - let head = (edge_buffer.len() - i - 1) as i32; - let ptime = tables.node(head).time; - if ptime > max_time - // Then this is a parent who is: - // 1. Born since the last simplification. - // 2. Left offspring - { - state.overlapper.clear_queue(); - process_births_from_buffer(head, edge_buffer, state)?; - state.overlapper.finalize_queue(tables.genome_length()); - simplification_logic::merge_ancestors( - &tables.nodes(), - tables.genome_length(), - head, - state, - &mut output.idmap, - )?; - } else if ptime <= max_time { - break; - } - } - - let existing_edges = - find_pre_existing_edges(&tables, &samples.edge_buffer_founder_nodes, &edge_buffer)?; - - let mut edge_i = 0; - let num_edges = tables.num_edges(); - - for ex in existing_edges { - while edge_i < num_edges - && tables.nodes_[tables.edges_[edge_i].parent as usize].time - > tables.nodes_[ex.parent as usize].time - { - edge_i = process_parent( - tables.edges_[edge_i].parent, - (edge_i, num_edges), - &tables, - state, - output, - )?; - } - if ex.start != usize::MAX { - while (edge_i as usize) < ex.start - && tables.nodes_[tables.edges_[edge_i].parent as usize].time - >= tables.nodes_[ex.parent as usize].time - { - edge_i = process_parent( - tables.edges_[edge_i].parent, - (edge_i, num_edges), - &tables, - state, - output, - )?; - } - } - // now, handle ex.parent - state.overlapper.clear_queue(); - if ex.start != usize::MAX { - while edge_i < ex.stop { - // TODO: a debug assert or regular assert? - if tables.edges_[edge_i].parent != ex.parent { - return Err(ForrusttsError::SimplificationError { - value: "Unexpected parent node".to_string(), - }); - } - let a = &mut state.ancestry; - let o = &mut state.overlapper; - queue_children( - tables.edges_[edge_i].child, - tables.edges_[edge_i].left, - tables.edges_[edge_i].right, - a, - o, - )?; - edge_i += 1; - } - if edge_i < num_edges && tables.edges_[edge_i].parent == ex.parent { - return Err(ForrusttsError::SimplificationError { - value: "error traversing pre-existing edges for parent".to_string(), - }); - } - } - process_births_from_buffer(ex.parent, edge_buffer, state)?; - state.overlapper.finalize_queue(tables.genome_length()); - simplification_logic::merge_ancestors( - &tables.nodes_, - tables.genome_length(), - ex.parent, - state, - &mut output.idmap, - )?; - } - - // Handle remaining edges. - while edge_i < num_edges { - edge_i = process_parent( - tables.edges_[edge_i].parent, - (edge_i, num_edges), - &tables, - state, - output, - )?; - } - - std::mem::swap(&mut tables.edges_, &mut state.new_edges); - std::mem::swap(&mut tables.nodes_, &mut state.new_nodes); - edge_buffer.reset(tables.num_nodes()); - - Ok(()) -} - -#[cfg(test)] -mod test { - - use super::*; - - // This shows that the closure error gets propagated - // as the result type. - #[test] - fn test_process_births_from_buffer_closure_error() { - let b = EdgeBuffer::new(); - let mut s = SimplificationBuffers::new(); - assert!(process_births_from_buffer(-1, &b, &mut s) - .map_or_else(|_: ForrusttsError| true, |_| false)); - } -} diff --git a/src/simplify_tables.rs b/src/simplify_tables.rs deleted file mode 100644 index ed7a3899..00000000 --- a/src/simplify_tables.rs +++ /dev/null @@ -1,133 +0,0 @@ -use crate::simplification_common::*; -use crate::tables::*; -use crate::ForrusttsError; -use crate::SamplesInfo; -use crate::SimplificationBuffers; -use crate::SimplificationFlags; -use crate::SimplificationOutput; - -/// Simplify a [``TableCollection``]. -/// -/// # Parameters -/// -/// * `samples`: -/// * `flags`: modify the behavior of the simplification algorithm. -/// * `tables`: a [``TableCollection``] to simplify. -/// * `output`: Where simplification output gets written. -/// See [``SimplificationOutput``]. -/// -/// # Notes -/// -/// The input tables must be sorted. -/// See [``TableCollection::sort_tables_for_simplification``]. -/// -/// It is common to simplify many times during a simulation. -/// To avoid making big allocations each time, see -/// [``simplify_tables``] to keep memory allocations -/// persistent between simplifications. -pub fn simplify_tables_without_state( - samples: &SamplesInfo, - flags: SimplificationFlags, - tables: &mut TableCollection, - output: &mut SimplificationOutput, -) -> Result<(), ForrusttsError> { - let mut state = SimplificationBuffers::new(); - simplify_tables(samples, flags, &mut state, tables, output) -} - -/// Simplify a [``TableCollection``]. -/// -/// This differs from [``simplify_tables_without_state``] in that the big memory -/// allocations made during simplification are preserved in -/// an instance of [``SimplificationBuffers``]. -/// -/// # Parameters -/// -/// * `samples`: -/// * `flags`: modify the behavior of the simplification algorithm. -/// * `state`: These are the internal data structures used -/// by the simpilfication algorithm. -/// * `tables`: a [``TableCollection``] to simplify. -/// * `output`: Where simplification output gets written. -/// See [``SimplificationOutput``]. -/// -/// # Notes -/// -/// The input tables must be sorted. -/// See [``TableCollection::sort_tables_for_simplification``]. -pub fn simplify_tables( - samples: &SamplesInfo, - flags: SimplificationFlags, - state: &mut SimplificationBuffers, - tables: &mut TableCollection, - output: &mut SimplificationOutput, -) -> Result<(), ForrusttsError> { - setup_simplification(samples, tables, flags, state, output)?; - - let mut edge_i = 0; - let num_edges = tables.num_edges(); - let mut new_edges_inserted: usize = 0; - while edge_i < num_edges { - edge_i = process_parent( - tables.edges_[edge_i].parent, - (edge_i, num_edges), - &tables, - state, - output, - )?; - - if state.new_edges.len() >= 1024 && new_edges_inserted + state.new_edges.len() < edge_i { - for i in state.new_edges.drain(..) { - tables.edges_[new_edges_inserted] = i; - new_edges_inserted += 1; - } - assert_eq!(state.new_edges.len(), 0); - } - } - - tables.edges_.truncate(new_edges_inserted); - tables.edges_.append(&mut state.new_edges); - std::mem::swap(&mut tables.nodes_, &mut state.new_nodes); - - Ok(()) -} - -#[cfg(test)] -mod tests { - use super::*; - - // TODO: we need lots more tests of these validations! - - #[test] - fn test_simplify_tables_unsorted_edges() { - let mut tables = TableCollection::new(1000).unwrap(); - - tables.add_node(0, 0).unwrap(); // parent - tables.add_node(1, 0).unwrap(); // child - tables.add_edge(100, tables.genome_length(), 0, 1).unwrap(); - tables.add_edge(0, 100, 0, 1).unwrap(); - - let mut output = SimplificationOutput::new(); - - let mut samples = SamplesInfo::new(); - samples.samples.push(1); - - let _ = simplify_tables_without_state( - &samples, - SimplificationFlags::VALIDATE_ALL, - &mut tables, - &mut output, - ) - .map_or_else( - |x: ForrusttsError| { - assert_eq!( - x, - ForrusttsError::TablesError { - value: TablesError::EdgesNotSortedByLeft - } - ) - }, - |_| panic!(), - ); - } -} diff --git a/src/wright_fisher.rs b/src/wright_fisher.rs index 92d8b252..0f388c78 100644 --- a/src/wright_fisher.rs +++ b/src/wright_fisher.rs @@ -5,8 +5,9 @@ //! code and benchmarking utilities. However, some of //! the concepts here that are *not* public may be useful //! to others. Feel free to copy them! -use crate::simplify_from_edge_buffer::simplify_from_edge_buffer; -use crate::simplify_tables::*; +use crate::simplification::simplify_from_edge_buffer; +use crate::simplification::simplify_tables; +use crate::simplification::simplify_tables_without_state; use crate::tables::TableCollection; use crate::tsdef::*; use crate::EdgeBuffer;