**NOTE:** before to start this notebook go to edit-> Notebook settings and select GPU as a hardware accelerator. Now your ready to go!

#  Domain adaptation and curriculum learning

In this notebook, we will see in practice, how the data-shift distribution affects the performance of deep learning models and some state-of-the-art strategies to solve this problem. Specifically, we will explore:


* Supervised Domain Adaptation
* Augmentation-based Domain adaptation (ADA)
* Adversarial Domain Adaptation 
* Scheduling losses and curriculum learning

We particularly focus on medical imaging data.



In [None]:
%matplotlib inline
import os

Let's Download the data locally, import the libraries we need and define some plotting functions

In [None]:
file_download_link = "https://github.com/KCL-BMEIS/AdvancedMachineLearningCourse/blob/main/Week4-Multitask_and_domain_shift/Data/WMLesions.zip?raw=true"
!wget -nc -O WMLesions.zip --no-check-certificate "$file_download_link"
!unzip -n WMLesions.zip

In [None]:
# Setting up various plot functions to be used throughout the notebook

import matplotlib.pyplot as plt
import numpy as np
import random


def plot_images(images, title=None, figsize=(15,15)):
    f, axes = plt.subplots(1, len(images), figsize=figsize)
    for image_id, image in enumerate(images):
        axes[image_id].imshow(np.rot90(np.rot90(np.rot90(image))), cmap='gray')
        axes[image_id].axis('off')
        if not title:
            axes[image_id].set_title('Clinic {}'.format(image_id+1),
                                     fontsize=20)
        elif len(title) == 1:
            axes[image_id].set_title('Image {}: {}'.format(
                image_id, title[0]), fontsize=20)
        else:
            axes[image_id].set_title('Image {}: {}'.format(
                image_id, title[image_id]), fontsize=20)
    f.tight_layout()
    

def plot_grids(grids, figsize=(15,15)):
    f, axes = plt.subplots(1, len(grids), figsize=figsize)
    for grid_id, grid in enumerate(grids):
        grid_array = sitk.GetArrayViewFromImage(grid)
        axes[grid_id].imshow(np.flip(grid_array, axis=0),
               interpolation='hamming',
               cmap='gray',
               origin='lower')
        axes[grid_id].set_title('Grid {}'.format(grid_id), fontsize=20)
    f.tight_layout()

    
def plot_histograms(images, figsize=(15,7.5)):
    f, axes = plt.subplots(1, len(images), figsize=figsize)
    for image_id, image in enumerate(images):
        histogram, bins = np.histogram(image, bins=40)
        axes[image_id].set_xlim([0, 140])
        axes[image_id].hist(histogram, bins)
        axes[image_id].set_title('Clinic {} Histogram'.format(image_id),
                                 fontsize=20)
    f.tight_layout()

# PART 1: Differences in Data distributions


Supervised learning models assume that training and testing data come from the
same distribution to achieve good performance. However, when this assumption is not fulfilled, the performance of these methods experienced a drop in their performance.


Specifically in medical imaging, differences in distribution can come from different protocols modalities or settings when acquired. Let' load 3 datasets acquired  whit  different protocols and plot samples from each one. (
This is a common scenario in clinical practice as usually different clinics use different settings for image acquisition.)



<br>
<div>
<center>
<img src="https://raw.githubusercontent.com/MauricioOrbes/AML_lecture_5/master/images/datashiftComparison.png" width="1000"/>
</center>
</div>

> __Figure__: Data-shift problem example in a segemntation of medical images


Let's plot samples from 3 different clinics

In [None]:
import nibabel as nib
import os

def read_file(filename):
    img = nib.load(filename)
    data = img.get_fdata()
    aff = img.affine
    return data, aff

# Build the 2D dataset: Let's take 20 images
dataset = []
clinic1_dir = 'WMH_3D/Clinic1'
clinic2_dir = 'WMH_3D/Clinic2'
clinic3_dir = 'WMH_3D/Clinic3'
for clinic_dir in [clinic1_dir, clinic2_dir,clinic3_dir]:
    image, _ = read_file(clinic_dir + '/' + os.listdir(clinic_dir)[0])
    # Let's take the middle slice
    dataset.append(image[:, :, 25])

In [None]:
plot_images(dataset[:3])


As you can see the images acquired at the three clinics are qualitatively different. Let's look at the histograms!

In [None]:
plot_histograms(dataset)



<!-- As you can see not only are the images qualitatively different we can also see a difference when looking at the image histograms.   -->

# Let's see how this differences affects  the performance of a supervised learning model. 

We are going to train a Neural Network to perform WMH hyperintensity segmentation. We use labeled samples from the clinic 1 to carry out the training. Then, inference will be carried out in validation set from both **clinic 1** (Same domain) and **clinic 2** (different domain)

