# Exercise 5: Scene-Dependent Image Segmentation

The goal of this homework is to implement a model that seperates foreground and background objects for a specific scene.  
We will use the highway scene from the Change Detection dataset:  
http://jacarini.dinf.usherbrooke.ca/dataset2014#

![input image](highway/input/in001600.jpg "Title") ![gt image](highway/groundtruth/gt001600.png "Title")

## Task 1: Create a custom (Pytorch) dataset


https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
You need to create a class that inherets from **from torch.utils.data.Dataset** and implements two methods:
- **def \_\_len\_\_(self)**:  returns the length of the dataset
- **def \_\_getitem\_\_(self, idx)**: given an integer idx returns the data x,y
    - x is the image as a float tensor of shape: $(3,H,W)$ 
    - y is the label image as a mask of shape: $(H,W)$ each pixel should contain the label 0 (background) or 1 (foreground). It is recommended to use the type torch.long
    
**Tips**:
- The first 470 images are not labeled. Just ignore these images. 
- If possible load all images into memory or even directly to GPU to increase speed.
- You can change the resolution to fit your model or your memory
- Add data augmentation to increase the data size and model robustness

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt

import glob

In [2]:
class ChangeDetectionDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

## Task 2: Create a custom Segmentation Model

- input: a batch of images $(B,3,H,W)$ 
- output: a batch of pixel-wise class predictions $(B,C,H,W)$, where $C=2$

Tips:
- It is recommended to use a Fully-Convolutional Neural Network, because it flexible to the input and output resolution.
- Use Residual Blocks with convolutional layers.
- Base your model on established segmentation models:
    - U-Net: https://arxiv.org/abs/1505.04597
    - Deeplab: https://arxiv.org/abs/1606.00915

In [3]:
class ResidualBlock(nn.Module):

    def __init__(self, 
        in_channels, 
        out_channels, 
        kernel_size=3,
        stride=1,
        padding=0,
        batch_norm=True
    ):
        super(ResidualBlock, self).__init__()
        self.conv = nn.Conv2d(
            in_channels, 
            out_channels, 
            kernel_size=kernel_size,
            stride=stride, 
            padding=padding
        )

        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm2d(out_channels) if batch_norm else None
        self.bn2 = nn.BatchNorm2d(out_channels) if batch_norm else None

        # self.downsample = None
        # if in_channels != out_channels:
        #     self.downsample = nn.Conv2d(
        #         in_channels, 
        #         out_channels, 
        #         kernel_size=1,
        #         padding=0
        #     )
    
    def forward(self, x):

        # residual = x
        # if self.downsample:
        #     residual = self.downsample(residual)
        #     residual = self.bn1(residual) if self.bn1 else residual

        x = self.bn2(self.conv(x)) if self.bn2 else self.conv(x)
        x = self.relu(x) # skip-connection applied
        return x

def createConvSequential(
    in_channels,
    out_channels,
    kernel_size=3,
    stride=1,
    padding=1,
    batch_norm=True
):
    sequential = nn.Sequential(
        ResidualBlock(in_channels, out_channels, kernel_size, stride, padding, batch_norm),
        ResidualBlock(out_channels, out_channels, kernel_size, stride, padding, batch_norm)
    )
    return sequential


In [18]:
class UNet(nn.Module):

    def __init__(self):
        super(UNet, self).__init__()
        # all available channels
        self.inp_ch = 3
        self.ch1 = 64
        self.ch2 = 128
        self.ch3 = 256
        self.ch4 = 512
        self.ch5 = 1024
        self.out_ch = 2
        # downsampling conv-layers
        self.ds1 = createConvSequential(self.inp_ch, self.ch1)
        self.ds2 = createConvSequential(self.ch1, self.ch2)
        self.ds3 = createConvSequential(self.ch2, self.ch3)
        self.ds4 = createConvSequential(self.ch3, self.ch4)
        self.ds5 = createConvSequential(self.ch4, self.ch5)
        # upsampling conv-layers
        self.us1 = createConvSequential(self.ch5, self.ch4)
        self.us2 = createConvSequential(self.ch4, self.ch3)
        self.us3 = createConvSequential(self.ch3, self.ch2)
        self.us4 = createConvSequential(self.ch2, self.ch1)
        # modules for chanaging channel sizes
        self.max_pool = nn.MaxPool2d(kernel_size=2)
        self.trans_conv1 = nn.ConvTranspose2d(self.ch5, self.ch4, kernel_size=2)
        self.trans_conv2 = nn.ConvTranspose2d(self.ch4, self.ch3, kernel_size=2)
        self.trans_conv3 = nn.ConvTranspose2d(self.ch3, self.ch2, kernel_size=2)
        self.trans_conv4 = nn.ConvTranspose2d(self.ch2, self.ch1, kernel_size=2)
        self.last_conv = nn.Conv2d(self.ch1, self.out_ch, kernel_size=1)
        self.mode = "bilinear"
    
    def add_residual(self, ds_res, x):

        target_h, target_w = x.size(2), x.size(3)
        input_h, input_w = ds_res.size(2), ds_res.size(3)

        x_start = (input_h - target_h) // 2
        y_start = (input_w - target_w) // 2
        # print(dims1, dims2, (dims1 - dims2), ((dims1 - dims2) / 2).long())
        #tensor([22, 32]) tensor([ 8, 13]) tensor([14, 19]) tensor([7, 9])
        # even - uneven = uneven  32, 13: 32-13=19, 19/2=9, 32-(2*9)=14     # goal is 13
        # uneven - even = uneven 33, 12: 33-12=21, 21/2=11, 32-(2*11)=10     # goal is 13

        if x_start < 0 or y_start < 0:
            ds_res = F.interpolate(ds_res, size=(target_h, target_w), mode=self.mode)
            cropped = ds_res
        else:
            cropped = ds_res[:, :, x_start :  x_start + target_h, y_start : y_start + target_w]
        
        x = torch.concatenate((cropped, x), dim=1)

        return x
    
    def forward(self, x):
        
        # downsampling path
        ds_res1 = self.ds1(x)
        x = self.max_pool(ds_res1)

        ds_res2 = self.ds2(x)
        x = self.max_pool(ds_res2)

        ds_res3 = self.ds3(x)
        x = self.max_pool(ds_res3)

        ds_res4 = self.ds4(x)
        x = self.max_pool(ds_res4)

        x = self.ds5(x)

        # upsampling path, crop and concatenate
        x = F.interpolate(x, scale_factor=2, mode=self.mode)
        x = self.trans_conv1(x)
        x = self.add_residual(ds_res4, x)
        x = self.us1(x)

        x = F.interpolate(x, scale_factor=2, mode=self.mode)
        x = self.trans_conv2(x)
        x = self.add_residual(ds_res3, x)
        x = self.us2(x)

        x = F.interpolate(x, scale_factor=2, mode=self.mode)
        x = self.trans_conv3(x)
        x = self.add_residual(ds_res2, x)
        x = self.us3(x)

        x = F.interpolate(x, scale_factor=2, mode=self.mode)
        x = self.trans_conv4(x)
        x = self.add_residual(ds_res1, x)
        x = self.us4(x)
        
        x = self.last_conv(x) # differentiate between fore- and background by creating 2 masks
        x = F.interpolate(x, size=(240, 320), mode="bilinear")
        
        return x

