In [None]:
import os
import pandas as pd
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform

# --- Paths & Hyperparams ---
CSV_PATH        = "/home/iambrink/NOH_Thyroid_Cancer_Data/CSV-files/Thyroid_Cancer_TAN&NOH_file.csv"
BASE_IMAGE_PATH = "/home/iambrink/NOH_Thyroid_Cancer_Data/superdata/"

MODEL_NAME  = "hf-hub:paige-ai/Virchow2"
NUM_CLASSES = 2
BATCH_SIZE  = 16
NUM_EPOCHS  = 1000
LR          = 1e-5
WD          = 1e-4
NUM_WORKERS = 8
DEVICE      = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load & split DataFrame ---
df = pd.read_csv(CSV_PATH).dropna(subset=["Surgery diagnosis in number"])
train_df, val_df = train_test_split(
    df, test_size=0.2, random_state=42,
    stratify=df["Surgery diagnosis in number"]
)

# --- Dataset ---
class ThyroidDataset(Dataset):
    def __init__(self, df, base_path, transform=None):
        self.df = df.reset_index(drop=True)
        self.base = base_path
        self.tf   = transform
    def __len__(self):
        return len(self.df)
    def __getitem__(self, i):
        row = self.df.iloc[i]
        img_path = os.path.join(self.base, row["image_path"].replace("\\","/"))
        img = Image.open(img_path).convert("RGB")
        label = int(row["Surgery diagnosis in number"])  # 0 or 1
        if self.tf:
            img = self.tf(img)
        return img, torch.tensor(label, dtype=torch.long)

# --- Transforms & DataLoaders ---
config          = resolve_data_config({}, model=MODEL_NAME)
train_transform = create_transform(**config, is_training=True)
val_transform   = create_transform(**config, is_training=False)

train_ds = ThyroidDataset(train_df, BASE_IMAGE_PATH, train_transform)
val_ds   = ThyroidDataset(val_df,   BASE_IMAGE_PATH, val_transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=NUM_WORKERS, pin_memory=True)

# --- Model & freeze backbone ---
model = timm.create_model(
    MODEL_NAME,
    pretrained=True,
    num_classes=NUM_CLASSES,
    mlp_layer=timm.layers.SwiGLUPacked,
    act_layer=torch.nn.SiLU
).to(DEVICE)

for p in model.parameters():
    p.requires_grad = False
for p in model.get_classifier().parameters():
    p.requires_grad = True

# --- Loss, optimizer, scheduler, scaler ---
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR, weight_decay=WD
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)
scaler    = torch.cuda.amp.GradScaler()
torch.backends.cudnn.benchmark = True

best_val_acc = 0.0

# --- Training & Validation ---
for epoch in range(1, NUM_EPOCHS + 1):
    # — Train —
    model.train()
    total_loss = 0.0
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch} Train"):
        imgs   = imgs.to(DEVICE, non_blocking=True)
        labels = labels.to(DEVICE, non_blocking=True)   # [B], LongTensor

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast():
            patch_logits = model(imgs)                    # [B, num_patches, 2]
            logits       = patch_logits.mean(dim=1)       # ⬅️ now [B,2]
            loss         = criterion(logits, labels)      # labels [B]

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item() * imgs.size(0)

    scheduler.step()
    avg_train_loss = total_loss / len(train_ds)

    # — Validate —
    model.eval()
    val_loss = 0.0
    correct  = 0
    total    = 0
    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc=f"Epoch {epoch} Val"):
            imgs   = imgs.to(DEVICE, non_blocking=True)
            labels = labels.to(DEVICE, non_blocking=True)

            with torch.cuda.amp.autocast():
                patch_logits = model(imgs)
                logits       = patch_logits.mean(dim=1)
                loss         = criterion(logits, labels)

            val_loss += loss.item() * imgs.size(0)
            preds    = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total   += labels.size(0)

    avg_val_loss = val_loss / len(val_ds)
    val_acc      = correct / total

    print(
        f"Epoch {epoch:2d} | "
        f"Train Loss: {avg_train_loss:.4f} | "
        f"Val Loss:   {avg_val_loss:.4f} | "
        f"Val Acc:    {val_acc:.4f}"
    )

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_virchow2.pth")
        print(f"→ Saved new best model (Acc: {best_val_acc:.4f})")


  scaler    = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
