Skip to content

Commit

Permalink
roc auc optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
ekorman committed Apr 20, 2024
1 parent 762c8c7 commit 35bcdff
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -675,16 +675,15 @@ def test_compute_roc_auc_with_label_map(
evaluation_type=enums.TaskType.CLASSIFICATION,
)

assert (
_compute_roc_auc(
db=db,
prediction_filter=prediction_filter,
groundtruth_filter=groundtruth_filter,
grouper_key="animal",
grouper_mappings=grouper_mappings,
)
== 0.7777777777777779
roc_auc = _compute_roc_auc(
db=db,
prediction_filter=prediction_filter,
groundtruth_filter=groundtruth_filter,
grouper_key="animal",
grouper_mappings=grouper_mappings,
)
assert roc_auc is not None
assert abs(roc_auc - 0.7777777777777779) < 1e-6


def test_compute_classification(
Expand Down
71 changes: 42 additions & 29 deletions api/valor_api/backend/metrics/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Sequence

import numpy as np
from sqlalchemy import Float, Integer, Subquery
from sqlalchemy import Integer, Subquery
from sqlalchemy.orm import Bundle, Session
from sqlalchemy.sql import and_, case, func, select
from sqlalchemy.sql.selectable import NamedFromClause
Expand Down Expand Up @@ -299,52 +299,65 @@ def _compute_binary_roc_auc(
if n - n_pos == 0:
return 1.0

# true positive rates
tprs = (
func.sum(
(gts_query.c.label_value == label.value).cast(Integer).cast(Float)
).over(order_by=-preds_query.c.score)
/ n_pos
)

# false positive rates
fprs = func.sum(
(gts_query.c.label_value != label.value).cast(Integer).cast(Float)
).over(order_by=-preds_query.c.score) / (n - n_pos)

tprs_fprs_query = (
basic_counts_query = (
select(
tprs.label("tprs"),
fprs.label("fprs"),
preds_query.c.datum_id,
preds_query.c.score,
).join(
preds_query, # type: ignore - SQLAlchemy Subquery is incompatible with join type
gts_query.c.datum_id == preds_query.c.datum_id,
(gts_query.c.label_value == label.value)
.cast(Integer)
.label("is_true_positive"),
(gts_query.c.label_value != label.value)
.cast(Integer)
.label("is_false_positive"),
)
.select_from(
preds_query.join( # type: ignore
gts_query, preds_query.c.datum_id == gts_query.c.datum_id # type: ignore
)
)
).subquery()
.alias("basic_counts")
)

tpr_fpr_cumulative = select(
basic_counts_query.c.score,
func.sum(basic_counts_query.c.is_true_positive)
.over(order_by=basic_counts_query.c.score.desc())
.label("cumulative_tp"),
func.sum(basic_counts_query.c.is_false_positive)
.over(order_by=basic_counts_query.c.score.desc())
.label("cumulative_fp"),
).alias("tpr_fpr_cumulative")

tpr_fpr_rates = select(
tpr_fpr_cumulative.c.score,
(tpr_fpr_cumulative.c.cumulative_tp / n_pos).label("tpr"),
(tpr_fpr_cumulative.c.cumulative_fp / (n - n_pos)).label("fpr"),
).alias("tpr_fpr_rates")

trap_areas = select(
(
0.5
* (
tprs_fprs_query.c.tprs
+ func.lag(tprs_fprs_query.c.tprs).over(
order_by=-tprs_fprs_query.c.score
tpr_fpr_rates.c.tpr
+ func.lag(tpr_fpr_rates.c.tpr).over(
order_by=tpr_fpr_rates.c.score.desc()
)
)
* (
tprs_fprs_query.c.fprs
- func.lag(tprs_fprs_query.c.fprs).over(
order_by=-tprs_fprs_query.c.score
tpr_fpr_rates.c.fpr
- func.lag(tpr_fpr_rates.c.fpr).over(
order_by=tpr_fpr_rates.c.score.desc()
)
)
).label("trap_area")
).subquery()

ret = db.scalar(func.sum(trap_areas.c.trap_area))

if ret is None:
return np.nan
return ret

return float(ret)


def _compute_roc_auc(
Expand Down Expand Up @@ -386,7 +399,7 @@ def _compute_roc_auc(
sum_roc_aucs = 0
label_count = 0

for grouper_value, labels in value_to_labels_mapping.items():
for _, labels in value_to_labels_mapping.items():
label_filter = groundtruth_filter.model_copy()
label_filter.label_ids = [label.id for label in labels]

Expand Down

0 comments on commit 35bcdff

Please sign in to comment.