In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
!pip install timm pandas --quiet

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
import timm
from sklearn.model_selection import train_test_split


# DINO 2

In [None]:
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import timm

# -----------------------
# Config
# -----------------------
IMAGE_SIZE = 224
LABEL_COL = "label"
CSV_PATH = "/content/drive/MyDrive/FYP/foci_labels_v02.csv"

# EXACT supervised split sizes
SUP_SPLIT = {
    0: {"train": 730,  "val": 50,  "test": 41},
    1: {"train": 2000, "val": 100, "test": 62},
}

# SSL + eval settings
BATCH_SSL = 64
BATCH_EVAL = 64
NUM_EPOCHS = 100
EVAL_EVERY = 10          # run kNN eval every N epochs
KNN_K = 20               # k in kNN

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -----------------------
# Datasets
# -----------------------
class DINOUnlabeledDataset(Dataset):
    def __init__(self, df, transform_global):
        self.paths = df["image_path"].tolist()
        self.transform_global = transform_global

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        v1 = self.transform_global(img)
        v2 = self.transform_global(img)
        return v1, v2

class LabeledImageDataset(Dataset):
    def __init__(self, df, transform):
        self.paths = df["image_path"].tolist()
        self.labels = df[LABEL_COL].astype(int).tolist()
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        x = self.transform(img)
        y = self.labels[idx]
        return x, y

# -----------------------
# Transforms
# -----------------------
global_transform = T.Compose([
    # T.RandomResizedCrop(IMAGE_SIZE, scale=(0.6, 1.0)),
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomRotation(degrees=20),
    # T.ColorJitter(brightness=0.2, contrast=0.2),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.25, 0.25, 0.25]),
])

eval_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.25, 0.25, 0.25]),
])


df_full = pd.read_csv(CSV_PATH)
df_full = df_full.drop_duplicates(subset=["image_path"]).reset_index(drop=True)

df_class0 = df_full[df_full[LABEL_COL] == 0].copy()
df_class1 = df_full[df_full[LABEL_COL] == 1].copy()

print("Total images:", len(df_full))
print("Label 0 count:", len(df_class0))
print("Label 1 count:", len(df_class1))

for lbl, req in SUP_SPLIT.items():
    have = len(df_full[df_full[LABEL_COL] == lbl])
    need = req["train"] + req["val"] + req["test"]
    assert have >= need, f"Not enough samples for label {lbl}: have {have}, need {need}"

# -----------------------
# Build EXACT supervised splits (train/val/test)
# -----------------------
def split_exact(df_lbl: pd.DataFrame, n_train: int, n_val: int, n_test: int, seed: int = 42):
    df_test = df_lbl.sample(n=n_test, random_state=seed)
    df_rem = df_lbl.drop(df_test.index)
    df_val = df_rem.sample(n=n_val, random_state=seed)
    df_train = df_rem.drop(df_val.index)
    assert len(df_train) == n_train, f"Train size mismatch: got {len(df_train)} expected {n_train}"
    return df_train, df_val, df_test

df0_train, df0_val, df0_test = split_exact(
    df_class0, SUP_SPLIT[0]["train"], SUP_SPLIT[0]["val"], SUP_SPLIT[0]["test"], seed=42
)
df1_train, df1_val, df1_test = split_exact(
    df_class1, SUP_SPLIT[1]["train"], SUP_SPLIT[1]["val"], SUP_SPLIT[1]["test"], seed=42
)

df_train = pd.concat([df0_train, df1_train]).sample(frac=1, random_state=42).reset_index(drop=True)
df_val   = pd.concat([df0_val,   df1_val]).sample(frac=1, random_state=42).reset_index(drop=True)
df_test  = pd.concat([df0_test,  df1_test]).sample(frac=1, random_state=42).reset_index(drop=True)

print("\nSupervised splits:")
print("Train:", len(df_train), " (label0:", (df_train[LABEL_COL]==0).sum(), "label1:", (df_train[LABEL_COL]==1).sum(), ")")
print("Val:  ", len(df_val),   " (label0:", (df_val[LABEL_COL]==0).sum(),   "label1:", (df_val[LABEL_COL]==1).sum(),   ")")
print("Test: ", len(df_test),  " (label0:", (df_test[LABEL_COL]==0).sum(),  "label1:", (df_test[LABEL_COL]==1).sum(),  ")")

train_csv = "/content/drive/MyDrive/FYP/foci_supervised_train_v02.csv"
val_csv   = "/content/drive/MyDrive/FYP/foci_supervised_val_v02.csv"
test_csv  = "/content/drive/MyDrive/FYP/foci_supervised_test_v02.csv"
df_train.to_csv(train_csv, index=False)
df_val.to_csv(val_csv, index=False)
df_test.to_csv(test_csv, index=False)

print("\nSaved:")
print(" -", train_csv)
print(" -", val_csv)
print(" -", test_csv)

# -----------------------
# SSL dataset: use ALL images
# -----------------------
df_unlabeled = df_full[["image_path"]].reset_index(drop=True)

unlab_dataset = DINOUnlabeledDataset(df_unlabeled, transform_global=global_transform)
unlab_loader = DataLoader(
    unlab_dataset,
    batch_size=BATCH_SSL,
    shuffle=True,
    num_workers=2,
    pin_memory=True,
    drop_last=True
)
print("Unlabeled images used for SSL (ALL):", len(unlab_dataset))

