# CNN Architectures

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F #contains some useful functions like activation functions & convolution operations you can use

device = torch.device("cuda: 0" if torch.cuda.is_available() else "cpu")


## ResNet with pytorch

ResNet was first introduced in 2016 as a way to deal with the gradient vanishing problem. This can occur when the network is too deep, and the gradients shrink to zero after a few back propagation steps. This can result in the parameter weights not being updated, since the gradient is zero.

ResNets can counter this problem by allowing the gradients to flow directly backwards, by adding the additive resnet connections.

An example of a resnet block is illustrated below:

![resnet-block](resnet-block.png)
source: https://d2l.ai/chapter_convolutional-modern/resnet.html

He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.

http://www.pabloruizruiz10.com/resources/CNNs/ResNet-PyTorch.html

https://towardsdatascience.com/understanding-and-visualizing-resnets-442284831be8


### Using existing ResNet 

It's possible to load existing networks using pytorch library torchvision - you can load these models using torchvision.models, which contains networks such as ResNet, Alexnet, VGG, Densenet, etc...
https://pytorch.org/docs/stable/torchvision/models.html

For example the following pretrained resnets models can be loaded in Pytorch:
```python
torchvision.models.resnet18(pretrained=True, **kwargs)
```

You can also load a model that hasn't been pretrained in the following way:
```python
torchvision.models.resnet18(pretrained=False, **kwargs)
```

You can find examples of how to use pretrained models in: https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html

However, you will find that using a pretrained model doesn't always suit your needs. For example, the resnet models shown above have been trained on RGB images (i.e. they are 3 channels), which means that you can't use them without adjustment on grayscale images, or on 3D medical data.


### Create custom ResNet layers 

We will therefore now cover how to create a custom ResNet layer that be used to build custom deep neural networks.

A resnet block (implemented as the BasicBlock class below) has the following steps:
1. Convolution, followed by batchnorm, followed by relu
2. Convolution, followed by batchnorm 
3. shortcut step, added to the output of the convolutions
4. relu

### Task 3 - implement the forward pass of the BasicBlock

All the layers have been initialized in the code below.
Implement a forward pass through the BasicBlock. 


In [17]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        #-----------------------------------------------------task 3 -----------------------------------------------------
        # Task 3: implement a forward pass
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        #-----------------------------------------------------------------------------------------------------------------
        return out


In [18]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, in_channels, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out
    
def ResNet18(in_channels=1):
    return ResNet(BasicBlock, [2,2,2,2], in_channels=in_channels)



## Use ResNet for classification - MNIST

First - importing MNIST data into this notebook

In [22]:
import torchvision
import numpy as np
from torchvision import datasets, models, transforms

# This is used to transform the images to Tensor and normalize it
transform = transforms.Compose(
   [transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])])