To perform this task we chose the well known [U-Net](https://arxiv.org/pdf/1505.04597.pdf) architecture which has achieved outstanding results for semantic segmentation.

Let's define the model in PyTorch. 


In [None]:
# import required libraries
import torch 
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

def initialize_weights(*models):
    for model in models:
        for module in model.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
                nn.init.kaiming_normal(module.weight)
                if module.bias is not None:
                    module.bias.data.zero_()
            elif isinstance(module, nn.BatchNorm2d):
                module.weight.data.fill_(1)
                module.bias.data.zero_()

class _EncoderBN(nn.Module):
    def __init__(self, in_channels, out_channels,k ,padd , dropout=False):
        super(_EncoderBN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=k, padding=padd)
        self.BN1a = nn.BatchNorm2d(out_channels)
        self.BN1b = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=k, padding=padd)
        self.BN2a = nn.BatchNorm2d(out_channels)
        self.BN2b = nn.BatchNorm2d(out_channels)

    def forward(self,x,d):
        if d == 'source':
            x = F.leaky_relu(self.BN1a(self.conv1(x)),inplace=True)
            x = F.leaky_relu(self.BN2a(self.conv2(x)),inplace=True)
        elif d == 'target':
            x = F.leaky_relu(self.BN1b(self.conv1(x)),inplace=True)
            x = F.leaky_relu(self.BN2b(self.conv2(x)),inplace=True)
        return x

class _DecoderBN(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super(_DecoderBN, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, middle_channels, kernel_size=3,padding=1)
        self.BN1a = nn.BatchNorm2d(middle_channels)
        self.BN1b = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, middle_channels, kernel_size=3, padding=1)
        self.BN2a = nn.BatchNorm2d(middle_channels)
        self.BN2b = nn.BatchNorm2d(middle_channels)
        self.convT = nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x, d):
        if d == 'source':
            x = F.leaky_relu(self.BN1a(self.conv1(x)),inplace=True)
            x = F.leaky_relu(self.BN2a(self.conv2(x)),inplace=True)
        elif d == 'target':
            x = F.leaky_relu(self.BN1b(self.conv1(x)),inplace=True)
            x = F.leaky_relu(self.BN2b(self.conv2(x)),inplace=True)

        return self.convT(x)

class prefinalBN(nn.Module):
      def __init__(self, in_channels,out_channels):
          super(prefinalBN,self).__init__()

          self.conv1 = nn.Conv2d(in_channels , out_channels, kernel_size=3, padding=1)
          self.BN1a = nn.BatchNorm2d(out_channels)
          self.BN1b = nn.BatchNorm2d(out_channels)
          self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
          self.BN2a = nn.BatchNorm2d(out_channels)
          self.BN2b = nn.BatchNorm2d(out_channels)
          nn.InstanceNorm1d

      def forward(self, x ,d):
          if d == 'source':

              x = F.leaky_relu(self.BN1a(self.conv1(x)),inplace=True)
              x = F.leaky_relu(self.BN2a(self.conv2(x)),inplace=True)
          elif d == 'target':
              x = F.leaky_relu(self.BN1b(self.conv1(x)),inplace=True)
              x = F.leaky_relu(self.BN2b(self.conv2(x)),inplace=True)
          return x

class ADABN(nn.Module):
    def __init__(self, num_classes, num_channels):
        super(ADABN, self).__init__()

        self.enc1 = _EncoderBN(num_channels, 64, 5, 2)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc2 = _EncoderBN(64, 96, 3, 1)
        self.enc3 = _EncoderBN(96, 128, 3, 1)
        self.enc4 = _EncoderBN(128, 256, 3, 1)
        self.center = _DecoderBN(256, 512, 256)
        self.dec4 = _DecoderBN(512, 256, 128)
        self.dec3 = _DecoderBN(256, 128, 96)
        self.dec2 = _DecoderBN(96 * 2, 96, 64)
        self.dec1 = prefinalBN(128,64)
        self.final = nn.Conv2d(64, num_classes, kernel_size=1)
        initialize_weights(self)

    def forward(self, x,d='source'):
        enc1 = self.enc1(x,d)
        enc2 = self.enc2(self.pool(enc1),d)
        enc3 = self.enc3(self.pool(enc2),d)
        enc4 = F.dropout(self.enc4(self.pool(enc3),d))

        center = self.center(self.pool(enc4),d)

        dec4 = self.dec4(torch.cat([center, enc4], 1),d)
        dec3 = self.dec3(torch.cat([dec4, enc3], 1),d)
        dec2 = self.dec2(torch.cat([dec3, enc2], 1),d)
        dec1 = self.dec1(torch.cat([dec2, enc1], 1),d)

        final = self.final(dec1)
        return (final, enc1, enc2, enc3, enc4, center, dec4, dec3, dec2, dec1)

model = ADABN(1,1)


<!-- Lets load some weights from a model trained only on clinic 1 -->
Training a neural network to perform segmentation would require many iterations. To save time we are going to load some already pre-trained weights. 



In [None]:
model.load_state_dict(torch.load('clinic1B2assource2load.pt'))


Now our model is ready to perform segmentation on validations sets from clinic 1 and clinic2. But before to do that we need to split our data (training and validation)  We're going to use a great package called `torchio` developed at KCL to save time


<!-- # Lets run that model on clinic 1 and clinic 2 and analyse the results

We're going to use a great package called `torchio` developed at KCL to save time. -->


In [None]:
!pip install torchio