# -----------------------
# DINO backbone + head
# -----------------------
class ViTBackbone(nn.Module):
    def __init__(self, name="vit_tiny_patch16_224", pretrained=True):
        super().__init__()
        self.vit = timm.create_model(
            name,
            pretrained=pretrained,
            num_classes=0,
            global_pool=""
        )

    def forward(self, x):
        feats = self.vit.forward_features(x)
        if isinstance(feats, dict):
            if "x" in feats:
                feats = feats["x"]
            elif "pooled" in feats:
                feats = feats["pooled"]
        if feats.dim() == 3:
            feats = feats[:, 0, :]
        return feats  # (B, D)

class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim=256, hidden_dim=2048, bottleneck_dim=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, bottleneck_dim),
            nn.GELU(),
            nn.Linear(bottleneck_dim, out_dim)
        )

    def forward(self, x):
        return self.mlp(x)

class DINOStudentTeacher(nn.Module):
    def __init__(self,
                 backbone_name="vit_tiny_patch16_224",
                 backbone_pretrained=True,
                 out_dim=256,
                 hidden_dim=2048,
                 bottleneck_dim=256):
        super().__init__()
        self.backbone = ViTBackbone(backbone_name, pretrained=backbone_pretrained)

        dummy = torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)
        with torch.no_grad():
            in_dim = self.backbone(dummy).shape[-1]

        self.head = DINOHead(in_dim, out_dim, hidden_dim, bottleneck_dim)

    def forward(self, x):
        feats = self.backbone(x)
        out = self.head(feats)
        return out

# -----------------------
# DINO loss
# -----------------------
@dataclass
class DINOLossConfig:
    out_dim: int = 256
    teacher_temp: float = 0.04
    student_temp: float = 0.1
    center_momentum: float = 0.9

class DINOLoss(nn.Module):
    def __init__(self, cfg: DINOLossConfig):
        super().__init__()
        self.cfg = cfg
        self.register_buffer("center", torch.zeros(1, cfg.out_dim))

    def forward(self, student_outputs, teacher_outputs):
        student_temp = self.cfg.student_temp
        teacher_temp = self.cfg.teacher_temp

        student_out = [F.softmax(s / student_temp, dim=-1) for s in student_outputs]

        with torch.no_grad():
            teacher_out = []
            for t in teacher_outputs:
                t = (t - self.center)
                t = F.softmax(t / teacher_temp, dim=-1)
                teacher_out.append(t)

        total_loss = 0.0
        n_terms = 0
        for i, t_out in enumerate(teacher_out):
            for j, s_out in enumerate(student_out):
                if i == j:
                    continue
                loss = torch.sum(-t_out * torch.log(s_out + 1e-6), dim=-1)
                total_loss += loss.mean()
                n_terms += 1

        total_loss /= max(1, n_terms)

        with torch.no_grad():
            batch_center = torch.cat(teacher_out, dim=0).mean(dim=0, keepdim=True)
            self.center = self.center * self.cfg.center_momentum + batch_center * (1.0 - self.cfg.center_momentum)

        return total_loss

# -----------------------
# Build student/teacher, optimizer
# -----------------------
OUT_DIM = 256
student = DINOStudentTeacher("vit_tiny_patch16_224", True, OUT_DIM).to(device)
teacher = DINOStudentTeacher("vit_tiny_patch16_224", True, OUT_DIM).to(device)

teacher.load_state_dict(student.state_dict())
for p in teacher.parameters():
    p.requires_grad = False

dino_loss_fn = DINOLoss(DINOLossConfig(out_dim=OUT_DIM)).to(device)
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-4, weight_decay=1e-4)
teacher_momentum = 0.996

@torch.no_grad()
def update_teacher(student, teacher, m):
    for ps, pt in zip(student.parameters(), teacher.parameters()):
        pt.data.mul_(m).add_(ps.data, alpha=(1.0 - m))

# -----------------------
# kNN probe (teacher backbone) for val monitoring
# -----------------------
@torch.no_grad()
def extract_features(backbone: nn.Module, loader: DataLoader):
    backbone.eval()
    feats_all, labels_all = [], []
    for x, y in tqdm(loader, desc="Extract feats", leave=False):
        x = x.to(device, non_blocking=True)
        f = backbone(x)                 # (B, D)
        f = F.normalize(f, dim=-1)      # cosine space
        feats_all.append(f.cpu())
        labels_all.append(y.cpu())
    return torch.cat(feats_all, dim=0), torch.cat(labels_all, dim=0)

@torch.no_grad()
def knn_accuracy(train_feats, train_labels, val_feats, val_labels, k=20):
    sims = val_feats @ train_feats.T
    topk = sims.topk(k=k, dim=1).indices
    topk_labels = train_labels[topk]

    # majority vote
    preds = []
    for row in topk_labels:
        counts = torch.bincount(row, minlength=int(train_labels.max().item()) + 1)
        preds.append(counts.argmax().item())
    preds = torch.tensor(preds)

    acc = (preds == val_labels).float().mean().item()
    return acc

train_eval_ds = LabeledImageDataset(df_train, transform=eval_transform)
val_eval_ds   = LabeledImageDataset(df_val,   transform=eval_transform)

train_eval_loader = DataLoader(train_eval_ds, batch_size=BATCH_EVAL, shuffle=False, num_workers=2, pin_memory=True)
val_eval_loader   = DataLoader(val_eval_ds,   batch_size=BATCH_EVAL, shuffle=False, num_workers=2, pin_memory=True)

# -----------------------
# SSL training loop + periodic kNN eval
# -----------------------
ssl_loss_history = []
knn_val_history = []

student.train()
teacher.eval()

