Skip to content

Commit

Permalink
added updated ClassificationReport and ROCAUC classes to classifier.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rebeccabilbro committed Oct 1, 2016
1 parent 482c08b commit 92bba31
Showing 1 changed file with 80 additions and 101 deletions.
181 changes: 80 additions & 101 deletions yellowbrick/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,169 +21,148 @@
import matplotlib.pyplot as plt

from sklearn.pipeline import Pipeline
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import classification_report
from sklearn.metrics import auc, roc_auc_score, roc_curve
from sklearn.metrics import precision_recall_fscore_support

from .color_utils import ddlheatmap
from .utils import get_model_name, isestimator
from .base import ModelVisualization, MultiModelMixin

from .base import Visualizer, ScoreVisualizer, MultiModelMixin

##########################################################################
## Classification Visualization Base Object
##########################################################################

class ClassifierVisualization(ModelVisualization):
pass
class ClassificationScoreVisualizer(ScoreVisualizer):

def __init__(self, model):
"""
Check to see if model is an instance of a classifer.
Should return a metrics mismatch error if it isn't.
"""
pass

##########################################################################
## Classification Report
##########################################################################

class ClassifierReport(ClassifierVisualization):
class ClassificationReport(ClassificationScoreVisualizer):
"""
Classification report that shows the precision, recall, and F1 scores
for the model. Integrates numerical scores as well color-coded heatmap.
"""

def __init__(self, model, **kwargs):
self.model = model
self.cmap = kwargs.pop('cmap', ddlheatmap)
self.name = kwargs.pop('name', get_model_name(model))
self.report = None


def parse_report(self):
"""
Custom classification_report parsing utility
Pass in a fitted model to generate a ROC curve.
"""

if self.report is None:
raise ModelError("Call score() before generating the model for parsing.")

# TODO: make a bit more robust, or look for the sklearn util that doesn't stringify
lines = self.report.split('\n')
classes = []
matrix = []

for line in lines[2:(len(lines)-3)]:
s = line.split()
classes.append(s[0])
value = [float(x) for x in s[1: len(s) - 1]]
matrix.append(value)

return matrix, classes
self.estimator = model
self.name = get_model_name(self.estimator)
self.cmap = kwargs.pop('cmap', ddlheatmap)
self.classes = model.classes_


def score(self, y_true, y_pred, **kwargs):
def score(self, y, y_pred=None, **kwargs):
"""
Generates the Scikit-Learn classification_report
"""
# TODO: Do a better job of guessing defaults from the model
cr_kwargs = {
'labels': kwargs.pop('labels', None),
'target_names': kwargs.pop('target_names', None),
'sample_weight': kwargs.pop('sample_weight', None),
'digits': kwargs.pop('digits', 2)
}
self.keys = ('precision', 'recall', 'f1')
self.scores = precision_recall_fscore_support(y, y_pred, labels=self.classes)
self.scores = map(lambda s: dict(zip(self.classes, s)), self.scores[0:3])
self.scores = dict(zip(self.keys, self.scores))
self._draw(y, y_pred)

self.report = classification_report(y_true, y_pred, **cr_kwargs)


def render(self):
def _draw(self, y, y_pred):
"""
Renders the classification report across each axis.
"""
title = '{} Classification Report'.format(self.name)
matrix, classes = self.parse_report()

fig, ax = plt.subplots(1)

for column in range(len(matrix)+1):
for row in range(len(classes)):
txt = matrix[row][column]
ax.text(column,row,matrix[row][column],va='center',ha='center')
self.matrix = []
for cls in self.classes:
self.matrix.append([self.scores['precision'][cls],self.scores['recall'][cls],self.scores['f1'][cls]])

fig = plt.imshow(matrix, interpolation='nearest', cmap=self.cmap)
plt.title(title)
plt.colorbar()
x_tick_marks = np.arange(len(classes)+1)
y_tick_marks = np.arange(len(classes))
plt.xticks(x_tick_marks, ['precision', 'recall', 'f1-score'], rotation=45)
plt.yticks(y_tick_marks, classes)
plt.ylabel('Classes')
plt.xlabel('Measures')
for column in range(len(self.matrix)+1):
for row in range(len(self.classes)):
ax.text(column,row,self.matrix[row][column],va='center',ha='center')

