# <font style="color:blue">LinkNet Architecture</font>

In this notebook, we will examine LinkNet architecture, do block exploration, and implement it.

# <font style="color:green">1. LinkNet</font>

Let's briefly overview the LinkNet architecture. [LinkNet](https://arxiv.org/pdf/1707.03718.pdf) was
introduced in 2017 by A.Chaurasia and E.Culurciello as a novel lightweight deep neural network for semantic
segmentation, which can learn a moderate growth rate of parameters:

---

<img src='https://www.learnopencv.com/wp-content/uploads/2020/04/c3-w11-LinkNet_architecture.png'>

---

In the picture above `/2` means downsampling of the feature map by a factor of 2 which is achieved by performing strided convolution, `∗2` denotes upsampling by `2`.


An encoder is the left half of the network, whereas the the right side of it is a decoder.


Let's view the below picture, where the encoder scheme is presented. The initial block of the encoder consists of a convolution with `7×7` kernel size and a stride of `2` and a max-pooling layer with a `3×3` kernel and also a stride of `2`. **The later encoder parts consist of residual blocks similar to `ResNet18` architecture:**

---

<img src='https://www.learnopencv.com/wp-content/uploads/2020/04/c3-w11-encoder.png'>

---


**Let's take a look at the decoder:**

---

<img src='https://www.learnopencv.com/wp-content/uploads/2020/04/c3-w11-decoder.png'>

---

- In the decoder the full-convolution technique is used. It is noticeable that the novelty of the LinkNet is in the connection between each encoder and decoder block. 


- Usually applying multiple downsampling in encoder leads to the partial missing of spatial information and further difficulties in its recovering. 


- In the **`LinkNet`'s** encoder inputs are bypassed to the outputs of the corresponding decoders. This scheme results in reconstruction of the lost spatial information. Also, such links between encoder and decoder trigger knowledge sharing at every layer. It means, that decoder can use smaller number of parameters and eventually will obtain more qualitative image feature information restoration.


- It will also see that to process an image of size `640×360×3` the network uses `11.5 million` parameters or `21.2 GFLOPs`.

# <font style="color:green">2. LinkNet Implementation</font>

Now, when we have discussed underlying ideas behind the architecture, let's implement it.
We will need the following imports:

In [1]:
import torch
# torch neural network (NN) module for building and training nets
import torch.nn as nn
# module with various model definitions
import torchvision.models as models

import numpy as np
from dataclasses import dataclass
import random

**Configuration for reproducible results.**

In [2]:
@dataclass
class SystemConfig:
    seed: int = 42  # seed number to set the state of all random number generators
    cudnn_benchmark_enabled: bool = False  # enable CuDNN benchmark for the sake of performance
    cudnn_deterministic: bool = True  # make cudnn deterministic (reproducible training)

In [3]:
def setup_system(system_config: SystemConfig) -> None:
    torch.manual_seed(system_config.seed)
    np.random.seed(system_config.seed)
    random.seed(system_config.seed)
    torch.set_printoptions(precision=10)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(system_config.seed)
        torch.backends.cudnn_benchmark_enabled = system_config.cudnn_benchmark_enabled
        torch.backends.cudnn.deterministic = system_config.cudnn_deterministic

## <font style="color:green">2.1. Decoder</font>

The below presented block is a decoder, which takes a feature map with defined channels number. The `channels_in` the result map should be equal to `channels_out`.