In [None]:
import torchio
from torchio import Subject
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import Compose
from torchvision.transforms import RandomCrop
import re
def sorted_aphanumeric(data):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
    return sorted(data, key=alphanum_key)


# define the Validation and Training sets for source and target Domains

# CLINIC1 will be the source domain whereas CLINIC2 will be the target domain 
# We create a a list of subjects for each clinic

def getiosubjects(path_dir):
  subjects_list = sorted_aphanumeric(os.listdir( path_dir + '/flair/'))
  
  iosubjects = []
  for slices in subjects_list:
   
      subject = torchio.Subject(
          flair=torchio.ScalarImage(path_dir + '/flair/' + slices),
          label=torchio.LabelMap(path_dir + '/labels/wmh' +  slices.split('FLAIR')[1])
      )
      iosubjects.append(subject)
  
  return iosubjects
                       
source_training_dir = 'WMH_DATABASE/Clinic1/Training'
source_validation_dir = 'WMH_DATABASE/Clinic1/Validation'

target_training_dir = 'WMH_DATABASE/Clinic2/Training'
target_validation_dir= 'WMH_DATABASE/Clinic2/Validation'

source_training_subjects = getiosubjects(source_training_dir)
source_validation_subjects = getiosubjects(source_validation_dir)

target_training_subjects = getiosubjects(target_training_dir) 
target_validation_subjects = getiosubjects(target_validation_dir) 


# Training Sets
source_dataset_training = torchio.SubjectsDataset(source_training_subjects)
source_training_loader = DataLoader(source_dataset_training, shuffle=True, batch_size=6)

target_dataset_training = torchio.SubjectsDataset(target_training_subjects)
target_training_loader = DataLoader(target_dataset_training, shuffle=True, batch_size=6)

# Validation Sets 
source_dataset_validation = torchio.SubjectsDataset(source_validation_subjects)
source_validation_loader = DataLoader(source_dataset_validation, shuffle=False, batch_size=6)

target_dataset_validation = torchio.SubjectsDataset(target_validation_subjects)
target_validation_loader = DataLoader(target_dataset_validation, shuffle=False, batch_size=6)