for epoch in range(NUM_EPOCHS):
    epoch_loss = 0.0
    n_batches = 0

    for v1, v2 in tqdm(unlab_loader, desc=f"SSL Epoch {epoch+1}/{NUM_EPOCHS}"):
        v1 = v1.to(device, non_blocking=True)
        v2 = v2.to(device, non_blocking=True)

        s1 = student(v1)
        s2 = student(v2)

        with torch.no_grad():
            t1 = teacher(v1)
            t2 = teacher(v2)

        loss = dino_loss_fn([s1, s2], [t1, t2])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        update_teacher(student, teacher, teacher_momentum)

        epoch_loss += loss.item()
        n_batches += 1

    avg_loss = epoch_loss / max(1, n_batches)
    ssl_loss_history.append(avg_loss)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - SSL DINO Loss: {avg_loss:.4f}")

    # ---- periodic kNN monitoring on supervised VAL ----
    if (epoch + 1) % EVAL_EVERY == 0 or (epoch == 0):
        print(f"\n[kNN probe] Epoch {epoch+1}: extracting features from teacher backbone...")
        tr_f, tr_y = extract_features(teacher.backbone, train_eval_loader)
        va_f, va_y = extract_features(teacher.backbone, val_eval_loader)
        acc = knn_accuracy(tr_f, tr_y, va_f, va_y, k=KNN_K)
        knn_val_history.append((epoch + 1, acc))
        print(f"[kNN probe] VAL accuracy @k={KNN_K}: {acc*100:.2f}%\n")

plt.figure()
plt.plot(range(1, len(ssl_loss_history) + 1), ssl_loss_history, marker="o")
plt.xlabel("Epoch")
plt.ylabel("SSL DINO Loss")
plt.title("DINO SSL Training Loss (ALL images)")
plt.grid(True)
plt.show()

if len(knn_val_history) > 0:
    xs = [e for e, _ in knn_val_history]
    ys = [a for _, a in knn_val_history]
    plt.figure()
    plt.plot(xs, ys, marker="o")
    plt.xlabel("Epoch")
    plt.ylabel("kNN Val Accuracy")
    plt.title(f"kNN Probe on Val (k={KNN_K}) using Teacher Backbone Features")
    plt.grid(True)
    plt.show()

# -----------------------
# Save only backbone from EMA teacher
# -----------------------
ssl_ckpt_path = "/content/drive/MyDrive/FYP/foci_dino_backbone_ALLimgs_v03.pth"
torch.save(teacher.backbone.state_dict(), ssl_ckpt_path)
print("\nSaved DINO-pretrained backbone to:", ssl_ckpt_path)


In [None]:
import os
import pandas as pd
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from PIL import Image
import torchvision.transforms as T
from tqdm import tqdm
import matplotlib.pyplot as plt
import timm

IMAGE_SIZE = 224
LABEL_COL = "label"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Path to DINO backbone (EMA teacher backbone)
dino_backbone_ckpt = "/content/drive/MyDrive/FYP/foci_dino_backbone_ALLimgs_v03.pth"
# (set this to whatever you actually saved)

heldout_csv_path = "/content/drive/MyDrive/FYP/foci_heldout_for_supervised.csv"

supervised_ckpt_path = "/content/drive/MyDrive/FYP/foci_dino_supervised_classifier_v03.pth"

# ---- Transforms ----
train_transform = T.Compose([
    # T.RandomResizedCrop(IMAGE_SIZE, scale=(0.8, 1.0)),
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5],
                std=[0.25, 0.25, 0.25]),
])

eval_transform = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5],
                std=[0.25, 0.25, 0.25]),
])



In [None]:
# -------------------------------
# Dataset for labeled classification
# -------------------------------
class FociLabeledDataset(Dataset):
    def __init__(self, df, transform):
        self.paths = df["image_path"].tolist()
        self.labels = df[LABEL_COL].astype(float).tolist()
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        img = Image.open(path).convert("RGB")
        img = self.transform(img)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return img, label


# -------------------------------
# ViT backbone (same as SSL)
# -------------------------------
import timm
import torch
import torch.nn as nn

IMAGE_SIZE = 224

class ViTBackbone(nn.Module):
    def __init__(self, name="vit_tiny_patch16_224", pretrained=False):
        super().__init__()
        self.vit = timm.create_model(
            name,
            # Ypretrained=False because later load DINO-pretrained weights instead of ImageNet weights.
            pretrained=pretrained,   # for DINO backbone, set False; weights come from ckpt
            num_classes=0,
            global_pool=""
        )

    def forward(self, x):
        """
        Return a single feature vector per image: (B, D)
        by taking the CLS token (token at index 0).
        """
        feats = self.vit.forward_features(x)  # tensor or dict


        # B = batch size,
        # N = number of tokens (1 CLS + patch tokens),
        # D = embedding dimension


        if isinstance(feats, dict):
            if "x" in feats:
                feats = feats["x"]      # (B, N, D)
            elif "pooled" in feats:
                feats = feats["pooled"] # (B, D)

        if feats.dim() == 3:
            # CLS token
            feats = feats[:, 0, :]      # (B, D)

        return feats                    # (B, D)


# -------------------------------
# Classifier using DINO backbone
# -------------------------------
class FociDINOClassifier(nn.Module):
    def __init__(self, backbone_ckpt, backbone_name="vit_tiny_patch16_224"):
        super().__init__()
        # use CLS-pooling backbone
        self.backbone = ViTBackbone(backbone_name, pretrained=False)

        # load DINO backbone weights
        state = torch.load(backbone_ckpt, map_location="cpu")
        self.backbone.load_state_dict(state)

        # freeze backbone for linear probe
        for p in self.backbone.parameters():
            p.requires_grad = False

        # infer embedding dim
        dummy = torch.zeros(1, 3, IMAGE_SIZE, IMAGE_SIZE)
        with torch.no_grad():
            in_dim = self.backbone(dummy).shape[-1]   # now (1, D)

        # Classification head
        self.head = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1)  # binary logit
        )

    def forward(self, x):
        feats = self.backbone(x)           # (B, D)
        logit = self.head(feats).squeeze(-1)  # (B,)
        return logit