Epoch 1 Train: 100%|██████████| 163/163 [00:23<00:00,  6.94it/s]
  with torch.cuda.amp.autocast():
Epoch 1 Val: 100%|██████████| 41/41 [00:06<00:00,  6.82it/s]


Epoch  1 | Train Loss: 0.7812 | Val Loss:   0.7179 | Val Acc:    0.5031
→ Saved new best model (Acc: 0.5031)


Epoch 2 Train: 100%|██████████| 163/163 [00:23<00:00,  6.86it/s]
Epoch 2 Val: 100%|██████████| 41/41 [00:06<00:00,  6.83it/s]


Epoch  2 | Train Loss: 0.6873 | Val Loss:   0.6677 | Val Acc:    0.6092
→ Saved new best model (Acc: 0.6092)


Epoch 3 Train: 100%|██████████| 163/163 [00:24<00:00,  6.78it/s]
Epoch 3 Val: 100%|██████████| 41/41 [00:06<00:00,  6.82it/s]


Epoch  3 | Train Loss: 0.6539 | Val Loss:   0.6396 | Val Acc:    0.7077
→ Saved new best model (Acc: 0.7077)


Epoch 4 Train: 100%|██████████| 163/163 [00:23<00:00,  6.86it/s]
Epoch 4 Val: 100%|██████████| 41/41 [00:06<00:00,  6.82it/s]


Epoch  4 | Train Loss: 0.6301 | Val Loss:   0.6171 | Val Acc:    0.7600
→ Saved new best model (Acc: 0.7600)


Epoch 5 Train: 100%|██████████| 163/163 [00:23<00:00,  6.86it/s]
Epoch 5 Val: 100%|██████████| 41/41 [00:05<00:00,  6.84it/s]


Epoch  5 | Train Loss: 0.6044 | Val Loss:   0.5972 | Val Acc:    0.7677
→ Saved new best model (Acc: 0.7677)


Epoch 6 Train: 100%|██████████| 163/163 [00:23<00:00,  6.87it/s]
Epoch 6 Val: 100%|██████████| 41/41 [00:05<00:00,  6.83it/s]


Epoch  6 | Train Loss: 0.5911 | Val Loss:   0.5814 | Val Acc:    0.7846
→ Saved new best model (Acc: 0.7846)


Epoch 7 Train: 100%|██████████| 163/163 [00:23<00:00,  6.86it/s]
Epoch 7 Val: 100%|██████████| 41/41 [00:06<00:00,  6.82it/s]


Epoch  7 | Train Loss: 0.5760 | Val Loss:   0.5674 | Val Acc:    0.7877
→ Saved new best model (Acc: 0.7877)


Epoch 8 Train: 100%|██████████| 163/163 [00:23<00:00,  6.84it/s]
Epoch 8 Val: 100%|██████████| 41/41 [00:06<00:00,  6.81it/s]


Epoch  8 | Train Loss: 0.5582 | Val Loss:   0.5543 | Val Acc:    0.7877


Epoch 9 Train: 100%|██████████| 163/163 [00:23<00:00,  6.92it/s]
Epoch 9 Val: 100%|██████████| 41/41 [00:06<00:00,  6.82it/s]


Epoch  9 | Train Loss: 0.5474 | Val Loss:   0.5442 | Val Acc:    0.7877


Epoch 10 Train: 100%|██████████| 163/163 [00:23<00:00,  6.92it/s]
Epoch 10 Val: 100%|██████████| 41/41 [00:05<00:00,  6.85it/s]


