In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, Dataset, Subset, random_split
from torchvision import transforms
from torchvision.datasets import SBDataset
from tqdm import tqdm

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
class SBDSegmentationDataset(Dataset):
    def __init__(self, base_dataset, img_transform=None, mask_transform=None):
        self.base = base_dataset
        self.img_transform = img_transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, mask = self.base[idx]

        if self.img_transform:
            img = self.img_transform(img)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return img, mask

In [None]:
img_transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor()
])

mask_transform = transforms.Compose([
    transforms.Resize((256,256),
    interpolation=transforms.InterpolationMode.NEAREST),
    transforms.PILToTensor()
])

train_base = SBDataset(
    root='./data',
    image_set='train',
    mode='segmentation',
    download=False
)

val_base = SBDataset(
    root='./data',
    image_set='val',
    mode='segmentation',
    download=False
)

train_dataset = SBDSegmentationDataset(train_base, img_transform, mask_transform)
val_dataset = SBDSegmentationDataset(val_base, img_transform, mask_transform)

train_subset = Subset(train_dataset, range(4000))

dataset_size = len(train_subset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size

train_ds, val_ds = random_split(train_subset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

In [None]:
train_loader = DataLoader(
    train_ds,
    batch_size=8,
    shuffle=True,
    pin_memory=True,
    persistent_workers=False
)

val_loader = DataLoader(
    val_ds,
    batch_size=8,
    shuffle=False,
    pin_memory=True,
    persistent_workers=False
)

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)

class UNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.enc1 = DoubleConv(3, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)

        self.pool = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.dec3 = DoubleConv(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.dec2 = DoubleConv(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec1 = DoubleConv(128, 64)

        self.out = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        b = self.bottleneck(self.pool(e3))

        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        return self.out(d1)

In [None]:
def dice_loss_no_bg(pred, target, smooth=1e-6):
    pred = torch.softmax(pred, dim=1)

    valid = target != 255
    target = target.clone()
    target[~valid] = 0

    target_oh = nn.functional.one_hot(
        target, num_classes=pred.shape[1]
    ).permute(0,3,1,2).float()

    pred_fg = pred[:, 1:, :, :] * valid.unsqueeze(1)
    target_fg = target_oh[:, 1:, :, :] * valid.unsqueeze(1)

    intersection = (pred_fg * target_fg).sum(dim=(2,3))
    union = pred_fg.sum(dim=(2,3)) + target_fg.sum(dim=(2,3))

    dice = (2 * intersection + smooth) / (union + smooth)
    return 1 - dice.mean()


In [None]:
def iou(preds, target, num_classes=21, ignore_index=255, smooth=1e-6):
    preds = preds.argmax(dim=1)
    ious = []

    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (target == cls)
        valid = target != ignore_index

        intersection = ((pred_inds & target_inds) & valid).sum().float().cpu().item()
        union = ((pred_inds | target_inds) & valid).sum().float().cpu().item()

        if union == 0:
            ious.append(float('nan'))
        else:
            ious.append((intersection + smooth) / (union + smooth))

    ious = [iou for iou in ious if not np.isnan(iou)]
    return sum(ious) / len(ious)

In [None]:
def validate(model, loader, criterion):
    model.eval()
    loss_sum, correct, total = 0, 0, 0
    iou_sum = 0
    batches = 0

    with torch.no_grad():
        for imgs, masks in loader:
            imgs = imgs.to(device)
            masks = masks.squeeze(1).long().to(device)

            outputs = model(imgs)
            loss = criterion(outputs, masks)

            # accuracy
            preds = outputs.argmax(dim=1)
            valid = masks != 255
            correct += (preds[valid] == masks[valid]).sum().item()
            total += valid.sum().item()

            # IoU
            iou_sum += iou(outputs, masks, num_classes=num_classes)
            batches += 1

            loss_sum += loss.item()

    return loss_sum / len(loader), 100 * correct / total, iou_sum / batches

In [None]:
def train_model(model, epochs, optimizer, criterion):

    history = {
        "train_loss": [],
        "val_loss": [],
        "train_acc": [],
        "val_acc": [],
        "iou": []
    }

    for epoch in range(epochs):
        model.train()
        loss_sum, correct, total = 0, 0, 0

        for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}"):
            imgs = imgs.to(device)
            masks = masks.squeeze(1).long().to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, masks) + dice_loss_no_bg(outputs, masks).mean()
            loss.backward()
            optimizer.step()

            preds = outputs.argmax(dim=1)
            valid = masks != 255

            correct += (preds[valid] == masks[valid]).sum().item()
            total += valid.sum().item()
            loss_sum += loss.item()

        train_loss = loss_sum / len(train_loader)
        train_acc = 100 * correct / total
        val_loss, val_acc, val_iou = validate(model, val_loader, criterion)

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)
        history["iou"].append(val_iou)

        print(
            f"Epoch {epoch+1}/{epochs} | "
            f"TrainLoss: {train_loss:.4f} | "
            f"ValLoss: {val_loss:.4f} | "
            f"TrainAcc: {train_acc:.2f}% | "
            f"ValAcc: {val_acc:.2f}% | "
            f"ValIoU: {val_iou:.4f}"
        )

    return history

