Skip to content

Commit

Permalink
Merge 5f2bbae into 994cc20
Browse files Browse the repository at this point in the history
  • Loading branch information
rebeccabilbro committed Jun 10, 2020
2 parents 994cc20 + 5f2bbae commit 9448675
Show file tree
Hide file tree
Showing 5 changed files with 216 additions and 8 deletions.
84 changes: 79 additions & 5 deletions docs/api/regressor/alphas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Alpha Selection

Regularization is designed to penalize model complexity, therefore the higher the alpha, the less complex the model, decreasing the error due to variance (overfit). Alphas that are too high on the other hand increase the error due to bias (underfit). It is important, therefore to choose an optimal alpha such that the error is minimized in both directions.

The AlphaSelection Visualizer demonstrates how different values of alpha influence model selection during the regularization of linear models. Generally speaking, alpha increases the affect of regularization, e.g. if alpha is zero there is no regularization and the higher the alpha, the more the regularization parameter influences the final model.
The ``AlphaSelection`` Visualizer demonstrates how different values of alpha influence model selection during the regularization of linear models. Generally speaking, alpha increases the affect of regularization, e.g. if alpha is zero there is no regularization and the higher the alpha, the more the regularization parameter influences the final model.

================= ==============================
Visualizer :class:`~yellowbrick.regressor.alphas.AlphaSelection`
Expand All @@ -14,6 +14,22 @@ Models Regression
Workflow Model selection, Hyperparameter tuning
================= ==============================

For Estimators *with* Built-in Cross-Validation
-----------------------------------------------

The ``AlphaSelection`` visualizer wraps a "RegressionCV" model and
visualizes the alpha/error curve. Use this visualization to detect if
the model is responding to regularization, e.g. as you increase or
decrease alpha, the model responds and error is decreased. If the
visualization shows a jagged or random plot, then potentially the model
is not sensitive to that type of regularization and another is required
(e.g. L1 or ``Lasso`` regularization).

.. NOTE::
The ``AlphaSelection`` visualizer requires a "RegressorCV" model, e.g.
a specialized class that performs cross-validated alpha-selection
on behalf of the model. See the ``ManualAlphaSelection`` visualizer if
your regression model does not include cross-validation.

.. plot::
:context: close-figs
Expand All @@ -22,8 +38,8 @@ Workflow Model selection, Hyperparameter tuning
import numpy as np

from sklearn.linear_model import LassoCV
from yellowbrick.regressor import AlphaSelection
from yellowbrick.datasets import load_concrete
from yellowbrick.regressor import AlphaSelection

# Load the regression dataset
X, y = load_concrete()
Expand All @@ -37,9 +53,46 @@ Workflow Model selection, Hyperparameter tuning
visualizer.fit(X, y)
visualizer.show()

For Estimators *without* Built-in Cross-Validation
--------------------------------------------------

Most scikit-learn ``Estimators`` with ``alpha`` parameters
have a version with built-in cross-validation. However, if the
regressor you wish to use doesn't have an associated "CV" estimator,
or for some reason you would like to specify more control over the
alpha selection process, then you can use the ``ManualAlphaSelection``
visualizer. This visualizer is essentially a wrapper for scikit-learn's
``cross_val_score`` method, fitting a model for each alpha specified.

.. plot::
:context: close-figs
:alt: Manual alpha selection on the concrete data set

import numpy as np

from sklearn.linear_model import Ridge
from yellowbrick.datasets import load_concrete
from yellowbrick.regressor import ManualAlphaSelection

# Load the regression dataset
X, y = load_concrete()

# Create a list of alphas to cross-validate against
alphas = np.logspace(1, 4, 50)

# Instantiate the visualizer
visualizer = ManualAlphaSelection(
Ridge(),
alphas=alphas,
cv=12,
scoring="neg_mean_squared_error"
)

Quick Method
------------
visualizer.fit(X, y)
visualizer.show()

Quick Methods
-------------

The same functionality above can be achieved with the associated quick method `alphas`. This method will build the ``AlphaSelection`` Visualizer object with the associated arguments, fit it, then (optionally) immediately show it.

