Skip to content

Commit

Permalink
Add tests for KElbowVisualizer sample weights
Browse files Browse the repository at this point in the history
  • Loading branch information
Express50 committed Apr 19, 2020
1 parent b994bb0 commit 354b025
Showing 1 changed file with 38 additions and 0 deletions.
38 changes: 38 additions & 0 deletions tests/test_cluster/test_elbow.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,34 @@ def test_timings(self):

self.assert_images_similar(visualizer)

def test_sample_weights(self):
"""
Test that passing in sample weights correctly influences the clusterer's fit
"""
seed = 1234

# original data has 5 clusters
X, y = make_blobs(n_samples=[5, 30, 30, 30, 30], n_features=5, random_state=seed, shuffle=False)

visualizer = KElbowVisualizer(
KMeans(random_state=seed), k=(2, 12), timings=False
)
visualizer.fit(X)
visualizer.finalize()
assert visualizer.elbow_value_ == 5

# weights should push elbow down to 4
weights = np.concatenate(
[
np.ones(5) * 0.0001,
np.ones(120),
]
)

visualizer.fit(X, sample_weight=weights)
visualizer.finalize()
assert visualizer.elbow_value_ == 4

@pytest.mark.xfail(reason="images not close due to timing lines")
def test_quick_method(self):
"""
Expand All @@ -414,3 +442,13 @@ def test_quick_method(self):
assert isinstance(oz, KElbowVisualizer)

self.assert_images_similar(oz)

def test_quick_method_params(self):
"""
Test the quick method correctly consumes the user-provided parameters
"""
X, y = make_blobs(centers=3)
custom_title = "My custom title"
model = KMeans(3, random_state=13)
oz = kelbow_visualizer(model, X, sample_weight=np.ones(X.shape[0]), title=custom_title)
assert oz.title == custom_title

0 comments on commit 354b025

Please sign in to comment.