Skip to content

Commit

Permalink
residuals plots
Browse files Browse the repository at this point in the history
  • Loading branch information
bbengfort committed Jun 3, 2016
1 parent be63645 commit 742f937
Show file tree
Hide file tree
Showing 4 changed files with 167 additions and 42 deletions.
68 changes: 53 additions & 15 deletions examples/examples.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion yellowbrick/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from .version import get_version
from .anscombe import anscombe
from .classifier import crplot, rocplot_compare
from .regressor import peplot
from .regressor import peplot, residuals_plot

##########################################################################
## Package Version
Expand Down
35 changes: 35 additions & 0 deletions yellowbrick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
Abstract base classes and interface for Yellowbrick.
"""

import matplotlib.pyplot as plt

from .exceptions import YellowbrickTypeError
from .utils import get_model_name, isestimator
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.cross_validation import cross_val_predict as cvp

##########################################################################
## Base class hierarhcy
Expand Down Expand Up @@ -79,3 +83,34 @@ def render(self, model=None):
raise NotImplementedError(
"Please specify how to render the model visualization"
)


class MultiModelMixin(object):
"""
Does predict for each of the models and generates subplots.
"""

def __init__(self, models, **kwargs):
# Ensure models is a collection, if it's a single estimator then we
# wrap it in a list so that the API doesn't break during render.
if isestimator(models):
models = [models]

# Keep track of the models
self.models = models
self.names = kwargs.pop('names', list(map(get_model_name, models)))

def generate_subplots(self):
"""
Generates the subplots for the number of given models.
"""
_, axes = plt.subplots(len(self.models), sharex=True, sharey=True)
return axes

def predict(self, X, y):
"""
Returns a generator containing the predictions for each of the
internal models (using cross_val_predict and a CV=12).
"""
for model in self.models:
yield cvp(model, X, y, cv=12)
104 changes: 78 additions & 26 deletions yellowbrick/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
import matplotlib as mpl
import matplotlib.pyplot as plt

from .base import ModelVisualization
from .utils import get_model_name, isestimator
from sklearn.cross_validation import cross_val_predict as cvp
from .base import ModelVisualization, MultiModelMixin
from sklearn.cross_validation import train_test_split as tts

##########################################################################
## Regression Visualization Base Object
Expand All @@ -36,41 +36,19 @@ class RegressorVisualization(ModelVisualization):
## Prediction Error Plots
##########################################################################

class PredictionError(RegressorVisualization):
class PredictionError(MultiModelMixin, RegressorVisualization):

def __init__(self, models, **kwargs):
"""
Pass in a collection of models to generate prediction error graphs.
"""
super(PredictionError, self).__init__(models, **kwargs)

# Ensure models is a collection, if it's a single estimator then we
# wrap it in a list so that the API doesn't break during render.
if isestimator(models):
models = [models]

# Keep track of the models
self.models = models
self.names = kwargs.pop('names', list(map(get_model_name, models)))
self.colors = {
'point': kwargs.pop('point_color', '#F2BE2C'),
'line': kwargs.pop('line_color', '#2B94E9'),
}

def generate_subplots(self):
"""
Generates the subplots for the number of given models.
"""
_, axes = plt.subplots(len(self.models), sharex=True, sharey=True)
return axes

def predict(self, X, y):
"""
Returns a generator containing the predictions for each of the
internal models (using cross_val_predict and a CV=12).
"""
for model in self.models:
yield cvp(model, X, y, cv=12)

def render(self, X, y):
"""
Renders each of the scatter plots per matrix.
Expand All @@ -93,5 +71,79 @@ def render(self, X, y):


def peplot(models, X, y, **kwargs):
# TODO: Docstring or this won't be understandable.
viz = PredictionError(models, **kwargs)
return viz.render(X, y)

##########################################################################
## Residuals Plots
##########################################################################

class ResidualsPlot(MultiModelMixin, RegressorVisualization):
"""
Unlike PredictionError, this viz takes classes instead of model instances
we should revise the API to have FittedRegressorVisualization vs. etc.
TODO: Fitted vs. Unfitted API.
"""

def __init__(self, models, **kwargs):
"""
Pass in a collection of model classes to generate train/test residual
plots by fitting the models and ... someone finish this docstring.
"""
super(ResidualsPlot, self).__init__(models, **kwargs)

# TODO: the names for the color arguments are _long_.
self.colors = {
'train_point': kwargs.pop('train_point_color', '#2B94E9'),
'test_point': kwargs.pop('test_point_color', '#94BA65'),
'line': kwargs.pop('line_color', '#333333'),
}

def fit(self, X, y):
"""
Fit all three models and also store the train/test splits.
TODO: move to MultiModelMixin.
"""
# TODO: make test size a parameter and do better data storage on viz.
self.X_train, self.X_test, self.y_train, self.y_test = tts(X, y, test_size=0.2)
self.models = list(map(lambda model: model.fit(self.X_train, self.y_train), self.models))

def render(self):
"""
Renders each residual plot across each axis.
"""

for idx, axe in enumerate(self.generate_subplots()):
# Get the information for this axis
model = self.models[idx]
name = self.names[idx]

# TODO: less proceedural?
# Add the training residuals
y_train_pred = model.predict(self.X_train)
axe.scatter(y_train_pred, y_train_pred - self.y_train, c=self.colors['train_point'], s=40, alpha=0.5)

# Add the test residuals
y_test_pred = model.predict(self.X_test)
axe.scatter(y_test_pred, y_test_pred - self.y_test, c=self.colors['test_point'], s=40)

# Add the hline and other axis elements
# TODO: better parameters based on the plot or, normalize, then push -1 to 1
axe.hlines(y=0, xmin=0, xmax=100)
axe.set_title(name)
axe.set_ylabel('Residuals')

# Finalize the residuals plot
# TODO: adjust the x and y ranges in order to compare (or use normalize)
plt.xlabel("Predicted Value")
return axe # TODO: We shouldn't return the last axis


def residuals_plot(models, X, y, **kwargs):
# TODO: Docstring or this won't be understandable.
viz = ResidualsPlot(models, **kwargs)
viz.fit(X, y)
return viz.render()

0 comments on commit 742f937

Please sign in to comment.