In [2]:
import contextlib

import fiftyone as fo
import fiftyone.utils.annotations as foa
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.datasets.multi_build import build_dataset_from_keys
from src.models.segformer_baseline import load_model
from torch.utils.data import DataLoader, random_split
from tqdm.auto import tqdm


In [3]:
BUILD_KEYS = ["tcr_phase1_build1", "tcr_phase1_build2"]

### 1. Dataset & DataLoader

Just 10 layers for now

In [4]:
# Build & split
full_ds = build_dataset_from_keys(
    BUILD_KEYS, size=512, augment=True, layers=range(0, 100, 10)
)
n_val   = int(len(full_ds)*0.1)
n_train = len(full_ds) - n_val
train_ds, val_ds = random_split(full_ds, 
								[n_train,n_val], 
								generator=torch.Generator().manual_seed(42))

print("dataset sizes:", len(train_ds), len(val_ds))

dataset sizes: 18 2


In [5]:
#  Device
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
print("Device:", device)

Device: cpu


In [6]:
# DataLoaders
train_loader = DataLoader(
    train_ds,
    batch_size=8,
    shuffle=True,
    num_workers=4,
    pin_memory=True,         # faster host → device copies
    prefetch_factor=4,
    persistent_workers=True,
)
val_loader = DataLoader(
    val_ds,
    batch_size=8,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    prefetch_factor=4,
    persistent_workers=True,
)

In [7]:
# BATCH   = 8 if device!="cpu" else 4
# WORKERS = 4 if device=="cpu" else 8
# train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,
#                           num_workers=WORKERS, pin_memory=(device=="cuda"))
# val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False,
#                           num_workers=WORKERS, pin_memory=(device=="cuda"))

In [8]:

print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

Train batches: 3, Val batches: 1


### 2. Model & Helpers

In [9]:
# Load ViT‐SegFormer
processor, model = load_model()
in_ch = model.decode_head.classifier.in_channels
model.decode_head.classifier = nn.Conv2d(in_ch, 2, kernel_size=1)
model.config.num_labels = 2
model.config.id2label = {0:"streak", 1:"spatter"}
model.config.label2id = {"streak":0, "spatter":1}
model.to(device)

