Skip to content

Commit

Permalink
more utility tests and some new decorators, closed #89
Browse files Browse the repository at this point in the history
  • Loading branch information
bbengfort committed Nov 7, 2016
1 parent b184134 commit 5861cf1
Show file tree
Hide file tree
Showing 4 changed files with 433 additions and 18 deletions.
306 changes: 298 additions & 8 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,42 @@
## Imports
##########################################################################

import inspect
import unittest

from sklearn.cluster import KMeans
from sklearn.pipeline import Pipeline
from sklearn.decomposition import PCA
from sklearn.neighbors import LSHForest
from sklearn.linear_model import LassoCV
from sklearn.linear_model import LinearRegression
from sklearn.neighbors import LSHForest
from sklearn.pipeline import Pipeline
import unittest
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

from yellowbrick.utils import get_model_name, isestimator
from yellowbrick.utils import *
from yellowbrick.base import Visualizer, ScoreVisualizer, ModelVisualizer


class ModelNameTests(unittest.TestCase):
##########################################################################
## Model Utility Tests
##########################################################################

class ModelUtilityTests(unittest.TestCase):

##////////////////////////////////////////////////////////////////////
## get_model_name testing
##////////////////////////////////////////////////////////////////////

def test_real_model(self):
"""
Test that model name works for sklearn estimators
"""
model1 = LassoCV()
model2 = LSHForest()
model3 = KMeans()
self.assertEqual(get_model_name(model1), 'LassoCV')
self.assertEqual(get_model_name(model2), 'LSHForest')
self.assertEqual(get_model_name(model3), 'KMeans')

def test_pipeline(self):
"""
Expand All @@ -58,12 +74,35 @@ def test_str_input(self):
"""
self.assertRaises(TypeError, get_model_name, 'helloworld')

##////////////////////////////////////////////////////////////////////
## isestimator testing
##////////////////////////////////////////////////////////////////////

def test_estimator_alias(self):
"""
Assert is_estimator aliases isestimator
"""
self.assertEqual(
is_estimator(LinearRegression), isestimator(LinearRegression)
)

def test_estimator_instance(self):
"""
Test that isestimator works for instances
"""
model = LinearRegression()
self.assertTrue(isestimator(model))

models = (
LinearRegression(),
LogisticRegression(),
KMeans(),
LSHForest(),
PCA(),
LassoCV(),
RandomForestClassifier(),
)

for model in models:
self.assertTrue(isestimator(model))

def test_pipeline_instance(self):
"""
Expand All @@ -80,7 +119,19 @@ def test_estimator_class(self):
"""
Test that isestimator works for classes
"""
self.assertTrue(LinearRegression)
models = (
LinearRegression,
LogisticRegression,
KMeans,
LSHForest,
PCA,
LassoCV,
RandomForestClassifier,
)

for model in models:
self.assertTrue(inspect.isclass(model))
self.assertTrue(isestimator(model))

def test_collection_not_estimator(self):
"""
Expand All @@ -92,6 +143,245 @@ def test_collection_not_estimator(self):
things = ['pepper', 'sauce', 'queen']
self.assertFalse(isestimator(things))

def test_visualizer_is_estimator(self):
"""
Assert that a Visualizer is an estimator
"""
self.assertTrue(is_estimator(Visualizer))
self.assertTrue(is_estimator(Visualizer()))
self.assertTrue(is_estimator(ScoreVisualizer))
self.assertTrue(is_estimator(ScoreVisualizer(LinearRegression())))
self.assertTrue(is_estimator(ModelVisualizer))
self.assertTrue(is_estimator(ModelVisualizer(LogisticRegression())))

##////////////////////////////////////////////////////////////////////
## isregressor testing
##////////////////////////////////////////////////////////////////////

def test_regressor_alias(self):
"""
Assert is_regressor aliases isregressor
"""
instance = LinearRegression()
self.assertEqual(is_regressor(instance), isregressor(instance))

def test_regressor_instance(self):
"""
Test that is_regressor works for instances
"""

# Test regressors are identified correctly
regressors = (
LassoCV,
LinearRegression,
)

for model in regressors:
instance = model()
self.assertTrue(is_regressor(instance))

# Test that non-regressors are identified correctly
notregressors = (
KMeans,
PCA,
LSHForest,
LogisticRegression,
RandomForestClassifier,
)

for model in notregressors:
instance = model()
self.assertFalse(is_regressor(instance))

def test_regressor_class(self):
"""
Test that is_regressor works for classes
"""

# Test regressors are identified correctly
regressors = (
LassoCV,
LinearRegression,
)

for klass in regressors:
self.assertTrue(inspect.isclass(klass))
self.assertTrue(is_regressor(klass))

