# Run Pipeline - Simple UNet

In [1]:
"""
A bite-size U-Net demo that prints tensor shapes to illustrate
concatenating skip connections.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import math

# ---------- 1. Reproducibility ------------------------------------------------
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# ---------- 2. Synthetic Segmentation Dataset --------------------------------
class CircleDataset(Dataset):
    """64×64 images containing a single filled circle + noise; mask is circle."""
    def __init__(self, n_samples: int):
        self.n = n_samples
        self.size = 64

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        H = W = self.size
        img  = torch.zeros((H, W), dtype=torch.float32)
        mask = torch.zeros((H, W), dtype=torch.float32)

        # random circle
        radius = random.randint(5, 15)
        cx = random.randint(radius + 1, W - radius - 2)
        cy = random.randint(radius + 1, H - radius - 2)

        yy, xx = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
        circle = ((xx - cx) ** 2 + (yy - cy) ** 2) <= radius ** 2
        mask[circle] = 1.0
        img[circle] = 1.0    # base signal

        # add slight noise
        img += 0.1 * torch.rand_like(img)
        img = img.clamp(0.0, 1.0)

        return img.unsqueeze(0), mask.unsqueeze(0)  # channel dim

In [3]:
# ---------- 3. Network Building Blocks ---------------------------------------
class DoubleConv(nn.Module):
    """(Conv → ReLU) × 2"""
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)

class DownBlock(nn.Module):
    def __init__(self, in_ch, out_ch, pool=True):
        super().__init__()
        self.double_conv = DoubleConv(in_ch, out_ch)
        self.pool   = nn.MaxPool2d(2) if pool else None

    def forward(self, x, verbose=False, tag=""):
        x = self.double_conv(x)
        if verbose:
            print(f"{tag} feat: {x.shape}")
        if self.pool is None:
            return x                           # bottom – no pool, no skip
        pooled = self.pool(x)
        if verbose:
            print(f"{tag} pool: {pooled.shape}")
        return x, pooled

class UpBlock(nn.Module):
    """Up-sample, concatenate skip, DoubleConv"""
    def __init__(self, in_ch, skip_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2)
        self.double_conv = DoubleConv(in_ch // 2 + skip_ch, out_ch)

    def forward(self, x, skip, verbose=False, tag=""):
        x = self.up(x)
        if verbose:
            print(f"{tag} up   : {x.shape}")
            print(f"{tag} skip : {skip.shape}")
        x = torch.cat([x, skip], dim=1)   # concat channels
        if verbose:
            print(f"{tag} cat  : {x.shape}")
        x = self.double_conv(x)
        if verbose:
            print(f"{tag} out  : {x.shape}")
        return x

# ---------- 4. Mini-U-Net -----------------------------------------------------
class MiniUNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.d1 = DownBlock(1, 16, pool=True)
        self.d2 = DownBlock(16, 32, pool=True)
        self.bottom = DownBlock(32, 64, pool=False)   # ← no max-pool here

        self.u2 = UpBlock(64, 32, 32)
        self.u1 = UpBlock(32, 16, 16)

        self.final = nn.Conv2d(16, 1, kernel_size=1)
        self._verbose_once = True  # only print shapes once

    def forward(self, x):
        verbose = False
        if self._verbose_once:
            verbose = True
            self._verbose_once = False

        s1, x = self.d1(x, verbose, "Down1")
        s2, x = self.d2(x, verbose, "Down2")
        x     = self.bottom(x, verbose, "Bottom")      # returns only x (16×16)

        x = self.u2(x, s2, verbose, "Up2")
        x = self.u1(x, s1, verbose, "Up1")

        if verbose:
            print(f"Final conv input: {x.shape}")
        logits = self.final(x)   # (B,1,64,64)
        if verbose:
            print(f"Logits out: {logits.shape}\n")
        return logits

In [4]:
# ---------- 5. Metrics --------------------------------------------------------
def dice_accuracy(logits, masks):
    """Simple pixel accuracy (threshold 0.5)."""
    preds = (torch.sigmoid(logits) > 0.5).float()
    return (preds == masks).float().mean().item()

# ---------- 6. Training -------------------------------------------------------
def train_epoch(model, loader, crit, opt, epoch):
    model.train()
    running_loss, running_acc = 0.0, 0.0
    for imgs, masks in loader:
        imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)

        opt.zero_grad()
        logits = model(imgs)
        loss = crit(logits, masks)
        loss.backward()
        opt.step()

        running_loss += loss.item() * imgs.size(0)
        running_acc  += dice_accuracy(logits, masks) * imgs.size(0)

    n = len(loader.dataset)
    print(f"[Epoch {epoch:02}] "
          f"loss={running_loss/n:.4f}  acc={running_acc/n:.3f}")

def validate(model, loader, crit):
    model.eval()
    val_loss, val_acc = 0.0, 0.0
    with torch.no_grad():
        for imgs, masks in loader:
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
            logits = model(imgs)
            val_loss += crit(logits, masks).item() * imgs.size(0)
            val_acc  += dice_accuracy(logits, masks) * imgs.size(0)
    n = len(loader.dataset)
    return val_loss/n, val_acc/n

In [5]:
class DiceLoss(nn.Module):
    """Soft Dice loss for logits; works with BCEWithLogits-style output."""
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        num   = 2 * (probs * targets).sum(dim=(1,2,3)) + self.eps
        den   = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) + self.eps
        dice  = num / den
        return 1 - dice.mean()                      # minimise (1-dice)

In [6]:
def main():
    train_ds = CircleDataset(200)
    val_ds   = CircleDataset(50)
    train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
    val_loader   = DataLoader(val_ds, batch_size=16)

    model = MiniUNet().to(DEVICE)
    # criterion = nn.BCEWithLogitsLoss()
    criterion = DiceLoss() # custom loss function
    optimiser = torch.optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(1, 6):
        train_epoch(model, train_loader, criterion, optimiser, epoch)
        v_loss, v_acc = validate(model, val_loader, criterion)
        print(f"          ↳ val_loss={v_loss:.4f}  val_acc={v_acc:.3f}\n")


In [7]:
main()

Down1 feat: torch.Size([16, 16, 64, 64])
Down1 pool: torch.Size([16, 16, 32, 32])
Down2 feat: torch.Size([16, 32, 32, 32])
Down2 pool: torch.Size([16, 32, 16, 16])
Bottom feat: torch.Size([16, 64, 16, 16])
Up2 up   : torch.Size([16, 32, 32, 32])
Up2 skip : torch.Size([16, 32, 32, 32])
Up2 cat  : torch.Size([16, 64, 32, 32])
Up2 out  : torch.Size([16, 32, 32, 32])
Up1 up   : torch.Size([16, 16, 64, 64])
Up1 skip : torch.Size([16, 16, 64, 64])
Up1 cat  : torch.Size([16, 32, 64, 64])
Up1 out  : torch.Size([16, 16, 64, 64])
Final conv input: torch.Size([16, 16, 64, 64])
Logits out: torch.Size([16, 1, 64, 64])

[Epoch 01] loss=0.8576  acc=0.081
          ↳ val_loss=0.8265  val_acc=0.079

[Epoch 02] loss=0.7792  acc=0.100
          ↳ val_loss=0.7928  val_acc=0.138

[Epoch 03] loss=0.7454  acc=0.598
          ↳ val_loss=0.6697  val_acc=0.954

[Epoch 04] loss=0.6836  acc=0.963
          ↳ val_loss=0.5108  val_acc=0.994

[Epoch 05] loss=0.2216  acc=0.991
          ↳ val_loss=0.0401  val_acc=0.9