# Image segmentation with pytorch using U-net

U-net was first developed in 2015 by Ronneberger et al., as a segmentation network for biomedical image analysis.
It has been extremely successful, with 9,000+ citations, and many new methods that have used the U-net architecture since.


The architecture of U-net is based on the idea of using skip connections (i.e. concatenating) at different levels of the network to retain high, and low level features.

Here is the architecture of a U-net:

---

![U-net](unet.png)
Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.

## Two-photon microscopy dataset of cortical axons

In this tutorial we use a dataset of cortical neurons with their corresponding segmentation binary labels.

These images were collected using in-vivo two-photon microscopy from the mouse somatosensory cortex. To generate the 2D images, a max projection was used over the 3D stack. The labels are binary segmentation maps of the axons.

Here we will use 100 [64x64] crops during training and validation. 

These are some example images [256x256] from the original dataset:
![axon_dataset](axon_dataset.png)

Bass, Cher, et al. "Image synthesis with a convolutional capsule generative adversarial network." Medical Imaging with Deep Learning (2019).


In [1]:
#load modules
from __future__ import print_function
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch
from torch.autograd import Variable
from load_memmap import *
from AxonDataset import *
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
import time
import torch.nn.functional as F
import torchvision.utils as vutils
import os

In [2]:
# Setting parameters
timestr = time.strftime("%d%m%Y-%H%M")
__location__ = os.path.realpath(
    os.path.join(os.getcwd(), os.path.dirname('__file__')))

path = os.path.join(__location__,'results')
if not os.path.exists(path):
    os.makedirs(path)
    
# Define your batch_size
batch_size = 16


## Creating a dataloader

In this example, a custom dataloader was created, and we import it from `AxonDataset.py`

we create a dataset, and split into a train and validation set with 80%, 20% split

### Task 1

create a list of random indices for the train and validation sets

In [3]:
#First we create a dataloader for our example dataset- two photon microscopy with axons
axon_dataset = AxonDataset(data_name='org64', type='train')

# -----------------------------------------------------task 1----------------------------------------------------------------
# Task 1: create a random list of incides for training and testing with a 80%,20% split

# We need to further split our training dataset into training and validation sets.
# Define the indices
indices = list(range(len(axon_dataset)))  # start with all the indices in training set
split = int(len(indices)*0.2)  # define the split size

# Get indices for train and validation datasets, and split the data
validation_idx = np.random.choice(indices, size=split, replace=False)
train_idx = list(set(indices) - set(validation_idx))
# ----------------------------------------------------------------------------------------------------------------------------

# feed indices into the sampler
train_sampler = SubsetRandomSampler(train_idx)
validation_sampler = SubsetRandomSampler(validation_idx)

# Create a dataloader instance 
train_loader = torch.utils.data.DataLoader(axon_dataset, batch_size = batch_size,
                                           sampler=train_sampler) 
val_loader = torch.utils.data.DataLoader(axon_dataset, batch_size = batch_size,
                                        sampler=validation_sampler) 


## Build a U-net 

We next build our u-net network.

First we define a layer `double_conv` that performs 2 sets of convolution followed by ReLu.

In [4]:
# define U-net
def double_conv(in_channels, out_channels, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=padding),
        nn.ReLU(inplace=True)
    )


### Define neural network
We then define our U-net network.

We initialise all the different layers in the network in `__init__`:
1. `self.dconv_down1` is a double convolutional layer
2. `self.maxpool` is a max pooling layer that is used to reduce the size of the input, and decrease the reptive field
3. `self.upsample` is an upsampling layer that is used to increase the size of the input
4. `dropout` is a dropout layer that is applied to regulise the training
5. `dconv_up4` is also a double convolutional layer- note that it takes in additional channels from previous layers (i.e. the skip connections).

skip connection are easily implemented by concatenating the result of a previous convolution with the current input, 

using e.g. `torch.cat([x, conv4], dim=1)`

### Task 2 - implement skip connections
Implement skip connections for conv3, conv2, and conv1.

See conv4 example below:

In [5]:

class UNet(nn.Module):

    def __init__(self):
        super().__init__()
        
        self.dconv_down1 = double_conv(1, 32)
        self.dconv_down2 = double_conv(32, 64)
        self.dconv_down3 = double_conv(64, 128)
        self.dconv_down4 = double_conv(128, 256)

        self.dconv_down5 = double_conv(256, 512)

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear')
        self.dropout = nn.Dropout2d(0.5)
        self.dconv_up4 = double_conv(256 + 512, 256)
        self.dconv_up3 = double_conv(128 + 256, 128)
        self.dconv_up2 = double_conv(128 + 64, 64)
        self.dconv_up1 = double_conv(64 + 32, 32)

        self.conv_last = nn.Conv2d(32, 1, 1)

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        conv1 = self.dropout(conv1)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        conv2 = self.dropout(conv2)
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        conv3 = self.dropout(conv3)
        x = self.maxpool(conv3)

        conv4 = self.dconv_down4(x)
        conv4 = self.dropout(conv4)
        x = self.maxpool(conv4)

        conv5 = self.dconv_down5(x)
        conv5 = self.dropout(conv5)

        x = self.upsample(conv5)
        
        # example of skip connection with conv4
        x = torch.cat([x, conv4], dim=1)
        
        x = self.dconv_up4(x)
        x = self.dropout(x)

        x = self.upsample(x)
        
        # --------------------------------------------------- task 2 ----------------------------------------------------------
        # Task 2: implement skip connection with conv3
        x = torch.cat([x, conv3], dim=1)
        # ---------------------------------------------------------------------------------------------------------------------
        x = self.dconv_up3(x)
        x = self.dropout(x)

        x = self.upsample(x)
        
        # --------------------------------------------------- task 2 ----------------------------------------------------------
        # Task 2: implement skip connection with conv2
        x = torch.cat([x, conv2], dim=1)
        # ---------------------------------------------------------------------------------------------------------------------

        x = self.dconv_up2(x)
        x = self.dropout(x)
        x = self.upsample(x)
        
        # --------------------------------------------------- task 2 ----------------------------------------------------------
        # Task 2: implement skip connection with conv1
        x = torch.cat([x, conv1], dim=1)
        # ---------------------------------------------------------------------------------------------------------------------

        x = self.dconv_up1(x)
        x = self.dropout(x)

        out = F.sigmoid(self.conv_last(x))

        return out

we initialise the network with a previously trained network by loading the weights

*for practical reasons training this network from scratch will take too long, and require large computational resources*

In [6]:
# initialise network - and load weights
net = UNet()
net.load_state_dict(torch.load(path+'/'+'model.pt')) #this function loads a pretrained network

<All keys matched successfully>

## Defining an appropriate loss function
We next define our loss function - in this case we use Dice loss, a commonly used loss for image segmentation.

The Dice coefficient can be used as a loss function, and is essentially a measure of overlap between two samples.

Dice is in the range of 0 to 1, where a Dice coefficient of 1 denotes perfect and complete overlap. The Dice coefficient was originally developed for binary data, and can be calculated as:

$Dice = \dfrac{2|A\cap B|}{|A| + |B|}$

where $|A\cap B|$ represents the common elements between sets $A$ and $B$, and $|A|$ represents the number of elements in set $A$ (and likewise for set $B$).

For the case of evaluating a Dice coefficient on predicted segmentation masks, we can approximate  $|A\cap B|$ as the element-wise multiplication between the prediction and target mask, and then sum the resulting matrix.

An **alternative loss** function would be pixel-wise cross entropy loss. It would examine each pixel individually, comparing the class predictions (depth-wise pixel vector) to our one-hot encoded target vector.


In [7]:
# dice loss
def dice_coeff(pred, target):
    """This definition generalize to real valued pred and target vector.
    This should be differentiable.
    pred: tensor with first dimension as batch
    target: tensor with first dimension as batch
    """

    smooth = 1.
    epsilon = 10e-8

    # have to use contiguous since they may from a torch.view op
    iflat = pred.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()

    A_sum = torch.sum(iflat * iflat)
    B_sum = torch.sum(tflat * tflat)

    dice = (2. * intersection + smooth) / (A_sum + B_sum + smooth)
    dice = dice.mean(dim=0)
    dice = torch.clamp(dice, 0, 1.0-epsilon)

    return  dice

# cross entropy loss
loss_BCE = nn.BCEWithLogitsLoss()


as before, we define the optimiser to train our network - here we use Adam.


In [8]:
#define your optimiser
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=1e-05, betas=(0.5, 0.999))
optimizer.zero_grad()


## Training and evaluating our segmentation network
We next train and evaluate our network 

note that the results are saved to a folder \results - so please check that

In [9]:
epochs=10
save_every=10
all_error = np.zeros(0)
all_error_L1 = np.zeros(0)
all_error_dice = np.zeros(0)
all_dice = np.zeros(0)
all_val_dice = np.zeros(1)
all_val_error = np.zeros(0)