training = torchvision.datasets.MNIST(root='./data', train=True,
                                       download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(training, batch_size=8,
                                         shuffle=True, num_workers=2)

testing = torchvision.datasets.MNIST(root='./data', train=False,
                                      download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testing, batch_size=8,
                                        shuffle=False, num_workers=2)

classes = ('0', '1', '2', '3',
          '4', '5', '6', '7', '8', '9')


### Task 4 - run ResNet on MNIST for classification
Create a ResNet network and run with the same code as above for classification, and then test.
Remember to define your loss function, optimizer, dataloaders, and your resnet network. 
Then run the training and testing, as before.



In [24]:
#-----------------------------------------------------task 4 -----------------------------------------------------
# Task 4: Train and test ResNet on MNIST dataset for classification
# hints: define your resnet network, loss function, optimizer and dataloaders. 
# Then you can run the same training and testing code as above.
# ----------------------------------------------------------------------------------------------------------------

In [25]:
import torch.optim as optim

resnet = ResNet18(in_channels=1)
resnet = resnet.to(device)

loss_fun = nn.CrossEntropyLoss()
loss_fun = loss_fun.to(device)

optimizer = optim.SGD(resnet.parameters(), lr=0.001, momentum=0.9)

In [26]:
epochs = 1
for epoch in range(epochs): 

    # enumerate can be used to output iteration index i, as well as the data 
    for i, (data, labels) in enumerate(train_loader, 0):
        data = data.to(device)
        labels = labels.to(device)
        # clear the gradient
        optimizer.zero_grad()

        #feed the input and acquire the output from network
        outputs = resnet(data)

        #calculating the predicted and the expected loss
        loss = loss_fun(outputs, labels)

        #compute the gradient
        loss.backward()

        #update the parameters
        optimizer.step()

        # print statistics
        ce_loss = loss.item()
        if i % 10 == 0:
            print('[%d, %5d] loss: %.3f' %
                 (epoch + 1, i + 1, ce_loss))


[1,     1] loss: 2.199
[1,    11] loss: 2.178
[1,    21] loss: 1.931
[1,    31] loss: 2.066
[1,    41] loss: 2.093
[1,    51] loss: 0.878
[1,    61] loss: 0.833
[1,    71] loss: 0.736
[1,    81] loss: 0.985
[1,    91] loss: 0.976
[1,   101] loss: 0.869
[1,   111] loss: 1.298
[1,   121] loss: 1.882
[1,   131] loss: 0.221
[1,   141] loss: 0.600
[1,   151] loss: 0.119
[1,   161] loss: 0.652
[1,   171] loss: 0.109
[1,   181] loss: 0.194
[1,   191] loss: 0.247
[1,   201] loss: 0.297
[1,   211] loss: 0.171
[1,   221] loss: 0.148
[1,   231] loss: 0.257
[1,   241] loss: 0.583
[1,   251] loss: 0.538
[1,   261] loss: 0.341
[1,   271] loss: 0.180
[1,   281] loss: 0.238
[1,   291] loss: 0.049
[1,   301] loss: 0.171
[1,   311] loss: 0.110
[1,   321] loss: 0.510
[1,   331] loss: 0.016
[1,   341] loss: 0.194
[1,   351] loss: 0.116
[1,   361] loss: 0.021
[1,   371] loss: 0.136
[1,   381] loss: 0.078
[1,   391] loss: 0.035
[1,   401] loss: 0.021
[1,   411] loss: 0.026
[1,   421] loss: 0.121
[1,   431] 

[1,  3571] loss: 0.028
[1,  3581] loss: 0.006
[1,  3591] loss: 0.206
[1,  3601] loss: 0.004
[1,  3611] loss: 0.027
[1,  3621] loss: 0.003
[1,  3631] loss: 0.018
[1,  3641] loss: 0.001
[1,  3651] loss: 0.343
[1,  3661] loss: 0.009
[1,  3671] loss: 0.003
[1,  3681] loss: 0.738
[1,  3691] loss: 0.036
[1,  3701] loss: 0.470
[1,  3711] loss: 0.511
[1,  3721] loss: 0.006
[1,  3731] loss: 0.004
[1,  3741] loss: 0.003
[1,  3751] loss: 0.009
[1,  3761] loss: 0.004
[1,  3771] loss: 0.014
[1,  3781] loss: 0.002
[1,  3791] loss: 0.034
[1,  3801] loss: 0.153
[1,  3811] loss: 0.002
[1,  3821] loss: 0.038
[1,  3831] loss: 0.038
[1,  3841] loss: 0.002
[1,  3851] loss: 0.003
[1,  3861] loss: 0.051
[1,  3871] loss: 0.284
[1,  3881] loss: 0.006
[1,  3891] loss: 0.242
[1,  3901] loss: 0.043
[1,  3911] loss: 0.006
[1,  3921] loss: 0.008
[1,  3931] loss: 0.002
[1,  3941] loss: 0.009
[1,  3951] loss: 0.145
[1,  3961] loss: 0.008
[1,  3971] loss: 0.004
[1,  3981] loss: 0.029
[1,  3991] loss: 0.060
[1,  4001] 

[1,  7141] loss: 0.001
[1,  7151] loss: 0.006
[1,  7161] loss: 0.029
[1,  7171] loss: 0.001
[1,  7181] loss: 0.002
[1,  7191] loss: 0.003
[1,  7201] loss: 0.007
[1,  7211] loss: 0.001
[1,  7221] loss: 0.033
[1,  7231] loss: 0.011
[1,  7241] loss: 0.092
[1,  7251] loss: 0.024
[1,  7261] loss: 0.063
[1,  7271] loss: 0.001
[1,  7281] loss: 0.017
[1,  7291] loss: 0.159
[1,  7301] loss: 0.000
[1,  7311] loss: 0.004
[1,  7321] loss: 0.075
[1,  7331] loss: 0.000
[1,  7341] loss: 0.002
[1,  7351] loss: 0.070
[1,  7361] loss: 0.001
[1,  7371] loss: 0.002
[1,  7381] loss: 0.035
[1,  7391] loss: 0.002
[1,  7401] loss: 0.043
[1,  7411] loss: 0.003
[1,  7421] loss: 0.013
[1,  7431] loss: 0.772
[1,  7441] loss: 0.003
[1,  7451] loss: 0.360
[1,  7461] loss: 0.015
[1,  7471] loss: 0.001
[1,  7481] loss: 0.008
[1,  7491] loss: 0.002


In [None]:
#make an iterator from test_loader
#Get a batch of testing images
test_iterator = iter(test_loader)
images, labels = test_iterator.next()

images = images.to(device)
labels = labels.to(device)

y_score = resnet(images)
# get predicted class from the class probabilities
_, y_pred = torch.max(y_score, 1)

print('Predicted: ', ' '.join('%5s' % classes[y_pred[j]] for j in range(8)))

# plot y_score - true label (t) vs predicted label (p)
fig2 = plt.figure()
for i in range(8):
    fig2.add_subplot(rows, columns, i+1)
    plt.title('t: ' + classes[labels[i].cpu()] + ' p: ' + classes[y_pred[i].cpu()])
    img = images[i] / 2 + 0.5     # this is to unnormalize the image
    img = torchvision.transforms.ToPILImage()(img.cpu())
    plt.axis('off')
    plt.imshow(img)
plt.show()


In [None]:
y_true = labels.data.cpu().numpy()
y_pred = y_pred.data.cpu().numpy()

In [None]:
accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='macro')
precision = precision_score(y_true, y_pred, average='macro')
recall = recall_score(y_true, y_pred, average='macro')
print('accuracy:', accuracy, ', f1 score:', f1, ', precision:', precision, ', recall:', recall)

