From 93f399434ef2899fa10331b943254bb40ab382fe Mon Sep 17 00:00:00 2001 From: Fr0benius Date: Tue, 1 Sep 2020 18:51:53 -0700 Subject: [PATCH 1/4] Convex hull trick (#14) Finished implementation of convex hull trick with sqrt decomposition --- src/range_query/sqrt_decomp.rs | 80 +++++++++++++++++++++++++++++----- 1 file changed, 68 insertions(+), 12 deletions(-) diff --git a/src/range_query/sqrt_decomp.rs b/src/range_query/sqrt_decomp.rs index ead4a39..e00bdef 100644 --- a/src/range_query/sqrt_decomp.rs +++ b/src/range_query/sqrt_decomp.rs @@ -103,9 +103,11 @@ impl MoState for DistinctVals { /// Represents a minimum (lower envelope) of a collection of linear functions of a variable, /// evaluated using the convex hull trick with square root decomposition. +#[derive(Debug)] pub struct PiecewiseLinearFn { - sorted_lines: Vec<(i64, i64)>, - recent_lines: Vec<(i64, i64)>, + sorted_lines: Vec<(f64, f64)>, + intersections: Vec, + recent_lines: Vec<(f64, f64)>, merge_threshold: usize, } @@ -116,31 +118,67 @@ impl PiecewiseLinearFn { pub fn with_merge_threshold(merge_threshold: usize) -> Self { Self { sorted_lines: vec![], + intersections: vec![], recent_lines: vec![], merge_threshold, } } /// Replaces this function with the minimum of itself and a provided line - pub fn min_with(&mut self, slope: i64, intercept: i64) { + pub fn min_with(&mut self, slope: f64, intercept: f64) { self.recent_lines.push((slope, intercept)); } fn update_envelope(&mut self) { self.recent_lines.extend(self.sorted_lines.drain(..)); - self.recent_lines.sort_unstable(); - for (slope, intercept) in self.recent_lines.drain(..) { - // TODO: do convex hull trick algorithm - self.sorted_lines.push((slope, intercept)); + self.recent_lines + .sort_unstable_by(|x, y| y.partial_cmp(&x).unwrap()); + self.intersections.clear(); + + for (m1, b1) in self.recent_lines.drain(..) { + while let Some(&(m2, b2)) = self.sorted_lines.last() { + // If slopes are equal, the later line will always have lower + // intercept, so we can get rid of the old one. + if (m1 - m2).abs() > 1e-10f64 { + let new_intersection = (b1 - b2) / (m2 - m1); + if &new_intersection > self.intersections.last().unwrap_or(&f64::MIN) { + self.intersections.push(new_intersection); + break; + } + } + self.intersections.pop(); + self.sorted_lines.pop(); + } + self.sorted_lines.push((m1, b1)); } } - fn eval_helper(&self, x: i64) -> i64 { - 0 // TODO: pick actual minimum, or infinity if empty + fn eval_in_envelope(&self, x: f64) -> f64 { + if self.sorted_lines.is_empty() { + return f64::MAX; + } + let idx = match self + .intersections + .binary_search_by(|y| y.partial_cmp(&x).unwrap()) + { + Ok(k) => k, + Err(k) => k, + }; + let (m, b) = self.sorted_lines[idx]; + m * x + b + } + + fn eval_helper(&self, x: f64) -> f64 { + self.recent_lines + .iter() + .map(|&(m, b)| m * x + b) + .min_by(|x, y| x.partial_cmp(y).unwrap()) + .unwrap_or(f64::MAX) + .min(self.eval_in_envelope(x)) } /// Evaluates the function at x - pub fn evaluate(&mut self, x: i64) -> i64 { + pub fn evaluate(&mut self, x: f64) -> f64 { if self.recent_lines.len() > self.merge_threshold { self.update_envelope(); } @@ -164,7 +202,25 @@ mod test { #[test] fn test_convex_hull_trick() { - let mut func = PiecewiseLinearFn::with_merge_threshold(3); - // TODO: make test + let lines = [(0, 3), (1, 0), (-1, 8), (2, -1), (-1, 4)]; + let xs = [0, 1, 2, 3, 4, 5]; + // results[i] consists of the expected y-coordinates after processing + // the first i+1 lines. + let results = [ + [3, 3, 3, 3, 3, 3], + [0, 1, 2, 3, 3, 3], + [0, 1, 2, 3, 3, 3], + [-1, 1, 2, 3, 3, 3], + [-1, 1, 2, 1, 0, -1], + ]; + for threshold in 0..=lines.len() { + let mut func = PiecewiseLinearFn::with_merge_threshold(threshold); + assert_eq!(func.evaluate(0.0), f64::MAX); + for (&(slope, intercept), expected) in lines.iter().zip(results.iter()) { + func.min_with(slope as f64, intercept as f64); + let ys: Vec = xs.iter().map(|&x| func.evaluate(x as f64) as i64).collect(); + assert_eq!(expected, &ys[..]); + } + } } } From 51416ce11da20cc0f273437715b3a784f035d399 Mon Sep 17 00:00:00 2001 From: Aram Ebtekar Date: Wed, 2 Sep 2020 11:46:26 -0700 Subject: [PATCH 2/4] PartialOrd utility functions, README fix --- README.md | 4 ++-- src/range_query/mod.rs | 21 +++++++++++-------- src/range_query/sqrt_decomp.rs | 37 ++++++++++------------------------ 3 files changed, 26 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 0bf9d50..9f33f28 100644 --- a/README.md +++ b/README.md @@ -47,8 +47,8 @@ Rather than try to persuade you with words, this repository aims to show by exam - [Elementary graph algorithms](src/graph/util.rs): minimum spanning tree, Euler path, Dijkstra's algorithm, DFS iteration - [Network flows](src/graph/flow.rs): Dinic's blocking flow, Hopcroft-Karp bipartite matching, min cost max flow - [Connected components](src/graph/connectivity.rs): 2-edge-, 2-vertex- and strongly connected components, bridges, articulation points, topological sort, 2-SAT -- [Associative range query](src/range_query): known colloquially as *segtrees*, coordinate compression, and Mo's query square root decomposition -- [Number thery](src/math/mod.rs): canonical solution to Bezout's identity, Miller's primality test +- [Associative range query](src/range_query): known colloquially as *segtrees*, coordinate compression, convex hull trick, and Mo's query square root decomposition +- [Number theory](src/math/mod.rs): canonical solution to Bezout's identity, Miller's primality test - [Arithmetic](src/math/num.rs): rational and complex numbers, linear algebra, safe modular arithmetic - [FFT](src/math/fft.rs): fast Fourier transform, number theoretic transform, convolution - [Scanner](src/scanner.rs): utility for reading input data ergonomically diff --git a/src/range_query/mod.rs b/src/range_query/mod.rs index 57155c3..64f0e8a 100644 --- a/src/range_query/mod.rs +++ b/src/range_query/mod.rs @@ -6,19 +6,24 @@ pub use dynamic_arq::{ArqView, DynamicArq}; pub use specs::ArqSpec; pub use static_arq::StaticArq; -/// Assuming slice is sorted, returns the minimum i for which slice[i] >= key, -/// or slice.len() if no such i exists -pub fn slice_lower_bound(slice: &[T], key: &T) -> usize { +/// A comparator on partially ordered elements, that panics if they are incomparable +pub fn asserting_cmp(a: &T, b: &T) -> std::cmp::Ordering { + a.partial_cmp(b).expect("Comparing incomparable elements") +} + +/// Assuming slice is totally ordered and sorted, returns the minimum i for which +/// slice[i] >= key, or slice.len() if no such i exists +pub fn slice_lower_bound(slice: &[T], key: &T) -> usize { slice - .binary_search_by(|x| x.cmp(key).then(std::cmp::Ordering::Greater)) + .binary_search_by(|x| asserting_cmp(x, key).then(std::cmp::Ordering::Greater)) .unwrap_err() } -/// Assuming slice is sorted, returns the minimum i for which slice[i] > key, -/// or slice.len() if no such i exists -pub fn slice_upper_bound(slice: &[T], key: &T) -> usize { +/// Assuming slice is totally ordered and sorted, returns the minimum i for which +/// slice[i] > key, or slice.len() if no such i exists +pub fn slice_upper_bound(slice: &[T], key: &T) -> usize { slice - .binary_search_by(|x| x.cmp(key).then(std::cmp::Ordering::Less)) + .binary_search_by(|x| asserting_cmp(x, key).then(std::cmp::Ordering::Less)) .unwrap_err() } diff --git a/src/range_query/sqrt_decomp.rs b/src/range_query/sqrt_decomp.rs index e00bdef..28f8f78 100644 --- a/src/range_query/sqrt_decomp.rs +++ b/src/range_query/sqrt_decomp.rs @@ -114,7 +114,7 @@ pub struct PiecewiseLinearFn { impl PiecewiseLinearFn { /// For N inserts interleaved with Q queries, a threshold of N/sqrt(Q) yields /// O(N sqrt Q + Q log N) time complexity. If all queries come after all inserts, - /// a threshold of 0 yields O(N + Q log N) time complexity. + /// any threshold less than N (e.g., 0) yields O(N + Q log N) time complexity. pub fn with_merge_threshold(merge_threshold: usize) -> Self { Self { sorted_lines: vec![], @@ -131,15 +131,14 @@ impl PiecewiseLinearFn { fn update_envelope(&mut self) { self.recent_lines.extend(self.sorted_lines.drain(..)); - self.recent_lines - .sort_unstable_by(|x, y| y.partial_cmp(&x).unwrap()); + self.recent_lines.sort_unstable_by(super::asserting_cmp); self.intersections.clear(); - for (m1, b1) in self.recent_lines.drain(..) { + for (m1, b1) in self.recent_lines.drain(..).rev() { while let Some(&(m2, b2)) = self.sorted_lines.last() { // If slopes are equal, the later line will always have lower // intercept, so we can get rid of the old one. - if (m1 - m2).abs() > 1e-10f64 { + if (m1 - m2).abs() > 1e-9 { let new_intersection = (b1 - b2) / (m2 - m1); if &new_intersection > self.intersections.last().unwrap_or(&f64::MIN) { self.intersections.push(new_intersection); @@ -153,28 +152,14 @@ impl PiecewiseLinearFn { } } - fn eval_in_envelope(&self, x: f64) -> f64 { - if self.sorted_lines.is_empty() { - return f64::MAX; - } - let idx = match self - .intersections - .binary_search_by(|y| y.partial_cmp(&x).unwrap()) - { - Ok(k) => k, - Err(k) => k, - }; - let (m, b) = self.sorted_lines[idx]; - m * x + b - } - fn eval_helper(&self, x: f64) -> f64 { - self.recent_lines - .iter() + let idx = super::slice_lower_bound(&self.intersections, &x); + std::iter::once(self.sorted_lines.get(idx)) + .flatten() + .chain(self.recent_lines.iter()) .map(|&(m, b)| m * x + b) - .min_by(|x, y| x.partial_cmp(y).unwrap()) - .unwrap_or(f64::MAX) - .min(self.eval_in_envelope(x)) + .min_by(super::asserting_cmp) + .unwrap_or(1e18) } /// Evaluates the function at x @@ -215,7 +200,7 @@ mod test { ]; for threshold in 0..=lines.len() { let mut func = PiecewiseLinearFn::with_merge_threshold(threshold); - assert_eq!(func.evaluate(0.0), f64::MAX); + assert_eq!(func.evaluate(0.0), 1e18); for (&(slope, intercept), expected) in lines.iter().zip(results.iter()) { func.min_with(slope as f64, intercept as f64); let ys: Vec = xs.iter().map(|&x| func.evaluate(x as f64) as i64).collect(); From dafd95beaa8e350d7cd19a4810958a8557cf9db3 Mon Sep 17 00:00:00 2001 From: Aram Ebtekar Date: Wed, 2 Sep 2020 12:07:29 -0700 Subject: [PATCH 3/4] Moved some algorithms to a new file misc.rs --- README.md | 3 +- src/lib.rs | 1 + src/misc.rs | 177 +++++++++++++++++++++++++++++++++ src/range_query/mod.rs | 79 --------------- src/range_query/sqrt_decomp.rs | 94 ----------------- 5 files changed, 180 insertions(+), 174 deletions(-) create mode 100644 src/misc.rs diff --git a/README.md b/README.md index 9f33f28..1b1e7fa 100644 --- a/README.md +++ b/README.md @@ -47,9 +47,10 @@ Rather than try to persuade you with words, this repository aims to show by exam - [Elementary graph algorithms](src/graph/util.rs): minimum spanning tree, Euler path, Dijkstra's algorithm, DFS iteration - [Network flows](src/graph/flow.rs): Dinic's blocking flow, Hopcroft-Karp bipartite matching, min cost max flow - [Connected components](src/graph/connectivity.rs): 2-edge-, 2-vertex- and strongly connected components, bridges, articulation points, topological sort, 2-SAT -- [Associative range query](src/range_query): known colloquially as *segtrees*, coordinate compression, convex hull trick, and Mo's query square root decomposition +- [Associative range query](src/range_query): known colloquially as *segtrees*, as well as Mo's query square root decomposition - [Number theory](src/math/mod.rs): canonical solution to Bezout's identity, Miller's primality test - [Arithmetic](src/math/num.rs): rational and complex numbers, linear algebra, safe modular arithmetic - [FFT](src/math/fft.rs): fast Fourier transform, number theoretic transform, convolution - [Scanner](src/scanner.rs): utility for reading input data ergonomically - [String processing](src/string_proc.rs): Knuth-Morris-Pratt and Aho-Corasick string matching, suffix array, Manacher's linear-time palindrome search +- [Miscellaneous algorithms](src/misc.rs): slice binary search, coordinate compression, convex hull trick with sqrt decomposition diff --git a/src/lib.rs b/src/lib.rs index 2e50a88..7a157f2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ //! Algorithms Cookbook in Rust. pub mod graph; pub mod math; +pub mod misc; pub mod range_query; pub mod scanner; pub mod string_proc; diff --git a/src/misc.rs b/src/misc.rs new file mode 100644 index 0000000..4332296 --- /dev/null +++ b/src/misc.rs @@ -0,0 +1,177 @@ +//! Miscellaneous algorithms. + +/// A comparator on partially ordered elements, that panics if they are incomparable +pub fn asserting_cmp(a: &T, b: &T) -> std::cmp::Ordering { + a.partial_cmp(b).expect("Comparing incomparable elements") +} + +/// Assuming slice is totally ordered and sorted, returns the minimum i for which +/// slice[i] >= key, or slice.len() if no such i exists +pub fn slice_lower_bound(slice: &[T], key: &T) -> usize { + slice + .binary_search_by(|x| asserting_cmp(x, key).then(std::cmp::Ordering::Greater)) + .unwrap_err() +} + +/// Assuming slice is totally ordered and sorted, returns the minimum i for which +/// slice[i] > key, or slice.len() if no such i exists +pub fn slice_upper_bound(slice: &[T], key: &T) -> usize { + slice + .binary_search_by(|x| asserting_cmp(x, key).then(std::cmp::Ordering::Less)) + .unwrap_err() +} + +/// A simple data structure for coordinate compression +pub struct SparseIndex { + coords: Vec, +} + +impl SparseIndex { + /// Build an index, given the full set of coordinates to compress. + pub fn new(mut coords: Vec) -> Self { + coords.sort_unstable(); + coords.dedup(); + Self { coords } + } + + /// Return Ok(i) if the coordinate q appears at index i + /// Return Err(i) if q appears between indices i-1 and i + pub fn compress(&self, q: i64) -> Result { + self.coords.binary_search(&q) + } +} + +/// Represents a minimum (lower envelope) of a collection of linear functions of a variable, +/// evaluated using the convex hull trick with square root decomposition. +pub struct PiecewiseLinearFn { + sorted_lines: Vec<(f64, f64)>, + intersections: Vec, + recent_lines: Vec<(f64, f64)>, + merge_threshold: usize, +} + +impl PiecewiseLinearFn { + /// For N inserts interleaved with Q queries, a threshold of N/sqrt(Q) yields + /// O(N sqrt Q + Q log N) time complexity. If all queries come after all inserts, + /// any threshold less than N (e.g., 0) yields O(N + Q log N) time complexity. + pub fn with_merge_threshold(merge_threshold: usize) -> Self { + Self { + sorted_lines: vec![], + intersections: vec![], + recent_lines: vec![], + merge_threshold, + } + } + + /// Replaces the represented function with the minimum of itself and a provided line + pub fn min_with(&mut self, slope: f64, intercept: f64) { + self.recent_lines.push((slope, intercept)); + } + + fn update_envelope(&mut self) { + self.recent_lines.extend(self.sorted_lines.drain(..)); + self.recent_lines.sort_unstable_by(asserting_cmp); + self.intersections.clear(); + + for (new_m, new_b) in self.recent_lines.drain(..).rev() { + while let Some(&(last_m, last_b)) = self.sorted_lines.last() { + // If slopes are equal, get rid of the old line as its intercept is higher + if (new_m - last_m).abs() > 1e-9 { + let intr = (new_b - last_b) / (last_m - new_m); + if self.intersections.last().map(|&x| x < intr).unwrap_or(true) { + self.intersections.push(intr); + break; + } + } + self.intersections.pop(); + self.sorted_lines.pop(); + } + self.sorted_lines.push((new_m, new_b)); + } + } + + fn eval_helper(&self, x: f64) -> f64 { + let idx = slice_lower_bound(&self.intersections, &x); + std::iter::once(self.sorted_lines.get(idx)) + .flatten() + .chain(self.recent_lines.iter()) + .map(|&(m, b)| m * x + b) + .min_by(asserting_cmp) + .unwrap_or(1e18) + } + + /// Evaluates the function at x + pub fn evaluate(&mut self, x: f64) -> f64 { + if self.recent_lines.len() > self.merge_threshold { + self.update_envelope(); + } + self.eval_helper(x) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_bounds() { + let mut vals = vec![16, 45, 45, 45, 82]; + + assert_eq!(slice_upper_bound(&vals, &44), 1); + assert_eq!(slice_lower_bound(&vals, &45), 1); + assert_eq!(slice_upper_bound(&vals, &45), 4); + assert_eq!(slice_lower_bound(&vals, &46), 4); + + vals.dedup(); + for (i, q) in vals.iter().enumerate() { + assert_eq!(slice_lower_bound(&vals, q), i); + assert_eq!(slice_upper_bound(&vals, q), i + 1); + } + } + + #[test] + fn test_coord_compress() { + let mut coords = vec![16, 99, 45, 18]; + let index = SparseIndex::new(coords.clone()); + + coords.sort_unstable(); + for (i, q) in coords.into_iter().enumerate() { + assert_eq!(index.compress(q - 1), Err(i)); + assert_eq!(index.compress(q), Ok(i)); + assert_eq!(index.compress(q + 1), Err(i + 1)); + } + } + + #[test] + fn test_range_compress() { + let queries = vec![(0, 10), (10, 19), (20, 29)]; + let coords = queries.iter().flat_map(|&(i, j)| vec![i, j + 1]).collect(); + let index = SparseIndex::new(coords); + + assert_eq!(index.coords, vec![0, 10, 11, 20, 30]); + } + + #[test] + fn test_convex_hull_trick() { + let lines = [(0, 3), (1, 0), (-1, 8), (2, -1), (-1, 4)]; + let xs = [0, 1, 2, 3, 4, 5]; + // results[i] consists of the expected y-coordinates after processing + // the first i+1 lines. + let results = [ + [3, 3, 3, 3, 3, 3], + [0, 1, 2, 3, 3, 3], + [0, 1, 2, 3, 3, 3], + [-1, 1, 2, 3, 3, 3], + [-1, 1, 2, 1, 0, -1], + ]; + for threshold in 0..=lines.len() { + let mut func = PiecewiseLinearFn::with_merge_threshold(threshold); + assert_eq!(func.evaluate(0.0), 1e18); + for (&(slope, intercept), expected) in lines.iter().zip(results.iter()) { + func.min_with(slope as f64, intercept as f64); + let ys: Vec = xs.iter().map(|&x| func.evaluate(x as f64) as i64).collect(); + assert_eq!(expected, &ys[..]); + } + } + } +} diff --git a/src/range_query/mod.rs b/src/range_query/mod.rs index 64f0e8a..332fb59 100644 --- a/src/range_query/mod.rs +++ b/src/range_query/mod.rs @@ -6,90 +6,11 @@ pub use dynamic_arq::{ArqView, DynamicArq}; pub use specs::ArqSpec; pub use static_arq::StaticArq; -/// A comparator on partially ordered elements, that panics if they are incomparable -pub fn asserting_cmp(a: &T, b: &T) -> std::cmp::Ordering { - a.partial_cmp(b).expect("Comparing incomparable elements") -} - -/// Assuming slice is totally ordered and sorted, returns the minimum i for which -/// slice[i] >= key, or slice.len() if no such i exists -pub fn slice_lower_bound(slice: &[T], key: &T) -> usize { - slice - .binary_search_by(|x| asserting_cmp(x, key).then(std::cmp::Ordering::Greater)) - .unwrap_err() -} - -/// Assuming slice is totally ordered and sorted, returns the minimum i for which -/// slice[i] > key, or slice.len() if no such i exists -pub fn slice_upper_bound(slice: &[T], key: &T) -> usize { - slice - .binary_search_by(|x| asserting_cmp(x, key).then(std::cmp::Ordering::Less)) - .unwrap_err() -} - -/// A simple data structure for coordinate compression -pub struct SparseIndex { - coords: Vec, -} - -impl SparseIndex { - /// Build an index, given the full set of coordinates to compress. - pub fn new(mut coords: Vec) -> Self { - coords.sort_unstable(); - coords.dedup(); - Self { coords } - } - - /// Return Ok(i) if the coordinate q appears at index i - /// Return Err(i) if q appears between indices i-1 and i - pub fn compress(&self, q: i64) -> Result { - self.coords.binary_search(&q) - } -} - #[cfg(test)] mod test { use super::specs::*; use super::*; - #[test] - fn test_bounds() { - let mut vals = vec![16, 45, 45, 45, 82]; - - assert_eq!(slice_upper_bound(&vals, &44), 1); - assert_eq!(slice_lower_bound(&vals, &45), 1); - assert_eq!(slice_upper_bound(&vals, &45), 4); - assert_eq!(slice_lower_bound(&vals, &46), 4); - - vals.dedup(); - for (i, q) in vals.iter().enumerate() { - assert_eq!(slice_lower_bound(&vals, q), i); - assert_eq!(slice_upper_bound(&vals, q), i + 1); - } - } - - #[test] - fn test_coord_compress() { - let mut coords = vec![16, 99, 45, 18]; - let index = SparseIndex::new(coords.clone()); - - coords.sort_unstable(); - for (i, q) in coords.into_iter().enumerate() { - assert_eq!(index.compress(q - 1), Err(i)); - assert_eq!(index.compress(q), Ok(i)); - assert_eq!(index.compress(q + 1), Err(i + 1)); - } - } - - #[test] - fn test_range_compress() { - let queries = vec![(0, 10), (10, 19), (20, 29)]; - let coords = queries.iter().flat_map(|&(i, j)| vec![i, j + 1]).collect(); - let index = SparseIndex::new(coords); - - assert_eq!(index.coords, vec![0, 10, 11, 20, 30]); - } - #[test] fn test_rmq() { let mut arq = StaticArq::::new(&[0; 10]); diff --git a/src/range_query/sqrt_decomp.rs b/src/range_query/sqrt_decomp.rs index 28f8f78..9accc7e 100644 --- a/src/range_query/sqrt_decomp.rs +++ b/src/range_query/sqrt_decomp.rs @@ -101,76 +101,6 @@ impl MoState for DistinctVals { } } -/// Represents a minimum (lower envelope) of a collection of linear functions of a variable, -/// evaluated using the convex hull trick with square root decomposition. -#[derive(Debug)] -pub struct PiecewiseLinearFn { - sorted_lines: Vec<(f64, f64)>, - intersections: Vec, - recent_lines: Vec<(f64, f64)>, - merge_threshold: usize, -} - -impl PiecewiseLinearFn { - /// For N inserts interleaved with Q queries, a threshold of N/sqrt(Q) yields - /// O(N sqrt Q + Q log N) time complexity. If all queries come after all inserts, - /// any threshold less than N (e.g., 0) yields O(N + Q log N) time complexity. - pub fn with_merge_threshold(merge_threshold: usize) -> Self { - Self { - sorted_lines: vec![], - intersections: vec![], - recent_lines: vec![], - merge_threshold, - } - } - - /// Replaces this function with the minimum of itself and a provided line - pub fn min_with(&mut self, slope: f64, intercept: f64) { - self.recent_lines.push((slope, intercept)); - } - - fn update_envelope(&mut self) { - self.recent_lines.extend(self.sorted_lines.drain(..)); - self.recent_lines.sort_unstable_by(super::asserting_cmp); - self.intersections.clear(); - - for (m1, b1) in self.recent_lines.drain(..).rev() { - while let Some(&(m2, b2)) = self.sorted_lines.last() { - // If slopes are equal, the later line will always have lower - // intercept, so we can get rid of the old one. - if (m1 - m2).abs() > 1e-9 { - let new_intersection = (b1 - b2) / (m2 - m1); - if &new_intersection > self.intersections.last().unwrap_or(&f64::MIN) { - self.intersections.push(new_intersection); - break; - } - } - self.intersections.pop(); - self.sorted_lines.pop(); - } - self.sorted_lines.push((m1, b1)); - } - } - - fn eval_helper(&self, x: f64) -> f64 { - let idx = super::slice_lower_bound(&self.intersections, &x); - std::iter::once(self.sorted_lines.get(idx)) - .flatten() - .chain(self.recent_lines.iter()) - .map(|&(m, b)| m * x + b) - .min_by(super::asserting_cmp) - .unwrap_or(1e18) - } - - /// Evaluates the function at x - pub fn evaluate(&mut self, x: f64) -> f64 { - if self.recent_lines.len() > self.merge_threshold { - self.update_envelope(); - } - self.eval_helper(x) - } -} - #[cfg(test)] mod test { use super::*; @@ -184,28 +114,4 @@ mod test { assert_eq!(answers, vec![2, 1, 5, 5]); } - - #[test] - fn test_convex_hull_trick() { - let lines = [(0, 3), (1, 0), (-1, 8), (2, -1), (-1, 4)]; - let xs = [0, 1, 2, 3, 4, 5]; - // results[i] consists of the expected y-coordinates after processing - // the first i+1 lines. - let results = [ - [3, 3, 3, 3, 3, 3], - [0, 1, 2, 3, 3, 3], - [0, 1, 2, 3, 3, 3], - [-1, 1, 2, 3, 3, 3], - [-1, 1, 2, 1, 0, -1], - ]; - for threshold in 0..=lines.len() { - let mut func = PiecewiseLinearFn::with_merge_threshold(threshold); - assert_eq!(func.evaluate(0.0), 1e18); - for (&(slope, intercept), expected) in lines.iter().zip(results.iter()) { - func.min_with(slope as f64, intercept as f64); - let ys: Vec = xs.iter().map(|&x| func.evaluate(x as f64) as i64).collect(); - assert_eq!(expected, &ys[..]); - } - } - } } From f9f6eb1ca312e76af8a5bc2cdfd7db972769afec Mon Sep 17 00:00:00 2001 From: Aram Ebtekar Date: Sun, 6 Sep 2020 17:05:35 -0700 Subject: [PATCH 4/4] cleanup using iterator magic --- src/misc.rs | 10 ++--- src/string_proc.rs | 103 +++++++++++++++++++++------------------------ 2 files changed, 53 insertions(+), 60 deletions(-) diff --git a/src/misc.rs b/src/misc.rs index 4332296..c3b41bd 100644 --- a/src/misc.rs +++ b/src/misc.rs @@ -70,7 +70,7 @@ impl PiecewiseLinearFn { fn update_envelope(&mut self) { self.recent_lines.extend(self.sorted_lines.drain(..)); - self.recent_lines.sort_unstable_by(asserting_cmp); + self.recent_lines.sort_unstable_by(asserting_cmp); // TODO: switch to O(n) merge self.intersections.clear(); for (new_m, new_b) in self.recent_lines.drain(..).rev() { @@ -78,7 +78,7 @@ impl PiecewiseLinearFn { // If slopes are equal, get rid of the old line as its intercept is higher if (new_m - last_m).abs() > 1e-9 { let intr = (new_b - last_b) / (last_m - new_m); - if self.intersections.last().map(|&x| x < intr).unwrap_or(true) { + if self.intersections.last() < Some(&intr) { self.intersections.push(intr); break; } @@ -92,9 +92,9 @@ impl PiecewiseLinearFn { fn eval_helper(&self, x: f64) -> f64 { let idx = slice_lower_bound(&self.intersections, &x); - std::iter::once(self.sorted_lines.get(idx)) - .flatten() - .chain(self.recent_lines.iter()) + self.recent_lines + .iter() + .chain(self.sorted_lines.get(idx)) .map(|&(m, b)| m * x + b) .min_by(asserting_cmp) .unwrap_or(1e18) diff --git a/src/string_proc.rs b/src/string_proc.rs index da3348e..9a9a692 100644 --- a/src/string_proc.rs +++ b/src/string_proc.rs @@ -61,13 +61,12 @@ impl<'a, C: Eq> Matcher<'a, C> { /// /// ``` /// use contest_algorithms::string_proc::Matcher; - /// let utf8_string = "hello"; - /// - /// let match_from_byte_literal = Matcher::new(b"hello"); - /// - /// let match_from_bytes = Matcher::new(utf8_string.as_bytes()); - /// + /// let byte_string: &[u8] = b"hello"; + /// let utf8_string: &str = "hello"; /// let vec_char: Vec = utf8_string.chars().collect(); + /// + /// let match_from_byte_literal = Matcher::new(byte_string); + /// let match_from_utf8 = Matcher::new(utf8_string.as_bytes()); /// let match_from_chars = Matcher::new(&vec_char); /// /// let vec_int = vec![4, -3, 1]; @@ -93,24 +92,24 @@ impl<'a, C: Eq> Matcher<'a, C> { Self { pattern, fail } } - /// KMP algorithm, sets match_lens[i] = length of longest prefix of pattern + /// KMP algorithm, sets @return[i] = length of longest prefix of pattern /// matching a suffix of text[0..=i]. - pub fn kmp_match(&self, text: &[C]) -> Vec { - let mut match_lens = Vec::with_capacity(text.len()); + pub fn kmp_match(&self, text: impl IntoIterator) -> Vec { let mut len = 0; - for ch in text { - if len == self.pattern.len() { - len = self.fail[len - 1]; - } - while len > 0 && self.pattern[len] != *ch { - len = self.fail[len - 1]; - } - if self.pattern[len] == *ch { - len += 1; - } - match_lens.push(len); - } - match_lens + text.into_iter() + .map(|ch| { + if len == self.pattern.len() { + len = self.fail[len - 1]; + } + while len > 0 && self.pattern[len] != ch { + len = self.fail[len - 1]; + } + if self.pattern[len] == ch { + len += 1; + } + len + }) + .collect() } } @@ -141,7 +140,7 @@ impl MultiMatcher { /// Precomputes the automaton that allows linear-time string matching. /// If there are duplicate patterns, all but one copy will be ignored. - pub fn new(patterns: Vec>) -> Self { + pub fn new(patterns: impl IntoIterator>) -> Self { let mut trie = Trie::default(); let pat_nodes: Vec = patterns.into_iter().map(|pat| trie.insert(pat)).collect(); @@ -171,16 +170,16 @@ impl MultiMatcher { } } - /// Aho-Corasick algorithm, sets match_nodes[i] = node corresponding to + /// Aho-Corasick algorithm, sets @return[i] = node corresponding to /// longest prefix of some pattern matching a suffix of text[0..=i]. - pub fn ac_match(&self, text: &[C]) -> Vec { - let mut match_nodes = Vec::with_capacity(text.len()); + pub fn ac_match(&self, text: impl IntoIterator) -> Vec { let mut node = 0; - for ch in text { - node = Self::next(&self.trie, &self.fail, node, &ch); - match_nodes.push(node); - } - match_nodes + text.into_iter() + .map(|ch| { + node = Self::next(&self.trie, &self.fail, node, &ch); + node + }) + .collect() } /// For each non-empty match, returns where in the text it ends, and the index @@ -235,9 +234,9 @@ impl SuffixArray { } /// Suffix array construction in O(n log n) time. - pub fn new(text: &[u8]) -> Self { - let n = text.len(); - let init_rank = text.iter().map(|&ch| ch as usize).collect::>(); + pub fn new(text: impl IntoIterator) -> Self { + let init_rank = text.into_iter().map(|ch| ch as usize).collect::>(); + let n = init_rank.len(); let mut sfx = Self::counting_sort(0..n, &init_rank, 256); let mut rank = vec![init_rank]; // Invariant at the start of every loop iteration: @@ -291,7 +290,7 @@ impl SuffixArray { /// # Panics /// /// Panics if text is empty. -pub fn palindromes(text: &[T]) -> Vec { +pub fn palindromes(text: &[impl Eq]) -> Vec { let mut pal = Vec::with_capacity(2 * text.len() - 1); pal.push(1); while pal.len() < pal.capacity() { @@ -339,27 +338,21 @@ mod test { #[test] fn test_kmp_matching() { - let text = b"banana"; - let pattern = b"ana"; + let pattern = "ana"; + let text = "banana"; - let matches = Matcher::new(pattern).kmp_match(text); + let matches = Matcher::new(pattern.as_bytes()).kmp_match(text.bytes()); assert_eq!(matches, vec![0, 1, 2, 3, 2, 3]); } #[test] fn test_ac_matching() { - let text = b"banana bans, apple benefits."; - let dict = vec![ - "banana".bytes(), - "benefit".bytes(), - "banapple".bytes(), - "ban".bytes(), - "fit".bytes(), - ]; - - let matcher = MultiMatcher::new(dict); - let match_nodes = matcher.ac_match(text); + let dict = vec!["banana", "benefit", "banapple", "ban", "fit"]; + let text = "banana bans, apple benefits."; + + let matcher = MultiMatcher::new(dict.iter().map(|s| s.bytes())); + let match_nodes = matcher.ac_match(text.bytes()); let end_pos_and_id = matcher.get_end_pos_and_pat_id(&match_nodes); assert_eq!( @@ -370,11 +363,11 @@ mod test { #[test] fn test_suffix_array() { - let text1 = b"bobocel"; - let text2 = b"banana"; + let text1 = "bobocel"; + let text2 = "banana"; - let sfx1 = SuffixArray::new(text1); - let sfx2 = SuffixArray::new(text2); + let sfx1 = SuffixArray::new(text1.bytes()); + let sfx2 = SuffixArray::new(text2.bytes()); assert_eq!(sfx1.sfx, vec![0, 2, 4, 5, 6, 1, 3]); assert_eq!(sfx2.sfx, vec![5, 3, 1, 0, 4, 2]); @@ -393,9 +386,9 @@ mod test { #[test] fn test_palindrome() { - let text = b"banana"; + let text = "banana"; - let pal_len = palindromes(text); + let pal_len = palindromes(text.as_bytes()); assert_eq!(pal_len, vec![1, 0, 1, 0, 3, 0, 5, 0, 3, 0, 1]); }