In [1]:
# Lab: U-Net for Oxford-IIIT Pet

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split

import torchvision.transforms as T
from torchvision.datasets import OxfordIIITPet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

input_size = (128, 128)


Using device: cpu


In [5]:
# TODO 1.1: image transform -> resize to input_size, to tensor [0,1]
transform_img = T.Compose([
    # TODO: fill
])

# TODO 1.2: mask transform -> resize NEAREST, keep ints, shift 1..3 -> 0..2, long
transform_mask = T.Compose([
    # TODO: fill
])

dataset = OxfordIIITPet(root='data', split='trainval', target_types='segmentation',
                        transform=transform_img, target_transform=transform_mask, download=True)
test_dataset = OxfordIIITPet(root='data', split='test', target_types='segmentation',
                             transform=transform_img, target_transform=transform_mask, download=True)

train_size = int(0.8 * len(dataset))
val_size   = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True,  num_workers=2 if device.type=='cuda' else 0)
val_loader   = DataLoader(val_dataset,   batch_size=16, shuffle=False, num_workers=2 if device.type=='cuda' else 0)

len(dataset), len(test_dataset), len(train_dataset), len(val_dataset)


100%|██████████| 792M/792M [00:20<00:00, 38.5MB/s]
100%|██████████| 19.2M/19.2M [00:01<00:00, 12.7MB/s]


(3680, 3669, 2944, 736)

In [8]:
from matplotlib.colors import ListedColormap
cmap = ListedColormap(['black', 'red', 'white'])

sample_img, sample_mask = train_dataset[1]
mask_array = sample_mask.squeeze().cpu().numpy().astype(np.uint8)

plt.figure(figsize=(6,3))
plt.subplot(1,2,1); plt.title("Input Image"); plt.imshow(sample_img.permute(1,2,0)); plt.axis('off')
plt.subplot(1,2,2); plt.title("Ground Truth Mask (0..2)"); plt.imshow(mask_array, cmap=cmap, vmin=0, vmax=2); plt.axis('off')
plt.show()
#nothing to do here just run

AttributeError: 'PngImageFile' object has no attribute 'squeeze'

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c, p_drop=0.1):
        super().__init__()
        # TODO 3.1: two 3x3 conv(+BN+ReLU) with padding=1, then Dropout
        self.net = nn.Sequential(
            # TODO: fill
        )
    def forward(self, x): return self.net(x)

class UpBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_c=out_c*2, out_c=out_c)
    def forward(self, x, skip):
        # TODO 3.2: upsample, size-fix to skip if needed, concat on channel dim, conv
        # x = ...
        # if ...: x = ...
        # x = torch.cat([x, skip], dim=1)
        # return self.conv(x)
        raise NotImplementedError


In [None]:
class UNetMedium(nn.Module):
    def __init__(self, n_classes=3, in_ch=3, p_drop=0.1):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, 32, p_drop)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(32, 64, p_drop)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ConvBlock(64, 128, p_drop)
        self.pool3 = nn.MaxPool2d(2)

        self.bott = ConvBlock(128, 256, p_drop)

        self.up3 = UpBlock(256, 128)
        self.up2 = UpBlock(128, 64)
        self.up1 = UpBlock(64, 32)

        self.outc = nn.Conv2d(32, n_classes, kernel_size=1)

    def forward(self, x, return_feats=False):
        # TODO 4.1: wire encoder, bottleneck, decoder with skips
        # s1 = ...
        # p1 = ...
        # s2 = ...
        # p2 = ...
        # s3 = ...
        # p3 = ...
        # b  = ...
        # d3 = ...
        # d2 = ...
        # d1 = ...
        # out = ...
        raise NotImplementedError

model = UNetMedium(n_classes=3).to(device)


In [7]:
class DiceLossMC(nn.Module):
    def __init__(self, num_classes, smooth=1.0, eps=1e-7):
        super().__init__()
        self.num_classes, self.smooth, self.eps = num_classes, smooth, eps
    def forward(self, logits, target):
        probs = F.softmax(logits, dim=1)
        target_oh = F.one_hot(target, num_classes=self.num_classes).permute(0,3,1,2).float()
        dims = (0,2,3)
        inter = (probs * target_oh).sum(dims)
        union = probs.sum(dims) + target_oh.sum(dims)
        dice = (2*inter + self.smooth) / (union + self.smooth + self.eps)
        return 1 - dice.mean()

