# MegaMedical multi-dataset embedding clustering

Load images from multiple MegaMedical tasks, embed with 5 encoders, then visualize PCA and k-means clustering.


In [None]:
# --- Imports ---
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

# Ensure repo and vendored deps are on path
repo_root = Path("/data/ddmg/mvseg-ordering/")
for path in [repo_root, repo_root / "UniverSeg", repo_root / "MultiverSeg"]:
    if str(path) not in sys.path:
        sys.path.append(str(path))

from experiments.dataset.mega_medical_dataset import MegaMedicalDataset
from experiments.encoders.multiverseg_encoder import MultiverSegEncoder
from experiments.encoders.clip import CLIPEncoder
from experiments.encoders.vit import ViTEncoder
from experiments.encoders.dinov2 import DinoV2Encoder
from experiments.encoders.medsam import MedSAMEncoder


In [None]:
# --- Config ---

# Pick device
device = torch.device('cpu')

# Use MegaMedical target indices for easy editing.
# Each entry is (target_index, n_samples).
# Edit indices + counts to control which datasets and how many images.
mega_targets = [
    (0, 3),
    (29, 3),
    (12, 4),
]

# Random seed (used for sampling within each target dataset)
seed = 23
np.random.seed(seed)
torch.manual_seed(seed)


In [7]:
# --- Load images across MegaMedical target indices ---

def load_target_images(target_index, n_samples, *, split='train', seed=23):
    ds = MegaMedicalDataset(dataset_target=target_index, split=split, seed=seed)
    indices = ds.get_data_indices()
    print(indices)
    rng = np.random.default_rng(seed + int(target_index))
    if n_samples > len(indices):
        n_samples = len(indices)
    pick = rng.choice(indices, size=n_samples, replace=False).tolist()
    images = []
    for idx in pick:
        img, _ = ds.get_item_by_data_index(idx)
        images.append(img)
    return images, pick

all_images = []
all_labels = []
all_meta = []

for target_index, n_samples in mega_targets:
    imgs, indices = load_target_images(target_index, n_samples, seed=seed)
    all_images.extend(imgs)
    all_labels.extend([f"target_{target_index}"] * len(imgs))
    for idx in indices:
        all_meta.append({
            'mega_target_index': int(target_index),
            'index': int(idx),
        })

print(f"Loaded {len(all_images)} images across {len(mega_targets)} MegaMedical targets")


No updates to index
Filtered task_df: 1248
got task df: 1248


The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  df["label_type"].fillna("soft", inplace=True)


target_datasets: 1248
[11 19 24 21  7  1 12 22  2  3 13 14  5 18  8  6 10 23  4  9 16  0 17 15
 20]
No updates to index


The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  df["label_type"].fillna("soft", inplace=True)


