Skip to content

Commit

Permalink
Third Party Estimator Wrapper
Browse files Browse the repository at this point in the history
Adds a wrapper for estimators that implement the scikit-learn API but do
not extend BaseEstimator. If the estimator is missing required
properties (generally the learned attributes) then a sensible error is
raised. Includes documentation about how to use non-sklearn estimators
with Yellowbrick.

Fixes #1098
Fixes #1099
Related to #397
Related to #1066
  • Loading branch information
bbengfort committed Oct 3, 2020
1 parent 0b24d16 commit f4576eb
Show file tree
Hide file tree
Showing 6 changed files with 277 additions and 10 deletions.
125 changes: 125 additions & 0 deletions tests/test_contrib/test_wrapper.py
@@ -0,0 +1,125 @@
# tests.test_contrib.test_wrapper
# Tests third-party estimator wrapper
#
# Author: Benjamin Bengfort
# Created: Fri Oct 02 17:27:50 2020 -0400
#
# Copyright (C) 2020 The scikit-yb developers
# For license information, see LICENSE.txt
#
# ID: test_wrapper.py [] benjamin@bengfort.com $

"""
Tests third-party estimator wrapper
"""

##########################################################################
## Imports
##########################################################################

import pytest


from yellowbrick.contrib.wrapper import *
from yellowbrick.utils.helpers import get_model_name
from yellowbrick.exceptions import YellowbrickAttributeError
from yellowbrick.utils.types import is_estimator, is_probabilistic
from yellowbrick.utils.types import is_classifier, is_clusterer, is_regressor

from sklearn.base import is_regressor as sk_is_regressor
from sklearn.base import is_classifier as sk_is_classifier
from sklearn.base import is_outlier_detector as sk_is_outlier_detector


##########################################################################
## Mocks and Fixtures
##########################################################################

class ThirdPartyEstimator(object):

def __init__(self, **params):
for attr, param in params.items():
setattr(self, attr, param)

def fit(self, X, y=None):
return 42

def predict_proba(self, X, y=None):
return 24


##########################################################################
## Test Suite
##########################################################################

class TestContribWrapper(object):
"""
Third party ContribEstimator wrapper
"""

def test_wrapped_estimator(self):
"""
Check that the contrib wrapper passes through correctly
"""
tpe = ContribEstimator(ThirdPartyEstimator(foo="bar"), "foo")
assert tpe.fit([1,2,3]) == 42
assert tpe._estimator_type == "foo"

def test_attribute_error(self):
"""
Assert a correct exception is raised on failed access
"""
tpe = ContribEstimator(ThirdPartyEstimator())
with pytest.raises(YellowbrickAttributeError):
tpe.foo

def test_get_model_name(self):
"""
ContribWrapper should return the underlying model name
"""
tpe = ContribEstimator(ThirdPartyEstimator())
assert get_model_name(tpe) == "ThirdPartyEstimator"

def test_wraps_is_estimator(self):
"""
Assert a wrapped estimator passes is_estimator check
"""
tpe = wrap(ThirdPartyEstimator())
assert is_estimator(tpe)

def test_wraps_is_classifier(self):
"""
Assert a wrapped estimator passes is_classifier check
"""
tpe = classifier(ThirdPartyEstimator())
assert is_classifier(tpe)
assert sk_is_classifier(tpe)

def test_wraps_is_regressor(self):
"""
Assert a wrapped estimator passes is_regressor check
"""
tpe = regressor(ThirdPartyEstimator())
assert is_regressor(tpe)
assert sk_is_regressor(tpe)

def test_wraps_is_clusterer(self):
"""
Assert a wrapped estimator passes is_clusterer check
"""
tpe = clusterer(ThirdPartyEstimator())
assert is_clusterer(tpe)

def test_wraps_is_outlier_detector(self):
"""
Assert a wrapped estimator passes is_outlier_detector check
"""
tpe = wrap(ThirdPartyEstimator(), OUTLIER_DETECTOR)
assert sk_is_outlier_detector(tpe)

def test_wraps_is_probabilistic(self):
"""
Assert a wrapped estimator passes is_probabilistic check
"""
tpe = wrap(ThirdPartyEstimator())
assert is_probabilistic(tpe)
3 changes: 0 additions & 3 deletions yellowbrick/contrib/__init__.py
Expand Up @@ -5,6 +5,3 @@
#
#
# ID: __init__.py [a60bc41] nathan.danielsen@gmail.com $


from .scatter import ScatterViz, ScatterVisualizer, scatterviz
134 changes: 134 additions & 0 deletions yellowbrick/contrib/wrapper.py
@@ -0,0 +1,134 @@
# yellowbrick.contrib.wrapper
# Wrapper for third-party estimators that implement the sklearn API
#
# Author: Benjamin Bengfort
# Created: Fri Oct 02 14:47:54 2020 -0400
#
# Copyright (C) 2020 The scikit-yb developers
# For license information, see LICENSE.txt
#
# ID: wrapper.py [] benjamin@bengfort.com $

"""
Wrapper for third-party estimators that implement the sklearn API but do not directly
subclass the ``sklearn.base.BaseEstimator`` class. This method is a quick way to get
other estimators into Yellowbrick, while avoiding weird errors and issues.
"""

##########################################################################
## Imports
##########################################################################

from yellowbrick.exceptions import YellowbrickAttributeError


