In [1]:
import contextlib
import os

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.datasets.multi_build import build_dataset_from_keys
from src.models.segformer_baseline import load_model
from torch.utils.data import DataLoader, random_split
from transformers import Trainer, TrainingArguments


In [2]:
BUILD_KEYS = ["tcr_phase1_build1", "tcr_phase1_build2"]

### 1. Dataset & DataLoader

Just 10 layers for now

In [None]:
# Build & split
full_ds = build_dataset_from_keys(
    # BUILD_KEYS, size=256, augment=True, layers=range(0, 100, 10)
    BUILD_KEYS, size=256, augment=True, layers=None
)


val_split = 0.1
n_val   = int(len(full_ds) * val_split)
n_train = len(full_ds) - n_val

g = torch.Generator().manual_seed(42)
train_ds, val_ds = random_split(full_ds, [n_train, n_val], generator=g)


In [4]:
# sanity 
print(len(train_ds), len(val_ds))
sample_img, sample_mask = train_ds[0]
print(sample_img.shape, sample_mask.shape)


6423 713


torch.Size([3, 256, 256]) torch.Size([256, 256])


In [5]:
#  Device
if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"
print("Device:", device)

Device: cuda


### 2. Model & Helpers

In [6]:
# Load ViT‐SegFormer
processor, model = load_model()

in_ch = model.decode_head.classifier.in_channels
model.decode_head.classifier = nn.Conv2d(in_ch, 3, kernel_size=1)

model.config.num_labels = 3
model.config.id2label  = {0: "background", 1: "streak", 2: "spatter"}
model.config.label2id  = {v: k for k, v in model.config.id2label.items()}

model.to(device)

  return func(*args, **kwargs)
Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b0-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([1]) in the model instantiated
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([1, 256, 1, 1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


SegformerForSemanticSegmentation(
  (segformer): SegformerModel(
    (encoder): SegformerEncoder(
      (patch_embeddings): ModuleList(
        (0): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(3, 32, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
          (layer_norm): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
        )
        (1): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (2): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(64, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((160,), eps=1e-05, elementwise_affine=True)
        )
        (3): SegformerOverlapPatchEmbeddings(
          (proj): Conv2d(160, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (layer_norm): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  

### 4. Quick Epoch Run & History

In [7]:
torch.backends.cuda.matmul.allow_tf32 = True

with contextlib.suppress(ValueError):
    model.gradient_checkpointing_enable() # saves VRAM

model = torch.compile(model)                # PyTorch 2.x dynamic compile

In [8]:
training_args = TrainingArguments(
    output_dir="../runs/segformer",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=30,
    fp16=True,
    optim="adamw_torch_fused",
    # tf32=True,  # Disabled: causes ValueError if not supported by hardware
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    gradient_accumulation_steps=2,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    dataloader_num_workers=os.cpu_count(),
    dataloader_pin_memory=True,
    ddp_find_unused_parameters=False,
    report_to=["mlflow"],          # comment out if MLflow not set up
)

In [9]:

def collate_fn(batch):
    imgs  = torch.stack([b[0] for b in batch])
    masks = torch.stack([b[1] for b in batch])
    return {"pixel_values": imgs, "labels": masks}

In [10]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collate_fn,
)

In [None]:
trainer.train()

In [None]:
# Tiny dataloader for visualization
viz_loader = DataLoader(
    val_ds,
    batch_size=4,
    shuffle=False,
    num_workers=min(2, os.cpu_count()),
    pin_memory=torch.cuda.is_available(),
    collate_fn=collate_fn,
)

imgs, masks = next(iter(viz_loader))
imgs, masks = imgs.to(model.device), masks.to(model.device)

### 5. Plot Training Curves

In [None]:
df = pd.DataFrame(trainer.state.log_history).dropna(subset=["epoch"])

plt.figure(figsize=(7, 4))
plt.subplot(1, 2, 1)
plt.plot(
	df["epoch"], 
	df["bce"], 
	label="BCE (val)" if "eval" in df.columns else "BCE"
	)

plt.title("Loss")
plt.xlabel("epoch")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(
	df["epoch"], 
	df["mean_iou"], 
	label="mIoU"
	)
plt.title("Mean IoU")
plt.xlabel("epoch")
plt.legend()
plt.tight_layout()

### 6. Prediction Visualization

In [None]:
model.eval()                         # IMPORTANT: disable dropout etc.

with torch.no_grad():
    logits = model(pixel_values=imgs).logits         # [B, C, H, W]
    preds  = logits.argmax(dim=1).cpu().numpy()      # class map per pixel

N = min(imgs.shape[0], 4)            # show at most 4 samples
fig, axes = plt.subplots(N, 3, figsize=(12, 3 * N))

rgba = {0: (0, 0, 0, 0.0),           # transparent background
        1: (0, 0, 1, 0.4),           # streak is blue
        2: (1, 0, 0, 0.4)}           # spatter is red

for i in range(N):
    ax_img, ax_gt, ax_pr = axes[i]

    im = imgs[i].cpu().permute(1, 2, 0).numpy()
    gt = masks[i].cpu().numpy()
    pr = preds[i]

    overlay_gt = np.zeros((*gt.shape, 4))
    overlay_pr = np.zeros_like(overlay_gt)
    for cls in (1, 2):
        overlay_gt[gt == cls] = rgba[cls]
        overlay_pr[pr == cls] = rgba[cls]

    # image
    ax_img.imshow(im)     
    ax_img.set_title("Image")             
    ax_img.axis("off")
    # ground–truth
    ax_gt.imshow(im)
    ax_gt.imshow(overlay_gt)
    ax_gt.set_title("Ground truth")
    ax_gt.axis("off")
    # prediction
    ax_pr.imshow(im)
    ax_pr.imshow(overlay_pr)
    ax_pr.set_title("Prediction")
    ax_pr.axis("off")

fig.legend(
    handles=[mpatches.Patch(color='blue', alpha=0.4, label='Streak'),
             mpatches.Patch(color='red',  alpha=0.4, label='Spatter')],
    loc="upper center", ncol=2, frameon=False, bbox_to_anchor=(0.5, 1.05)
)
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