Expand All @@ -60,10 +113,31 @@ The same functionality above can be achieved with the associated quick method `a
alphas(LassoCV(random_state=0), X, y)


The ``ManualAlphaSelection`` visualizer can also be used as a oneliner:

.. plot::
:context: close-figs
:alt: manual alphas on the energy dataset

from sklearn.linear_model import ElasticNet
from yellowbrick.regressor.alphas import manual_alphas

from yellowbrick.datasets import load_energy

# Load dataset
X, y = load_energy()

# Instantiate a model
model = ElasticNet(tol=0.01, max_iter=10000)

# Use the quick method and immediately show the figure
manual_alphas(model, X, y, cv=6)


API Reference
-------------

.. automodule:: yellowbrick.regressor.alphas
:members: AlphaSelection, ManualAlphaSelection, alphas
:members: AlphaSelection, ManualAlphaSelection, alphas, manual_alphas
:undoc-members:
:show-inheritance:
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
50 changes: 50 additions & 0 deletions tests/test_regressor/test_alphas.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from yellowbrick.exceptions import YellowbrickTypeError
from yellowbrick.exceptions import YellowbrickValueError
from yellowbrick.regressor.alphas import AlphaSelection, alphas
from yellowbrick.regressor.alphas import ManualAlphaSelection, manual_alphas


from sklearn.svm import SVR, SVC
from sklearn.cluster import KMeans
Expand Down Expand Up @@ -167,3 +169,51 @@ def test_quick_method(self):
)
assert isinstance(visualizer, AlphaSelection)
self.assert_images_similar(visualizer)


class TestManualAlphaSelection(VisualTestCase):
"""
Test the ManualAlphaSelection visualizer
"""
def test_similar_image_manual(self):
"""
Integration test with image similarity comparison
"""

visualizer = ManualAlphaSelection(Lasso(random_state=0), cv=5)

X, y = make_regression(random_state=0)
visualizer.fit(X, y)
visualizer.finalize()

self.assert_images_similar(visualizer)

@pytest.mark.parametrize("model", [RidgeCV, LassoCV, LassoLarsCV, ElasticNetCV])
def test_regressor_nocv_manual(self, model):
"""
Ensure only non-CV regressors are allowed
"""
with pytest.raises(YellowbrickTypeError):
ManualAlphaSelection(model())

@pytest.mark.parametrize("model", [SVR, Ridge, Lasso, LassoLars, ElasticNet])
def test_regressor_cv_manual(self, model):
"""
Ensure non-CV regressors are allowed
"""
try:
ManualAlphaSelection(model())
except YellowbrickTypeError:
pytest.fail("could not instantiate Regressor on alpha selection")

def test_quick_method_manual(self):
"""
Test the manual alphas quick method producing a valid visualization
"""
X, y = load_energy(return_dataset=True).to_numpy()

visualizer = manual_alphas(
ElasticNet(random_state=0), X, y, cv=3, is_fitted=False, show=False
)
assert isinstance(visualizer, ManualAlphaSelection)
self.assert_images_similar(visualizer)
90 changes: 87 additions & 3 deletions yellowbrick/regressor/alphas.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(self, model, ax=None, is_fitted="auto", **kwargs):

# Check to make sure this is a "RegressorCV"
name = model.__class__.__name__
if not name.endswith("CV"):
if not name.endswith("CV") and not isinstance(self, ManualAlphaSelection):
raise YellowbrickTypeError(
(
"'{}' is not a CV regularization model;"
Expand Down Expand Up @@ -314,7 +314,10 @@ def __init__(self, model, ax=None, alphas=None, cv=None, scoring=None, **kwargs)
super(ManualAlphaSelection, self).__init__(model, ax=ax, **kwargs)

# Set manual alpha selection parameters
self.alphas = alphas or np.logspace(-10, -2, 200)
if alphas is not None:
self.alphas = alphas
else:
self.alphas = np.logspace(-10, -2, 200)
self.errors = None
self.score_method = partial(cross_val_score, cv=cv, scoring=scoring)

Expand Down Expand Up @@ -361,7 +364,7 @@ def draw(self):


##########################################################################
## Quick Method
## Quick Methods
##########################################################################


Expand Down Expand Up @@ -426,3 +429,84 @@ def alphas(model, X, y=None, ax=None, is_fitted="auto", show=True, **kwargs):

# Return the visualizer
return visualizer


def manual_alphas(
model,
X,
y=None,
ax=None,
alphas=None,
cv=None,
scoring=None,
show=True,
**kwargs
):
"""Quick Method:
The Manual Alpha Selection Visualizer demonstrates how different values of alpha
influence model selection during the regularization of linear models.
Generally speaking, alpha increases the affect of regularization, e.g. if
alpha is zero there is no regularization and the higher the alpha, the
more the regularization parameter influences the final model.
Parameters
----------
model : an unfitted Scikit-Learn regressor
Should be an instance of an unfitted regressor, and specifically one
whose name doesn't end with "CV". The regressor must support a call to
``set_params(alpha=alpha)`` and be fit multiple times. If the
regressor name ends with "CV" a ``YellowbrickValueError`` is raised.
ax : matplotlib Axes, default: None
The axes to plot the figure on. If None is passed in the current axes
will be used (or generated if required).
alphas : ndarray or Series, default: np.logspace(-10, 2, 200)
An array of alphas to fit each model with
cv : int, cross-validation generator or an iterable, optional
Determines the cross-validation splitting strategy.
Possible inputs for cv are:
- None, to use the default 3-fold cross validation,
- integer, to specify the number of folds in a `(Stratified)KFold`,
- An object to be used as a cross-validation generator.
- An iterable yielding train, test splits.
This argument is passed to the
``sklearn.model_selection.cross_val_score`` method to produce the
cross validated score for each alpha.
scoring : string, callable or None, optional, default: None
A string (see model evaluation documentation) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``.
This argument is passed to the
``sklearn.model_selection.cross_val_score`` method to produce the
cross validated score for each alpha.
kwargs : dict
Keyword arguments that are passed to the base class and may influence
the visualization as defined in other Visualizers.
Returns
-------
visualizer : AlphaSelection
Returns the alpha selection visualizer
"""
# Instantiate the visualizer
visualizer = ManualAlphaSelection(
model, ax, alphas=alphas, scoring=scoring, cv=cv, **kwargs
)

visualizer.fit(X, y)

if show:
visualizer.show()
else:
visualizer.finalize()

# Return the visualizer
return visualizer

0 comments on commit 9448675

Please sign in to comment.