In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.mixture import GaussianMixture
from sklearn.metrics import silhouette_score, mean_squared_error, adjusted_rand_score
from sklearn.model_selection import KFold, train_test_split
from sklearn.preprocessing import StandardScaler
from scipy.optimize import linear_sum_assignment
from datetime import datetime
import matplotlib.patches as mpatches

In [None]:
# parameters to set

correlation_threshold = 0.8 # set this
dim_red_method = "umap" # out of {"umap", "pca", "tsne"}
perplexity = 50 # only used in tsne
standardization = True
group = "all" # out of {"male", "female", "all"}
survey_category = "stress" # out of {"stress", "depression", "needs"}
seed = 1 # for reproducability

In [None]:
dim_red_method_upper = dim_red_method.upper()
if (dim_red_method == "tsne"):
    if (standardization):
        df = pd.read_csv(f"working_data/{dim_red_method}_features_correlation_threshold_{correlation_threshold}_perplexity_{perplexity}_with_standardization_{group}_{survey_category}.csv")
    else:
        df = pd.read_csv(f"working_data/{dim_red_method}_features_correlation_threshold_{correlation_threshold}_perplexity_{perplexity}_without_standardization_{group}_{survey_category}.csv")
else:
    if (standardization):
        df = pd.read_csv(f"working_data/{dim_red_method}_features_correlation_threshold_{correlation_threshold}_with_standardization_{group}_{survey_category}.csv")
    else:
        df = pd.read_csv(f"working_data/{dim_red_method}_features_correlation_threshold_{correlation_threshold}_without_standardization_{group}_{survey_category}.csv")
X = df.iloc[:, 2:].to_numpy()
ids = df.iloc[:, :2] # USER_ID and WEEK_START as identifiers

In [None]:
# cross-validation to find most suitable number of clusters

X_trainval, X_test = train_test_split(X, test_size=0.20, random_state=seed, stratify=None)

scaler = StandardScaler().fit(X_trainval) # fit only on train+val
X_trainval_sc = scaler.transform(X_trainval)
X_test_sc = scaler.transform(X_test) # for final eval / plotting

components_range = range(2, 11)
aic_list, bic_list, mean_silhouettes, mean_aris = [], [], [], []

kf = KFold(n_splits=5, shuffle=True, random_state=seed)

for n in components_range:
    # reference model on all train+val data
    gmm_ref = GaussianMixture(n_components=n, random_state=seed).fit(X_trainval_sc)

    aic_list.append(gmm_ref.aic(X_trainval_sc))
    bic_list.append(gmm_ref.bic(X_trainval_sc))

    fold_sil, fold_ari = [], []

    for tr_idx, val_idx in kf.split(X_trainval_sc):
        X_tr, X_val = X_trainval_sc[tr_idx], X_trainval_sc[val_idx]

        gmm_cv = GaussianMixture(n_components=n, random_state=seed).fit(X_tr)
        lbl_val_cv  = gmm_cv.predict(X_val)
        lbl_val_ref = gmm_ref.predict(X_val)

        # silhouette in original feature space (scaled)
        fold_sil.append(silhouette_score(X_val, lbl_val_cv))

        fold_ari.append(adjusted_rand_score(lbl_val_ref, lbl_val_cv))

    mean_silhouettes.append(np.mean(fold_sil))
    mean_aris.append(np.mean(fold_ari))

    print(f"n={n:>2} | Silhouette={mean_silhouettes[-1]:.3f} | "
          f"AIC={aic_list[-1]:.0f} | BIC={bic_list[-1]:.0f} | "
          f"ARI={mean_aris[-1]:.3f}")

best_n = components_range[np.argmax(mean_silhouettes)]
print(f"Chosen n_components (by CV silhouette): {best_n}")

In [None]:
# AIC and BIC
plt.figure(figsize=(10, 4))
plt.plot(components_range, aic_list, label='AIC', marker='o')
plt.plot(components_range, bic_list, label='BIC', marker='s')
plt.axvline(best_n, ls='--', c='gray')
plt.xlabel('Number of Clusters')
plt.ylabel('Information Criterion')
#plt.title(f'GMM AIC/BIC vs. Number of Components ({dim_red_method_upper}, {group}, {survey_category})')
plt.legend()
plt.grid(True)

if dim_red_method == "tsne":
    plt.savefig(f'clustering_plots/aic_bic_GMM_on_{dim_red_method_upper}_correlation_{correlation_threshold}_perplexity_{perplexity}_{group}_{survey_category}.png',
                dpi=300, bbox_inches='tight', pad_inches=0.1, facecolor='white')
else:
    plt.savefig(f'clustering_plots/aic_bic_GMM_on_{dim_red_method_upper}_correlation_{correlation_threshold}_{group}_{survey_category}.png',
                dpi=300, bbox_inches='tight', pad_inches=0.1, facecolor='white')
plt.show()