### Let's Define a similarity measure 
To evaluate and compare performances we need a similarity measure between ground truth and predicted segmentation. To this end we are going to use the  [dice similarity index](https://en.wikipedia.org/wiki/S%C3%B8rensen%E2%80%93Dice_coefficient), Let's define the dice in PyTorch


In [None]:

def pairwisedice(output,target):
    s = (10e-20)
    output = output > 0.8
    output = output.type(torch.FloatTensor)
    
    target = target == 1
    target = target.type(torch.FloatTensor)

    intersect = torch.sum(output * target)

    dice = (2 * intersect) / (torch.sum(output) + torch.sum(target) + s)
    
    return dice




In [None]:
import random
from torch.autograd import Variable
import os
import time
import matplotlib.pyplot as plt
import matplotlib
from tqdm import tqdm
import sys


# We define the inference function which get the prediction for all subjects.
# Note, the predictions are obtained on batch base
def inference(data_loader,batch_size=6,slices_per_image=24):
  slice_counter = 0   
  dice = []
  with tqdm(total=len(data_loader), file=sys.stdout) as pbar:
    for slices in data_loader: 
        if (slice_counter % slices_per_image) == 0:
          pred = torch.Tensor([])
          ref = torch.Tensor([])        
        batch_images = slices['flair']['data'][...,0].float()
        batch_labels = slices['label']['data'][...,0]
        model.eval()
        outputs, _, _, _, _, _, _, _, _, _ = model(Variable(batch_images))
        
        pred = torch.cat((pred,torch.sigmoid(outputs)),0)
        ref = torch.cat((ref,batch_labels),0)
       
        if pred.size(0)==slices_per_image:
          dice.append(pairwisedice(torch.squeeze(pred,1),torch.squeeze(ref,1)).item())
        slice_counter += batch_size
        pbar.update(1)
  return dice

# First we get the inference for  Clinic1 validation set
dice_source_noadaptation=inference(source_validation_loader)
# We do the same for Clinic 2 validation set
dice_target_noadaptation=inference(target_validation_loader)

fig1, ax1 = plt.subplots()
ax1.set_title('Source vs Target Performance')
ax1.boxplot([dice_source_noadaptation,dice_target_noadaptation]);

ax1.set_xticklabels(['Source', 'Target'])
ax1.get_xaxis().tick_bottom()
ax1.set_ylabel('Dice')

print('Source dice performance {}' .format(np.mean(dice_source_noadaptation)))
print('Target dice performance {}' .format(np.mean(dice_target_noadaptation)))




As you can see the performance of our model on data from a different domain (Target) is considerably lower when compared to the performance on data belonging to the same domain (Source) used during training. This fact motivates the development of domain adaptation techniques to get our model to perform well across different domains.

**What do we mean by domain adaptation?**

Domain adaptation is the process of adapting a model from a source domain to a target domain in order to ovoid this dropp in performance.

Domain adaptation usually uses some information from the target domain to perform adaptation. Depending on the kind of information the adaptation could be:

**Supervised Adaptation:** When there is access to some labeled data from the target domain.

**Unsupervised Domain adaptation:** When we only have access to unlabeled data from the target domain


<br>
<div>
<center>
<img src="https://raw.githubusercontent.com/MauricioOrbes/AML_lecture_5/master/images/SupVsUnsupervised.png" width="1200"/>
</center>
</div>

> __Figure__: Supervised Vs Unsupervised Domain Adaptation


# PART 2: Lets try supervised domain adaptation between clinic 1 and clinic 2 and analyse the results

A straightforward way to carry out this adaptation, is by using a smaller amount
of annotated data from the target domain to refine a pre-trained classifier on a source domain. This process is also known as [fine-tuning](https://). 

Assuming a model $f(\cdot)$ has been training on a source domain using the following training loss.

* $Source\_Loss =  soft\_dice(f(source\_images),source\_labels) \ \ \ \ \ \ \ (1)$

The adaptation is done by computing the same loss on target images and add it to the previous one.

* $Target\_Loss = soft\_dice(f(target\_images),target\_labels) \ \ \ \ \ \  \ \ \ (2) $

* $Total\_loss =  \alpha* Source\_loss + \beta * Target\_loss \ \ \ \ \ \ \ \ \ \ \ \ \ \  \ \ \ \ \ \ \ \ \  \ \ \ \ \ \ (3)$

**Exercise:** Write the loss function to carry out domain adaptation. 

The code below is already implemented to continue the optimization based on the source loss. As the training goes the model will overfit the source domain, which will increase the gap between source and target performance.

I would like you to implement the **$Target\_loss$** which will be added to the **$Source\_loss$**. Also, try modifying the parameters  $\alpha$ and $\beta$ and see the impact it has on the performance.  

In [None]:
# Supervised domain adaptation

import random
import torch.optim as optim
from torch.autograd import Variable
import os
import time

# Define the optimizer (we choose adam with learning rate (1e-4))

optimizer = optim.Adam(model.parameters(), lr=1e-4 )

# Soft dice is used as cost function
def dice_soft_loss(output, target):
    s = (10e-20)

    intersect = torch.sum(output * target)
    dice = (2 * intersect) / (torch.sum(output) + torch.sum(target) + s)

    return 1 - dice


# Set the number of epochs here:
number_of_epochs = 2

# EXERCISE; Try different values for this parameters

alpha = 0.5 # YOUR CODE HERE
beta  = 0.5 # YOUR CODE HERE

model.train()
for epoch in range(number_of_epochs):
  
  with tqdm(total=len(source_training_loader), file=sys.stdout) as pbar:

    start_time = time.time()
    running_loss = 0
    indb =0
    for patch_s,patch_t in zip(source_training_loader,target_training_loader):
         
        #Get a batch of source slices 
        source_batch_images = patch_t['flair']['data'][...,0].float()
        source_batch_labels = patch_t['label']['data'][...,0]
        #Get a batch of target slices
        target_batch_images = patch_s['flair']['data'][...,0].float()
        target_batch_labels = patch_s['label']['data'][...,0]
        

        outputs_source, _, _, _, _, _, _, _, _, _ = model(Variable(source_batch_images,requires_grad=True))
        #Supervised loss eq (1)
        supervised_source_loss = dice_soft_loss(torch.sigmoid(outputs_source), Variable(source_batch_labels,requires_grad=True))
        

        #Supervised Adaptation loss goes here
        outputs_target, _, _, _, _, _, _, _, _, _ = model(Variable(target_batch_images,requires_grad=True))
        
        
        #EXERCISE: Implement the supervised target loss according to eq(2) 
        supervised_target_loss = dice_soft_loss(
            torch.sigmoid(outputs_target), Variable(target_batch_labels,requires_grad=True)) # YOUR CODE HERE 
        
        total_loss = alpha * supervised_source_loss + beta * supervised_target_loss
        
        model.zero_grad()
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()
        indb += 1
        pbar.update(1)

    end_time = time.time()

    print('Training: [epoch %d, loss %.3f] time:%.3f ' % (epoch + 1, running_loss / indb, (end_time-start_time) / 60 ))



In [None]:
# Now lets do inference again and compare agains no adaptation. 

dice_source_supervised_adaptation=inference(source_validation_loader)
dice_target_supervised_adaptation=inference(target_validation_loader)

fig1, ax1 = plt.subplots()
ax1.set_title('No adaptation vs Supervised Adaptation')
ax1.boxplot([dice_target_noadaptation, dice_target_supervised_adaptation]);

ax1.set_xticklabels(['No adaptation', 'Supervised Adaptation'])
ax1.get_xaxis().tick_bottom()
ax1.set_ylabel('Dice')


print('No adaptatiion  dice performance {}' .format(np.mean(dice_target_noadaptation)))
print('Supervised adaptatiion dice performance {}' .format(np.mean(dice_target_supervised_adaptation)))

As you can see the model performs much better now in the target domain data. However, labeled data from the source domain would not be always available, this motivates the development of unsupervised domain adaptation. 

<!-- # Lets run try unsupervised domain adaptation between clinic 1 and clinic 2 and analyse the results -->

# PART 3: Augmentation-based unsupervised Domain Adaptation


One smart way to perform domain adaptation when not labels are available for the target domain is through a consistency loss.  This idea was firstly used in the classification task [here](https://arxiv.org/pdf/1904.12848.pdf) and adapted for segmentation task  [here](https://arxiv.org/pdf/1904.12848.pdf)

This loss enforces the output consistency between the model predictions of one image and a perturbed or augmented version of it. As consistency is measured between predictions, no labels are needed, therefore, it can be computed on unlabeled data. Similarly to supervised domain adaptation, the optimization is carried out by the addition of supervised loss and a consistency loss as:  
* $Total\_loss =  \alpha* Source\_loss + \beta * Target\_loss \ \ \ \ \ \ \ \ \ \ \ \ \ \  \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (4)$

Where 
* $Source\_Loss =  soft\_dice(f(source\_images),source\_labels) \ \ \ \ \ \ \ \ \ \ \  \ \ \ \ \ \ \ \ \ \ \ \ \ \  (5)$

* $Consistency\_Loss =  soft\_dice(f(target\_images),f(\phi(target\_labels))) \ \ \ \ \ \ \ (6)$

Where $\phi(\cdot)$ performs a transformation to input image.

Although, the consistency loss as in $(5)$ would work for classification tasks (e.g a car rotated or in a different color is still a car).  A problem arises in a segmentation task, as the inconsistency introduced between original and augmented predictions when spatial transformations (e.g., translation, rotation or any similar spatial transformation for data augmentation) is applied with the input images.      

This problem can be solve by a simple trick, the same transformation $\phi(\cdot)$ applied to the image before to feed the model, will be applied to the prediction of the model for the original image. The consistency loss can be then rewritten as:

 $Consistency\_Loss =  soft\_dice(\phi(f(target\_images)),f(\phi(target\_labels))) \ \ \ \ \ \ \ (7)$

## Exercise: Let's perform augmentation-based unsueprvised adaptation between clinic 1 and clinic2 


First of all, we need to define a transformation to perturb the image, for simplicity we are going to use [affine](https://en.wikipedia.org/wiki/Affine_transformation) transformations, which perform rotations, scaling and shearing.




In [None]:
# Generate affine tranformations
# This function generate an affine transformation matrix for your bach

def GenerateAffine(inputs, degreeFreedom=5, scale=[0.9, 1.2], shearingScale=[0.01, 0.01], Ngpu=0):
    degree = torch.FloatTensor(inputs.size(0)).uniform_(-degreeFreedom, degreeFreedom) * 3.1416 / 180;
    Theta_rotations = torch.zeros(inputs.size(0), 3, 3)

    Theta_rotations[:, 0, 0] = torch.cos(degree);
    Theta_rotations[:, 0, 1] = torch.sin(degree);
    Theta_rotations[:, 1, 0] = -torch.sin(degree);
    Theta_rotations[:, 1, 1] = torch.cos(degree);
    Theta_rotations[:, 2, 2] = 1

    degree = torch.FloatTensor(inputs.size(0), 2).uniform_(scale[0], scale[1])

    Theta_scale = torch.zeros(inputs.size(0), 3, 3)

    Theta_scale[:, 0, 0] = degree[:, 0]
    Theta_scale[:, 0, 1] = 0
    Theta_scale[:, 1, 0] = 0
    Theta_scale[:, 1, 1] = degree[:, 1]
    Theta_scale[:, 2, 2] = 1

    degree = torch.cat((torch.FloatTensor(inputs.size(0), 1).uniform_(-shearingScale[0], shearingScale[0]),
                        torch.FloatTensor(inputs.size(0), 1).uniform_(-shearingScale[1], shearingScale[1])), 1)

    Theta_shearing = torch.zeros(inputs.size(0), 3, 3)

    Theta_shearing[:, 0, 0] = 1
    Theta_shearing[:, 0, 1] = degree[:, 0]
    Theta_shearing[:, 1, 0] = degree[:, 1]
    Theta_shearing[:, 1, 1] = 1
    Theta_shearing[:, 2, 2] = 1

    Theta = torch.matmul(Theta_rotations, Theta_scale)
    Theta = torch.matmul(Theta_shearing, Theta)

    Theta_inv = torch.inverse(Theta)

    Theta = Theta[:, 0:2, :]
    Theta_inv = Theta_inv[:, 0:2, :]

    return Theta, Theta_inv


# We need a function that transform the batch using the the affine tranformation matrix

def apply_trasform(inputs, theta):
    grid = F.affine_grid(theta, inputs.size())

    if len(inputs.size()) < 4:
        outputs = F.grid_sample(inputs, grid, mode='nearest', padding_mode="border")
    else:
        outputs = F.grid_sample(inputs, grid, padding_mode="border")

    return outputs



<!-- ## Let's perform Aumentation-based Domain adaptation between clinic1 and clinic2 

We are going to use affine transformations to perturb the images -->

I would like you to implement the consistency loss which will be added to the source supervised loss. Note you will have to apply transformations as in eq $(7)$


In [None]:
# As the network has been trained supervised, we load the weights again for fair comparison.
model.load_state_dict(torch.load('clinic1B2assource2load.pt'))
optimizer = optim.Adam(model.parameters(), lr=1e-4 )

sj =0

# Set parameteres here:
number_of_epochs = 2
alpha =0.5
beta  =0.5
model.train()
for epoch in range(number_of_epochs):
  
  with tqdm(total=len(source_training_loader), file=sys.stdout) as pbar:

    start_time = time.time()
    running_loss = 0
    indb =0
    for patch_s,patch_t in zip(source_training_loader,target_training_loader):

        source_batch_images = patch_t['flair']['data'][...,0].float()
        source_batch_labels = patch_t['label']['data'][...,0]
        
        target_batch_images = patch_s['flair']['data'][...,0].float()
        
        # Get the predictions of source batch and computed supervised loss
        outputs_source, _, _, _, _, _, _, _, _, _ = model(Variable(source_batch_images,requires_grad=True))
        supervised_source_loss = dice_soft_loss(torch.sigmoid(outputs_source), Variable(source_batch_labels,requires_grad=True))
        

        #EXERCISE: adaptation start here

        # Predictions for the original image are computed here 
        outputs_target, _, _, _, _, _, _, _, _, _ = model(Variable(target_batch_images,requires_grad=True))
        # Also the matrix transormation has been computed
        Theta, Theta_inv = GenerateAffine(Variable(target_batch_images,requires_grad=True))
    
        #Now you need to  get a perturbed version of the target_batch_images f(phi(target images)) .
        target_batch_images_aug = apply_trasform(
            Variable(target_batch_images,requires_grad=True), Theta) # YOUR CODE  HERE 
        

        #Once we have a perturbed batch we can get its prediction (f(phi(target images)))
        outputs_target_aug,_, _, _, _, _, _, _, _, _ = model(target_batch_images_aug)
        
        #Now you need to tranform the predictions of the original image phi(f(target_images))
        outputs_target_transformed = apply_trasform(
            outputs_target,Theta) # YOUR CODE HERE

        
        #Now you need to compute the the consistency loss as in eq (7)
        consitency_loss = dice_soft_loss(
            torch.sigmoid(outputs_target_aug),torch.sigmoid(outputs_target_transformed)) # YOUR CODE HERE 
        
        total_loss = alpha * supervised_source_loss + beta * consitency_loss

        model.zero_grad()
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()
        indb += 1
        pbar.update(1)

    end_time = time.time()

    print('Training: [epoch %d, loss %.3f] time:%.3f ' % (epoch + 1, running_loss / indb, (end_time-start_time) / 60 ))





As before we infer on the validation sets to see the performance

In [None]:

dice_source_auda=inference(source_validation_loader)
dice_target_auda=inference(target_validation_loader)




fig1, ax1 = plt.subplots()
ax1.set_title('No adaptation vs Supervised Adaptation vs Unsupervised adaptation')
ax1.boxplot([dice_target_noadaptation, dice_target_supervised_adaptation,dice_target_auda]);

ax1.set_xticklabels(['No adaptation', 'Supervised Adaptation','ADA'])
ax1.get_xaxis().tick_bottom()
ax1.set_ylabel('Dice')


print('No adaptatiion  dice performance {}' .format(np.mean(dice_target_noadaptation)))
print('Supervised adaptatiion dice performance {}' .format(np.mean(dice_target_supervised_adaptation)))
print('Augmentation-Unsupervised adaptatiion dice performance {}' .format(np.mean(dice_target_auda)))






As we can see the performance on the target domain has increased. Adaptation using only unlabeled data has been successful!. Now we are going to see another popular method to perform  Unsupervised domain adaptation. 


# PART 4: Adversarial Domain Adaptation

One popular solution for semi-supervised domain adaptation is through [adversarial learning](https://arxiv.org/abs/1612.08894). ​
​
The overall idea of adversarial domain adaptation is learning feature representations that are agnostic to the data domain. ​
​
This is achieved by learning an adversarial network that attempts to discriminate the domain of the input data coming from both domains.

<!-- The accuracy of a binary classifier that distinguishes between samples from two domains can serve as a proxy of the divergence of distributions p(Xs) and P(Xt) which otherwise is not straightforward to compute.  THis idea sas first introduce in....Insiperd by this , the authors of presented a method for simultaneously learning a domain-invariant representation and a task-related by a single network. this is done by minimizing the accuracy of an auxiliary network , a domain-discriminator that processes a hidden representation of the main network and tries to classify the domain of the input sample. -->


<br>
<div>
<center>
<img src="https://raw.githubusercontent.com/MauricioOrbes/AML_lecture_5/master/images/adversarial.png" width="800"/>
</center>
</div>

> __Figure__: Adversarial Domain Adaptation Scheme: The accuracy of a binary classifier can be used to measure differences from source or target distributions which could be not straightforward to computed by other means. The method simultaneously learns a domain-invariant representation $h_\theta( \cdot)$ and a task-related (domain classification) by a single network. This is done by minimizing the accuracy of an auxiliary network or domain-discriminator $d_\theta(\cdot)$ that processes a hidden representation $h_\theta( \cdot)$ of the main network and tries to classify the domain of the input sample

The optimization loss used to carry out the adaptation is the addition between a segmentation loss and the adversarial loss as follows

* $Total\_loss = Segmentation\_loss + \alpha * Adversarial\_loss \ \ \ (8)$

Where the segmentation loss is the same as in equations (1) and (5). 

* $Segmentation\_loss =   soft\_dice(f(source\_images),source\_labels) \ \ \ \ \ \ \ \ \ \ \  \ \ \ \ \ \ \ \ \ \ \ \ \ \  \ \ \ (9)$

For the adversarial loss, we are going to use Cross-Entropy (**CE**) as a cost function as it widely used in most of the classification tasks. Note the output of the discriminator is compared against the vector $[0,1]$ as we have assigned the labels $0$ and $1$ for the source and target domains.  
* $Adversarial\_loss = CE(d(h(source\_images,target\_images)),[0,1])   \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ (10)$

## Training schedule:
A complication of adversarial training concerns the strength with which the segmenter is adapting its features in order to counter the discriminator, which is controlled by the parameter $\alpha$ eq $(10)$. Setting $\alpha=0$ will let both networks learn independently. This will allows the segmenter to initially lear features for the segmentation of the source domain with of being affected by the noisy adversarial gradients from an initially poorly performing domain-discriminator.  As proposed in[ Kamnistast et al 2014](https://) a good practice of learning scheduling follows the next steps: 

1. Set $\alpha = 0$ during the first $e_1$ epochs 

2. after $e_1$ increase $\alpha$ according to a linear schedule as ($e_{curr}= $ current epoch): 

$$\alpha = \alpha_{max}\frac{e_{curr}-e_1}{e_2-e_1} \ \ \ \ \ \ \ \ \ \ (11)$$

Where $\alpha_{max}$ is the maximun weighting, 

3. set  $\alpha = \alpha_{max}$ after epoch  $e_2$


## Excersise: Schedule Adversarial Domain adaptation learning.

The code below performs adversarial domain adaptation. However as currently implemented both networks trains simultaneously from scratch. We would like you to implement the above-mentioned steps in order to get proper training. Try different values for $e_1, $ $e_2$, and $e_{max}$ and see the impact on the performance.  










In [None]:


import random
import torch.optim as optim
from torch.autograd import Variable
import os
import time
from tqdm import tqdm
import sys

# Adversarial Domain Adaptation

# let's define the Discriminator model

class DiscriminatorDomain(nn.Module):
    def __init__(self, num_channels,num_classes,complexity):
        super(DiscriminatorDomain, self).__init__()
        self.conv1 = nn.Conv2d(num_channels, int(8*complexity), kernel_size=3, stride=2)
        self.BN1   = nn.BatchNorm2d(int(8*complexity))
        self.conv2 = nn.Conv2d(int(8*complexity), int(16 * complexity) , kernel_size=3, stride=2)
        self.BN2   = nn.BatchNorm2d(int(16* complexity))
        self.conv3 = nn.Conv2d(int(16* complexity), int(32* complexity),  kernel_size=3, stride=2 )
        self.BN3   = nn.BatchNorm2d(int(32* complexity))
        self.conv4 = nn.Conv2d(int(32* complexity),int(64* complexity),kernel_size=3, stride=2)
        self.BN4   = nn.BatchNorm2d(int(64* complexity))
        
        self.fc1 = nn.Linear(int(64 * 7 * 7 * complexity), int(128* complexity))
        self.drop_1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(int(128* complexity), int(64* complexity))
        self.drop_2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(int(64* complexity), num_classes)
        
    def forward(self, x):
        
        x = F.leaky_relu(self.BN1(self.conv1(x)),0.2)
        x = F.leaky_relu(self.BN2(self.conv2(x)),0.2)
        x = F.leaky_relu(self.BN3(self.conv3(x)),0.2)

        x = F.leaky_relu(self.BN4(self.conv4(x)),0.2)
        complexity =x.size(1)
        x = x.view(-1, int( 7 * 7 * complexity))

        x = F.relu(self.drop_1(self.fc1(x)))
        x = F.relu(self.drop_2(self.fc2(x)))
        x = self.fc3(x)
        
        return x


Ninitialfilters = 4
complexity = Ninitialfilters / 8 

# Set The discriminator model
discriminator = DiscriminatorDomain(352,2,complexity)

# Let's load the pretrained model again
model.load_state_dict(torch.load('clinic1B2assource2load.pt'))

number_of_epochs = 10

#EXERCISE: initialize the parameters alpha, alpha_max, e1, and e2 

alpha_max= 0.3 # YOUR CODE HERE
e1= 2 # YOUR CODE HERE
e2= 5 # YOUR CODE HERE

# We need to initialized the optimazers
optimizer_model = optim.Adam(model.parameters(), lr=1e-4 )
optimizer_discriminator = optim.Adam(discriminator.parameters(), lr=1e-4 )

# Let's define the loss function to train the discriminator
lf_discriminator  =  nn.CrossEntropyLoss(size_average=True)

for epoch in range(number_of_epochs):
  
  # The Scheduling is governed by the parameter alpha in function of epochs
  # EXERCISE: Fill each one of the following conditions in order to shedule the training.

  if epoch < e1:
    alpha = 0 # YOUR CODE HERE
  elif epoch > e1 and epoch <= e2:
    alpha = alpha_max*(epoch-e1)/(e2-e1) # YOUR CODE HERE
  else:
    alpha = alpha_max # YOUR CODE HERE

  with tqdm(total=len(source_training_loader), file=sys.stdout) as pbar:

    start_time = time.time()
    running_loss = 0
    indb =0
    for patch_s,patch_t in zip(source_training_loader,target_training_loader):
        
        model.eval()
        discriminator.train()
        source_batch_images = patch_t['flair']['data'][...,0].float()
        source_batch_labels = patch_t['label']['data'][...,0]
        
        target_batch_images = patch_s['flair']['data'][...,0].float()
        
        #  First step the discriminator is updated

        # The input of the discriminator is composed of slices of Target and source 
        # Domain, so the ideas is clasify them
        inputs_model_adv = torch.cat((source_batch_images,target_batch_images),0)
        
        # We generate the labels to train the discriminator(0: for source domain 1: target domain)
        labels_discriminator = torch.cat((torch.zeros(6),torch.ones(6)),0).type(torch.LongTensor)
        
        # This is equivalent to get h(x) from the model (basicaly takes the 4 last 
        # decoder representation from the U-net)
        _,_,_,_,_,_,dec4,dec3,dec2,dec1 =  model(Variable(inputs_model_adv))        
        
        dec1 = F.interpolate(dec1, size = dec2.size()[2:], mode = 'bilinear')
        dec2 = F.interpolate(dec2, size = dec2.size()[2:], mode = 'bilinear')
        dec3 = F.interpolate(dec3, size = dec2.size()[2:], mode = 'bilinear')
        dec4 = F.interpolate(dec4, size = dec2.size()[2:], mode = 'bilinear')

        inputs_discriminator = torch.cat((dec1,dec2,dec3,dec4),1)
        outputs_discriminator = discriminator(inputs_discriminator)

        loss_classifier = lf_discriminator(Variable(outputs_discriminator,requires_grad=True), Variable(labels_discriminator))
        
        discriminator.zero_grad()
        loss_classifier.backward()
        optimizer_discriminator.step()
            

        # TRAIN THE SEGMENTER
        # ideally it should be used a diferent batch to feed the network.
        # for simplicity we use the same batch (the performance seems not 
        # be affected too much according to previous experiments).
        
        model.train()
        discriminator.eval()


        outputs_source, _, _, _, _, _, _, _, _, _ = model(Variable(source_batch_images,requires_grad=True))
        

        _,_,_,_,_,_,dec4,dec3,dec2,dec1 =  model(Variable(inputs_model_adv))        
        
        dec1 = F.interpolate(dec1, size = dec2.size()[2:], mode = 'bilinear')
        dec2 = F.interpolate(dec2, size = dec2.size()[2:], mode = 'bilinear')
        dec3 = F.interpolate(dec3, size = dec2.size()[2:], mode = 'bilinear')
        dec4 = F.interpolate(dec4, size = dec2.size()[2:], mode = 'bilinear')

        inputs_discriminator = torch.cat((dec1,dec2,dec3,dec4),1)
        outputs_discriminator = discriminator(inputs_discriminator)
        
        supervised_source_loss = dice_soft_loss(torch.sigmoid(outputs_source), Variable(source_batch_labels,requires_grad=True))
        loss_adv = lf_discriminator(Variable(outputs_discriminator,requires_grad=True), Variable(labels_discriminator))

        total_loss = supervised_source_loss -alpha*loss_adv
        
        model.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        running_loss += total_loss.item()
        indb += 1       
        

        pbar.update(1)

    end_time = time.time()

    print('Training: [epoch %d, loss %.3f] time:%.3f ' % (epoch + 1, running_loss / indb, (end_time-start_time) / 60 ))



Let's do inference and compare with the previous methods. 

In [None]:


dice_source_adversarial=inference(source_validation_loader)
dice_target_adversarial=inference(target_validation_loader)

fig1, ax1 = plt.subplots()
ax1.set_title('No adaptation vs Supervised Adaptation vs ADA vs Adversarial')
ax1.boxplot([dice_target_noadaptation, dice_target_supervised_adaptation,dice_target_auda,dice_target_adversarial]);

ax1.set_xticklabels(['No adaptation', 'Supervised Adaptation','ADA', 'Adversarial'])
ax1.get_xaxis().tick_bottom()
ax1.set_ylabel('Dice')


print('No adaptatiion  dice performance {}' .format(np.mean(dice_target_noadaptation)))
print('Supervised adaptatiion dice performance {}' .format(np.mean(dice_target_supervised_adaptation)))
print('Augmentation-Unsupervised adaptatiion dice performance {}' .format(np.mean(dice_target_auda)))
print('Adversatial adaptation dice performance {}' .format(np.mean(dice_target_adversarial)))
