diff --git a/docs/api/contrib/index.rst b/docs/api/contrib/index.rst index d79960a3c..fec517dcb 100644 --- a/docs/api/contrib/index.rst +++ b/docs/api/contrib/index.rst @@ -1,14 +1,19 @@ .. -*- mode: rst -*- -Yellowbrick Contrib -=================== +Contrib and Third-Party Libraries +================================= -The ``yellowbrick.contrib`` package contains a variety of extra tools and experimental visualizers that are outside of core support or are still in development. Here is a listing of the contrib modules currently available: +Yellowbrick's primary dependencies are scikit-learn and matplotlib, however the data science landscape in Python is large and there are many opportunties to use Yellowbrick with other machine learning and data analysis frameworks. The ``yellowbrick.contrib`` package contains several methods and mechanisms to support non-scikit-learn machine learning as well as extra tools and experimental visualizers that are outside of core support or are still in development. + +.. note:: If you're interested in using a non-scikit-learn estimator with Yellowbrick, please see the :doc:`wrapper` documentation. If the wrapper doesn't work out of the box, we welcome contributions to this module to include other libraries! + +The following contrib packages are currently available: .. toctree:: - :maxdepth: 2 + :maxdepth: 1 - boundaries + wrapper statsmodels + boundaries scatter missing/index diff --git a/docs/api/contrib/scatter.rst b/docs/api/contrib/scatter.rst index d6f754956..f978d0b1c 100644 --- a/docs/api/contrib/scatter.rst +++ b/docs/api/contrib/scatter.rst @@ -11,7 +11,7 @@ A scatter visualizer simply plots two features against each other and colors the :context: close-figs :alt: ScatterVisualizer on occupancy dataset - from yellowbrick.contrib import ScatterVisualizer + from yellowbrick.contrib.scatter import ScatterVisualizer from yellowbrick.datasets import load_occupancy # Load the classification dataset diff --git a/docs/api/contrib/statsmodels.rst b/docs/api/contrib/statsmodels.rst index 7fdf64034..d799d51c0 100644 --- a/docs/api/contrib/statsmodels.rst +++ b/docs/api/contrib/statsmodels.rst @@ -1,8 +1,54 @@ .. -*- mode: rst -*- -StatsModels Visualizers +statsmodels Visualizers ======================= +`statsmodels `_ is a Python library that provides utilities for the estimation of several statistical models and includes extensive results and metrics for each estimator. In particular, statsmodels excels at generalized linear models (GLMs) which are far superior to scikit-learn's implementation of ordinary least squares. + +This contrib module allows statsmodels users to take advantage of Yellowbrick visualizers by creating a wrapper class that implements the scikit-learn ``BaseEstimator``. Using the wrapper class, statsmodels can be passed directly to many visualizers, customized for the scoring and metric functionality required. + +.. warning:: The statsmodel wrapper is currently a prototype and as such is currently a bit trivial. Many options and extra functionality such as weights are not currently handled. We are actively looking for statsmodels users to contribute to this package! + +Using the statsmodels wrapper: + +.. code:: python + + import statsmodels.api as sm + + from functools import partial + from yellowbrick.regressor import ResidualsPlot + from yellowbrick.contrib.statsmodels import StatsModelsWrapper + + glm_gaussian_partial = partial(sm.GLM, family=sm.families.Gaussian()) + model = StatsModelsWrapper(glm_gaussian_partial) + + viz = ResidualsPlot(model) + viz.fit(X_train, y_train) + viz.score(X_test, y_test) + viz.show() + +You can also use fitted estimators with the wrapper to avoid having to pass a partial function: + +.. code:: python + + from yellowbrick.regressor import prediction_error + + # Create the OLS model + model = sm.OLS(y, X) + + # Get the detailed results + results = model.fit() + print(results.summary()) + + # Visualize the prediction error + prediction_error(StatsModelWrapper(model), X, y, is_fitted=True) + +This example also shows the use of a Yellowbrick oneliner, which is often more suited to the analytical style of statsmodels. + + +API Reference +------------- + .. automodule:: yellowbrick.contrib.statsmodels.base :members: StatsModelsWrapper :undoc-members: diff --git a/docs/api/contrib/wrapper.rst b/docs/api/contrib/wrapper.rst new file mode 100644 index 000000000..d365f9303 --- /dev/null +++ b/docs/api/contrib/wrapper.rst @@ -0,0 +1,104 @@ +.. -*- mode: rst -*- + +Using Third-Party Estimators +============================ + +Many machine learning libraries implement the scikit-learn estimator API to easily integrate alternative optimization or decision methods into a data science workflow. Because of this, it seems like it should be simple to drop in a non-scikit-learn estimator into a Yellowbrick visualizer, and in principle, it is. However, the reality is a bit more complicated. + +Yellowbrick visualizers often utilize more than just the method interface of estimators (e.g. ``fit()`` and ``predict()``), relying on the learned attributes (object properties with a single underscore suffix, e.g. ``coef_``). The issue is that when a third-party estimator does not expose these attributes, truly gnarly exceptions and tracebacks occur. Yellowbrick is meant to aid machine learning diagnostics reasoning, therefore instead of just allowing drop-in functionality that may cause confusion, we've created a wrapper functionality that is a bit kinder with it's messaging. + +But first, an example. + +.. code:: python + + # Import the wrap function and a Yellowbrick visualizer + from yellowbrick.contrib.wrapper import wrap + from yellowbrick.model_selection import feature_importances + + # Instantiate the third party estimator and wrap it, optionally fitting it + model = wrap(ThirdPartyEstimator()) + model.fit(X_train, y_train) + + # Use the visualizer + oz = feature_importances(model, X_test, y_test, is_fitted=True) + +The ``wrap`` function initializes the third party model as a ``ContribEstimator``, which passes through all functionality to the underlying estimator, however if an error occurs, the exception that will be raised looks like: + +.. code:: text + + yellowbrick.exceptions.YellowbrickAttributeError: estimator is missing the 'fit' + attribute, which is required for this visualizer - please see the third party + estimators documentation. + +Some estimators are required to pass type checking, for example the estimator must be a classifier, regressor, clusterer, density estimator, or outlier detector. A second argument can be passed to the ``wrap`` function declaring the type of estimator: + +.. code:: python + + from yellowbrick.classifier import precision_recall_curve + from yellowbrick.contrib.wrapper import wrap, CLASSIFIER + + model = wrap(ThirdPartyClassifier(), CLASSIFIER) + precision_recall_curve(model, X, y) + +Or you can simply use the wrap helper functions of the specific type: + +.. code:: python + + from yellowbrick.contrib.wrapper import regressor, classifier, clusterer + from yellowbrick.regressor import prediction_error + from yellowbrick.classifier import classification_report + from yellowbrick.cluster import intercluster_distance + + reg = regressor(ThirdPartyRegressor()) + prediction_error(reg, X, y) + + clf = classifier(ThirdPartyClassifier()) + classification_report(clf, X, y) + + ctr = clusterer(ThirdPartyClusterer()) + intercluster_distance(ctr, X) + + +So what should you do if a required attribute is missing from your estimator? The simplest and quickest thing to do is to subclass ``ContribEstimator`` and add the required functionality. + +.. code:: python + + from yellowbrick.contrib.wrapper import ContribEstimator, CLASSIFIER + + class MyWrapper(ContribEstimator): + + _estimator_type = CLASSIFIER + + @property + def feature_importances_(self): + return self.estimator.tree_feature_importances() + + + model = MyWrapper(ThirdPartyEstimator() + feature_importances(model, X, y) + + +This is certainly less than ideal - but we'd welcome a contrib PR to add more native functionality to Yellowbrick! + +Tested Libraries +---------------- + +The following libraries have been tested with the Yellowbrick wrapper. + +- `xgboost `_: both the ``XGBRFRegressor`` and ``XGBRFClassifier`` have been tested with Yellowbrick both with and without the wrapper functionality. +- `CatBoost `_: the ``CatBoostClassifier`` has been tested with the ``ClassificationReport`` visualizer. + +The following libraries have been partially tested and will likely work without too much additional effort: + +- `cuML `_: it is likely that clustering, classification, and regression cuML estimators will work with Yellowbrick visualizers. However, the cuDF datasets have not been tested with Yellowbrick. +- `Spark MLlib `_: The Spark DataFrame API and estimators should work with Yellowbrick visualizers in a local notebook context after collection. + +.. note:: If you have used a Python machine learning library not listed here with Yellowbrick, please let us know - we'd love to add it to the list! Also if you're using a library that is not wholly compatible, please open an issue so that we can explore how to integrate it with the ``yellowbrick.contrib`` module! + +API Reference +------------- + +.. automodule:: yellowbrick.contrib.wrapper + :members: wrap, classifier, regressor, clusterer, ContribEstimator + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/changelog.rst b/docs/changelog.rst index 1f9fabf1f..4ffb5ba27 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,12 +3,47 @@ Changelog ========= +Version 1.2 +----------- + +* Tag: v1.2_ +* Deployed Friday, October 9, 2020 +* Current Contributors: Rebecca Bilbro, Larry Gray, Vladislav Skripniuk, David Landsman, Prema Roman, @aldermartinez, Tan Tran, Benjamin Bengfort, Kellen Donohue, Kristen McIntyre, Tony Ojeda, Edwin Schmierer, Adam Morris, Nathan Danielsen + +Major Changes: + - Added Q-Q plot as side-by-side option to the ``ResidualsPlot`` visualizer. + - More robust handling of binary classification in ``ROCAUC`` visualization, standardizing the way that classifiers with ``predict_proba`` and ``decision_function`` methods are handling. A ``binary`` hyperparameter was added to the visualizer to ensure correct interpretation of binary ROCAUC plots. + - Fixes to ``ManualAlphaSelection`` to move it from prototype to prime time including documentation, tests, and quick method. This method allows users to perform alpha selection visualization on non-CV estimators. + - Removal of AppVeyor from the CI matrix after too many out-of-core (non-Yellowbrick) failures with setup and installation on the VisualStudio images. Yellowbrick CI currently omits Windows and Miniconda from the test matrix and we are actively looking for new solutions. + - Third party estimator wrapper in contrib to provide enhanced support for non-scikit-learn estimators such as those in Keras, CatBoost, and cuML. + +Minor Changes: + - Allow users to specify colors for the ``PrecisionRecallCurve``. + - Update ``ClassificationScoreVisualizer`` base class to have a ``class_colors_`` learned attribute instead of a ``colors`` property; additional polishing of multi-class colors in ``PrecisionRecallCurve``, ``ROCAUC``, and ``ClassPredictionError``. + - Update ``KElbowVisualizer`` fit method and quick method to allow passing ``sample_weight`` parameter through the visualizer. + - Enhancements to classification documentation to better discuss precision and recall and to diagnose with ``PrecisionRecallCurve`` and ``ClassificationReport`` visualizers. + - Improvements to ``CooksDistance`` visualizer documentation. + - Corrected ``KElbowVisualizer`` label and legend formatting. + - Typo fixes to ``ROCAUC`` documentation, labels, and legend. Typo fix to ``Manifold`` documentation. + - Use of ``tight_layout`` accessing the Visualizer figure property to finalize images and resolve discrepancies in plot directive images in documentation. + - Add ``get_param_names`` helper function to identify keyword-only parameters that belong to a specific method. + - Splits package namespace for ``yellowbrick.regressor.residuals`` to move ``PredictionError`` to its own module, ``yellowbrick.regressor.prediction_error``. + - Update tests to use ``SVC`` instead of ``LinearSVC`` and correct ``KMeans`` scores based on updates to scikit-learn v0.23. + - Continued maintenance and management of baseline images following dependency updates; removal of mpl.cbook dependency. + - Explicitly include license file in source distribution via ``MANIFEST.in``. + - Fixes to some deprecation warnings from ``sklearn.metrics``. + - Testing requirements depends on Pandas v1.0.4 or later. + - Reintegrates pytest-spec and verbose test logging, updates pytest dependency to v0.5.0 or later. + - Added Pandas v0.20 or later to documentation dependencies. + +.. _v1.2: https://github.com/DistrictDataLabs/yellowbrick/releases/tag/v1.2 + Version 1.1 ----------- * Tag: v1.1_ * Deployed Wednesday, February 12, 2020 -* Contributors: Benjamin Bengfort, Rebecca Bilbro, Kristen McIntyre, Larry Gray, Prema Roman, Adam Morris, Shivendra Sharma, Michael Chestnut, Michael Garod, Naresh Bachwani, Piyush Gautam, Daniel Navarrete, Molly Morrison, Emma Kwiecinska, Sarthak Jain, Tony Ojeda, Edwin Schmier, Nathan Danielsen +* Contributors: Benjamin Bengfort, Rebecca Bilbro, Kristen McIntyre, Larry Gray, Prema Roman, Adam Morris, Shivendra Sharma, Michael Chestnut, Michael Garod, Naresh Bachwani, Piyush Gautam, Daniel Navarrete, Molly Morrison, Emma Kwiecinska, Sarthak Jain, Tony Ojeda, Edwin Schmierer, Nathan Danielsen Major Changes: - Quick methods (aka Oneliners), which return a fully fitted finalized visualizer object in only a single line, are now implemented for all Yellowbrick Visualizers. Test coverage has been added for all quick methods. The documentation has been updated to document and demonstrate the usage of the quick methods. @@ -47,7 +82,7 @@ Version 1.0 * Tag: v1.0_ * Deployed Wednesday, August 28, 2019 -* Contributors: Benjamin Bengfort, Rebecca Bilbro, Nathan Danielsen, Kristen McIntyre, Larry Gray, Prema Roman, Adam Morris, Tony Ojeda, Edwin Schmier, Carl Dawson, Daniel Navarrete, Francois Dion, Halee Mason, Jeff Hale, Jiayi Zhang, Jimmy Shah, John Healy, Justin Ormont, Kevin Arvai, Michael Garod, Mike Curry, Nabanita Dash, Naresh Bachwani, Nicholas A. Brown, Piyush Gautam, Pradeep Singh, Rohit Ganapathy, Ry Whittington, Sangarshanan, Sourav Singh, Thomas J Fan, Zijie (ZJ) Poh, Zonghan, Xie +* Contributors: Benjamin Bengfort, Rebecca Bilbro, Nathan Danielsen, Kristen McIntyre, Larry Gray, Prema Roman, Adam Morris, Tony Ojeda, Edwin Schmierer, Carl Dawson, Daniel Navarrete, Francois Dion, Halee Mason, Jeff Hale, Jiayi Zhang, Jimmy Shah, John Healy, Justin Ormont, Kevin Arvai, Michael Garod, Mike Curry, Nabanita Dash, Naresh Bachwani, Nicholas A. Brown, Piyush Gautam, Pradeep Singh, Rohit Ganapathy, Ry Whittington, Sangarshanan, Sourav Singh, Thomas J Fan, Zijie (ZJ) Poh, Zonghan, Xie .. warning:: **Python 2 Deprecation**: Please note that this release deprecates Yellowbrick's support for Python 2.7. After careful consideration and following the lead of our primary dependencies (NumPy, scikit-learn, and Matplolib), we have chosen to move forward with the community and support Python 3.4 and later. diff --git a/docs/faq.rst b/docs/faq.rst index 218fda939..24dbcac5c 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -130,3 +130,11 @@ How can I access the sample datasets used in the examples? ---------------------------------------------------------- Visit the :doc:`api/datasets/index` page. + + +Can I use Yellowbrick with libraries other than scikit-learn? +------------------------------------------------------------- + +Potentially! Yellowbrick visualizers rely on the internal model implementing the scikit-learn API (e.g. having a ``fit()`` and ``predict()`` method), and often expect to be able to retrieve learned attributes from the model (e.g. ``coef_``). Some third-party estimators fully implement the scikit-learn API, but not all do. + +When using third-party libraries with Yellowbrick, we encourage you to ``wrap`` the model using the ``yellowbrick.contrib.wrapper`` module. Visit the :doc:`api/contrib/wrapper` page for all the details! diff --git a/tests/__init__.py b/tests/__init__.py index 533cdee65..ac99a9781 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -27,7 +27,7 @@ ## Test Constants ########################################################################## -EXPECTED_VERSION = "1.1" +EXPECTED_VERSION = "1.2" ########################################################################## diff --git a/tests/requirements.txt b/tests/requirements.txt index ee172df6a..8eb66e424 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -18,3 +18,6 @@ nltk>=3.2 pandas>=1.0.4 umap-learn==0.3.9 # reminder to bump periodically! +# Third-Party Estimator Tests +# xgboost==1.2.0 +# catboost==0.24.1 \ No newline at end of file diff --git a/tests/test_contrib/test_wrapper.py b/tests/test_contrib/test_wrapper.py new file mode 100644 index 000000000..cbb59b90c --- /dev/null +++ b/tests/test_contrib/test_wrapper.py @@ -0,0 +1,231 @@ +# tests.test_contrib.test_wrapper +# Tests third-party estimator wrapper +# +# Author: Benjamin Bengfort +# Created: Fri Oct 02 17:27:50 2020 -0400 +# +# Copyright (C) 2020 The scikit-yb developers +# For license information, see LICENSE.txt +# +# ID: test_wrapper.py [] benjamin@bengfort.com $ + +""" +Tests third-party estimator wrapper +""" + +########################################################################## +## Imports +########################################################################## + +import pytest + +from yellowbrick.contrib.wrapper import * +from yellowbrick.regressor import residuals_plot +from yellowbrick.classifier import classification_report +from yellowbrick.exceptions import YellowbrickAttributeError +from yellowbrick.utils.helpers import get_model_name, is_fitted +from yellowbrick.utils.types import is_estimator, is_probabilistic +from yellowbrick.utils.types import is_classifier, is_clusterer, is_regressor + +from sklearn.base import is_regressor as sk_is_regressor +from sklearn.base import is_classifier as sk_is_classifier +from sklearn.model_selection import train_test_split as tts +from sklearn.datasets import make_classification, make_regression +from sklearn.base import is_outlier_detector as sk_is_outlier_detector + + +try: + import xgboost as xgb +except ImportError: + xgb = None + +try: + import cudf + from cuml.ensemble import RandomForestClassifier as curfc +except ImportError: + curfc = None + +try: + import catboost +except ImportError: + catboost = None + + +########################################################################## +## Mocks and Fixtures +########################################################################## + +class ThirdPartyEstimator(object): + + def __init__(self, **params): + for attr, param in params.items(): + setattr(self, attr, param) + + def fit(self, X, y=None): + return 42 + + def predict_proba(self, X, y=None): + return 24 + + +########################################################################## +## Test Suite +########################################################################## + +class TestContribWrapper(object): + """ + Third party ContribEstimator wrapper + """ + + def test_wrapped_estimator(self): + """ + Check that the contrib wrapper passes through correctly + """ + tpe = ContribEstimator(ThirdPartyEstimator(foo="bar"), "foo") + assert tpe.fit([1,2,3]) == 42 + assert tpe._estimator_type == "foo" + + def test_attribute_error(self): + """ + Assert a correct exception is raised on failed access + """ + tpe = ContribEstimator(ThirdPartyEstimator()) + with pytest.raises(YellowbrickAttributeError): + tpe.foo + + def test_get_model_name(self): + """ + ContribWrapper should return the underlying model name + """ + tpe = ContribEstimator(ThirdPartyEstimator()) + assert get_model_name(tpe) == "ThirdPartyEstimator" + + def test_wraps_is_estimator(self): + """ + Assert a wrapped estimator passes is_estimator check + """ + tpe = wrap(ThirdPartyEstimator()) + assert is_estimator(tpe) + + def test_wraps_is_classifier(self): + """ + Assert a wrapped estimator passes is_classifier check + """ + tpe = classifier(ThirdPartyEstimator()) + assert is_classifier(tpe) + assert sk_is_classifier(tpe) + + def test_wraps_is_regressor(self): + """ + Assert a wrapped estimator passes is_regressor check + """ + tpe = regressor(ThirdPartyEstimator()) + assert is_regressor(tpe) + assert sk_is_regressor(tpe) + + def test_wraps_is_clusterer(self): + """ + Assert a wrapped estimator passes is_clusterer check + """ + tpe = clusterer(ThirdPartyEstimator()) + assert is_clusterer(tpe) + + def test_wraps_is_outlier_detector(self): + """ + Assert a wrapped estimator passes is_outlier_detector check + """ + tpe = wrap(ThirdPartyEstimator(), OUTLIER_DETECTOR) + assert sk_is_outlier_detector(tpe) + + def test_wraps_is_probabilistic(self): + """ + Assert a wrapped estimator passes is_probabilistic check + """ + tpe = wrap(ThirdPartyEstimator()) + assert is_probabilistic(tpe) + + +########################################################################## +## Test Non-sklearn Estimators +########################################################################## + +class TestNonSklearnEstimators(object): + """ + Check various non-sklearn estimators to see if the wrapper works for them + """ + + @pytest.mark.skipif(xgb is None, reason="requires xgboost") + def test_xgboost_regressor(self): + """ + Validate xgboost regressor with wrapper + """ + X, y = make_regression( + n_samples=500, n_features=22, n_informative=8, random_state=8311982 + ) + X_train, X_test, y_train, y_test = tts(X, y) + + model = regressor(xgb.XGBRFRegressor()) + oz = residuals_plot(model, X_train, y_train, X_test, y_test, show=False) + assert is_fitted(oz) + + + @pytest.mark.skipif(xgb is None, reason="requires xgboost") + def test_xgboost_regressor_unwrapped(self): + """ + Validate xgboost regressor without wrapper + """ + X, y = make_regression( + n_samples=500, n_features=22, n_informative=8, random_state=8311982 + ) + X_train, X_test, y_train, y_test = tts(X, y) + + model = xgb.XGBRFRegressor() + oz = residuals_plot(model, X_train, y_train, X_test, y_test, show=False) + assert is_fitted(oz) + + @pytest.mark.skipif(curfc is None, reason="requires cuML") + def test_cuml_classifier(self): + """ + Validate cuML classifier with wrapper + """ + # NOTE: this is currently untested as I wasn't able to install cuML + X, y = make_classification( + n_samples=400, n_features=10, n_informative=2, n_redundant=3, + n_classes=2, n_clusters_per_class=2, random_state=8311982 + ) + X_train, X_test, y_train, y_test = tts(X, y) + + # Convert to cudf dataframes + X_train = cudf.DataFrame(X_train) + y_train = cudf.Series(y_train) + X_test = cudf.DataFrame(X_test) + y_test = cudf.Series(y_test) + + model = classifier(curfc(n_estimators=40, max_depth=8, max_features=1)) + oz = classification_report(model, X_train, y_train, X_test, y_test, show=False) + assert is_fitted(oz) + + @pytest.mark.skipif(catboost is None, reason="requires CatBoost") + def test_catboost_classifier(self): + """ + Validate CatBoost classifier with wrapper + """ + X, y = make_classification( + n_samples=400, n_features=10, n_informative=2, n_redundant=3, + n_classes=2, n_clusters_per_class=2, random_state=8311982 + ) + X_train, X_test, y_train, y_test = tts(X, y) + + model = classifier(catboost.CatBoostClassifier( + iterations=2, depth=2, learning_rate=1, loss_function='Logloss' + )) + + # For some reason, this works if you call fit directly and pass is_fitted to + # the visualizer, but does not work if you rely on the visualizer to fit the + # model on the data. I can't tell if this is a catboost or Yellowbrick issue. + model.fit(X_train, y_train) + + oz = classification_report( + model, X_train, y_train, X_test, y_test, is_fitted=True, show=False + ) + assert is_fitted(oz) diff --git a/yellowbrick/contrib/__init__.py b/yellowbrick/contrib/__init__.py index e9e66dfe4..0e8ef1173 100644 --- a/yellowbrick/contrib/__init__.py +++ b/yellowbrick/contrib/__init__.py @@ -5,6 +5,3 @@ # # # ID: __init__.py [a60bc41] nathan.danielsen@gmail.com $ - - -from .scatter import ScatterViz, ScatterVisualizer, scatterviz diff --git a/yellowbrick/contrib/wrapper.py b/yellowbrick/contrib/wrapper.py new file mode 100644 index 000000000..3db8c35d6 --- /dev/null +++ b/yellowbrick/contrib/wrapper.py @@ -0,0 +1,134 @@ +# yellowbrick.contrib.wrapper +# Wrapper for third-party estimators that implement the sklearn API +# +# Author: Benjamin Bengfort +# Created: Fri Oct 02 14:47:54 2020 -0400 +# +# Copyright (C) 2020 The scikit-yb developers +# For license information, see LICENSE.txt +# +# ID: wrapper.py [] benjamin@bengfort.com $ + +""" +Wrapper for third-party estimators that implement the sklearn API but do not directly +subclass the ``sklearn.base.BaseEstimator`` class. This method is a quick way to get +other estimators into Yellowbrick, while avoiding weird errors and issues. +""" + +########################################################################## +## Imports +########################################################################## + +from yellowbrick.exceptions import YellowbrickAttributeError + + +########################################################################## +## Module Constants +########################################################################## + +CLASSIFIER = "classifier" +REGRESSOR = "regressor" +CLUSTERER = "clusterer" +DENSITY_ESTIMATOR = "DensityEstimator" +OUTLIER_DETECTOR = "outlier_detector" + + +########################################################################## +## Functional API +########################################################################## + +def wrap(estimator, estimator_type=None): + """ + Wrap a third-party estimator that implements portions of the scikit-learn API to + make it available to Yellowbrick visualizers. If the Yellowbrick visualizer cannot + succeed, then a sensible error is raised instead. + + Parameters + ---------- + estimator : object + The non-sklearn estimator to wrap and use for Visualizers + + estimator_type : str, optional + One of "classifier", "regressor", "clusterer", "DensityEstimator", or + "outlier_detector" that allows the contrib estimator to pass the scikit-learn + ``is_classifier``, etc. functions. If not specified, the _estimator_type attr + is passed through to the underlying estimator. + """ + return ContribEstimator(estimator, estimator_type) + + +def classifier(estimator): + """ + Wrap a third-party classifier to make it available to Yellowbrick visualizers. + + Parameters + ---------- + estimator : object + The non-sklearn classifier to wrap and use for Visualizers + """ + return wrap(estimator, CLASSIFIER) + + +def regressor(estimator): + """ + Wrap a third-party regressor to make it available to Yellowbrick visualizers. + + Parameters + ---------- + estimator : object + The non-sklearn regressor to wrap and use for Visualizers + """ + return wrap(estimator, REGRESSOR) + + +def clusterer(estimator): + """ + Wrap a third-party clusterer to make it available to Yellowbrick visualizers. + + Parameters + ---------- + estimator : object + The non-sklearn clusterer to wrap and use for Visualizers + """ + return wrap(estimator, CLUSTERER) + + +########################################################################## +## ContribEstimator - Third Pary Estimator Wrapper +########################################################################## + +class ContribEstimator(object): + """ + Wraps a third party estimator that implements the sckit-learn API and therefore + could be used with Yellowbrick but doesn't subclass ``BaseEstimator``. Since there + are a number of pitfalls, this object provides sensible errors and warnings rather + than completely blowing up, allowing contrib users to identify issues and fix them, + smoothing the path to getting third party estimators into the Yellowbrick ecosystem. + + Parameters + ---------- + estimator : object + The non-sklearn estimator to wrap and use for Visualizers + + estimator_type : str, optional + One of "classifier", "regressor", "clusterer", "DensityEstimator", or + "outlier_detector" that allows the contrib estimator to pass the scikit-learn + ``is_classifier``, etc. functions. If not specified, the _estimator_type attr + is passed through to the underlying estimator. + """ + + def __init__(self, estimator, estimator_type=None): + self.estimator = estimator + # Do not set estimator type if not specified to allow passthrough + if estimator_type: + self._estimator_type = estimator_type + + def __getattr__(self, attr): + # proxy to the wrapped object + try: + return getattr(self.estimator, attr) + except AttributeError: + raise YellowbrickAttributeError(( + "estimator is missing the '{}' attribute, which is required for this " + "visualizer - please see the third party estimators documentation." + ).format(attr)) diff --git a/yellowbrick/exceptions.py b/yellowbrick/exceptions.py index 2bf03247e..0e21d1c56 100644 --- a/yellowbrick/exceptions.py +++ b/yellowbrick/exceptions.py @@ -89,6 +89,14 @@ class YellowbrickKeyError(YellowbrickError, KeyError): pass +class YellowbrickAttributeError(YellowbrickError, AttributeError): + """ + A required attribute is missing on the estimator. + """ + + pass + + ########################################################################## ## Assertions ########################################################################## diff --git a/yellowbrick/utils/helpers.py b/yellowbrick/utils/helpers.py index edaa0bb6b..6cb43672c 100644 --- a/yellowbrick/utils/helpers.py +++ b/yellowbrick/utils/helpers.py @@ -28,12 +28,13 @@ from yellowbrick.utils.types import is_estimator from yellowbrick.exceptions import YellowbrickTypeError +from yellowbrick.contrib.wrapper import ContribEstimator + ########################################################################## ## Model and Feature Information ########################################################################## - def is_fitted(estimator): """ In order to ensure that we don't call ``fit`` on an already-fitted model, @@ -137,11 +138,12 @@ def get_model_name(model): "Cannot detect the model name for non estimator: '{}'".format(type(model)) ) + if isinstance(model, Pipeline): + return get_model_name(model.steps[-1][-1]) + elif isinstance(model, ContribEstimator): + return model.estimator.__class__.__name__ else: - if isinstance(model, Pipeline): - return get_model_name(model.steps[-1][-1]) - else: - return model.__class__.__name__ + return model.__class__.__name__ def has_ndarray_int_columns(features, X): diff --git a/yellowbrick/utils/types.py b/yellowbrick/utils/types.py index 04bb67cec..4fc1a24fc 100644 --- a/yellowbrick/utils/types.py +++ b/yellowbrick/utils/types.py @@ -21,6 +21,7 @@ import numpy as np from sklearn.base import BaseEstimator +from yellowbrick.contrib.wrapper import ContribEstimator ########################################################################## @@ -39,9 +40,9 @@ def is_estimator(model): Scikit-Learn estimator or Yellowbrick visualizer """ if inspect.isclass(model): - return issubclass(model, BaseEstimator) + return issubclass(model, (BaseEstimator, ContribEstimator)) - return isinstance(model, BaseEstimator) + return isinstance(model, (BaseEstimator, ContribEstimator)) # Alias for closer name to isinstance and issubclass diff --git a/yellowbrick/version.py b/yellowbrick/version.py index 95f134ddc..fb72e2027 100644 --- a/yellowbrick/version.py +++ b/yellowbrick/version.py @@ -19,11 +19,11 @@ __version_info__ = { "major": 1, - "minor": 1, + "minor": 2, "micro": 0, "releaselevel": "final", "post": 0, - "serial": 17, + "serial": 18, } ##########################################################################