# Silhouette
plt.figure(figsize=(10, 4))
plt.plot(components_range, mean_silhouettes, label='CV Silhouette Score', marker='o', color='tab:red')
plt.axvline(best_n, ls='--', c='gray')
plt.xlabel('Number of Clusters')
plt.ylabel('Silhouette Score')
#plt.title(f'GMM CV Silhouette Score vs. Number of Components ({dim_red_method_upper}, {group}, {survey_category})')
plt.legend()
plt.grid(True)

if dim_red_method == "tsne":
    plt.savefig(f'clustering_plots/silhouette_GMM_on_{dim_red_method_upper}_correlation_{correlation_threshold}_perplexity_{perplexity}_{group}_{survey_category}.png',
                dpi=300, bbox_inches='tight', pad_inches=0.1, facecolor='white')
else:
    plt.savefig(f'clustering_plots/silhouette_GMM_on_{dim_red_method_upper}_correlation_{correlation_threshold}_{group}_{survey_category}.png',
                dpi=300, bbox_inches='tight', pad_inches=0.1, facecolor='white')
plt.show()

# ARI
plt.figure(figsize=(10, 4))
plt.plot(components_range, mean_aris, label='CV ARI', marker='v', color='tab:green')
plt.axvline(best_n, ls='--', c='gray')
plt.xlabel('Number of Clusters')
plt.ylabel('Adjusted Rand Index')
#plt.title(f'GMM CV ARI vs. Number of Components ({dim_red_method_upper}, {group}, {survey_category})')
plt.legend()
plt.grid(True)

if dim_red_method == "tsne":
    plt.savefig(f'clustering_plots/ari_GMM_on_{dim_red_method_upper}_correlation_{correlation_threshold}_perplexity_{perplexity}_{group}_{survey_category}.png',
                dpi=300, bbox_inches='tight', pad_inches=0.1, facecolor='white')
else:
    plt.savefig(f'clustering_plots/ari_GMM_on_{dim_red_method_upper}_correlation_{correlation_threshold}_{group}_{survey_category}.png',
                dpi=300, bbox_inches='tight', pad_inches=0.1, facecolor='white')
plt.show()


In [None]:
# evaluation on test set

best_n_components = list(components_range)[np.nanargmax(mean_silhouettes)]
#best_n_components = 3

final_gmm = GaussianMixture(n_components=best_n_components, random_state=seed)
final_gmm.fit(X_trainval_sc)
labels_test = final_gmm.predict(X_test_sc) + 1
test_silhouette = silhouette_score(X_test_sc, labels_test)

print(f"Best number of components (based on CV silhouette): {best_n_components}")
print(f"Silhouette score on test set: {test_silhouette:.3f}")

scatter = plt.scatter(X_test_sc[:, 0], X_test_sc[:, 1], c=labels_test, cmap='turbo', s=1, alpha=0.2)
plt.xlabel(f"{dim_red_method_upper} Component 1")
plt.ylabel(f"{dim_red_method_upper} Component 2")
#plt.title(f"GMM with {best_n_components} clusters on {dim_red_method_upper} ({group}, {survey_category})")

unique_labels = np.unique(labels_test)
colors = [scatter.cmap(scatter.norm(label)) for label in unique_labels]
legend_patches = [mpatches.Patch(color=color, label=f"Cluster {label}") for label, color in zip(unique_labels, colors)]
plt.legend(handles=legend_patches)

if dim_red_method == "tsne":
    plt.savefig(f'clustering_plots/GMM_with_{best_n_components}_clusters_on_{dim_red_method_upper}_with_correlation_threshold_{correlation_threshold}_and_perplexity_{perplexity}_{group}_{survey_category}.png', 
                dpi=300, bbox_inches='tight', pad_inches=0.1, facecolor='white')
else:
    plt.savefig(f'clustering_plots/GMM_with_{best_n_components}_clusters_on_{dim_red_method_upper}_with_correlation_threshold_{correlation_threshold}_{group}_{survey_category}.png', 
                dpi=300, bbox_inches='tight', pad_inches=0.1, facecolor='white')

plt.show()

In [None]:
# saving labels of whole data

labels_all = final_gmm.predict(scaler.transform(X)) + 1

gmm_labels_df = df.iloc[:, :2].copy()  # USER_ID and WEEK_START
gmm_labels_df[f"{dim_red_method_upper}_1"] = X[:, 0]
gmm_labels_df[f"{dim_red_method_upper}_2"] = X[:, 1]
gmm_labels_df["cluster_label"] = labels_all


if dim_red_method == "tsne":
    label_path = f"working_data/cluster_labels/GMM_labels_{best_n_components}_clusters_on_{dim_red_method_upper}_correlation_threshold_{correlation_threshold}_perplexity_{perplexity}_{group}_{survey_category}.csv"
else:
    label_path = f"working_data/cluster_labels/GMM_labels_{best_n_components}_clusters_on_{dim_red_method_upper}_correlation_threshold_{correlation_threshold}_{group}_{survey_category}.csv"

gmm_labels_df.to_csv(label_path, index=False)