## Use ResNet for classification - CIFAR10

### Task 5 - run ResNet on CIFAR10 for classification 

Create a ResNet network and train for classification on the CIFAR10 dataset.

Here are some example images from the CIFAR10 datasets- we have 10 classes:

![cifar10](cifar10.jpg)
source: https://appliedmachinelearning.blog/2018/03/24/achieving-90-accuracy-in-object-recognition-task-on-cifar-10-dataset-with-keras-convolutional-neural-networks/

You can load the CIFAR10 dataset using torchvision in the following way:
```python
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=8,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=8,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
```
You can use this tutorial as a reference for training on CIFAR10 - https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

Remember to define your loss function, optimizer, dataloaders, and your resnet network. 
Then run the training and testing, same as with MNIST.

In [2]:
#-----------------------------------------------------task 5 -----------------------------------------------------
# Task 5: Train and test ResNet on CIFAR10 dataset for classification
# hints: define your resnet network, loss function, optimizer and dataloaders. 
# Then you can run the same training and testing code as above.
# ----------------------------------------------------------------------------------------------------------------

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=8,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=8,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [6]:
import torch.optim as optim

resnet_cifar = ResNet18(in_channels=3)
resnet_cifar = resnet_cifar.to(device)

loss_fun = nn.CrossEntropyLoss()
loss_fun = loss_fun.to(device)

