# Uncertainty estimation for images

Spend the next 15min reading this paper
https://arxiv.org/abs/1703.04977

Now, let's implement it. 



In [None]:
import torch
device = torch.device("cuda: 0" if torch.cuda.is_available() else "cpu")

## The data

### Two-photon microscopy dataset of cortical axons

As in lecture 3, this tutorial we use a dataset of cortical neurons with their corresponding segmentation binary labels.

These images were collected using in-vivo two-photon microscopy from the mouse somatosensory cortex. To generate the 2D images, a max projection was used over the 3D stack. The labels are binary segmentation maps of the axons.

Here we will use 100 [64x64] crops during training and validation. 

Bass, Cher, et al. "Image synthesis with a convolutional capsule generative adversarial network." Medical Imaging with Deep Learning (2019).


In [None]:
file_download_link = "https://github.com/KCL-BMEIS/AdvancedMachineLearningCourse/blob/main/Week9_Uncertainty/Data/Week9_data.zip?raw=true"
!wget -O Week9_data.zip --no-check-certificate "$file_download_link"
!unzip -q Week9_data.zip

In [None]:
#load modules
from __future__ import print_function
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch
from torch.autograd import Variable
from Data.AxonDataset import AxonDataset
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
import time
import torch.nn.functional as F
import torchvision.utils as vutils
import os
import matplotlib.pyplot as plt


In [None]:

# Setting parameters
timestr = time.strftime("%d%m%Y-%H%M")
__location__ = os.path.realpath(
    os.path.join(os.getcwd(), os.path.dirname('__file__')))

print(__location__)

path = os.path.join(__location__,'results')
if not os.path.exists(path):
    os.makedirs(path)
    
# Define your batch_size
batch_size = 32


### Creating a dataloader

In this example, a custom dataloader was created, and we import it from `AxonDataset.py`

In [None]:

#First we create a dataloader for our example dataset- two photon microscopy with axons
axon_dataset = AxonDataset(data_name='org64', type='train', folder='/content/Data')

indices = list(range(len(axon_dataset)))  # start with all the indices in training set
split = int(len(indices)*0.2)  # define the split size

# Get indices for train and validation datasets, and split the data
validation_idx = np.random.choice(indices, size=split, replace=False)
train_idx = list(set(indices) - set(validation_idx))


# feed indices into the sampler
train_sampler = SubsetRandomSampler(train_idx)
validation_sampler = SubsetRandomSampler(validation_idx)

# Create a dataloader instance 
train_loader = torch.utils.data.DataLoader(axon_dataset, batch_size = batch_size,
                                           sampler=train_sampler) 
val_loader = torch.utils.data.DataLoader(axon_dataset, batch_size = batch_size,
                                        sampler=validation_sampler) 



## The network

Let's implement a UNET as per lecture 3. 

Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. "U-net: Convolutional networks for biomedical image segmentation." International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.

First we define a layer `double_conv` that performs 2 sets of convolution followed by ReLu.This is set up as a `nn.Sequential(` block.

In [None]:
# define U-net
def double_conv(in_channels, out_channels, padding=1):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=padding),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=padding),
        nn.ReLU(inplace=True)
    )


We then define our U-net network.

We first initialise all the different layers in the network in `__init__`:
1. `self.dconv_down1` is a double convolutional layer (defined above)
2. `self.maxpool` is a max pooling layer that is used to reduce the size of the input, and increase the receptive field
3. `self.upsample` is an upsampling layer that is used to increase the size of the input
4. `dropout` is a dropout layer that is applied to regularise the training
5. `dconv_up4` is also a double convolutional layer- note that it takes in additional channels from previous layers (i.e. the skip connections).


In [None]:

