Skip to content

Commit

Permalink
Third Party Estimator Wrapper (#1103)
Browse files Browse the repository at this point in the history
This PR introdues a wrapper for estimators that implement the scikit-learn API but do
not extend BaseEstimator. If the estimator is missing required properties (generally the learned attributes) then a sensible error is raised. Includes documentation about how to use non-sklearn estimators with Yellowbrick. Closes #1098, closes #1099, closes #397
  • Loading branch information
bbengfort committed Oct 5, 2020
1 parent 0b24d16 commit 1329281
Show file tree
Hide file tree
Showing 11 changed files with 548 additions and 17 deletions.
15 changes: 10 additions & 5 deletions docs/api/contrib/index.rst
@@ -1,14 +1,19 @@
.. -*- mode: rst -*-
Yellowbrick Contrib
===================
Contrib and Third-Party Libraries
=================================

The ``yellowbrick.contrib`` package contains a variety of extra tools and experimental visualizers that are outside of core support or are still in development. Here is a listing of the contrib modules currently available:
Yellowbrick's primary dependencies are scikit-learn and matplotlib, however the data science landscape in Python is large and there are many opportunties to use Yellowbrick with other machine learning and data analysis frameworks. The ``yellowbrick.contrib`` package contains several methods and mechanisms to support non-scikit-learn machine learning as well as extra tools and experimental visualizers that are outside of core support or are still in development.

.. note:: If you're interested in using a non-scikit-learn estimator with Yellowbrick, please see the :doc:`wrapper` documentation. If the wrapper doesn't work out of the box, we welcome contributions to this module to include other libraries!

The following contrib packages are currently available:

.. toctree::
:maxdepth: 2
:maxdepth: 1

boundaries
wrapper
statsmodels
boundaries
scatter
missing/index
2 changes: 1 addition & 1 deletion docs/api/contrib/scatter.rst
Expand Up @@ -11,7 +11,7 @@ A scatter visualizer simply plots two features against each other and colors the
:context: close-figs
:alt: ScatterVisualizer on occupancy dataset

from yellowbrick.contrib import ScatterVisualizer
from yellowbrick.contrib.scatter import ScatterVisualizer
from yellowbrick.datasets import load_occupancy

# Load the classification dataset
Expand Down
48 changes: 47 additions & 1 deletion docs/api/contrib/statsmodels.rst
@@ -1,8 +1,54 @@
.. -*- mode: rst -*-
StatsModels Visualizers
statsmodels Visualizers
=======================

`statsmodels <https://www.statsmodels.org/stable/index.html>`_ is a Python library that provides utilities for the estimation of several statistical models and includes extensive results and metrics for each estimator. In particular, statsmodels excels at generalized linear models (GLMs) which are far superior to scikit-learn's implementation of ordinary least squares.

This contrib module allows statsmodels users to take advantage of Yellowbrick visualizers by creating a wrapper class that implements the scikit-learn ``BaseEstimator``. Using the wrapper class, statsmodels can be passed directly to many visualizers, customized for the scoring and metric functionality required.

.. warning:: The statsmodel wrapper is currently a prototype and as such is currently a bit trivial. Many options and extra functionality such as weights are not currently handled. We are actively looking for statsmodels users to contribute to this package!

Using the statsmodels wrapper:

.. code:: python
import statsmodels.api as sm
from functools import partial
from yellowbrick.regressor import ResidualsPlot
from yellowbrick.contrib.statsmodels import StatsModelsWrapper
glm_gaussian_partial = partial(sm.GLM, family=sm.families.Gaussian())
model = StatsModelsWrapper(glm_gaussian_partial)
viz = ResidualsPlot(model)
viz.fit(X_train, y_train)
viz.score(X_test, y_test)
viz.show()
You can also use fitted estimators with the wrapper to avoid having to pass a partial function:

.. code:: python
from yellowbrick.regressor import prediction_error
# Create the OLS model
model = sm.OLS(y, X)
# Get the detailed results
results = model.fit()
print(results.summary())
# Visualize the prediction error
prediction_error(StatsModelWrapper(model), X, y, is_fitted=True)
This example also shows the use of a Yellowbrick oneliner, which is often more suited to the analytical style of statsmodels.


API Reference
-------------

.. automodule:: yellowbrick.contrib.statsmodels.base
:members: StatsModelsWrapper
:undoc-members:
Expand Down
104 changes: 104 additions & 0 deletions docs/api/contrib/wrapper.rst
@@ -0,0 +1,104 @@
.. -*- mode: rst -*-
Using Third-Party Estimators
============================

Many machine learning libraries implement the scikit-learn estimator API to easily integrate alternative optimization or decision methods into a data science workflow. Because of this, it seems like it should be simple to drop in a non-scikit-learn estimator into a Yellowbrick visualizer, and in principle, it is. However, the reality is a bit more complicated.

Yellowbrick visualizers often utilize more than just the method interface of estimators (e.g. ``fit()`` and ``predict()``), relying on the learned attributes (object properties with a single underscore suffix, e.g. ``coef_``). The issue is that when a third-party estimator does not expose these attributes, truly gnarly exceptions and tracebacks occur. Yellowbrick is meant to aid machine learning diagnostics reasoning, therefore instead of just allowing drop-in functionality that may cause confusion, we've created a wrapper functionality that is a bit kinder with it's messaging.

But first, an example.

.. code:: python
# Import the wrap function and a Yellowbrick visualizer
from yellowbrick.contrib.wrapper import wrap
from yellowbrick.model_selection import feature_importances
# Instantiate the third party estimator and wrap it, optionally fitting it
model = wrap(ThirdPartyEstimator())
model.fit(X_train, y_train)
# Use the visualizer
oz = feature_importances(model, X_test, y_test, is_fitted=True)
The ``wrap`` function initializes the third party model as a ``ContribEstimator``, which passes through all functionality to the underlying estimator, however if an error occurs, the exception that will be raised looks like:

.. code:: text
yellowbrick.exceptions.YellowbrickAttributeError: estimator is missing the 'fit'
attribute, which is required for this visualizer - please see the third party
estimators documentation.
Some estimators are required to pass type checking, for example the estimator must be a classifier, regressor, clusterer, density estimator, or outlier detector. A second argument can be passed to the ``wrap`` function declaring the type of estimator:

.. code:: python
from yellowbrick.classifier import precision_recall_curve
from yellowbrick.contrib.wrapper import wrap, CLASSIFIER
model = wrap(ThirdPartyClassifier(), CLASSIFIER)
precision_recall_curve(model, X, y)
Or you can simply use the wrap helper functions of the specific type:

.. code:: python
from yellowbrick.contrib.wrapper import regressor, classifier, clusterer
from yellowbrick.regressor import prediction_error
from yellowbrick.classifier import classification_report
from yellowbrick.cluster import intercluster_distance
reg = regressor(ThirdPartyRegressor())
prediction_error(reg, X, y)
clf = classifier(ThirdPartyClassifier())
classification_report(clf, X, y)
ctr = clusterer(ThirdPartyClusterer())
intercluster_distance(ctr, X)
So what should you do if a required attribute is missing from your estimator? The simplest and quickest thing to do is to subclass ``ContribEstimator`` and add the required functionality.

.. code:: python
from yellowbrick.contrib.wrapper import ContribEstimator, CLASSIFIER
class MyWrapper(ContribEstimator):
_estimator_type = CLASSIFIER
@property
def feature_importances_(self):
return self.estimator.tree_feature_importances()
model = MyWrapper(ThirdPartyEstimator()
feature_importances(model, X, y)
This is certainly less than ideal - but we'd welcome a contrib PR to add more native functionality to Yellowbrick!
Tested Libraries
----------------
The following libraries have been tested with the Yellowbrick wrapper.
- `xgboost <https://xgboost.readthedocs.io/en/latest/index.html>`_: both the ``XGBRFRegressor`` and ``XGBRFClassifier`` have been tested with Yellowbrick both with and without the wrapper functionality.
- `CatBoost <https://catboost.ai/>`_: the ``CatBoostClassifier`` has been tested with the ``ClassificationReport`` visualizer.
The following libraries have been partially tested and will likely work without too much additional effort:
- `cuML <https://github.com/rapidsai/cuml>`_: it is likely that clustering, classification, and regression cuML estimators will work with Yellowbrick visualizers. However, the cuDF datasets have not been tested with Yellowbrick.
- `Spark MLlib <https://spark.apache.org/docs/latest/ml-guide.html>`_: The Spark DataFrame API and estimators should work with Yellowbrick visualizers in a local notebook context after collection.
.. note:: If you have used a Python machine learning library not listed here with Yellowbrick, please let us know - we'd love to add it to the list! Also if you're using a library that is not wholly compatible, please open an issue so that we can explore how to integrate it with the ``yellowbrick.contrib`` module!
API Reference
-------------
.. automodule:: yellowbrick.contrib.wrapper
:members: wrap, classifier, regressor, clusterer, ContribEstimator
:undoc-members:
:show-inheritance:
3 changes: 3 additions & 0 deletions tests/requirements.txt
Expand Up @@ -18,3 +18,6 @@ nltk>=3.2
pandas>=1.0.4
umap-learn==0.3.9 # reminder to bump periodically!

# Third-Party Estimator Tests
# xgboost==1.2.0
# catboost==0.24.1

0 comments on commit 1329281

Please sign in to comment.