w_bg, w_border, w_pet = 0.5, 2.0, 1.5
ce_loss   = nn.CrossEntropyLoss(weight=torch.tensor([w_bg, w_border, w_pet], device=device))
dice_loss = DiceLossMC(num_classes=3)

# TODO 5.1: combine CE and Dice (e.g., CE + 0.5*Dice)
def loss_fn(logits, y):
    # return ...
    raise NotImplementedError


In [None]:
from torch.cuda.amp import autocast, GradScaler
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3, weight_decay=1e-4)
scaler = GradScaler(enabled=(device.type=='cuda'))

def train_epoch():
    model.train()
    tot = 0.0
    for x, y in train_loader:
        x = x.to(device)
        y = y.squeeze(1).long().to(device)
        optimizer.zero_grad(set_to_none=True)
        # TODO 6.1: AMP forward + loss + backward + step + scaler.update
        with autocast(enabled=(device.type=='cuda')):
            logits = model(x)
            loss = loss_fn(logits, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        tot += loss.item()
    return tot / max(1, len(train_loader))

@torch.no_grad()
def eval_epoch(loader):
    model.eval()
    tot = 0.0
    for x, y in loader:
        x = x.to(device)
        y = y.squeeze(1).long().to(device)
        logits = model(x)
        tot += loss_fn(logits, y).item()
    return tot / max(1, len(loader))

EPOCHS = 60
best = 1e9
for e in range(1, EPOCHS+1):
    import time; t0 = time.time()
    tr = train_epoch()
    va = eval_epoch(val_loader)
    print(f"Epoch {e:02d} | train {tr:.4f} | val {va:.4f} | {time.time()-t0:.1f}s")
    if va < best:
        best = va
        torch.save(model.state_dict(), "unet_small_best.pt")


In [None]:
from torch.cuda.amp import autocast, GradScaler
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3, weight_decay=1e-4)
scaler = GradScaler(enabled=(device.type=='cuda'))