In [None]:
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import timm

IMAGE_SIZE = 224
LABEL_COL = "label"
CSV_PATH = "/content/drive/MyDrive/FYP/foci_labels_v02.csv"

# EXACT supervised split sizes
SUP_SPLIT = {
    0: {"train": 730,  "val": 50,  "test": 41},
    1: {"train": 2000, "val": 100, "test": 62},
}

# SSL + eval settings
BATCH_SSL = 64
BATCH_EVAL = 64
NUM_EPOCHS = 100
EVAL_EVERY = 10          # run kNN eval every N epochs
KNN_K = 20               # k in kNN

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -----------------------
# Datasets
# -----------------------
class DINOUnlabeledDataset(Dataset):
    def __init__(self, df, transform_global):
        self.paths = df["image_path"].tolist()
        self.transform_global = transform_global

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        v1 = self.transform_global(img)
        v2 = self.transform_global(img)
        return v1, v2

class LabeledImageDataset(Dataset):
    def __init__(self, df, transform):
        self.paths = df["image_path"].tolist()
        self.labels = df[LABEL_COL].astype(int).tolist()
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        x = self.transform(img)
        y = self.labels[idx]
        return x, y

# -----------------------
# Transforms
# -----------------------
global_transform = T.Compose([
    # T.RandomResizedCrop(IMAGE_SIZE, scale=(0.6, 1.0)),
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    T.RandomRotation(degrees=20),
    # T.ColorJitter(brightness=0.2, contrast=0.2),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.25, 0.25, 0.25]),
])

# Deterministic transform for feature extraction / eval
eval_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.25, 0.25, 0.25]),
])


# global_transform = T.Compose([
#     # T.RandomResizedCrop(IMAGE_SIZE, scale=(0.6, 1.0)),
#     T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
#     T.RandomHorizontalFlip(),
#     T.RandomVerticalFlip(),
#     T.RandomRotation(degrees=20),
#     # T.ColorJitter(brightness=0.2, contrast=0.2),
#     T.ToTensor(),
#     T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.25, 0.25, 0.25]),
# ])

# # Deterministic transform for feature extraction / eval
# eval_transform = T.Compose([
#     T.Resize(256),
#     T.CenterCrop(IMAGE_SIZE),
#     T.ToTensor(),
#     T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.25, 0.25, 0.25]),
# ])


df_full = pd.read_csv(CSV_PATH)
df_full = df_full.drop_duplicates(subset=["image_path"]).reset_index(drop=True)

df_class0 = df_full[df_full[LABEL_COL] == 0].copy()
df_class1 = df_full[df_full[LABEL_COL] == 1].copy()

print("Total images:", len(df_full))
print("Label 0 count:", len(df_class0))
print("Label 1 count:", len(df_class1))

for lbl, req in SUP_SPLIT.items():
    have = len(df_full[df_full[LABEL_COL] == lbl])
    need = req["train"] + req["val"] + req["test"]
    assert have >= need, f"Not enough samples for label {lbl}: have {have}, need {need}"

# -----------------------
# Build EXACT supervised splits (train/val/test)
# -----------------------
def split_exact(df_lbl: pd.DataFrame, n_train: int, n_val: int, n_test: int, seed: int = 42):
    df_test = df_lbl.sample(n=n_test, random_state=seed)
    df_rem = df_lbl.drop(df_test.index)
    df_val = df_rem.sample(n=n_val, random_state=seed)
    df_train = df_rem.drop(df_val.index)
    assert len(df_train) == n_train, f"Train size mismatch: got {len(df_train)} expected {n_train}"
    return df_train, df_val, df_test

df0_train, df0_val, df0_test = split_exact(
    df_class0, SUP_SPLIT[0]["train"], SUP_SPLIT[0]["val"], SUP_SPLIT[0]["test"], seed=42
)
df1_train, df1_val, df1_test = split_exact(
    df_class1, SUP_SPLIT[1]["train"], SUP_SPLIT[1]["val"], SUP_SPLIT[1]["test"], seed=42
)

df_train = pd.concat([df0_train, df1_train]).sample(frac=1, random_state=42).reset_index(drop=True)
df_val   = pd.concat([df0_val,   df1_val]).sample(frac=1, random_state=42).reset_index(drop=True)
df_test  = pd.concat([df0_test,  df1_test]).sample(frac=1, random_state=42).reset_index(drop=True)

print("\nSupervised splits:")
print("Train:", len(df_train), " (label0:", (df_train[LABEL_COL]==0).sum(), "label1:", (df_train[LABEL_COL]==1).sum(), ")")
print("Val:  ", len(df_val),   " (label0:", (df_val[LABEL_COL]==0).sum(),   "label1:", (df_val[LABEL_COL]==1).sum(),   ")")
print("Test: ", len(df_test),  " (label0:", (df_test[LABEL_COL]==0).sum(),  "label1:", (df_test[LABEL_COL]==1).sum(),  ")")

