# 🩺 SimCLR for Dermatology — Colab Notebook
A minimal, **self-contained** pipeline:  
* custom EfficientNet-B0 (no external imports)  
* 2×B×2×B InfoNCE loss  
* class-balanced sampling  
* adjustable label subset & per-class cap  
* automatic checkpoints  


In [1]:
#@title 📦 Install packages (run once) { display-mode: "form" }
!pip -q install albumentations==1.4.3 tqdm torchvision
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.0/137.0 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m117.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m51.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m57.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m42.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
#@title 🔗 Mount Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:

#@title 🛠️ Imports & utility funcs
from __future__ import annotations
import os, random, re
from collections import Counter
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms as T
from torch.cuda.amp import GradScaler, autocast

def set_seed(seed:int = 123):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# ───── Label normaliser used EVERYWHERE ──────
def _clean(name:str) -> str:
    """
    • turn underscores into spaces
    • collapse duplicate whitespace
    • Title-case the result
    """
    name = name.replace("_", " ")
    name = re.sub(r"\s+", " ", name).strip()
    return name.title()

In [4]:

#@title 📊 Build dataset summaries (creates 2 Excel files) { display-mode: "code" }
"""
Creates:
  • dataset_summary.xlsx  – one row per label
  • dataset_summary_full.xlsx – one row per label × split × modality
Both are saved beside the dataset.
"""
import re, pandas as pd
from collections import defaultdict, Counter
from pathlib import Path

ROOT = Path("/content/drive/MyDrive/derm-mmmodal/final_divided")
OUT1 = ROOT / "dataset_summary.xlsx"
OUT2 = ROOT / "dataset_summary_full.xlsx"

def count_images(folder:Path):
    exts = {".png", ".jpg", ".jpeg", ".bmp", ".gif", ".tif", ".tiff"}
    return sum(1 for f in folder.iterdir() if f.suffix.lower() in exts)

records  = []
totals   = defaultdict(Counter)

for split in ("train","val","test"):
    # 1️⃣ image_only
    for d in (ROOT/"image_only"/split).iterdir():
        if not d.is_dir(): continue
        lbl = _clean(d.name)
        n   = count_images(d)
        records.append(dict(label=lbl, split=split, modality="image_only", n=n))
        totals[lbl][f"image_only_{split}"] += n

    # 2️⃣ img_text  (count images via the XLSX)
    for xl in (ROOT/"img_text"/split).glob("*.xlsx"):
        df = pd.read_excel(xl)
        col = "label" if "label" in df.columns else "cat"
        df[col] = df[col].apply(_clean)
        for lbl, cnt in df[col].value_counts().items():
            records.append(dict(label=lbl, split=split,
                                modality="img_text", n=int(cnt)))
            totals[lbl][f"img_text_{split}"] += int(cnt)

    # 3️⃣ synthetic  (text only – counts matter for NLP tasks)
    for xl in (ROOT/"synthetic"/split).glob("*.xlsx"):
        df = pd.read_excel(xl)
        df["label"] = df["label"].apply(_clean)
        for lbl, cnt in df["label"].value_counts().items():
            records.append(dict(label=lbl, split=split,
                                modality="synthetic", n=int(cnt)))
            totals[lbl][f"synthetic_{split}"] += int(cnt)

    # 4️⃣ text_only  (text only – counts matter for NLP tasks)
    for xl in (ROOT/"text_only"/split).glob("*.xlsx"):
        df = pd.read_excel(xl)
        df["label"] = df["label"].apply(_clean)
        for lbl, cnt in df["label"].value_counts().items():
            records.append(dict(label=lbl, split=split,
                                modality="text_only", n=int(cnt)))
            totals[lbl][f"text_only_{split}"] += int(cnt)

# save long form
pd.DataFrame(records).to_excel(OUT2, index=False)

# save wide form
rows = []
for lbl, cnts in totals.items():
    row = {"label": lbl, **cnts, "total": sum(cnts.values())}
    rows.append(row)
(pd.DataFrame(rows).fillna(0)
   .astype({"total":int})
   .sort_values("label")
   .to_excel(OUT1, index=False))

print("✅  Saved:", OUT1.name, "&", OUT2.name)