optimizer = optim.SGD(resnet_cifar.parameters(), lr=0.001, momentum=0.9)

NameError: name 'ResNet18' is not defined

In [None]:
epochs = 1
for epoch in range(epochs): 

    # enumerate can be used to output iteration index i, as well as the data 
    for i, (data, labels) in enumerate(train_loader, 0):
        data = data.to(device)
        labels = labels.to(device)
        # clear the gradient
        optimizer.zero_grad()

        #feed the input and acquire the output from network
        outputs = resnet_cifar(data)

        #calculating the predicted and the expected loss
        loss = loss_fun(outputs, labels)

        #compute the gradient
        loss.backward()

        #update the parameters
        optimizer.step()

        # print statistics
        ce_loss = loss.item()
        if i % 10 == 0:
            print('[%d, %5d] loss: %.3f' %
                 (epoch + 1, i + 1, ce_loss))


170500096it [00:50, 6205325.08it/s]                               

In [None]:
#make an iterator from test_loader
#Get a batch of testing images
test_iterator = iter(test_loader)
images, labels = test_iterator.next()
images = images.to(device)
labels = labels.to(device)

y_score = resnet_cifar(images)
# get predicted class from the class probabilities
_, y_pred = torch.max(y_score, 1)

print('Predicted: ', ' '.join('%5s' % classes[y_pred[j]] for j in range(8)))

# plot y_score - true label (t) vs predicted label (p)
fig2 = plt.figure()
for i in range(8):
    fig2.add_subplot(rows, columns, i+1)
    plt.title('t: ' + classes[labels[i].cpu()] + ' p: ' + classes[y_pred[i].cpu()])
    img = images[i] / 2 + 0.5     # this is to unnormalize the image
    img = torchvision.transforms.ToPILImage()(img.cpu())
    plt.axis('off')
    plt.imshow(img)
plt.show()


In [None]:
y_true = labels.data.cpu().numpy()
y_pred = y_pred.data.cpu().numpy()

In [None]:
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score

accuracy = accuracy_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred, average='macro')
precision = precision_score(y_true, y_pred, average='macro')
recall = recall_score(y_true, y_pred, average='macro')
print('accuracy:', accuracy, ', f1 score:', f1, ', precision:', precision, ', recall:', recall)

# Image segmentation with pytorch using U-net

U-net was first developed in 2015 by Ronneberger et al., as a segmentation network for biomedical image analysis.
It has been extremely successful, with 9,000+ citations, and many new methods that have used the U-net architecture since.


The architecture of U-net is based on the idea of using skip connections (i.e. concatenating) at different levels of the network to retain high, and low level features.

Here is the architecture of a U-net:

---

![U-net](unet.png)
Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.

## Two-photon microscopy dataset of cortical axons

In this tutorial we use a dataset of cortical neurons with their corresponding segmentation binary labels.

These images were collected using in-vivo two-photon microscopy from the mouse somatosensory cortex. To generate the 2D images, a max projection was used over the 3D stack. The labels are binary segmentation maps of the axons.

Here we will use 100 [64x64] crops during training and validation. 

These are some example images [256x256] from the original dataset:
![axon_dataset](axon_dataset.png)

Bass, Cher, et al. "Image synthesis with a convolutional capsule generative adversarial network." Medical Imaging with Deep Learning (2019).


In [10]:
#load modules
from __future__ import print_function
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch
from torch.autograd import Variable
from AxonDataset import AxonDataset
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
import time
import torch.nn.functional as F
import torchvision.utils as vutils
import os
import matplotlib.pyplot as plt


In [2]:
# Setting parameters
timestr = time.strftime("%d%m%Y-%H%M")
__location__ = os.path.realpath(
    os.path.join(os.getcwd(), os.path.dirname('__file__')))

