Skip to content

Commit

Permalink
Merge pull request #10 from Enet4/imp/average-enum
Browse files Browse the repository at this point in the history
Static enforcement of score averaging
  • Loading branch information
ErikPartridge committed Sep 21, 2018
2 parents f429a1a + 34db42b commit 73a8767
Showing 1 changed file with 40 additions and 39 deletions.
79 changes: 40 additions & 39 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,23 @@ where
.sum()
}

/// The type of score averaging strategy employed in the calculation of
/// precision, recall, or F-measure.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub enum Average {
/// Macro averaging (averaged across classes or labels).
Macro,
/// Averaging across classes, weighted by the number of true instances.
Weighted,
}

impl Default for Average {
/// The default average strategy is `Average::Macro`.
fn default() -> Self {
Average::Macro
}
}

/// The precision of a dataset
///
/// Returns a float where a 1.0 is a perfectly precise result set
Expand All @@ -118,25 +135,21 @@ where
/// ```
/// # extern crate parsnip;
/// #[macro_use] extern crate approx; // for approximate equality check
/// use parsnip::precision;
/// use parsnip::{Average, precision};
///
/// let actual = vec![0, 1, 2, 0, 1, 2];
/// let pred = vec![0, 2, 1, 0, 0, 1];
///
/// assert_ulps_eq!(precision(&pred, &actual, Some("macro".to_string())), 0.22222222);
/// assert_ulps_eq!(precision(&pred, &actual, Average::Macro), 0.22222222);
/// ```
pub fn precision<T>(pred: &[T], actual: &[T], average: Option<String>) -> f32
pub fn precision<T>(pred: &[T], actual: &[T], average: Average) -> f32
where
T: Eq,
T: Hash,
{
match average {
None => macro_precision(pred, actual),
Some(string) => match string.as_ref() {
"macro" => macro_precision(pred, actual),
"weighted" => weighted_precision(pred, actual),
_ => panic!("invalid averaging type"),
},
Average::Macro => macro_precision(pred, actual),
Average::Weighted => weighted_precision(pred, actual),
}
}

Expand Down Expand Up @@ -199,25 +212,21 @@ where
/// ```
/// # extern crate parsnip;
/// #[macro_use] extern crate approx; // for approximate equality check
/// use parsnip::recall;
/// use parsnip::{Average, recall};
///
/// let actual = vec![0, 1, 2, 0, 1, 2];
/// let pred = vec![0, 2, 1, 0, 0, 1];
///
/// assert_ulps_eq!(recall(&pred, &actual, Some("macro".to_string())), 0.333333333);
/// assert_ulps_eq!(recall(&pred, &actual, Average::Macro), 0.333333333);
/// ```
pub fn recall<T>(pred: &[T], actual: &[T], average: Option<String>) -> f32
pub fn recall<T>(pred: &[T], actual: &[T], average: Average) -> f32
where
T: Eq,
T: Hash,
{
match average {
None => macro_recall(pred, actual),
Some(string) => match string.as_ref() {
"macro" => macro_recall(pred, actual),
"weighted" => weighted_recall(pred, actual),
_ => panic!("invalid averaging type"),
},
Average::Macro => macro_recall(pred, actual),
Average::Weighted => weighted_recall(pred, actual),
}
}

Expand Down Expand Up @@ -249,26 +258,22 @@ where
/// ```
/// # extern crate parsnip;
/// #[macro_use] extern crate approx; // for approximate equality check
/// use parsnip::f1_score;
/// use parsnip::{Average, f1_score};
///
/// let actual = vec![0, 1, 2, 0, 1, 2];
/// let pred = vec![0, 2, 1, 0, 0, 1];
///
/// assert_ulps_eq!(f1_score(&pred, &actual, Some("macro".to_string())), 0.26666666);
/// assert_ulps_eq!(f1_score(&pred, &actual, Some("weighted".to_string())), 0.26666666);
/// assert_ulps_eq!(f1_score(&pred, &actual, Average::Macro), 0.26666666);
/// assert_ulps_eq!(f1_score(&pred, &actual, Average::Weighted), 0.26666666);
/// ```
pub fn f1_score<T>(pred: &[T], actual: &[T], average: Option<String>) -> f32
pub fn f1_score<T>(pred: &[T], actual: &[T], average: Average) -> f32
where
T: Eq,
T: Hash,
{
match average {
None => macro_f1(pred, actual),
Some(string) => match string.as_ref() {
"macro" => macro_f1(pred, actual),
"weighted" => weighted_f1(pred, actual),
_ => panic!("invalid averaging type"),
},
Average::Macro => macro_f1(pred, actual),
Average::Weighted => weighted_f1(pred, actual),
}
}

Expand Down Expand Up @@ -324,26 +329,22 @@ where
/// ```
/// # extern crate parsnip;
/// #[macro_use] extern crate approx; // for approximate equality check
/// use parsnip::fbeta_score;
/// use parsnip::{Average, fbeta_score};
///
/// let actual = vec![0, 1, 2, 0, 1, 2];
/// let pred = vec![0, 2, 1, 0, 0, 1];
///
/// assert_ulps_eq!(fbeta_score(&pred, &actual, 0.5, Some("macro".to_string())), 0.23809524);
/// assert_ulps_eq!(fbeta_score(&pred, &actual, 0.5, Some("weighted".to_string())), 0.23809527);
/// assert_ulps_eq!(fbeta_score(&pred, &actual, 0.5, Average::Macro), 0.23809524);
/// assert_ulps_eq!(fbeta_score(&pred, &actual, 0.5, Average::Weighted), 0.23809527);
/// ```
pub fn fbeta_score<T>(pred: &[T], actual: &[T], beta: f32, average: Option<String>) -> f32
pub fn fbeta_score<T>(pred: &[T], actual: &[T], beta: f32, average: Average) -> f32
where
T: Eq,
T: Hash,
{
match average {
None => macro_fbeta_score(pred, actual, beta),
Some(string) => match string.as_ref() {
"macro" => macro_fbeta_score(pred, actual, beta),
"weighted" => weighted_fbeta_score(pred, actual, beta),
_ => panic!("invalid averaging type"),
},
Average::Macro => macro_fbeta_score(pred, actual, beta),
Average::Weighted => weighted_fbeta_score(pred, actual, beta),
}
}

Expand Down Expand Up @@ -437,6 +438,6 @@ mod tests {
fn test_f1_score() {
let actual = vec![0, 1, 2, 0, 1, 2];
let pred = vec![0, 2, 1, 0, 0, 1];
assert_ulps_eq!(f1_score(&pred, &actual, Some("macro".to_string())), 0.26666666);
assert_ulps_eq!(f1_score(&pred, &actual, Average::Macro), 0.26666665);
}
}

0 comments on commit 73a8767

Please sign in to comment.