Filtered task_df: 1248
got task df: 1248
target_datasets: 1248
[245 166 240  21  45  25 163  61 236  10 102 257  38 118 242  60 156 183
 233  85  36 244   1 221  42  87 246 199 250 220  90 202 215  67 134  41
 149 186  70 146 253 141 112 173 108 212 175 140 206 135 160 180 101  35
 237 205  89 195  91 241 208 116  39   2 182 104 103  32 138  44 159 259
 151  17 153  93  71 235  30  47 203 225 181  73 193  72 258 172  82 254
 248 152  18 228  15  24 158  92 210  96  75 232  27 143 179   3 105 128
  23  22  57 127 256 214  83  74 223 136 229 167  63 211  68 190  50 122
   7 123 157 107   5 168 191 184  46  55 189 219 100 209 115 226 197  51
   0 129 188 249 216 165 114 139 224  95  94  33 142  98  86  28 218 252
  66  76  77 207 137 194 171  20  40 124 155 230 227 176 106 162 144 169
   9 130  84  81 177 231 113 147  26 109 164   8  88  80 243 111 170  29
 126  97 174  49 255  59  99  13  78 117 154 150  48 161 185 198 204  58
 238 125  79  16   6 119  37 133  64 131 200  12 148 217  65 

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  df["label_type"].fillna("soft", inplace=True)


Filtered task_df: 1248
got task df: 1248
target_datasets: 1248
[44 45  7 29 55  2 13 40 38 26 36 39 18 10 33 57 24 59 22 35 54 56 27 15
 16  6 23 28 46 37  3  8 12 20 25  1 21  0  9 30 14 34 19 51 43 11  5 53
 49 31 42 58 50  4 41 48 32 17 47 52]
Loaded 10 images across 3 MegaMedical targets


In [None]:
# --- Visualize selected images ---
n = len(all_images)
cols = min(5, n)
rows = (n + cols - 1) // cols
fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
if rows == 1:
    axes = [axes] if cols == 1 else axes
axes_flat = axes if isinstance(axes, (list, tuple)) else axes.flatten()
for i in range(rows * cols):
    ax = axes_flat[i]
    if i < n:
        img = all_images[i].detach().cpu().squeeze().numpy()
        ax.imshow(img, cmap='gray')
        ax.set_title(all_labels[i])
    ax.axis('off')
fig.tight_layout()
plt.show()


In [None]:
# --- Build encoders ---
encoders = {
    'multiverseg': MultiverSegEncoder(pooling='gap_gmp'),
    'clip': CLIPEncoder(model_name='ViT-B-32', pretrained='openai'),
    'vit': ViTEncoder(model_name='vit_b_16', pretrained=True),
    'dinov2': DinoV2Encoder(model_name='facebook/dinov2-base'),
    'medsam': MedSAMEncoder(model_type='vit_b'),
}

for k in encoders:
    encoders[k] = encoders[k].to(device).eval()


In [6]:
# --- Compute embeddings ---
@torch.no_grad()
def embed_images(encoder, images):
    embs = []
    for img in images:
        img = img.to(device)
        emb = encoder(img)
        if emb.dim() > 1:
            emb = emb.squeeze(0)
        embs.append(emb.detach().cpu().numpy())
    return np.stack(embs, axis=0)

embeddings = {name: embed_images(enc, all_images) for name, enc in encoders.items()}


KeyboardInterrupt: 

In [None]:
# --- PCA + KMeans visualization ---

label_names = sorted(set(all_labels))
label_to_id = {l: i for i, l in enumerate(label_names)}
colors = [label_to_id[l] for l in all_labels]

for name, emb in embeddings.items():
    pca = PCA(n_components=2, random_state=seed)
    emb2 = pca.fit_transform(emb)

    # k-means clustering (k = number of datasets)
    k = len(label_names)
    km = KMeans(n_clusters=k, random_state=seed, n_init='auto')
    clusters = km.fit_predict(emb2)

    fig, axes = plt.subplots(1, 2, figsize=(10, 4))
    axes[0].scatter(emb2[:, 0], emb2[:, 1], c=colors, cmap='tab10')
    axes[0].set_title(f"{name} PCA (colored by dataset)")
    axes[1].scatter(emb2[:, 0], emb2[:, 1], c=clusters, cmap='tab10')
    axes[1].set_title(f"{name} PCA (k-means clusters)")
    for ax in axes:
        ax.set_xlabel('PC1')
        ax.set_ylabel('PC2')
    fig.tight_layout()
    plt.show()


In [None]:
# --- Per-cluster image thumbnails (for each encoder) ---
# Shows all images per cluster.
for encoder_name, emb in embeddings.items():
    pca = PCA(n_components=2, random_state=seed)
    emb2 = pca.fit_transform(emb)
    k = len(set(all_labels))
    km = KMeans(n_clusters=k, random_state=seed, n_init='auto')
    clusters = km.fit_predict(emb2)

    cluster_ids = sorted(set(clusters))
    max_per_cluster = max((clusters == cid).sum() for cid in cluster_ids)
    fig, axes = plt.subplots(len(cluster_ids), max_per_cluster, figsize=(3 * max_per_cluster, 3 * len(cluster_ids)))
    if len(cluster_ids) == 1:
        axes = [axes]
    for row, cid in enumerate(cluster_ids):
        idxs = [i for i, c in enumerate(clusters) if c == cid]
        for col in range(max_per_cluster):
            ax = axes[row][col] if len(cluster_ids) > 1 else axes[col]
            if col < len(idxs):
                img = all_images[idxs[col]].detach().cpu().squeeze().numpy()
                ax.imshow(img, cmap='gray')
                ax.set_title(f"cluster {cid}")
            ax.axis('off')
    fig.suptitle(f'Cluster thumbnails ({encoder_name})', y=1.02)
    fig.tight_layout()
    plt.show()


**Notes**
- Edit `task_specs` to select different tasks/labels/slices.
- If you want more samples, increase `n_samples` per task.
- CLIP/ViT/DINOv2 can be slow on CPU; use GPU if available.