path = os.path.join(__location__,'results')
if not os.path.exists(path):
    os.makedirs(path)
    
# Define your batch_size
batch_size = 16


## Creating a dataloader

In this example, a custom dataloader was created, and we import it from `AxonDataset.py`

we create a dataset, and split into a train and validation set with 80%, 20% split

### Task 1

create a list of random indices for the train and validation sets

In [3]:
#First we create a dataloader for our example dataset- two photon microscopy with axons
axon_dataset = AxonDataset(data_name='org64', type='train')

# -----------------------------------------------------task 1----------------------------------------------------------------
# Task 1: create a random list of incides for training and testing with a 80%,20% split

# We need to further split our training dataset into training and validation sets.
# Define the indices
indices = list(range(len(axon_dataset)))  # start with all the indices in training set
split = int(len(indices)*0.2)  # define the split size

# Get indices for train and validation datasets, and split the data
validation_idx = np.random.choice(indices, size=split, replace=False)
train_idx = list(set(indices) - set(validation_idx))
# ----------------------------------------------------------------------------------------------------------------------------

# feed indices into the sampler
train_sampler = SubsetRandomSampler(train_idx)
validation_sampler = SubsetRandomSampler(validation_idx)

# Create a dataloader instance 
train_loader = torch.utils.data.DataLoader(axon_dataset, batch_size = batch_size,
                                           sampler=train_sampler) 
val_loader = torch.utils.data.DataLoader(axon_dataset, batch_size = batch_size,
                                        sampler=validation_sampler) 


## Build a U-net 

We next build our u-net network.

First we define a layer `double_conv` that performs 2 sets of convolution followed by ReLu.

In [4]:
# define U-net
def double_conv(in_channels, out_channels, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=padding),
        nn.ReLU(inplace=True)
    )


### Define neural network
We then define our U-net network.

We initialise all the different layers in the network in `__init__`:
1. `self.dconv_down1` is a double convolutional layer
2. `self.maxpool` is a max pooling layer that is used to reduce the size of the input, and decrease the reptive field
3. `self.upsample` is an upsampling layer that is used to increase the size of the input
4. `dropout` is a dropout layer that is applied to regulise the training
5. `dconv_up4` is also a double convolutional layer- note that it takes in additional channels from previous layers (i.e. the skip connections).

skip connection are easily implemented by concatenating the result of a previous convolution with the current input, 

using e.g. `torch.cat([x, conv4], dim=1)`

### Task 2 - implement skip connections
Implement skip connections for conv3, conv2, and conv1.

See conv4 example below:

In [5]:

class UNet(nn.Module):

    def __init__(self):
        super().__init__()
        
        self.dconv_down1 = double_conv(1, 32)
        self.dconv_down2 = double_conv(32, 64)
        self.dconv_down3 = double_conv(64, 128)
        self.dconv_down4 = double_conv(128, 256)
        self.dconv_down5 = double_conv(256, 512)

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
        self.dropout = nn.Dropout2d(0.5)
        self.dconv_up4 = double_conv(256 + 512, 256)
        self.dconv_up3 = double_conv(128 + 256, 128)
        self.dconv_up2 = double_conv(128 + 64, 64)
        self.dconv_up1 = double_conv(64 + 32, 32)

        self.conv_last = nn.Conv2d(32, 1, 1)

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        conv1 = self.dropout(conv1)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        conv2 = self.dropout(conv2)
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        conv3 = self.dropout(conv3)
        x = self.maxpool(conv3)

        conv4 = self.dconv_down4(x)
        conv4 = self.dropout(conv4)
        x = self.maxpool(conv4)

        conv5 = self.dconv_down5(x)
        conv5 = self.dropout(conv5)

        x = self.upsample(conv5)
        
        # example of skip connection with conv4
        x = torch.cat([x, conv4], dim=1)
        
        x = self.dconv_up4(x)
        x = self.dropout(x)

        x = self.upsample(x)
        
        # --------------------------------------------------- task 2 ----------------------------------------------------------
        # Task 2: implement skip connection with conv3
        x = torch.cat([x, conv3], dim=1)
        # ---------------------------------------------------------------------------------------------------------------------
        x = self.dconv_up3(x)
        x = self.dropout(x)

        x = self.upsample(x)
        
        # --------------------------------------------------- task 2 ----------------------------------------------------------
        # Task 2: implement skip connection with conv2
        x = torch.cat([x, conv2], dim=1)
        # ---------------------------------------------------------------------------------------------------------------------

        x = self.dconv_up2(x)
        x = self.dropout(x)
        x = self.upsample(x)
        
        # --------------------------------------------------- task 2 ----------------------------------------------------------
        # Task 2: implement skip connection with conv1
        x = torch.cat([x, conv1], dim=1)
        # ---------------------------------------------------------------------------------------------------------------------

        x = self.dconv_up1(x)
        x = self.dropout(x)

        out = F.sigmoid(self.conv_last(x))

        return out

