Skip to content

Commit

Permalink
[Metric] Add quadratic kappa (#1104)
Browse files Browse the repository at this point in the history
* add test case

* Update test_classification_metrics.py

* Update __init__.py

* Update __init__.py

* Update classification_metrics.py
  • Loading branch information
sxjscience committed May 12, 2021
1 parent cbd4374 commit 405f175
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 3 deletions.
6 changes: 5 additions & 1 deletion core/src/autogluon/core/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,9 @@ def rmse_func(y_true, y_pred):
recall = make_scorer('recall',
sklearn.metrics.recall_score)

# Register other metrics
quadratic_kappa = make_scorer('quadratic_kappa', quadratic_kappa, needs_proba=False)


def customized_log_loss(y_true, y_pred, eps=1e-15):
"""
Expand Down Expand Up @@ -528,7 +531,8 @@ def customized_log_loss(y_true, y_pred, eps=1e-15):
QUANTILE_METRICS[alias] = scorer

CLASSIFICATION_METRICS = dict()
for scorer in [accuracy, balanced_accuracy, mcc, roc_auc, roc_auc_ovo_macro, average_precision, log_loss, pac_score]:
for scorer in [accuracy, balanced_accuracy, mcc, roc_auc, roc_auc_ovo_macro, average_precision,
log_loss, pac_score, quadratic_kappa]:
CLASSIFICATION_METRICS[scorer.name] = scorer
for alias in scorer.alias:
CLASSIFICATION_METRICS[alias] = scorer
Expand Down
36 changes: 35 additions & 1 deletion core/src/autogluon/core/metrics/classification_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from scipy.sparse import coo_matrix
from sklearn.utils.multiclass import unique_labels
from sklearn.utils import check_consistent_length
from sklearn.metrics import cohen_kappa_score

try:
from sklearn.metrics._classification import _check_targets, type_of_target
except:
Expand Down Expand Up @@ -331,4 +333,36 @@ def confusion_matrix(solution, prediction, labels=None, weights=None, normalize=
cm_df = pd.DataFrame(data=cm, index=labels, columns=labels)
return cm_df
else:
return cm
return cm

# TODO Add the "labels" option to metrics that will require the label map.
# We will need to update how we use those metrics accordingly.
def quadratic_kappa(y_true, y_pred):
"""Calculate the cohen kappa score with quadratic weighting scheme.
This is also known as "quadratic kappa" in the Kaggle competitions
such as petfinder: https://www.kaggle.com/c/petfinder-adoption-prediction/overview/evaluation
We will also support probabilistic input to ensure that the function knows
the number of possible classes.
Parameters
----------
y_true
Shape (#samples,)
y_pred
Shape (#samples, #class) or (#samples,)
Returns
-------
score
scalar score
"""
labels = None
if y_pred.ndim > 1:
if labels is not None:
assert len(labels) == y_pred.shape[1]
else:
labels = np.arange(y_pred.shape[1])
y_pred = np.argmax(y_pred, axis=-1)
return cohen_kappa_score(y_true, y_pred, labels=labels, weights='quadratic')
17 changes: 16 additions & 1 deletion core/tests/unittests/metrics/test_classification_metrics.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import pytest
import sklearn
from autogluon.core.metrics import confusion_matrix, log_loss
from autogluon.core.metrics import confusion_matrix, log_loss, quadratic_kappa


def test_confusion_matrix_with_valid_inputs_without_labels_and_weights():
Expand Down Expand Up @@ -173,3 +173,18 @@ def test_log_loss_with_sklearn(gt, probs):

ag_loss_as_sklearn = log_loss.convert_score_to_sklearn_val(ag_loss)
np.testing.assert_allclose(ag_loss_as_sklearn, sklearn_log_loss)


def test_quadratic_kappa():
actuals = np.array([4, 4, 3, 4, 4, 4, 1, 1, 2, 1])
preds = np.array([0, 2, 1, 0, 0, 0, 1, 1, 2, 1])
value = quadratic_kappa(actuals, preds)
assert round(value, 3) == -0.139

actuals = np.array([0, 1, 0, 1])
preds = np.array([[0.8, 0.1, 0.1],
[0.7, 0.1, 0.2],
[0.1, 0.8, 0.1],
[0.1, 0.1, 0.8]])
value = quadratic_kappa(actuals, preds)
assert value == 0.25

0 comments on commit 405f175

Please sign in to comment.