Skip to content

Commit

Permalink
Use random_state parameter in KMeans clustering for reproducibility (
Browse files Browse the repository at this point in the history
…#834)

* Add random seeds to k-means clustering steps for neighborhood analysis

* Make sure subsampling for silhouette also gets a random seed

* Fix random_state parameter position in sample function for subsampling
  • Loading branch information
alex-l-kong committed Nov 17, 2022
1 parent 29fe021 commit e1af451
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 17 deletions.
29 changes: 19 additions & 10 deletions ark/analysis/spatial_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,9 @@ def create_neighborhood_matrix(all_data, dist_mat_dir, included_fovs=None, distl
return cell_neighbor_counts, cell_neighbor_freqs


def generate_cluster_matrix_results(all_data, neighbor_mat, cluster_num, excluded_channels=None,
included_fovs=None, cluster_label_col=settings.KMEANS_CLUSTER,
def generate_cluster_matrix_results(all_data, neighbor_mat, cluster_num, seed=42,
excluded_channels=None, included_fovs=None,
cluster_label_col=settings.KMEANS_CLUSTER,
fov_col=settings.FOV_ID, cell_type_col=settings.CELL_TYPE,
label_col=settings.CELL_LABEL,
pre_channel_col=settings.PRE_CHANNEL_COL,
Expand All @@ -554,6 +555,8 @@ def generate_cluster_matrix_results(all_data, neighbor_mat, cluster_num, exclude
cluster_num (int):
the optimal k to pass into k-means clustering to generate the final clusters
and corresponding results
seed (int):
the random seed to set for k-means clustering
excluded_channels (list):
all channel names to be excluded from analysis
included_fovs (list):
Expand Down Expand Up @@ -607,7 +610,7 @@ def generate_cluster_matrix_results(all_data, neighbor_mat, cluster_num, exclude

# generate cluster labels
cluster_labels = spatial_analysis_utils.generate_cluster_labels(
neighbor_mat_data, cluster_num)
neighbor_mat_data, cluster_num, seed=seed)

# add labels to neighbor mat
neighbor_mat_data_all[cluster_label_col] = cluster_labels
Expand Down Expand Up @@ -651,8 +654,9 @@ def generate_cluster_matrix_results(all_data, neighbor_mat, cluster_num, exclude
return all_data_clusters, num_cell_type_per_cluster, mean_marker_exp_per_cluster


def compute_cluster_metrics_inertia(neighbor_mat, min_k=2, max_k=10, included_fovs=None,
fov_col=settings.FOV_ID, label_col=settings.CELL_LABEL):
def compute_cluster_metrics_inertia(neighbor_mat, min_k=2, max_k=10, seed=42,
included_fovs=None, fov_col=settings.FOV_ID,
label_col=settings.CELL_LABEL):
"""Produce k-means clustering metrics to help identify optimal number of clusters using
inertia
Expand All @@ -663,6 +667,8 @@ def compute_cluster_metrics_inertia(neighbor_mat, min_k=2, max_k=10, included_fo
the minimum k we want to generate cluster statistics for, must be at least 2
max_k (int):
the maximum k we want to generate cluster statistics for, must be at least 2
seed (int):
the random seed to set for k-means clustering
included_fovs (list):
fovs to include in analysis. If argument is none, default is all fovs used.
fov_col (str):
Expand Down Expand Up @@ -695,14 +701,14 @@ def compute_cluster_metrics_inertia(neighbor_mat, min_k=2, max_k=10, included_fo

# generate the cluster score information
neighbor_cluster_stats = spatial_analysis_utils.compute_kmeans_inertia(
neighbor_mat_data=neighbor_mat_data, min_k=min_k, max_k=max_k)
neighbor_mat_data=neighbor_mat_data, min_k=min_k, max_k=max_k, seed=seed)

return neighbor_cluster_stats


def compute_cluster_metrics_silhouette(neighbor_mat, min_k=2, max_k=10, included_fovs=None,
fov_col=settings.FOV_ID, label_col=settings.CELL_LABEL,
subsample=None):
def compute_cluster_metrics_silhouette(neighbor_mat, min_k=2, max_k=10, seed=42,
included_fovs=None, fov_col=settings.FOV_ID,
label_col=settings.CELL_LABEL, subsample=None):
"""Produce k-means clustering metrics to help identify optimal number of clusters using
Silhouette score
Expand All @@ -713,6 +719,8 @@ def compute_cluster_metrics_silhouette(neighbor_mat, min_k=2, max_k=10, included
the minimum k we want to generate cluster statistics for, must be at least 2
max_k (int):
the maximum k we want to generate cluster statistics for, must be at least 2
seed (int):
the random seed to set for k-means clustering
included_fovs (list):
fovs to include in analysis. If argument is none, default is all fovs used.
fov_col (str):
Expand Down Expand Up @@ -749,7 +757,8 @@ def compute_cluster_metrics_silhouette(neighbor_mat, min_k=2, max_k=10, included

# generate the cluster score information
neighbor_cluster_stats = spatial_analysis_utils.compute_kmeans_silhouette(
neighbor_mat_data=neighbor_mat_data, min_k=min_k, max_k=max_k, subsample=subsample
neighbor_mat_data=neighbor_mat_data, min_k=min_k, max_k=max_k,
seed=seed, subsample=subsample
)

return neighbor_cluster_stats
22 changes: 15 additions & 7 deletions ark/utils/spatial_analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def compute_neighbor_counts(current_fov_neighborhood_data, dist_matrix, distlim,
return counts_pd, freqs_pd


def compute_kmeans_inertia(neighbor_mat_data, min_k=2, max_k=10):
def compute_kmeans_inertia(neighbor_mat_data, min_k=2, max_k=10, seed=42):
"""For a given neighborhood matrix, cluster and compute inertia using k-means clustering
from the range of k=min_k to max_k
Expand All @@ -492,6 +492,8 @@ def compute_kmeans_inertia(neighbor_mat_data, min_k=2, max_k=10):
the minimum k we want to generate cluster statistics for, must be at least 2
max_k (int):
the maximum k we want to generate cluster statistics for, must be at least 2
seed (int):
the random seed to set for k-means clustering
Returns:
xarray.DataArray:
Expand All @@ -508,13 +510,13 @@ def compute_kmeans_inertia(neighbor_mat_data, min_k=2, max_k=10):
# iterate over each k value
pb_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'
for n in tqdm(range(min_k, max_k + 1), bar_format=pb_format):
cluster_fit = KMeans(n_clusters=n).fit(neighbor_mat_data)
cluster_fit = KMeans(n_clusters=n, random_state=seed).fit(neighbor_mat_data)
cluster_stats.loc[n] = cluster_fit.inertia_

return cluster_stats


def compute_kmeans_silhouette(neighbor_mat_data, min_k=2, max_k=10, subsample=None):
def compute_kmeans_silhouette(neighbor_mat_data, min_k=2, max_k=10, seed=42, subsample=None):
"""For a given neighborhood matrix, cluster and compute Silhouette score using k-means
from the range of k=min_k to max_k
Expand All @@ -525,6 +527,8 @@ def compute_kmeans_silhouette(neighbor_mat_data, min_k=2, max_k=10, subsample=No
the minimum k we want to generate cluster statistics for, must be at least 2
max_k (int):
the maximum k we want to generate cluster statistics for, must be at least 2
seed (int):
the random seed to set for k-means clustering
subsample (int):
the number of cells that will be sampled from each neighborhood cluster for
calculating Silhouette score
Expand All @@ -545,7 +549,7 @@ def compute_kmeans_silhouette(neighbor_mat_data, min_k=2, max_k=10, subsample=No
# iterate over each k value
pb_format = '{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'
for n in tqdm(range(min_k, max_k + 1), bar_format=pb_format):
cluster_fit = KMeans(n_clusters=n).fit(neighbor_mat_data)
cluster_fit = KMeans(n_clusters=n, random_state=seed).fit(neighbor_mat_data)
cluster_labels = cluster_fit.labels_

sub_dat = neighbor_mat_data.copy()
Expand All @@ -554,7 +558,9 @@ def compute_kmeans_silhouette(neighbor_mat_data, min_k=2, max_k=10, subsample=No
if subsample is not None:
# Subsample each cluster
sub_dat = sub_dat.groupby("cluster").apply(
lambda x: x.sample(subsample, replace=len(x) < subsample)).reset_index(drop=True)
lambda x: x.sample(
subsample, replace=len(x) < subsample, random_state=seed)
).reset_index(drop=True)

cluster_score = sklearn.metrics.silhouette_score(sub_dat.drop("cluster", axis=1),
sub_dat["cluster"],
Expand All @@ -564,7 +570,7 @@ def compute_kmeans_silhouette(neighbor_mat_data, min_k=2, max_k=10, subsample=No
return cluster_stats


def generate_cluster_labels(neighbor_mat_data, cluster_num):
def generate_cluster_labels(neighbor_mat_data, cluster_num, seed=42):
"""Run k-means clustering with k=cluster_num
Give the same data, given several runs the clusters will always be the same,
Expand All @@ -575,13 +581,15 @@ def generate_cluster_labels(neighbor_mat_data, cluster_num):
neighborhood matrix data with only the desired fovs
cluster_num (int):
the k we want to use when running k-means clustering
seed (int):
the random seed to set for k-means clustering
Returns:
numpy.ndarray:
the neighborhood cluster labels assigned to each cell in neighbor_mat_data
"""

cluster_fit = KMeans(n_clusters=cluster_num).fit(neighbor_mat_data)
cluster_fit = KMeans(n_clusters=cluster_num, random_state=seed).fit(neighbor_mat_data)
# Add 1 to avoid cluster number 0
cluster_labels = cluster_fit.labels_ + 1

Expand Down

0 comments on commit e1af451

Please sign in to comment.