# Homework 6

Taking inspiration from the last 2 pictures within the notebook (07-convnets.ipynb), implement a U-Net-style CNN with the following specs:
 
1. All convolutions must use a 3 x 3 kernel and **leave the spatial dimensions (i.e. height, width) of the input untouched**.
2. Downsampling in the contracting part is performed via maxpooling with a 2 x 2 kernel and stride of 2.
3. Upsampling is operated by a deconvolution with a 2 x 2 kernel and stride of 2. The PyTorch module that implements the deconvolution is `nn.ConvTranspose2d`
4. The final layer of the expanding part has only 1 channel
    * between how many classes are we discriminating?
 
Create a network class with (at least) a `__init__` and a `forward` method. Please resort to additional structures (e.g., `nn.Module`s, private methods...) if you believe it helps readability of your code.
 
Test, at least with random data, that the network is doing the correct tensor operations and that the output has the correct shape (e.g., use `assert`s in your code to see if the byproduct is of the expected shape).
 
Note: the overall organization of your work can greatly improve readability and understanding of your code by others. Please consider preparing your notebook in an organized fashion so that we can better understand (and correct) your implementation.

![](../labs/img/unet_small.png)

#### U-Net [3](https://arxiv.org/abs/1505.04597)

The architecture is a composition of two parts: a contracting module and an expanding module.

The contracting module is a sequence of VGG convolutional blocks. 

After reaching the bottom part, we begin upsampling the image following the inverse of the scheme from the contracting module, with an additional operation: we concatenate the output of the upsampling with the output from the last convolutional layer of the corresponding block (as in the image). Thus, if the upsampling yields a 256-channel output, we concatenate this output with the output of the last 256-channel convolutional layer from the contracting module. This leaves us with a 512-channel tensor which we convolve to 256-channels once again. Note that, if the spatial dimensions of the data from the contracting module doesn't match those of the upsampled data, cropping is operated so that we can safely concatenate the two tensors.

Actually, the original implementation of U-Net operates a semantic segmentation on a window which is approximately 2/3 of the original image (there will hence be a leftover band of pixels outside the center of the image). In the image below, the white thin lines represent the area that will be subject to the segmentation.

![](../labs/img/unet_crop.jpg)

For what concerns the output, instead, we end up with a tensor of shape $h^\prime \times w^\prime \times C$, where $C$ denotes the number of the classes we want to operate segmentation (logically speaking, **if we want to classify each pixel, we wish to produce a softmax for each pixel**).

![](../labs/img/unet_last.jpg)

In [None]:
import torch
from torch import nn
import torchvision.transforms as T
from torchsummary import summary

In [None]:
class Contraction(nn.Module):
    def _build_vgg_block(self, num_conv_layers, in_channels, out_channels, kernel_size=3, batchnorm=True, activation=nn.ReLU, maxpool=True):
        layers = []
        for i in range(num_conv_layers):
            if i == 0:
                num_channels_in = in_channels
            else:
                num_channels_in = out_channels
            
            layers.append(nn.Conv2d(in_channels=num_channels_in, out_channels=out_channels, kernel_size=kernel_size, stride=2))
            if batchnorm:
                layers.append(nn.BatchNorm2d(out_channels))
            layers.append(activation())
        
        if maxpool:
            layers.append(nn.MaxPool2d(kernel_size=2))
        return nn.Sequential(*layers)

    def __init__(self, num_classes=10, in_channels=3):
        super().__init__()
        self.features = nn.Sequential(
            self._build_vgg_block(2, in_channels, 16, activation=nn.SiLU),
            self._build_vgg_block(2, 16, 32, activation=nn.SiLU)
        )
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=1)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32, num_classes)
        )

    def forward(self, X):
        return self.classifier(self.avgpool(self.features(X)))

In [None]:
class Expansion(nn.Module):
    def __init__(self, in_channels=3, out_channels=64, activ=nn.ReLU, bias=False):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=7, stride=2, bias=False),
            nn.BatchNorm2d(out_channels),
            activ(),
            nn.MaxPool2d(kernel_size=3, stride=2)
        )
    def forward(self, X):
        return self.layers(X)

In [None]:
class UNet(nn.Module):
    def __init__(self, num_classes=10, in_channels=3, base_width=64, activ=nn.ReLU):
        super().__init__()
        self.prep = ResNetPrep(in_channels=in_channels, out_channels=base_width, activ=activ)
        self.res_blocks = nn.Sequential(
            ResidualBlock(in_channels=base_width, n_channels=base_width, activ=activ),
            ResidualBlock(in_channels=base_width, n_channels=base_width, activ=activ),

            ResidualBlock(in_channels=base_width, n_channels=base_width * 2, activ=activ, downsample=True),
            ResidualBlock(in_channels=base_width * 2, n_channels=base_width * 2, activ=activ),

            ResidualBlock(in_channels=base_width * 2, n_channels=base_width * 4, activ=activ, downsample=True),
            ResidualBlock(in_channels=base_width * 4, n_channels=base_width * 4, activ=activ),

            ResidualBlock(in_channels=base_width * 4, n_channels=base_width * 8, activ=activ, downsample=True),
            ResidualBlock(in_channels=base_width * 8, n_channels=base_width * 8, activ=activ)
        )
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.classifier = nn.Linear(in_features=base_width*8, out_features=num_classes, bias=True)
    
    def forward(self, X):
        out = self.prep(X)
        out = self.res_blocks(out)
        out = self.classifier(self.avgpool(out))
        return out

In [None]:
net = CNN(in_channels=3)
_ = summary(net)

Test if the net works on random data

In [None]:
X = torch.rand((100,3,28,28))
y = net(X)
y.shape

### References

[1](https://arxiv.org/abs/1603.07285) Dumoulin, Vincent, and Francesco Visin. "A guide to convolution arithmetic for deep learning."

[2](https://openaccess.thecvf.com/content_cvpr_2016/html/He_Deep_Residual_Learning_CVPR_2016_paper.html) He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition

[3](https://arxiv.org/abs/1505.04597) Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention.

[4](https://arxiv.org/abs/1605.07146) Zagoruyko, Sergey, and Nikos Komodakis. "Wide residual networks."

[5](https://arxiv.org/abs/1409.1556v6) Simonyan and Zisserman. "Very Deep Convolutional Networks for Large-Scale Image Recognition."