In [25]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import random

from networkx.algorithms.shortest_paths.unweighted import predecessor
from torch.utils.data import DataLoader, Dataset, Subset, random_split
from torchvision import datasets, transforms
from torchvision.datasets import SBDataset
from tqdm import tqdm

In [26]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [27]:
class SBDSegmentationDataset(Dataset):
    def __init__(self, base_dataset, img_transform=None, mask_transform=None):
        self.base = base_dataset
        self.img_transform = img_transform
        self.mask_transform = mask_transform

    def __len__(self):
        return len(self.base)

    def __getitem__(self, idx):
        img, mask = self.base[idx]

        if self.img_transform:
            img = self.img_transform(img)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return img, mask

In [28]:
img_transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor()
])

mask_transform = transforms.Compose([
    transforms.Resize((256,256),
    interpolation=transforms.InterpolationMode.NEAREST),
    transforms.PILToTensor()
])

train_base = SBDataset(
    root='./data',
    image_set='train',
    mode='segmentation',
    download=False
)

val_base = SBDataset(
    root='./data',
    image_set='val',
    mode='segmentation',
    download=False
)

train_dataset = SBDSegmentationDataset(train_base, img_transform, mask_transform)
val_dataset = SBDSegmentationDataset(val_base, img_transform, mask_transform)

train_subset = Subset(train_dataset, range(4000))