train_csv = "/content/drive/MyDrive/FYP/foci_supervised_train_v02.csv"
val_csv   = "/content/drive/MyDrive/FYP/foci_supervised_val_v02.csv"
test_csv  = "/content/drive/MyDrive/FYP/foci_supervised_test_v02.csv"
df_train.to_csv(train_csv, index=False)
df_val.to_csv(val_csv, index=False)
df_test.to_csv(test_csv, index=False)

print("\nSaved:")
print(" -", train_csv)
print(" -", val_csv)
print(" -", test_csv)

print("Train size:", len(df_train))
print("Val size:", len(df_val))
print("Test size:", len(df_test))

train_dataset = FociLabeledDataset(df_train, transform=train_transform)
val_dataset   = FociLabeledDataset(df_val, transform=eval_transform)
test_dataset  = FociLabeledDataset(df_test, transform=eval_transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True,
                          num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False,
                          num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_dataset, batch_size=16, shuffle=False,
                          num_workers=2, pin_memory=True)


Train / eval helper functions

In [None]:
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    epoch_loss = 0.0
    correct = 0
    total = 0

    for imgs, labels in tqdm(loader, desc="Train", leave=False):
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        logits = model(imgs)  # (B,)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item() * imgs.size(0)

        probs = torch.sigmoid(logits)
        preds = (probs >= 0.5).float()
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = epoch_loss / max(1, total)
    acc = correct / max(1, total)
    return avg_loss, acc


@torch.no_grad()
def evaluate(model, loader, criterion, prefix="Val"):
    model.eval()
    epoch_loss = 0.0
    correct = 0
    total = 0

    for imgs, labels in tqdm(loader, desc=prefix, leave=False):
        imgs = imgs.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        logits = model(imgs)
        loss = criterion(logits, labels)

        epoch_loss += loss.item() * imgs.size(0)

        probs = torch.sigmoid(logits)
        preds = (probs >= 0.5).float()
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    avg_loss = epoch_loss / max(1, total)
    acc = correct / max(1, total)
    return avg_loss, acc


Supervised training loop + plots

In [None]:
dino_backbone_ckpt

In [None]:
supervised_ckpt_path

In [None]:
# -------------------------------
# Supervised training loop (with val)
# -------------------------------
model = FociDINOClassifier(backbone_ckpt=dino_backbone_ckpt).to(device)

# Only train the head initially (linear probe)
optimizer = torch.optim.AdamW(model.head.parameters(), lr=1e-4, weight_decay=1e-4)
criterion = nn.BCEWithLogitsLoss()

num_epochs_sup = 100
train_losses, val_losses = [], []
train_accs, val_accs = [], []

best_val_acc = 0.0

