Skip to content

Commit

Permalink
Merge pull request #12 from ErikPartridge/good-errors
Browse files Browse the repository at this point in the history
Good errors
  • Loading branch information
ErikPartridge committed Sep 28, 2018
2 parents 0744876 + ede67b0 commit 35cd1a0
Showing 1 changed file with 99 additions and 41 deletions.
140 changes: 99 additions & 41 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,26 @@
use std::collections::HashMap;
use std::collections::HashSet;
use std::hash::Hash;
use std::error::Error;
use std::fmt;

/// The error returned when the length of the predicted and the ground truth do not match
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct LengthError(usize, usize);


impl fmt::Display for LengthError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Dataset lengths must be equal, found {} and {}", self.0, self.1)
}
}

impl Error for LengthError {
fn description(&self) -> &str {
"Dataset lengths do not match"
}
}


/// Compute the gini impurity of a dataset.
///
Expand Down Expand Up @@ -41,24 +61,29 @@ where
/// Returns a float where 1.0 is a perfectly accurate dataset
/// ```
/// use parsnip::categorical_accuracy;
/// # use parsnip::LengthError;
/// # fn main() -> Result<(), LengthError> {
/// let pred = vec![0, 0, 0 , 1, 2];
/// let actual = vec![1, 1, 1, 1, 2];
/// assert_eq!(categorical_accuracy(&pred, &actual), 0.4);
/// assert_eq!(categorical_accuracy(&pred, &actual)?, 0.4);
/// # Ok(())
/// # }
/// ```
pub fn categorical_accuracy<T>(pred: &[T], actual: &[T]) -> f32
pub fn categorical_accuracy<T>(pred: &[T], actual: &[T]) -> Result<f32, LengthError>
where
T: Eq,
{
assert_eq!(pred.len(), actual.len());
if pred.len() != actual.len(){
return Err(LengthError(pred.len(), actual.len()));
}
let truthy = pred.iter().zip(actual).filter(|(x, y)| x == y).count();
truthy as f32 / pred.len() as f32
Ok(truthy as f32 / pred.len() as f32)
}

