# Course/TP — Transformer-based Segmentation for Astrophysics
Generated on 2025-10-28T18:20:06.571472Z

> A complete, hands-on notebook that teaches Transformer-based image segmentation on astro-like data. Includes theory, architecture choices, code, metrics, and visual diagnostics. No internet needed.


## Learning Outcomes

By the end of this TP you will be able to:

1. Generate and inspect an **astro-like synthetic segmentation dataset** (galaxies on noisy backgrounds).
2. Implement a **Transformer-based segmentation model** (lite SegFormer-style encoder + simple decoder).
3. Understand **patch/overlap embeddings**, **self-attention**, **positional encodings**, and **token mixing** in a segmentation setting.
4. Train with **Dice + BCE** loss, monitor **IoU** and **Dice**, and visualize qualitative predictions.
5. Tune key hyperparameters (patch size, embedding dimension, number of heads/layers, learning rate, augmentations).



## Context

Astronomical surveys often require **pixel-wise segmentation** to separate sources from background, delineate galaxy disks, bulges, arms, or to flag artifacts. Transformers have become competitive for dense prediction tasks through encoder-decoder designs such as **SegFormer** and **Mask2Former**. This notebook builds a **minimal working prototype**, optimized for clarity over maximum performance.

If you later want to plug in **real data**, you can adapt the dataset class to read FITS or PNG images and load masks prepared by your pipeline (e.g., SExtractor, human annotations).


## 1) Environment Check & Imports

In [None]:

import os, math, random, time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# Reproducibility
SEED = 1234
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)



## 2) Synthetic Astro-like Dataset

We create simple grayscale scenes (H×W) with:
- Background noise (Gaussian + Poisson-like mix)
- A few **galaxy-like blobs** (elliptical Gaussian profiles)
- Optional **bar/arm hints** with simple parametric curves
- Final **binary mask**: 1 for galaxy pixels, 0 for background

This is intentionally simple but surprisingly effective for practicing segmentation.


In [None]:

H, W = 128, 128           # image size
N_TRAIN, N_VAL, N_TEST = 800, 120, 120
BATCH_SIZE = 16

def draw_elliptical_gaussian(img, msk, cx, cy, a, b, theta, peak):
    # Elliptical Gaussian blob centered at (cx,cy) with axes a,b and angle theta
    yy, xx = np.mgrid[0:H, 0:W]
    x0 = xx - cx
    y0 = yy - cy
    ct, st = np.cos(theta), np.sin(theta)
    xr =  ct * x0 + st * y0
    yr = -st * x0 + ct * y0
    g = peak * np.exp(-0.5 * ((xr / (a+1e-6))**2 + (yr / (b+1e-6))**2))
    img += g
    msk[:] = np.maximum(msk, (g > (0.1*peak)).astype(np.float32))

def add_bar(img, cx, cy, length, width, theta, amp):
    yy, xx = np.mgrid[0:H, 0:W]
    x0 = xx - cx
    y0 = yy - cy
    ct, st = np.cos(theta), np.sin(theta)
    xr =  ct * x0 + st * y0
    yr = -st * x0 + ct * y0
    bar = np.exp(-0.5*((yr/width)**2)) * (np.abs(xr) < length).astype(np.float32)
    img += amp * bar

def add_spiral_hint(img, cx, cy, turns, amp):
    # Very rough spiral hint using log-spiral radius
    yy, xx = np.mgrid[0:H, 0:W]
    r = np.sqrt((xx-cx)**2 + (yy-cy)**2) + 1e-6
    ang = np.arctan2(yy-cy, xx-cx)
    # create ridges in angle-radius space
    k = turns * 0.5
    ridge = np.sin(k * np.log(r+1.0) + 3*ang)
    ridge = (ridge > 0.9).astype(np.float32)
    img += amp * ridge

def make_sample():
    img = np.zeros((H, W), dtype=np.float32)
    msk = np.zeros((H, W), dtype=np.float32)
    # background
    img += 0.02*np.random.randn(H, W).astype(np.float32)
    img += np.random.poisson(1.0, size=(H, W)).astype(np.float32) * 0.001

    # number of galaxies
    n_obj = np.random.randint(1, 4)
    for _ in range(n_obj):
        cx = np.random.uniform(20, W-20)
        cy = np.random.uniform(20, H-20)
        a  = np.random.uniform(4, 12)
        b  = np.random.uniform(3, 10)
        theta = np.random.uniform(0, np.pi)
        peak = np.random.uniform(0.7, 1.5)
        draw_elliptical_gaussian(img, msk, cx, cy, a, b, theta, peak)

        if np.random.rand() < 0.4:
            add_bar(img, cx, cy, length=np.random.uniform(6, 14),
                    width=np.random.uniform(1.5, 3.5), theta=theta, amp=0.2)

        if np.random.rand() < 0.3:
            add_spiral_hint(img, cx, cy, turns=np.random.uniform(1.0, 2.5), amp=0.1)

    # normalization to [0,1]
    img -= img.min()
    img /= (img.max() + 1e-6)
    return img.astype(np.float32), msk.astype(np.float32)