for epoch in range(epochs):

    ##########
    # Train
    ##########
    t0 = time.time()
    for i, (data, label) in enumerate(train_loader):
        
        # setting your network to train will ensure that parameters will be updated during training, 
        # and that dropout will be used
        net.train()
        net.zero_grad()

        target_real = torch.ones(data.size()[0])
        batch_size = data.size()[0]
        pred = net(data)
        
        # dice loss = 1-dice_coeff
        # ----------------------------------------------- task 3 ------------------------------------------------------------
        # Task 3: change loss function here
        err = 1- dice_coeff(pred, label)
        err = loss_BCE(pred, label)
        # -------------------------------------------------------------------------------------------------------------------

        dice_value = dice_coeff(pred, label).item()

        err.backward()
        optimizer.step()
        optimizer.zero_grad()

        time_elapsed = time.time() - t0
        print('[{:d}/{:d}][{:d}/{:d}] Elapsed_time: {:.0f}m{:.0f}s Loss: {:.4f} Dice: {:.4f}'
              .format(epoch, epochs, i, len(train_loader), time_elapsed // 60, time_elapsed % 60,
                      err.item(), dice_value))

        if i % save_every == 0:
            # setting your network to eval mode to remove dropout during testing
            net.eval()

            vutils.save_image(data.data, '%s/epoch_%03d_i_%03d_train_data.png' % (path, epoch, i),
                                  normalize=True)
            vutils.save_image(label.data, '%s/epoch_%03d_i_%03d_train_label.png' % (path, epoch, i),
                                  normalize=True)
            vutils.save_image(pred.data, '%s/epoch_%03d_i_%03d_train_pred.png' % (path, epoch, i),
                                  normalize=True)

            error = err.item()

            all_error = np.append(all_error, error)
            all_dice = np.append(all_dice, dice_value)

    # #############
    # # Validation
    # #############
    mean_error = np.zeros(0)
    mean_dice = np.zeros(0)
    t0 = time.time()
    for i, (data, label) in enumerate(val_loader):

        net.eval()
        batch_size = data.size()[0]

        data, label = Variable(data), Variable(label)
        pred = net(data)
        
        # ----------------------------------------------- task 3 ------------------------------------------------------------
        # Task 3: change loss function here
        err = 1-dice_coeff(pred, label)
        # err = loss_BCE(pred, label)
        # -------------------------------------------------------------------------------------------------------------------

        # compare generated image to data-  metric
        dice_value = dice_coeff(pred, label).item()

        if i == 0:
            vutils.save_image(data.data, '%s/epoch_%03d_i_%03d_val_data.png' % (path, epoch, i),
                              normalize=True)
            vutils.save_image(label.data, '%s/epoch_%03d_i_%03d_val_label.png' % (path, epoch, i),
                              normalize=True)
            vutils.save_image(pred.data, '%s/epoch_%03d_i_%03d_val_pred.png' % (path, epoch, i),
                              normalize=True)

        error = err.item()
        mean_error = np.append(mean_error, error)
        mean_dice = np.append(mean_dice, dice_value)

    all_val_error = np.append(all_val_error, np.mean(mean_error))
    all_val_dice = np.append(all_val_dice, np.mean(mean_dice))

    time_elapsed = time.time() - t0

    print('Elapsed_time: {:.0f}m{:.0f}s Val dice: {:.4f}'
          .format(time_elapsed // 60, time_elapsed % 60, mean_dice.mean()))
    
    
    num_it_per_epoch_train = ((train_loader.dataset.x_data.shape[0] * (1 - 0.2)) // (
            save_every * batch_size)) + 1
    epochs_train = np.arange(1,all_error.size+1) / num_it_per_epoch_train
    epochs_val = np.arange(0,all_val_dice.size)

    plt.figure()
    plt.plot(epochs_val, all_val_dice, label='dice_val')
    plt.xlabel('epochs')
    plt.legend()
    plt.title('Dice score')
    plt.savefig(path + '/dice_val.png')
    plt.close()



  "See the documentation of nn.Upsample for details.".format(mode))


[0/10][0/20] Elapsed_time: 0m2s Loss: 0.3897 Dice: 0.6103
[0/10][1/20] Elapsed_time: 0m3s Loss: 0.3224 Dice: 0.6776
[0/10][2/20] Elapsed_time: 0m5s Loss: 0.3065 Dice: 0.6935
[0/10][3/20] Elapsed_time: 0m6s Loss: 0.3955 Dice: 0.6045
[0/10][4/20] Elapsed_time: 0m8s Loss: 0.3504 Dice: 0.6496
[0/10][5/20] Elapsed_time: 0m9s Loss: 0.3148 Dice: 0.6852
[0/10][6/20] Elapsed_time: 0m10s Loss: 0.3657 Dice: 0.6343
[0/10][7/20] Elapsed_time: 0m12s Loss: 0.3506 Dice: 0.6494
[0/10][8/20] Elapsed_time: 0m13s Loss: 0.3435 Dice: 0.6565
[0/10][9/20] Elapsed_time: 0m14s Loss: 0.3312 Dice: 0.6688
[0/10][10/20] Elapsed_time: 0m16s Loss: 0.3087 Dice: 0.6913
[0/10][11/20] Elapsed_time: 0m17s Loss: 0.3534 Dice: 0.6466
[0/10][12/20] Elapsed_time: 0m19s Loss: 0.3108 Dice: 0.6892
[0/10][13/20] Elapsed_time: 0m20s Loss: 0.3470 Dice: 0.6530
[0/10][14/20] Elapsed_time: 0m22s Loss: 0.3565 Dice: 0.6435
[0/10][15/20] Elapsed_time: 0m23s Loss: 0.3281 Dice: 0.6719
[0/10][16/20] Elapsed_time: 0m25s Loss: 0.3561 Dice: 0.6

## Results 
the results are saved to a folder \results - so please check that:

The results are saved per epoch for both training and validation, and are saved as the 
1. real data, 
2. binary labels, 
3. predicted labels. 

In this example since we trained on a small sample of the data (100 crops) the results are far from optimal, and are likely to overfit to the data.

### Task 3

1. Change the dice loss to a cross entropy loss in the code - is dice loss or cross entropy loss better?
2. run the training with dropout - what's the effect?

**Note down your dice validation scores for each experiment, then change**
