Skip to content

Commit

Permalink
Use registry pattern to make classification metric set
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Jan 4, 2022
1 parent 852819e commit b62e0ed
Showing 1 changed file with 9 additions and 24 deletions.
33 changes: 9 additions & 24 deletions rexmex/metricset.py
@@ -1,18 +1,6 @@
from typing import Collection, List, Tuple

from rexmex.metrics.classification import (
accuracy_score,
average_precision_score,
balanced_accuracy_score,
f1_score,
fowlkes_mallows_index,
matthews_correlation_coefficient,
pr_auc_score,
precision_score,
recall_score,
roc_auc_score,
specificity,
)
from rexmex.metrics.classification import classifications
from rexmex.metrics.rating import (
mean_absolute_error,
mean_absolute_percentage_error,
Expand Down Expand Up @@ -104,17 +92,14 @@ class ClassificationMetricSet(MetricSet):
"""

def __init__(self):
self["roc_auc"] = roc_auc_score
self["pr_auc"] = pr_auc_score
self["average_precision"] = average_precision_score
self["f1_score"] = binarize(f1_score)
self["matthews_correlation_coefficent"] = binarize(matthews_correlation_coefficient)
self["fowlkes_mallows_index"] = binarize(fowlkes_mallows_index)
self["precision"] = binarize(precision_score)
self["recall"] = binarize(recall_score)
self["specificity"] = binarize(specificity)
self["accuracy"] = binarize(accuracy_score)
self["balanced_accuracy"] = binarize(balanced_accuracy_score)
super().__init__()
for func in classifications:
name = func.__name__
if name.endswith("_score"):
name = name[: -len("_score")]
if func.binarize:
func = binarize(func)
self[name] = func

def __repr__(self):
"""
Expand Down

0 comments on commit b62e0ed

Please sign in to comment.