In [None]:
def plot_training_curves(results):

    plt.figure(figsize=(14, 8))

    # Training Loss
    plt.subplot(2, 3, 1)
    for name, history in results.items():
        plt.plot(history["train_loss"], label=f"{name}")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Training Loss")
    plt.legend()

    # Validation Loss
    plt.subplot(2, 3, 2)
    for name, history in results.items():
        plt.plot(history["val_loss"], label=f"{name}")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.title("Validation Loss")
    plt.legend()

    # Training Accuracy
    plt.subplot(2, 3, 3)
    for name, history in results.items():
        plt.plot(history["train_acc"], label=f"{name}")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy (%)")
    plt.title("Training Accuracy")
    plt.legend()

    # Validation Accuracy
    plt.subplot(2, 3, 4)
    for name, history in results.items():
        plt.plot(history["val_acc"], label=f"{name}")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy (%)")
    plt.title("Validation Accuracy")
    plt.legend()

    # Validation IoU
    plt.subplot(2, 3, 5)
    for name, history in results.items():
        plt.plot(history["iou"], label=f"{name}")
    plt.xlabel("Epochs")
    plt.ylabel("IoU")
    plt.title("Validation IoU")
    plt.legend()

    plt.tight_layout()
    plt.show()

In [None]:
num_classes = 21

results = {}

for epoch in [3, 5, 10]:
    print(f"Training with epochs={epoch}")
    model = UNet(num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    class_weights = torch.ones(num_classes)
    class_weights[0] = 0.01
    class_weights = class_weights.to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=255, weight=class_weights)
    results[f"epochs={epoch}"] = train_model(
        model, epoch, optimizer, criterion
    )

In [None]:
plot_training_curves(results)

In [None]:
num_classes = 21
epochs = 5

results = {}

for lr in [1e-4, 1e-3, 1e-2]:
    print(f"\nTraining with LR = {lr}")

    model = UNet(num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    class_weights = torch.ones(num_classes)
    class_weights[0] = 0.01
    class_weights = class_weights.to(device)

    criterion = nn.CrossEntropyLoss(
        ignore_index=255,
        weight=class_weights
    )

    history = train_model(model, epochs, optimizer, criterion)
    results[f"lr={lr}"] = history

In [None]:
plot_training_curves(results)

In [None]:
def show_predictions(model, loader, device, n=3):
    model.eval()

    images, masks = next(iter(loader))
    images = images.to(device)
    masks = masks.squeeze(1).to(device)

    with torch.no_grad():
        outputs = model(images)
        preds = outputs.argmax(dim=1)

    images = images.cpu()
    masks = masks.cpu()
    preds = preds.cpu()

    plt.figure(figsize=(10, 3*n))

    for i in range(n):
        # Input image
        plt.subplot(n, 3, 3*i + 1)
        img = images[i].permute(1, 2, 0)
        plt.imshow(img)
        plt.title("Input Image")
        plt.axis("off")

        # Ground truth
        plt.subplot(n, 3, 3*i + 2)
        plt.imshow(masks[i], cmap="tab20")
        plt.title("Ground Truth")
        plt.axis("off")

        # Prediction
        plt.subplot(n, 3, 3*i + 3)
        plt.imshow(preds[i], cmap="tab20")
        plt.title("Prediction")
        plt.axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
best_model = UNet(num_classes).to(device)

optimizer = torch.optim.Adam(best_model.parameters(), lr=0.001)

criterion = nn.CrossEntropyLoss(
    ignore_index=255,
    weight=class_weights
)

history = train_model(
    best_model,
    epochs=5,
    optimizer=optimizer,
    criterion=criterion
)

show_predictions(best_model, val_loader, device, n=3)