Skip to content

Commit

Permalink
Allow user to specify colors for PrecisionRecallCurve (#1051)
Browse files Browse the repository at this point in the history
This PR allows user to specify colors for PrecisionRecallCurve, and does some additional polishing of the visualizer for the multiclass case. In addition, it updates the name of the colors property on the ClassificationScoreVisualizer class to class_colors_
  • Loading branch information
VladSkripniuk committed May 9, 2020
1 parent 8689a2e commit 1879734
Show file tree
Hide file tree
Showing 18 changed files with 116 additions and 38 deletions.
18 changes: 14 additions & 4 deletions docs/api/classifier/prcurve.rst
Expand Up @@ -48,7 +48,7 @@ The base case for precision-recall curves is the binary classification case, and
Multi-Label Classification
--------------------------

To support multi-label classification, the estimator is wrapped in a `OneVsRestClassifier <http://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html>`_ to produce binary comparisons for each class (e.g. the positive case is the class and the negative case is any other class). The Precision-Recall curve is then computed as the micro-average of the precision and recall for all classes:
To support multi-label classification, the estimator is wrapped in a `OneVsRestClassifier <http://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html>`_ to produce binary comparisons for each class (e.g. the positive case is the class and the negative case is any other class). The precision-recall curve can then be computed as the micro-average of the precision and recall for all classes (by setting ``micro=True``), or individual curves can be plotted for each class (by setting ``per_class=True``):

.. plot::
:context: close-figs
Expand All @@ -68,7 +68,11 @@ To support multi-label classification, the estimator is wrapped in a `OneVsRestC
X_train, X_test, y_train, y_test = tts(X, y, test_size=0.2, shuffle=True)

# Create the visualizer, fit, score, and show it
viz = PrecisionRecallCurve(RandomForestClassifier(n_estimators=10))
viz = PrecisionRecallCurve(
RandomForestClassifier(n_estimators=10),
per_class=True,
cmap="Set1"
)
viz.fit(X_train, y_train)
viz.score(X_test, y_test)
viz.show()
Expand All @@ -89,15 +93,21 @@ A more complex Precision-Recall curve can be computed, however, displaying the e
# Load dataset and encode categorical variables
X, y = load_game()
X = OrdinalEncoder().fit_transform(X)

# Encode the target (we'll use the encoder to retrieve the class labels)
encoder = LabelEncoder()
y = encoder.fit_transform(y)

X_train, X_test, y_train, y_test = tts(X, y, test_size=0.2, shuffle=True)

# Create the visualizer, fit, score, and show it
viz = PrecisionRecallCurve(
MultinomialNB(), per_class=True, iso_f1_curves=True,
fill_area=False, micro=False, classes=encoder.classes_
MultinomialNB(),
classes=encoder.classes_,
colors=["purple", "cyan", "blue"],
iso_f1_curves=True,
per_class=True,
micro=False
)
viz.fit(X_train, y_train)
viz.score(X_test, y_test)
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions tests/test_classifier/test_base.py
Expand Up @@ -106,10 +106,10 @@ def test_colors_property(self):
oz = ClassificationScoreVisualizer(GaussianNB())

with pytest.raises(NotFitted, match="cannot determine colors before fit"):
oz.colors
oz.class_colors_

oz.fit(self.multiclass.X.train, self.multiclass.y.train)
assert len(oz.colors) == len(oz.classes_)
assert len(oz.class_colors_) == len(oz.classes_)

def test_decode_labels_warning(self):
"""
Expand Down
21 changes: 15 additions & 6 deletions tests/test_classifier/test_prcurve.py
Expand Up @@ -351,7 +351,7 @@ def test_quick_method(self):
fill_area=False,
iso_f1_curves=True,
ap_score=False,
show=False
show=False,
)

assert isinstance(oz, PrecisionRecallCurve)
Expand Down Expand Up @@ -435,11 +435,7 @@ def test_quick_method_with_test_set(self):
)

viz = precision_recall_curve(
RandomForestClassifier(random_state=72),
X_train,
y_train,
X_test,
y_test,
RandomForestClassifier(random_state=72), X_train, y_train, X_test, y_test
)
self.assert_images_similar(viz)

Expand All @@ -462,3 +458,16 @@ def test_missing_test_data_in_quick_method(self):

with pytest.raises(YellowbrickValueError, match=emsg):
precision_recall_curve(RandomForestClassifier(), X_train, y_train, X_test)

def test_per_class_and_micro(self):
"""
Test if both per_class and micro set to True, user gets micro ignored warning
"""
msg = (
"micro=True is ignored;"
"specify per_class=False to draw a PR curve after micro-averaging"
)
with pytest.warns(YellowbrickWarning, match=msg):
PrecisionRecallCurve(
RidgeClassifier(random_state=13), micro=True, per_class=True
)
2 changes: 1 addition & 1 deletion yellowbrick/classifier/base.py
Expand Up @@ -132,7 +132,7 @@ def __init__(
self.set_params(classes=classes, encoder=encoder, force_model=force_model)

@property
def colors(self):
def class_colors_(self):
"""
Returns ``_colors`` if it exists, otherwise computes a categorical color
per class based on the matplotlib color cycle. If the visualizer is not
Expand Down
3 changes: 2 additions & 1 deletion yellowbrick/classifier/class_prediction_error.py
Expand Up @@ -199,7 +199,7 @@ def draw(self):
self.ax,
labels=list(self.classes_),
ticks=self.classes_,
colors=self.colors,
colors=self.class_colors_,
legend_kws=legend_kws,
)
return self.ax
Expand Down Expand Up @@ -230,6 +230,7 @@ def finalize(self, **kwargs):
# Ensure the legend fits on the figure
self.fig.tight_layout(rect=[0, 0, 0.90, 1])


##########################################################################
## Quick Method
##########################################################################
Expand Down
81 changes: 69 additions & 12 deletions yellowbrick/classifier/prcurve.py
Expand Up @@ -17,6 +17,7 @@
## Imports
##########################################################################

import warnings
import numpy as np

from sklearn.preprocessing import label_binarize
Expand All @@ -25,6 +26,8 @@
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_curve as sk_precision_recall_curve

from yellowbrick.style.colors import resolve_colors
from yellowbrick.exceptions import YellowbrickWarning
from yellowbrick.exceptions import ModelError, NotFitted
from yellowbrick.exceptions import YellowbrickValueError
from yellowbrick.classifier.base import ClassificationScoreVisualizer
Expand Down Expand Up @@ -71,14 +74,26 @@ class PrecisionRecallCurve(ClassificationScoreVisualizer):
The axes to plot the figure on. If not specified the current axes will be
used (or generated if required).
classes : list of str, defult: None
classes : list of str, default: None
The class labels to use for the legend ordered by the index of the sorted
classes discovered in the ``fit()`` method. Specifying classes in this
manner is used to change the class names to a more specific format or
to label encoded integer classes. Some visualizers may also use this
field to filter the visualization for specific classes. For more advanced
usage specify an encoder rather than class labels.
colors : list of strings, default: None
An optional list or tuple of colors to colorize the curves when
``per_class=True``. If ``per_class=False``, this parameter will
be ignored. If both ``colors`` and ``cmap`` are provided,
``cmap`` will be ignored.
cmap : string or Matplotlib colormap, default: None
An optional string or Matplotlib colormap to colorize the curves
when ``per_class=True``. If ``per_class=False``, this parameter
will be ignored. If both ``colors`` and ``cmap`` are provided,
``cmap`` will be ignored.
encoder : dict or LabelEncoder, default: None
A mapping of classes to human readable labels. Often there is a mismatch
between desired class labels and those contained in the target variable
Expand Down Expand Up @@ -187,6 +202,8 @@ def __init__(
model,
ax=None,
classes=None,
colors=None,
cmap=None,
encoder=None,
fill_area=True,
ap_score=True,
Expand Down Expand Up @@ -214,6 +231,8 @@ def __init__(
self.set_params(
fill_area=fill_area,
ap_score=ap_score,
colors=colors,
cmap=cmap,
micro=micro,
iso_f1_curves=iso_f1_curves,
iso_f1_values=set(iso_f1_values),
Expand All @@ -222,6 +241,13 @@ def __init__(
line_opacity=line_opacity,
)

if self.micro and self.per_class:
warnings.warn(
"micro=True is ignored;"
"specify per_class=False to draw a PR curve after micro-averaging",
YellowbrickWarning,
)

def fit(self, X, y=None):
"""
Fit the classification model; if y is multi-class, then the estimator
Expand Down Expand Up @@ -312,6 +338,11 @@ def draw(self):
"""
Draws the precision-recall curves computed in score on the axes.
"""
# set the colors
self._colors = resolve_colors(
n_colors=len(self.classes_), colormap=self.cmap, colors=self.colors
)

if self.iso_f1_curves:
for f1 in self.iso_f1_values:
x = np.linspace(0.01, 1)
Expand All @@ -327,46 +358,54 @@ def _draw_binary(self):
"""
Draw the precision-recall curves in the binary case
"""
self._draw_pr_curve(self.recall_, self.precision_, label="binary PR curve")
self._draw_pr_curve(self.recall_, self.precision_, label="Binary PR curve")
self._draw_ap_score(self.score_)

def _draw_multiclass(self):
"""
Draw the precision-recall curves in the multiclass case
"""
# TODO: handle colors better with a mapping and user input
if self.per_class:

colors = dict(zip(self.classes_, self._colors))

for cls in self.classes_:
precision = self.precision_[cls]
recall = self.recall_[cls]

label = "PR for class {} (area={:0.2f})".format(cls, self.score_[cls])
self._draw_pr_curve(recall, precision, label=label)
self._draw_pr_curve(recall, precision, label=label, color=colors[cls])

if self.micro:
elif self.micro:
precision = self.precision_[MICRO]
recall = self.recall_[MICRO]
self._draw_pr_curve(recall, precision)
label = "Micro-average PR for all classes"
self._draw_pr_curve(recall, precision, label=label)

self._draw_ap_score(self.score_[MICRO])

def _draw_pr_curve(self, recall, precision, label=None):
def _draw_pr_curve(self, recall, precision, label=None, color=None):
"""
Helper function to draw a precision-recall curve with specified settings
"""
self.ax.step(
recall, precision, alpha=self.line_opacity, where="post", label=label
recall,
precision,
alpha=self.line_opacity,
where="post",
label=label,
color=color,
)
if self.fill_area:
if self.fill_area and not self.per_class:
self.ax.fill_between(
recall, precision, step="post", alpha=self.fill_opacity
recall, precision, step="post", alpha=self.fill_opacity, color=color
)

def _draw_ap_score(self, score, label=None):
"""
Helper function to draw the AP score annotation
"""
label = label or "Avg Precision={:0.2f}".format(score)
label = label or "Avg. precision={:0.2f}".format(score)
if self.ap_score:
self.ax.axhline(y=score, color="r", ls="--", label=label)

Expand All @@ -383,6 +422,8 @@ def finalize(self):
self.ax.set_ylabel("Precision")
self.ax.set_xlabel("Recall")

self.ax.grid(False)

def _get_y_scores(self, X):
"""
The ``precision_recall_curve`` metric requires target scores that
Expand Down Expand Up @@ -439,6 +480,8 @@ def precision_recall_curve(
y_test=None,
ax=None,
classes=None,
colors=None,
cmap=None,
encoder=None,
fill_area=True,
ap_score=True,
Expand Down Expand Up @@ -492,14 +535,26 @@ def precision_recall_curve(
The axes to plot the figure on. If not specified the current axes will be
used (or generated if required).
classes : list of str, defult: None
classes : list of str, default: None
The class labels to use for the legend ordered by the index of the sorted
classes discovered in the ``fit()`` method. Specifying classes in this
manner is used to change the class names to a more specific format or
to label encoded integer classes. Some visualizers may also use this
field to filter the visualization for specific classes. For more advanced
usage specify an encoder rather than class labels.
colors : list of strings, default: None
An optional list or tuple of colors to colorize the curves when
``per_class=True``. If ``per_class=False``, this parameter will
be ignored. If both ``colors`` and ``cmap`` are provided,
``cmap`` will be ignored.
cmap : string or Matplotlib colormap, default: None
An optional string or Matplotlib colormap to colorize the curves
when ``per_class=True``. If ``per_class=False``, this parameter
will be ignored. If both ``colors`` and ``cmap`` are provided,
``cmap`` will be ignored.
encoder : dict or LabelEncoder, default: None
A mapping of classes to human readable labels. Often there is a mismatch
between desired class labels and those contained in the target variable
Expand Down Expand Up @@ -570,6 +625,8 @@ def precision_recall_curve(
model,
ax=ax,
classes=classes,
colors=colors,
cmap=cmap,
encoder=encoder,
fill_area=fill_area,
ap_score=ap_score,
Expand Down

0 comments on commit 1879734

Please sign in to comment.