##########################################################################
## Module Constants
##########################################################################

CLASSIFIER = "classifier"
REGRESSOR = "regressor"
CLUSTERER = "clusterer"
DENSITY_ESTIMATOR = "DensityEstimator"
OUTLIER_DETECTOR = "outlier_detector"


##########################################################################
## Functional API
##########################################################################

def wrap(estimator, estimator_type=None):
"""
Wrap a third-party estimator that implements portions of the scikit-learn API to
make it available to Yellowbrick visualizers. If the Yellowbrick visualizer cannot
succeed, then a sensible error is raised instead.
Parameters
----------
estimator : object
The non-sklearn estimator to wrap and use for Visualizers
estimator_type : str, optional
One of "classifier", "regressor", "clusterer", "DensityEstimator", or
"outlier_detector" that allows the contrib estimator to pass the scikit-learn
``is_classifier``, etc. functions. If not specified, the _estimator_type attr
is passed through to the underlying estimator.
"""
return ContribEstimator(estimator, estimator_type)


def classifier(estimator):
"""
Wrap a third-party classifier to make it available to Yellowbrick visualizers.
Parameters
----------
estimator : object
The non-sklearn classifier to wrap and use for Visualizers
"""
return wrap(estimator, CLASSIFIER)


def regressor(estimator):
"""
Wrap a third-party regressor to make it available to Yellowbrick visualizers.
Parameters
----------
estimator : object
The non-sklearn regressor to wrap and use for Visualizers
"""
return wrap(estimator, REGRESSOR)


def clusterer(estimator):
"""
Wrap a third-party clusterer to make it available to Yellowbrick visualizers.
Parameters
----------
estimator : object
The non-sklearn clusterer to wrap and use for Visualizers
"""
return wrap(estimator, CLUSTERER)


##########################################################################
## ContribEstimator - Third Pary Estimator Wrapper
##########################################################################

class ContribEstimator(object):
"""
Wraps a third party estimator that implements the sckit-learn API and therefore
could be used with Yellowbrick but doesn't subclass ``BaseEstimator``. Since there
are a number of pitfalls, this object provides sensible errors and warnings rather
than completely blowing up, allowing contrib users to identify issues and fix them,
smoothing the path to getting third party estimators into the Yellowbrick ecosystem.
Parameters
----------
estimator : object
The non-sklearn estimator to wrap and use for Visualizers
estimator_type : str, optional
One of "classifier", "regressor", "clusterer", "DensityEstimator", or
"outlier_detector" that allows the contrib estimator to pass the scikit-learn
``is_classifier``, etc. functions. If not specified, the _estimator_type attr
is passed through to the underlying estimator.
"""

def __init__(self, estimator, estimator_type=None):
self.estimator = estimator
# Do not set estimator type if not specified to allow passthrough
if estimator_type:
self._estimator_type = estimator_type

def __getattr__(self, attr):
# proxy to the wrapped object
try:
return getattr(self.estimator, attr)
except AttributeError:
raise YellowbrickAttributeError((
"estimator is missing the '{}' attribute, which is required for this "
"visualizer - please see the third party estimators documentation."
).format(attr))
8 changes: 8 additions & 0 deletions yellowbrick/exceptions.py
Expand Up @@ -89,6 +89,14 @@ class YellowbrickKeyError(YellowbrickError, KeyError):
pass


class YellowbrickAttributeError(YellowbrickError, AttributeError):
"""
A required attribute is missing on the estimator.
"""

pass


##########################################################################
## Assertions
##########################################################################
Expand Down
12 changes: 7 additions & 5 deletions yellowbrick/utils/helpers.py
Expand Up @@ -28,12 +28,13 @@

from yellowbrick.utils.types import is_estimator
from yellowbrick.exceptions import YellowbrickTypeError
from yellowbrick.contrib.wrapper import ContribEstimator


##########################################################################
## Model and Feature Information
##########################################################################


def is_fitted(estimator):
"""
In order to ensure that we don't call ``fit`` on an already-fitted model,
Expand Down Expand Up @@ -137,11 +138,12 @@ def get_model_name(model):
"Cannot detect the model name for non estimator: '{}'".format(type(model))
)

if isinstance(model, Pipeline):
return get_model_name(model.steps[-1][-1])
elif isinstance(model, ContribEstimator):
return model.estimator.__class__.__name__
else:
if isinstance(model, Pipeline):
return get_model_name(model.steps[-1][-1])
else:
return model.__class__.__name__
return model.__class__.__name__


def has_ndarray_int_columns(features, X):
Expand Down
5 changes: 3 additions & 2 deletions yellowbrick/utils/types.py
Expand Up @@ -21,6 +21,7 @@
import numpy as np

from sklearn.base import BaseEstimator
from yellowbrick.contrib.wrapper import ContribEstimator


##########################################################################
Expand All @@ -39,9 +40,9 @@ def is_estimator(model):
Scikit-Learn estimator or Yellowbrick visualizer
"""
if inspect.isclass(model):
return issubclass(model, BaseEstimator)
return issubclass(model, (BaseEstimator, ContribEstimator))

return isinstance(model, BaseEstimator)
return isinstance(model, (BaseEstimator, ContribEstimator))


# Alias for closer name to isinstance and issubclass
Expand Down

0 comments on commit f4576eb

Please sign in to comment.