In [None]:
import os
import umap

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

from sklearn.metrics import davies_bouldin_score, adjusted_rand_score
from sklearn.mixture import GaussianMixture

In [None]:
models = ["uni_full_embeddings.csv", "conch_full_embeddings.csv", "ctranspath_full_embeddings.csv"]
titles = ["UNI", "CONCH", "CTransPath"]
label_col = "isup"
meta_path = "" #Here you need to include the path to the metadata file

meta_df = pd.read_csv(meta_path)

In [None]:
#DEBUG_SUBSET = 50 
DEBUG_SUBSET = None
cmap_name = "Set2"

text_positions = [(0.50, 0.90), (0.10, 0.10), (0.10, 0.85)]
fig, axes = plt.subplots(1, 3, figsize=(10, 6))
for i, filename in enumerate(models):
    # 1. Load and Prepare Data
    df = pd.read_csv(os.path.join("embeddings", filename))
    pro_df = df.T.reset_index().rename(columns={"index": "img_path"})
    merged_df = pd.merge(pro_df, meta_df[["img_path", "isup", "BLOCK#"]], on="img_path", how="inner")
    
    unique_labels = sorted(merged_df[label_col].unique())
    
    if DEBUG_SUBSET is not None:
        merged_df = merged_df.sample(n=min(DEBUG_SUBSET, len(merged_df)), random_state=161)

    X = merged_df.drop(["img_path", "isup", "BLOCK#"], axis=1)
    y = merged_df[label_col]
    y_codes = y.astype("category").cat.codes

    # 2. Dimensionality Reduction & Clustering
    reducer = umap.UMAP(random_state=161)
    embedding = reducer.fit_transform(X)

    gmm = GaussianMixture(n_components=len(unique_labels), covariance_type="full", random_state=161, n_init=5)
    clusters = gmm.fit_predict(X)
    ari = adjusted_rand_score(y_codes, clusters)
    dbi = davies_bouldin_score(X, y_codes)

    ax = axes[i]
    scatter = ax.scatter(embedding[:, 0], embedding[:, 1], c=y_codes, cmap=cmap_name, s=5, alpha=1.0)
    
    ax.set_axis_off()
    ax.set_title(titles[i])


    textstr = f"ARI: {ari:.3f}\nDBI: {dbi:.3f}"
    props = dict(boxstyle="round", facecolor="white", alpha=0.6, edgecolor="gray")
    pos_x, pos_y = text_positions[i]
    ax.text(pos_x, pos_y, textstr, transform=ax.transAxes, fontsize=9, va="bottom", ha="center", bbox=props)

color_mapper = plt.get_cmap(cmap_name)
norm = plt.Normalize(vmin=0, vmax=len(unique_labels)-1)
legend_handles = [mpatches.Patch(color=color_mapper(norm(idx)), label=f"ISUP {label}") 
                  for idx, label in enumerate(unique_labels)]

fig.legend(handles=legend_handles, loc="lower center", ncol=len(unique_labels), 
           bbox_to_anchor=(0.5, 0.02), title="ISUP Grades")

plt.tight_layout(rect=[0, 0.08, 1, 0.95])
plt.savefig(os.path.join("umap_embeddings", "umap_embeddings.png"), dpi=300)