we initialise the network with a previously trained network by loading the weights

*for practical reasons training this network from scratch will take too long, and require large computational resources*

In [6]:
# initialise network - and load weights
net = UNet()
net.load_state_dict(torch.load(path+'/'+'model.pt')) #this function loads a pretrained network

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

## Defining an appropriate loss function
We next define our loss function - in this case we use Dice loss, a commonly used loss for image segmentation.

The Dice coefficient can be used as a loss function, and is essentially a measure of overlap between two samples.

Dice is in the range of 0 to 1, where a Dice coefficient of 1 denotes perfect and complete overlap. The Dice coefficient was originally developed for binary data, and can be calculated as:

$Dice = \dfrac{2|A\cap B|}{|A| + |B|}$

where $|A\cap B|$ represents the common elements between sets $A$ and $B$, and $|A|$ represents the number of elements in set $A$ (and likewise for set $B$).

For the case of evaluating a Dice coefficient on predicted segmentation masks, we can approximate  $|A\cap B|$ as the element-wise multiplication between the prediction and target mask, and then sum the resulting matrix.

An **alternative loss** function would be pixel-wise cross entropy loss. It would examine each pixel individually, comparing the class predictions (depth-wise pixel vector) to our one-hot encoded target vector.


In [7]:
# dice loss
def dice_coeff(pred, target):
    """This definition generalize to real valued pred and target vector.
    This should be differentiable.
    pred: tensor with first dimension as batch
    target: tensor with first dimension as batch
    """

    smooth = 1.
    epsilon = 10e-8

    # have to use contiguous since they may from a torch.view op
    iflat = pred.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()

    A_sum = torch.sum(iflat * iflat)
    B_sum = torch.sum(tflat * tflat)

    dice = (2. * intersection + smooth) / (A_sum + B_sum + smooth)
    dice = dice.mean(dim=0)
    dice = torch.clamp(dice, 0, 1.0-epsilon)

    return  dice

# cross entropy loss
loss_BCE = nn.BCEWithLogitsLoss()


as before, we define the optimiser to train our network - here we use Adam.


In [8]:
#define your optimiser
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=1e-05, betas=(0.5, 0.999))
optimizer.zero_grad()


## Training and evaluating our segmentation network
We next train and evaluate our network 

note that the results are saved to a folder \results - so please check that

In [11]:
epochs=10
save_every=10
all_error = np.zeros(0)
all_error_L1 = np.zeros(0)
all_error_dice = np.zeros(0)
all_dice = np.zeros(0)
all_val_dice = np.zeros(1)
all_val_error = np.zeros(0)

