Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds a wrapper for estimators that implement the scikit-learn API but do not extend BaseEstimator. If the estimator is missing required properties (generally the learned attributes) then a sensible error is raised. Includes documentation about how to use non-sklearn estimators with Yellowbrick. Fixes #1098 Fixes #1099 Related to #397 Related to #1066
- Loading branch information
Showing
6 changed files
with
277 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# tests.test_contrib.test_wrapper | ||
# Tests third-party estimator wrapper | ||
# | ||
# Author: Benjamin Bengfort | ||
# Created: Fri Oct 02 17:27:50 2020 -0400 | ||
# | ||
# Copyright (C) 2020 The scikit-yb developers | ||
# For license information, see LICENSE.txt | ||
# | ||
# ID: test_wrapper.py [] benjamin@bengfort.com $ | ||
|
||
""" | ||
Tests third-party estimator wrapper | ||
""" | ||
|
||
########################################################################## | ||
## Imports | ||
########################################################################## | ||
|
||
import pytest | ||
|
||
|
||
from yellowbrick.contrib.wrapper import * | ||
from yellowbrick.utils.helpers import get_model_name | ||
from yellowbrick.exceptions import YellowbrickAttributeError | ||
from yellowbrick.utils.types import is_estimator, is_probabilistic | ||
from yellowbrick.utils.types import is_classifier, is_clusterer, is_regressor | ||
|
||
from sklearn.base import is_regressor as sk_is_regressor | ||
from sklearn.base import is_classifier as sk_is_classifier | ||
from sklearn.base import is_outlier_detector as sk_is_outlier_detector | ||
|
||
|
||
########################################################################## | ||
## Mocks and Fixtures | ||
########################################################################## | ||
|
||
class ThirdPartyEstimator(object): | ||
|
||
def __init__(self, **params): | ||
for attr, param in params.items(): | ||
setattr(self, attr, param) | ||
|
||
def fit(self, X, y=None): | ||
return 42 | ||
|
||
def predict_proba(self, X, y=None): | ||
return 24 | ||
|
||
|
||
########################################################################## | ||
## Test Suite | ||
########################################################################## | ||
|
||
class TestContribWrapper(object): | ||
""" | ||
Third party ContribEstimator wrapper | ||
""" | ||
|
||
def test_wrapped_estimator(self): | ||
""" | ||
Check that the contrib wrapper passes through correctly | ||
""" | ||
tpe = ContribEstimator(ThirdPartyEstimator(foo="bar"), "foo") | ||
assert tpe.fit([1,2,3]) == 42 | ||
assert tpe._estimator_type == "foo" | ||
|
||
def test_attribute_error(self): | ||
""" | ||
Assert a correct exception is raised on failed access | ||
""" | ||
tpe = ContribEstimator(ThirdPartyEstimator()) | ||
with pytest.raises(YellowbrickAttributeError): | ||
tpe.foo | ||
|
||
def test_get_model_name(self): | ||
""" | ||
ContribWrapper should return the underlying model name | ||
""" | ||
tpe = ContribEstimator(ThirdPartyEstimator()) | ||
assert get_model_name(tpe) == "ThirdPartyEstimator" | ||
|
||
def test_wraps_is_estimator(self): | ||
""" | ||
Assert a wrapped estimator passes is_estimator check | ||
""" | ||
tpe = wrap(ThirdPartyEstimator()) | ||
assert is_estimator(tpe) | ||
|
||
def test_wraps_is_classifier(self): | ||
""" | ||
Assert a wrapped estimator passes is_classifier check | ||
""" | ||
tpe = classifier(ThirdPartyEstimator()) | ||
assert is_classifier(tpe) | ||
assert sk_is_classifier(tpe) | ||
|
||
def test_wraps_is_regressor(self): | ||
""" | ||
Assert a wrapped estimator passes is_regressor check | ||
""" | ||
tpe = regressor(ThirdPartyEstimator()) | ||
assert is_regressor(tpe) | ||
assert sk_is_regressor(tpe) | ||
|
||
def test_wraps_is_clusterer(self): | ||
""" | ||
Assert a wrapped estimator passes is_clusterer check | ||
""" | ||
tpe = clusterer(ThirdPartyEstimator()) | ||
assert is_clusterer(tpe) | ||
|
||
def test_wraps_is_outlier_detector(self): | ||
""" | ||
Assert a wrapped estimator passes is_outlier_detector check | ||
""" | ||
tpe = wrap(ThirdPartyEstimator(), OUTLIER_DETECTOR) | ||
assert sk_is_outlier_detector(tpe) | ||
|
||
def test_wraps_is_probabilistic(self): | ||
""" | ||
Assert a wrapped estimator passes is_probabilistic check | ||
""" | ||
tpe = wrap(ThirdPartyEstimator()) | ||
assert is_probabilistic(tpe) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
# yellowbrick.contrib.wrapper | ||
# Wrapper for third-party estimators that implement the sklearn API | ||
# | ||
# Author: Benjamin Bengfort | ||
# Created: Fri Oct 02 14:47:54 2020 -0400 | ||
# | ||
# Copyright (C) 2020 The scikit-yb developers | ||
# For license information, see LICENSE.txt | ||
# | ||
# ID: wrapper.py [] benjamin@bengfort.com $ | ||
|
||
""" | ||
Wrapper for third-party estimators that implement the sklearn API but do not directly | ||
subclass the ``sklearn.base.BaseEstimator`` class. This method is a quick way to get | ||
other estimators into Yellowbrick, while avoiding weird errors and issues. | ||
""" | ||
|
||
########################################################################## | ||
## Imports | ||
########################################################################## | ||
|
||
from yellowbrick.exceptions import YellowbrickAttributeError | ||
|
||
|
||
########################################################################## | ||
## Module Constants | ||
########################################################################## | ||
|
||
CLASSIFIER = "classifier" | ||
REGRESSOR = "regressor" | ||
CLUSTERER = "clusterer" | ||
DENSITY_ESTIMATOR = "DensityEstimator" | ||
OUTLIER_DETECTOR = "outlier_detector" | ||
|
||
|
||
########################################################################## | ||
## Functional API | ||
########################################################################## | ||
|
||
def wrap(estimator, estimator_type=None): | ||
""" | ||
Wrap a third-party estimator that implements portions of the scikit-learn API to | ||
make it available to Yellowbrick visualizers. If the Yellowbrick visualizer cannot | ||
succeed, then a sensible error is raised instead. | ||
Parameters | ||
---------- | ||
estimator : object | ||
The non-sklearn estimator to wrap and use for Visualizers | ||
estimator_type : str, optional | ||
One of "classifier", "regressor", "clusterer", "DensityEstimator", or | ||
"outlier_detector" that allows the contrib estimator to pass the scikit-learn | ||
``is_classifier``, etc. functions. If not specified, the _estimator_type attr | ||
is passed through to the underlying estimator. | ||
""" | ||
return ContribEstimator(estimator, estimator_type) | ||
|
||
|
||
def classifier(estimator): | ||
""" | ||
Wrap a third-party classifier to make it available to Yellowbrick visualizers. | ||
Parameters | ||
---------- | ||
estimator : object | ||
The non-sklearn classifier to wrap and use for Visualizers | ||
""" | ||
return wrap(estimator, CLASSIFIER) | ||
|
||
|
||
def regressor(estimator): | ||
""" | ||
Wrap a third-party regressor to make it available to Yellowbrick visualizers. | ||
Parameters | ||
---------- | ||
estimator : object | ||
The non-sklearn regressor to wrap and use for Visualizers | ||
""" | ||
return wrap(estimator, REGRESSOR) | ||
|
||
|
||
def clusterer(estimator): | ||
""" | ||
Wrap a third-party clusterer to make it available to Yellowbrick visualizers. | ||
Parameters | ||
---------- | ||
estimator : object | ||
The non-sklearn clusterer to wrap and use for Visualizers | ||
""" | ||
return wrap(estimator, CLUSTERER) | ||
|
||
|
||
########################################################################## | ||
## ContribEstimator - Third Pary Estimator Wrapper | ||
########################################################################## | ||
|
||
class ContribEstimator(object): | ||
""" | ||
Wraps a third party estimator that implements the sckit-learn API and therefore | ||
could be used with Yellowbrick but doesn't subclass ``BaseEstimator``. Since there | ||
are a number of pitfalls, this object provides sensible errors and warnings rather | ||
than completely blowing up, allowing contrib users to identify issues and fix them, | ||
smoothing the path to getting third party estimators into the Yellowbrick ecosystem. | ||
Parameters | ||
---------- | ||
estimator : object | ||
The non-sklearn estimator to wrap and use for Visualizers | ||
estimator_type : str, optional | ||
One of "classifier", "regressor", "clusterer", "DensityEstimator", or | ||
"outlier_detector" that allows the contrib estimator to pass the scikit-learn | ||
``is_classifier``, etc. functions. If not specified, the _estimator_type attr | ||
is passed through to the underlying estimator. | ||
""" | ||
|
||
def __init__(self, estimator, estimator_type=None): | ||
self.estimator = estimator | ||
# Do not set estimator type if not specified to allow passthrough | ||
if estimator_type: | ||
self._estimator_type = estimator_type | ||
|
||
def __getattr__(self, attr): | ||
# proxy to the wrapped object | ||
try: | ||
return getattr(self.estimator, attr) | ||
except AttributeError: | ||
raise YellowbrickAttributeError(( | ||
"estimator is missing the '{}' attribute, which is required for this " | ||
"visualizer - please see the third party estimators documentation." | ||
).format(attr)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters