Skip to content

Commit

Permalink
Initial version of ConfusionMatrix visualizer using the matplotlib pc…
Browse files Browse the repository at this point in the history
…olormesh.

- Allows for percent or raw count representation of the predictions
- Implements heatmap with white=0, green=100%, and yellow-orange-red heatmap for everything else
- Allows zooming in on confusion matrix using passed list of classes, with accurate %-of-all-true calculations
- Tested for moderately large class numbers (30+)
- Diagonal line indicates accurate predictions
- Documentation added to docs/examples/methods.rst for one example matrix

Suggested future improvements:
- Resize font based on image size + class count
- Allow custom color coding, including custom colors for _over and _under values (e.g. zero and 100%)
- Vary text font color based on background color
- While this branch currently adds an example to methods.rst, the examples/confusionMatrix.ipynb has additional examples using different of the passed parameters. This should probably also be exported as rst and added to the docs, but there was not an obvious place to put it so I am excluding that for now.

Note this commit squashes all previous commits on this branch
  • Loading branch information
NealHumphrey committed Mar 16, 2017
1 parent 942a070 commit ceee7f8
Show file tree
Hide file tree
Showing 6 changed files with 677 additions and 0 deletions.
Binary file added docs/examples/images/confusionMatrix_3_0.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
59 changes: 59 additions & 0 deletions docs/examples/methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,65 @@ heatmap in order for easy interpretation and detection.
.. image:: images/examples_32_0.png


Confusion Matrix Visualizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~

The ``ConfusionMatrix`` visualizer is a ScoreVisualizer that takes a
fitted scikit-learn classifier and a set of test X and y values and
returns a report showing how each of the test values predicted classes
compare to their actual classes. Data scientists use confusion matrices
to understand which classes are most easily confused. These provide
similar information as what is available in a ClassificationReport, but
rather than top-level scores they provide deeper insight into the
classification of individual data points.

Below are a few examples of using the ConfusionMatrix visualizer; more
information can be found by looking at the
sklearn.metrics.confusion\_matrix documentation.

.. code:: python
#First do our imports
import yellowbrick
from sklearn.datasets import load_digits
from sklearn.cross_validation import train_test_split
from sklearn.linear_model import LogisticRegression
from yellowbrick.classifier import ConfusionMatrix
.. code:: python
# We'll use the handwritten digits data set from scikit-learn.
# Each feature of this dataset is an 8x8 pixel image of a handwritten number.
# Digits.data converts these 64 pixels into a single array of features
digits = load_digits()
X = digits.data
y = digits.target
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size =0.2, random_state=11)
model = LogisticRegression()
#The ConfusionMatrix visualizer taxes a model
cm = ConfusionMatrix(model, classes=[0,1,2,3,4,5,6,7,8,9])
#Fit fits the passed model. This is unnecessary if you pass the visualizer a pre-fitted model
cm.fit(X_train, y_train)
#To create the ConfusionMatrix, we need some test data. Score runs predict() on the data
#and then creates the confusion_matrix from scikit learn.
cm.score(X_test, y_test)
#How did we do?
cm.poof()
.. image:: images%5CconfusionMatrix_3_0.png



ROCAUC
~~~~~~

Expand Down
313 changes: 313 additions & 0 deletions examples/confusionMatrix.ipynb

Large diffs are not rendered by default.

55 changes: 55 additions & 0 deletions tests/test_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@

from sklearn.svm import LinearSVC
from sklearn.metrics import *
from sklearn.datasets import load_digits
from sklearn.cross_validation import train_test_split
from sklearn.linear_model import LogisticRegression

##########################################################################
## Data
Expand Down Expand Up @@ -68,3 +71,55 @@ def test_class_report(self):
model.fit(X,y)
visualizer = ClassificationReport(model, classes=["A", "B"])
visualizer.score(X,y)

class ConfusionMatrixTests(VisualTestCase):
def __init__(self, *args, **kwargs):
super(ConfusionMatrixTests, self).__init__(*args, **kwargs)
#Use the same data for all the tests
self.digits = load_digits()

X = self.digits.data
y = self.digits.target

X_train, X_test, y_train, y_test = train_test_split(X,y, test_size =0.2, random_state=11)
self.X_train = X_train
self.X_test = X_test
self.y_train = y_train
self.y_test = y_test

def test_confusion_matrix(self):
model = LogisticRegression()
cm = ConfusionMatrix(model, classes=[0,1,2,3,4,5,6,7,8,9])
cm.fit(self.X_train, self.y_train)
cm.score(self.X_test, self.y_test)

def test_no_classes_provided(self):
model = LogisticRegression()
cm = ConfusionMatrix(model)
cm.fit(self.X_train, self.y_train)
cm.score(self.X_test, self.y_test)