# Test that non-regressors are identified correctly
notregressors = (
KMeans,
PCA,
LSHForest,
LogisticRegression,
RandomForestClassifier,
)

for klass in notregressors:
self.assertTrue(inspect.isclass(klass))
self.assertFalse(is_regressor(klass))

def test_regressor_pipeline(self):
"""
Test that is_regressor works for pipelines
"""
model = Pipeline([
('reduce_dim', PCA()),
('linreg', LinearRegression())
])

self.assertTrue(is_regressor(model))

def test_regressor_visualizer(self):
"""
Test that is_regressor works on visualizers
"""
model = ScoreVisualizer(LinearRegression())
self.assertTrue(is_regressor(model))

##////////////////////////////////////////////////////////////////////
## isclassifier testing
##////////////////////////////////////////////////////////////////////

def test_classifier_alias(self):
"""
Assert is_classifier aliases isclassifier
"""
instance = LogisticRegression()
self.assertEqual(is_classifier(instance), isclassifier(instance))

def test_classifier_instance(self):
"""
Test that is_classifier works for instances
"""

# Test classifiers are identified correctly
classifiers = (
LogisticRegression,
RandomForestClassifier,
)

for model in classifiers:
instance = model()
self.assertTrue(is_classifier(instance))

# Test that non-classifiers are identified correctly
notclassifiers = (
KMeans,
PCA,
LSHForest,
LinearRegression,
LassoCV,
)

for model in notclassifiers:
instance = model()
self.assertFalse(is_classifier(instance))

def test_classifier_class(self):
"""
Test that is_classifier works for classes
"""

# Test classifiers are identified correctly
classifiers = (
RandomForestClassifier,
LogisticRegression,
)

for klass in classifiers:
self.assertTrue(inspect.isclass(klass))
self.assertTrue(is_classifier(klass))

# Test that non-regressors are identified correctly
notclassifiers = (
KMeans,
PCA,
LSHForest,
LassoCV,
LinearRegression,
)

for klass in notclassifiers:
self.assertTrue(inspect.isclass(klass))
self.assertFalse(is_classifier(klass))

def test_classifier_pipeline(self):
"""
Test that is_regressor works for pipelines
"""
model = Pipeline([
('reduce_dim', PCA()),
('linreg', LogisticRegression())
])

self.assertTrue(is_classifier(model))

def test_classifier_visualizer(self):
"""
Test that is_classifier works on visualizers
"""
model = ScoreVisualizer(RandomForestClassifier())
self.assertTrue(is_classifier(model))

##########################################################################
## Decorator Tests
##########################################################################

class DecoratorTests(unittest.TestCase):
"""
Tests for the decorator utilities.
"""

def test_docutil(self):
"""
Test the docutil docstring copying methodology.
"""

class Visualizer(object):

def __init__(self):
"""
This is the correct docstring.
"""
pass


def undecorated(*args, **kwargs):
"""
This is an undecorated function string.
"""
pass

# Test the undecorated string to protect from magic
self.assertEqual(
undecorated.__doc__.strip(), "This is an undecorated function string."
)

# Decorate manually and test the newly decorated return function.
decorated = docutil(Visualizer.__init__)(undecorated)
self.assertEqual(
decorated.__doc__.strip(), "This is the correct docstring."
)

# Assert that decoration modifies the original function.
self.assertEqual(
undecorated.__doc__.strip(), "This is the correct docstring."
)

@docutil(Visualizer.__init__)
def sugar(*args, **kwargs):
pass

# Assert that syntactic sugar works as expected.
self.assertEqual(
sugar.__doc__.strip(), "This is the correct docstring."
)


##########################################################################
## Execute Tests
##########################################################################

if __name__ == "__main__":
unittest.main()
4 changes: 2 additions & 2 deletions yellowbrick/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ class ModelVisualizer(Visualizer):
and enables the user to visualize the performance of models across a range
of hyperparameter values (e.g. using VisualGridsearch and ValidationCurve).
"""
def __init__(self, ax=None, **kwargs):
def __init__(self, model, ax=None, **kwargs):
"""
Parameters
----------
Expand All @@ -222,7 +222,7 @@ def __init__(self, ax=None, **kwargs):
These parameters can be influenced later on in the visualization
process, but can and should be set as early as possible.
"""
super(ScoreVisualizer, self).__init__(ax=ax, **kwargs)
super(ModelVisualizer, self).__init__(ax=ax, **kwargs)
self.estimator = model


Expand Down
1 change: 0 additions & 1 deletion yellowbrick/features/pcoords.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from yellowbrick.exceptions import YellowbrickTypeError
from yellowbrick.style.colors import resolve_colors, get_color_cycle


##########################################################################
## Quick Methods
##########################################################################
Expand Down

0 comments on commit 5861cf1

Please sign in to comment.