# Quick sanity check sample
img, msk = make_sample()
print("Sample pixel stats:", img.min(), img.max(), msk.mean())


In [None]:

# Dataset & DataLoader
class ToyAstroSegDataset(Dataset):
    def __init__(self, n_samples):
        self.n = n_samples
    def __len__(self):
        return self.n
    def __getitem__(self, idx):
        img, msk = make_sample()
        img = img[None, ...]   # add channel: (1, H, W)
        return torch.from_numpy(img), torch.from_numpy(msk)

train_ds = ToyAstroSegDataset(N_TRAIN)
val_ds   = ToyAstroSegDataset(N_VAL)
test_ds  = ToyAstroSegDataset(N_TEST)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

len(train_ds), len(val_ds), len(test_ds)



### Visualization

Always look at your data. Here we plot a few images with their masks.


In [None]:

def show_batch(dataset, n=6):
    import math
    n = min(n, len(dataset))
    cols = 3
    rows = int(math.ceil(n/cols))
    plt.figure(figsize=(cols*4, rows*4))
    for i in range(n):
        img, msk = dataset[i]
        plt.subplot(rows, cols, i+1)
        plt.imshow(img[0].numpy())
        plt.title("Image (grayscale)")
        plt.axis("off")
    plt.show()

    plt.figure(figsize=(cols*4, rows*4))
    for i in range(n):
        img, msk = dataset[i]
        plt.subplot(rows, cols, i+1)
        plt.imshow(msk.numpy())
        plt.title("Mask")
        plt.axis("off")
    plt.show()

show_batch(train_ds, n=6)



## 3) Transformer-based Segmentation Model

We implement a **lightweight SegFormer-like** architecture:

- **Overlapping Patch Embedding**: A small convolution projects the input into tokens with spatial stride (patch size) and overlap, improving local continuity.
- **Transformer Encoder Blocks**: Stacks of self-attention + MLP (with GELU), LayerNorm pre-norm, residual connections.
- **Positional Encoding**: 2D learnable positional embeddings added to token features.
- **Decoder**: Simple upsampling head that uses a sequence-to-image reshape and a few conv layers to output a dense mask.

### Key Parameters

- `patch_size` (stride): controls downsampling. Larger stride reduces memory but loses fine detail.
- `embed_dim`: token channel dimension. Higher improves capacity but costs memory.
- `depth`: number of Transformer layers.
- `num_heads`: attention heads. More heads give more subspace mixing but add compute.
- `mlp_ratio`: hidden size of MLP (`mlp_dim = embed_dim * mlp_ratio`).

We keep this model compact so it trains fast on CPU/GPU.


In [None]:

class OverlapPatchEmbed(nn.Module):
    def __init__(self, in_ch=1, embed_dim=64, patch_size=8, overlap=0.5):
        super().__init__()
        stride = patch_size
        k = int(patch_size + patch_size*overlap)
        if k % 2 == 0:
            k += 1  # ensure odd kernel for centered receptive field
        padding = k // 2
        self.proj = nn.Conv2d(in_ch, embed_dim, kernel_size=k, stride=stride, padding=padding)
        self.norm = nn.LayerNorm(embed_dim)
        self.patch_size = patch_size

    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)  # (B, E, H', W')
        B, E, Hp, Wp = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, N, E), N = Hp*Wp
        x = self.norm(x)
        return x, (Hp, Wp)

class MLP(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, drop=0.0):
        super().__init__()
        hidden = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, dim)
        self.drop = nn.Dropout(drop)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=4, mlp_ratio=4.0, attn_drop=0.0, drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=attn_drop, batch_first=True)
        self.drop_path1 = nn.Dropout(drop)

        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_ratio=mlp_ratio, drop=drop)
        self.drop_path2 = nn.Dropout(drop)

    def forward(self, x):
        # Self-attention
        x_res = x
        x = self.norm1(x)
        attn_out, _ = self.attn(x, x, x, need_weights=False)
        x = x_res + self.drop_path1(attn_out)

        # MLP
        x_res = x
        x = self.norm2(x)
        x = x_res + self.drop_path2(self.mlp(x))
        return x