for epoch in range(epochs):

    ##########
    # Train
    ##########
    t0 = time.time()
    for i, (data, label) in enumerate(train_loader):
        
        # setting your network to train will ensure that parameters will be updated during training, 
        # and that dropout will be used
        net.train()
        net.zero_grad()

        target_real = torch.ones(data.size()[0])
        batch_size = data.size()[0]
        pred = net(data)
        
        # dice loss = 1-dice_coeff
        # ----------------------------------------------- task 3 ------------------------------------------------------------
        # Task 3: change loss function here
        err = 1- dice_coeff(pred, label)
        err = loss_BCE(pred, label)
        # -------------------------------------------------------------------------------------------------------------------

        dice_value = dice_coeff(pred, label).item()

        err.backward()
        optimizer.step()
        optimizer.zero_grad()

        time_elapsed = time.time() - t0
        print('[{:d}/{:d}][{:d}/{:d}] Elapsed_time: {:.0f}m{:.0f}s Loss: {:.4f} Dice: {:.4f}'
              .format(epoch, epochs, i, len(train_loader), time_elapsed // 60, time_elapsed % 60,
                      err.item(), dice_value))

        if i % save_every == 0:
            # setting your network to eval mode to remove dropout during testing
            net.eval()

            vutils.save_image(data.data, '%s/epoch_%03d_i_%03d_train_data.png' % (path, epoch, i),
                                  normalize=True)
            vutils.save_image(label.data, '%s/epoch_%03d_i_%03d_train_label.png' % (path, epoch, i),
                                  normalize=True)
            vutils.save_image(pred.data, '%s/epoch_%03d_i_%03d_train_pred.png' % (path, epoch, i),
                                  normalize=True)

            error = err.item()

            all_error = np.append(all_error, error)
            all_dice = np.append(all_dice, dice_value)

    # #############
    # # Validation
    # #############
    mean_error = np.zeros(0)
    mean_dice = np.zeros(0)
    t0 = time.time()
    for i, (data, label) in enumerate(val_loader):

        net.eval()
        batch_size = data.size()[0]

        data, label = Variable(data), Variable(label)
        pred = net(data)
        
        # ----------------------------------------------- task 3 ------------------------------------------------------------
        # Task 3: change loss function here
        err = 1-dice_coeff(pred, label)
        # err = loss_BCE(pred, label)
        # -------------------------------------------------------------------------------------------------------------------

        # compare generated image to data-  metric
        dice_value = dice_coeff(pred, label).item()

        if i == 0:
            vutils.save_image(data.data, '%s/epoch_%03d_i_%03d_val_data.png' % (path, epoch, i),
                              normalize=True)
            vutils.save_image(label.data, '%s/epoch_%03d_i_%03d_val_label.png' % (path, epoch, i),
                              normalize=True)
            vutils.save_image(pred.data, '%s/epoch_%03d_i_%03d_val_pred.png' % (path, epoch, i),
                              normalize=True)

        error = err.item()
        mean_error = np.append(mean_error, error)
        mean_dice = np.append(mean_dice, dice_value)

    all_val_error = np.append(all_val_error, np.mean(mean_error))
    all_val_dice = np.append(all_val_dice, np.mean(mean_dice))

    time_elapsed = time.time() - t0

    print('Elapsed_time: {:.0f}m{:.0f}s Val dice: {:.4f}'
          .format(time_elapsed // 60, time_elapsed % 60, mean_dice.mean()))
    
    
    num_it_per_epoch_train = ((train_loader.dataset.x_data.shape[0] * (1 - 0.2)) // (
            save_every * batch_size)) + 1
    epochs_train = np.arange(1,all_error.size+1) / num_it_per_epoch_train
    epochs_val = np.arange(0,all_val_dice.size)

    plt.figure()
    plt.plot(epochs_val, all_val_dice, label='dice_val')
    plt.xlabel('epochs')
    plt.legend()
    plt.title('Dice score')
    plt.savefig(path + '/dice_val.png')
    plt.close()



  "See the documentation of nn.Upsample for details.".format(mode))