✅  Saved: dataset_summary.xlsx & dataset_summary_full.xlsx


In [5]:

#@title 🗂️ Dataset + loaders (image branches only) { display-mode: "code" }
import random
from collections import Counter
from pathlib import Path

class ImgDS(Dataset):
    def __init__(self, paths, labels, tfm=None):
        self.paths, self.labels, self.tfm = paths, labels, tfm
    def __len__(self):  return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        if self.tfm: img = self.tfm(img)
        return img, self.labels[idx]

class TwoCrop:
    """Return two correlated views for SimCLR."""
    def __init__(self, tfm): self.tfm = tfm
    def __call__(self, x):   return self.tfm(x), self.tfm(x)

class SimCLRDS(Dataset):
    def __init__(self, base: ImgDS, two_crop: TwoCrop):
        self.base, self.two_crop = base, two_crop
    def __len__(self):  return len(self.base)
    def __getitem__(self, idx):
        img, lbl = self.base[idx]
        v1, v2   = self.two_crop(img)
        return v1, v2, lbl

# ─────────────────────────────────────────────────────────────────────
def _collect_split(root:Path, split:str, focus:set|None, cap:int|None):
    """Collect (paths, labels) for one split under image_only + img_text."""
    paths, labels = [], []

    # image_only
    for d in (root/"image_only"/split).iterdir():
        if not d.is_dir(): continue
        cls = _clean(d.name)
        if focus and cls not in focus: continue
        files = list(d.iterdir()); random.shuffle(files)
        take  = files[:cap] if cap else files
        paths.extend(take)
        labels.extend([cls]*len(take))

    # img_text
    xl_dir = root/"img_text"/split
    if xl_dir.exists():
        for xl in xl_dir.glob("*.xlsx"):
            df   = pd.read_excel(xl)
            col  = "label" if "label" in df.columns else "cat"
            df[col] = df[col].apply(_clean)
            if focus: df = df[df[col].isin(focus)]
            if cap:   df = df.groupby(col).head(cap)
            for _, row in df.iterrows():
                f = xl.parent/"images"/row["image"]
                if f.exists():
                    paths.append(f)
                    labels.append(row[col])

    return paths, labels

# ─────────────────────────────────────────────────────────────────────
def build_loaders(
    root_dir:str,
    wanted_labels:list[str]|None = None,
    img_sz:int   = 224,
    batch:int    = 64,
    cap:int|None = None,
    workers:int  = 2,
):
    root  = Path(root_dir)
    focus = set(map(_clean, wanted_labels)) if wanted_labels else None

    tr_p, tr_l = _collect_split(root, "train", focus, cap)
    va_p, va_l = _collect_split(root, "val",   focus, cap)
    te_p, te_l = _collect_split(root, "test",  focus, cap)

    classes = sorted({_clean(l) for l in tr_l+va_l+te_l})
    to_idx  = {c:i for i,c in enumerate(classes)}
    tr_l = [to_idx[l] for l in tr_l]
    va_l = [to_idx[l] for l in va_l]
    te_l = [to_idx[l] for l in te_l]

    norm   = T.Normalize([0.5]*3, [0.5]*3)
    aug    = T.Compose([
        T.RandomResizedCrop(img_sz, scale=(0.08,1.0)),
        T.RandomHorizontalFlip(),
        T.RandomApply([T.ColorJitter(0.8,0.8,0.8,0.2)], p=0.8),
        T.RandomGrayscale(0.2),
        T.GaussianBlur(23, sigma=(0.1,2.0)),
        T.RandomSolarize(128, p=0.2),
        T.ToTensor(), norm
    ])
    eval_t = T.Compose([
        T.Resize(int(img_sz*256/224)),
        T.CenterCrop(img_sz),
        T.ToTensor(), norm
    ])

    # class-balanced sampler for SimCLR
    β, cnt = 0.9999, Counter(tr_l)
    eff    = {c:(1-β**cnt[c])/(1-β) for c in cnt}
    wts    = [1/eff[l] for l in tr_l]
    sampler= WeightedRandomSampler(wts, len(wts), replacement=True)

    sim_ds   = SimCLRDS(ImgDS(tr_p, tr_l, None), TwoCrop(aug))
    probe_ds = ImgDS(tr_p, tr_l, eval_t)
    val_ds   = ImgDS(va_p, va_l, eval_t)
    test_ds  = ImgDS(te_p, te_l, eval_t)

    sim_dl   = DataLoader(sim_ds,   batch_size=batch, sampler=sampler,
                          num_workers=workers, drop_last=True, pin_memory=True)
    probe_dl = DataLoader(probe_ds, batch_size=batch, shuffle=True,
                          num_workers=workers, pin_memory=True)
    val_dl   = DataLoader(val_ds,   batch_size=batch, shuffle=False,
                          num_workers=workers, pin_memory=True)
    test_dl  = DataLoader(test_ds,  batch_size=batch, shuffle=False,
                          num_workers=workers, pin_memory=True)

    return sim_dl, probe_dl, val_dl, test_dl, classes

