# Dense Prediction
---
In this part, you will study a problem of segmentation. The goal of this assignment is to study, implement, and compare different components of dense prediction models, including **data augmentation**, **backbones**, **classifiers** and **losses**.

This assignment will require training multiple neural networks, therefore it is advised to use a **GPU** accelerator.

## Dataset

We will use a simplified version of a [MasonryWallAnalysis](http://mplab.sztaki.hu/geocomp/masonryWallAnalysis).

## Part 1. Code


### `dataset`
**TODO: implement and apply data augmentations**

You'll need to study a popular augmentations library: [Albumentations](https://albumentations.ai/), and implement the requested augs. Remember that geometric augmentations need to be applied to both images and masks at the same time, and Albumentations has [native support](https://albumentations.ai/docs/getting_started/mask_augmentation/) for that.

In [1]:
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from torchvision.transforms import ToTensor
import os
from PIL import Image
import numpy as np
import torch
import cv2

MEAN = (0.485, 0.456, 0.406)
STD = (0.229, 0.224, 0.225)

class UnNormalize:
    
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor

class FloodNet(Dataset):
    """
    Labels semantic:
    0: Background, 1: Building, 2: Road, 3: Water, 4: Tree, 5: Vehicle, 6: Pool, 7: Grass
    """
    def __init__(
        self,
        data_path: str,
        phase: str,
        augment: bool,
        img_size: int,
    ):
        self.num_classes = 2
        self.data_path = data_path
        self.phase = phase
        self.augment = augment
        self.img_size = img_size
        
        if phase == 'test':
            self.phase = f'{self.phase}/Test Set (1)'

        self.items = [
            filename.split('.')[0]
            for filename in os.listdir(f'{self.data_path}/{self.phase}/images')
            if len(filename.split('.')[0]) > 0
        ]
        
        # TODO: implement augmentations (3.5 points)
        if augment:
            # TODO:
            # Random resize
            # Random crop (within image borders, output size = img_size)
            # Random rotation
            # Random horizontal and vertical Flip
            # Random color augmentation
            self.transform = A.Compose([
                A.RandomScale(
                    scale_limit=(0.5, 1.5),
                    p=0.5,
                ),
                A.RandomCrop(
                    width=self.img_size,
                    height=self.img_size,
                    p=1.,
                ),
                A.RandomRotate90(
                    p=0.5,
                ),
                A.VerticalFlip(
                    p=0.5,
                ),
                A.HorizontalFlip(
                    p=0.5,
                ),
                A.ColorJitter(
                    brightness=0.2,
                    contrast=0.2,
                    p=0.5,
                ),
                A.Blur(
                    p=0.5,
                ),
                A.Normalize(
                    mean=MEAN,
                    std=STD,
                ),
            ])

        else:
            # TODO: random crop to img_size
            self.transform = A.Compose([
                A.RandomCrop(
                    width=self.img_size,
                    height=self.img_size,
                    p=1.,
                ),
                A.Normalize(
                    mean=MEAN,
                    std=STD,
                ),
            ])
            
        self.base_tf = A.Compose([
            A.Resize(
                width=self.img_size,
                height=self.img_size,
                p=1.,
            ),
            A.Normalize(
                mean=MEAN,
                std=STD,
            ),
        ])
        
        self.to_tensor = ToTensor()

    def __len__(self):
        return len(self.items)

    def __getitem__(self, index):
        image = np.asarray(Image.open(f'{self.data_path}/{self.phase}/images/{self.items[index]}.png'))
        mask = np.asarray(Image.open(f'{self.data_path}/{self.phase}/labels/{self.items[index]}.png')).clip(0, 1)
        
        if self.phase == 'train':
            # TODO: apply transform to both image and mask (0.5 points)

            transformed = self.transform(
                image=image,
                mask=mask,
            )

            image = transformed['image']
            mask = transformed['mask']
        else:
            
            transformed = self.base_tf(
                image=image,
                mask=mask,
            )

            image = transformed['image']
            mask = transformed['mask']
            
        image = self.to_tensor(image.copy())
        mask = torch.from_numpy(mask.copy()).long()

        if self.phase == 'train':
            assert isinstance(image, torch.FloatTensor) and image.shape == (3, self.img_size, self.img_size)
            assert isinstance(mask, torch.LongTensor) and mask.shape == (self.img_size, self.img_size)

        return image, mask

### `model`
**TODO: Implement the required models.**

Typically, all segmentation networks consist of an encoder and decoder. Below is a scheme for a popular DeepLab v3 architecture:

<img src="https://i.imgur.com/cdlkxvp.png" />

The encoder consists of a convolutional backbone, typically with extensive use of convs with dilations (atrous convs) and a head, which helps to further boost the receptive field. As you can see, the general idea for the encoders is to have as big of a receptive field, as possible.

The decoder either does upsampling with convolutions (similarly to the scheme above, or to UNets), or even by simply interpolating the outputs of the encoder.

In this assignment, you will need to implement **UNet** and **DeepLab** models. Example UNet looks like this:

<img src="https://i.imgur.com/uVdcE4e.png" />

For **DeepLab** model we will have three variants for backbones: **ResNet18**, **VGG11 (with BatchNorm)**, and **MobileNet v3 (small).** Use `torchvision.models` to obtain pre-trained versions of these backbones and simply extract their convolutional parts. To familiarize yourself with **MobileNet v3** model, follow this [link](https://paperswithcode.com/paper/searching-for-mobilenetv3).

We will also use **Atrous Spatial Pyramid Pooling (ASPP)** head. Its scheme can be seen in the DeepLab v3 architecture above. ASPP is one of the blocks which greatly increases the spatial size of the model, and hence boosts the model's performance. For more details, you can refer to this [link](https://paperswithcode.com/method/aspp).

In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import models

class DoubleBlock(nn.Module):

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        middle_channels: int=None,
        dropout: bool=False,
        p: float=0.2,
    ) -> None:

        super().__init__()

        if middle_channels is None:
            middle_channels = out_channels

        self.first_step = nn.Sequential(
            nn.Dropout(p) if dropout else nn.Identity(),
            nn.Conv2d(
                in_channels,
                middle_channels,
                kernel_size=(3, 3),
                padding=(1, 1),
                bias=False,
            ),
            nn.BatchNorm2d(middle_channels),
            nn.ReLU(inplace=True),
        )
        self.second_step = nn.Sequential(
            nn.Conv2d(
                middle_channels,
                out_channels,
                kernel_size=(3, 3),
                padding=(1, 1),
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:

        x = self.first_step(x)
        x = self.second_step(x)

        return x

class Encoder(nn.Module):

    def __init__(
        self,
        channels: list,
        max_channels: int,
    ) -> None:

        super().__init__()
        
        last_or_not = lambda i, length: True if i == length - 1 else False

        self.blocks = nn.ModuleList([
            DoubleBlock(
                in_channels=channels[i],
                out_channels=channels[i + 1],
                middle_channels=max_channels if last_or_not(i, len(channels)) else channels[i + 1]
            )
            for i in range(len(channels) - 1)
        ])

        self.downsample = nn.MaxPool2d(
            kernel_size=(2, 2),
        )
    
    def forward(
        self,
        x: torch.Tensor,
    ) -> tuple:

        features = []

        for block in self.blocks:

            x = block(x)
            features.append(x)
            x = self.downsample(x)

        return features

class Decoder(nn.Module):

    def __init__(
        self,
        channels: list,
        p: float=0.2,
    ) -> None:

        super().__init__()

        self.channels = channels

        self.blocks = nn.ModuleList([
            DoubleBlock(
                in_channels=self.channels[i],
                out_channels=self.channels[i + 1],
                dropout=True,
                p=p,
            )
            for i in range(len(self.channels) - 1)
        ])

        self.upsample = nn.ModuleList([
            nn.ConvTranspose2d(
                in_channels=self.channels[i],
                out_channels=self.channels[i + 1],
                kernel_size=(2, 2),
                stride=(2, 2),
            )
            for i in range(len(self.channels) - 1)
        ])
    
    def forward(
        self,
        x: torch.Tensor,
        features: torch.Tensor,
    ) -> torch.Tensor:
        
        for i in range(len(self.channels) - 1):

            x = self.upsample[i](x)

            _, _, H, W = x.shape
            _, _, H_feat, W_feat = features[i].shape
            
            if H != H_feat or W != W_feat:

                features[i] = torch.nn.functional.interpolate(
                    input=features[i],
                    size=(H, W),
                    mode='bilinear',
                    align_corners=True,
                )

            x = torch.cat(
                tensors=[x, features[i]],
                dim=1,
            )

            x = self.blocks[i](x)

        return x


class UNet(nn.Module):
    """
    TODO: 8 points

    A standard UNet network (with padding in covs).

    For reference, see the scheme in materials/unet.png
    - Use batch norm between conv and relu
    - Use max pooling for downsampling
    - Use conv transpose with kernel size = 3, stride = 2, padding = 1, and output padding = 1 for upsampling
    - Use 0.5 dropout after concat

    Args:
      - num_classes: number of output classes
      - min_channels: minimum number of channels in conv layers
      - max_channels: number of channels in the bottleneck block
      - num_down_blocks: number of blocks which end with downsampling

    The full architecture includes downsampling blocks, a bottleneck block and upsampling blocks

    You also need to account for inputs which size does not divide 2**num_down_blocks:
    interpolate them before feeding into the blocks to the nearest size which divides 2**num_down_blocks,
    and interpolate output logits back to the original shape
    """
    def __init__(
        self, 
        num_classes,
        min_channels: int=32,
        max_channels: int=512, 
        num_down_blocks: int=5,
        input_channels: int=3,
    ) -> None:

        super(UNet, self).__init__()
        
        self.num_classes = num_classes

        channels = self.make_channels(
            min_channels=min_channels,
            num_down_blocks=num_down_blocks,
        )
        
        self.decoder = Decoder(
            channels=channels[::-1],
        )
        self.encoder = Encoder(
            channels=[input_channels, ] + channels,
            max_channels=max_channels,
        )
        self.head = nn.Conv2d(
            in_channels=channels[0],
            out_channels=num_classes,
            kernel_size=(1, 1),
        )

        self.init_weights()

    @staticmethod
    def make_channels(
        min_channels: int,
        num_down_blocks: int,
    ) -> list:

        output = [
            min_channels*(2**i)
            for i in range(num_down_blocks)
        ]

        return output

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:

        B, C, H, W = x.shape
        
        features = self.encoder(x)[::-1]

        x, last = features[0], features[1:]

        out = self.decoder(
            x=x,
            features=last,
        )

        logits = self.head(out)
        
        _, _, H_feat, W_feat = logits.shape
            
        if H != H_feat or W != W_feat:

            logits = torch.nn.functional.interpolate(
                input=logits,
                size=(H, W),
                mode='bilinear',
                align_corners=True,
            )

        assert logits.shape == (B, self.num_classes, H, W), \
            f'Wrong shape of the logits. Got: {logits.shape}, expected: {(B, self.num_classes, H, W)}'
        
        return logits

    def init_weights(
        self,
    ) -> None:

        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.kaiming_normal_(
                    m.weight,
                    mode='fan_out',
                    nonlinearity='relu',
                )
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


class DeepLab(nn.Module):
    """
    TODO: 6 points

    (simplified) DeepLab segmentation network.
    
    Args:
      - backbone: ['resnet18', 'vgg11_bn', 'mobilenet_v3_small'],
      - aspp: use aspp module
      - num classes: num output classes

    During forward pass:
      - Pass inputs through the backbone to obtain features
      - Apply ASPP (if needed)
      - Apply head
      - Upsample logits back to the shape of the inputs
    """
    def __init__(
        self,
        backbone: str,
        aspp: bool,
        num_classes: int,
    ) -> None:
        
        super(DeepLab, self).__init__()
        
        self.backbone = backbone
        self.num_classes = num_classes
        
        self.backbone, self.out_features = self.get_features(
            backbone_name=backbone,
            freeze=True,
            pretrained=True,
        )

        if aspp:
            self.aspp = ASPP(self.out_features, 256, [12, 24, 36])
        else:
            self.aspp = None

        self.head = DeepLabHead(self.out_features, num_classes)
    
    @staticmethod
    def get_features(
        backbone_name: str,
        freeze: bool=False,
        pretrained: bool=False,
    ) -> list:
        
        backbone = getattr(models, backbone_name)(
            pretrained=pretrained,
        )
        
        for param in backbone.parameters():
            param.requires_grad = not freeze
            
        blocks = list(backbone.children())
        
        convs, avg, clf = blocks[:-2], blocks[-2], blocks[-1]
        
        if isinstance(clf, nn.Sequential):
            clf = clf[0]
        
        features = torch.nn.Sequential(*convs)
        
        if isinstance(avg.output_size, (tuple, list)):
            devider = avg.output_size[0] * avg.output_size[1]
        else:
            devider = avg.output_size * avg.output_size
            
        output_features = clf.in_features // devider
        
        return features, output_features
        
    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        
        B, C, H, W = x.shape
        
        features = self.backbone(x)
        
        if self.aspp is not None:
            features = self.aspp(features)
            
        logits = self.head(features)
        
        _, _, H_out, W_out = logits.shape
        
        if H != H_out or W != W_out:

            logits = torch.nn.functional.interpolate(
                input=logits,
                size=(H, W),
                mode='bilinear',
                align_corners=True,
            )

        assert logits.shape == (B, self.num_classes, H, W), 'Wrong shape of the logits'
        return logits


class DeepLabHead(nn.Sequential):
    
    def __init__(self, in_channels, num_classes):
        
        super(DeepLabHead, self).__init__(
            nn.Conv2d(in_channels, in_channels, 3, padding=1, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels, num_classes, 1),
        )

class ASPPBlock(nn.Sequential):
    
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        padding: int,
        dilation: int,
    ) -> None:
        
        super(ASPPBlock, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, dilation=dilation, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

class ASPP(nn.Module):
    """
    TODO: 8 points

    Atrous Spatial Pyramid Pooling module
    with given atrous_rates and out_channels for each head
    Description: https://paperswithcode.com/method/aspp
    
    Detailed scheme: materials/deeplabv3.png
      - "Rates" are defined by atrous_rates
      - "Conv" denotes a Conv-BN-ReLU block
      - "Image pooling" denotes a global average pooling, followed by a 1x1 "conv" block and bilinear upsampling
      - The last layer of ASPP block should be Dropout with p = 0.5

    Args:
      - in_channels: number of input and output channels
      - num_channels: number of output channels in each intermediate "conv" block
      - atrous_rates: a list with dilation values
    """
    def __init__(self, in_channels, num_channels, atrous_rates):
        
        super(ASPP, self).__init__()
        
        self.blocks = nn.ModuleList([
            ASPPBlock(
                in_channels=in_channels,
                out_channels=num_channels,
                kernel_size=(3, 3),
                padding=atrous_rate,
                dilation=atrous_rate,
            )
            for atrous_rate in atrous_rates
        ])
        
        self.pooling = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(in_channels, num_channels, 1, stride=1, bias=False),
            nn.BatchNorm2d(num_channels),
            nn.ReLU(),
        )
        
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(num_channels*4, in_channels, 1, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.PReLU(in_channels),
            nn.Dropout(0.5),
        )
        
    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        # TODO: forward pass through the ASPP module
        
        features = [
            block(x)
            for block in self.blocks
        ]
        
        features.append(
            torch.nn.functional.interpolate(
                input=self.pooling(x),
                size=features[-1].shape[2:],
                mode='bilinear',
                align_corners=True,
            )
        )
        
        out = torch.cat(
            tensors=features,
            dim=1,
        )
        
        res = self.conv1x1(out)
            
        assert res.shape[1] == x.shape[1], 'Wrong number of output channels'
        assert res.shape[2] == x.shape[2] and res.shape[3] == x.shape[3], 'Wrong spatial size'
        return res

### `loss`
**TODO: implement test losses.**

For validation, we will use three metrics. 
- Mean intersection over union: **mIoU**,
- Mean class accuracy: **classAcc**,
- Accuracy: **Acc**.

To calculate **IoU**, use this formula for binary segmentation masks for each class, and then average w.r.t. all classes:

$$ \text{IoU} = \frac{ \text{area of intersection} }{ \text{area of union} } = \frac{ \| \hat{m} \cap m  \| }{ \| \hat{m} \cup m \| }, \quad \text{$\hat{m}$ — predicted binary mask},\ \text{$m$ — target binary mask}.$$

For **mRecall** you can use the following formula:

$$
    \text{mRecall} = \frac{ \| \hat{m} \cap m \| }{ \| m \| }
$$

And **accuracy** is a fraction of correctly identified pixels in the image.

Generally, we want our models to optimize accuracy since this implies that it makes little mistakes. However, most of the segmentation problems have imbalanced classes, and therefore the models tend to underfit the rare classes. Therefore, we also need to measure the mean performance of the model across all classes (mean IoU or mean class accuracy). In reality, these metrics (not the accuracy) are the go-to benchmarks for segmentation models.

In [3]:
def calc_val_data(preds, masks, num_classes):
    preds = torch.argmax(preds, dim=1)
    
    intersection = torch.stack(
        tensors=[
            ((masks == i) & (preds == i)).sum((1, 2))
            for i in range(num_classes)
        ],
        dim=-1,
    ) #calc intersection for each class
    union = torch.stack(
        tensors=[
            ((masks == i) | (preds == i)).sum((1, 2))
            for i in range(num_classes)
        ],
        dim=-1,
    ) #calc union for each class

    target = torch.stack(
        tensors=[
            (masks == i).sum((1, 2))
            for i in range(num_classes)
        ],
        dim=-1,
    ) #calc number of pixels in groundtruth mask per class
    # Output shapes: B x num_classes

    assert isinstance(intersection, torch.Tensor), 'Output should be a tensor'
    assert isinstance(union, torch.Tensor), 'Output should be a tensor'
    assert isinstance(target, torch.Tensor), 'Output should be a tensor'

    assert intersection.shape == union.shape == target.shape, 'Wrong output shape'
    assert union.shape[0] == masks.shape[0] and union.shape[1] == num_classes, 'Wrong output shape'

    return intersection, union, target

def calc_val_loss(intersection, union, target, eps = 1e-7):
    
    iou = torch.nanmean(
        intersection / union,
        dim=1,
    )
    rec = torch.nanmean(
        intersection / target,
        dim=1,
    )
    acc = 1 - (union - intersection).sum(dim=1) / (2*target.sum(dim=1))
    
    mean_iou = iou.mean() # TODO: calc mean class iou
    mean_class_rec = rec.mean() # TODO: calc mean class recall
    mean_acc = acc.mean() # TODO: calc mean accuracy

    return mean_iou, mean_class_rec, mean_acc

### `train`
**TODO: define optimizer and learning rate scheduler.**

You need to experiment with different optimizers and schedulers and pick one of each which works the best. Since the grading will be partially based on the validation performance of your models, we strongly advise doing some preliminary experiments and pick the configuration with the best results.

In [4]:
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Modifications Copyright Skoltech Deep Learning Course.

import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader


class SegModel(pl.LightningModule):
    def __init__(
        self,
        model: str,
        backbone: str,
        aspp: bool,
        augment_data: bool,
        optimizer: str = 'default',
        scheduler: str = 'default',
        lr: float = None,
        batch_size: int = 16,
        data_path: str = '../data/PPKE-SZTAKI-MasonryBenchmark',
        image_size: int = 512,
        num_classes: int = 2,
    ) -> None:
        
        super(SegModel, self).__init__()
        self.num_classes = num_classes

        if model == 'unet':
            self.net = UNet(num_classes=self.num_classes)
        elif model == 'deeplab':
            self.net = DeepLab(backbone, aspp, self.num_classes)

        self.train_dataset = FloodNet(data_path, 'train', augment_data, image_size)
        self.test_dataset = FloodNet(data_path, 'test', augment_data, image_size)

        self.batch_size = batch_size
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.lr = lr
        self.eps = 1e-7

        self.unnorm = UnNormalize(
            mean=MEAN,
            std=STD,
        )

        # Visualization
        self.color_map = torch.FloatTensor(
            [[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1],
             [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]])

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        img, mask = batch
        pred = self.forward(img)

        train_loss = F.cross_entropy(pred, mask)

        self.log('train_loss', train_loss, prog_bar=True)

        return train_loss

    def validation_step(self, batch, batch_idx):
        img, mask = batch
        pred = self.forward(img)
        
        val_loss = F.cross_entropy(pred, mask)

        self.log('val_loss', val_loss, prog_bar=True)

        intersection, union, target = calc_val_data(pred, mask, self.num_classes)

        return {'intersection': intersection, 'union': union, 'target': target, 'img': self.unnorm(img), 'pred': pred, 'mask': mask}

    def validation_epoch_end(self, outputs):
        intersection = torch.cat([x['intersection'] for x in outputs])
        union = torch.cat([x['union'] for x in outputs])
        target = torch.cat([x['target'] for x in outputs])

        mean_iou, mean_class_rec, mean_acc = calc_val_loss(intersection, union, target, self.eps)

        log_dict = {'mean_iou': mean_iou, 'mean_class_rec': mean_class_rec, 'mean_acc': mean_acc}

        for k, v in log_dict.items():
            self.log(k, v, prog_bar=True)

        # Visualize results
        img = torch.cat([x['img'] for x in outputs]).cpu()
        pred = torch.cat([x['pred'] for x in outputs]).cpu()
        mask = torch.cat([x['mask'] for x in outputs]).cpu()

        pred_vis = self.visualize_mask(torch.argmax(pred, dim=1))
        mask_vis = self.visualize_mask(mask)

        results = torch.cat(torch.cat([img, pred_vis, mask_vis], dim=3).split(1, dim=0), dim=2)
        results_thumbnail = F.interpolate(results, scale_factor=0.25, mode='bilinear', align_corners=True)[0]

        self.logger.experiment.add_image('results', results_thumbnail, self.current_epoch)

    def visualize_mask(self, mask):
        b, h, w = mask.shape
        mask_ = mask.view(-1)

        if self.color_map.device != mask.device:
            self.color_map = self.color_map.to(mask.device)

        mask_vis = self.color_map[mask_].view(b, h, w, 3).permute(0, 3, 1, 2).clone()

        return mask_vis

    def configure_optimizers(self):
        # TODO: 2 points
        # Use self.optimizer and self.scheduler to call different optimizers
        opt = torch.optim.Adam(self.net.parameters(), lr=self.lr) # TODO: init optimizer
        sch = torch.optim.lr_scheduler.StepLR(opt, step_size=20) # TODO: init learning rate scheduler
        return [opt], [sch]

    def train_dataloader(self):
        return DataLoader(self.train_dataset, num_workers=8, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.test_dataset, num_workers=8, batch_size=1, shuffle=False)

## Part 2. Train and benchmark

In this part of the assignment, you need to train the following models and measure their training time:
- **UNet** (with and without data augmentation),
- **DeepLab** with **ResNet18** backbone (with **ASPP** = True and False),
- **DeepLab** with the remaining backbones you implemented and **ASPP** = True).

To get the full mark for this assignment, all the required models should be trained (and their checkpoints provided), and have at least 0.5 accuracies.

After the models are trained, evaluate their inference time on both GPU and CPU.

Example training and evaluation code are below.

In [5]:
import pytorch_lightning as pl
import time
import torch


def define_model(
    model_name: str, 
    backbone: str, 
    aspp: bool, 
    augment_data: bool, 
    optimizer: str, 
    scheduler: str, 
    lr: float, 
    checkpoint_name: str = '', 
    batch_size: int = 8,
    num_classes: int = 2,
):
    assignment_dir = 'semantic_segmentation'
    experiment_name = f'{model_name}_{backbone}_augment={augment_data}_aspp={aspp}'
    model_name = model_name.lower()
    backbone = backbone.lower() if backbone is not None else backbone
    
    model = SegModel(
        model_name, 
        backbone, 
        aspp, 
        augment_data,
        optimizer,
        scheduler,
        lr,
        batch_size,
        data_path='../data/PPKE-SZTAKI-MasonryBenchmark',
        image_size=512,
        num_classes=num_classes,
    )

    if checkpoint_name:
        model.load_state_dict(torch.load(f'{assignment_dir}/logs/{experiment_name}/{checkpoint_name}')['state_dict'])
    
    return model, experiment_name

def train(model, experiment_name, use_gpu):
    assignment_dir = 'semantic_segmentation'

    logger = pl.loggers.TensorBoardLogger(save_dir=f'{assignment_dir}/logs', name=experiment_name)

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        monitor='mean_iou',
        dirpath=f'{assignment_dir}/logs/{experiment_name}',
        filename='{epoch:02d}-{mean_iou:.3f}',
        mode='max',
    )
    
    trainer = pl.Trainer(
        max_epochs=100, 
        gpus=1 if use_gpu else None, 
        benchmark=True, 
        check_val_every_n_epoch=5, 
        logger=logger, 
        callbacks=[checkpoint_callback],
    )

    time_start = time.time()
    
    trainer.fit(model)
    
    torch.cuda.synchronize()
    time_end = time.time()
    
    training_time = (time_end - time_start) / 60
    
    return training_time

In [6]:
model, experiment_name = define_model(
    model_name='Unet',
    backbone=None,
    aspp=None,
    augment_data=True,
    optimizer='adam', # use these options to experiment
    scheduler='step_lr', # with optimizers and schedulers
    lr=0.01, # experiment to find the best LR
    num_classes=2,
)

training_time = train(model, experiment_name, use_gpu=True)
print('Training time:', training_time)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type | Params
------------------------------
0 | net  | UNet | 7.8 M 
------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.052    Total estimated model params size (MB)
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")


Validation sanity check: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Training time: 24.335324863592785


After training, the loss curves and validation images with their segmentation masks can be viewed using the TensorBoard extension:

In [7]:
%reload_ext tensorboard
%tensorboard --logdir semantic_segmentation/logs