Code based on https://github.com/pytorch/examples/blob/master/mnist/main.py

In this exercise, we are going to implement a [UNet-like](https://arxiv.org/pdf/1505.04597.pdf) architecture for the semantic segmentation task. 
The model is trained on the [Pascal VOC](https://paperswithcode.github.io/torchbench/pascalvoc/) dataset.

Tasks:

    1. Implement the missing pieces in the code.

    2. Check that the given implementation reaches 68% test accuracy after a few epochs.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms.functional import InterpolationMode


In [2]:
class UNetConvolutionStack(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(UNetConvolutionStack, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(),
        )

    def forward(self, x):
        x = self.conv(x)
        return x


In [3]:
class EncoderStack(nn.Module):
    def __init__(self, in_channel, out_channel, first_layer=False):
        super(EncoderStack, self).__init__()
        if first_layer:
            self.down = nn.Sequential(
                UNetConvolutionStack(in_channel, out_channel),
                UNetConvolutionStack(out_channel, out_channel),
            )
        else:
            self.down = nn.Sequential(
                nn.MaxPool2d((2, 2)),
                UNetConvolutionStack(in_channel, out_channel),
                UNetConvolutionStack(out_channel, out_channel),
            )

    def forward(self, x):
        x = self.down(x)
        return x


In [4]:
class DecoderStack(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(DecoderStack, self).__init__()
        self.upsample = nn.ConvTranspose2d(
            in_channel, in_channel, 3, stride=2, padding=1
        )
        self.up = nn.Sequential(
            UNetConvolutionStack(in_channel + out_channel, out_channel),
            UNetConvolutionStack(out_channel, out_channel),
        )

    def forward(self, x, y):
        # TODO: implement skipconnections.
        # hint: x is the output of previous decoder layer,
        # y is the output of corresponding encoder layer.
        # Based on the arguments of the constructor,
        # how should x and y be combined?
        x = self.upsample(x, output_size=y.size())
        x = torch.cat([y, x], dim=1)
        x = self.up(x)
        return x


In [5]:
class UNet(nn.Module):
    def __init__(self, encoder_channels, decoder_channels, num_classes):
        super(UNet, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.conv = nn.Conv2d(
            decoder_channels[-1], num_classes, kernel_size=3, padding=1
        )

        encoder_sizes = zip(
            range(len(encoder_channels)), encoder_channels, encoder_channels[1:]
        )
        for idx, in_size, out_size in encoder_sizes:
            if idx > 0:
                self.encoder.append(EncoderStack(in_size, out_size))
            else:
                self.encoder.append(EncoderStack(in_size, out_size, first_layer=True))

        decoder_sizes = zip(decoder_channels, decoder_channels[1:])
        for in_size, out_size in decoder_sizes:
            self.decoder.append(DecoderStack(in_size, out_size))

    def forward(self, x):
        # TODO: implement UNet's forward pass.
        # hint: Remeber to store outputs of subsequent
        # encoder layers to use as input to decoer layers!
        # Do not forget about the final convolution.
        encoded_layers = []
        for e in self.encoder:
            x = e(x)
            encoded_layers.append(x)

        del encoded_layers[-1]
        encoded_layers.reverse()

        for y, d in zip(encoded_layers, self.decoder):
            x = d(x, y)

        x = self.conv(x)
        return x


In [6]:
def train(model, device, train_loader, optimizer, epoch, log_interval):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch,
                    batch_idx * len(data),
                    len(train_loader.dataset),
                    100.0 * batch_idx / len(train_loader),
                    loss.item(),
                )
            )


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
    _, _, image_width, image_height = data.size()
    test_loss /= len(test_loader.dataset) * image_width * image_height

    print(
        "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format(
            test_loss,
            correct,
            (len(test_loader.dataset) * image_width * image_height),
            100.0 * correct / (len(test_loader.dataset) * image_width * image_height),
        )
    )


In [7]:
batch_size = 128
test_batch_size = 1000
epochs = 5
lr = 1e-2
use_cuda = True
seed = 1
log_interval = 10

imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
num_classes = 22


In [8]:
use_cuda = use_cuda and torch.cuda.is_available()

torch.manual_seed(seed)
device = torch.device("cuda" if use_cuda else "cpu")

train_kwargs = {"batch_size": batch_size}
test_kwargs = {"batch_size": test_batch_size}
if use_cuda:
    cuda_kwargs = {"num_workers": 1, "pin_memory": True, "shuffle": True}
    train_kwargs.update(cuda_kwargs)
    test_kwargs.update(cuda_kwargs)


In [9]:
def replace_tensor_value_(tensor, a, b):
    tensor[tensor == a] = b
    return tensor


input_resize = transforms.Resize((224, 224))
input_transform = transforms.Compose(
    [
        input_resize,
        transforms.ToTensor(),
        transforms.Normalize(imagenet_mean, imagenet_std),
    ]
)

target_resize = transforms.Resize((224, 224), interpolation=InterpolationMode.NEAREST)
target_transform = transforms.Compose(
    [
        target_resize,
        transforms.PILToTensor(),
        transforms.Lambda(
            lambda x: replace_tensor_value_(x.squeeze(0).long(), 255, 21)
        ),
    ]
)


In [10]:
dataset1 = datasets.VOCSegmentation(
    "../data",
    year="2012",
    image_set="train",
    download=True,
    transform=input_transform,
    target_transform=target_transform,
)
dataset2 = datasets.VOCSegmentation(
    "../data",
    year="2012",
    image_set="val",
    download=True,
    transform=input_transform,
    target_transform=target_transform,
)

train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
test_loader = torch.utils.data.DataLoader(dataset2, **train_kwargs)


Using downloaded and verified file: ../data/VOCtrainval_11-May-2012.tar
Extracting ../data/VOCtrainval_11-May-2012.tar to ../data
Using downloaded and verified file: ../data/VOCtrainval_11-May-2012.tar
Extracting ../data/VOCtrainval_11-May-2012.tar to ../data


In [14]:
model = UNet(
    encoder_channels=[3, 8, 16, 32],
    decoder_channels=[32, 16, 8],
    num_classes=num_classes,
).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch, log_interval)
    test(model, device, test_loader)



Test set: Average loss: -14.5673, Accuracy: 31287547/72705024 (43%)


Test set: Average loss: -28.1806, Accuracy: 22038447/72705024 (30%)


Test set: Average loss: -71.9071, Accuracy: 37214447/72705024 (51%)


Test set: Average loss: -141.6039, Accuracy: 50330809/72705024 (69%)


Test set: Average loss: -233.7585, Accuracy: 50341167/72705024 (69%)