## Task 3: Create a training loop
- split data into training and test data, e.g. 80% training data and 20% test data using your custom dataset.
- Create a Dataloader for your custom datasets 
- Define a training loop for a single epoch:
    - forward pass
    - Loss function, e.g. cross entropy
    - optimizer 
    - backward pass
    - logging
- Define validation loop:
    - forward pass
    - extract binary labels, e.g. threshold or argmax for each pixel.
    - compute evaluation metrics: Accuracy, Precision, Recall and Intersection over Union for each image

In [19]:
def train_model(
    model, 
    device,
    train_loader,
    validation_loader,
    train_func,
    validation_func,
    lr=0.001, 
    momentum=0,
    step_size=5, 
    gamma=0.1, 
    epochs=10,
    **additional_params
):
    train_params = additional_params.pop("train_params", {})
    test_params = additional_params.pop("validation_params", {})
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
    for epoch in range(1, epochs + 1):
        train_func(model, device, train_loader, optimizer, epoch, **train_params)
        validation_func(model, device, validation_loader, epoch, **test_params)
        scheduler.step()

In [27]:
def train_segmentation_model(
    model, 
    device, 
    train_loader, 
    optimizer, 
    epoch
):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader): # target label - ground truth
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        # output = torch.argmax(output, dim=1).float() # pick the according class-labels by picking highest values
        # img = plt.imshow(output.squeeze(0), cmap="gray")
        # print(output.shape, target.shape)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 10 == 0:
            data_processed_so_far = batch_idx * len(data)
            total_data = len(train_loader.dataset)
            progress = 100. * batch_idx / len(train_loader)
            print("Train Epoch {}: [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                epoch, data_processed_so_far, total_data, progress, loss.item()
            ), end="\r")


def validate_segmentation_model(
    model, 
    device, 
    test_loader, 
    epoch
):
    model.eval()
    test_loss = 0.
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction="sum").item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max probability
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    total_data = len(test_loader.dataset)
    test_loss /= total_data
    progress = 100. * correct / total_data
    print("Test Set (Epoch {}): Average Loss: {:.4f}\tAccuracy: {}/{} ({:.0f}%){}".format(
        epoch, test_loss, correct, total_data, progress, 80 * " "
    ))

In [28]:
# load dataset
img_paths = sorted(glob.glob("highway/input/*.jpg"))
imgs = [read_image(img_path).float() for img_path in img_paths]
img_label_paths = sorted(glob.glob("highway/groundtruth/*.png"))
img_labels = [read_image(img_label_path).bool().int().long().squeeze() for img_label_path in img_label_paths]

# split dataset (80/20)
X_train, X_validation, y_train, y_validation = train_test_split(imgs, img_labels, test_size=0.2, random_state=42)

In [29]:
# hyperparameters
batch_size = 8
momentum = 0.99

# create training and validation loaders
train_dataset = ChangeDetectionDataset(X_train, y_train)
validation_dataset = ChangeDetectionDataset(X_validation, y_validation)
train_loader = DataLoader(train_dataset, batch_size=batch_size)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size)

# run training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet_model = UNet().to(device)
train_model(
    unet_model,
    device,
    train_loader,
    validation_loader,
    train_segmentation_model,
    validate_segmentation_model,
    momentum=momentum
)



KeyboardInterrupt: 

## Task 4: Small Report of your model and training
- visualize training and test error over each epoch
- report the evaluation metrics of the final model