# Optimiser + scaler
opt    = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-2)
scaler = torch.amp.GradScaler(device_type="cuda") if device=="cuda" else None

  return func(*args, **kwargs)
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b0-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([1]) in the model instantiated
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([1, 256, 1, 1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
def train_one_epoch(model, loader, opt, scaler, device, desc="train"):
    """Returns mean_loss, mean_mIoU across classes 0/1."""

    model.train()
    total_loss = 0.0
    inter = [0,0]
    union = [0,0]
    n_batches=0

    use_amp = (device=="cuda")
    def autocast(): return (
                torch.amp.autocast(device_type="cuda") 
                if use_amp else contextlib.nullcontext()
    )

    for imgs, masks in tqdm(loader, desc=desc, leave=False):
        imgs, masks = imgs.to(device), masks.to(device)
        imgs = imgs.to(device)
        masks = masks.to(device).long()
        with autocast():
            out = model(pixel_values=imgs).logits          # [B,2,h,w]
            out = F.interpolate(out, 
                                size=masks.shape[-2:], 
                                mode="bilinear",
                                align_corners=False
                                )
            loss=F.cross_entropy(out, masks)

        opt.zero_grad(set_to_none=True)
        if use_amp:
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
        else:
            loss.backward(); opt.step()

        preds = out.argmax(dim=1)
        for cls in (0,1):
            inter[cls] += int(((preds==cls)&(masks==cls)).sum())
            union[cls] += int(((preds==cls)|(masks==cls)).sum())

        total_loss += loss.item()
        n_batches +=1

    miou = sum(inter[i]/(union[i]+1e-6) for i in (0,1))/2
    return total_loss/n_batches, miou

In [11]:
@torch.no_grad()
def eval_one_epoch(model, loader, device, desc="val"):
    model.eval()
    total_loss=0.0; inter=[0,0]; union=[0,0]; n_batches=0
    for imgs, masks in tqdm(loader, desc=desc, leave=False):
        imgs, masks = imgs.to(device), masks.to(device)
        out = model(pixel_values=imgs).logits
        out = F.interpolate(out, 
                            size=masks.shape[-2:], 
                            mode="bilinear", 
                            align_corners=False
                            )
        loss=F.cross_entropy(out, masks)

        preds = out.argmax(dim=1)
        for cls in (0,1):
            inter[cls] += int(((preds==cls)&(masks==cls)).sum())
            union[cls] += int(((preds==cls)|(masks==cls)).sum())

        total_loss += loss.item()
        n_batches +=1

    miou = sum(inter[i]/(union[i]+1e-6) for i in (0,1))/2
    return total_loss/n_batches, miou

### 4. Quick Epoch Run & History

In [12]:
EPOCHS=5
hist = {"tl":[], "ti":[],"vl":[],"vi":[]}

for ep in range(EPOCHS):
    tl, ti = train_one_epoch(model, train_loader, opt, scaler, device, desc=f"ep{ep}_tr")
    vl, vi = eval_one_epoch(model,   val_loader,   device, desc=f"ep{ep}_vl")
    hist["tl"].append(tl); hist["ti"].append(ti)
    hist["vl"].append(vl); hist["vi"].append(vi)
    print(
        f"Epoch {ep:02d} ▶ train_loss={tl:.3f}, train_iou={ti:.3f} | "
        f"val_loss={vl:.3f}, val_iou={vi:.3f}"
    )



ep0_tr:   0%|          | 0/3 [00:00<?, ?it/s]

RuntimeError: expected scalar type Long but found Float

### 5. Plot Training Curves

In [None]:
# plot
fig,axes=plt.subplots(1,2,figsize=(10,4))
axes[0].plot(hist["tl"], '-o', label="train")
axes[0].plot(hist["vl"], '-o', label="val")
axes[0].set_title("loss"); axes[0].legend()
axes[1].plot(hist["ti"], '-o', label="train")
axes[1].plot(hist["vi"], '-o', label="val")
axes[1].set_title("mIoU"); axes[1].legend()
plt.tight_layout()

### 6. Prediction Visualization

In [None]:
# pick a few from val
imgs, masks = next(iter(val_loader))
imgs, masks = imgs.to(device), masks

with torch.no_grad():
    out = model(pixel_values=imgs).logits
    out = F.interpolate(out, size=masks.shape[-2:], mode="bilinear", align_corners=False)
    preds = out.argmax(dim=1).cpu().numpy()

N=4
plt.figure(figsize=(12, 9))
for i in range(N):
    im = imgs[i].cpu().permute(1,2,0).numpy()
    gt = masks[i].cpu().numpy()
    pr = preds[i]

    # map classes to colours: 0=transparent,1=blue,2=red
    cmap = {0:(0,0,0,0), 1:(0,0,1,0.4), 2:(1,0,0,0.4)}
    overlay_gt = np.zeros((gt.shape[0],gt.shape[1],4))
    overlay_pr = np.zeros_like(overlay_gt)
    for cls in (1,2):
        overlay_gt[gt==cls] = cmap[cls]
        overlay_pr[pr==cls] = cmap[cls]

    plt.subplot(N,3,3*i+1); plt.imshow(im); 
    plt.title("Image"); plt.axis("off")
    plt.subplot(N,3,3*i+2); plt.imshow(im); 
    plt.imshow(overlay_gt); plt.title("GT"); plt.axis("off")
    plt.subplot(N,3,3*i+3); plt.imshow(im); 
    plt.imshow(overlay_pr); plt.title("Pred"); plt.axis("off")

plt.tight_layout()