Skip to content

Commit

Permalink
Add support for sample_weight in KElbowVisualizer (#1059)
Browse files Browse the repository at this point in the history
This PR closes #1057 by correctly passing scikit-learn estimator kwargs when calling estimator.fit in KElbowVisualizer, as well as introducing a new utility for determining parameters for methods that is used to achieve the same effect in the kelbow quick method.
  • Loading branch information
Express50 committed Apr 21, 2020
1 parent 779487c commit 674cad2
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 13 deletions.
53 changes: 43 additions & 10 deletions tests/test_cluster/test_elbow.py
Expand Up @@ -306,7 +306,10 @@ def test_calinski_harabasz_metric(self):
self.assert_images_similar(visualizer)
assert_array_almost_equal(visualizer.k_scores_, expected)

@pytest.mark.xfail(IS_WINDOWS_OR_CONDA, reason="computation of k_scores_ varies by 2.867 max absolute difference")
@pytest.mark.xfail(
IS_WINDOWS_OR_CONDA,
reason="computation of k_scores_ varies by 2.867 max absolute difference",
)
def test_locate_elbow(self):
"""
Test the addition of locate_elbow to an image
Expand All @@ -325,15 +328,7 @@ def test_locate_elbow(self):
visualizer.fit(X)
assert len(visualizer.k_scores_) == 5
assert visualizer.elbow_value_ == 3
expected = np.array(
[
4286.5,
12463.4,
8763.8,
6939.3,
5858.8,
]
)
expected = np.array([4286.5, 12463.4, 8763.8, 6939.3, 5858.8])

visualizer.finalize()
self.assert_images_similar(visualizer, tol=0.5, windows_tol=2.2)
Expand Down Expand Up @@ -400,6 +395,32 @@ def test_timings(self):

self.assert_images_similar(visualizer)

def test_sample_weights(self):
"""
Test that passing in sample weights correctly influences the clusterer's fit
"""
seed = 1234

# original data has 5 clusters
X, y = make_blobs(
n_samples=[5, 30, 30, 30, 30],
n_features=5,
random_state=seed,
shuffle=False,
)

visualizer = KElbowVisualizer(
KMeans(random_state=seed), k=(2, 12), timings=False
)
visualizer.fit(X)
assert visualizer.elbow_value_ == 5

# weights should push elbow down to 4
weights = np.concatenate([np.ones(5) * 0.0001, np.ones(120)])

visualizer.fit(X, sample_weight=weights)
assert visualizer.elbow_value_ == 4

@pytest.mark.xfail(reason="images not close due to timing lines")
def test_quick_method(self):
"""
Expand All @@ -414,3 +435,15 @@ def test_quick_method(self):
assert isinstance(oz, KElbowVisualizer)

self.assert_images_similar(oz)

def test_quick_method_params(self):
"""
Test the quick method correctly consumes the user-provided parameters
"""
X, y = make_blobs(centers=3)
custom_title = "My custom title"
model = KMeans(3, random_state=13)
oz = kelbow_visualizer(
model, X, sample_weight=np.ones(X.shape[0]), title=custom_title
)
assert oz.title == custom_title
36 changes: 36 additions & 0 deletions tests/test_utils/test_helpers.py
Expand Up @@ -173,6 +173,42 @@ def test_check_fitted(self):
assert check_fitted(model, is_fitted_by=True) is True
assert check_fitted(model, is_fitted_by=False) is False

@pytest.mark.parametrize(
"estimator",
[
SVC,
SVR,
Ridge,
KMeans,
RidgeCV,
GaussianNB,
MiniBatchKMeans,
LinearRegression,
],
ids=[
"SVC",
"SVR",
"Ridge",
"KMeans",
"RidgeCV",
"GaussianNB",
"MiniBatchKMeans",
"LinearRegression",
],
)
def test_get_param_names(self, estimator):
"""
Assert we successfully extract the parameters from sklearn estimators
"""
assert "sample_weight" in get_param_names(estimator.fit)

def test_get_param_names_type(self):
"""
Assert a type error is raised when passing a non-method
"""
with pytest.raises(TypeError):
get_param_names("test")


##########################################################################
## Numeric Function Tests
Expand Down
14 changes: 11 additions & 3 deletions yellowbrick/cluster/elbow.py
Expand Up @@ -28,7 +28,7 @@
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics.pairwise import pairwise_distances

from yellowbrick.utils import KneeLocator
from yellowbrick.utils import KneeLocator, get_param_names
from yellowbrick.style.palettes import LINE_COLOR
from yellowbrick.cluster.base import ClusteringScoreVisualizer
from yellowbrick.exceptions import YellowbrickValueError, YellowbrickWarning
Expand Down Expand Up @@ -309,7 +309,7 @@ def fit(self, X, y=None, **kwargs):

# Set the k value and fit the model
self.estimator.set_params(n_clusters=k)
self.estimator.fit(X)
self.estimator.fit(X, **kwargs)

# Append the time and score to our plottable metrics
self.k_timers_.append(time.time() - start)
Expand Down Expand Up @@ -415,6 +415,7 @@ def finalize(self):
## Quick Method
##########################################################################


def kelbow_visualizer(
model,
X,
Expand Down Expand Up @@ -487,6 +488,13 @@ def kelbow_visualizer(
viz : KElbowVisualizer
The kelbow visualizer, fitted and finalized.
"""
klass = type(model)

# figure out which kwargs correspond to fit method
fit_params = get_param_names(klass.fit)

fit_kwargs = {key: kwargs.pop(key) for key in fit_params if key in kwargs}

oz = KElbow(
model,
ax=ax,
Expand All @@ -496,7 +504,7 @@ def kelbow_visualizer(
locate_elbow=locate_elbow,
**kwargs
)
oz.fit(X, y)
oz.fit(X, y, **fit_kwargs)

if show:
oz.show()
Expand Down
30 changes: 30 additions & 0 deletions yellowbrick/utils/helpers.py
Expand Up @@ -19,6 +19,7 @@
##########################################################################

import re
import inspect
import sklearn
import numpy as np

Expand Down Expand Up @@ -185,6 +186,35 @@ def is_monotonic(a, increasing=True):
return np.all(a[1:] <= a[:-1], axis=0)


def get_param_names(method):
"""
Returns a list of keyword-only parameter names that may be
passed into method.
Parameters
----------
method : function
The method for which to return keyword-only parameters.
Returns
-------
parameters : list
A list of keyword-only parameter names for method.
"""
try:
signature = inspect.signature(method)
except (ValueError, TypeError) as e:
raise e

parameters = [
p
for p in signature.parameters.values()
if p.name != "self" and p.kind != p.VAR_KEYWORD
]

return sorted([p.name for p in parameters])


##########################################################################
## Numeric Computations
##########################################################################
Expand Down

0 comments on commit 674cad2

Please sign in to comment.