Skip to content

Commit

Permalink
Merge pull request #187 from abstractqqq/better_roc_extreme_case
Browse files Browse the repository at this point in the history
better extreme case handling for roc auc
  • Loading branch information
abstractqqq committed Jun 22, 2024
2 parents 5899ae6 + 94aff5a commit 7b2fcbc
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 41 deletions.
9 changes: 6 additions & 3 deletions python/polars_ds/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,8 @@ def query_roc_auc(
Computes ROC AUC using self as actual and pred as predictions.
Self must be binary and castable to type UInt32. If self is not all 0s and 1s or not binary,
the result will not make sense, or some error may occur.
the result will not make sense, or some error may occur. If no positive class exist in data,
NaN will be returned.
Parameters
----------
Expand All @@ -281,7 +282,8 @@ def query_gini(actual: StrOrExpr, pred: StrOrExpr) -> pl.Expr:
Computes the Gini coefficient. This is 2 * AUC - 1.
Self must be binary and castable to type UInt32. If self is not all 0s and 1s or not binary,
the result will not make sense, or some error may occur.
the result will not make sense, or some error may occur. If no positive class exist in data,
NaN will be returned.
Parameters
----------
Expand Down Expand Up @@ -379,7 +381,8 @@ def query_binary_metrics(actual: StrOrExpr, pred: StrOrExpr, threshold: float =
having the names as given here.
Self must be binary and castable to type UInt32. If self is not all 0s and 1s,
the result will not make sense, or some error may occur.
the result will not make sense, or some error may occur. If there is no positive class in data,
NaN or other numerical error may occur.
Average precision is computed using Sum (R_n - R_n-1)*P_n-1, which is not the textbook definition,
but is consistent with Scikit-learn. For more information, see
Expand Down
81 changes: 43 additions & 38 deletions src/num/tp_fp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fn combo_output(_: &[Field]) -> PolarsResult<Field> {
Ok(Field::new("", DataType::Struct(v)))
}

fn tp_fp_frame(predicted: &Series, actual: &Series, as_ratio: bool) -> PolarsResult<LazyFrame> {
fn tp_fp_frame(predicted: &Series, actual: &Series, positive_count:u32, as_ratio: bool) -> PolarsResult<LazyFrame> {
// Checking for data quality issues
if (actual.len() != predicted.len())
|| actual.is_empty()
Expand All @@ -28,13 +28,6 @@ fn tp_fp_frame(predicted: &Series, actual: &Series, as_ratio: bool) -> PolarsRes
));
}

let positive_counts = actual.sum::<u32>().unwrap_or(0);
if positive_counts == 0 {
return Err(PolarsError::ComputeError(
"No positives in actual, or actual cannot be turned into integers.".into(),
));
}

// Start computing
let n = predicted.len() as u32;
let df = df!(
Expand All @@ -52,8 +45,8 @@ fn tp_fp_frame(predicted: &Series, actual: &Series, as_ratio: bool) -> PolarsRes
.sort(["threshold"], Default::default())
.with_columns([
(lit(n) - col("cnt").cum_sum(false) + col("cnt")).alias("predicted_positive"),
(lit(positive_counts) - col("pos_cnt_at_threshold").cum_sum(false))
.shift_and_fill(1, positive_counts)
(lit(positive_count) - col("pos_cnt_at_threshold").cum_sum(false))
.shift_and_fill(1, positive_count)
.alias("tp"),
])
.select([
Expand Down Expand Up @@ -89,7 +82,12 @@ fn pl_combo_b(inputs: &[Series]) -> PolarsResult<Series> {
let threshold = inputs[2].f64()?;
let threshold = threshold.get(0).unwrap_or(0.5);

let mut binding = tp_fp_frame(predicted, actual, true)?.collect()?;
let positive_count = actual.sum::<u32>().unwrap_or(0);
if positive_count == 0 {
return Ok(Series::from_iter([f64::NAN]))
}

let mut binding = tp_fp_frame(predicted, actual, positive_count, true)?.collect()?;
let frame = binding.align_chunks();

let tpr = frame.drop_in_place("tpr").unwrap();
Expand Down Expand Up @@ -149,6 +147,39 @@ fn binary_confusion_matrix(combined_series: &UInt32Chunked) -> [u32; 4] {
output
}


#[polars_expr(output_type=Float64)]
fn pl_roc_auc(inputs: &[Series]) -> PolarsResult<Series> {
// actual, when passed in, is always u32 (done in Python extension side)
let actual = &inputs[0];
let predicted = &inputs[1];

let positive_count = actual.sum::<u32>().unwrap_or(0);
if positive_count == 0 {
return Ok(Series::from_iter([f64::NAN]))
}

let mut binding = tp_fp_frame(predicted, actual, positive_count, true)?
.select([col("tpr"), col("fpr")])
.collect()?;
let frame = binding.align_chunks();

let tpr = frame.drop_in_place("tpr").unwrap();
let fpr = frame.drop_in_place("fpr").unwrap();

// Should be contiguous. No need to rechunk
let y = tpr.f64().unwrap();
let x = fpr.f64().unwrap();

let y: ArrayView1<f64> = y.to_ndarray()?;
let x: ArrayView1<f64> = x.to_ndarray()?;

let out: f64 = -super::trapz::trapz(y, x);

Ok(Series::from_iter([out]))
}


// bcm = binary confusion matrix
fn bcm_output(_: &[Field]) -> PolarsResult<Field> {
let tp = Field::new("tp", DataType::UInt32);
Expand Down Expand Up @@ -272,30 +303,4 @@ fn pl_binary_confusion_matrix(inputs: &[Series]) -> PolarsResult<Series> {

let out = result.into_struct("confusion_matrix");
Ok(out.into_series())
}

#[polars_expr(output_type=Float64)]
fn pl_roc_auc(inputs: &[Series]) -> PolarsResult<Series> {
// actual, when passed in, is always u32 (done in Python extension side)
let actual = &inputs[0];
let predicted = &inputs[1];

let mut binding = tp_fp_frame(predicted, actual, true)?
.select([col("tpr"), col("fpr")])
.collect()?;
let frame = binding.align_chunks();

let tpr = frame.drop_in_place("tpr").unwrap();
let fpr = frame.drop_in_place("fpr").unwrap();

// Should be contiguous. No need to rechunk
let y = tpr.f64().unwrap();
let x = fpr.f64().unwrap();

let y: ArrayView1<f64> = y.to_ndarray()?;
let x: ArrayView1<f64> = x.to_ndarray()?;

let out: f64 = -super::trapz::trapz(y, x);

Ok(Series::from_iter([out]))
}
}
21 changes: 21 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,27 @@ def test_confusion_matrix(y_true, y_score):
pytest.approx(res) == ref


def test_roc_auc():
from sklearn.metrics import roc_auc_score

df = pds.random_data(size=2000, n_cols=0).select(
pds.random(0.0, 1.0).alias("predictions"),
pds.random(0.0, 1.0).round().cast(pl.Int32).alias("target"),
pl.lit(0).alias("zero_target"),
)

roc_auc = df.select(pds.query_roc_auc("target", "predictions")).item(0, 0)

answer = roc_auc_score(df["target"].to_numpy(), df["predictions"].to_numpy())

assert np.isclose(roc_auc, answer)

# When all classes are 0, roc_auc returns NaN
nan_roc = df.select(pds.query_roc_auc("zero_target", "predictions")).item(0, 0)

assert np.isnan(nan_roc)


def test_multiclass_roc_auc():
from sklearn.metrics import roc_auc_score

Expand Down

0 comments on commit 7b2fcbc

Please sign in to comment.