Skip to content

Commit 0424240

Browse files
munckymagikLukeMathWalker
authored andcommitted
Remove redundant call to abs in sq_l2_dist (#57)
* Add a benchmark for sq_l2_dist in deviation * Removes a redundant call to abs in sq_l2_dist
1 parent d9483ea commit 0424240

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

Cargo.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,8 @@ harness = false
3939

4040
[[bench]]
4141
name = "summary_statistics"
42-
harness = false
42+
harness = false
43+
44+
[[bench]]
45+
name = "deviation"
46+
harness = false

benches/deviation.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use criterion::{
2+
black_box, criterion_group, criterion_main, AxisScale, Criterion, ParameterizedBenchmark,
3+
PlotConfiguration,
4+
};
5+
use ndarray::prelude::*;
6+
use ndarray_rand::RandomExt;
7+
use ndarray_stats::DeviationExt;
8+
use rand::distributions::Uniform;
9+
10+
fn sq_l2_dist(c: &mut Criterion) {
11+
let lens = vec![10, 100, 1000, 10000];
12+
let benchmark = ParameterizedBenchmark::new(
13+
"sq_l2_dist",
14+
|bencher, &len| {
15+
let data = Array::random(len, Uniform::new(0.0, 1.0));
16+
let data2 = Array::random(len, Uniform::new(0.0, 1.0));
17+
18+
bencher.iter(|| black_box(data.sq_l2_dist(&data2).unwrap()))
19+
},
20+
lens,
21+
)
22+
.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
23+
c.bench("sq_l2_dist", benchmark);
24+
}
25+
26+
criterion_group! {
27+
name = benches;
28+
config = Criterion::default();
29+
targets = sq_l2_dist
30+
}
31+
criterion_main!(benches);

src/deviation.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,8 @@ where
261261

262262
Zip::from(self).and(other).apply(|self_i, other_i| {
263263
let (a, b) = (self_i.clone(), other_i.clone());
264-
let abs_diff = (a - b).abs();
265-
result += abs_diff.clone() * abs_diff;
264+
let diff = a - b;
265+
result += diff.clone() * diff;
266266
});
267267

268268
Ok(result)

0 commit comments

Comments
 (0)