def test_raw_count_mode(self):
model = LogisticRegression()
cm = ConfusionMatrix(model)
cm.fit(self.X_train, self.y_train)
cm.score(self.X_test, self.y_test, percent=False)

def test_zoomed_in(self):
model = LogisticRegression()
cm = ConfusionMatrix(model, classes=[0,1,2])
cm.fit(self.X_train, self.y_train)
cm.score(self.X_test, self.y_test)

def test_extra_classes(self):
model = LogisticRegression()
cm = ConfusionMatrix(model, classes=[0,1,2,11])
cm.fit(self.X_train, self.y_train)
cm.score(self.X_test, self.y_test)
self.assertTrue(cm.selected_class_counts[3]==0)

def test_one_class(self):
model = LogisticRegression()
cm = ConfusionMatrix(model, classes=[0])
cm.fit(self.X_train, self.y_train)
cm.score(self.X_test, self.y_test)
238 changes: 238 additions & 0 deletions yellowbrick/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@
from sklearn.cross_validation import train_test_split
from sklearn.metrics import auc, roc_auc_score, roc_curve
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import confusion_matrix

from .exceptions import YellowbrickTypeError
from .utils import get_model_name, isestimator, isclassifier
from .base import Visualizer, ScoreVisualizer, MultiModelMixin
from .style.palettes import color_sequence, color_palette, LINE_COLOR

from .utils import numpy_div0


##########################################################################
## Classification Visualization Base Object
Expand All @@ -50,6 +53,241 @@ def __init__(self, model, ax=None, **kwargs):

super(ClassificationScoreVisualizer, self).__init__(model, ax=ax, **kwargs)

#TODO during refactoring this can be used to generalize ClassBalance
def class_counts(self, y):
unique, counts = np.unique(y, return_counts=True)
return dict(zip(unique, counts))

##########################################################################
## ConfusionMatrix
##########################################################################

class ConfusionMatrix(ClassificationScoreVisualizer):
"""
Creates a heatmap visualization of the sklearn.metrics.confusion_matrix().
Initialization: Requires a classification model
"""
def __init__(self, model, ax=None, classes=None, **kwargs):
"""
Provide a classifier model
Parameters
----------
:param model: the Scikit-Learn estimator
Should be an instance of a classifier, else the __init__ will
return an error.
:param ax: the matplotlib axis to plot the figure on (if None, a new axis will be created)
:param classes: a list of class names to use in the confusion_matrix.
This is passed to the 'labels' parameter of sklearn.metrics.confusion_matrix(), and follows the behaviour
indicated by that function. It may be used to reorder or select a subset of labels.
If None, values that appear at least once in y_true or y_pred are used in sorted order.
Default: None
"""
super(ConfusionMatrix, self).__init__(model, ax=ax, classes=None,**kwargs)
#Parameters provided by super (for reference during development only):
#self.ax
#self.size
#self.color
#self.title
#self.estimator
#self.name

#Initialize all the other attributes we'll use (for coder clarity)
self.confusion_matrix = None

self.cmap = color_sequence(kwargs.pop('cmap', 'YlOrRd'))
self.cmap.set_under(color='w')
self.cmap.set_over(color='#2a7d4f')
self.edgecolors=[] #used to draw diagonal line for predicted class = true class


#Convert list to array if necessary, since estimator.classes_ returns nparray
self._classes = None if classes == None else np.array(classes)

#TODO hoist this to shared confusion matrix / classification report heatmap class
@property
def classes(self):
'''
Returns a numpy array of the classes in y
Matches the user provided list if provided by the user in __init__
If no list provided, tries to obtain it from the fitted estimator
'''
if self._classes is None:
try:
print("trying")
return self.estimator.classes_
except AttributeError:
return None
return self._classes

@classes.setter
def classes(self, value):
self._classes = value

#todo hoist
def fit(self, X, y=None, **kwargs):
"""
Parameters
----------
X : ndarray or DataFrame of shape n x m
A matrix of n instances with m features
y : ndarray or Series of length n
An array or series of target or class values
kwargs: keyword arguments passed to Scikit-Learn API.
"""
super(ConfusionMatrix, self).fit(X, y, **kwargs)
if self._classes is None:
self.classes = self.estimator.classes_
return self

def score(self, X, y, sample_weight=None, percent=True):
"""
Generates the Scikit-Learn confusion_matrix and applies this to the appropriate axis
Parameters
----------
X : ndarray or DataFrame of shape n x m
A matrix of n instances with m features
y : ndarray or Series of length n
An array or series of target or class values
sample_weight: optional, passed to the confusion_matrix
percent: optional, Boolean. Determines whether or not the confusion_matrix
should be displayed as raw numbers or as a percent of the true
predictions. Note, if using a subset of classes in __init__, percent should
be set to False or inaccurate percents will be displayed.
"""
y_pred = self.predict(X)