class UNet(nn.Module):

    def __init__(self):
        super().__init__()
        
        self.dconv_down1 = double_conv(1, 32)
        self.dconv_down2 = double_conv(32, 64)
        self.dconv_down3 = double_conv(64, 128)
        self.dconv_down4 = double_conv(128, 256)
        self.dconv_down5 = double_conv(256, 512)

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.dropout = nn.Dropout2d(0.3)
        self.dconv_up4 = double_conv(256 + 512, 256)
        self.dconv_up3 = double_conv(128 + 256, 128)
        self.dconv_up2 = double_conv(128 + 64, 64)
        self.dconv_up1 = double_conv(64 + 32, 32)
        
        
        self.dconv_up4_uncert = double_conv(256 + 512, 256)
        self.dconv_up3_uncert = double_conv(128 + 256, 128)
        self.dconv_up2_uncert = double_conv(128 + 64, 64)
        self.dconv_up1_uncert = double_conv(64 + 32, 32)

        self.conv_last = nn.Conv2d(32, 1, 1)
        self.conv_last_uncert = nn.Conv2d(32, 1, 1)

    def forward(self, x):
        
        #######   ENCODER ###############
        
        conv1 = self.dconv_down1(x)
        conv1 = self.dropout(conv1)
        x = self.maxpool(conv1)

        # implement encoder layers conv2, conv3 and conv4
        
        conv2 = self.dconv_down2(x)
        conv2 = self.dropout(conv2)
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        conv3 = self.dropout(conv3)
        x = self.maxpool(conv3)

        conv4 = self.dconv_down4(x)
        conv4 = self.dropout(conv4)
        x = self.maxpool(conv4)

        # implement bottleneck
        
        conv5 = self.dconv_down5(x)
        conv5 = self.dropout(conv5)
       
        #######   DECODER ###############
        
        # Implement the decoding layers
        
        deconv4 = self.upsample(conv5)
        deconv4 = torch.cat([deconv4, conv4], dim=1)  
        deconv4  = self.dconv_up4(deconv4)
        deconv4 = self.dropout(deconv4)

        deconv3 = self.upsample(deconv4 )       
        deconv3 = torch.cat([deconv3, conv3], dim=1)
        deconv3 = self.dconv_up3(deconv3)
        deconv3 = self.dropout(deconv3)

        deconv2 = self.upsample(deconv3)      
        deconv2 = torch.cat([deconv2, conv2], dim=1)
        deconv2 = self.dconv_up2(deconv2)
        deconv2 = self.dropout(deconv2)
       
        deconv1 = self.upsample(deconv2)   
        deconv1 = torch.cat([deconv1, conv1], dim=1)
        deconv1 = self.dconv_up1(deconv1)
        deconv1 = self.dropout(deconv1)

        #---------------------------------------------------------------------------------------------------------------------
        out1 = torch.sigmoid(self.conv_last(deconv1))
        
        deconv4_uncert = self.upsample(conv5)
        deconv4_uncert = torch.cat([deconv4_uncert, conv4], dim=1)  
        deconv4_uncert  = self.dconv_up4_uncert(deconv4_uncert)
        deconv4_uncert = self.dropout(deconv4_uncert)

        deconv3_uncert = self.upsample(deconv4_uncert )       
        deconv3_uncert= torch.cat([deconv3_uncert, conv3], dim=1)
        deconv3_uncert = self.dconv_up3_uncert(deconv3_uncert)
        deconv3_uncert = self.dropout(deconv3_uncert)

        deconv2_uncert = self.upsample(deconv3_uncert)      
        deconv2_uncert = torch.cat([deconv2_uncert, conv2], dim=1)
        deconv2_uncert = self.dconv_up2_uncert(deconv2_uncert)
        deconv2_uncert = self.dropout(deconv2_uncert)
       
        deconv1_uncert = self.upsample(deconv2_uncert)   
        deconv1_uncert = torch.cat([deconv1_uncert, conv1], dim=1)
        deconv1_uncert = self.dconv_up1_uncert(deconv1_uncert)
        deconv1_uncert = self.dropout(deconv1_uncert)

        #---------------------------------------------------------------------------------------------------------------------
        out2 = self.conv_last_uncert(deconv1_uncert)

        return out1, out2