dataset_size = len(train_subset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size

train_ds, val_ds = random_split(train_subset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

In [29]:
train_loader = DataLoader(
    train_ds,
    batch_size=8,
    shuffle=True,
    pin_memory=True,
    persistent_workers=False
)

val_loader = DataLoader(
    val_ds,
    batch_size=8,
    shuffle=False,
    pin_memory=True,
    persistent_workers=False
)

In [30]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.net(x)

class UNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.enc1 = DoubleConv(3, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)

        self.pool = nn.MaxPool2d(2)
        self.bottleneck = DoubleConv(256, 512)

        self.up3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.dec3 = DoubleConv(512, 256)

        self.up2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.dec2 = DoubleConv(256, 128)

        self.up1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.dec1 = DoubleConv(128, 64)

        self.out = nn.Conv2d(64, num_classes, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        b = self.bottleneck(self.pool(e3))

        d3 = self.dec3(torch.cat([self.up3(b), e3], dim=1))
        d2 = self.dec2(torch.cat([self.up2(d3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        return self.out(d1)

In [31]:
def dice_loss(pred, target, smooth=1e-6):
    pred = torch.softmax(pred, dim=1)

    valid = target != 255
    target = target.clone()
    target[~valid] = 0

    target_one_hot = nn.functional.one_hot(target, num_classes=pred.shape[1]).permute(0, 3, 1, 2).float()

    pred = pred * valid.unsqueeze(1)

    intersection = (pred * target_one_hot).sum(dim=(2, 3))
    union = pred.sum(dim=(2, 3)) + target_one_hot.sum(dim=(2, 3))

    return 1 - (2 * intersection + smooth) / (union + smooth)


In [32]:
def iou(preds, target, num_classes=21, ignore_index=255, smooth=1e-6):
    preds = preds.argmax(dim=1)
    ious = []

    for cls in range(num_classes):
        pred_inds = (preds == cls)
        target_inds = (target == cls)
        valid = target != ignore_index

        intersection = ((pred_inds & target_inds) & valid).sum().float().cpu().item()
        union = ((pred_inds | target_inds) & valid).sum().float().cpu().item()

        if union == 0:
            ious.append(float('nan'))
        else:
            ious.append((intersection + smooth) / (union + smooth))

    ious = [iou for iou in ious if not np.isnan(iou)]
    return sum(ious) / len(ious)

In [33]:
def validate(model, loader, criterion):
    model.eval()
    loss_sum, correct, total = 0, 0, 0
    iou_sum = 0
    batches = 0

    with torch.no_grad():
        for imgs, masks in loader:
            imgs = imgs.to(device)
            masks = masks.squeeze(1).long().to(device)

            outputs = model(imgs)
            loss = criterion(outputs, masks)

            # accuracy
            preds = outputs.argmax(dim=1)
            valid = masks != 255
            correct += (preds[valid] == masks[valid]).sum().item()
            total += valid.sum().item()

            # IoU
            iou_sum += iou(outputs, masks, num_classes=num_classes)
            batches += 1

            loss_sum += loss.item()

    return loss_sum / len(loader), 100 * correct / total, iou_sum / batches

In [34]:
def train_model(model, epochs, optimizer, criterion):

    history = {
        "train_loss": [],
        "val_loss": [],
        "train_acc": [],
        "val_acc": [],
        "iou": []
    }

    for epoch in range(epochs):
        model.train()
        loss_sum, correct, total = 0, 0, 0

        for imgs, masks in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}"):
            imgs = imgs.to(device)
            masks = masks.squeeze(1).long().to(device)

            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, masks) + dice_loss(outputs, masks).mean()
            loss.backward()
            optimizer.step()

            preds = outputs.argmax(dim=1)
            valid = masks != 255

            correct += (preds[valid] == masks[valid]).sum().item()
            total += valid.sum().item()
            loss_sum += loss.item()

        train_loss = loss_sum / len(train_loader)
        train_acc = 100 * correct / total
        val_loss, val_acc, val_iou = validate(model, val_loader, criterion)

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)
        history["iou"].append(val_iou)

        print(
            f"Epoch {epoch+1}/{epochs} | "
            f"TrainLoss: {train_loss:.4f} | "
            f"ValLoss: {val_loss:.4f} | "
            f"TrainAcc: {train_acc:.2f}% | "
            f"ValAcc: {val_acc:.2f}% | "
            f"ValIoU: {val_iou:.4f}"
        )

    return history

In [35]:
num_classes = 21

results = {}

for epoch in [5, 10]:
    print(f"Training with epochs={epoch}")
    model = UNet(num_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    class_weights = torch.ones(num_classes)
    class_weights[0] = 0.3
    class_weights = class_weights.to(device)
    criterion = nn.CrossEntropyLoss(ignore_index=255, weight=class_weights)
    results[f"epochs={epoch}"] = train_model(
        model, epoch, optimizer, criterion
    )

Training with epochs=5


Epoch 1/5: 100%|██████████| 400/400 [05:18<00:00,  1.26it/s]


Epoch 1/5 | TrainLoss: 3.1719 | ValLoss: 2.1971 | TrainAcc: 69.99% | ValAcc: 70.33% | ValIoU: 0.0842


Epoch 2/5: 100%|██████████| 400/400 [06:35<00:00,  1.01it/s]


Epoch 2/5 | TrainLoss: 3.1072 | ValLoss: 2.1238 | TrainAcc: 70.46% | ValAcc: 70.33% | ValIoU: 0.0842


Epoch 3/5: 100%|██████████| 400/400 [06:47<00:00,  1.02s/it]


Epoch 3/5 | TrainLoss: 3.0924 | ValLoss: 2.1198 | TrainAcc: 70.47% | ValAcc: 70.33% | ValIoU: 0.0842


Epoch 4/5: 100%|██████████| 400/400 [06:42<00:00,  1.01s/it]


Epoch 4/5 | TrainLoss: 3.0705 | ValLoss: 2.0810 | TrainAcc: 70.46% | ValAcc: 70.33% | ValIoU: 0.0842


Epoch 5/5: 100%|██████████| 400/400 [06:43<00:00,  1.01s/it]


Epoch 5/5 | TrainLoss: 3.0385 | ValLoss: 2.1313 | TrainAcc: 68.19% | ValAcc: 61.73% | ValIoU: 0.0938
Training with epochs=10


Epoch 1/10: 100%|██████████| 400/400 [06:43<00:00,  1.01s/it]


Epoch 1/10 | TrainLoss: 3.1584 | ValLoss: 2.1469 | TrainAcc: 70.05% | ValAcc: 70.33% | ValIoU: 0.0842


Epoch 2/10: 100%|██████████| 400/400 [06:42<00:00,  1.01s/it]


Epoch 2/10 | TrainLoss: 3.1125 | ValLoss: 2.1994 | TrainAcc: 70.46% | ValAcc: 70.33% | ValIoU: 0.0842


Epoch 3/10: 100%|██████████| 400/400 [06:43<00:00,  1.01s/it]


Epoch 3/10 | TrainLoss: 3.0828 | ValLoss: 2.0957 | TrainAcc: 68.63% | ValAcc: 69.32% | ValIoU: 0.0880


Epoch 4/10: 100%|██████████| 400/400 [06:43<00:00,  1.01s/it]


Epoch 4/10 | TrainLoss: 3.0540 | ValLoss: 2.1063 | TrainAcc: 66.68% | ValAcc: 62.45% | ValIoU: 0.0933


Epoch 5/10: 100%|██████████| 400/400 [06:43<00:00,  1.01s/it]


Epoch 5/10 | TrainLoss: 3.0438 | ValLoss: 2.0935 | TrainAcc: 66.40% | ValAcc: 69.82% | ValIoU: 0.0900


Epoch 6/10: 100%|██████████| 400/400 [06:42<00:00,  1.01s/it]


Epoch 6/10 | TrainLoss: 3.0282 | ValLoss: 2.0830 | TrainAcc: 67.31% | ValAcc: 57.33% | ValIoU: 0.0900


Epoch 7/10: 100%|██████████| 400/400 [06:43<00:00,  1.01s/it]


Epoch 7/10 | TrainLoss: 3.0137 | ValLoss: 2.0475 | TrainAcc: 67.74% | ValAcc: 69.15% | ValIoU: 0.0983


Epoch 8/10: 100%|██████████| 400/400 [06:42<00:00,  1.01s/it]


Epoch 8/10 | TrainLoss: 3.0131 | ValLoss: 2.0404 | TrainAcc: 67.17% | ValAcc: 64.12% | ValIoU: 0.0977


Epoch 9/10: 100%|██████████| 400/400 [06:43<00:00,  1.01s/it]


Epoch 9/10 | TrainLoss: 2.9939 | ValLoss: 2.0216 | TrainAcc: 67.05% | ValAcc: 68.61% | ValIoU: 0.1004


Epoch 10/10: 100%|██████████| 400/400 [06:43<00:00,  1.01s/it]


Epoch 10/10 | TrainLoss: 2.9940 | ValLoss: 2.0669 | TrainAcc: 67.27% | ValAcc: 69.29% | ValIoU: 0.0920