In [6]:

#@title 🧩 EfficientNet-B0 + SimCLR wrapper & FP16-safe InfoNCE { display-mode: "code" }
import torch
from torchvision import models

class EfficientNetB0(nn.Module):
    def __init__(self, num_classes:int=0, pretrained:bool=True):
        super().__init__()
        b0 = models.efficientnet_b0(
            weights = models.EfficientNet_B0_Weights.IMAGENET1K_V1
            if pretrained else None
        )
        self.features, self.avgpool = b0.features, b0.avgpool
        in_feats = b0.classifier[1].in_features
        self.classifier = (nn.Linear(in_feats, num_classes)
                           if num_classes else nn.Identity())
        self.num_classes = num_classes

    def extract_features(self, x):
        return torch.flatten(self.avgpool(self.features(x)), 1)

    def forward(self, x):
        feats = self.extract_features(x)
        return self.classifier(feats) if self.num_classes else feats

class SimCLRModel(nn.Module):
    def __init__(self, proj_dim:int=128, num_classes:int=0):
        super().__init__()
        self.backbone  = EfficientNetB0(num_classes, pretrained=True)
        self.projector = nn.Sequential(
            nn.Linear(1280,1280,bias=False),
            nn.BatchNorm1d(1280), nn.ReLU(inplace=True),
            nn.Linear(1280,proj_dim)
        )

    def forward(self, x, *, return_logits=False):
        v = self.backbone.extract_features(x)
        z = F.normalize(self.projector(v), dim=1)
        lg= self.backbone.classifier(v) if (return_logits and self.backbone.num_classes) else None
        return v, z, lg

class InfoNCELoss(nn.Module):
    def __init__(self, temperature:float = 0.2):
        super().__init__()
        self.t = temperature
    def forward(self, z1, z2):
        B = z1.size(0)
        z  = torch.cat([z1,z2],0)         # (2B,D)
        sim= (z @ z.T)/self.t             # cosine-scaled
        mask = torch.eye(2*B,device=sim.device,dtype=torch.bool)
        sim = sim.masked_fill(mask, torch.finfo(sim.dtype).min)
        targets = (torch.arange(2*B, device=sim.device) + B) % (2*B)
        return F.cross_entropy(sim, targets)

In [7]:

#@title 🚀 Train SimCLR (backbone) { display-mode: "form" }
DATA_ROOT       = "/content/drive/MyDrive/derm-mmmodal/final_divided"  #@param {type:"string"}
SELECTED_LABELS = [
    "Fungal Infection",
    "Vitiligo",
    "Psoriasis",
    "Impetigo",
    "Urticaria",
]
CAP_PER_CLASS   = 450   #@param {type:"integer"}
EPOCHS          = 30    #@param {type:"integer"}
BATCH           = 128    #@param {type:"integer"}

import torch, os
from tqdm.notebook import trange, tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
set_seed(123)
os.makedirs("checkpoints_final", exist_ok=True)

sim_dl, probe_dl, val_dl, test_dl, classes = build_loaders(
    DATA_ROOT, SELECTED_LABELS or None,
    img_sz=224, batch=BATCH, cap=CAP_PER_CLASS, workers=2
)

model   = SimCLRModel(proj_dim=128, num_classes=len(classes)).to(DEVICE)
loss_fn = InfoNCELoss(temperature=0.2)
opt     = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-6)
sched   = torch.optim.lr_scheduler.CosineAnnealingLR(
            opt, T_max=EPOCHS*len(sim_dl))