for epoch in range(num_epochs_sup):
    print(f"\nSupervised Epoch {epoch+1}/{num_epochs_sup}")
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_acc = evaluate(model, val_loader, criterion, prefix="Val")

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

    print(f"  Train loss: {train_loss:.4f}, acc: {train_acc:.4f}")
    print(f"  Val   loss: {val_loss:.4f}, acc: {val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), supervised_ckpt_path)
        print(f"  --> New best model saved with val acc = {best_val_acc:.4f}")

# ---- Plots ----
plt.figure()
plt.plot(train_losses, label="Train loss")
plt.plot(val_losses, label="Val loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Supervised training vs validation loss")
plt.legend()
plt.grid(True)
plt.show()

plt.figure()
plt.plot(train_accs, label="Train acc")
plt.plot(val_accs, label="Val acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Supervised training vs validation accuracy")
plt.legend()
plt.grid(True)
plt.show()

print("Best val acc:", best_val_acc)
print("Best classifier saved at:", supervised_ckpt_path)


Test evaluation (using best model)

In [None]:
supervised_ckpt_path

In [None]:
dino_backbone_ckpt="/content/drive/MyDrive/FYP/foci_dino_backbone_70_70.pth"
supervised_ckpt_path="/content/drive/MyDrive/FYP/foci_dino_supervised_classifier.pth"

In [None]:
criterion = nn.BCEWithLogitsLoss()
best_model = FociDINOClassifier(backbone_ckpt=dino_backbone_ckpt).to(device)
best_model.load_state_dict(torch.load(supervised_ckpt_path, map_location=device))

test_loss, test_acc = evaluate(best_model, test_loader, criterion, prefix="Test")
print(f"\nTEST RESULTS -> loss: {test_loss:.4f}, acc: {test_acc:.4f}")

Inference helpers + example usage

In [None]:
# # -------------------------------
# # INFERENCE HELPERS
# # -------------------------------
# @torch.no_grad()
# def infer_on_loader(model, loader, threshold=0.5):
#     model.eval()
#     all_probs = []
#     all_preds = []
#     for imgs, _ in tqdm(loader, desc="Infer", leave=False):
#         imgs = imgs.to(device, non_blocking=True)
#         logits = model(imgs)
#         probs = torch.sigmoid(logits)
#         preds = (probs >= threshold).float()
#         all_probs.extend(probs.cpu().tolist())
#         all_preds.extend(preds.cpu().tolist())
#     return all_probs, all_preds


# @torch.no_grad()
# def infer_on_csv(model, csv_path, transform=eval_transform,
#                  out_csv=None, threshold=0.5):
#     df = pd.read_csv(csv_path)

#     class InferenceDataset(Dataset):
#         def __init__(self, df, transform):
#             self.paths = df["image_path"].tolist()
#             self.transform = transform
#         def __len__(self):
#             return len(self.paths)
#         def __getitem__(self, idx):
#             img = Image.open(self.paths[idx]).convert("RGB")
#             img = self.transform(img)
#             return img, 0.0  # dummy label

#     dataset = InferenceDataset(df, transform)
#     loader = DataLoader(dataset, batch_size=16, shuffle=False,
#                         num_workers=2, pin_memory=True)

#     probs, preds = infer_on_loader(model, loader, threshold=threshold)
#     df["prob_pos"] = probs
#     df["pred_label"] = preds

#     if out_csv is not None:
#         df.to_csv(out_csv, index=False)
#         print("Saved inference results to:", out_csv)

#     return df


# @torch.no_grad()
# def infer_single_image(model, img_path, transform=eval_transform, threshold=0.5):
#     img = Image.open(img_path).convert("RGB")
#     img = transform(img)
#     img = img.unsqueeze(0).to(device)
#     logit = model(img)
#     prob = torch.sigmoid(logit).item()
#     pred = 1.0 if prob >= threshold else 0.0
#     return prob, pred


# import os
# import numpy as np
# from PIL import Image
# import matplotlib.pyplot as plt

# def run_inference_and_visualize_csv(model,
#                                     csv_path,
#                                     transform=eval_transform,
#                                     threshold=0.5,
#                                     out_csv=None,
#                                     max_rows=None):
#     """
#     For each row in csv_path:
#       - loads image_path
#       - runs model -> prob, pred
#       - computes gradient-based saliency map wrt input
#       - displays [image | image + heatmap] with file path
#     Also returns df with prob_pos & pred_label columns and optionally saves it.
#     """
#     df = pd.read_csv(csv_path)
#     print("Num rows in CSV:", len(df))

#     probs_all = []
#     preds_all = []

#     model.eval()

#     # Optionally limit number of rows visualized
#     if max_rows is None:
#         max_rows = len(df)

#     for idx, row in df.iloc[:max_rows].iterrows():
#         img_path = row["image_path"]
#         true_label = row.get("label", None)  # might be missing for unlabeled CSV

#         # --- load + transform image ---
#         img_pil = Image.open(img_path).convert("RGB")
#         x = transform(img_pil).unsqueeze(0).to(device)  # (1, 3, H, W)
#         x.requires_grad_(True)

#         # --- forward pass ---
#         model.zero_grad()
#         logit = model(x)            # (1,) from FociDINOClassifier
#         prob = torch.sigmoid(logit)[0]
#         pred = 1.0 if prob.item() >= threshold else 0.0

#         # --- backward to get saliency wrt input ---
#         prob.backward()
#         grad = x.grad.detach()      # (1, 3, H, W)
#         grad = grad.abs().max(dim=1)[0]  # (1, H, W) â†’ max over channels
#         grad = grad[0].cpu().numpy()

#         # normalize heatmap to [0, 1]
#         g_min, g_max = grad.min(), grad.max()
#         if g_max > g_min:
#             grad_norm = (grad - g_min) / (g_max - g_min)
#         else:
#             grad_norm = np.zeros_like(grad)

#         probs_all.append(prob.item())
#         preds_all.append(pred)

#         # --- PLOT: image + heatmap overlay ---
#         fig, axes = plt.subplots(1, 2, figsize=(8, 4))

#         # left: original image
#         axes[0].imshow(img_pil)
#         axes[0].set_title("Original", fontsize=10)
#         axes[0].axis("off")

#         # right: overlay saliency
#         axes[1].imshow(img_pil)
#         axes[1].imshow(grad_norm, cmap="jet", alpha=0.5)
#         axes[1].set_title("Saliency / attention heatmap", fontsize=10)
#         axes[1].axis("off")

#         # big title with file path + labels
#         title_lines = [f"path: {img_path}"]
#         if true_label is not None:
#             title_lines.append(f"true={int(true_label)}, pred={int(pred)}, prob={prob.item():.3f}")
#         else:
#             title_lines.append(f"pred={int(pred)}, prob={prob.item():.3f}")

#         fig.suptitle("\n".join(title_lines), fontsize=8)
#         plt.tight_layout()
#         plt.show()

#     # add predictions to df
#     df = df.copy()
#     df["prob_pos"] = probs_all
#     df["pred_label"] = preds_all

#     if out_csv is not None:
#         df.to_csv(out_csv, index=False)
#         print("Saved inference results to:", out_csv)

#     return df

import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# -------------------------------
# INFERENCE HELPERS
# -------------------------------
@torch.no_grad()
def infer_on_loader(model, loader, threshold=0.5):
    model.eval()
    all_probs = []
    all_preds = []
    for imgs, _ in tqdm(loader, desc="Infer", leave=False):
        imgs = imgs.to(device, non_blocking=True)
        logits = model(imgs)                      # shape: (B,) or (B,1)
        logits = logits.view(-1)                  # ensure (B,)
        probs = torch.sigmoid(logits)             # (B,)
        preds = (probs >= threshold).float()      # (B,)
        all_probs.extend(probs.cpu().tolist())
        all_preds.extend(preds.cpu().tolist())
    return all_probs, all_preds


@torch.no_grad()
def infer_on_csv(model, csv_path, transform,
                 out_csv=None, threshold=0.5, batch_size=16):
    df = pd.read_csv(csv_path)

    class InferenceDataset(Dataset):
        def __init__(self, df_, transform_):
            self.paths = df_["image_path"].tolist()
            self.transform = transform_
        def __len__(self):
            return len(self.paths)
        def __getitem__(self, idx):
            img = Image.open(self.paths[idx]).convert("RGB")
            img = self.transform(img)
            return img, 0.0  # dummy label

    dataset = InferenceDataset(df, transform)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False,
                        num_workers=2, pin_memory=True)

    probs, preds = infer_on_loader(model, loader, threshold=threshold)
    df = df.copy()
    df["prob_pos"] = probs
    df["pred_label"] = preds

    if out_csv is not None:
        df.to_csv(out_csv, index=False)
        print("Saved inference results to:", out_csv)

    return df


