Skip to content

Commit

Permalink
[MRG] Adds multiclass ROC AUC (scikit-learn#12789)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasjpfan authored and TomDLT committed Jul 25, 2019
1 parent 6da9b14 commit 1d9f033
Show file tree
Hide file tree
Showing 9 changed files with 593 additions and 62 deletions.
74 changes: 70 additions & 4 deletions doc/modules/model_evaluation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ Others also work in the multiclass case:
confusion_matrix
hinge_loss
matthews_corrcoef
roc_auc_score


Some also work in the multilabel case:
Expand All @@ -331,6 +332,7 @@ Some also work in the multilabel case:
precision_recall_fscore_support
precision_score
recall_score
roc_auc_score
zero_one_loss

And some work with binary and multilabel (but not multiclass) problems:
Expand All @@ -339,7 +341,6 @@ And some work with binary and multilabel (but not multiclass) problems:
:template: function.rst

average_precision_score
roc_auc_score


In the following sub-sections, we will describe each of those functions,
Expand Down Expand Up @@ -1313,9 +1314,52 @@ In multi-label classification, the :func:`roc_auc_score` function is
extended by averaging over the labels as :ref:`above <average>`.

Compared to metrics such as the subset accuracy, the Hamming loss, or the
F1 score, ROC doesn't require optimizing a threshold for each label. The
:func:`roc_auc_score` function can also be used in multi-class classification,
if the predicted outputs have been binarized.
F1 score, ROC doesn't require optimizing a threshold for each label.

The :func:`roc_auc_score` function can also be used in multi-class
classification. Two averaging strategies are currently supported: the
one-vs-one algorithm computes the average of the pairwise ROC AUC scores, and
the one-vs-rest algorithm computes the average of the ROC AUC scores for each
class against all other classes. In both cases, the predicted labels are
provided in an array with values from 0 to ``n_classes``, and the scores
correspond to the probability estimates that a sample belongs to a particular
class. The OvO and OvR algorithms supports weighting uniformly
(``average='macro'``) and weighting by the prevalence (``average='weighted'``).

**One-vs-one Algorithm**: Computes the average AUC of all possible pairwise
combinations of classes. [HT2001]_ defines a multiclass AUC metric weighted
uniformly:

.. math::
\frac{2}{c(c-1)}\sum_{j=1}^{c}\sum_{k > j}^c (\text{AUC}(j | k) +
\text{AUC}(k | j))
where :math:`c` is the number of classes and :math:`\text{AUC}(j | k)` is the
AUC with class :math:`j` as the positive class and class :math:`k` as the
negative class. In general,
:math:`\text{AUC}(j | k) \neq \text{AUC}(k | j))` in the multiclass
case. This algorithm is used by setting the keyword argument ``multiclass``
to ``'ovo'`` and ``average`` to ``'macro'``.

The [HT2001]_ multiclass AUC metric can be extended to be weighted by the
prevalence:

.. math::
\frac{2}{c(c-1)}\sum_{j=1}^{c}\sum_{k > j}^c p(j \cup k)(
\text{AUC}(j | k) + \text{AUC}(k | j))
where :math:`c` is the number of classes. This algorithm is used by setting
the keyword argument ``multiclass`` to ``'ovo'`` and ``average`` to
``'weighted'``. The ``'weighted'`` option returns a prevalence-weighted average
as described in [FC2009]_.

**One-vs-rest Algorithm**: Computes the AUC of each class against the rest.
The algorithm is functionally the same as the multilabel case. To enable this
algorithm set the keyword argument ``multiclass`` to ``'ovr'``. Similar to
OvO, OvR supports two types of averaging: ``'macro'`` [F2006]_ and
``'weighted'`` [F2001]_.

In applications where a high false positive rate is not tolerable the parameter
``max_fpr`` of :func:`roc_auc_score` can be used to summarize the ROC curve up
Expand All @@ -1341,6 +1385,28 @@ to the given limit.
for an example of using ROC to
model species distribution.

.. topic:: References:

.. [HT2001] Hand, D.J. and Till, R.J., (2001). `A simple generalisation
of the area under the ROC curve for multiple class classification problems.
<http://link.springer.com/article/10.1023/A:1010920819831>`_
Machine learning, 45(2), pp.171-186.
.. [FC2009] Ferri, Cèsar & Hernandez-Orallo, Jose & Modroiu, R. (2009).
`An Experimental Comparison of Performance Measures for Classification.
<https://www.math.ucdavis.edu/~saito/data/roc/ferri-class-perf-metrics.pdf>`_
Pattern Recognition Letters. 30. 27-38.
.. [F2006] Fawcett, T., 2006. `An introduction to ROC analysis.
<http://www.sciencedirect.com/science/article/pii/S016786550500303X>`_
Pattern Recognition Letters, 27(8), pp. 861-874.
.. [F2001] Fawcett, T., 2001. `Using rule sets to maximize
ROC performance <http://ieeexplore.ieee.org/document/989510/>`_
In Data Mining, 2001.
Proceedings IEEE International Conference, pp. 131-138.
.. _zero_one_loss:

Zero one loss
Expand Down
7 changes: 7 additions & 0 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ Changelog
- |API| Deprecate ``training_data_`` unused attribute in
:class:`manifold.Isomap`. :issue:`10482` by `Tom Dupre la Tour`_.

:mod:`sklearn.metrics`
......................

- |Feature| Added multiclass support to :func:`metrics.roc_auc_score`.
:issue:`12789` by :user:`Kathy Chen <kathyxchen>`,
:user:`Mohamed Maskani <maskani-moh>`, and :user:`Thomas Fan <thomasjpfan>`.

:mod:`sklearn.model_selection`
..................

Expand Down
42 changes: 33 additions & 9 deletions examples/model_selection/plot_roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,21 @@
The "steepness" of ROC curves is also important, since it is ideal to maximize
the true positive rate while minimizing the false positive rate.
Multiclass settings
-------------------
ROC curves are typically used in binary classification to study the output of
a classifier. In order to extend ROC curve and ROC area to multi-class
or multi-label classification, it is necessary to binarize the output. One ROC
a classifier. In order to extend ROC curve and ROC area to multi-label
classification, it is necessary to binarize the output. One ROC
curve can be drawn per label, but one can also draw a ROC curve by considering
each element of the label indicator matrix as a binary prediction
(micro-averaging).
Another evaluation measure for multi-class classification is
Another evaluation measure for multi-label classification is
macro-averaging, which gives equal weight to the classification of each
label.
.. note::
See also :func:`sklearn.metrics.roc_auc_score`,
:ref:`sphx_glr_auto_examples_model_selection_plot_roc_crossval.py`.
:ref:`sphx_glr_auto_examples_model_selection_plot_roc_crossval.py`
"""
print(__doc__)
Expand All @@ -47,6 +44,7 @@
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from scipy import interp
from sklearn.metrics import roc_auc_score

# Import some data to play with
iris = datasets.load_iris()
Expand Down Expand Up @@ -101,8 +99,8 @@


##############################################################################
# Plot ROC curves for the multiclass problem

# Plot ROC curves for the multilabel problem
# ..........................................
# Compute macro-average ROC curve and ROC area

# First aggregate all false positive rates
Expand Down Expand Up @@ -146,3 +144,29 @@
plt.title('Some extension of Receiver operating characteristic to multi-class')
plt.legend(loc="lower right")
plt.show()


##############################################################################
# Area under ROC for the multiclass problem
# .........................................
# The :func:`sklearn.metrics.roc_auc_score` function can be used for
# multi-class classification. The mutliclass One-vs-One scheme compares every
# unique pairwise combination of classes. In this section, we calcuate the AUC
# using the OvR and OvO schemes. We report a macro average, and a
# prevalence-weighted average.
y_prob = classifier.predict_proba(X_test)

macro_roc_auc_ovo = roc_auc_score(y_test, y_prob, multi_class="ovo",
average="macro")
weighted_roc_auc_ovo = roc_auc_score(y_test, y_prob, multi_class="ovo",
average="weighted")
macro_roc_auc_ovr = roc_auc_score(y_test, y_prob, multi_class="ovr",
average="macro")
weighted_roc_auc_ovr = roc_auc_score(y_test, y_prob, multi_class="ovr",
average="weighted")
print("One-vs-One ROC AUC scores:\n{:.6f} (macro),\n{:.6f} "
"(weighted by prevalence)"
.format(macro_roc_auc_ovo, weighted_roc_auc_ovo))
print("One-vs-Rest ROC AUC scores:\n{:.6f} (macro),\n{:.6f} "
"(weighted by prevalence)"
.format(macro_roc_auc_ovr, weighted_roc_auc_ovr))
72 changes: 72 additions & 0 deletions sklearn/metrics/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# Noel Dawe <noel@dawe.me>
# License: BSD 3 clause

from itertools import combinations

import numpy as np

Expand Down Expand Up @@ -123,3 +124,74 @@ def _average_binary_score(binary_metric, y_true, y_score, average,
return np.average(score, weights=average_weight)
else:
return score


def _average_multiclass_ovo_score(binary_metric, y_true, y_score,
average='macro'):
"""Average one-versus-one scores for multiclass classification.
Uses the binary metric for one-vs-one multiclass classification,
where the score is computed according to the Hand & Till (2001) algorithm.
Parameters
----------
binary_metric : callable
The binary metric function to use that accepts the following as input
y_true_target : array, shape = [n_samples_target]
Some sub-array of y_true for a pair of classes designated
positive and negative in the one-vs-one scheme.
y_score_target : array, shape = [n_samples_target]
Scores corresponding to the probability estimates
of a sample belonging to the designated positive class label
y_true : array-like, shape = (n_samples, )
True multiclass labels.
y_score : array-like, shape = (n_samples, n_classes)
Target scores corresponding to probability estimates of a sample
belonging to a particular class
average : 'macro' or 'weighted', optional (default='macro')
Determines the type of averaging performed on the pairwise binary
metric scores
``'macro'``:
Calculate metrics for each label, and find their unweighted
mean. This does not take label imbalance into account. Classes
are assumed to be uniformly distributed.
``'weighted'``:
Calculate metrics for each label, taking into account the
prevalence of the classes.
Returns
-------
score : float
Average of the pairwise binary metric scores
"""
check_consistent_length(y_true, y_score)

y_true_unique = np.unique(y_true)
n_classes = y_true_unique.shape[0]
n_pairs = n_classes * (n_classes - 1) // 2
pair_scores = np.empty(n_pairs)

is_weighted = average == "weighted"
prevalence = np.empty(n_pairs) if is_weighted else None

# Compute scores treating a as positive class and b as negative class,
# then b as positive class and a as negative class
for ix, (a, b) in enumerate(combinations(y_true_unique, 2)):
a_mask = y_true == a
b_mask = y_true == b
ab_mask = np.logical_or(a_mask, b_mask)

if is_weighted:
prevalence[ix] = np.average(ab_mask)

a_true = a_mask[ab_mask]
b_true = b_mask[ab_mask]

a_true_score = binary_metric(a_true, y_score[ab_mask, a])
b_true_score = binary_metric(b_true, y_score[ab_mask, b])
pair_scores[ix] = (a_true_score + b_true_score) / 2

return np.average(pair_scores, weights=prevalence)
Loading

0 comments on commit 1d9f033

Please sign in to comment.