scaler  = GradScaler()

for ep in trange(1, EPOCHS+1, desc="SimCLR epochs"):
    model.train(); running = 0.0
    pbar = tqdm(sim_dl, desc="Batches", leave=False)
    for x1,x2,_ in pbar:
        x1,x2 = x1.to(DEVICE), x2.to(DEVICE)
        opt.zero_grad()
        with autocast():
            _,z1,_ = model(x1)
            _,z2,_ = model(x2)
            loss   = loss_fn(z1,z2)
        scaler.scale(loss).backward()
        scaler.step(opt); scaler.update(); sched.step()
        running += loss.item()
        pbar.set_postfix(loss=f"{loss.item():.4f}", refresh=False)

    avg = running/len(sim_dl)
    print(f"Epoch {ep:02d}/{EPOCHS} – InfoNCE {avg:.4f}")
    torch.save(model.state_dict(), f"checkpoints_final/pretrain_ep{ep:03d}.pth")


Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /root/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth
100%|██████████| 20.5M/20.5M [00:00<00:00, 191MB/s]
  scaler  = GradScaler()


SimCLR epochs:   0%|          | 0/30 [00:00<?, ?it/s]

Batches:   0%|          | 0/10 [00:00<?, ?it/s]

  with autocast():


Epoch 01/30 – InfoNCE 4.7980


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 02/30 – InfoNCE 4.1030


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 03/30 – InfoNCE 3.7724


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 04/30 – InfoNCE 3.5356


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 05/30 – InfoNCE 3.4060


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 06/30 – InfoNCE 3.2335


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 07/30 – InfoNCE 3.1963


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 08/30 – InfoNCE 3.0613


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 09/30 – InfoNCE 2.9410


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 10/30 – InfoNCE 2.9066


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 11/30 – InfoNCE 2.8444


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 12/30 – InfoNCE 2.8716


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 13/30 – InfoNCE 2.7647


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 14/30 – InfoNCE 2.7363


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 15/30 – InfoNCE 2.7078


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 16/30 – InfoNCE 2.6889


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 17/30 – InfoNCE 2.7297


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 18/30 – InfoNCE 2.6509


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 19/30 – InfoNCE 2.6085


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 20/30 – InfoNCE 2.6049


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 21/30 – InfoNCE 2.6321


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 22/30 – InfoNCE 2.5961


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 23/30 – InfoNCE 2.5738


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 24/30 – InfoNCE 2.5782


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 25/30 – InfoNCE 2.5650


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 26/30 – InfoNCE 2.5731


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 27/30 – InfoNCE 2.5717


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 28/30 – InfoNCE 2.5178


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 29/30 – InfoNCE 2.5454


Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 30/30 – InfoNCE 2.5339


In [8]:

#@title ▶️ Linear probe (on frozen features) { display-mode: "code" }
import torch, os
import torch.nn.functional as F
from torch import nn
from tqdm.notebook import trange, tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# loaders (same params)
_, probe_dl, val_dl, _, class_names = build_loaders(
    DATA_ROOT, SELECTED_LABELS or None,
    img_sz=224, batch=BATCH, cap=CAP_PER_CLASS, workers=2
)

# backbone + load last epoch
model = SimCLRModel(proj_dim=128, num_classes=len(class_names)).to(DEVICE)
model.load_state_dict(torch.load(f"checkpoints_final/pretrain_ep{EPOCHS:03d}.pth"))
model.eval();  [p.requires_grad_(False) for p in model.parameters()]

probe  = nn.Linear(1280, len(class_names)).to(DEVICE)
opt_p  = torch.optim.AdamW(probe.parameters(), lr=3e-4)
best   = 0.0