def train_epoch():
    model.train()
    tot = 0.0
    for x, y in train_loader:
        x = x.to(device)
        y = y.squeeze(1).long().to(device)
        optimizer.zero_grad(set_to_none=True)
        # TODO 6.1: AMP forward + loss + backward + step + scaler.update
        with autocast(enabled=(device.type=='cuda')):
            logits = model(x)
            loss = loss_fn(logits, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        tot += loss.item()
    return tot / max(1, len(train_loader))

@torch.no_grad()
def eval_epoch(loader):
    model.eval()
    tot = 0.0
    for x, y in loader:
        x = x.to(device)
        y = y.squeeze(1).long().to(device)
        logits = model(x)
        tot += loss_fn(logits, y).item()
    return tot / max(1, len(loader))

EPOCHS = 60
best = 1e9
for e in range(1, EPOCHS+1):
    import time; t0 = time.time()
    tr = train_epoch()
    va = eval_epoch(val_loader)
    print(f"Epoch {e:02d} | train {tr:.4f} | val {va:.4f} | {time.time()-t0:.1f}s")
    if va < best:
        best = va
        torch.save(model.state_dict(), "unet_small_best.pt")


In [None]:
from matplotlib.colors import ListedColormap
cmap = ListedColormap(['black','red','white'])

def show_batch_preds(ds, n=3):
    model.eval()
    for i in range(n):
        img, mask = ds[i]
        x = img.unsqueeze(0).to(device)
        with torch.no_grad():
            logits = model(x)
            pred = logits.argmax(1).squeeze(0).cpu().numpy().astype('uint8')
        gt = mask.squeeze().cpu().numpy().astype('uint8')
        fig, ax = plt.subplots(1,3, figsize=(9,3))
        ax[0].set_title("Input"); ax[0].imshow(img.permute(1,2,0)); ax[0].axis('off')
        ax[1].set_title("Ground Truth"); ax[1].imshow(gt, cmap=cmap, vmin=0, vmax=2); ax[1].axis('off')
        ax[2].set_title("Predicted"); ax[2].imshow(pred, cmap=cmap, vmin=0, vmax=2); ax[2].axis('off')
        plt.show()


In [None]:
import numpy as np, matplotlib.pyplot as plt, torch
from matplotlib.colors import ListedColormap

cmap = ListedColormap(['black','red','white'])

def _to_numpy_img(x_1x3hw):
    return x_1x3hw.detach().cpu().squeeze(0).permute(1,2,0).numpy()

def _norm01(arr, eps=1e-6):
    return (arr - arr.min()) / (arr.max() - arr.min() + eps)

def visualize_feature_map(feature_tensor, title, max_channels=16, channels=None, tight=True):
    ft = feature_tensor.detach().cpu().squeeze(0)
    C = ft.shape[0]
    if channels is None:
        Cshow = min(C, max_channels); ch_idx = list(range(Cshow))
    else:
        ch_idx = [i for i in channels if 0 <= i < C]; Cshow = len(ch_idx)
    cols = int(np.ceil(np.sqrt(Cshow))); rows = int(np.ceil(Cshow / cols))
    fig, axes = plt.subplots(rows, cols, figsize=(cols*2, rows*2))
    fig.suptitle(f"{title}  (feat shape: {tuple(feature_tensor.shape)})", fontsize=12)
    axes = np.atleast_1d(axes).ravel()
    for i in range(rows*cols):
        ax = axes[i]; ax.axis('off')
        if i < Cshow:
            fm = ft[ch_idx[i]].numpy()
            ax.imshow(_norm01(fm), cmap='gray'); ax.set_title(f"ch {ch_idx[i]}", fontsize=8)
    if tight: plt.tight_layout()
    plt.show()

def overlay_activation(rgb, fmap, title, reduce='mean'):
    img = _to_numpy_img(rgb)
    ft = fmap.detach().cpu().squeeze(0)
    act = ft.max(0).values.numpy() if reduce=='max' else ft.mean(0).numpy()
    act = _norm01(act)
    plt.figure(figsize=(6,3))
    plt.subplot(1,2,1); plt.title("Input"); plt.imshow(img); plt.axis('off')
    plt.subplot(1,2,2); plt.title(title); plt.imshow(img); plt.imshow(act, alpha=0.45, cmap='jet'); plt.axis('off')
    plt.show()

def visualize_concat(up, skip, concat, title, k_each=6):
    upC, skC = up.shape[1], skip.shape[1]
    k_up = min(k_each, upC); k_sk = min(k_each, skC)
    panels = [up[:, i:i+1] for i in range(k_up)] + [skip[:, i:i+1] for i in range(k_sk)]
    labels = [f"up {i}" for i in range(k_up)] + [f"skip {i}" for i in range(k_sk)]
    k_cat = min(k_up + k_sk, concat.shape[1])
    panels += [concat[:, i:i+1] for i in range(k_cat)]
    labels += [f"cat {i}" for i in range(k_cat)]
    n = len(panels); cols = max(k_each, 3); rows = int(np.ceil(n / cols))
    plt.figure(figsize=(cols*2, rows*2))
    plt.suptitle(f"{title}\nup:{tuple(up.shape)}  skip:{tuple(skip.shape)}  concat:{tuple(concat.shape)}", fontsize=11)
    for idx, tens in enumerate(panels):
        plt.subplot(rows, cols, idx+1)
        fm = tens.detach().cpu().squeeze().numpy()
        plt.imshow(_norm01(fm), cmap='gray'); plt.axis('off'); plt.title(labels[idx], fontsize=8)
    plt.tight_layout(); plt.show()

def show_prediction(model, sample_img, sample_mask, show_confidence=True):
    x = sample_img.to(device).unsqueeze(0)
    model.eval()
    with torch.no_grad():
        logits = model(x)
        probs  = torch.softmax(logits, dim=1)
        pred   = probs.argmax(1).squeeze(0).cpu().numpy().astype('uint8')
        conf   = probs.max(1).values.squeeze(0).cpu().numpy()
    img_np = sample_img.permute(1,2,0).cpu().numpy()
    gt = sample_mask.squeeze().cpu().numpy().astype('uint8')
    ncols = 4 if show_confidence else 3
    plt.figure(figsize=(3*ncols,3))
    plt.subplot(1,ncols,1); plt.title("Input"); plt.imshow(img_np); plt.axis('off')
    plt.subplot(1,ncols,2); plt.title("Ground Truth"); plt.imshow(gt, cmap=cmap, vmin=0, vmax=2); plt.axis('off')
    plt.subplot(1,ncols,3); plt.title("Predicted"); plt.imshow(pred, cmap=cmap, vmin=0, vmax=2); plt.axis('off')
    if show_confidence:
        from matplotlib.colors import ListedColormap
        plt.subplot(1,ncols,4); plt.title("Top-class confidence"); plt.imshow(_norm01(conf), cmap='magma'); plt.axis('off')
    plt.show()
