Skip to content

Commit

Permalink
Ensure inertia and silhouette score array initialization accounts for…
Browse files Browse the repository at this point in the history
… varying min_k (#820)
  • Loading branch information
alex-l-kong committed Nov 10, 2022
1 parent edeac2e commit f5d2c3d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
4 changes: 2 additions & 2 deletions ark/utils/spatial_analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ def compute_kmeans_inertia(neighbor_mat_data, min_k=2, max_k=10):
# create array we can store the results of each k for clustering
coords = [np.arange(min_k, max_k + 1)]
dims = ["cluster_num"]
stats_raw_data = np.zeros(max_k - 1)
stats_raw_data = np.zeros(max_k - min_k + 1)
cluster_stats = xr.DataArray(stats_raw_data, coords=coords, dims=dims)

# iterate over each k value
Expand Down Expand Up @@ -539,7 +539,7 @@ def compute_kmeans_silhouette(neighbor_mat_data, min_k=2, max_k=10, subsample=No
# create array we can store the results of each k for clustering
coords = [np.arange(min_k, max_k + 1)]
dims = ["cluster_num"]
stats_raw_data = np.zeros(max_k - 1)
stats_raw_data = np.zeros(max_k - min_k + 1)
cluster_stats = xr.DataArray(stats_raw_data, coords=coords, dims=dims)

# iterate over each k value
Expand Down
22 changes: 11 additions & 11 deletions ark/utils/spatial_analysis_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,29 +333,29 @@ def test_compute_neighbor_counts():
def test_compute_kmeans_inertia():
neighbor_mat = test_utils._make_neighborhood_matrix()[['feature1', 'feature2']]

neighbor_cluster_stats = spatial_analysis_utils.compute_kmeans_inertia(
neighbor_mat, max_k=3)
neighbor_inertia_stats = spatial_analysis_utils.compute_kmeans_inertia(
neighbor_mat, min_k=3, max_k=6)

# assert we have the right cluster_num values
assert list(neighbor_cluster_stats.coords["cluster_num"].values) == [2, 3]
assert list(neighbor_inertia_stats.coords["cluster_num"].values) == list(range(3, 7))

# assert k=3 produces the best inertia
three_cluster_score = neighbor_cluster_stats.loc[3].values
assert np.all(three_cluster_score <= neighbor_cluster_stats.values)
# assert k=6 produces the best inertia
best_inertia_score = neighbor_inertia_stats.loc[6].values
assert np.all(best_inertia_score <= neighbor_inertia_stats.values)


def test_compute_kmeans_silhouette():
neighbor_mat = test_utils._make_neighborhood_matrix()[['feature1', 'feature2']]

neighbor_cluster_stats = spatial_analysis_utils.compute_kmeans_silhouette(
neighbor_mat, max_k=3)
neighbor_silhouette_stats = spatial_analysis_utils.compute_kmeans_silhouette(
neighbor_mat, min_k=3, max_k=6)

# assert we have the right cluster_num values
assert list(neighbor_cluster_stats.coords["cluster_num"].values) == [2, 3]
assert list(neighbor_silhouette_stats.coords["cluster_num"].values) == list(range(3, 7))

# assert k=3 produces the best silhouette score
three_cluster_score = neighbor_cluster_stats.loc[3].values
assert np.all(three_cluster_score >= neighbor_cluster_stats.values)
best_silhouette_score = neighbor_silhouette_stats.loc[3].values
assert np.all(best_silhouette_score >= neighbor_silhouette_stats.values)


def test_generate_cluster_labels():
Expand Down

0 comments on commit f5d2c3d

Please sign in to comment.