[0/10][0/20] Elapsed_time: 0m1s Loss: 0.6918 Dice: 0.3308
[0/10][1/20] Elapsed_time: 0m1s Loss: 0.6926 Dice: 0.3762
[0/10][2/20] Elapsed_time: 0m2s Loss: 0.6926 Dice: 0.4082
[0/10][3/20] Elapsed_time: 0m3s Loss: 0.6921 Dice: 0.3713
[0/10][4/20] Elapsed_time: 0m3s Loss: 0.6929 Dice: 0.4716
[0/10][5/20] Elapsed_time: 0m4s Loss: 0.6918 Dice: 0.3139
[0/10][6/20] Elapsed_time: 0m5s Loss: 0.6902 Dice: 0.3906
[0/10][7/20] Elapsed_time: 0m5s Loss: 0.6906 Dice: 0.2503
[0/10][8/20] Elapsed_time: 0m6s Loss: 0.6891 Dice: 0.4556
[0/10][9/20] Elapsed_time: 0m7s Loss: 0.6932 Dice: 0.4262
[0/10][10/20] Elapsed_time: 0m7s Loss: 0.6914 Dice: 0.3525
[0/10][11/20] Elapsed_time: 0m8s Loss: 0.6917 Dice: 0.3306
[0/10][12/20] Elapsed_time: 0m9s Loss: 0.6912 Dice: 0.3525
[0/10][13/20] Elapsed_time: 0m9s Loss: 0.6923 Dice: 0.3962
[0/10][14/20] Elapsed_time: 0m10s Loss: 0.6909 Dice: 0.4460
[0/10][15/20] Elapsed_time: 0m11s Loss: 0.6908 Dice: 0.4187
[0/10][16/20] Elapsed_time: 0m11s Loss: 0.6912 Dice: 0.4453
[0/1

[6/10][16/20] Elapsed_time: 0m13s Loss: 0.6902 Dice: 0.4482
[6/10][17/20] Elapsed_time: 0m14s Loss: 0.6915 Dice: 0.3506
[6/10][18/20] Elapsed_time: 0m14s Loss: 0.6904 Dice: 0.4503
[6/10][19/20] Elapsed_time: 0m15s Loss: 0.6901 Dice: 0.4841
Elapsed_time: 0m1s Val dice: 0.5425
[7/10][0/20] Elapsed_time: 0m1s Loss: 0.6905 Dice: 0.3740
[7/10][1/20] Elapsed_time: 0m1s Loss: 0.6903 Dice: 0.4503
[7/10][2/20] Elapsed_time: 0m2s Loss: 0.6910 Dice: 0.4864
[7/10][3/20] Elapsed_time: 0m3s Loss: 0.6912 Dice: 0.3169
[7/10][4/20] Elapsed_time: 0m4s Loss: 0.6890 Dice: 0.4592
[7/10][5/20] Elapsed_time: 0m4s Loss: 0.6898 Dice: 0.4876
[7/10][6/20] Elapsed_time: 0m5s Loss: 0.6918 Dice: 0.3694
[7/10][7/20] Elapsed_time: 0m6s Loss: 0.6884 Dice: 0.4061
[7/10][8/20] Elapsed_time: 0m7s Loss: 0.6930 Dice: 0.4311
[7/10][9/20] Elapsed_time: 0m7s Loss: 0.6907 Dice: 0.4738
[7/10][10/20] Elapsed_time: 0m8s Loss: 0.6900 Dice: 0.5504
[7/10][11/20] Elapsed_time: 0m9s Loss: 0.6910 Dice: 0.3993
[7/10][12/20] Elapsed_time

## Results 
the results are saved to a folder \results - so please check that:

The results are saved per epoch for both training and validation, and are saved as the 
1. real data, 
2. binary labels, 
3. predicted labels. 

In this example since we trained on a small sample of the data (100 crops) the results are far from optimal, and are likely to overfit to the data.

### Task 3

1. Change the dice loss to a cross entropy loss in the code - is dice loss or cross entropy loss better?
2. run the training with dropout - what's the effect?

**Note down your dice validation scores for each experiment, then change**