@torch.no_grad()
def infer_single_image(model, img_path, transform, threshold=0.5):
    model.eval()
    img = Image.open(img_path).convert("RGB")
    x = transform(img).unsqueeze(0).to(device)
    logit = model(x).view(-1)[0]
    prob = torch.sigmoid(logit).item()
    pred = 1.0 if prob >= threshold else 0.0
    return prob, pred


# -------------------------------
# SALIENCY + CSV VISUALIZATION (no checkerboard/grid artifacts)
# -------------------------------
def run_inference_and_visualize_csv(model,
                                    csv_path,
                                    transform,
                                    threshold=0.5,
                                    out_csv=None,
                                    max_rows=None,
                                    display_size=None):
    df = pd.read_csv(csv_path)
    print("Num rows in CSV:", len(df))

    probs_all = []
    preds_all = []

    model.eval()

    if max_rows is None:
        max_rows = len(df)

    if display_size is None:
        display_size = (IMAGE_SIZE, IMAGE_SIZE)

    for idx, row in df.iloc[:max_rows].iterrows():
        img_path = row["image_path"]
        true_label = row.get("label", None)

        img_pil = Image.open(img_path).convert("RGB")

        x = transform(img_pil).unsqueeze(0).to(device)  # (1,3,H,W)
        x.requires_grad_(True)

        model.zero_grad(set_to_none=True)

        logit = model(x).view(-1)[0]          # scalar
        prob = torch.sigmoid(logit)           # scalar in (0,1)
        pred = 1.0 if prob.item() >= threshold else 0.0

        prob.backward()

        grad = x.grad.detach()                # (1,3,H,W)
        grad = grad.abs().amax(dim=1)[0]      # (H,W) max over channels

        # normalize to [0,1]
        g_min, g_max = grad.min(), grad.max()
        if (g_max - g_min) > 1e-12:
            grad_norm = (grad - g_min) / (g_max - g_min)
        else:
            grad_norm = torch.zeros_like(grad)

        grad_norm = grad_norm.cpu().numpy()   # (H,W)

        probs_all.append(prob.item())
        preds_all.append(pred)

        img_disp = img_pil.resize(display_size, resample=Image.BILINEAR)

        # --- plot ---
        fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True)

        # left: original (resized to match heatmap)
        axes[0].imshow(img_disp, interpolation="bilinear")
        axes[0].set_title("Original", fontsize=10)
        axes[0].set_axis_off()

        # right: overlay (smooth interpolation prevents checkerboard)
        axes[1].imshow(img_disp, interpolation="bilinear")
        axes[1].imshow(
            grad_norm,
            cmap="jet",
            alpha=0.45,
            interpolation="bilinear",
            vmin=0.0, vmax=1.0
        )
        axes[1].set_title("Saliency heatmap", fontsize=10)
        axes[1].set_axis_off()

        title_lines = [f"path: {img_path}"]
        if true_label is not None and not pd.isna(true_label):
            title_lines.append(f"true={int(true_label)}, pred={int(pred)}, prob={prob.item():.3f}")
        else:
            title_lines.append(f"pred={int(pred)}, prob={prob.item():.3f}")

        fig.suptitle("\n".join(title_lines), fontsize=8)
        plt.show()

        x.grad = None

    df_out = df.iloc[:max_rows].copy()
    df_out["prob_pos"] = probs_all
    df_out["pred_label"] = preds_all

    if out_csv is not None:
        df_out.to_csv(out_csv, index=False)
        print("Saved inference results to:", out_csv)

    return df_out

In [None]:
import os
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter
import torch
import torchvision.transforms.functional as TF

def pad_to_square(pil_img, fill=0):
    """Return padded square PIL + (pad_left, pad_top, new_side)."""
    w, h = pil_img.size
    side = max(w, h)
    pad_left = (side - w) // 2
    pad_top  = (side - h) // 2
    pad_right = side - w - pad_left
    pad_bottom = side - h - pad_top
    padded = TF.pad(pil_img, padding=[pad_left, pad_top, pad_right, pad_bottom], fill=fill)
    return padded, pad_left, pad_top, side

def unpad_map(map_sq, pad_left, pad_top, orig_w, orig_h, side):
    """
    map_sq: saliency map on square (side x side) in numpy
    Crops out padding region to get back to original (orig_h x orig_w).
    """
    x0 = pad_left
    y0 = pad_top
    x1 = x0 + orig_w
    y1 = y0 + orig_h
    return map_sq[y0:y1, x0:x1]

