Skip to content

Commit

Permalink
Merge 96ff189 into 8689a2e
Browse files Browse the repository at this point in the history
  • Loading branch information
VladSkripniuk committed May 8, 2020
2 parents 8689a2e + 96ff189 commit e6b9464
Show file tree
Hide file tree
Showing 17 changed files with 81 additions and 24 deletions.
18 changes: 14 additions & 4 deletions docs/api/classifier/prcurve.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ The base case for precision-recall curves is the binary classification case, and
Multi-Label Classification
--------------------------

To support multi-label classification, the estimator is wrapped in a `OneVsRestClassifier <http://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html>`_ to produce binary comparisons for each class (e.g. the positive case is the class and the negative case is any other class). The Precision-Recall curve is then computed as the micro-average of the precision and recall for all classes:
To support multi-label classification, the estimator is wrapped in a `OneVsRestClassifier <http://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html>`_ to produce binary comparisons for each class (e.g. the positive case is the class and the negative case is any other class). The precision-recall curve can then be computed as the micro-average of the precision and recall for all classes (by setting ``micro=True``), or individual curves can be plotted for each class (by setting ``per_class=True``):

.. plot::
:context: close-figs
Expand All @@ -68,7 +68,11 @@ To support multi-label classification, the estimator is wrapped in a `OneVsRestC
X_train, X_test, y_train, y_test = tts(X, y, test_size=0.2, shuffle=True)

# Create the visualizer, fit, score, and show it
viz = PrecisionRecallCurve(RandomForestClassifier(n_estimators=10))
viz = PrecisionRecallCurve(
RandomForestClassifier(n_estimators=10),
per_class=True,
cmap="Set1"
)
viz.fit(X_train, y_train)
viz.score(X_test, y_test)
viz.show()
Expand All @@ -89,15 +93,21 @@ A more complex Precision-Recall curve can be computed, however, displaying the e
# Load dataset and encode categorical variables
X, y = load_game()
X = OrdinalEncoder().fit_transform(X)

# Encode the target (we'll use the encoder to retrieve the class labels)
encoder = LabelEncoder()
y = encoder.fit_transform(y)

X_train, X_test, y_train, y_test = tts(X, y, test_size=0.2, shuffle=True)

# Create the visualizer, fit, score, and show it
viz = PrecisionRecallCurve(
MultinomialNB(), per_class=True, iso_f1_curves=True,
fill_area=False, micro=False, classes=encoder.classes_
MultinomialNB(),
classes=encoder.classes_,
colors=["purple", "cyan", "blue"],
iso_f1_curves=True,
per_class=True,
micro=False
)
viz.fit(X_train, y_train)
viz.score(X_test, y_test)
Expand Down
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions tests/test_classifier/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def test_colors_property(self):
oz = ClassificationScoreVisualizer(GaussianNB())

with pytest.raises(NotFitted, match="cannot determine colors before fit"):
oz.colors
oz.class_colors_

oz.fit(self.multiclass.X.train, self.multiclass.y.train)
assert len(oz.colors) == len(oz.classes_)
assert len(oz.class_colors_) == len(oz.classes_)

def test_decode_labels_warning(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion yellowbrick/classifier/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(
self.set_params(classes=classes, encoder=encoder, force_model=force_model)

@property
def colors(self):
def class_colors_(self):
"""
Returns ``_colors`` if it exists, otherwise computes a categorical color
per class based on the matplotlib color cycle. If the visualizer is not
Expand Down
2 changes: 1 addition & 1 deletion yellowbrick/classifier/class_prediction_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def draw(self):
self.ax,
labels=list(self.classes_),
ticks=self.classes_,
colors=self.colors,
colors=self.class_colors_,
legend_kws=legend_kws,
)
return self.ax
Expand Down
71 changes: 59 additions & 12 deletions yellowbrick/classifier/prcurve.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from yellowbrick.exceptions import ModelError, NotFitted
from yellowbrick.exceptions import YellowbrickValueError
from yellowbrick.classifier.base import ClassificationScoreVisualizer
from yellowbrick.style.colors import resolve_colors


# Target Type Constants
Expand Down Expand Up @@ -71,14 +72,26 @@ class PrecisionRecallCurve(ClassificationScoreVisualizer):
The axes to plot the figure on. If not specified the current axes will be
used (or generated if required).
classes : list of str, defult: None
classes : list of str, default: None
The class labels to use for the legend ordered by the index of the sorted
classes discovered in the ``fit()`` method. Specifying classes in this
manner is used to change the class names to a more specific format or
to label encoded integer classes. Some visualizers may also use this
field to filter the visualization for specific classes. For more advanced
usage specify an encoder rather than class labels.
colors : list of strings, default: None
An optional list or tuple of colors to colorize the curves when
``per_class=True``. If ``per_class=False``, this parameter will
be ignored. If both ``colors`` and ``cmap`` are provided,
``cmap`` will be ignored.
cmap : string or Matplotlib colormap, default: None
An optional string or Matplotlib colormap to colorize the curves
when ``per_class=True``. If ``per_class=False``, this parameter
will be ignored. If both ``colors`` and ``cmap`` are provided,
``cmap`` will be ignored.
encoder : dict or LabelEncoder, default: None
A mapping of classes to human readable labels. Often there is a mismatch
between desired class labels and those contained in the target variable
Expand Down Expand Up @@ -187,6 +200,8 @@ def __init__(
model,
ax=None,
classes=None,
colors=None,
cmap=None,
encoder=None,
fill_area=True,
ap_score=True,
Expand Down Expand Up @@ -214,6 +229,8 @@ def __init__(
self.set_params(
fill_area=fill_area,
ap_score=ap_score,
colors=colors,
cmap=cmap,
micro=micro,
iso_f1_curves=iso_f1_curves,
iso_f1_values=set(iso_f1_values),
Expand Down Expand Up @@ -312,6 +329,13 @@ def draw(self):
"""
Draws the precision-recall curves computed in score on the axes.
"""
# set the colors
self._colors = resolve_colors(
n_colors=len(self.classes_),
colormap=self.cmap,
colors=self.colors
)

if self.iso_f1_curves:
for f1 in self.iso_f1_values:
x = np.linspace(0.01, 1)
Expand All @@ -327,46 +351,51 @@ def _draw_binary(self):
"""
Draw the precision-recall curves in the binary case
"""
self._draw_pr_curve(self.recall_, self.precision_, label="binary PR curve")
self._draw_pr_curve(self.recall_, self.precision_, label="Binary PR curve")
self._draw_ap_score(self.score_)

def _draw_multiclass(self):
"""
Draw the precision-recall curves in the multiclass case
"""
# TODO: handle colors better with a mapping and user input
if self.per_class:

colors = dict(zip(self.classes_, self._colors))

for cls in self.classes_:
precision = self.precision_[cls]
recall = self.recall_[cls]

label = "PR for class {} (area={:0.2f})".format(cls, self.score_[cls])
self._draw_pr_curve(recall, precision, label=label)
self._draw_pr_curve(recall, precision, label=label, color=colors[cls])

if self.micro:
elif self.micro:
precision = self.precision_[MICRO]
recall = self.recall_[MICRO]
self._draw_pr_curve(recall, precision)
label = "Micro-average PR for all classes"
self._draw_pr_curve(recall, precision, label=label)

self._draw_ap_score(self.score_[MICRO])

def _draw_pr_curve(self, recall, precision, label=None):
def _draw_pr_curve(self, recall, precision, label=None, color=None):
"""
Helper function to draw a precision-recall curve with specified settings
"""
self.ax.step(
recall, precision, alpha=self.line_opacity, where="post", label=label
recall, precision, alpha=self.line_opacity, where="post",
label=label, color=color
)
if self.fill_area:
if self.fill_area and not self.per_class:
self.ax.fill_between(
recall, precision, step="post", alpha=self.fill_opacity
recall, precision, step="post", alpha=self.fill_opacity,
color=color
)

def _draw_ap_score(self, score, label=None):
"""
Helper function to draw the AP score annotation
"""
label = label or "Avg Precision={:0.2f}".format(score)
label = label or "Avg. precision={:0.2f}".format(score)
if self.ap_score:
self.ax.axhline(y=score, color="r", ls="--", label=label)

Expand All @@ -383,6 +412,8 @@ def finalize(self):
self.ax.set_ylabel("Precision")
self.ax.set_xlabel("Recall")

self.ax.grid(False)

def _get_y_scores(self, X):
"""
The ``precision_recall_curve`` metric requires target scores that
Expand Down Expand Up @@ -439,6 +470,8 @@ def precision_recall_curve(
y_test=None,
ax=None,
classes=None,
colors=None,
cmap=None,
encoder=None,
fill_area=True,
ap_score=True,
Expand Down Expand Up @@ -492,14 +525,26 @@ def precision_recall_curve(
The axes to plot the figure on. If not specified the current axes will be
used (or generated if required).
classes : list of str, defult: None
classes : list of str, default: None
The class labels to use for the legend ordered by the index of the sorted
classes discovered in the ``fit()`` method. Specifying classes in this
manner is used to change the class names to a more specific format or
to label encoded integer classes. Some visualizers may also use this
field to filter the visualization for specific classes. For more advanced
usage specify an encoder rather than class labels.
colors : list of strings, default: None
An optional list or tuple of colors to colorize the curves when
``per_class=True``. If ``per_class=False``, this parameter will
be ignored. If both ``colors`` and ``cmap`` are provided,
``cmap`` will be ignored.
cmap : string or Matplotlib colormap, default: None
An optional string or Matplotlib colormap to colorize the curves
when ``per_class=True``. If ``per_class=False``, this parameter
will be ignored. If both ``colors`` and ``cmap`` are provided,
``cmap`` will be ignored.
encoder : dict or LabelEncoder, default: None
A mapping of classes to human readable labels. Often there is a mismatch
between desired class labels and those contained in the target variable
Expand Down Expand Up @@ -570,6 +615,8 @@ def precision_recall_curve(
model,
ax=ax,
classes=classes,
colors=colors,
cmap=cmap,
encoder=encoder,
fill_area=fill_area,
ap_score=ap_score,
Expand Down
6 changes: 3 additions & 3 deletions yellowbrick/classifier/rocauc.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def draw(self):
-------
ax : the axis with the plotted figure
"""
colors = self.colors[0 : len(self.classes_)]
colors = self.class_colors_[0 : len(self.classes_)]
n_classes = len(colors)

# If it's a binary decision, plot the single ROC curve
Expand Down Expand Up @@ -385,7 +385,7 @@ def draw(self):
self.fpr[MICRO],
self.tpr[MICRO],
linestyle="--",
color=self.colors[len(self.classes_) - 1],
color=self.class_colors_[len(self.classes_) - 1],
label="micro-average ROC curve, AUC = {:0.2f}".format(
self.roc_auc["micro"]
),
Expand All @@ -397,7 +397,7 @@ def draw(self):
self.fpr[MACRO],
self.tpr[MACRO],
linestyle="--",
color=self.colors[len(self.classes_) - 1],
color=self.class_colors_[len(self.classes_) - 1],
label="macro-average ROC curve, AUC = {:0.2f}".format(
self.roc_auc["macro"]
),
Expand Down
2 changes: 1 addition & 1 deletion yellowbrick/contrib/classifier/boundaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def draw(self, X, y=None, **kwargs):
X = self._select_feature_columns(X)

color_cycle = iter(
resolve_colors(colors=self.colors, n_colors=len(self.classes_))
resolve_colors(colors=self.class_colors_, n_colors=len(self.classes_))
)
colors = OrderedDict([(c, next(color_cycle)) for c in self.classes_.keys()])

Expand Down

0 comments on commit e6b9464

Please sign in to comment.