fig = plt.imshow(self.matrix, interpolation='nearest', cmap=self.cmap)
return ax


def crplot(model, y_true, y_pred, **kwargs):
"""
Plots a classification report as a heatmap. (More to follow).
"""
viz = ClassifierReport(model, **kwargs)
viz.score(y_true, y_pred, **kwargs)
def poof(self):
"""
Plots a classification report as a heatmap.
"""
plt.title('{} Classification Report'.format(self.name))
plt.colorbar()
x_tick_marks = np.arange(len(self.classes)+1)
y_tick_marks = np.arange(len(self.classes))
plt.xticks(x_tick_marks, ['precision', 'recall', 'f1-score'], rotation=45)
plt.yticks(y_tick_marks, self.classes)
plt.ylabel('Classes')
plt.xlabel('Measures')

return viz.render()
return plt


##########################################################################
## Receiver Operating Characteristics
##########################################################################

class ROCAUC(MultiModelMixin, ClassifierVisualization):
class ROCAUC(ClassificationScoreVisualizer):
"""
Plot the ROC to visualize the tradeoff between the classifier's
sensitivity and specificity.
"""
def __init__(self, models, **kwargs):
def __init__(self, model, **kwargs):
"""
Pass in a collection of models to generate ROC curves.
Pass in a model to generate a ROC curve.
"""
super(ROCAUC, self).__init__(models, **kwargs)
self.estimator = model
self.name = get_model_name(self.estimator)
super(ROCAUC, self).__init__(model, **kwargs)
self.colors = {
'roc': kwargs.pop('roc_color', '#2B94E9'),
'diagonal': kwargs.pop('diagonal_color', '#666666'),
}

def fit(self, X, y):
"""
Custom fit method
"""
self.models = list(map(lambda model: model.fit(X, y), self.models))

def render(self, X, y):
def fit(self):
pass

def predict(self):
pass

def score(self, y, y_pred=None):
self.fpr, self.tpr, self.thresholds = roc_curve(y, y_pred)
self.roc_auc = auc(self.fpr, self.tpr)
self._draw(y, y_pred)

def _draw(self, y, y_pred):
"""
Renders each ROC-AUC plot across each axis.
Renders ROC-AUC plot.
Called internally by score, possibly more than once
"""
for idx, axe in enumerate(self.generate_subplots()):
# Get the information for this axis
name = self.names[idx]
model = self.models[idx]
y_pred = model.predict(X)
fpr, tpr, thresholds = roc_curve(y, y_pred)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(self.fpr, self.tpr, c=self.colors['roc'], label='AUC = {:0.2f}'.format(self.roc_auc))

axe.plot(fpr, tpr, c=self.colors['roc'], label='AUC = {:0.2f}'.format(roc_auc))
# Plot the line of no discrimination to compare the curve to.
plt.plot([0,1],[0,1],'m--',c=self.colors['diagonal'])

# Plot the line of no discrimination to compare the curve to.
axe.plot([0,1],[0,1],'m--',c=self.colors['diagonal'])

axe.set_title('ROC for {}'.format(name))
axe.legend(loc='lower right')

plt.xlim([0,1])
plt.ylim([0,1.1])
def poof(self, **kwargs):
"""
Called by user.
return axe
Only takes self.
Take in the model as input and generates a plot of
the ROC plots with AUC metrics embedded.
"""
plt.title('ROC for {}'.format(self.name))
plt.legend(loc='lower right')

def rocplot(models, X, y, **kwargs):
"""
Take in the model, data and labels as input and generate a multi-plot of
the ROC plots with AUC metrics embedded.
"""
viz = ROCAUC(models, **kwargs)
viz.fit(X, y)
plt.xlim([-0.02,1])
plt.ylim([0,1.1])

return viz.render(X, y)
return plt

0 comments on commit 92bba31

Please sign in to comment.