diff --git a/tests/test_cluster/test_elbow.py b/tests/test_cluster/test_elbow.py index 894169612..e5a0e6672 100644 --- a/tests/test_cluster/test_elbow.py +++ b/tests/test_cluster/test_elbow.py @@ -306,7 +306,10 @@ def test_calinski_harabasz_metric(self): self.assert_images_similar(visualizer) assert_array_almost_equal(visualizer.k_scores_, expected) - @pytest.mark.xfail(IS_WINDOWS_OR_CONDA, reason="computation of k_scores_ varies by 2.867 max absolute difference") + @pytest.mark.xfail( + IS_WINDOWS_OR_CONDA, + reason="computation of k_scores_ varies by 2.867 max absolute difference", + ) def test_locate_elbow(self): """ Test the addition of locate_elbow to an image @@ -325,15 +328,7 @@ def test_locate_elbow(self): visualizer.fit(X) assert len(visualizer.k_scores_) == 5 assert visualizer.elbow_value_ == 3 - expected = np.array( - [ - 4286.5, - 12463.4, - 8763.8, - 6939.3, - 5858.8, - ] - ) + expected = np.array([4286.5, 12463.4, 8763.8, 6939.3, 5858.8]) visualizer.finalize() self.assert_images_similar(visualizer, tol=0.5, windows_tol=2.2) @@ -400,6 +395,32 @@ def test_timings(self): self.assert_images_similar(visualizer) + def test_sample_weights(self): + """ + Test that passing in sample weights correctly influences the clusterer's fit + """ + seed = 1234 + + # original data has 5 clusters + X, y = make_blobs( + n_samples=[5, 30, 30, 30, 30], + n_features=5, + random_state=seed, + shuffle=False, + ) + + visualizer = KElbowVisualizer( + KMeans(random_state=seed), k=(2, 12), timings=False + ) + visualizer.fit(X) + assert visualizer.elbow_value_ == 5 + + # weights should push elbow down to 4 + weights = np.concatenate([np.ones(5) * 0.0001, np.ones(120)]) + + visualizer.fit(X, sample_weight=weights) + assert visualizer.elbow_value_ == 4 + @pytest.mark.xfail(reason="images not close due to timing lines") def test_quick_method(self): """ @@ -414,3 +435,15 @@ def test_quick_method(self): assert isinstance(oz, KElbowVisualizer) self.assert_images_similar(oz) + + def test_quick_method_params(self): + """ + Test the quick method correctly consumes the user-provided parameters + """ + X, y = make_blobs(centers=3) + custom_title = "My custom title" + model = KMeans(3, random_state=13) + oz = kelbow_visualizer( + model, X, sample_weight=np.ones(X.shape[0]), title=custom_title + ) + assert oz.title == custom_title diff --git a/tests/test_utils/test_helpers.py b/tests/test_utils/test_helpers.py index 91d761ebe..2c92fc802 100644 --- a/tests/test_utils/test_helpers.py +++ b/tests/test_utils/test_helpers.py @@ -173,6 +173,42 @@ def test_check_fitted(self): assert check_fitted(model, is_fitted_by=True) is True assert check_fitted(model, is_fitted_by=False) is False + @pytest.mark.parametrize( + "estimator", + [ + SVC, + SVR, + Ridge, + KMeans, + RidgeCV, + GaussianNB, + MiniBatchKMeans, + LinearRegression, + ], + ids=[ + "SVC", + "SVR", + "Ridge", + "KMeans", + "RidgeCV", + "GaussianNB", + "MiniBatchKMeans", + "LinearRegression", + ], + ) + def test_get_param_names(self, estimator): + """ + Assert we successfully extract the parameters from sklearn estimators + """ + assert "sample_weight" in get_param_names(estimator.fit) + + def test_get_param_names_type(self): + """ + Assert a type error is raised when passing a non-method + """ + with pytest.raises(TypeError): + get_param_names("test") + ########################################################################## ## Numeric Function Tests diff --git a/yellowbrick/cluster/elbow.py b/yellowbrick/cluster/elbow.py index 9c720e49d..fa90f0190 100644 --- a/yellowbrick/cluster/elbow.py +++ b/yellowbrick/cluster/elbow.py @@ -28,7 +28,7 @@ from sklearn.preprocessing import LabelEncoder from sklearn.metrics.pairwise import pairwise_distances -from yellowbrick.utils import KneeLocator +from yellowbrick.utils import KneeLocator, get_param_names from yellowbrick.style.palettes import LINE_COLOR from yellowbrick.cluster.base import ClusteringScoreVisualizer from yellowbrick.exceptions import YellowbrickValueError, YellowbrickWarning @@ -309,7 +309,7 @@ def fit(self, X, y=None, **kwargs): # Set the k value and fit the model self.estimator.set_params(n_clusters=k) - self.estimator.fit(X) + self.estimator.fit(X, **kwargs) # Append the time and score to our plottable metrics self.k_timers_.append(time.time() - start) @@ -415,6 +415,7 @@ def finalize(self): ## Quick Method ########################################################################## + def kelbow_visualizer( model, X, @@ -487,6 +488,13 @@ def kelbow_visualizer( viz : KElbowVisualizer The kelbow visualizer, fitted and finalized. """ + klass = type(model) + + # figure out which kwargs correspond to fit method + fit_params = get_param_names(klass.fit) + + fit_kwargs = {key: kwargs.pop(key) for key in fit_params if key in kwargs} + oz = KElbow( model, ax=ax, @@ -496,7 +504,7 @@ def kelbow_visualizer( locate_elbow=locate_elbow, **kwargs ) - oz.fit(X, y) + oz.fit(X, y, **fit_kwargs) if show: oz.show() diff --git a/yellowbrick/utils/helpers.py b/yellowbrick/utils/helpers.py index 01294ab47..edaa0bb6b 100644 --- a/yellowbrick/utils/helpers.py +++ b/yellowbrick/utils/helpers.py @@ -19,6 +19,7 @@ ########################################################################## import re +import inspect import sklearn import numpy as np @@ -185,6 +186,35 @@ def is_monotonic(a, increasing=True): return np.all(a[1:] <= a[:-1], axis=0) +def get_param_names(method): + """ + Returns a list of keyword-only parameter names that may be + passed into method. + + Parameters + ---------- + method : function + The method for which to return keyword-only parameters. + + Returns + ------- + parameters : list + A list of keyword-only parameter names for method. + """ + try: + signature = inspect.signature(method) + except (ValueError, TypeError) as e: + raise e + + parameters = [ + p + for p in signature.parameters.values() + if p.name != "self" and p.kind != p.VAR_KEYWORD + ] + + return sorted([p.name for p in parameters]) + + ########################################################################## ## Numeric Computations ##########################################################################