In [1]:
import contextlib

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from src.datasets.multi_build import build_dataset_from_keys
from src.models.segformer_baseline import load_model
from torch.utils.data import DataLoader, random_split
from tqdm.auto import tqdm

### 1. Dataset & DataLoader

In [None]:
BUILD_KEYS = ["tcr_phase1_build1", "tcr_phase1_build2"]

# Build & split
full_ds = build_dataset_from_keys(BUILD_KEYS, size=512, augment=True)
n_val   = int(len(full_ds) * 0.1)
n_train = len(full_ds) - n_val
train_ds, val_ds = random_split(full_ds, [n_train, n_val],
                                generator=torch.Generator().manual_seed(42))

# DataLoaders
train_loader = DataLoader(
    train_ds, batch_size=8, shuffle=True, num_workers=0, pin_memory=True
)
val_loader = DataLoader(
    val_ds, batch_size=8, shuffle=False, num_workers=0, pin_memory=True
)

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")


Train batches: 401, Val batches: 45


### 2. Model & Helpers

In [3]:
# Load ViT‐SegFormer
processor, model = load_model()
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Optimiser + scaler
opt     = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)
scaler  = torch.amp.GradScaler(device_type="cuda") if device.startswith("cuda") else None

  return func(*args, **kwargs)
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b0-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([1]) in the model instantiated
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([1, 256, 1, 1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
def train_one_epoch(model, loader, opt, scaler, device, desc="train"):
    """Run one training epoch and return (loss, mIoU)."""
    model.train()
    inter = union = loss_sum = 0.0
    n = 0

    use_cuda = device.startswith("cuda")

    def autocast():
        return (
            torch.amp.autocast(device_type="cuda")
            if use_cuda
            else contextlib.nullcontext()
        )

    for imgs, masks in tqdm(loader, desc=desc, leave=False):
        imgs, masks = imgs.to(device), masks.to(device)

        with autocast():
            out = model(pixel_values=imgs).logits
            logits = F.interpolate(
                out, size=masks.shape[-2:], mode="bilinear", align_corners=False
            ).squeeze(1)
            loss = F.binary_cross_entropy_with_logits(logits, masks.float())

        opt.zero_grad(set_to_none=True)
        if use_cuda:
            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()
        else:
            loss.backward()
            opt.step()

        preds = torch.sigmoid(logits) > 0.5
        inter += torch.logical_and(preds, masks).sum().item()
        union += torch.logical_or(preds, masks).sum().item()

        loss_sum += loss.item()
        n += 1

    return loss_sum / n, inter / (union + 1e-6)

In [5]:
@torch.no_grad()
def eval_one_epoch(model, loader, device, desc="val"):
    """Run one validation epoch and return (loss, mIoU)."""
    model.eval()
    inter = union = loss_sum = 0.0
    n = 0

    for imgs, masks in tqdm(loader, desc=desc, leave=False):
        imgs, masks = imgs.to(device), masks.to(device)
        out = model(pixel_values=imgs).logits
        logits = F.interpolate(
            out, size=masks.shape[-2:], mode="bilinear", align_corners=False
        ).squeeze(1)
        loss = F.binary_cross_entropy_with_logits(logits, masks.float())

        preds = torch.sigmoid(logits) > 0.5
        inter += torch.logical_and(preds, masks).sum().item()
        union += torch.logical_or(preds, masks).sum().item()

        loss_sum += loss.item()
        n += 1

    return loss_sum / n, inter / (union + 1e-6)

### 4. Quick Epoch Run & History

Currently could take up to 11 hours and 5 minutes.

In [6]:
EPOCHS = 5
history = {"train_loss": [], "train_iou": [], "val_loss": [], "val_iou": []}

for ep in range(EPOCHS):
    tl, ti = train_one_epoch(
        model, train_loader, opt, scaler, device, desc=f"ep{ep:02d}_train"
    )
    vl, vi = eval_one_epoch(model, val_loader, device,           desc=f"ep{ep:02d}_val")
    history["train_loss"].append(tl)
    history["train_iou"].append(ti)
    history["val_loss"].append(vl)
    history["val_iou"].append(vi)
    print(
        f"Epoch {ep:02d} ▶ train_loss={tl:.3f}, train_iou={ti:.3f} | "
        f"val_loss={vl:.3f}, val_iou={vi:.3f}"
    )



ep00_train:   0%|          | 0/401 [00:00<?, ?it/s]

KeyboardInterrupt: 

### 5. Plot Training Curves

In [None]:
epochs = range(EPOCHS)
plt.figure(figsize=(10,4))

plt.subplot(1,2,1)
plt.plot(epochs, history["train_loss"], marker="o", label="train")
plt.plot(epochs, history["val_loss"],   marker="o", label="val")
plt.title("BCE-Logits Loss")
plt.xlabel("Epoch"); plt.ylabel("Loss")
plt.legend()

plt.subplot(1,2,2)
plt.plot(epochs, history["train_iou"], marker="o", label="train")
plt.plot(epochs, history["val_iou"],   marker="o", label="val")
plt.title("Mean IoU")
plt.xlabel("Epoch"); plt.ylabel("IoU")
plt.legend()

plt.tight_layout()
plt.show()

### 6. Prediction Visualization

In [None]:
# grab one batch from validation
imgs, masks = next(iter(val_loader))
imgs, masks = imgs.to(device), masks.to(device)

# forward & threshold
with torch.no_grad():
    out    = model(pixel_values=imgs).logits
    logits = F.interpolate(
        out, size=masks.shape[-2:], mode="bilinear", align_corners=False
    ).squeeze(1)
    preds  = (torch.sigmoid(logits) > 0.5).cpu()

# plot first 4
N = min(4, imgs.size(0))
plt.figure(figsize=(12, 8))
for i in range(N):
    img_np = imgs[i].cpu().permute(1,2,0).numpy()
    mask_gt = masks[i].cpu().numpy()
    mask_pr = preds[i].numpy()

    plt.subplot(N, 3, 3*i+1)
    plt.imshow(img_np); plt.title("Image"); plt.axis("off")
    plt.subplot(N, 3, 3*i+2)
    plt.imshow(mask_gt, cmap="gray"); plt.title("GT Mask"); plt.axis("off")
    plt.subplot(N, 3, 3*i+3)
    plt.imshow(mask_pr, cmap="gray"); plt.title("Pred Mask"); plt.axis("off")
plt.tight_layout()
plt.show()