Skip to content

Commit

Permalink
predicted error regressor
Browse files Browse the repository at this point in the history
  • Loading branch information
bbengfort committed Jun 3, 2016
1 parent 602cabd commit be63645
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 27 deletions.
91 changes: 72 additions & 19 deletions examples/examples.ipynb

Large diffs are not rendered by default.

9 changes: 5 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Dependencies
matplotlib==1.5.1
scipy==0.17.1
scikit-learn==0.17.1
numpy==1.11.0
cycler==0.10.0

## Utilities
#cycler==0.10.0
#pyparsing==2.1.4
#pytz==2016.4
#python-dateutil==2.5.3
Expand All @@ -13,13 +14,13 @@ numpy==1.11.0

## Testing Requirements (uncomment for development)
#nose==1.3.7
#coverage==4.0.3
#coverage==4.1

## Build Requirements (uncomment for deployment)
#wheel==0.29.0

## Pip stuff (ignore)
#Python==2.7.11
#Python==3.5.1
#pip==8.1.2
#setuptools==21.0.0
#setuptools==22.0.5
#wsgiref==0.1.2
52 changes: 52 additions & 0 deletions tests/test_regressor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# tests.test_regressor
# Ensure that the regressor visualizations work.
#
# Author: Benjamin Bengfort <bbengfort@districtdatalabs.com>
# Created: Fri Jun 03 14:20:02 2016 -0700
#
# Copyright (C) 2016 District Data Labs
# For license information, see LICENSE.txt
#
# ID: test_regressor.py [] benjamin@bengfort.com $

"""
Ensure that the regressor visualizations work.
"""

##########################################################################
## Imports
##########################################################################

import unittest

from yellowbrick.regressor import *
from yellowbrick.utils import *

from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import SVR

##########################################################################
## Prediction error test case
##########################################################################

class PredictionErrorTests(unittest.TestCase):

def test_init_pe_viz(self):
"""
Ensure that both a single model and multiple models can be rendered
"""
viz = PredictionError([RandomForestRegressor(), SVR()])
self.assertEqual(len(viz.models), 2)

viz = PredictionError(SVR())
self.assertEqual(len(viz.models), 1)

def test_init_pe_names(self):
"""
Ensure that model names are correctly extracted
"""
viz = PredictionError([RandomForestRegressor(), SVR()])
self.assertEqual(viz.names, ["RandomForestRegressor", "SVR"])

viz = PredictionError(SVR())
self.assertEqual(viz.names, ["SVR"])
50 changes: 47 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,74 @@
from sklearn.pipeline import Pipeline
import unittest

from yellowbrick.utils import get_model_name
from yellowbrick.utils import get_model_name, isestimator


class ModelNameTests(unittest.TestCase):

def test_real_model(self):
"""
Test that model name works for sklearn estimators
"""
model1 = LassoCV()
model2 = LSHForest()
self.assertEqual(get_model_name(model1), 'LassoCV')
self.assertEqual(get_model_name(model2), 'LSHForest')

def test_pipeline(self):
"""
Test that model name works for sklearn pipelines
"""
pipeline = Pipeline([('reduce_dim', PCA()),
('linreg', LinearRegression())])
self.assertEqual(get_model_name(pipeline), 'LinearRegression')

def test_int_input(self):

"""
Assert a type error is raised when an int is passed to model name.
"""
self.assertRaises(TypeError, get_model_name, 1)

def test_str_input(self):

"""
Assert a type error is raised when a str is passed to model name.
"""
self.assertRaises(TypeError, get_model_name, 'helloworld')

def test_estimator_instance(self):
"""
Test that isestimator works for instances
"""
model = LinearRegression()
self.assertTrue(isestimator(model))

def test_pipeline_instance(self):
"""
Test that isestimator works for pipelines
"""
model = Pipeline([
('reduce_dim', PCA()),
('linreg', LinearRegression())
])

self.assertTrue(isestimator(model))

def test_estimator_class(self):
"""
Test that isestimator works for classes
"""
self.assertTrue(LinearRegression)

def test_collection_not_estimator(self):
"""
Make sure that a collection is not an estimator
"""
for cls in (list, dict, tuple, set):
self.assertFalse(isestimator(cls))

things = ['pepper', 'sauce', 'queen']
self.assertFalse(isestimator(things))


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions yellowbrick/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .version import get_version
from .anscombe import anscombe
from .classifier import crplot, rocplot_compare
from .regressor import peplot

##########################################################################
## Package Version
Expand Down
71 changes: 70 additions & 1 deletion yellowbrick/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,81 @@
## Imports
##########################################################################

from .base import ModelVisualization
import matplotlib as mpl
import matplotlib.pyplot as plt

from .base import ModelVisualization
from .utils import get_model_name, isestimator
from sklearn.cross_validation import cross_val_predict as cvp

##########################################################################
## Regression Visualization Base Object
##########################################################################

class RegressorVisualization(ModelVisualization):
pass


##########################################################################
## Prediction Error Plots
##########################################################################

class PredictionError(RegressorVisualization):

def __init__(self, models, **kwargs):
"""
Pass in a collection of models to generate prediction error graphs.
"""

# Ensure models is a collection, if it's a single estimator then we
# wrap it in a list so that the API doesn't break during render.
if isestimator(models):
models = [models]

# Keep track of the models
self.models = models
self.names = kwargs.pop('names', list(map(get_model_name, models)))
self.colors = {
'point': kwargs.pop('point_color', '#F2BE2C'),
'line': kwargs.pop('line_color', '#2B94E9'),
}

def generate_subplots(self):
"""
Generates the subplots for the number of given models.
"""
_, axes = plt.subplots(len(self.models), sharex=True, sharey=True)
return axes

def predict(self, X, y):
"""
Returns a generator containing the predictions for each of the
internal models (using cross_val_predict and a CV=12).
"""
for model in self.models:
yield cvp(model, X, y, cv=12)

def render(self, X, y):
"""
Renders each of the scatter plots per matrix.
"""
for idx, (axe, y_pred) in enumerate(zip(self.generate_subplots(), self.predict(X, y))):
# Plot the correct values
axe.scatter(y, y_pred, c=self.colors['point'])

# Draw the best fit line
# TODO: Add best fit line computation metric
axe.plot([y.min(), y.max()], [y.min(), y.max()], 'k--', lw=4, c=self.colors['line'])

# Set the title and the y-axis label
axe.set_title("Predicted vs. Actual Values for {}".format(self.names[idx]))
axe.set_ylabel('Predicted Value')

# Finalize figure
plt.xlabel('Actual Value')
return axe # TODO: We shouldn't return the last axis


def peplot(models, X, y, **kwargs):
viz = PredictionError(models, **kwargs)
return viz.render(X, y)
13 changes: 13 additions & 0 deletions yellowbrick/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
##########################################################################
## Imports
##########################################################################

from sklearn.pipeline import Pipeline
from sklearn.base import BaseEstimator

##########################################################################
## Model detection utilities
##########################################################################

def get_model_name(model):
"""
Expand All @@ -28,3 +32,12 @@ def get_model_name(model):
return model.steps[-1][-1].__class__.__name__
else:
return model.__class__.__name__

def isestimator(model):
"""
Determines if a model is an estimator using issubclass and isinstance.
"""
if type(model) == type:
return issubclass(model, BaseEstimator)

return isinstance(model, BaseEstimator)

0 comments on commit be63645

Please sign in to comment.