Skip to content

Commit c5403ef

Browse files
Deviation - Allow mixed ownership (#48)
* Deviation measures can now be computed for pairs of arrays with different ownership over their data * Revert test suite * Add a test dedicated to mixed ownership
1 parent 6f16dd8 commit c5403ef

File tree

2 files changed

+70
-32
lines changed

2 files changed

+70
-32
lines changed

src/deviation.rs

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ where
1919
///
2020
/// * `MultiInputError::EmptyInput` if `self` is empty
2121
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
22-
fn count_eq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
22+
fn count_eq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
2323
where
24-
A: PartialEq;
24+
A: PartialEq,
25+
T: Data<Elem = A>;
2526

2627
/// Counts the number of indices at which the elements of the arrays `self`
2728
/// and `other` are not equal.
@@ -30,9 +31,10 @@ where
3031
///
3132
/// * `MultiInputError::EmptyInput` if `self` is empty
3233
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
33-
fn count_neq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
34+
fn count_neq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
3435
where
35-
A: PartialEq;
36+
A: PartialEq,
37+
T: Data<Elem = A>;
3638

3739
/// Computes the [squared L2 distance] between `self` and `other`.
3840
///
@@ -50,9 +52,10 @@ where
5052
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
5153
///
5254
/// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance
53-
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
55+
fn sq_l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
5456
where
55-
A: AddAssign + Clone + Signed;
57+
A: AddAssign + Clone + Signed,
58+
T: Data<Elem = A>;
5659

5760
/// Computes the [L2 distance] between `self` and `other`.
5861
///
@@ -72,9 +75,10 @@ where
7275
/// **Panics** if the type cast from `A` to `f64` fails.
7376
///
7477
/// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance
75-
fn l2_dist(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
78+
fn l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
7679
where
77-
A: AddAssign + Clone + Signed + ToPrimitive;
80+
A: AddAssign + Clone + Signed + ToPrimitive,
81+
T: Data<Elem = A>;
7882

7983
/// Computes the [L1 distance] between `self` and `other`.
8084
///
@@ -92,9 +96,10 @@ where
9296
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
9397
///
9498
/// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry
95-
fn l1_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
99+
fn l1_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
96100
where
97-
A: AddAssign + Clone + Signed;
101+
A: AddAssign + Clone + Signed,
102+
T: Data<Elem = A>;
98103

99104
/// Computes the [L∞ distance] between `self` and `other`.
100105
///
@@ -111,9 +116,10 @@ where
111116
/// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape
112117
///
113118
/// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance
114-
fn linf_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
119+
fn linf_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
115120
where
116-
A: Clone + PartialOrd + Signed;
121+
A: Clone + PartialOrd + Signed,
122+
T: Data<Elem = A>;
117123

118124
/// Computes the [mean absolute error] between `self` and `other`.
119125
///
@@ -133,9 +139,10 @@ where
133139
/// **Panics** if the type cast from `A` to `f64` fails.
134140
///
135141
/// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error
136-
fn mean_abs_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
142+
fn mean_abs_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
137143
where
138-
A: AddAssign + Clone + Signed + ToPrimitive;
144+
A: AddAssign + Clone + Signed + ToPrimitive,
145+
T: Data<Elem = A>;
139146

140147
/// Computes the [mean squared error] between `self` and `other`.
141148
///
@@ -155,9 +162,10 @@ where
155162
/// **Panics** if the type cast from `A` to `f64` fails.
156163
///
157164
/// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error
158-
fn mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
165+
fn mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
159166
where
160-
A: AddAssign + Clone + Signed + ToPrimitive;
167+
A: AddAssign + Clone + Signed + ToPrimitive,
168+
T: Data<Elem = A>;
161169

162170
/// Computes the unnormalized [root-mean-square error] between `self` and `other`.
163171
///
@@ -175,9 +183,10 @@ where
175183
/// **Panics** if the type cast from `A` to `f64` fails.
176184
///
177185
/// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation
178-
fn root_mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
186+
fn root_mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
179187
where
180-
A: AddAssign + Clone + Signed + ToPrimitive;
188+
A: AddAssign + Clone + Signed + ToPrimitive,
189+
T: Data<Elem = A>;
181190

182191
/// Computes the [peak signal-to-noise ratio] between `self` and `other`.
183192
///
@@ -196,13 +205,14 @@ where
196205
/// **Panics** if the type cast from `A` to `f64` fails.
197206
///
198207
/// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
199-
fn peak_signal_to_noise_ratio(
208+
fn peak_signal_to_noise_ratio<T>(
200209
&self,
201-
other: &ArrayBase<S, D>,
210+
other: &ArrayBase<T, D>,
202211
maxv: A,
203212
) -> Result<f64, MultiInputError>
204213
where
205-
A: AddAssign + Clone + Signed + ToPrimitive;
214+
A: AddAssign + Clone + Signed + ToPrimitive,
215+
T: Data<Elem = A>;
206216

207217
private_decl! {}
208218
}
@@ -231,9 +241,10 @@ where
231241
S: Data<Elem = A>,
232242
D: Dimension,
233243
{
234-
fn count_eq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
244+
fn count_eq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
235245
where
236246
A: PartialEq,
247+
T: Data<Elem = A>,
237248
{
238249
return_err_if_empty!(self);
239250
return_err_unless_same_shape!(self, other);
@@ -249,16 +260,18 @@ where
249260
Ok(count)
250261
}
251262

252-
fn count_neq(&self, other: &ArrayBase<S, D>) -> Result<usize, MultiInputError>
263+
fn count_neq<T>(&self, other: &ArrayBase<T, D>) -> Result<usize, MultiInputError>
253264
where
254265
A: PartialEq,
266+
T: Data<Elem = A>,
255267
{
256268
self.count_eq(other).map(|n_eq| self.len() - n_eq)
257269
}
258270

259-
fn sq_l2_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
271+
fn sq_l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
260272
where
261273
A: AddAssign + Clone + Signed,
274+
T: Data<Elem = A>,
262275
{
263276
return_err_if_empty!(self);
264277
return_err_unless_same_shape!(self, other);
@@ -274,9 +287,10 @@ where
274287
Ok(result)
275288
}
276289

277-
fn l2_dist(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
290+
fn l2_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
278291
where
279292
A: AddAssign + Clone + Signed + ToPrimitive,
293+
T: Data<Elem = A>,
280294
{
281295
let sq_l2_dist = self
282296
.sq_l2_dist(other)?
@@ -286,9 +300,10 @@ where
286300
Ok(sq_l2_dist.sqrt())
287301
}
288302

289-
fn l1_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
303+
fn l1_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
290304
where
291305
A: AddAssign + Clone + Signed,
306+
T: Data<Elem = A>,
292307
{
293308
return_err_if_empty!(self);
294309
return_err_unless_same_shape!(self, other);
@@ -303,9 +318,10 @@ where
303318
Ok(result)
304319
}
305320

306-
fn linf_dist(&self, other: &ArrayBase<S, D>) -> Result<A, MultiInputError>
321+
fn linf_dist<T>(&self, other: &ArrayBase<T, D>) -> Result<A, MultiInputError>
307322
where
308323
A: Clone + PartialOrd + Signed,
324+
T: Data<Elem = A>,
309325
{
310326
return_err_if_empty!(self);
311327
return_err_unless_same_shape!(self, other);
@@ -323,9 +339,10 @@ where
323339
Ok(max)
324340
}
325341

326-
fn mean_abs_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
342+
fn mean_abs_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
327343
where
328344
A: AddAssign + Clone + Signed + ToPrimitive,
345+
T: Data<Elem = A>,
329346
{
330347
let l1_dist = self
331348
.l1_dist(other)?
@@ -336,9 +353,10 @@ where
336353
Ok(l1_dist / n)
337354
}
338355

339-
fn mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
356+
fn mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
340357
where
341358
A: AddAssign + Clone + Signed + ToPrimitive,
359+
T: Data<Elem = A>,
342360
{
343361
let sq_l2_dist = self
344362
.sq_l2_dist(other)?
@@ -349,21 +367,23 @@ where
349367
Ok(sq_l2_dist / n)
350368
}
351369

352-
fn root_mean_sq_err(&self, other: &ArrayBase<S, D>) -> Result<f64, MultiInputError>
370+
fn root_mean_sq_err<T>(&self, other: &ArrayBase<T, D>) -> Result<f64, MultiInputError>
353371
where
354372
A: AddAssign + Clone + Signed + ToPrimitive,
373+
T: Data<Elem = A>,
355374
{
356375
let msd = self.mean_sq_err(other)?;
357376
Ok(msd.sqrt())
358377
}
359378

360-
fn peak_signal_to_noise_ratio(
379+
fn peak_signal_to_noise_ratio<T>(
361380
&self,
362-
other: &ArrayBase<S, D>,
381+
other: &ArrayBase<T, D>,
363382
maxv: A,
364383
) -> Result<f64, MultiInputError>
365384
where
366385
A: AddAssign + Clone + Signed + ToPrimitive,
386+
T: Data<Elem = A>,
367387
{
368388
let maxv_f = maxv.to_f64().expect("failed cast from type A to f64");
369389
let msd = self.mean_sq_err(&other)?;

tests/deviation.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,24 @@ use num_traits::Float;
88

99
use std::f64;
1010

11+
#[test]
12+
fn test_deviation_computation_for_mixed_ownership() {
13+
// It's enough to check that the code compiles!
14+
let a = array![0., 0.];
15+
let b = array![1., 0.];
16+
17+
let _ = a.count_eq(&b.view());
18+
let _ = a.count_neq(&b.view());
19+
let _ = a.l2_dist(&b.view());
20+
let _ = a.sq_l2_dist(&b.view());
21+
let _ = a.l1_dist(&b.view());
22+
let _ = a.linf_dist(&b.view());
23+
let _ = a.mean_abs_err(&b.view());
24+
let _ = a.mean_sq_err(&b.view());
25+
let _ = a.root_mean_sq_err(&b.view());
26+
let _ = a.peak_signal_to_noise_ratio(&b.view(), 10.);
27+
}
28+
1129
#[test]
1230
fn test_count_eq() -> Result<(), MultiInputError> {
1331
let a = array![0., 0.];

0 commit comments

Comments
 (0)