We have used `ConvTranspose2d` for upsampling, find details [here](https://pytorch.org/docs/stable/nn.html#torch.nn.ConvTranspose2d).

In [4]:
# create decoder block inherited from nn.Module
class DecoderBlock(nn.Module):
    def __init__(self, channels_in, channels_out):
        super().__init__()

        # 1x1 projection module to reduce channels
        self.proj = nn.Sequential(
            # convolution
            nn.Conv2d(channels_in, channels_in // 4, kernel_size=1, bias=False),
            # batch normalization
            nn.BatchNorm2d(channels_in // 4),
            # relu activation
            nn.ReLU()
        )

        # fully convolutional module
        self.deconv = nn.Sequential(
            # deconvolution
            nn.ConvTranspose2d(
                channels_in // 4,
                channels_in // 4,
                kernel_size=4,
                stride=2,
                padding=1,
                output_padding=0,
                groups=channels_in // 4,
                bias=False
            ),
            # batch normalization
            nn.BatchNorm2d(channels_in // 4),
            # relu activation
            nn.ReLU()
        )

        # 1x1 unprojection module to increase channels
        self.unproj = nn.Sequential(
            # convolution
            nn.Conv2d(channels_in // 4, channels_out, kernel_size=1, bias=False),
            # batch normalization
            nn.BatchNorm2d(channels_out),
            # relu activation
            nn.ReLU()
        )

    # stack layers and perform a forward pass
    def forward(self, x):

        proj = self.proj(x)
        deconv = self.deconv(proj)
        unproj = self.unproj(deconv)

        return unproj

## <font style="color:green">2.2. LinkNet</font>

To define the network we also use blocks from pretrained PyTorch
[ResNet18](https://pytorch.org/docs/stable/_modules/torchvision/models/resnet.html#resnet18) as the encoder architecture is similar to `ResNet18` encoder.

In [5]:
# create LinkNet model with ResNet18 encoder
class LinkNet(nn.Module):
    def __init__(self, num_classes, encoder="resnet18"):
        super().__init__()
        assert hasattr(models, encoder), "Undefined encoder type"
        # prepare feature extractor from `torchvision` ResNet model
        feature_extractor = getattr(models, encoder)(pretrained=True)
        # Init block: get configured Conv2d, BatchNorm2d layers and ReLU from torch ResNet class
        self.init = nn.Sequential(feature_extractor.conv1, feature_extractor.bn1, feature_extractor.relu)
        self.maxpool = feature_extractor.maxpool

        # Encoder's blocks: torch ResNet18 blocks initialization
        self.layer1 = feature_extractor.layer1
        self.layer2 = feature_extractor.layer2
        self.layer3 = feature_extractor.layer3
        self.layer4 = feature_extractor.layer4

        # Decoder's block: DecoderBlock module
        self.up4 = DecoderBlock(self._num_channels(self.layer4), self._num_channels(self.layer3))
        self.up3 = DecoderBlock(self._num_channels(self.layer3), self._num_channels(self.layer2))
        self.up2 = DecoderBlock(self._num_channels(self.layer2), self._num_channels(self.layer1))
        self.up1 = DecoderBlock(self._num_channels(self.layer1), self._num_channels(self.layer1))

        # Classification block: define a classifier module
        self.classifier = nn.Sequential(
            # deconvolution layer
            nn.ConvTranspose2d(self._num_channels(self.layer1), 32, 3, stride=2, bias=False),
            # batch normalization with num_features = 32
            nn.BatchNorm2d(32),
            # activation function
            nn.ReLU(),
            # convolutional layer
            nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False),
            # batch normalization with num_features = 32
            nn.BatchNorm2d(32),
            # activation function
            nn.ReLU(),
            # convolutional layer
            nn.Conv2d(32, num_classes, kernel_size=2, padding=0)
        )

    # get a compatible number of channels to stack all of the LinkNet's blocks together
    @staticmethod
    def _num_channels(block):
        """
           Extract batch-norm num_features from the input block.

            Arguments:
                block: torch resNet18 layers.
        """
        # check whether the input block is models.resnet.BasicBlock type
        if isinstance(block[-1], models.resnet.BasicBlock):
            return block[-1].bn2.weight.size(0)
        # if not extract the spatial characteristic of batch-norm weights from input block
        return block[-1].bn3.weight.size(0)

    # define the forward pass
    def forward(self, x):

        # output size = (64, 160, 160)
        init = self.init(x)
        # output size = (64, 80, 80)
        maxpool = self.maxpool(init)
        # output size = (64, 80, 80)
        layer1 = self.layer1(maxpool)
        # output size = (128, 40, 40)
        layer2 = self.layer2(layer1)
        # output size = (256, 20, 20)
        layer3 = self.layer3(layer2)
        # output size = (512, 10, 10)
        layer4 = self.layer4(layer3)
        # output size = (256, 20, 20)
        up4 = self.up4(layer4) + layer3
        # output size = (128, 40, 40)
        up3 = self.up3(up4) + layer2
        # output size = (64, 80, 80)
        up2 = self.up2(up3) + layer1
        # output size = (64, 160, 160)
        up1 = self.up1(up2)
        # output size = (5, 320, 320), where 5 is the predefined number of classes
        output = self.classifier(up1)

        return output

## <font style="color:green">2.3. Check the Implementation</font>

**Let's check the implementation.**

In [6]:
# apply system settings
setup_system(SystemConfig)

# input data for model check
input_tensor = torch.zeros(1, 3, 320, 320)

# LinkNet architecture
model = LinkNet(num_classes=5, encoder="resnet18")

# examining the prediction size
pred = model(input_tensor)
print('Prediction Size: {}'.format(pred.size()))

Prediction Size: torch.Size([1, 5, 320, 320])


**Input width and height is the same as output width and height because semantic segmentation predicts the label of each pixel.**

# <font style="color:green">3. Model Profiler</font>

**Let's find the number of parameters and the number of floating-point operations of our model.**

Let's write a class and a wrapper function around it for model profiling.

**Note that to count the number of parameters, we need only the model (model parameters), where to count floating-point operation, we also need input image size, because convolution output size depends on input image size.**

In [7]:
class ModelProfiler(nn.Module):
    """ Profile PyTorch models.

    Compute FLOPs (FLoating OPerations) and number of trainable parameters of model.

    Arguments:
        model (nn.Module): model which will be profiled.

    Example:
        model = torchvision.models.resnet50()
        profiler = ModelProfiler(model)
        var = torch.zeros(1, 3, 224, 224)
        profiler(var)
        print("FLOPs: {0:.5}; #Params: {1:.5}".format(profiler.get_flops('G'), profiler.get_params('M')))

    Warning:
        Model profiler doesn't work with models, wrapped by torch.nn.DataParallel.
    """
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.flops = 0
        self.units = {'K': 10.**3, 'M': 10.**6, 'G': 10.**9}
        self.hooks = None
        self._remove_hooks()

    def get_flops(self, units='G'):
        """ Get number of floating operations per inference.

        Arguments:
            units (string): units of the flops value ('K': Kilo (10^3), 'M': Mega (10^6), 'G': Giga (10^9)).

        Returns:
            Floating operations per inference at the choised units.
        """
        assert units in self.units
        return self.flops / self.units[units]

    def get_params(self, units='G'):
        """ Get number of trainable parameters of the model.

        Arguments:
            units (string): units of the flops value ('K': Kilo (10^3), 'M': Mega (10^6), 'G': Giga (10^9)).

        Returns:
            Number of trainable parameters of the model at the choised units.
        """
        assert units in self.units
        params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        if units is not None:
            params = params / self.units[units]
        return params

    def forward(self, *args, **kwargs):
        self.flops = 0
        self._init_hooks()
        output = self.model(*args, **kwargs)
        self._remove_hooks()
        return output

    def _remove_hooks(self):
        if self.hooks is not None:
            for hook in self.hooks:
                hook.remove()
        self.hooks = None

    def _init_hooks(self):
        self.hooks = []

        def hook_compute_flop(module, _, output):
            self.flops += module.weight.size()[1:].numel() * output.size()[1:].numel()

        def add_hooks(module):
            if isinstance(module, (nn.Conv2d, nn.Linear)):
                self.hooks.append(module.register_forward_hook(hook_compute_flop))

        self.model.apply(add_hooks)

In [8]:
def profile_model(model, input_size, cuda):
    """ Compute FLOPS and #Params of the CNN.

    Arguments:
        model (nn.Module): model which should be profiled.
        input_size (tuple): size of the input variable.
        cuda (bool): if True then variable will be upload to the GPU.

    Returns:
        dict:
            dict["flops"] (float): number of GFLOPs.
            dict["params"] (int): number of million parameters.
    """
    profiler = ModelProfiler(model)
    var = torch.zeros(input_size)
    if cuda:
        var = var.cuda()
    profiler(var)
    return {"flops": profiler.get_flops('G'), "params": profiler.get_params('M')}

**Let's calculate  `GFLOPs` and the `number of parameters` in the model.**

In [9]:
# input data for model check
input_tensor = torch.zeros(1, 3, 640, 320)

flops, params = profile_model(model, input_tensor.size(), False).values()

print('GFLOPs:\t\t\t\t{}\nNo. of params (in million):\t{}'.format(flops, params))


GFLOPs:				9.613157376
No. of params (in million):	11.341829
