Skip to content

Commit 00684c4

Browse files
committed
remove compare fn to make impl clear
1 parent 1c55b00 commit 00684c4

File tree

19 files changed

+1462
-200
lines changed

19 files changed

+1462
-200
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ clippy:
1010
$(CARGO) clippy
1111

1212
test:
13-
#$(TEST) -- --nocapture --test-threads=1
13+
#$(TEST) -- --nocapture
1414
$(TEST)
1515

1616
test_dp:

src/common/heap.rs

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,16 @@ pub struct BinaryHeap<K> {
3535
keys: Vec<K>,
3636
}
3737

38-
impl<K> BinaryHeap<K> {
39-
pub fn new<F>(mut keys: Vec<K>, test: &F) -> Self
40-
where
41-
F: Fn(&K, &K) -> bool,
42-
{
43-
build_heap(&mut keys, test);
38+
impl<K> BinaryHeap<K>
39+
where
40+
K: Ord,
41+
{
42+
pub fn new(mut keys: Vec<K>) -> Self {
43+
build_heap(&mut keys);
4444
Self { keys }
4545
}
4646

47-
pub fn pop<F>(&mut self, test: &F) -> Option<K>
48-
where
49-
F: Fn(&K, &K) -> bool,
50-
{
47+
pub fn pop(&mut self) -> Option<K> {
5148
let len = self.keys.len();
5249
if len > 0 {
5350
// 从长度为 n 的数组中删除第一个元素需要线性时间 O(n)。
@@ -57,33 +54,27 @@ impl<K> BinaryHeap<K> {
5754
// 个元素,然后将数组的长度减一。
5855
self.keys.swap(0, len - 1);
5956
let key = self.keys.pop();
60-
heapify(&mut self.keys, 0, test);
57+
heapify(&mut self.keys, 0);
6158
key
6259
} else {
6360
None
6461
}
6562
}
6663

67-
pub fn set<F>(&mut self, i: usize, key: K, test: &F)
68-
where
69-
F: Fn(&K, &K) -> bool,
70-
{
64+
pub fn set(&mut self, i: usize, key: K) {
7165
match self.keys.get(i) {
72-
Some(v) if test(&key, v) => {
66+
Some(v) if &key >= v => {
7367
self.keys[i] = key;
74-
heap_fix(&mut self.keys, i, test);
68+
heap_fix(&mut self.keys, i);
7569
}
7670
_ => (),
7771
}
7872
}
7973

80-
pub fn insert<F>(&mut self, key: K, test: &F)
81-
where
82-
F: Fn(&K, &K) -> bool,
83-
{
74+
pub fn insert(&mut self, key: K) {
8475
let i = self.keys.len();
8576
self.keys.push(key);
86-
heap_fix(&mut self.keys, i, test);
77+
heap_fix(&mut self.keys, i);
8778
}
8879

8980
//for test
@@ -92,9 +83,9 @@ impl<K> BinaryHeap<K> {
9283
}
9384
}
9485

95-
pub fn heapify<K, F>(keys: &mut [K], mut i: usize, test: &F)
86+
pub fn heapify<K>(keys: &mut [K], mut i: usize)
9687
where
97-
F: Fn(&K, &K) -> bool,
88+
K: Ord,
9889
{
9990
let n = keys.len();
10091
loop {
@@ -103,13 +94,13 @@ where
10394
let mut m = i;
10495

10596
if let Some(v) = keys.get(l) {
106-
if l < n && test(v, &keys[m]) {
97+
if l < n && v >= &keys[m] {
10798
m = l;
10899
}
109100
}
110101

111102
if let Some(v) = keys.get(r) {
112-
if r < n && test(v, &keys[m]) {
103+
if r < n && v >= &keys[m] {
113104
m = r;
114105
}
115106
}
@@ -123,9 +114,9 @@ where
123114
}
124115
}
125116

126-
pub fn build_heap<K, F>(keys: &mut [K], compare: &F)
117+
pub fn build_heap<K>(keys: &mut [K])
127118
where
128-
F: Fn(&K, &K) -> bool,
119+
K: Ord,
129120
{
130121
// i以 n / 2作为第一个分支节点,开始构建heap。
131122
// 因为叶子结点,已经满足堆定义,所以从二叉树倒数第二层最后一个节点
@@ -135,21 +126,21 @@ where
135126
// index = 2 ^ (p - 1) - 1 = 2 ^ ( log(n) - 1) - 1 <= n / 2
136127
let mut i = keys.len() as i32 / 2;
137128
while i >= 0 {
138-
heapify(keys, i as usize, compare);
129+
heapify(keys, i as usize);
139130
i -= 1;
140131
}
141132
}
142133

143134
// 与heapify的区别:
144135
// heapify 是从i节点开始,调整子树 (向下调整)
145136
// heap_fix 是从i节点开始,调整父节点(向上调整)
146-
fn heap_fix<K, F>(keys: &mut [K], mut i: usize, test: &F)
137+
fn heap_fix<K>(keys: &mut [K], mut i: usize)
147138
where
148-
F: Fn(&K, &K) -> bool,
139+
K: Ord,
149140
{
150141
while i > 0 {
151142
let parent = parent!(i);
152-
if test(&keys[i], &keys[parent]) {
143+
if keys[i] >= keys[parent] {
153144
keys.swap(i, parent);
154145
i = parent;
155146
}

src/math/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
pub mod mysqrt;
2+
pub mod sparse_vector;

src/tree/binary/sparse_vector.rs renamed to src/math/sparse_vector.rs

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,15 @@
33
/// It includes methods for addition, subtraction,
44
/// dot product, scalar product, unit vector, and Euclidean norm.
55
///
6-
/// The implementation is a symbol table of indices and values for which the vector
7-
/// coordinates are nonzero. This makes it efficient when most of the vector coordindates
8-
/// are zero.
6+
/// The implementation is a symbol table (Red Black Tree) of indices and values
7+
/// for which the vector coordinates are nonzero. This makes it efficient when
8+
/// most of the vector coordindates are zero.
99
///
1010
/// ref: https://github.com/kevin-wayne/algs4.git
1111
use crate::tree::binary::rb2::RedBlackTreeV2;
1212
use crate::tree::binary::Tree;
1313
use std::ops::{Add, Sub};
1414

15-
#[derive(Debug)]
16-
pub enum Err {
17-
Dimension,
18-
}
19-
2015
pub struct SparseVector {
2116
d: usize,
2217
st: Tree<usize, f64>,
@@ -40,7 +35,7 @@ impl SparseVector {
4035

4136
/// Returns the ith coordinate of this vector
4237
pub fn get(&self, i: usize) -> f64 {
43-
self.st.get(&i).cloned().unwrap_or(0.0)
38+
*self.st.get(&i).unwrap_or(&0.0)
4439
}
4540

4641
/// Returns the number of nonzero entries in this vector.
@@ -58,7 +53,7 @@ impl SparseVector {
5853
if self.d != that.d {
5954
Err(Err::Dimension)
6055
} else {
61-
let keys = if self.st.size() <= that.st.size() {
56+
let keys = if self.nnz() <= that.nnz() {
6257
self.st.keys()
6358
} else {
6459
that.st.keys()
@@ -85,34 +80,20 @@ impl SparseVector {
8580
/// Returns the scalar-vector product of this vector with the specified scalar.
8681
pub fn scale(&self, alpha: f64) -> Self {
8782
let mut c = Self::new(self.d);
88-
for i in self.st.keys() {
89-
c.put(*i, alpha * self.get(*i));
83+
for &i in self.st.keys() {
84+
c.put(i, alpha * self.get(i));
9085
}
9186
c
9287
}
9388
}
9489

95-
impl ToString for SparseVector {
96-
fn to_string(&self) -> String {
97-
let keys = self.st.keys();
98-
let mut v = Vec::with_capacity(keys.len());
99-
for i in self.st.keys() {
100-
v.push(format!("({}, {})", i, self.get(*i)));
101-
}
102-
v.join("")
103-
}
104-
}
105-
10690
impl Add for SparseVector {
10791
type Output = Self;
10892

10993
fn add(self, rhs: Self) -> Self::Output {
110-
let mut c = Self::new(self.d);
111-
for i in self.st.keys() {
112-
c.put(*i, self.get(*i));
113-
}
114-
for i in rhs.st.keys() {
115-
c.put(*i, c.get(*i) + rhs.get(*i));
94+
let mut c = self.clone();
95+
for &i in rhs.st.keys() {
96+
c.put(i, c.get(i) + rhs.get(i));
11697
}
11798
c
11899
}
@@ -122,13 +103,36 @@ impl Sub for SparseVector {
122103
type Output = Self;
123104

124105
fn sub(self, rhs: Self) -> Self::Output {
125-
let mut c = Self::new(self.d);
126-
for i in self.st.keys() {
127-
c.put(*i, self.get(*i));
106+
let mut c = self.clone();
107+
for &i in rhs.st.keys() {
108+
c.put(i, c.get(i) - rhs.get(i));
109+
}
110+
c
111+
}
112+
}
113+
114+
impl ToString for SparseVector {
115+
fn to_string(&self) -> String {
116+
let keys = self.st.keys();
117+
let mut v = Vec::with_capacity(keys.len());
118+
for &i in keys {
119+
v.push(format!("({}, {})", i, self.get(i)));
128120
}
129-
for i in rhs.st.keys() {
130-
c.put(*i, c.get(*i) - rhs.get(*i));
121+
v.join("")
122+
}
123+
}
124+
125+
impl Clone for SparseVector {
126+
fn clone(&self) -> Self {
127+
let mut c = Self::new(self.d);
128+
for &i in self.st.keys() {
129+
c.put(i, self.get(i));
131130
}
132131
c
133132
}
134133
}
134+
135+
#[derive(Debug)]
136+
pub enum Err {
137+
Dimension,
138+
}

src/search/binary.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,16 @@ use std::cmp::Ordering;
44

55
pub fn search<K>(xs: &[K], k: K) -> Option<usize>
66
where
7-
K: std::cmp::PartialOrd,
7+
K: Ord,
88
{
99
let mut l = 0;
1010
let mut u = xs.len();
1111

1212
while l < u {
1313
let m = (l + u) / 2;
14-
match xs[m].partial_cmp(&k) {
15-
Some(Ordering::Equal) => return Some(m),
16-
Some(Ordering::Less) => l = m + 1,
14+
match xs[m].cmp(&k) {
15+
Ordering::Equal => return Some(m),
16+
Ordering::Less => l = m + 1,
1717
_ => u = m,
1818
}
1919
}

src/sort/bubble.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
//! 若初始状态反序,时间复杂度 O(n^2)
1414
//! 稳定排序算法
1515
16-
pub fn sort<T, F>(a: &mut [T], compare: F)
16+
pub fn sort<T>(a: &mut [T])
1717
where
18-
F: Fn(&T, &T) -> bool,
18+
T: Ord,
1919
{
2020
let len = a.len();
2121
for i in 0..len.saturating_sub(1) {
2222
let mut swapped = false;
2323

2424
for j in 0..(len - 1 - i) {
25-
if compare(&a[j], &a[j + 1]) {
25+
if a[j] > a[j + 1] {
2626
a.swap(j, j + 1);
2727
swapped = true;
2828
}

src/sort/heap.rs

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
66
use crate::common::heap::{self, BinaryHeap};
77

8-
pub fn sort<T, F>(a: &[T], compare: &F) -> Vec<T>
8+
pub fn sort<T>(a: &[T]) -> Vec<T>
99
where
10-
T: Copy,
11-
F: Fn(&T, &T) -> bool,
10+
T: Ord + Copy,
1211
{
1312
let data = Vec::from(a);
14-
let mut heap = BinaryHeap::new(data, compare);
13+
let mut heap = BinaryHeap::new(data);
1514
let mut res = Vec::with_capacity(a.len());
16-
while let Some(v) = heap.pop(compare) {
15+
while let Some(v) = heap.pop() {
1716
res.push(v);
1817
}
1918
res
@@ -28,16 +27,15 @@ where
2827
/// 就地排序,小 -> 大
2928
pub fn floyd_sort<T>(a: &mut [T])
3029
where
31-
T: std::cmp::PartialOrd,
30+
T: Ord,
3231
{
3332
// 构建最大堆
34-
let compare = |x: &T, y: &T| x >= y;
35-
heap::build_heap(a, &compare);
33+
heap::build_heap(a);
3634

3735
let mut i = a.len();
3836
while i > 1 {
3937
i -= 1;
4038
a.swap(0, i);
41-
heap::heapify(&mut a[0..i], 0, &compare);
39+
heap::heapify(&mut a[0..i], 0);
4240
}
4341
}

src/sort/insert.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
//!
33
//! 基本思想是将一个value插入到有序表中
44
5-
pub fn sort<T, F>(a: &mut [T], compare: F)
5+
pub fn sort<T>(a: &mut [T])
66
where
7-
F: Fn(&T, &T) -> bool,
7+
T: Ord,
88
{
99
let len = a.len();
1010
// 注意起始索引
1111
for i in 1..len {
1212
// 将a[i]插入到a[i-1],a[i-2],a[i-3]……之中
1313
let mut j = i;
14-
while j > 0 && compare(&a[j], &a[j - 1]) {
14+
while j > 0 && a[j] < a[j - 1] {
1515
a.swap(j, j - 1);
1616
j -= 1;
1717
}

0 commit comments

Comments
 (0)