# Notebook 1-UNI2-h: Embedding extraction

***Install & Imports***

In [2]:
!pip uninstall -y \
  libcugraph-cu12 pylibcugraph-cu12 \
  libraft-cu12 pylibraft-cu12 rmm-cu12


Found existing installation: libcugraph-cu12 25.6.0
Uninstalling libcugraph-cu12-25.6.0:
  Successfully uninstalled libcugraph-cu12-25.6.0
Found existing installation: pylibcugraph-cu12 25.6.0
Uninstalling pylibcugraph-cu12-25.6.0:
  Successfully uninstalled pylibcugraph-cu12-25.6.0
Found existing installation: libraft-cu12 25.2.0
Uninstalling libraft-cu12-25.2.0:
  Successfully uninstalled libraft-cu12-25.2.0
Found existing installation: pylibraft-cu12 25.2.0
Uninstalling pylibraft-cu12-25.2.0:
  Successfully uninstalled pylibraft-cu12-25.2.0
Found existing installation: rmm-cu12 25.2.0
Uninstalling rmm-cu12-25.2.0:
  Successfully uninstalled rmm-cu12-25.2.0


In [3]:
import os
import torch
import timm
import pandas as pd
from PIL import Image
from tqdm import tqdm
from huggingface_hub import login
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform



***Hugging Face Authentication ***

In [4]:
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

***Load Metadata (CRITICAL)***

In [5]:
CSV_PATH = "/kaggle/input/01-wsi-level-stratified-split/patches_metadata_with_split.csv"
PATCH_DIR = "/kaggle/input/camelyon-prepro-v4-patches/patches"

df = pd.read_csv(CSV_PATH)

# Sanity checks
assert set(df["split"].unique()) == {"train", "val", "test"}
assert {"patch_path", "wsi_id", "label", "split"}.issubset(df.columns)

print(df.head())
print(df["split"].value_counts())


                                          patch_path      wsi_id      x  \
0  /kaggle/working/patches/normal_074_x54016_y698...  normal_074  54016   
1  /kaggle/working/patches/normal_074_x18688_y744...  normal_074  18688   
2  /kaggle/working/patches/normal_074_x48640_y757...  normal_074  48640   
3  /kaggle/working/patches/normal_074_x50688_y757...  normal_074  50688   
4  /kaggle/working/patches/normal_074_x50944_y757...  normal_074  50944   

       y  label  split  
0  69888      0  train  
1  74496      0  train  
2  75776      0  train  
3  75776      0  train  
4  75776      0  train  
split
train    11700
test      2700
val       2400
Name: count, dtype: int64


***Load UNI Model (Feature-only)***

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"

timm_kwargs = {
    "img_size": 224,
    "patch_size": 14,
    "depth": 24,
    "num_heads": 24,
    "init_values": 1e-5,
    "embed_dim": 1536,
    "mlp_ratio": 2.66667 * 2,
    "num_classes": 0,          # no classifier head
    "no_embed_class": True,
    "mlp_layer": timm.layers.SwiGLUPacked,
    "act_layer": torch.nn.SiLU,
    "reg_tokens": 8,
    "dynamic_img_size": True,
}

model = timm.create_model(
    "hf-hub:MahmoodLab/UNI2-h",
    pretrained=True,
    **timm_kwargs
)

model = model.to(device)
model.eval()

for p in model.parameters():
    p.requires_grad = False


config.json:   0%|          | 0.00/587 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.73G [00:00<?, ?B/s]

***UNI Transform (Model-correct)***

In [7]:
transform = create_transform(
    **resolve_data_config(model.pretrained_cfg, model=model)
)

print(transform)

Compose(
    Resize(size=224, interpolation=bilinear, max_size=None, antialias=True)
    CenterCrop(size=(224, 224))
    MaybeToTensor()
    Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)


***Sanity Check (Single Patch)***

In [8]:
fname = os.path.basename(df.iloc[0]["patch_path"])
img_path = os.path.join(PATCH_DIR, fname)

img = Image.open(img_path).convert("RGB")
x = transform(img).unsqueeze(0).to(device)

with torch.inference_mode():
    emb = model(x)

print("Embedding shape:", emb.shape)  # [1, 1536]


Embedding shape: torch.Size([1, 1536])


***Patch → Embedding Extraction (CORE LOOP)***

In [9]:
OUT_PATH = "/kaggle/working/uni2_patch_embeddings.pt"

embeddings = []
meta_rows = []

model.eval()

with torch.inference_mode():
    for _, row in tqdm(df.iterrows(), total=len(df)):
        fname = os.path.basename(row["patch_path"])
        img_path = os.path.join(PATCH_DIR, fname)

        img = Image.open(img_path).convert("RGB")
        x = transform(img).unsqueeze(0).to(device)

        feat = model(x).squeeze(0).cpu()  # [1536]
        embeddings.append(feat)

        patch_id = os.path.splitext(fname)[0]

        meta_rows.append({
            "patch_id": patch_id,
            "wsi_id": row["wsi_id"],
            "label": int(row["label"]),
            "split": row["split"],     
        })


100%|██████████| 16800/16800 [25:46<00:00, 10.87it/s]


***Stack & Save (IDENTICAL FORMAT)***

In [10]:
embeddings = torch.stack(embeddings)  # [16800, 1536]

torch.save(
    {
        "embeddings": embeddings,
        "metadata": meta_rows,
    },
    OUT_PATH
)

print("Saved to:", OUT_PATH)
print("Embeddings shape:", embeddings.shape)


Saved to: /kaggle/working/uni2_patch_embeddings.pt
Embeddings shape: torch.Size([16800, 1536])


***Checkpoint Verification (Must Match H-Optimus)***

In [11]:
ckpt = torch.load(OUT_PATH, map_location="cpu")

print("Keys:", ckpt.keys())
print("Embeddings shape:", ckpt["embeddings"].shape)
print("Embeddings dtype:", ckpt["embeddings"].dtype)
print("Metadata entries:", len(ckpt["metadata"]))
print("First 3 metadata rows:", ckpt["metadata"][:3])


Keys: dict_keys(['embeddings', 'metadata'])
Embeddings shape: torch.Size([16800, 1536])
Embeddings dtype: torch.float32
Metadata entries: 16800
First 3 metadata rows: [{'patch_id': 'normal_074_x54016_y69888', 'wsi_id': 'normal_074', 'label': 0, 'split': 'train'}, {'patch_id': 'normal_074_x18688_y74496', 'wsi_id': 'normal_074', 'label': 0, 'split': 'train'}, {'patch_id': 'normal_074_x48640_y75776', 'wsi_id': 'normal_074', 'label': 0, 'split': 'train'}]


***WSI-level Sanity Check (IMPORTANT)***

In [12]:
from collections import defaultdict

meta_df = pd.DataFrame(ckpt["metadata"])

wsi_splits = {}
wsi_labels = {}

for _, r in meta_df.iterrows():
    wsi_splits[r["wsi_id"]] = r["split"]
    wsi_labels[r["wsi_id"]] = r["label"]

for split in ["train", "val", "test"]:
    wsis = [w for w, s in wsi_splits.items() if s == split]
    print(
        f"{split}: {len(wsis)} WSIs | "
        f"Normal: {sum(wsi_labels[w]==0 for w in wsis)} | "
        f"Tumor: {sum(wsi_labels[w]==1 for w in wsis)}"
    )


train: 39 WSIs | Normal: 17 | Tumor: 22
val: 8 WSIs | Normal: 3 | Tumor: 5
test: 9 WSIs | Normal: 4 | Tumor: 5
