diff --git a/tests/test_contrib/test_wrapper.py b/tests/test_contrib/test_wrapper.py new file mode 100644 index 000000000..a3fade9dd --- /dev/null +++ b/tests/test_contrib/test_wrapper.py @@ -0,0 +1,125 @@ +# tests.test_contrib.test_wrapper +# Tests third-party estimator wrapper +# +# Author: Benjamin Bengfort +# Created: Fri Oct 02 17:27:50 2020 -0400 +# +# Copyright (C) 2020 The scikit-yb developers +# For license information, see LICENSE.txt +# +# ID: test_wrapper.py [] benjamin@bengfort.com $ + +""" +Tests third-party estimator wrapper +""" + +########################################################################## +## Imports +########################################################################## + +import pytest + + +from yellowbrick.contrib.wrapper import * +from yellowbrick.utils.helpers import get_model_name +from yellowbrick.exceptions import YellowbrickAttributeError +from yellowbrick.utils.types import is_estimator, is_probabilistic +from yellowbrick.utils.types import is_classifier, is_clusterer, is_regressor + +from sklearn.base import is_regressor as sk_is_regressor +from sklearn.base import is_classifier as sk_is_classifier +from sklearn.base import is_outlier_detector as sk_is_outlier_detector + + +########################################################################## +## Mocks and Fixtures +########################################################################## + +class ThirdPartyEstimator(object): + + def __init__(self, **params): + for attr, param in params.items(): + setattr(self, attr, param) + + def fit(self, X, y=None): + return 42 + + def predict_proba(self, X, y=None): + return 24 + + +########################################################################## +## Test Suite +########################################################################## + +class TestContribWrapper(object): + """ + Third party ContribEstimator wrapper + """ + + def test_wrapped_estimator(self): + """ + Check that the contrib wrapper passes through correctly + """ + tpe = ContribEstimator(ThirdPartyEstimator(foo="bar"), "foo") + assert tpe.fit([1,2,3]) == 42 + assert tpe._estimator_type == "foo" + + def test_attribute_error(self): + """ + Assert a correct exception is raised on failed access + """ + tpe = ContribEstimator(ThirdPartyEstimator()) + with pytest.raises(YellowbrickAttributeError): + tpe.foo + + def test_get_model_name(self): + """ + ContribWrapper should return the underlying model name + """ + tpe = ContribEstimator(ThirdPartyEstimator()) + assert get_model_name(tpe) == "ThirdPartyEstimator" + + def test_wraps_is_estimator(self): + """ + Assert a wrapped estimator passes is_estimator check + """ + tpe = wrap(ThirdPartyEstimator()) + assert is_estimator(tpe) + + def test_wraps_is_classifier(self): + """ + Assert a wrapped estimator passes is_classifier check + """ + tpe = classifier(ThirdPartyEstimator()) + assert is_classifier(tpe) + assert sk_is_classifier(tpe) + + def test_wraps_is_regressor(self): + """ + Assert a wrapped estimator passes is_regressor check + """ + tpe = regressor(ThirdPartyEstimator()) + assert is_regressor(tpe) + assert sk_is_regressor(tpe) + + def test_wraps_is_clusterer(self): + """ + Assert a wrapped estimator passes is_clusterer check + """ + tpe = clusterer(ThirdPartyEstimator()) + assert is_clusterer(tpe) + + def test_wraps_is_outlier_detector(self): + """ + Assert a wrapped estimator passes is_outlier_detector check + """ + tpe = wrap(ThirdPartyEstimator(), OUTLIER_DETECTOR) + assert sk_is_outlier_detector(tpe) + + def test_wraps_is_probabilistic(self): + """ + Assert a wrapped estimator passes is_probabilistic check + """ + tpe = wrap(ThirdPartyEstimator()) + assert is_probabilistic(tpe) diff --git a/yellowbrick/contrib/__init__.py b/yellowbrick/contrib/__init__.py index e9e66dfe4..0e8ef1173 100644 --- a/yellowbrick/contrib/__init__.py +++ b/yellowbrick/contrib/__init__.py @@ -5,6 +5,3 @@ # # # ID: __init__.py [a60bc41] nathan.danielsen@gmail.com $ - - -from .scatter import ScatterViz, ScatterVisualizer, scatterviz diff --git a/yellowbrick/contrib/wrapper.py b/yellowbrick/contrib/wrapper.py new file mode 100644 index 000000000..3db8c35d6 --- /dev/null +++ b/yellowbrick/contrib/wrapper.py @@ -0,0 +1,134 @@ +# yellowbrick.contrib.wrapper +# Wrapper for third-party estimators that implement the sklearn API +# +# Author: Benjamin Bengfort +# Created: Fri Oct 02 14:47:54 2020 -0400 +# +# Copyright (C) 2020 The scikit-yb developers +# For license information, see LICENSE.txt +# +# ID: wrapper.py [] benjamin@bengfort.com $ + +""" +Wrapper for third-party estimators that implement the sklearn API but do not directly +subclass the ``sklearn.base.BaseEstimator`` class. This method is a quick way to get +other estimators into Yellowbrick, while avoiding weird errors and issues. +""" + +########################################################################## +## Imports +########################################################################## + +from yellowbrick.exceptions import YellowbrickAttributeError + + +########################################################################## +## Module Constants +########################################################################## + +CLASSIFIER = "classifier" +REGRESSOR = "regressor" +CLUSTERER = "clusterer" +DENSITY_ESTIMATOR = "DensityEstimator" +OUTLIER_DETECTOR = "outlier_detector" + + +########################################################################## +## Functional API +########################################################################## + +def wrap(estimator, estimator_type=None): + """ + Wrap a third-party estimator that implements portions of the scikit-learn API to + make it available to Yellowbrick visualizers. If the Yellowbrick visualizer cannot + succeed, then a sensible error is raised instead. + + Parameters + ---------- + estimator : object + The non-sklearn estimator to wrap and use for Visualizers + + estimator_type : str, optional + One of "classifier", "regressor", "clusterer", "DensityEstimator", or + "outlier_detector" that allows the contrib estimator to pass the scikit-learn + ``is_classifier``, etc. functions. If not specified, the _estimator_type attr + is passed through to the underlying estimator. + """ + return ContribEstimator(estimator, estimator_type) + + +def classifier(estimator): + """ + Wrap a third-party classifier to make it available to Yellowbrick visualizers. + + Parameters + ---------- + estimator : object + The non-sklearn classifier to wrap and use for Visualizers + """ + return wrap(estimator, CLASSIFIER) + + +def regressor(estimator): + """ + Wrap a third-party regressor to make it available to Yellowbrick visualizers. + + Parameters + ---------- + estimator : object + The non-sklearn regressor to wrap and use for Visualizers + """ + return wrap(estimator, REGRESSOR) + + +def clusterer(estimator): + """ + Wrap a third-party clusterer to make it available to Yellowbrick visualizers. + + Parameters + ---------- + estimator : object + The non-sklearn clusterer to wrap and use for Visualizers + """ + return wrap(estimator, CLUSTERER) + + +########################################################################## +## ContribEstimator - Third Pary Estimator Wrapper +########################################################################## + +class ContribEstimator(object): + """ + Wraps a third party estimator that implements the sckit-learn API and therefore + could be used with Yellowbrick but doesn't subclass ``BaseEstimator``. Since there + are a number of pitfalls, this object provides sensible errors and warnings rather + than completely blowing up, allowing contrib users to identify issues and fix them, + smoothing the path to getting third party estimators into the Yellowbrick ecosystem. + + Parameters + ---------- + estimator : object + The non-sklearn estimator to wrap and use for Visualizers + + estimator_type : str, optional + One of "classifier", "regressor", "clusterer", "DensityEstimator", or + "outlier_detector" that allows the contrib estimator to pass the scikit-learn + ``is_classifier``, etc. functions. If not specified, the _estimator_type attr + is passed through to the underlying estimator. + """ + + def __init__(self, estimator, estimator_type=None): + self.estimator = estimator + # Do not set estimator type if not specified to allow passthrough + if estimator_type: + self._estimator_type = estimator_type + + def __getattr__(self, attr): + # proxy to the wrapped object + try: + return getattr(self.estimator, attr) + except AttributeError: + raise YellowbrickAttributeError(( + "estimator is missing the '{}' attribute, which is required for this " + "visualizer - please see the third party estimators documentation." + ).format(attr)) diff --git a/yellowbrick/exceptions.py b/yellowbrick/exceptions.py index 2bf03247e..0e21d1c56 100644 --- a/yellowbrick/exceptions.py +++ b/yellowbrick/exceptions.py @@ -89,6 +89,14 @@ class YellowbrickKeyError(YellowbrickError, KeyError): pass +class YellowbrickAttributeError(YellowbrickError, AttributeError): + """ + A required attribute is missing on the estimator. + """ + + pass + + ########################################################################## ## Assertions ########################################################################## diff --git a/yellowbrick/utils/helpers.py b/yellowbrick/utils/helpers.py index edaa0bb6b..6cb43672c 100644 --- a/yellowbrick/utils/helpers.py +++ b/yellowbrick/utils/helpers.py @@ -28,12 +28,13 @@ from yellowbrick.utils.types import is_estimator from yellowbrick.exceptions import YellowbrickTypeError +from yellowbrick.contrib.wrapper import ContribEstimator + ########################################################################## ## Model and Feature Information ########################################################################## - def is_fitted(estimator): """ In order to ensure that we don't call ``fit`` on an already-fitted model, @@ -137,11 +138,12 @@ def get_model_name(model): "Cannot detect the model name for non estimator: '{}'".format(type(model)) ) + if isinstance(model, Pipeline): + return get_model_name(model.steps[-1][-1]) + elif isinstance(model, ContribEstimator): + return model.estimator.__class__.__name__ else: - if isinstance(model, Pipeline): - return get_model_name(model.steps[-1][-1]) - else: - return model.__class__.__name__ + return model.__class__.__name__ def has_ndarray_int_columns(features, X): diff --git a/yellowbrick/utils/types.py b/yellowbrick/utils/types.py index 04bb67cec..4fc1a24fc 100644 --- a/yellowbrick/utils/types.py +++ b/yellowbrick/utils/types.py @@ -21,6 +21,7 @@ import numpy as np from sklearn.base import BaseEstimator +from yellowbrick.contrib.wrapper import ContribEstimator ########################################################################## @@ -39,9 +40,9 @@ def is_estimator(model): Scikit-Learn estimator or Yellowbrick visualizer """ if inspect.isclass(model): - return issubclass(model, BaseEstimator) + return issubclass(model, (BaseEstimator, ContribEstimator)) - return isinstance(model, BaseEstimator) + return isinstance(model, (BaseEstimator, ContribEstimator)) # Alias for closer name to isinstance and issubclass