PROBE_EPOCHS = 30
for ep in trange(1, PROBE_EPOCHS+1, desc="Probe epochs"):
    probe.train()
    for x,y in tqdm(probe_dl, desc="Probe batches", leave=False):
        x,y = x.to(DEVICE), y.to(DEVICE)
        with torch.no_grad():
            feats,_,_ = model(x)
        loss = F.cross_entropy(probe(feats), y)
        opt_p.zero_grad(); loss.backward(); opt_p.step()

    # val
    probe.eval(); correct=total=0
    with torch.no_grad():
        for x,y in val_dl:
            x,y = x.to(DEVICE), y.to(DEVICE)
            preds = probe(model(x)[0]).argmax(1)
            correct += (preds==y).sum().item(); total += y.size(0)
    acc = 100*correct/total
    if acc>best:
        best = acc; torch.save(probe.state_dict(),"checkpoints_final/best_probe.pth")
    print(f"Epoch {ep:02d}/{PROBE_EPOCHS} – Val {acc:.2f}% (best {best:.2f}%)")



Probe epochs:   0%|          | 0/30 [00:00<?, ?it/s]

Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 01/30 – Val 55.23% (best 55.23%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 02/30 – Val 65.00% (best 65.00%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 03/30 – Val 67.27% (best 67.27%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 04/30 – Val 70.23% (best 70.23%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 05/30 – Val 70.00% (best 70.23%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 06/30 – Val 72.05% (best 72.05%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 07/30 – Val 72.50% (best 72.50%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 08/30 – Val 73.41% (best 73.41%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 09/30 – Val 72.73% (best 73.41%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 10/30 – Val 74.32% (best 74.32%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 11/30 – Val 73.86% (best 74.32%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 12/30 – Val 74.55% (best 74.55%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 13/30 – Val 75.23% (best 75.23%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 14/30 – Val 75.23% (best 75.23%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 15/30 – Val 74.77% (best 75.23%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 16/30 – Val 74.77% (best 75.23%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 17/30 – Val 74.77% (best 75.23%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 18/30 – Val 75.23% (best 75.23%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 19/30 – Val 75.23% (best 75.23%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 20/30 – Val 75.45% (best 75.45%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 21/30 – Val 75.45% (best 75.45%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 22/30 – Val 75.68% (best 75.68%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 23/30 – Val 75.91% (best 75.91%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 24/30 – Val 76.14% (best 76.14%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 25/30 – Val 76.82% (best 76.82%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 26/30 – Val 77.05% (best 77.05%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 27/30 – Val 77.05% (best 77.05%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 28/30 – Val 77.50% (best 77.50%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 29/30 – Val 77.27% (best 77.50%)


Probe batches:   0%|          | 0/11 [00:00<?, ?it/s]

Epoch 30/30 – Val 77.27% (best 77.50%)


In [9]:

#@title 🧪 Evaluate best probe on test split { display-mode: "code" }
import torch
from tqdm.notebook import tqdm
from torch import nn

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# rebuild test DL
_,_,_, test_dl, classes = build_loaders(
    DATA_ROOT, SELECTED_LABELS or None,
    img_sz=224, batch=BATCH, cap=CAP_PER_CLASS, workers=2
)

model = SimCLRModel(proj_dim=128, num_classes=len(classes)).to(DEVICE)
model.load_state_dict(torch.load(f"checkpoints_final/pretrain_ep{EPOCHS:03d}.pth"))
model.eval()

probe = nn.Linear(1280, len(classes)).to(DEVICE)
probe.load_state_dict(torch.load("checkpoints_final/best_probe.pth"))
probe.eval()

correct=total=0
for x,y in tqdm(test_dl, desc="Test batches"):
    x,y = x.to(DEVICE), y.to(DEVICE)
    preds = probe(model(x)[0]).argmax(1)
    correct += (preds==y).sum().item(); total += y.size(0)

print(f"🏁 Final test accuracy: {100*correct/total:.2f}%")




Test batches:   0%|          | 0/4 [00:00<?, ?it/s]

🏁 Final test accuracy: 74.45%


In [10]:
from google.colab import drive
import shutil
import os

# 2️⃣ Define source and destination paths
src  = '/content/checkpoints_final'
dst  = '/content/drive/MyDrive/checkpoints_final'

# 3️⃣ Move (or rename) the folder in one go
if os.path.exists(src):
    # If dst already exists, remove it first (optional)
    if os.path.exists(dst):
        shutil.rmtree(dst)
    shutil.move(src, dst)
    print(f"✔️  Moved `{src}` → `{dst}`")
else:
    print(f"❌  Source folder not found: {src}")


✔️  Moved `/content/checkpoints_final` → `/content/drive/MyDrive/checkpoints_final`