class TinySegFormer(nn.Module):
    def __init__(self, in_ch=1, embed_dim=96, depth=4, num_heads=4, mlp_ratio=3.0,
                 patch_size=8, overlap=0.5, num_classes=1):
        super().__init__()
        self.embed = OverlapPatchEmbed(in_ch, embed_dim, patch_size, overlap)
        self.pos_emb = None  # initialized after first forward (depends on H'W')
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.num_classes = num_classes

        # Decoder: simple conv head after reshaping back to (H', W')
        self.decoder = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(embed_dim, embed_dim//2, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Conv2d(embed_dim//2, num_classes, kernel_size=1),
        )

        self.patch_size = patch_size

    def forward(self, x):
        B = x.shape[0]
        tokens, (Hp, Wp) = self.embed(x)        # (B, N, E)
        N, E = tokens.shape[1], tokens.shape[2]

        # create 2D learned positional embeddings if not set
        if self.pos_emb is None or self.pos_emb.shape[1] != N:
            self.pos_emb = nn.Parameter(torch.zeros(1, N, E, device=x.device))
            nn.init.trunc_normal_(self.pos_emb, std=0.02)

        z = tokens + self.pos_emb               # add positional info
        for blk in self.blocks:
            z = blk(z)
        z = self.norm(z)                        # (B, N, E)

        z = z.transpose(1, 2).reshape(B, E, Hp, Wp)  # back to spatial

        # upsample to full resolution
        scale = self.patch_size
        z = F.interpolate(z, scale_factor=scale, mode="bilinear", align_corners=False)

        logits = self.decoder(z)                # (B, num_classes, H, W)
        return logits



### Why Overlapping Patches?

Pure non-overlapping patches can create **block artifacts** and miss local continuity. Overlapping patch embedding combines small convolutions with stride to produce tokens that share neighbors. This improves segmentation edges and stability without heavy computation.


## 4) Losses & Metrics: Dice, BCE, IoU

In [None]:

def dice_loss(pred, target, eps=1e-6):
    # pred: logits (B,1,H,W) -> apply sigmoid inside
    prob = torch.sigmoid(pred)
    num = 2 * (prob * target).sum(dim=(1,2,3))
    den = (prob + target).sum(dim=(1,2,3)) + eps
    dice = 1 - (num / den)
    return dice.mean()

def bce_loss(pred, target):
    return F.binary_cross_entropy_with_logits(pred, target)

def combo_loss(pred, target, w_dice=0.6, w_bce=0.4):
    return w_dice * dice_loss(pred, target) + w_bce * bce_loss(pred, target)

@torch.no_grad()
def iou_score(pred, target, thresh=0.5, eps=1e-6):
    prob = torch.sigmoid(pred)
    pred_mask = (prob >= thresh).float()
    inter = (pred_mask * target).sum(dim=(1,2,3))
    union = (pred_mask + target - pred_mask*target).sum(dim=(1,2,3)) + eps
    iou = (inter / union).mean().item()
    return iou

@torch.no_grad()
def dice_coeff(pred, target, thresh=0.5, eps=1e-6):
    prob = torch.sigmoid(pred)
    pred_mask = (prob >= thresh).float()
    num = 2 * (pred_mask * target).sum(dim=(1,2,3))
    den = (pred_mask + target).sum(dim=(1,2,3)) + eps
    return (num / den).mean().item()



## 5) Training Loop

We use AdamW, a modest learning rate, and track validation IoU/Dice. Tune:

- `embed_dim`, `depth`, `num_heads`: capacity
- `patch_size`: resolution vs compute
- `lr`, `weight_decay`, `epochs`: optimization


In [None]:

model = TinySegFormer(in_ch=1, embed_dim=128, depth=4, num_heads=4, mlp_ratio=3.0,
                      patch_size=8, overlap=0.5, num_classes=1).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
EPOCHS = 10

def train_one_epoch(loader):
    model.train()
    total_loss = 0.0
    for img, msk in loader:
        img = img.to(device)
        msk = msk.to(device).unsqueeze(1)  # (B,1,H,W)
        optimizer.zero_grad()
        logits = model(img)
        loss = combo_loss(logits, msk)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * img.size(0)
    return total_loss / len(loader.dataset)

@torch.no_grad()
def evaluate(loader):
    model.eval()
    total_loss = 0.0
    iou, dice = [], []
    for img, msk in loader:
        img = img.to(device)
        msk = msk.to(device).unsqueeze(1)
        logits = model(img)
        loss = combo_loss(logits, msk)
        total_loss += loss.item() * img.size(0)
        iou.append(iou_score(logits, msk))
        dice.append(dice_coeff(logits, msk))
    return total_loss / len(loader.dataset), float(np.mean(iou)), float(np.mean(dice))

