#Applied AI for Health Research

#Practical 5: CNN Architectures

Tutorial by Cher Bass and Emma Robinson. Edited by Mariana da Silva.

Let's start by importing the modules and data that we need for the notebook. We start by training and testing on the MNIST dataset, which we have previously used in other Practicals.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F # Contains some useful functions like activation 
                                # functions & convolution operations you can use

import torchvision
import numpy as np
from torchvision import datasets, models, transforms

# This is used to transform the images to Tensor and normalize it
transform = transforms.Compose(
   [transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])])

training = torchvision.datasets.MNIST(root='./data', train=True,
                                       download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(training, batch_size=32,
                                         shuffle=True, num_workers=2)

testing = torchvision.datasets.MNIST(root='./data', train=False,
                                      download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testing, batch_size=32,
                                        shuffle=False, num_workers=2)

classes = ('0', '1', '2', '3',
          '4', '5', '6', '7', '8', '9')

Now set your device to cuda:


In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
print(device) 



---




##Exercise 1: ResNet

**ResNet with PyTorch**

ResNet was first introduced in 2015 as a way to support training of deeper networks through supporting networks in learning identity mappings during training. It does this through implementation of residual blocks

An example of a resnet block (from the original [2015 paper](https://arxiv.org/abs/1512.03385)) is illustrated below (see [image source](t)):

<figure>
<img src="https://drive.google.com/uc?id=1NQ_sLsu0GsXQsVEQ9Rtm5Mnvm2BuucCZ" alt="Drawing" style="width: 800px;"/>
<figcaption align = "centre"> Fig 1. ResNet 2015 residual block  </figcaption>
</figure>

Here, input data passes down two paths. In one, it is passed through two convolutional (weights learning) layers; in the other it skips these out to be added to the output of these layers. This shortcut operation is the identity mapping. If there are no gains to be made by learning more weights kernels (the nework is already deep enough); then the network can simply learn to pass the input unchanged through the block (an idenity transform) by pushing these weights kernels to zero.



---



**Using PyTorch implementation**

Torchvision offers some default implementations of popular networks

For example the following pretrained resnets models can be loaded in Pytorch:
```python
import torchvision
torchvision.models.resnet18(pretrained=True, **kwargs)
```

You can also load a model that hasn't been pretrained in the following way:
```python
torchvision.models.resnet18(pretrained=False, **kwargs)
```

To see more examples, including networks such as ResNet, Alexnet, VGG, Densenet, see [torchvision models](https://pytorch.org/docs/stable/torchvision/models.html) and, for usage, see the official [tutorial](https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html)

However, these pretrained models will not always suit your needs. For example, the resnet models are designed for 3 channel, two-dimensional input  (i.e. for RGB natural channels); this means that you can't use them without adjustment on grayscale images, or on 3D medical data.




---


**Coding the residual block**

The first thing we need to do is implement a `BasicBlock` class, which will implement a single ResNet (2015) block, which includes the following steps (see Fig 1): 

1. **(strided) Convolution, followed by batchnorm**, followed by relu: with option to downsample through stride=2 and increase the number of output channels
2. **Convolution, followed by batchnorm:** stride 1; input and output channels constant
3. **shortcut step**, where the input is first transformed through a strided $1 \times 1$ convolutional operation to match the dimensions of the output of the residual block and then added to the output of the convolutions. 
4. **relu**

Note, **only the first convolution of each block offers the option of upsampling the channel dimension and downsampling the data through striding**. Further, several residual blocks are typically changed together between downsampling steps (see lilac, green and red groups); therefore downsampling is not implemented for all blocks.

<img src="https://drive.google.com/uc?id=1SVmOrg7uxRowWQNtr5jWmz2re9fLID4Q" alt="Drawing" style="width: 800px;"/>

### Ex 1.1 - Create the Residual  block


The most challenging bit of coding up a residual block is implementing the reshapeing of the shortcut step. **So let's start by ignoring it to create the main body of the residual block. This will work _provided we maintain input dimensions_**. 

Let us create a `ResidualBlock` and define (parametrise) the required `Conv2d` and `BatchNorm2d` steps in the constructor `__init__(self, channels1,channels2,res_stride=1)`; here `res_stride` is the intended stride, `channels1` are the number of channels of the incoming activations, and `channels2` is the number of output channels. Note, `res_stride`=1 by default and this should only change if this is intended as a downsampling block;

Note, biases are set to `False` in the block as they are instead handled by the batchnorm layer. Also, observe that the Relu layer is implemented in the forward pass function.

**Task 1.1.1** Edit (`__init__`) to input

1. `self.conv1` a 2D convolution with the power to: a) downsample spatial dimensions (with stride `res_stride` ); and b) upsample channel dimensions (to `channels2`). Set arguments `kernel_size=3, stride=res_stride, padding=1, bias=False`
2. ` self.bn1` a 2D batchnorm layer to follow the first convolutional layer. What does it expect for the number of input features (`num_features`)?
3. `self.conv2` the second convolutional layer. What should its stride, input and output channel dimensions be given **only the first convolution can change output dimensions**? (set `kernel_size=3, padding=1, bias=False` as before)
4. ` self.bn2` a 2D batchnorm layer to follow the second convolutional layer. **Note, a different batch normalisation instance is needed each time as each stores learnable parameters**.

**Task 1.1.2** Edit `forward(self, x)` line 35 to **implement the shortcut**. Here the identity mapping `self.shortcut(x)` must be _added_ to the output of the weights layers. 

For PyTorch documentation, see [nn.Conv2d](https://pytorch.org/docs/stable/nn.html#torch.nn.Conv2d) and [nn.BatchNorm2d](https://pytorch.org/docs/stable/nn.html#torch.nn.BatchNorm2d)

**Make sure you understand what all lines of the forward function are doing**. Note, that the output of the first operation is assigned to variable `out` in order to preserve the input `x` for the shortcut (identity) mapping.


In [None]:
class ResidualBlock(nn.Module):

    def __init__(self, channels1,channels2,res_stride=1):
        super(ResidualBlock, self).__init__()
        self.inplanes=channels1

        # Task 1.1.1 construct the block 
        # implement conv1 (which option for reshaping), conv2 (no reshaping) and 2 batchnorm layers to insert between each
        self.conv1 = nn.Conv2d(channels1, channels2, kernel_size=3, stride=res_stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels2)
        self.conv2 = nn.Conv2d(channels2, channels2, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels2)  

        self.shortcut=nn.Sequential()

    def forward(self, x):
        
        # forward pass: Conv2d > BatchNorm2d > ReLU > Conv2D >  BatchNorm2d > ADD > ReLU
        out=self.conv1(x)
        out=self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)

        # Task 1.1.3 - implement the shortcut (1 line)
        out += self.shortcut(x)

        # final ReLu
        out = F.relu(out)

        return out


### Ex 1.2 . Perform a test forward pass (_keeping input and output dimensions constant_ )


**Task 1.2.1.** Instantiate an instance of class ResidualBlock (create a network called `blk`) with input channels = 3 and output channels = 3; leave `res_stride` as default (1). **We have not implemented a shortcut with downsampling yet so running with stride will fail.**

**Task 1.2.2.** create a random tensor of size $5 \times 3 \times 100 \times 100$ (which matches expected input dimensions $N,C_{in},H,W$)

**Task 1.2.3.** Pass the input through a forward pass and print input and output shape.

HINT: look at how this was done in previous training loops. Remember - you don't need to explicitely call the forward function.

In [None]:
# Task 1.2.1
blk = ResidualBlock(3,3)


# Task 1.2.2
data = torch.randint(0, 255, (5,3,100,100)).to(torch.float)


# Task 1.2.3
output=blk(data)
print(data.shape,output.shape)

### Ex 1.3. Implement the shortcut

Next, lets implement a shortcut with downsampling. 

Specfically, **if this is a reshaping residual block** (`channels2` $\ne$ `channels1` and `res_stride` $\neq$ 1) then **we will also need to reshape the input as it is passed through the shortcut**. Edit the `ResidualBlock` constructor to complete the shortcut function, which will downsample the input as it is passed through the shortcut. 

**Task 1.3.1. Change `ResidualBlock.__init__()` to implement a `nn.Sequential()` block with two steps:**
1. A $1 \times 1 $ `nn.Conv2d` layer with `stride = res_stride, bias = False`.  This will support changes of spatial dimensions through strided convolutions and changes of feature dimensions through $1 \times 1 $ convolutions. What should your input and output channels be to make it equivalent to the output of a _reshaping_ residual block?
2. Batch normalisation. Think carefully about the input dimension. 


In [None]:
class ResidualBlock(nn.Module):

    def __init__(self, channels1,channels2,res_stride=1):
        super(ResidualBlock, self).__init__()
        self.inplanes=channels1

        self.conv1 = nn.Conv2d(channels1, channels2, kernel_size=3, stride=res_stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels2)
        self.conv2 = nn.Conv2d(channels2, channels2, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels2)

        if res_stride != 1 or channels2 != channels1:
            # Exercise 1.3 
            # create an nn.Sequential() block with one 1x1 conv2D and one batchnorm
            self.shortcut=nn.Sequential(
                nn.Conv2d(channels1, channels2, kernel_size=1, stride=res_stride, bias=False),
                nn.BatchNorm2d(channels2)
            )

        else:
            self.shortcut = nn.Sequential()

    def forward(self, x):
        
        out=self.conv1(x)
        out=self.bn1(out)
        out = F.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = F.relu(out)

        return out

**Task 1.3.2.** Test the network again, but this time implement `stride = 2` and change the number of output channels to 10.



 

In [None]:
blk = ResidualBlock(3,10,2)

output=blk(data)

print(data.shape,output.shape)



We now have all the building blocks we need to build a residual network. In what follows we will construct a ResNet with four residual layers. Each layer will contain 2 residual blocks. 


### Ex 1.4 - Create a Residual Network class


In the original paper the network starts with a  convolutional layer with a $7 \times 7 $ kernel, followed by a batchnorm. However, as we intend to test on the MNIST (which is very small) lets change the $7 \times 7 $ kernel to a $3 \times 3 $.

We will implement 4 residual layers (or blocks), where the residual block class is passed to the network class as the argument `block`, and output channels and strides for each block are parametrised by lists (also past to the constructor) as `num_features` and `num_strides` respectively.

**Task 1.4.1.** 
- Initialise the network with a 3 × 3 convolutional layer, with input channels = `in_channels`, output channel = `num_features[0]`,  stride = `num_strides[0]`, padding=1 and bias false; 
- Implement a batch normalisation layer to follow this.

**Task 1.4.2.** 

- Comment the function `_make_layer`. What is each line doing? 
- Make sure you understand how this is used to create residual blocks in the constructor 

**Task 1.4.3.** 

- The penultimate layer of the network is an average pool which averages over spatial dimensions to return a flattened vector of length equal to the number of channels of the tensor passed to it. 
- The network must output 10 class predictions
- Bearing that in mind, implement the final linear layer of the network 

**hint** if you remain unsure you can always print the shape of all the tensors in the network

In [None]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_strides, num_features, in_channels, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = num_features[0] # in_planes stores the number channels output from first convolution
        
        # TASK 1.4.2. replace 'None' to initialise the network with a 3x3 conv and batch norm (2 lines):
        self.conv1 = nn.Conv2d(in_channels, num_features[0], kernel_size=3, stride=num_strides[0], padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(num_features[0])

        # Creating 4 residual layers - num_blocks per layer is given by input argument num_blocks (which is an array)
        self.layer1 = self._make_layer(block, num_features[1], num_blocks, stride=num_strides[1])
        self.layer2 = self._make_layer(block, num_features[2], num_blocks, stride=num_strides[2])
        self.layer3 = self._make_layer(block, num_features[3], num_blocks, stride=num_strides[3])
        self.layer4 = self._make_layer(block, num_features[4], num_blocks, stride=num_strides[4])
        
        # TASK 1.4.3. create linear layer:
        self.linear = nn.Linear(num_features[4], num_classes)

    # TASK 1.4.2. comment the function:
    def _make_layer(self, block, planes, num_blocks, stride):
        layers = []
        # create initial layer with option of downsampling and increase channels
        layers.append(block(self.in_planes, planes, stride))
        # then create num_blocks more for each group
        for i in np.arange(num_blocks):
            layers.append(block(planes, planes))
        
        # update class attribute in_planes which is keeping track of input channels
        self.in_planes = planes 
              
        return nn.Sequential(*layers) # return sequential object comining layers
        
    def forward(self, x):
      # initial convolution and batch norm
        out = F.relu(self.bn1(self.conv1(x)))
        # residual blocks 
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        
        # average pool (flattens spatial dimensions)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)       
        out = self.linear(out)

        return out

### Ex 1.5 - Train on MNIST for classification



Below we have created an instance of our resent class which runs four levels of residual blocks, with 2 blocks in each group

In [None]:
import torch.optim as optim

resnet = ResNet(ResidualBlock,3, [1,1,2,2,2], [64,64,128,256,512], in_channels=1)

# see how the network is loaded to the device (GPU)
# this allows the optimisation to be run on GPU
resnet = resnet.to(device) 

**Task 1.5.1.** Create a suitable loss function for classification

In [None]:
loss_fun = nn.CrossEntropyLoss()

**Task 1.5.2.** Create an SGD optimiser with momentum, and assign learning rate as 0.001

In [None]:
optimizer = optim.SGD(resnet.parameters(), lr=0.001, momentum=0.9)

**Task 1.5.3.** Complete the training function - don't forget to set runtime to GPU and to push input data and labels (from each batch) to the device.

In [None]:
epochs = 1
for epoch in range(epochs): 

    # enumerate can be used to output iteration index i, as well as the data 
    for i, (data, labels) in enumerate(train_loader, 0):
        # TASK 1.5.3. Complete training loop:
        data = data.to(device)
        labels = labels.to(device)
        # clear the gradient
        optimizer.zero_grad()

        #feed the input and acquire the output from network
        outputs = resnet(data)

        #calculating the predicted and the expected loss
        loss = loss_fun(outputs, labels)

        #compute the gradient
        loss.backward()

        #update the parameters
        optimizer.step()
        
        # Print statistics of loss tensor
        ce_loss = loss.item()
        if i % 100 == 0:
            print('[%d, %5d] loss: %.3f' %
                 (epoch + 1, i + 1, ce_loss))

**Task 1.5.4.** Test performance of your network by running the validation code in the cells below:

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

# Make an iterator from test_loader
test_iterator = iter(test_loader)

# Get a batch of testing images
images, labels = test_iterator.next()

images = images.to(device)
labels = labels.to(device)

y_score = resnet(images)

# Get predicted class from the class probabilities
_, y_pred = torch.max(y_score, 1)

print('Predicted: ', ' '.join('%5s' % classes[y_pred[j]] for j in range(8)))

# Plot true label (t) vs predicted label (p)
rows = 2
columns = 4
fig2 = plt.figure()
for i in range(8):
    fig2.add_subplot(rows, columns, i+1)
    plt.title('t: ' + classes[labels[i].cpu()] + ' p: ' + classes[y_pred[i].cpu()])
    img = images[i] / 2 + 0.5     # this is to unnormalize the image
    img = torchvision.transforms.ToPILImage()(img.cpu())
    plt.axis('off')
    plt.imshow(img)
plt.show()

In [None]:
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score

y_true = labels.data.cpu().numpy()
y_pred = y_pred.data.cpu().numpy()

accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='macro')
precision = precision_score(y_true, y_pred, average='macro')
recall = recall_score(y_true, y_pred, average='macro')

print('accuracy:', accuracy, ', f1 score:', f1, ', precision:', precision, ', recall:', recall)



---



##Exercise 2: Image segmentation with PyTorch using U-net




U-net was first developed in 2015 by [Ronneberger et al.](https://arxiv.org/abs/1505.04597), as a segmentation network for biomedical image analysis.
It has been extremely successful, with 9,000+ citations, and many new methods that have used the U-net architecture since.


The architecture of U-net is based on the idea of using skip connections (i.e. concatenating) at different levels of the network to retain high, and low level features.

Here is the architecture of a U-net:

<img src="https://drive.google.com/uc?id=1zUKKrbcB1BZxJ7-hEYpteCVlVFRJ1nRg" alt="Drawing" style="width: 800px;"/>



---


**Dataset**

In this tutorial we use a dataset of cortical neurons with their corresponding segmentation binary labels.

These images were collected using in-vivo two-photon microscopy from the mouse somatosensory cortex. To generate the 2D images, a max projection was used over the 3D stack. The labels are binary segmentation maps of the axons.

Here we will use 100 [64x64] crops during training and validation. Below are some example images [256x256] from the original dataset, taken from [Bass et al 2019](http://proceedings.mlr.press/v102/bass19a.html)

<img src="https://drive.google.com/uc?id=1YRJev88nBr4aqyaHRU27KWxFaX4JwUJz" alt="Drawing" style="width: 800px;"/>

Run the cell below to download the data and auxiliary scripts to your Colab working directory:

In [None]:
!wget -nv https://github.com/IS-pillar-3/datasets/raw/main/session_5/AxonDataset.py
!wget -nv https://github.com/IS-pillar-3/datasets/raw/main/session_5/org64_data_train.npy
!wget -nv https://github.com/IS-pillar-3/datasets/raw/main/session_5/org64_mask_train.npy

In [None]:
!pwd /content

#load modules
from __future__ import print_function
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
import time
import torch.nn.functional as F
import torchvision.utils as vutils
import os
import matplotlib.pyplot as plt

from AxonDataset import *

# Setting parameters
timestr = time.strftime("%d%m%Y-%H%M")
__location__ = os.path.realpath(
    os.path.join(os.getcwd(), os.path.dirname('__file__')))

path = __location__

results_path = os.path.join(__location__,'results')
if not os.path.exists(results_path):
    os.makedirs(results_path)

print(path)

### Ex 2.1 Creating a dataloader

In this example, a custom dataset was created, and we import it from `AxonDataset.py`. 

Then, rather than create a separate test and train instance, we override the default DataLoader `shuffle` option to instead implement a  `torch.utils.data.sampler.SubsetRandomSampler` to 'sample elements randomly from a given list of indices, without replacement' (see [PyTorch documentation](https://pytorch.org/docs/stable/data.html#torch.utils.data.SubsetRandomSampler)).

To do this:
1. data is randomly split into 80% train and 20% validation sets
2. these lists are passed to class `SubsetRandomSampler` to create a sampling instance for each group
3. train and validation DataLoaders are created from the same dataset by passing a different sampler for each class

This is a good way of randomly separating your own data, in instances where PyTorch does not provide custom Datasets

**Task 2.1.1.** Comment the lines of code below to verify you understand how the bespoke samplers are implemented

In [None]:
#First we create a custom dataset of two photon microscopy images of axons
axon_dataset = AxonDataset(data_name='org64', type='train')

# We need to further split our training dataset into training and validation sets.
# Determine the number of examples in train and validation and create list of all indices
indices = list(range(len(axon_dataset)))  
# define the split size
split = int(len(indices)*0.2)  

# Get random list of indices for validation
validation_idx = np.random.choice(indices, size=split, replace=False)
# training examples are remainder 
train_idx = list(set(indices) - set(validation_idx))

# feed indices into the SubsetRandomSampler to create a sampling instance for each of train and validation
train_sampler = SubsetRandomSampler(train_idx)
validation_sampler = SubsetRandomSampler(validation_idx)

batch_size = 16

# Create a dataloader instance overriding shuffle to pass the bespoke samplers
train_loader = torch.utils.data.DataLoader(axon_dataset, batch_size = batch_size,
                                           sampler=train_sampler) 
val_loader = torch.utils.data.DataLoader(axon_dataset, batch_size = batch_size,
                                        sampler=validation_sampler) 




---


### Building a U-net

The original U-net encoder performs downsampling through a $2 \times 2 $ max pool (however, strided convolutions are equally viable). Thus, in what follows, a single level of encoding can be represented as:

```
 conv1 = self.dconv_down1(x)
 conv1 = self.dropout(conv1)
 x = self.maxpool(conv1)
 ```
Here, a dropout layer is inserted between the convolutional layer and the maxpool for regularisation. An alternative approach is to insert a batchnorm between the `nn.Conv2d` and the `nn.ReLU` e.g. [see](https://github.com/milesial/Pytorch-UNet)

Next we need to define how we perform an upsample step. This  is performed through use of [`nn.Upsample`](https://pytorch.org/docs/stable/nn.html#torch.nn.Upsample), which interpolates the data to a higher resolution grid. The layer must be created in the consructor (see line 14)  and expects arguments `scale_factor` and (interpolation) `mode`. There are several options for the interpolation mode; we recommend bilinear. In this example we upsample by a `scale_factor` of 2 each time (to match the $2\times 2$ max pool used during downsampling). 
        
Then, a single level of decoding might may represented as:

```
 deconv4 = self.upsample(conv5)
 deconv4  = self.dconv_up4(deconv4)
 deconv4 = self.dropout(deconv4)
 ```
 
However, we are still missing something vital...

**Skip connections**

The U-net is a symmetric network with equal numbers of encoding and decoding layers. These form pairs where the spatial dimensions of each encoder/decoder layer in the pair are consistent.

A key feature of the U-net is that to support segmentation of sharp boundaries, with preservation of high resolution features, it is necessary to pass features learnt during encoding across the network. The theory is that the early layers, with their small-receptive fields, learn the high-spatial frequency information (i.e. they act as edge detectors and/or texture filters). As the receptive field increases during encoding spatial specicity is lost, but spatial localisation (where class relevant objects broadly are in the image) is gained. In order to import the high spatial frequency information of the early encoding layers into the final decoding layers the *activations* learnt during encoding are directly concatenated onto the upsampled activations of the paired decoding layer.

In other words for the first decoding layer (which for a 5-layer U-Net is the layer that directly follows the bottleneck `conv5`) is:

```
 deconv4 = self.upsample(conv5)
 deconv4 = torch.cat([deconv4, conv4], dim=1)
 deconv4  = self.dconv_up4(deconv4)
 deconv4 = self.dropout(deconv4)
 ```
 
 The activations (output) of convolution layer conv (`conv4`) is directly concatenated to the output of `self.upsample` where concatenation is performed on the channel axis (`axis=1`); Thus putting this all together...

### Ex 2.2 - Creating the U-net Class

First we define a layer double_conv that performs 2 sets of convolution followed by ReLu. This is set up as a nn.Sequential() block.

In [None]:
def double_conv(in_channels, out_channels, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=padding),
        nn.ReLU(inplace=True)
    )

We then initialise all the different layers in the network in `__init__`:
1. `self.dconv_down1` is a double convolutional layer (defined above)
2. `self.maxpool` is a max pooling layer that is used to reduce the size of the input, and increase the receptive field
3. `self.upsample` is an upsampling layer that is used to increase the size of the input
4. `dropout` is a dropout layer that is applied to regularise the training
5. `dconv_up4` is also a double convolutional layer- note that it takes in additional channels from previous layers (i.e. the skip connections).


**Task 2.2.1.** Following the example for `conv1` complete encoder layers 2, 3 and 4. How many features does each layer have?

**Task 2.2.2.** Complete layer `conv5`; this is the bottleneck layer (the bottom of the network) and thus **has no maxpool**.

**Task 2.2.3.**  Using the upsampling and skip connection example above implement the decoder layers `deconv4`, `deconv3`, `deconv2`, `deconv1`.

**Task 2.2.4.** We are expecting class labels as output; thus the output requires a sigmoid transformation; check you understand what this does?

In [None]:
class UNet(nn.Module):

    def __init__(self):
        super().__init__()
        
        self.dconv_down1 = double_conv(1, 32)
        self.dconv_down2 = double_conv(32, 64)
        self.dconv_down3 = double_conv(64, 128)
        self.dconv_down4 = double_conv(128, 256)
        self.dconv_down5 = double_conv(256, 512)

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
        self.dropout = nn.Dropout2d(0.5)
        self.dconv_up4 = double_conv(256 + 512, 256)
        self.dconv_up3 = double_conv(128 + 256, 128)
        self.dconv_up2 = double_conv(128 + 64, 64)
        self.dconv_up1 = double_conv(64 + 32, 32)

        self.conv_last = nn.Conv2d(32, 1, 1)

    def forward(self, x):
        
        ####### ENCODER #######
        
        # layer 1
        conv1 = self.dconv_down1(x)
        conv1 = self.dropout(conv1)
        x = self.maxpool(conv1)

        # TASK 3.2.1. Replace 'nones' to implement encoder layers conv2, conv3 and conv4
        # layer 2
        conv2 = self.dconv_down2(x)
        conv2 = self.dropout(conv2)
        x = self.maxpool(conv2)

        # layer 3
        conv3 = self.dconv_down3(x)
        conv3 = self.dropout(conv3)
        x = self.maxpool(conv3)

        # layer 4
        conv4 = self.dconv_down4(x)
        conv4 = self.dropout(conv4)
        x = self.maxpool(conv4)

        # TASK 3.2.2. Replace 'nones' to implement bottleneck (hint: 2 lines)
        # layer 5
        conv5 = self.dconv_down5(x)
        conv5 = self.dropout(conv5) 
        

        ####### DECODER #######

        # TASK 3.2.3. Implement the decoding layers
        deconv4 = self.upsample(conv5)
        deconv4 = torch.cat([deconv4, conv4], dim=1)  
        deconv4  = self.dconv_up4(deconv4)
        deconv4 = self.dropout(deconv4)

        deconv3 = self.upsample(deconv4 )       
        deconv3 = torch.cat([deconv3, conv3], dim=1)
        deconv3 = self.dconv_up3(deconv3)
        deconv3 = self.dropout(deconv3)

        deconv2 = self.upsample(deconv3)      
        deconv2 = torch.cat([deconv2, conv2], dim=1)
        deconv2 = self.dconv_up2(deconv2)
        deconv2 = self.dropout(deconv2)
       
        deconv1 = self.upsample(deconv2)   
        deconv1 = torch.cat([deconv1, conv1], dim=1)
        deconv1 = self.dconv_up1(deconv1)
        deconv1 = self.dropout(deconv1)

        out = F.sigmoid(self.conv_last(deconv1)) 

        return out

### Exercise 2.3 - Create loss function


We next define our loss function - in this case we use Dice loss, a commonly used loss for image segmentation that is essentially a measure of overlap between two samples.

Dice is in the range of 0 to 1, where a Dice coefficient of 1 denotes perfect and complete overlap. The Dice coefficient was originally developed for binary data, and can be calculated as:

$Dice = \dfrac{2|A\cap B|}{|A| + |B|}$

where $|A\cap B|$ represents the common elements between sets $A$ and $B$, and $|A|$ represents the number of elements in set $A$ (and likewise for set $B$).

For the case of evaluating a Dice coefficient on predicted segmentation masks, we can get the intersection $|A\cap B|$ as the element-wise multiplication between the prediction and target mask, and then sum the resulting matrix.

**Task 2.3.1.** Implement the calculation of the Dice coeficient in the function below:

In [None]:
def dice_coeff(pred, target):
    smooth = 1. # this is added to avoid errors coming from possible division by zero
    epsilon = 10e-8

    iflat = pred.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum() + smooth

    A_sum = torch.sum(iflat) 
    B_sum = torch.sum(tflat) + smooth

    # TASK 2.3.1. replace 'None' to implement the dice coeficient 
    dice = (2. * intersection + smooth) / (A_sum + B_sum + smooth)
    dice = dice.mean(dim=0)
    dice = torch.clamp(dice, 0, 1.0-epsilon)

    return  dice

An **alternative loss** function would be pixel-wise binary cross entropy (BCE) loss. It would examine each pixel individually, comparing the class predictions (depth-wise pixel vector) to our one-hot encoded target vector.

**Task 2.3.2.** Implement the binary cross entropy loss using the PyTorch function

In [None]:
# TASK 2.3.2. implement the BCE loss
loss_BCE = nn.BCELoss()



---



### Saving and loading models
For practical reasons training this network from scratch will take too long, and require large computational resources. To save time we initialise the network with a previously trained network by loading the weights in the following way:

In [None]:
# Initialise network - and load weights
net = UNet()

# This function loads a pretrained network
net.load_state_dict(torch.load(path + '/' + 'model.pt', map_location=torch.device(device)))
net = net.to(device)

# Example how to save a model - check in your results path
torch.save(net.state_dict(), path + '/model_save_test.pt')

In general [PyTorch documentation](https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html#save-the-general-checkpoint), it is advised that you save and load not just network paramters but also the state of the optimiser, current state of the loss and the epoch:

<figure>
<img src="https://drive.google.com/uc?id=1hU498xlA_DbstHSSUwqfm9U9fqtGCZw1" alt="Drawing" width="800px;"/>
</figure>

<figure>
<img src="https://drive.google.com/uc?id=1mQllAgFWxZ9ViaXJmPjeStXJViroBQHi" alt="Drawing" width="800px;"/>
</figure>

More details on options for saving and loading are provided [here](https://pytorch.org/tutorials/beginner/saving_loading_models.html). These provide an example of the model and optimiser `state_dict()` where you can see these are python dictionaries with:

- **in the case of the model**: keys which store the current state of weight and bias tensors of the model. In more general terms the model dict will store all parameter tensors required to restart the model
- **in the case of the optimiser**: this dict stores the hyper-parameters of the optimiser e.g. learning rate, momentum, weight decay etc as well as the current state of the optimiser object.

Note, **if saving for inference _only_, it is only necessary to save the `model.state_dict()_**

A common PyTorch convention is to save models using either a .pt or .pth file extension (see example of `SAVE_PATH`) above.

### Ex 2.4 - Training and Evaluation

Review the stages of training with:
- Network set as `net.train()` for training and `net.eval()` for validation
- Clearing of gradients
- Loss as `loss = 1 - dice_coeff(pred,target)` where pred is output of forwards pass
- Backpropagation and update


The results are saved per epoch for both training and validation in the `content/results/` folder, and are saved as the:

* real data,
* binary labels,
* predicted labels.

In this example, since we trained on a small sample of the data (100 crops), the results are far from optimal and are likely to overfit to the data.

In [None]:
from distutils.version import LooseVersion
epochs = 10
save_every = 10
all_error = np.zeros(0)
all_error_L1 = np.zeros(0)
all_error_dice = np.zeros(0)
all_dice = np.zeros(0)
all_val_dice = np.zeros(1)
all_val_error = np.zeros(0)

for epoch in range(epochs):

    ##########
    # Train
    ##########

    # set network to train prior to training loop 
    net.train() 
    t0 = time.time()
    for i, (data, label) in enumerate(train_loader):
        data = data.to(device)
        label= label.to(device)
        
        optimizer.zero_grad() 

        target_real = torch.ones(data.size()[0])
        batch_size = data.size()[0]
        pred = net(data)

        dice_value = dice_coeff(pred, label)
        
        loss = 1 - dice_value

        loss.backward()
        optimizer.step()
        
        time_elapsed = time.time() - t0
        print('[{:d}/{:d}][{:d}/{:d}] Elapsed_time: {:.0f}m{:.0f}s Loss: {:.4f} Dice: {:.4f}'
              .format(epoch, epochs, i, len(train_loader), time_elapsed // 60, time_elapsed % 60,
                      LooseVersion.item(), dice_value.item()))

        if i % save_every == 0:

            vutils.save_image(data.data, '%s/epoch_%03d_i_%03d_train_data.png' % (path, epoch, i),
                                  normalize=True)
            vutils.save_image(label.data, '%s/epoch_%03d_i_%03d_train_label.png' % (path, epoch, i),
                                  normalize=True)
            vutils.save_image(pred.data, '%s/epoch_%03d_i_%03d_train_pred.png' % (path, epoch, i),
                                  normalize=True)
            error = loss.item()

            all_error = np.append(all_error, error)
            all_dice = np.append(all_dice, dice_value.item())

    # # Task 2.4.2 - Checkpointing
    torch.save({
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': err,
            }, path+'/model_ch.pt')

    # #############
    # # Validation
    # #############
    mean_error = np.zeros(0)
    mean_dice = np.zeros(0)
    t0 = time.time()

    # set network to eval prior to training loop 
    net.eval()
    for i, (data, label) in enumerate(val_loader):
        data=data.to(device)
        label=label.to(device)     
        batch_size = data.size()[0]

        pred = net(data)

        dice_value = dice_coeff(pred, label)

        err = 1 - dice_value

        if i == 0:
            vutils.save_image(data.data, '%s/epoch_%03d_i_%03d_val_data.png' % (path, epoch, i),
                              normalize=True)
            vutils.save_image(label.data, '%s/epoch_%03d_i_%03d_val_label.png' % (path, epoch, i),
                              normalize=True)
            vutils.save_image(pred.data, '%s/epoch_%03d_i_%03d_val_pred.png' % (path, epoch, i),
                              normalize=True)

        error = err.item()
        mean_error = np.append(mean_error, error)
        mean_dice = np.append(mean_dice, dice_value.item())

    all_val_error = np.append(all_val_error, np.mean(mean_error))
    all_val_dice = np.append(all_val_dice, np.mean(mean_dice))

    time_elapsed = time.time() - t0

    print('Elapsed_time: {:.0f}m{:.0f}s Val dice: {:.4f}'
          .format(time_elapsed // 60, time_elapsed % 60, mean_dice.mean()))
    
    num_it_per_epoch_train = ((train_loader.dataset.x_data.shape[0] * (1 - 0.2)) // 
                              (save_every * batch_size)) + 1

    epochs_train = np.arange(1,all_error.size+1) / num_it_per_epoch_train
    epochs_val = np.arange(0,all_val_dice.size)

**Task 2.4.1.** Which operations are implemented only in the training loop and not during validation? Answer in the cell below:

In [None]:
# optimizer.zero_grad(), loss.backward() and optimizer.step()

**Task 2.4.2.** Edit the training/validation code to add checkpointing using torch.save() - i.e. save the model and optimiser state_dict every epoch. 

**Task 2.4.3.** Plot the dice coefficient of the validation set using the plotting fucntion below; Note down your dice validation scores.

In [None]:
%matplotlib inline

plt.figure()
plt.plot(epochs_val, all_val_dice, label='dice_val')
plt.xlabel('epochs')
plt.legend()
plt.title('Dice score')
plt.savefig(path + '/dice_val.png')
plt.close()

**Task 2.4.4.** Change the dice loss to a binary cross entropy loss in the code and plot the results again - is dice loss or cross entropy loss better?

In [None]:
from distutils.version import LooseVersion
epochs = 10
save_every = 10
all_error = np.zeros(0)
all_error_L1 = np.zeros(0)
all_error_dice = np.zeros(0)
all_dice = np.zeros(0)
all_val_dice = np.zeros(1)
all_val_error = np.zeros(0)

for epoch in range(epochs):

    ##########
    # Train
    ##########

    # set network to train prior to training loop 
    net.train() 
    t0 = time.time()
    for i, (data, label) in enumerate(train_loader):
        data = data.to(device)
        label= label.to(device)
        
        optimizer.zero_grad() 

        target_real = torch.ones(data.size()[0])
        batch_size = data.size()[0]
        pred = net(data)

        dice_value = dice_coeff(pred, label)
        
        loss = loss_BCE(pred, label)

        loss.backward()
        optimizer.step()
        
        time_elapsed = time.time() - t0
        print('[{:d}/{:d}][{:d}/{:d}] Elapsed_time: {:.0f}m{:.0f}s Loss: {:.4f} Dice: {:.4f}'
              .format(epoch, epochs, i, len(train_loader), time_elapsed // 60, time_elapsed % 60,
                      LooseVersion.item(), dice_value.item()))

        if i % save_every == 0:

            vutils.save_image(data.data, '%s/epoch_%03d_i_%03d_train_data.png' % (path, epoch, i),
                                  normalize=True)
            vutils.save_image(label.data, '%s/epoch_%03d_i_%03d_train_label.png' % (path, epoch, i),
                                  normalize=True)
            vutils.save_image(pred.data, '%s/epoch_%03d_i_%03d_train_pred.png' % (path, epoch, i),
                                  normalize=True)
            error = loss.item()

            all_error = np.append(all_error, error)
            all_dice = np.append(all_dice, dice_value.item())


    # #############
    # # Validation
    # #############
    mean_error = np.zeros(0)
    mean_dice = np.zeros(0)
    t0 = time.time()

    # set network to eval prior to training loop 
    net.eval()
    for i, (data, label) in enumerate(val_loader):
        data=data.to(device)
        label=label.to(device)     
        batch_size = data.size()[0]

        pred = net(data)

        dice_value = dice_coeff(pred, label)

        err = loss_BCE(pred, label)

        if i == 0:
            vutils.save_image(data.data, '%s/epoch_%03d_i_%03d_val_data.png' % (path, epoch, i),
                              normalize=True)
            vutils.save_image(label.data, '%s/epoch_%03d_i_%03d_val_label.png' % (path, epoch, i),
                              normalize=True)
            vutils.save_image(pred.data, '%s/epoch_%03d_i_%03d_val_pred.png' % (path, epoch, i),
                              normalize=True)

        error = err.item()
        mean_error = np.append(mean_error, error)
        mean_dice = np.append(mean_dice, dice_value.item())

    all_val_error = np.append(all_val_error, np.mean(mean_error))
    all_val_dice = np.append(all_val_dice, np.mean(mean_dice))

    time_elapsed = time.time() - t0

    print('Elapsed_time: {:.0f}m{:.0f}s Val dice: {:.4f}'
          .format(time_elapsed // 60, time_elapsed % 60, mean_dice.mean()))
    
    num_it_per_epoch_train = ((train_loader.dataset.x_data.shape[0] * (1 - 0.2)) // 
                              (save_every * batch_size)) + 1

    epochs_train = np.arange(1,all_error.size+1) / num_it_per_epoch_train
    epochs_val = np.arange(0,all_val_dice.size)

In [None]:
%matplotlib inline

plt.figure()
plt.plot(epochs_val, all_val_dice, label='dice_val')
plt.xlabel('epochs')
plt.legend()
plt.title('Dice score')
plt.savefig(path + '/dice_val.png')
plt.close()

**Task 2.4.5.** Re-load and re-train from your saved model, making sure to load the state dict for the model *and* the optimiser.

In [None]:
model = UNet()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-05, betas=(0.5, 0.999))

checkpoint = torch.load(path+'/model_ch.pt')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch_saved = checkpoint['epoch']
loss = checkpoint['loss']

epochs=20
all_error = np.zeros(0)
all_error_L1 = np.zeros(0)
all_error_dice = np.zeros(0)
all_dice = np.zeros(0)
all_val_dice = np.zeros(1)
all_val_error = np.zeros(0)

for epoch in range(epoch_saved, epochs):

    ##########
    # Train
    ##########

    # set network to train prior to training loop 
    net.train() # this will ensure that parameters will be updated during training & that dropout will be used
    t0 = time.time()
    for i, (data, label) in enumerate(train_loader):
        data = data.to(device)
        label= label.to(device)
        
        optimizer.zero_grad() 

        target_real = torch.ones(data.size()[0])
        batch_size = data.size()[0]
        pred = net(data)
        
        err = 1- dice_coeff(pred, label) #loss_BCE(pred, label)

        dice_value = dice_coeff(pred, label).item()

        err.backward()
        optimizer.step()
        

        time_elapsed = time.time() - t0
        print('[{:d}/{:d}][{:d}/{:d}] Elapsed_time: {:.0f}m{:.0f}s Loss: {:.4f} Dice: {:.4f}'
              .format(epoch, epochs, i, len(train_loader), time_elapsed // 60, time_elapsed % 60,
                      err.item(), dice_value))

        if i % save_every == 0:
            # setting your network to eval mode to remove dropout during testing
            net.eval()

            vutils.save_image(data.data, '%s/epoch_%03d_i_%03d_train_data.png' % (path, epoch, i),
                                  normalize=True)
            vutils.save_image(label.data, '%s/epoch_%03d_i_%03d_train_label.png' % (path, epoch, i),
                                  normalize=True)
            vutils.save_image(pred.data, '%s/epoch_%03d_i_%03d_train_pred.png' % (path, epoch, i),
                                  normalize=True)

            error = err.item()

            all_error = np.append(all_error, error)
            all_dice = np.append(all_dice, dice_value)

   
    # #############
    # # Validation
    # #############
    mean_error = np.zeros(0)
    mean_dice = np.zeros(0)
    t0 = time.time()

    # set network to eval prior to training loop 
    net.eval()
    for i, (data, label) in enumerate(val_loader):
        data=data.to(device)
        label=label.to(device)     
        batch_size = data.size()[0]

        pred = net(data)
        
        err = 1-dice_coeff(pred, label)  #loss_BCE(pred, label)
        
        # compare generated image to data-  metric
        dice_value = dice_coeff(pred, label).item()

        if i == 0:
            vutils.save_image(data.data, '%s/epoch_%03d_i_%03d_val_data.png' % (path, epoch, i),
                              normalize=True)
            vutils.save_image(label.data, '%s/epoch_%03d_i_%03d_val_label.png' % (path, epoch, i),
                              normalize=True)
            vutils.save_image(pred.data, '%s/epoch_%03d_i_%03d_val_pred.png' % (path, epoch, i),
                              normalize=True)

        error = err.item()
        mean_error = np.append(mean_error, error)
        mean_dice = np.append(mean_dice, dice_value)

    all_val_error = np.append(all_val_error, np.mean(mean_error))
    all_val_dice = np.append(all_val_dice, np.mean(mean_dice))

    time_elapsed = time.time() - t0

    print('Elapsed_time: {:.0f}m{:.0f}s Val dice: {:.4f}'
          .format(time_elapsed // 60, time_elapsed % 60, mean_dice.mean()))
    
    
    num_it_per_epoch_train = ((train_loader.dataset.x_data.shape[0] * (1 - 0.2)) // (
            save_every * batch_size)) + 1
    epochs_train = np.arange(1,all_error.size+1) / num_it_per_epoch_train
    epochs_val = np.arange(0,all_val_dice.size)