In [None]:
# Imports
import numpy as np
from scipy.spatial.distance import pdist, cdist, squareform
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
import sys
sys.path.append("..")

from sklearn.cluster import KMeans, SpectralClustering, AgglomerativeClustering
from hdbscan import HDBSCAN
from sklearn.metrics import silhouette_score, davies_bouldin_score

from utils import get_model_checkpoint_filepaths

In [None]:
# Get last word embeddings from training
checkpoint_filepaths_dict = get_model_checkpoint_filepaths(
    output_dir="../output/word2vec_training/03-Oct-2020_15-00-00",
    model_name="word2vec_sgns",
    dataset_name="enwiki",
)
last_embedding_weights_filepath = checkpoint_filepaths_dict["intermediate_embedding_weight_filepaths"][-1]
last_embedding_weights = np.load(last_embedding_weights_filepath, mmap_mode="r").astype(np.float64)

In [None]:
def dunn_index(X, labels):
    """
    TODO: Docs
    """
    unique_labels = np.unique(labels)
    # print([len(pdist(X[labels==lab])) for lab in unique_labels])
    diam = np.max([np.max(pdist(X[labels==lab]), initial=0) for lab in unique_labels])
    sep = np.min([np.min(cdist(X[labels==unique_labels[i]], X[labels==unique_labels[j]])) 
                  for i in range(len(unique_labels)) for j in range(i)])
    return sep/diam

def sd_validity_index(X, labels):
    """
    TODO: Docs
    """
    unique_labels = np.unique(labels)
    scat = np.mean([np.linalg.norm(np.var(X[labels==lab], axis=0)) 
                    for lab in unique_labels]) / np.linalg.norm(np.var(X, axis=0))
    
    centers = np.array([np.mean(X[labels==lab], axis=0) for lab in unique_labels])
    center_dists = pdist(centers)
    dis = np.sum(1/np.sum(squareform(center_dists), axis = 0)) * np.max(center_dists) / np.min(center_dists)
    return scat + dis

In [None]:
def evaluate_cluster_methods(
    word_embeddings: np.ndarray,
    vocab_size: int,
    cluster_classes: list,
    cluster_metrics: list,
    cluster_numbers: list,
) -> dict:
    """
    TODO: Docs
    """
    X = word_embeddings[:vocab_size]
    clusterer_results = {clusterer_name: {} for clusterer_name, _, _ in cluster_classes}
    for cluster_name, cluster_cls, kwargs in cluster_classes:
        print(f"--- Evaluating {cluster_name}... ---")
        if cluster_name != "HDBSCAN":
            clusterer_results[cluster_name] = {
                "labels": {},
                "metric_scores": {metric_name: [] for metric_name, _, _ in cluster_metrics},
                "optimal_metric_score_indices": {}
            }
            for k in tqdm(ks, desc="Fitting and predicting"):
                cls = cluster_cls(n_clusters=k, **kwargs)
                cluster_labels = cls.fit_predict(X)
                clusterer_results[cluster_name]["labels"][k] = cluster_labels
                
                for metric_name, metric_func, _ in cluster_metrics:
                    metric_score = metric_func(X, cluster_labels)
                    clusterer_results[cluster_name]["metric_scores"][metric_name].append(metric_score)
                
                for metric_name, _, metric_opt_value_idx_func in cluster_metrics:
                    opt_value_idx = metric_opt_value_idx_func(
                        clusterer_results[cluster_name]["metric_scores"][metric_name]
                    )
                    clusterer_results[cluster_name]["optimal_metric_score_indices"][metric_name] = opt_value_idx
        else:
            cls = cluster_cls()
            cluster_labels = cls.fit_predict(X)
            clusterer_results[cluster_name] = {
                "labels": cluster_labels,
                "metric_scores": {metric_name: [] for metric_name, _, _ in cluster_metrics},
                "optimal_metric_score_indices": {}
            }
            
            for metric_name, metric_func, _ in cluster_metrics:
                metric_score = metric_func(X, cluster_labels)
                print(f"- {metric_name}: {metric_score:.3f}")
                clusterer_results[cluster_name]["metric_scores"][metric_name].append(metric_score)
            
            for metric_name, _, metric_opt_value_idx_func in cluster_metrics:
                opt_value_idx = metric_opt_value_idx_func(
                    clusterer_results[cluster_name]["metric_scores"][metric_name]
                )
                clusterer_results[cluster_name]["optimal_metric_score_indices"][metric_name] = opt_value_idx
    return clusterer_results

In [None]:
cluster_classes = [
    ("K-means clustering", KMeans, {}),
    ("Spectral clustering", SpectralClustering, {"n_jobs": -1}),
    ("Agglomerative clustering", AgglomerativeClustering, {}),
    ("HDBSCAN", HDBSCAN, {"core_dist_n_jobs": -1})
]
cluster_metrics = [
    ("Average silhouette score", silhouette_score, np.argmax),
    ("Davies-Bouldin score", davies_bouldin_score, np.argmin),
    ("Dunn index", dunn_index, np.argmax),
    ("SD validity index", sd_validity_index, np.argmin)
]
max_cluster_num = 100
ks = list(range(2, max_cluster_num + 1))

In [None]:
# Perform evaluation
cluster_results_dict = evaluate_cluster_methods(
    word_embeddings=last_embedding_weights,
    vocab_size=1000,
    cluster_classes=cluster_classes,
    cluster_metrics=cluster_metrics,
    cluster_numbers=ks
)

In [None]:
def visualize_cluster_results(cluster_results: dict) -> None:
    """
    TODO: Docs
    """
    for cluster_name, cluster_content in cluster_results_dict.items():
        metric_names = list(cluster_content["metric_scores"].keys())
        
        if cluster_name != "HDBSCAN":
            print(f"-- Visualizing metrics for {cluster_name} --")
            ks = list(cluster_content["labels"].keys())
            fig, axes = plt.subplots(nrows=1, ncols=len(metric_names), figsize=(3.25 * len(metric_names), 3))
            for metric_name, ax in zip(metric_names, axes.ravel()):
                metric_scores = cluster_content["metric_scores"][metric_name]
                optimal_metric_score_idx = cluster_content["optimal_metric_score_indices"][metric_name]
                ax.set_title(metric_name)
                ax.set_xlabel("Cluster number")
                ax.set_ylabel("Metric value")
                ax.scatter(ks, metric_scores)
                ax.plot(ks, metric_scores)
                ax.plot(ks[optimal_metric_score_idx], metric_scores[optimal_metric_score_idx], 'ro')
            plt.tight_layout()
            plt.show()
        else:
            print(f"-- Printing metrics for {cluster_name} --")
            for metric_name in metric_names:
                metric_value = cluster_content["metric_scores"][metric_name][0]
                print(f"{metric_name}: {metric_value}")

In [None]:
visualize_cluster_results(cluster_results_dict)

In [None]:
# TODO: Synonym words should be in the same cluster.