train_hist = {"train_loss": [], "val_loss": [], "val_iou": [], "val_dice": []}
for epoch in range(1, EPOCHS+1):
    tr = train_one_epoch(train_loader)
    vl, viou, vdice = evaluate(val_loader)
    train_hist["train_loss"].append(tr)
    train_hist["val_loss"].append(vl)
    train_hist["val_iou"].append(viou)
    train_hist["val_dice"].append(vdice)
    print(f"Epoch {epoch:02d} | train {tr:.4f} | val {vl:.4f} | IoU {viou:.3f} | Dice {vdice:.3f}")


### Curves: Loss / IoU / Dice

In [None]:

# One plot per figure, matplotlib only, no specific colors
plt.figure()
plt.plot(train_hist["train_loss"], label="train loss")
plt.plot(train_hist["val_loss"], label="val loss")
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Training vs Validation Loss"); plt.legend(); plt.show()

plt.figure()
plt.plot(train_hist["val_iou"], label="val IoU")
plt.xlabel("Epoch"); plt.ylabel("IoU"); plt.title("Validation IoU"); plt.legend(); plt.show()

plt.figure()
plt.plot(train_hist["val_dice"], label="val Dice")
plt.xlabel("Epoch"); plt.ylabel("Dice"); plt.title("Validation Dice"); plt.legend(); plt.show()


## 6) Test Evaluation & Qualitative Results

In [None]:

test_loss, test_iou, test_dice = evaluate(test_loader)
print(f"Test | loss {test_loss:.4f} | IoU {test_iou:.3f} | Dice {test_dice:.3f}")


In [None]:

@torch.no_grad()
def show_predictions(loader, n_batches=2):
    model.eval()
    shown = 0
    for img, msk in loader:
        img = img.to(device)
        logits = model(img)
        prob = torch.sigmoid(logits).cpu().numpy()
        img = img.cpu().numpy()
        msk = msk.numpy()
        B = img.shape[0]
        for b in range(B):
            if shown >= n_batches * BATCH_SIZE:
                return
            plt.figure()
            plt.imshow(img[b,0])
            plt.title("Input image")
            plt.axis("off")
            plt.show()

            plt.figure()
            plt.imshow(msk[b])
            plt.title("Ground truth mask")
            plt.axis("off")
            plt.show()

            plt.figure()
            plt.imshow((prob[b,0] >= 0.5).astype(np.float32))
            plt.title("Prediction (thresholded)")
            plt.axis("off")
            plt.show()

            shown += 1

show_predictions(test_loader, n_batches=1)



## 7) Advanced Topics & Extensions

1. **Multi-class Segmentation**: Change `num_classes` to K and use `CrossEntropyLoss` with one-hot targets for Dice (or per-class Dice).
2. **Pyramid Features**: Build multiple stages with increasing stride (e.g., 4/8/16) and a lightweight decoder that fuses multiscale features.
3. **Relative Positional Bias**: Replace absolute positional embeddings with relative biases to better generalize to varying sizes.
4. **Data Augmentation**: Random rotations, flips, elastic deformations. For astronomy, prefer intensity-preserving transforms.
5. **Mixed Precision**: Use `torch.cuda.amp` for faster training on GPUs.
6. **Regularization**: Stochastic depth, attention dropout, spatial dropout.
7. **Real FITS Data**: Replace the dataset with FITS readers. Normalize by exposure time, handle bad pixels with masks, and consider PSF variations.



## 8) Key Hyperparameters: Practical Guide

- `patch_size`: start at 8 or 4 for 128×128 images. Smaller captures finer edges but increases tokens.
- `embed_dim`: 64–256 is common for small problems. Larger boosts accuracy at compute cost.
- `depth`: 4–12 layers depending on capacity and data size.
- `num_heads`: 2–8 typically. More heads may help complex textures.
- `mlp_ratio`: 3–4 is a good default.
- `loss`: Dice + BCE is robust for class imbalance common in astronomy (small sources vs large background).
- `lr`: 1e-3 to 3e-4 with AdamW works well; watch validation metrics to pick a schedule.



## 9) Checklist for Real Datasets

- [ ] Data loading (FITS/PNG), normalization, bad-pixel masking
- [ ] Train/val/test split respecting fields/targets
- [ ] Metric selection aligned to science goals (e.g., high recall of faint sources)
- [ ] Calibration of predicted masks if used for photometry
- [ ] Uncertainty estimation (e.g., MC dropout) for downstream vetting