To save time we initialise the network with a previously trained network by loading the weights

*for practical reasons training this network from scratch will take too long, and require large computational resources*

In [None]:
# initialise network - and load weights
net = UNet()

#adding line to support GPU use (where available)
net=net.to(device)

## The Loss Function
We next define a Dice loss as a tracking metric, and the CE uncertainty loss as the optimisation metric. 

In [None]:
# dice loss
def dice_coeff(pred, target):
    """This definition generalize to real valued pred and target vector.
    This should be differentiable.
    pred: tensor with first dimension as batch
    target: tensor with first dimension as batch
    """

    smooth = 1.
    epsilon = 10e-8

    # have to use contiguous since they may from a torch.view op
    iflat = pred.contiguous().view(-1)
    tflat = target.contiguous().view(-1)
    intersection = (iflat * tflat).sum()

    A_sum = torch.sum(iflat * iflat)
    B_sum = torch.sum(tflat * tflat)

    dice = (2. * intersection + smooth) / (A_sum + B_sum + smooth)
    dice = dice.mean(dim=0)
    dice = torch.clamp(dice, 0, 1.0-epsilon)

    return  dice




In [None]:
def ce_with_uncertainty(pred,log_var, target, regul=1, epsilon=0.005):
    lab_smoothing=0.00001
    ce_loss = - torch.log(pred+lab_smoothing) * target - torch.log((1-pred)+lab_smoothing) * (1-target) 
    total_loss = 0.5*ce_loss/(torch.exp(log_var)+epsilon) + regul*0.5*torch.log(torch.exp(log_var) + epsilon)
    return total_loss.mean()
    

Here the penalty term `lab_smoothing` is added to prevent penalising very harshly any mistake by the network.

As before, we define the optimiser to train our network - here we use Adam.


In [None]:
#define your optimiser
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=1e-04, betas=(0.5, 0.999))
optimizer.zero_grad()


## Training and evaluating the network
We next train and evaluate our network 

note that the results are saved to a folder \results - so please check that

In [None]:
epochs=300
save_every=50
all_error = np.zeros(0)
all_error_L1 = np.zeros(0)
all_error_dice = np.zeros(0)
all_dice = np.zeros(0)
all_val_dice = np.zeros(1)
all_val_error = np.zeros(0)
import os
cwd = os.getcwd()
path= cwd + '/results'