def run_inference_and_visualize_csv_whole_image(
    model,
    csv_path,
    threshold=0.5,
    out_csv=None,
    max_rows=None,
    blur_sigma=5,
    output_dir=None,
    save_fig=True,
    show_fig=True,
    dpi=150,
    input_size=224,
):
    df = pd.read_csv(csv_path)
    print("Num rows in CSV:", len(df))

    probs_all, preds_all, kept_indices = [], [], []
    model.eval()

    if max_rows is None:
        max_rows = len(df)

    if output_dir is not None:
        os.makedirs(output_dir, exist_ok=True)
        print("Saving figures to:", output_dir)

    for idx, row in df.iloc[:max_rows].iterrows():
        img_path = row["image_path"]
        true_label = row.get("label", None)

        try:
            img_pil = Image.open(img_path).convert("RGB")
        except Exception as e:
            print(f"[SKIP] Error loading {img_path}: {e}")
            continue

        orig_w, orig_h = img_pil.size

        img_sq, pad_left, pad_top, side = pad_to_square(img_pil, fill=0)
        img_in = img_sq.resize((input_size, input_size), resample=Image.BILINEAR)

        x = TF.to_tensor(img_in)  # (3,H,W) float [0,1]
        x = TF.normalize(x, mean=[0.5,0.5,0.5], std=[0.25,0.25,0.25])
        x = x.unsqueeze(0).to(device)
        x.requires_grad_(True)

        model.zero_grad(set_to_none=True)

        logit = model(x).view(-1)[0]
        prob = torch.sigmoid(logit)
        prob_val = float(prob.item())
        pred_val = 1.0 if prob_val >= threshold else 0.0

        prob.backward()

        grad = x.grad.detach()                 # (1,3,224,224)
        grad = grad.abs().amax(dim=1)[0]       # (224,224)
        grad = grad.cpu().numpy()

        grad = gaussian_filter(grad, sigma=blur_sigma)

        g_min, g_max = grad.min(), grad.max()
        grad_norm = (grad - g_min) / (g_max - g_min + 1e-12)

        kept_indices.append(idx)
        probs_all.append(prob_val)
        preds_all.append(pred_val)

        map_sq = Image.fromarray((grad_norm * 255).astype(np.uint8)).resize(
            (side, side), resample=Image.BILINEAR
        )
        map_sq = np.array(map_sq).astype(np.float32) / 255.0

        map_unpadded = unpad_map(map_sq, pad_left, pad_top, orig_w, orig_h, side)

        map_final = Image.fromarray((map_unpadded * 255).astype(np.uint8)).resize(
            (orig_w, orig_h), resample=Image.BILINEAR
        )
        map_final = np.array(map_final).astype(np.float32) / 255.0

        assert map_final.shape == (orig_h, orig_w), f"Mismatch map {map_final.shape} vs img {(orig_h, orig_w)}"

        fig, axes = plt.subplots(1, 3, figsize=(12, 4), constrained_layout=True)

        axes[0].imshow(img_pil)
        axes[0].set_title("Original (full image)", fontsize=10)
        axes[0].axis("off")

        axes[1].imshow(map_final, cmap="jet", vmin=0.0, vmax=1.0)
        axes[1].set_title(f"Saliency (full img)", fontsize=10)
        axes[1].axis("off")

        axes[2].imshow(img_pil)
        axes[2].imshow(map_final, cmap="jet", alpha=0.45, vmin=0.0, vmax=1.0)
        axes[2].set_title("Overlay (aligned to full image)", fontsize=10)
        axes[2].axis("off")

        title_lines = [f"path: {os.path.basename(img_path)}"]
        if true_label is not None and not pd.isna(true_label):
            title_lines.append(f"true={int(true_label)}, pred={int(pred_val)}, prob={prob_val:.3f}")
        else:
            title_lines.append(f"pred={int(pred_val)}, prob={prob_val:.3f}")
        fig.suptitle("\n".join(title_lines), fontsize=9)

        if save_fig and output_dir is not None:
            base = os.path.splitext(os.path.basename(img_path))[0]
            save_name = f"{idx:05d}_{base}_pred{int(pred_val)}_p{prob_val:.3f}.png"
            fig.savefig(os.path.join(output_dir, save_name), dpi=dpi, bbox_inches="tight")

        if show_fig:
            plt.show()
        plt.close(fig)

        x.grad = None

    df_out = df.loc[kept_indices].copy()
    df_out["prob_pos"] = probs_all
    df_out["pred_label"] = preds_all

    if out_csv is not None:
        df_out.to_csv(out_csv, index=False)
        print("Saved inference results to:", out_csv)

    return df_out


In [None]:
supervised_ckpt_path

In [None]:
dino_backbone_ckpt="/content/drive/MyDrive/FYP/foci_dino_backbone_ALLimgs_v02.pth"
supervised_ckpt_path="/content/drive/MyDrive/FYP/foci_dino_supervised_classifier_v02.pth"

In [None]:
best_model = FociDINOClassifier(backbone_ckpt=dino_backbone_ckpt).to(device)
best_model.load_state_dict(torch.load(supervised_ckpt_path, map_location=device))

df_pred = run_inference_and_visualize_csv_whole_image(
    best_model,
    csv_path="/content/drive/MyDrive/FYP/foci_supervised_test_v02.csv",
    threshold=0.5,
    out_csv="/content/drive/MyDrive/FYP/foci_supervised_test_v02_pred.csv",
    max_rows=None,
    blur_sigma=5,
    output_dir="/content/drive/MyDrive/FYP/saliency_panels_fullimage_v02",
    save_fig=True,
    show_fig=True,
    input_size=224
)


In [None]:
dino_backbone_ckpt="/content/drive/MyDrive/FYP/foci_dino_backbone_ALLimgs_v03.pth"
supervised_ckpt_path="/content/drive/MyDrive/FYP/foci_dino_supervised_classifier_v03.pth"

dino_backbone_ckpt="/content/drive/MyDrive/FYP/foci_dino_backbone_70_70.pth"
supervised_ckpt_path="/content/drive/MyDrive/FYP/foci_dino_supervised_classifier.pth"

In [None]:
prob, pred = infer_single_image(
    best_model,
    "/content/drive/MyDrive/FYP/groundtruth/4Gy_gH2AX_4Hr/cell_0207.png",
    threshold=0.5
)
print("Prob(class=1):", prob, "Predicted label:", pred)