self.confusion_matrix = confusion_matrix(y_true = y, y_pred = y_pred, labels=self.classes, sample_weight=sample_weight)
self._class_counts = self.class_counts(y)

#Make array of only the classes actually being used.
#Needed because sklearn confusion_matrix only returns counts for selected classes
#but percent should be calculated based on all classes
selected_class_counts = []
for c in self.classes:
try:
selected_class_counts.append(self._class_counts[c])
except KeyError:
selected_class_counts.append(0)
self.selected_class_counts = np.array(selected_class_counts)

return self.draw(percent)

def draw(self, percent=True):
"""
Renders the classification report
Should only be called internally, as it uses values calculated in Score
and score calls this method.
Parameters
----------
percent: Boolean
Whether the heatmap should represent "% of True" or raw counts
"""
# Create the axis if it doesn't exist
if self.ax is None:
self.ax = plt.gca()

if percent == True:
#Convert confusion matrix to percent of each row, i.e. the predicted as a percent of true in each class
#numpy_div0 function returns 0 instead of NAN.
self._confusion_matrix_display = numpy_div0(
self.confusion_matrix,
self.selected_class_counts
)
self._confusion_matrix_display =np.round(self._confusion_matrix_display* 100, decimals=0)
else:
self._confusion_matrix_display = self.confusion_matrix

#Y axis should be sorted top to bottom in pcolormesh
self._confusion_matrix_plottable = self._confusion_matrix_display[::-1,::]

#Set up the dimensions of the pcolormesh
X = np.linspace(start=0, stop=len(self.classes), num=len(self.classes)+1)
Y = np.linspace(start=0, stop=len(self.classes), num=len(self.classes)+1)
self.ax.set_ylim(bottom=0, top=self._confusion_matrix_plottable.shape[0])
self.ax.set_xlim(left=0, right=self._confusion_matrix_plottable.shape[1])

#Put in custom axis labels
self.xticklabels = self.classes
self.yticklabels = self.classes[::-1]
self.xticks = np.arange(0, len(self.classes), 1) + .5
self.yticks = np.arange(0, len(self.classes), 1) + .5
self.ax.set(xticks=self.xticks, yticks=self.yticks)
self.ax.set_xticklabels(self.xticklabels, rotation="vertical", fontsize=8)
self.ax.set_yticklabels(self.yticklabels, fontsize=8)

######################
# Add the data labels to each square
######################
for x_index, x in np.ndenumerate(X):
#np.ndenumerate returns a tuple for the index, must access first element using [0]
x_index = x_index[0]
for y_index, y in np.ndenumerate(Y):
#Clean up our iterators
#numpy doesn't like non integers as indexes; also np.ndenumerate returns tuple
x_int = int(x)
y_int = int(y)
y_index = y_index[0]

#X and Y are one element longer than the confusion_matrix. Don't want to add text for the last X or Y
if x_index == X[-1] or y_index == Y[-1]:
break

#center the text in the middle of the block
text_x = x + 0.5
text_y = y + 0.5

#make zero values more subtle
#TODO also add the background color-based logic from .util as in ticket #154
text_color = "0.75" if self._confusion_matrix_plottable[x_int,y_int] == 0 else "black"

#Put the data labels in the middle of the heatmap square
self.ax.text(text_y,
text_x,
"{:.0f}{}".format(self._confusion_matrix_plottable[x_int,y_int],"%" if percent==True else ""),
va='center',
ha='center',
fontsize=8,
color=text_color)

#If the prediction is correct, put a bounding box around that square to better highlight it to the user
#This will be used in ax.pcolormesh, setting now since we're iterating over the matrix
#ticklabels are conveniently already reversed properly to match the _confusion_matrix_plottalbe order
if self.xticklabels[x_int] == self.yticklabels[y_int]:
self.edgecolors.append('black')
else:
self.edgecolors.append('w')

# Draw the heatmap. vmin and vmax operate in tandem with the cmap.set_under and cmap.set_over to alter the color of 0 and 100
highest_count = self._confusion_matrix_plottable.max()
vmax = 99.999 if percent == True else highest_count
mesh = self.ax.pcolormesh(X,
Y,
self._confusion_matrix_plottable,
vmin=0.00001,
vmax=vmax,
edgecolor=self.edgecolors,
cmap=self.cmap,
linewidth='0.01') #edgecolor='0.75', linewidth='0.01'
return self.ax

def finalize(self, **kwargs):
self.set_title('{} Confusion Matrix'.format(self.name))
self.ax.set_ylabel('True Class')
self.ax.set_xlabel('Predicted Class')

##########################################################################
## Classification Report
Expand Down

0 comments on commit ceee7f8

Please sign in to comment.