t0 = time.time()
for epoch in range(epochs):

    ##########
    # Train
    ##########
    mean_error = np.zeros(0)
    mean_dice = np.zeros(0)
    for i, (data, label) in enumerate(train_loader):
        data = data.to(device)
        label= label.to(device)
        # setting your network to train will ensure that parameters will be updated during training, 
        # and that dropout will be used
        net.train()
        net.zero_grad()

        target_real = torch.ones(data.size()[0])
        batch_size = data.size()[0]
        pred, log_var = net(data)

        # dice loss = 1-dice_coeff
        # ----------------------------------------------- task 3 ------------------------------------------------------------
        # Task 3: change loss function here
        err = ce_with_uncertainty(pred, log_var, label, regul=3, epsilon=(0.001-0.0009*epoch/epochs))
        # -------------------------------------------------------------------------------------------------------------------

        dice_value = dice_coeff(pred, label).item()

        err.backward()
        optimizer.step()
        optimizer.zero_grad()

        error = err.item()
        mean_error = np.append(mean_error, error)
        mean_dice = np.append(mean_dice, dice_value)


    # #############
    # # Validation
    # #############
    if epoch % 5 == 0:
        time_elapsed = time.time() - t0
        error = err.item()

        print('[{:d}/{:d}] Elapsed_time: {:.0f}m{:.0f}s Loss: {:.4f} Dice: {:.4f}'
              .format(epoch, epochs, time_elapsed // 60, time_elapsed % 60,
                      np.mean(mean_error), np.mean(mean_dice)))
        t0 = time.time()
        
    mean_error = np.zeros(0)
    mean_dice = np.zeros(0)
    for i, (data, label) in enumerate(val_loader):
        data = data.to(device)
        label= label.to(device)
        net.eval()
        batch_size = data.size()[0]

        data, label = Variable(data), Variable(label)
        pred, log_var = net(data)
        
        err = 1-dice_coeff(pred, label)

        # compare generated image to data-  metric
        dice_value = dice_coeff(pred, label).item()
        if epoch % 100 == 0:
            if i == 0:
                vutils.save_image(data.data, '%s/epoch_val_data.png' % (path),
                                  normalize=True)
                vutils.save_image(label.data, '%s/epoch_val_label.png' % (path),
                                  normalize=True)
                vutils.save_image(pred.data, '%s/epoch_val_pred.png' % (path),
                                  normalize=True)
                vutils.save_image(torch.exp(log_var).data, '%s/epoch_val_aleatoric_uncert.png' % (path),
                                  normalize=True)

        error = err.item()
        mean_error = np.append(mean_error, error)
        mean_dice = np.append(mean_dice, dice_value)

    all_val_error = np.append(all_val_error, np.mean(mean_error))
    all_val_dice = np.append(all_val_dice, np.mean(mean_dice))
    
    
    num_it_per_epoch_train = ((train_loader.dataset.x_data.shape[0] * (1 - 0.2)) // (
            save_every * batch_size)) + 1
    epochs_train = np.arange(1,all_error.size+1) / num_it_per_epoch_train
    epochs_val = np.arange(0,all_val_dice.size)

    plt.figure()
    plt.plot(epochs_val, all_val_dice, label='dice_val')
    plt.xlabel('epochs')
    plt.legend()
    plt.title('Dice score')
    plt.savefig(path + '/dice_val.png')
    plt.close()



## Dropout Sampling

Now, let's sample from the model with dropout at test time. A few tricks are neede though...



In [None]:
def enable_dropout(model):
  for m in model.modules():
    if m.__class__.__name__.startswith('Dropout'):
      m.train()
  
# #############
# # Epistemic Evaluation
# #############
dropout_samples = 100
mean_error = np.zeros(0)
mean_dice = np.zeros(0)
samples_array = torch.zeros(size=(dropout_samples, 32, 1, 64, 64), dtype=torch.float)
samples_logvar = torch.zeros(size=(dropout_samples, 32, 1, 64, 64), dtype=torch.float)
for i, (data, label) in enumerate(val_loader):
  for sample in range(dropout_samples):
    data = data.to(device)
    label= label.to(device)
    enable_dropout(net)
    batch_size = data.size()[0]

    data, label = Variable(data), Variable(label)
    pred, log_var = net(data)
    
    if pred.data.size()[0] == 32:
      samples_array[sample, :, :, :, :] = pred.data.detach().cpu()
      samples_logvar[sample, :, :, :, :] = torch.exp(log_var.data.detach().cpu())


  net.eval()
  pred, log_var = net(data)

  # Calculate the variance resulting from multiple dropout samples, per pixel
  epistemic_var = torch.var(samples_array, axis=0)
  mean_pred = torch.mean(samples_array, axis=0)
  mean_logval = torch.mean(samples_logvar, axis=0)
  if i == 0:
    vutils.save_image(data.data, '%s/epoch_val_data.png' % (path),
                      normalize=True)
    vutils.save_image(label.data, '%s/epoch_val_label.png' % (path),
                      normalize=True)
    vutils.save_image(mean_pred, '%s/epoch_val_pred.png' % (path),
                      normalize=True)
    vutils.save_image(mean_logval, '%s/epoch_val_aleatoric_uncert.png' % (path),
                      normalize=False)
    vutils.save_image(epistemic_var, '%s/epoch_val_epistemic_unc.png' % (path),
                      normalize=False)
  else: 
    break