fn class_precision<T>(pred: &[T], actual: &[T], class: &T) -> f32
where
T: Eq,
{
assert_eq!(pred.len(), actual.len());
//First, get the map of all true positives
let true_positives = pred
.iter()
Expand All @@ -78,7 +103,6 @@ where
T: Eq,
T: Hash,
{
assert_eq!(pred.len(), actual.len());
let classes: HashSet<_> = pred.into_iter().collect();
let mut class_weights = HashMap::new();
for value in &classes {
Expand All @@ -87,6 +111,7 @@ where
actual.iter().filter(|a| *a == *value).count() as f32 / actual.len() as f32,
);
}

classes
.iter()
.map(|c| class_precision(pred, actual, &c) * class_weights[c])
Expand All @@ -98,7 +123,6 @@ where
T: Eq,
T: Hash,
{
assert_eq!(pred.len(), actual.len());
let classes: HashSet<_> = pred.into_iter().collect();
let mut class_weights = HashMap::new();
for value in classes.clone() {
Expand Down Expand Up @@ -136,28 +160,33 @@ impl Default for Average {
/// # extern crate parsnip;
/// #[macro_use] extern crate approx; // for approximate equality check
/// use parsnip::{Average, precision};
///
/// # use parsnip::LengthError;
/// # fn main() -> Result<(), LengthError> {
/// let actual = vec![0, 1, 2, 0, 1, 2];
/// let pred = vec![0, 2, 1, 0, 0, 1];
///
/// assert_ulps_eq!(precision(&pred, &actual, Average::Macro), 0.22222222);
/// assert_ulps_eq!(precision(&pred, &actual, Average::Macro)?, 0.22222222);
/// # Ok(())
/// # }
/// ```
pub fn precision<T>(pred: &[T], actual: &[T], average: Average) -> f32
pub fn precision<T>(pred: &[T], actual: &[T], average: Average) -> Result<f32, LengthError>
where
T: Eq,
T: Hash,
{
if pred.len() != actual.len(){
return Err(LengthError(pred.len(), actual.len()));
}
match average {
Average::Macro => macro_precision(pred, actual),
Average::Weighted => weighted_precision(pred, actual),
Average::Macro => Ok(macro_precision(pred, actual)),
Average::Weighted => Ok(weighted_precision(pred, actual)),
}
}

fn class_recall<T>(pred: &[T], actual: &[T], class: &T) -> f32
where
T: Eq,
{
assert_eq!(pred.len(), actual.len());
let true_positives = pred
.iter()
.zip(actual)
Expand All @@ -176,7 +205,6 @@ where
T: Eq,
T: Hash,
{
assert_eq!(pred.len(), actual.len());
let classes: HashSet<_> = pred.into_iter().collect();
let mut class_weights = HashMap::new();
for value in &classes {
Expand All @@ -196,7 +224,6 @@ where
T: Eq,
T: Hash,
{
assert_eq!(pred.len(), actual.len());
let classes: HashSet<_> = pred.into_iter().collect();
classes
.iter()
Expand All @@ -213,20 +240,27 @@ where
/// # extern crate parsnip;
/// #[macro_use] extern crate approx; // for approximate equality check
/// use parsnip::{Average, recall};
///
/// # use parsnip::LengthError;
/// # fn main() -> Result<(), LengthError> {
/// let actual = vec![0, 1, 2, 0, 1, 2];
/// let pred = vec![0, 2, 1, 0, 0, 1];
///
/// assert_ulps_eq!(recall(&pred, &actual, Average::Macro), 0.333333333);
/// assert_ulps_eq!(recall(&pred, &actual, Average::Macro)?, 0.333333333);
/// # Ok(())
/// # }
/// ```
pub fn recall<T>(pred: &[T], actual: &[T], average: Average) -> f32
pub fn recall<T>(pred: &[T], actual: &[T], average: Average) -> Result<f32, LengthError>
where
T: Eq,
T: Hash,
{
if pred.len() != actual.len(){
return Err(LengthError(pred.len(), actual.len()));
}

match average {
Average::Macro => macro_recall(pred, actual),
Average::Weighted => weighted_recall(pred, actual),
Average::Macro => Ok(macro_recall(pred, actual)),
Average::Weighted => Ok(weighted_recall(pred, actual)),
}
}

Expand Down Expand Up @@ -259,21 +293,27 @@ where
/// # extern crate parsnip;
/// #[macro_use] extern crate approx; // for approximate equality check
/// use parsnip::{Average, f1_score};
///
/// # use parsnip::LengthError;
/// # fn main() -> Result<(), LengthError> {
/// let actual = vec![0, 1, 2, 0, 1, 2];
/// let pred = vec![0, 2, 1, 0, 0, 1];
///
/// assert_ulps_eq!(f1_score(&pred, &actual, Average::Macro), 0.26666666);
/// assert_ulps_eq!(f1_score(&pred, &actual, Average::Weighted), 0.26666666);
/// assert_ulps_eq!(f1_score(&pred, &actual, Average::Macro)?, 0.26666666);
/// assert_ulps_eq!(f1_score(&pred, &actual, Average::Weighted)?, 0.26666666);
/// # Ok(())
/// # }
/// ```
pub fn f1_score<T>(pred: &[T], actual: &[T], average: Average) -> f32
pub fn f1_score<T>(pred: &[T], actual: &[T], average: Average) -> Result<f32, LengthError>
where
T: Eq,
T: Hash,
{
if pred.len() != actual.len() {
return Err(LengthError(pred.len(), actual.len()));
}
match average {
Average::Macro => macro_f1(pred, actual),
Average::Weighted => weighted_f1(pred, actual),
Average::Macro => Ok(macro_f1(pred, actual)),
Average::Weighted => Ok(weighted_f1(pred, actual)),
}
}

Expand All @@ -285,16 +325,21 @@ where
/// ```
/// use parsnip::hamming_loss;
///
/// # use parsnip::LengthError;
/// # fn main() -> Result<(), LengthError> {
/// let actual = vec![0, 1, 2, 0, 0];
/// let pred = vec![0, 2, 1, 0, 1];
///
/// assert_eq!(hamming_loss(&pred, &actual), 0.6);
/// assert_eq!(hamming_loss(&pred, &actual)?, 0.6);
/// # Ok(())
/// # }
/// ```
pub fn hamming_loss<T>(pred: &[T], actual: &[T]) -> f32
pub fn hamming_loss<T>(pred: &[T], actual: &[T]) -> Result<f32, LengthError>
where
T: Eq,
{
1.0 - categorical_accuracy(pred, actual)
let cat_acc = categorical_accuracy(pred, actual)?;
Ok(1. - cat_acc)
}

fn macro_fbeta_score<T>(pred: &[T], actual: &[T], beta: f32) -> f32
Expand Down Expand Up @@ -329,22 +374,28 @@ where
/// ```
/// # extern crate parsnip;
/// #[macro_use] extern crate approx; // for approximate equality check
/// use parsnip::{Average, fbeta_score};
///
/// use parsnip::{Average, fbeta_score, LengthError};
/// # fn main() -> Result<(), LengthError> {
/// let actual = vec![0, 1, 2, 0, 1, 2];
/// let pred = vec![0, 2, 1, 0, 0, 1];
///
/// assert_ulps_eq!(fbeta_score(&pred, &actual, 0.5, Average::Macro), 0.23809524);
/// assert_ulps_eq!(fbeta_score(&pred, &actual, 0.5, Average::Weighted), 0.23809527);
/// assert_ulps_eq!(fbeta_score(&pred, &actual, 0.5, Average::Macro)?, 0.23809524);
/// assert_ulps_eq!(fbeta_score(&pred, &actual, 0.5, Average::Weighted)?, 0.23809527);
/// # Ok(())
/// # }
/// ```
pub fn fbeta_score<T>(pred: &[T], actual: &[T], beta: f32, average: Average) -> f32
pub fn fbeta_score<T>(pred: &[T], actual: &[T], beta: f32, average: Average) -> Result<f32, LengthError>
where
T: Eq,
T: Hash,
{
if pred.len() != actual.len(){
return Err(LengthError(pred.len(), actual.len()));
}

match average {
Average::Macro => macro_fbeta_score(pred, actual, beta),
Average::Weighted => weighted_fbeta_score(pred, actual, beta),
Average::Macro => Ok(macro_fbeta_score(pred, actual, beta)),
Average::Weighted => Ok(weighted_fbeta_score(pred, actual, beta)),
}
}

Expand All @@ -355,13 +406,16 @@ where
/// Supports macro and weighted averages
/// ```
/// use parsnip::jaccard_similiarity_score;
///
/// # use parsnip::LengthError;
/// # fn main() -> Result<(), LengthError> {
/// let actual = vec![0, 2, 1, 3];
/// let pred = vec![0, 1, 2, 3];
///
/// assert_eq!(jaccard_similiarity_score(&pred, &actual), 0.5);
/// assert_eq!(jaccard_similiarity_score(&pred, &actual)?, 0.5);
/// # Ok(())
/// # }
/// ```
pub fn jaccard_similiarity_score<T>(pred: &[T], actual: &[T]) -> f32
pub fn jaccard_similiarity_score<T>(pred: &[T], actual: &[T]) -> Result<f32, LengthError>
where
T: Eq,
{
Expand Down Expand Up @@ -389,7 +443,9 @@ mod tests {
fn test_categorical_accuracy() {
let pred = vec![0, 1, 0, 1, 0, 1];
let real = vec![0, 0, 0, 0, 1, 0];
assert_ulps_eq!(0.33333333, categorical_accuracy(&pred, &real));
assert_ulps_eq!(0.33333333, categorical_accuracy(&pred, &real).unwrap());
let pred_short = vec![0];
assert!(categorical_accuracy(&pred_short, &real).is_err());
}

#[test]
Expand Down Expand Up @@ -438,6 +494,8 @@ mod tests {
fn test_f1_score() {
let actual = vec![0, 1, 2, 0, 1, 2];
let pred = vec![0, 2, 1, 0, 0, 1];
assert_ulps_eq!(f1_score(&pred, &actual, Average::Macro), 0.26666665);
assert_ulps_eq!(f1_score(&pred, &actual, Average::Macro).unwrap(), 0.26666665);
let pred_short = vec![0];
assert!(f1_score(&pred_short, &actual, Average::Weighted).is_err());
}
}

0 comments on commit 35cd1a0

Please sign in to comment.