Epoch 10 | Train Loss: 0.5342 | Val Loss:   0.5356 | Val Acc:    0.7923
→ Saved new best model (Acc: 0.7923)


Epoch 11 Train: 100%|██████████| 163/163 [00:23<00:00,  6.87it/s]
Epoch 11 Val: 100%|██████████| 41/41 [00:06<00:00,  6.82it/s]


Epoch 11 | Train Loss: 0.5270 | Val Loss:   0.5271 | Val Acc:    0.7969
→ Saved new best model (Acc: 0.7969)


Epoch 12 Train: 100%|██████████| 163/163 [00:23<00:00,  6.87it/s]
Epoch 12 Val: 100%|██████████| 41/41 [00:06<00:00,  6.82it/s]


Epoch 12 | Train Loss: 0.5240 | Val Loss:   0.5206 | Val Acc:    0.8031
→ Saved new best model (Acc: 0.8031)


Epoch 13 Train: 100%|██████████| 163/163 [00:23<00:00,  6.88it/s]
Epoch 13 Val: 100%|██████████| 41/41 [00:05<00:00,  6.84it/s]


Epoch 13 | Train Loss: 0.5099 | Val Loss:   0.5139 | Val Acc:    0.8062
→ Saved new best model (Acc: 0.8062)


Epoch 14 Train: 100%|██████████| 163/163 [00:23<00:00,  6.89it/s]
Epoch 14 Val: 100%|██████████| 41/41 [00:05<00:00,  6.83it/s]


Epoch 14 | Train Loss: 0.5065 | Val Loss:   0.5085 | Val Acc:    0.8092
→ Saved new best model (Acc: 0.8092)


Epoch 15 Train: 100%|██████████| 163/163 [00:23<00:00,  6.87it/s]
Epoch 15 Val: 100%|██████████| 41/41 [00:05<00:00,  6.83it/s]


Epoch 15 | Train Loss: 0.4977 | Val Loss:   0.5042 | Val Acc:    0.8138
→ Saved new best model (Acc: 0.8138)


Epoch 16 Train: 100%|██████████| 163/163 [00:23<00:00,  6.87it/s]
Epoch 16 Val: 100%|██████████| 41/41 [00:06<00:00,  6.82it/s]


Epoch 16 | Train Loss: 0.4933 | Val Loss:   0.4993 | Val Acc:    0.8154
→ Saved new best model (Acc: 0.8154)


Epoch 17 Train: 100%|██████████| 163/163 [00:23<00:00,  6.87it/s]
Epoch 17 Val: 100%|██████████| 41/41 [00:06<00:00,  6.82it/s]


Epoch 17 | Train Loss: 0.4886 | Val Loss:   0.4951 | Val Acc:    0.8154


Epoch 18 Train: 100%|██████████| 163/163 [00:23<00:00,  6.91it/s]
Epoch 18 Val: 100%|██████████| 41/41 [00:05<00:00,  6.84it/s]


Epoch 18 | Train Loss: 0.4785 | Val Loss:   0.4917 | Val Acc:    0.8169
→ Saved new best model (Acc: 0.8169)


Epoch 19 Train: 100%|██████████| 163/163 [00:23<00:00,  6.87it/s]
Epoch 19 Val: 100%|██████████| 41/41 [00:06<00:00,  6.81it/s]


Epoch 19 | Train Loss: 0.4792 | Val Loss:   0.4888 | Val Acc:    0.8154


Epoch 20 Train: 100%|██████████| 163/163 [00:23<00:00,  6.92it/s]
Epoch 20 Val: 100%|██████████| 41/41 [00:06<00:00,  6.82it/s]


Epoch 20 | Train Loss: 0.4748 | Val Loss:   0.4863 | Val Acc:    0.8200
→ Saved new best model (Acc: 0.8200)


Epoch 21 Train:  62%|██████▏   | 101/163 [00:14<00:08,  6.92it/s]