<a href="https://colab.research.google.com/github/DerManjuel/MDL/blob/main/MDL_Exercise3_registration.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Medical Deep Learning


## Exercise 3: Learning-based multi-modal 3D registration


The aim of this exercise is to introduce you to deep-learning based image registration and also explore mutual information as a metric to supervise the training of multi-modal feature networks. The method we want to implement comprises three parts: 

1.   A global mutual information loss function
2.   A correlation layer to robustly estimate large rigid transformations
3.   A compact 3D CNN network (with some modality specific and some shared layers) to predict features suitable for multi-modal CT/MR registration

**Provided functions and data loading**

![thorax](https://drive.google.com/uc?export=view&id=1b8Y6YRMl-YTe6tdSysY2C1dKA9Vvqh6y)

We have prepared a dataset with CT and MRI scan pairs of same patients from TCIA (the Cancer Imaging Archive) and also provide you with manual annotations to evaluate the method (those are not necessary for training).
The image dimensions are $192\times160\times192$, there are 8 (pre-aligned) scan pairs that will be augmented to 64 pairs using random rigid transformations and each scan has 4 anatomical labels: liver, spleen, right kidney and left kidney (some patients have only one kidney).

The following cells provide the fundamental for loading and augmenting the data. Also, you are given a number of functions for the subsequent registration tasks. You are invited to read and understand the code if you are interested.

In [1]:
!pip install torchinfo
import torchinfo
# Download train and test data
!wget -nc https://cloud.imi.uni-luebeck.de/s/76KJ7RBqpsdjSbw/download -O mdl3_masks.npz
!wget -nc https://cloud.imi.uni-luebeck.de/s/yAZNkTBRGoeePZa/download -O mdl3_imgs.pth

# Download feature data for testing task 1 & 2
!wget -nc https://cloud.imi.uni-luebeck.de/s/nSixsneJ6fDbfBB/download -O mdl3_exercise_task12.pth

# Download and import an additional python file providing utility functions
!wget -nc https://cloud.imi.uni-luebeck.de/s/X8A8Dixgj62Qwtf/download -O mdl_exercise3_utils.py

from mdl_exercise3_utils import *

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.8.0
--2023-05-31 11:31:55--  https://cloud.imi.uni-luebeck.de/s/76KJ7RBqpsdjSbw/download
Resolving cloud.imi.uni-luebeck.de (cloud.imi.uni-luebeck.de)... 141.83.20.118
Connecting to cloud.imi.uni-luebeck.de (cloud.imi.uni-luebeck.de)|141.83.20.118|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1638735 (1.6M) [application/octet-stream]
Saving to: ‘mdl3_masks.npz’


2023-05-31 11:31:56 (2.43 MB/s) - ‘mdl3_masks.npz’ saved [1638735/1638735]

--2023-05-31 11:31:56--  https://cloud.imi.uni-luebeck.de/s/yAZNkTBRGoeePZa/download
Resolving cloud.imi.uni-luebeck.de (cloud.imi.uni-luebeck.de)... 141.83.20.118
Connecting to cloud.imi.uni-luebeck.de (cloud.imi.uni-luebeck.de)|141.83.20.118|:443... connected.
HTTP request sen

In [2]:
# some parameters

grid_step = 12
disp_radius = 4
disp_step = 5
beta = 25

W = D = 192
H = 160

In [3]:
def least_trimmed_rigid(fixed_pts, moving_pts, iter=5):
    idx = torch.arange(fixed_pts.shape[0]).to(fixed_pts.device)
    for i in range(iter):
        x = find_rigid_3d(fixed_pts[idx,:], moving_pts[idx,:]).t()
        residual = torch.sqrt(torch.sum(torch.pow(moving_pts - torch.mm(fixed_pts, x), 2), 1))
        _, idx = torch.topk(residual, fixed_pts.shape[0]//2, largest=False)
    return x.t().to(fixed_pts.dtype)

def find_rigid_3d(x, y):
    x_mean = x[:, :3].mean(0)
    y_mean = y[:, :3].mean(0)
    u, s, v = torch.svd(torch.matmul((x[:, :3]-x_mean).t(), (y[:, :3]-y_mean)).float())
    m = torch.eye(v.shape[0], v.shape[0]).to(x.device)
    m[-1,-1] = torch.det(torch.matmul(v, u.t()).float())
    rotation = torch.matmul(torch.matmul(v, m), u.t())
    translation = y_mean - torch.matmul(rotation, x_mean)
    T = torch.eye(4).to(x.device)
    T[:3,:3] = rotation
    T[:3, 3] = translation
    return T

def generate_random_rigid_3d(strength=.3):
    x = torch.randn(12,3).to(device)
    y = x + strength*torch.randn(12,3).to(device)
    return find_rigid_3d(x, y)

disp = torch.stack(torch.meshgrid(
    torch.arange(- disp_step * disp_radius, disp_step * disp_radius + 1, disp_step),
    torch.arange(- disp_step * disp_radius , disp_step * disp_radius  + 1, disp_step),
    torch.arange(- disp_step * disp_radius , disp_step * disp_radius  + 1, disp_step))).permute(1, 2, 3, 0).contiguous().view(1, 1, -1, 1, 3).float()

disp = (disp.flip(-1) * 2 / (torch.tensor([W, H, D]) - 1))#.to(dtype).to(device)

def get_displacement():
    return disp
    
disp_width = disp_radius * 2 + 1

#finding 50% best matches and computing soft-correspondences is provided as "robust_rigid_fit"
#in initial notebook, it receives the ssd_cost tensor N x 729 from Task 2 and returns a 4x4 matrix R

def robust_rigid_fit(ssd_cost, kpts_fixed, feat_fix):
    ssd_cost = ssd_cost.view(1,-1,(disp_radius*2+1)**3)
    kpts_fixed = kpts_fixed.view(1,-1,3)
    #mask_fix, feat_fix, feat_mov):#, grid_step, disp_radius, disp_step, beta=15):
    #use predefined set of displacements
    disp1 = disp.to(ssd_cost.device).to(ssd_cost.dtype)
    
    #remove 50% least reliable control points based on the minimum cost of their respective values
    ssd_val, ssd_idx = torch.min(ssd_cost.squeeze(), 1)
    idx_best = torch.sort(ssd_val, dim=0, descending=False)[1][:kpts_fixed.shape[1]//2]
    #compute a weighted soft correspondence (displacement)
    #this step is crucial to keep the loss differentiable!
    disp_best = torch.sum(torch.softmax(-beta*ssd_cost.squeeze(0).unsqueeze(2),1) * disp1.view(1, -1, 3), 1)
    disp_best = disp_best[idx_best,:]
    
    #compute absolute coordinates for coresspondences and run least trimmed squares fitting
    fixed_pts = torch.cat((kpts_fixed[0,idx_best,:], torch.ones(idx_best.size(0),1).to(feat_fix.device).to(feat_fix.dtype)),1)
    moving_pts = torch.cat((kpts_fixed[0,idx_best,:] + disp_best, torch.ones(idx_best.size(0),1).to(feat_fix.device).to(feat_fix.dtype)),1)
    with torch.cuda.amp.autocast(enabled=False): #SVD is not available/stable with FP16
        R = least_trimmed_rigid(fixed_pts.float(), moving_pts.float())
    return R[:3,:4].unsqueeze(0)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [4]:

# Run all required imports

import os
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import numpy as np
from tqdm.notebook import tqdm

In [5]:
#load data and create eight augmented versions each (64 CT/MR pairs in total)

mdl3_imgs = torch.load('mdl3_imgs.pth')
mdl3_masks = np.load('mdl3_masks.npz')

def load_case(case):
    img_fix = mdl3_imgs['mdl3_img_fix'][case].float().cpu()
    img_mov = mdl3_imgs['mdl3_img_mov'][case].float().cpu()
    mask_fix = torch.from_numpy(mdl3_masks['mdl3_mask_fix'][case]).cpu()
    mask_mov = torch.from_numpy(mdl3_masks['mdl3_mask_mov'][case]).cpu()
    seg_fix = torch.from_numpy(mdl3_masks['mdl3_seg_fix'][case]).cpu().long()
    seg_mov = torch.from_numpy(mdl3_masks['mdl3_seg_mov'][case]).cpu().long()
    return img_fix, img_mov, seg_fix, seg_mov, mask_fix, mask_mov


TRAIN_CASES = torch.arange(8) 

imgs_fix_train = torch.zeros(len(TRAIN_CASES), 8, D, H, W).float()
imgs_mov_train = torch.zeros(len(TRAIN_CASES), 1, D, H, W).float()
segs_fix_train = torch.zeros(len(TRAIN_CASES), 8, D, H, W).int()
segs_mov_train = torch.zeros(len(TRAIN_CASES), 1, D, H, W).int()
masks_fix_train = torch.zeros(len(TRAIN_CASES), 8, D, H, W).bool()
masks_mov_train = torch.zeros(len(TRAIN_CASES), 1, D, H, W).bool()
for i, case in enumerate(tqdm(TRAIN_CASES, desc='load cases')):
    img_fix, img_mov, seg_fix, seg_mov, mask_fix, mask_mov = load_case(case)
    device = img_fix.device
    img_fix = img_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    img_mov = img_mov.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    seg_fix = seg_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    seg_mov = seg_mov.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    mask_fix = mask_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    mask_mov = mask_mov.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    
    imgs_mov_train[i:i+1] = img_mov
    segs_mov_train[i:i+1] = seg_mov
    masks_mov_train[i:i+1] = mask_mov
    for j in range(8):
        with torch.no_grad():
            R = generate_random_rigid_3d()
            grid = F.affine_grid(R[:3,:4].unsqueeze(0).cuda(), (1,1,D,H,W))
            img_fix_ = F.grid_sample(img_fix.cuda(), grid, padding_mode='border')
            seg_fix_ = F.grid_sample(F.one_hot(seg_fix[0, 0]).permute(3, 0, 1, 2).unsqueeze(0).float().cuda(), grid).argmax(1, keepdim=True).int()
            mask_fix_ = F.grid_sample(mask_fix.float().cuda(), grid)>0.5

            imgs_fix_train[i:i+1, j:j+1] = img_fix_.cpu()
            segs_fix_train[i:i+1, j:j+1] = seg_fix_.cpu()
            masks_fix_train[i:i+1, j:j+1] = mask_fix_.cpu()

load cases:   0%|          | 0/8 [00:00<?, ?it/s]



### Task 1: Computation of (joint) histograms and mutual information (25 points)

Let's start with implementing mutual information as a loss function, that we can use later for supervised training. All input tensors have the shape $[N\times C\times D\times H\times W]$.

* The values $v_i$ should be sampled of the (fixed) mask to exclude background locations.

* Define a range for the histogram bins $b_i$ with `torch.linspace` with 64 steps and the minimum and maximum value of the respective image.

* Compute the histograms $h_i$ using a Parzen window weighting with $\sigma=0.015$: $$h_i=\exp\left(-\frac{(v_i-b_i)^2}{2 \cdot \sigma^2}\right)$$ (Broadcasting becomes again very handy here).

* After this step your $h_{\text{fix}}$ and $h_{\text{mov}}$ should have a shape of  $[64 \times N]$, where $N$ is the number of pixels (could be obtained with the sum over `mask_fix`).

* Calculate the marginal (individual) $\rho_i$ by summing/averaging over the pixels and normalizing the resulting vector by its sum (add a small $\epsilon=1e-10$ for numeric stability). For the joint histogram the pairwise sums are implicitly obtained using a matrix multiplication of $h_{\text{fix}}$ and transposed $h_{\text{mov}}$. Do not forget to calculate its marginal too.

* Calculate the entropy $E_i$ as $$E_i=-\sum \rho_i \cdot \log_2(\rho_i + \epsilon)$$

* Return the mutual information loss `mi`$= -(E_\text{fix} + E_\text{mov} - E_\text{joint})$

* Test your implementation with the testcase below. It should yield `tensor(-0.0271)`.

**Helpful functions:** `torch.linspace, torch.max, torch.min, torch.exp, torch.pow, torch.sum`

In [93]:
# all tensor are of shape [1, 1, 192, 160, 192].
# We convert mask_fix to boolean, so you can use it directly for indexing.
def mutual_inf(mask_fix, img_fix, img_mov):
    sigma = 0.015
    epsilon = 1e-10
    mask_fix = mask_fix.to(torch.bool)
    
    # TODO: draw samples from mask_fix
    v_fix = img_fix[mask_fix]
    v_mov = img_mov[mask_fix]

    # TODO: define bins
    with torch.no_grad():
        b_fix = torch.linspace(start=torch.min(v_fix), end=torch.max(v_fix), steps=64).cuda()
        b_mov = torch.linspace(start=torch.min(v_mov), end=torch.max(v_mov), steps=64).cuda()

    # TODO: estimate histograms
    h_fix = torch.exp(-((torch.pow(v_fix.reshape(1,-1) - b_fix.view(-1,1), 2) / (2 * pow(sigma,2)))))
    h_mov = torch.exp(-((torch.pow(v_mov.reshape(1,-1) - b_mov.view(-1,1), 2) / (2 * pow(sigma,2)))))
    #print(h_fix.shape)
    #print(torch.sum(mask_fix))
    
    # TODO: estimate marginal
    p_fix = torch.mean(h_fix, dim=1) / (torch.sum(torch.mean(h_fix,dim=1)) + epsilon)
    p_mov = torch.mean(h_mov, dim =1) / (torch.sum(torch.mean(h_mov,dim=1)) + epsilon)
    s = torch.matmul(h_fix, torch.transpose(h_mov,1,0))
    p_joint = s / (torch.sum(s) + epsilon)
    

    # TODO: estimate entropies
    E_fix = - torch.sum(p_fix * torch.log2(p_fix + epsilon))
    E_mov = - torch.sum(p_mov * torch.log2(p_mov + epsilon))
    E_joint = - torch.sum(p_joint * torch.log2(p_joint + epsilon))
    mi = -(E_fix + E_mov - E_joint)
    return mi


# testing
task12 = torch.load('mdl3_exercise_task12.pth')
mutual_inf(mask_fix.cuda(), task12['img_fix'].float().cuda(), img_mov.cuda())

tensor(-0.1047, device='cuda:0')

### Task 2: Correlation layer (35 points)
In this task you should compute similarity scores (sum of squared differences) for a set of grid points and a wide range of potential displacements. Your function takes fixed and moving features as well as a mask for the fixed scan as input. The predefined function `get_displacement()` will return a tensor that specifies $9\times9\times9=729$ 3D displacements with a maximal range of 20 voxels in each direction.

* Define a grid of control points `kpts_fixed` using `F.affine_grid` with `align_corners=True`, a spacing of 12 voxels and range of .925 (to exclude points near the image boundaries). Move it to the same device as `mask_fix`.

* Now use `kpts_fixed` with `f.grid_sample` on `mask_fix` to create a down sampled version of it. Use again `align_corners=True` and think of the correct interpolation mode. Cast the result back to boolean.

* Next, we want to exclude all key points in `kpts_fixed` that are pointing on the image background. Therefor, view it as $[1, N, 3]$ and index it with `mask_fix_downsampled` also viewed as a 1D tensor.

* To store the similarity scores, we create an emtpy tensor `ssd` of shape $[N\times 9^3]$ on the same device as `feat_fixed`.

* Use 32 chunks for unrolling and compute the (dis)similarity as follows:
    * get the current subset of grid points `subsampled_kpts_fixed` by indexing `kpts_fixed`with `idx` and view it so all key points are in $D_\text{out}$ (have a look into `F.grid_sample` for the detailed shape).
    * sample the fixed features `feat_patch_fixed` at (the current subset of) grid points and the moving feature tensor `feat_displacements` at the combined (added) coordinates of absolute grid points and relative displacements.
   + Square and sum over the channel dimension (C) and save it to the indexed `ssd` tensor.    

The final registration is done with the provided function `robust_rigid_fit` (you do not need to call it here), which searches for the most probable correspondences, by filtering out potentially erroneous ones based on their similarity score and the residual of a globally rigid least-square fit.

**Helpful notes:**:
+ functions: `F.affine_grid, F.grid_sample, view`
+ chunk along with view and common indexing, because we will unroll the function call (computed with a loop of several chunks) to save memory.

In [69]:
# Task2 correlation layer
def correlation(mask_fix, feat_fixed, feat_moving):

    # parameters
    grid_step = 12 ### <---
    grid_range = 0.925 ### <---
    B, _, D, H, W = mask_fix.shape  # -> [1, 1, 192, 160, 192]

    # TODO: create grid
    theta = torch.eye(3, 4, device=mask_fix.device).unsqueeze(0) * grid_range
    kpts_fixed = F.affine_grid(theta, size=(B, 1, D//grid_step, H//grid_step, W//grid_step), align_corners=True)

    # TODO: sample from mask_fix
    mask_fix_downsampled = F.grid_sample(mask_fix.float(), kpts_fixed, align_corners=True, mode='bilinear').bool()
    #print('mask_fix_downsampled.shape', mask_fix_downsampled.shape)
    # TODO: exclude all invalid coordinates/grid points
    #print('kpts_fixed.shape', kpts_fixed.shape)
    kpts_fixed = kpts_fixed.view(1,-1,3)
    kpts_fixed = kpts_fixed[:,mask_fix_downsampled.view(-1),:]
    #print('kpts_fixed.shape', kpts_fixed.shape)

    # TODO: create empty tensor
    N = mask_fix_downsampled.sum()
    ssd = torch.zeros((N,9**3),device=feat_fixed.device)

    unroll_factor = 32
    displacements = get_displacement().to(feat_fixed.device).to(feat_fixed.dtype)
    for idx in tqdm(torch.tensor_split(torch.arange(B), unroll_factor), disable=True):

        # TODO: get subset of grid points
        #print('idx', idx)
        subsampled_kpts_fixed = kpts_fixed[:,idx,:].view(1,-1,1,1,3)
        #print('subsampled_kpts_fixed.shape', subsampled_kpts_fixed.shape) # 1,40,1,1,3

        # sample the fixed features feat_patch_fixed at (the current subset of) grid points and
        # the moving feature tensor feat_displacements at
        # the combined (added) coordinates of absolute grid points and relative displacements.
        # Square and sum over the channel dimension (C) and save it to the indexed ssd tensor.
        # TODO: sample from features
        feat_patch_fixed = F.grid_sample(feat_fixed, subsampled_kpts_fixed, align_corners=True, mode='bilinear')
        feat_displacements = F.grid_sample(feat_moving, subsampled_kpts_fixed + displacements, align_corners=True, mode='bilinear')
        
        # has to be fulfilled
        assert list(feat_patch_fixed.shape) == [B, 64, len(idx), 1, 1]
        assert list(feat_displacements.shape) == [B, 64, len(idx), 729, 1]
        
        # TODO: calculate similarity
        ssd[idx] = torch.sum((feat_patch_fixed - feat_displacements)**2).squeeze()
        #ssd[idx] = torch.sum(diff**2, dim=0)
        
    return ssd.unsqueeze(0), kpts_fixed

# testing
correlation(task12['mask_fix'],task12['feat_fix'].to(torch.float),task12['feat_mov'].to(torch.float))

(tensor([[[607.7994, 607.7994, 607.7994,  ..., 607.7994, 607.7994, 607.7994],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          ...,
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
          [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]]]),
 tensor([[[ 4.3167e-01, -6.1667e-01, -8.0167e-01],
          [-6.7833e-01, -4.6250e-01, -8.0167e-01],
          [ 4.3167e-01, -4.6250e-01, -8.0167e-01],
          ...,
          [ 4.3167e-01, -2.7567e-08,  9.2500e-01],
          [ 5.5500e-01, -2.7567e-08,  9.2500e-01],
          [ 3.0833e-01,  1.5417e-01,  9.2500e-01]]]))

**Optional**: The following code enables to verify your solutions for tasks 1 & 2 improving the mean dice score from ~0.49 to ~0.74:

In [94]:
#test for task1 and task2

task12 = torch.load('mdl3_exercise_task12.pth')

print(task12.keys())
with torch.no_grad():
    mask_fix = task12['mask_fix'].float().cuda()
    img_mov = task12['img_mov'].float().cuda()
    #call mutual information before transform
    loss0 = mutual_inf(mask_fix, task12['img_fix'].float().cuda(), img_mov)
    print('mi before',loss0)

    with torch.cuda.amp.autocast(enabled=True):
            #call your own implementation of correlation layer
            feat_fix = task12['feat_fix'].half().cuda()
            feat_mov = task12['feat_mov'].half().cuda()
            cost,kpts_fixed = correlation(mask_fix,feat_fix,feat_mov)
            
            #provided function for robust fitting
            R = robust_rigid_fit(cost.cuda(),kpts_fixed.cuda(),feat_fix)
            
    #mutual information requires 32bit precision
    grid = F.affine_grid(R, (1,1,D,H,W))
    img_warped = F.grid_sample(img_mov,grid.float(),mode='bilinear')
    loss1 = mutual_inf(mask_fix,task12['img_fix'].float().cuda(),img_warped)
    print('mi after',loss1)  
    seg_fix = task12['seg_fix'].float().cuda()
    seg_mov = task12['seg_mov'].float().cuda()

    seg_mov_warped = F.grid_sample(seg_mov.float(), grid, mode='nearest')

    d0 = dice_coeff(seg_fix.cpu(),seg_mov.cpu(),5) 
    print(d0)
    print('mean dice before: ',d0.mean().item())
    d1 = dice_coeff(seg_fix.cpu(),seg_mov_warped.cpu(),5)
    print(d1)
    print('mean dice after: ', d1.mean().item())
    print()


dict_keys(['feat_fix', 'feat_mov', 'seg_mov', 'mask_fix', 'R', 'seg_fix', 'img_fix', 'img_mov'])
mi before tensor(-0.1047, device='cuda:0')




mi after tensor(-0.1047, device='cuda:0')
tensor([0.6773, 0.4854, 0.5103, 0.2932])
mean dice before:  0.49155986309051514
tensor([0.6773, 0.4854, 0.5103, 0.2932])
mean dice after:  0.49155986309051514



### Task 3: Define network and train CNNs with mutual information (40 points)
To extract suitable features for registration both scans should be fed into a 3D CNN with trainable parameters. To  account for the differences between CT and MRI, but also encourage the learning of multi-modal relationships you should build your network based on three modules: one for CT, one for MRI (each with two blocks) and one that is shared for both (with three blocks).

#### CIR block
First, let's build a `CIR` function that implement two convolutional building block and simply return them as a `nn.Sequential`. A convolutional building block consists of
+ `nn.Conv3d` with `kernel_size=3` and padding mode 'same'.
+ `nn.InstanceNorm3d`
+ `nn.LeakyReLU`

The first convolution should get the option to have a `stride=2`.

In [95]:
# TODO
def CIR(in_channels:int , out_channels:int, stride=1):
    return nn.Sequential(
        nn.Conv3d(in_channels=in_channels,out_channels=out_channels, kernel_size=2,stride=2,padding=8),
        nn.InstanceNorm3d(out_channels),
        nn.LeakyReLU(),

        nn.Conv3d(in_channels=in_channels,out_channels=out_channels, kernel_size=3,stride=stride,padding=8),
        nn.InstanceNorm3d(out_channels),
        nn.LeakyReLU()
    )

#### ModalityNet
The modality specific modules should both receive a single-channel input and start with `base=16` feature maps that are doubled to 32 in the second `CIR` block. 

In [96]:
# TODO
class ModalityNet(nn.Module):
    def __init__(self, base):
        super(ModalityNet, self).__init__()
        self.base_channels = 16
        self.cir1 = CIR(self.base_channels,16)
        self.cir2 = CIR(16, 32)

    def forward(self, x):
        x = self.cir1(x)
        x = self.cir2(x)

        return x

#### SharedNet
The shared subnetwork should receive the 32-channel feature maps and comprise three blocks of the same pattern as above: doubling of channels and `stride=2` in second block. The last block keeps the channel dimension and has no stride. This yields 64-channel feature that will be mapped into a range of 0 to 1 using a `nn.Sigmoid`. These feature tensors (for fixed = MRI and moving = CT) will be the input of your correlation layer from Task 2.

In [97]:
# TODO
class SharedNet(nn.Module):
    def __init__(self, base, out_channels):
        super(SharedNet, self).__init__()

        self.cir1 = CIR(in_channels=32, out_channels=32)
        self.cir2=CIR(32,64,stride=2)
        self.cir3=CIR(64,64,stride=0)
        self.sig=nn.Sigmoid()


    def forward(self, x):
        x= self.cir1(x)
        x=self.cir2(x)
        x=self.cir3(x)
        x=self.sig(x)

        return x   

#### FeatureNet

The `FeatureNet` joins the two individual `ModalityNet` and the `SharedNet`. The last one will be - as its name suggests - shared between the two modalities.

In [98]:
class FeatureNet(nn.Module):
    def __init__(self):
        super(FeatureNet, self).__init__()
        
        base = 16
        out_channels = 64
        
        self.modality1_net = ModalityNet(base)
        self.modality2_net = ModalityNet(base)
        self.shared_net = SharedNet(base, out_channels)

    def forward(self, x, y):
        x = self.modality1_net(x)
        y = self.modality2_net(y)
        x = self.shared_net(x)
        y = self.shared_net(y)
        return x, y
    

During training, the resulting SSD tensor is again fed into robust_rigid_fit, which return a 1x3x4 matrix to obtain a deformation grid that will be applied to the moving scan. Afterwards the mutual information can be computed and minimised as loss. The training loop is pre-defined, we use a batch-size of 1 (hence InstanceNorm) and train for 30 epochs (240 iterations).

During training, you see visualisations of the estimated transforms (using it to warp the moving
segmentation) and after training you can run the provided evaluation functions and see that the Dice overlap should increase from 43% to above 60%.

**Optional**: Plot the initial dice values

In [99]:
get_dice_all(TRAIN_CASES,segs_fix_train,segs_mov_train)

Case:  tensor(0)
Initial Dice: 0.60, 0.36, 0.66, 0.00 (mean: 0.40)
--

Case:  tensor(1)
Initial Dice: 0.61, 0.21, 0.08, 0.00 (mean: 0.23)
--

Case:  tensor(2)
Initial Dice: 0.48, 0.08, 0.28, 0.15 (mean: 0.25)
--

Case:  tensor(3)
Initial Dice: 0.72, 0.28, 0.31, 0.12 (mean: 0.36)
--

Case:  tensor(4)
Initial Dice: 0.63, 0.26, 0.27, 0.08 (mean: 0.31)
--

Case:  tensor(5)
Initial Dice: 0.81, 0.38, 0.59, 0.54 (mean: 0.58)
--

Case:  tensor(6)
Initial Dice: 0.57, 0.23, 0.32, 0.21 (mean: 0.33)
--

Case:  tensor(7)
Initial Dice: 0.61, 0.43, 0.34, 0.31 (mean: 0.42)
--

Initial Dice (all): 0.63, 0.28, 0.36, 0.18 (mean: 0.36)


In [105]:
# training loop 

num_epochs = 30
init_lr = 0.001
device = 'cuda'

net = FeatureNet().to(device)
parameter_count(net)
optimizer = optim.Adam(net.parameters(), lr=init_lr)

scaler = torch.cuda.amp.GradScaler()
losses = torch.zeros(num_epochs)

for epoch in range(num_epochs):
    net.train()
    torch.cuda.synchronize()
    t0 = time.time()
    running_loss = 0
    rand_idx = torch.randperm(len(TRAIN_CASES))
    for idx in rand_idx:
        optimizer.zero_grad()
        
        rand_idx1 = torch.randint(8, (1,))[0]
        img_fix = imgs_fix_train[idx:idx+1, rand_idx1:rand_idx1+1].to(device,non_blocking=True)# + 1
        img_mov = imgs_mov_train[idx:idx+1].to(device,non_blocking=True)
        seg_fix = segs_fix_train[idx:idx+1, rand_idx1:rand_idx1+1].long().to(device,non_blocking=True)
        seg_mov = segs_mov_train[idx:idx+1].to(device,non_blocking=True).long()
        mask_fix = masks_fix_train[idx:idx+1, rand_idx1:rand_idx1+1].to(device,non_blocking=True)
        mask_mov = masks_mov_train[idx:idx+1].to(device,non_blocking=True)
        
        with torch.cuda.amp.autocast(enabled=True):

            #TODO: call your own implementation of network architecture and correlation layer
            feat_fix, feat_mov = task12['feat_fix'].half().cuda(), task12['feat_mov'].half().cuda()
            cost,kpts_fixed = correlation(mask_fix, feat_fix, feat_mov)
            
            #provided function for robust fitting
            R = robust_rigid_fit(cost,kpts_fixed,feat_fix)
            
        #mutual information requires 32bit precision
        with torch.cuda.amp.autocast(enabled=False): 
            grid = F.affine_grid(R, (1,1,D,H,W))
            img_warped = F.grid_sample(img_mov,grid.float(),mode='bilinear')
            
            #TODO: call your own implementation for mutual information loss
            loss = mutual_inf(mask_fix.cuda(), task12['img_fix'].float().cuda(), img_mov.cuda())

        seg_mov_warped = F.grid_sample(F.one_hot(seg_mov, 5).view(1, D, H, W, -1).permute(0, 4, 1, 2, 3).float(), grid.float(), mode='bilinear').argmax(1)
        if(rand_idx1==4):
            plt.figure()
            q100 = float(torch.topk(img_fix[0,0,:,60,:].reshape(-1),100)[0].cpu().data[-1:])
            gray1 = torch.clamp(img_fix[0,0,:,60,:].data.cpu().t().flip([0,1]),0,q100)/q100
            rgb = overlaySegment(gray1,seg_mov_warped[0,:,60,:].long().data.cpu().t().flip([0,1]))
            plt.imshow(rgb)
            plt.show()
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
    
    running_loss /= len(TRAIN_CASES)
    losses[epoch] = running_loss
    torch.cuda.synchronize()
    t1 = time.time()

    print('epoch (train): {:02d} -- loss: {:.3f} -- time(s): {:.1f}'.format(epoch, running_loss, t1-t0))
    gpu_usage()

# parameters: 305152




RuntimeError: ignored

Plot the loss and store network weights:

In [None]:
plt.plot(losses)
FOLD = 3
torch.save(net.cpu().state_dict(), 'net_mi_fold{}.pth'.format(FOLD))

**Evaluation**: Run the following code for a final evaluation. You dice should increase from 43% to above 60%.

In [None]:
#quantitaive evaluation of the trained network, should return around 60% Dice after registration
net = FeatureNet().to(device)
net.load_state_dict(torch.load('net_mi_fold{}.pth'.format(FOLD)))
net.eval()
parameter_count(net)

torch.manual_seed(30)
TEST_CASES = TRAIN_CASES
with torch.no_grad():
    dice_init_all = torch.zeros(4,len(TEST_CASES),4)
    dice_all = torch.zeros(4,len(TEST_CASES),4)
    for i in range(4):
        for j, case in enumerate(TEST_CASES):
            img_fix, img_mov, seg_fix, seg_mov, mask_fix, mask_mov = load_case(case)

            img_fix = img_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)# + 1
            img_mov = img_mov.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
            seg_fix = seg_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
            seg_mov = seg_mov.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
            mask_fix = mask_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
            mask_mov = mask_mov.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)

            R = generate_random_rigid_3d()
            grid = F.affine_grid(R[:3,:4].unsqueeze(0), (1,1,D,H,W))
            img_fix_ = F.grid_sample(img_fix, grid)
            seg_fix_ = F.grid_sample(F.one_hot(seg_fix[0, 0]).permute(3, 0, 1, 2).unsqueeze(0).float(), grid).argmax(1, keepdim=True)
            mask_fix_ = (F.grid_sample(mask_fix.float(), grid)>0.5).float()
            with torch.cuda.amp.autocast():
                feat_fix, feat_mov = net(img_fix_.contiguous(), img_mov.contiguous())
                ssd_cost,kpts_fix = correlation(mask_fix_,feat_fix,feat_mov)
            R = robust_rigid_fit(ssd_cost,kpts_fix,feat_fix)
            
            grid = F.affine_grid(R, (1,1,D,H,W))
            seg_mov_warped = F.grid_sample(seg_mov.float(), grid, mode='nearest')

            d = dice_coeff(seg_fix.cpu(),seg_mov.cpu(),5); print(d,d.mean())
            d0 = dice_coeff(seg_fix_.cpu(),seg_mov.cpu(),5); print(d0,d0.mean())
            d1 = dice_coeff(seg_fix_.cpu(),seg_mov_warped.cpu(),5); print(d1,d1.mean())
            print()

            dice_init_all[i, j] = d0
            dice_all[i, j] = d1

In [None]:
# Print dice 

print('Initial dice: ', (dice_init_all.sum()/(dice_init_all>0).sum()).item())
print('Dice after reg:', (dice_all.sum()/(dice_init_all>0).sum()).item())

### Bonus Task: Deformable Image Registration

Deformable image registration is the process of finding correspondence between images that are not linked by simple rigid shifts and rotations. It is more realistic setup, because patients have many degrees of freedom and can move and deform due to many processes including simply being lying in a slightly different position from day to day, weight loss, tumor shrinkage, normal tissue shrinkage, inflammation of normal tissue, and motion due to respiration.

In the bonus task we will use method proposed in [Label-driven weakly-supervised learning for multimodal deformable image registration](https://arxiv.org/pdf/1711.01666.pdf) paper to register CT images taken at inspiration and expiration respiratory phases of the same patient.

You can notice that in the provided template we are using several methods from [MONAI](https://docs.monai.io/en/stable/). MONAI is a PyTorch-based, open-source framework for deep learning in healthcare imaging. It contains variety of very useful tools that can save a lot of time during development of deep-learning based solution for medical image processing. You are more than welcome to explore it.

#### Environment

Import required packages:

In [None]:
import torchinfo
from torch.nn import MSELoss

from monai.apps import download_and_extract

from monai.networks.blocks import Warp

from monai.losses import DiceLoss, BendingEnergyLoss
from monai.metrics import DiceMetric

from monai.data import DataLoader, Dataset
from monai.transforms import Compose, LoadImaged, Resized, ScaleIntensityRanged
from monai.utils import first

#### Data

Download and extract the dataset:

In [None]:
resource = "https://zenodo.org/record/3835682/files/training.zip"

compressed_file = "paired_ct_lung.zip"
data_dir = "paired_ct_lung"
if not os.path.exists(os.path.join('.', data_dir)):
    download_and_extract(resource, compressed_file)
    os.rename(os.path.join('.', "training"), data_dir)

Create training and validation data dictionaries:

In [None]:
data_dicts = [
    {
        "fixed_image": os.path.join(data_dir, "scans/case_%03d_exp.nii.gz" % idx),
        "moving_image": os.path.join(data_dir, "scans/case_%03d_insp.nii.gz" % idx),
        "fixed_label": os.path.join(data_dir, "lungMasks/case_%03d_exp.nii.gz" % idx),
        "moving_label": os.path.join(data_dir, "lungMasks/case_%03d_insp.nii.gz" % idx),
    }
    for idx in range(1, 21)
]

train_files, val_files = data_dicts[:18], data_dicts[18:]

Define data processing pipeline using MONAI transforms:
- LoadImaged: loads the lung CT images and labels from NIfTI format files, "ensure_channel_first=True" ensure that the first dim is channel.
- ScaleIntensityRanged: extracts intensity range [-285, 3770] and scales to [0, 1].
- Resized: resize images to the same size.

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["fixed_image", "moving_image", "fixed_label", "moving_label"], ensure_channel_first=True),
        ScaleIntensityRanged(
            keys=["fixed_image", "moving_image"],
            a_min=-285,
            a_max=3770,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        Resized(
            keys=["fixed_image", "moving_image", "fixed_label", "moving_label"],
            mode=("trilinear", "trilinear", "nearest", "nearest"),
            align_corners=(True, True, None, None),
            spatial_size=(96, 96, 104),
        ),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["fixed_image", "moving_image", "fixed_label", "moving_label"], ensure_channel_first=True),
        ScaleIntensityRanged(
            keys=["fixed_image", "moving_image"],
            a_min=-285,
            a_max=3770,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        Resized(
            keys=["fixed_image", "moving_image", "fixed_label", "moving_label"],
            mode=("trilinear", "trilinear", "nearest", "nearest"),
            align_corners=(True, True, None, None),
            spatial_size=(96, 96, 104),
        ),
    ]
)

Show example of inspiration and expiration respiratory phases of the same patient:

In [None]:
check_ds = Dataset(data=train_files, transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=1)
check_data = first(check_loader)
fixed_image = check_data["fixed_image"][0][0].permute(1, 0, 2)
fixed_label = check_data["fixed_label"][0][0].permute(1, 0, 2)
moving_image = check_data["moving_image"][0][0].permute(1, 0, 2)
moving_label = check_data["moving_label"][0][0].permute(1, 0, 2)

print(f"moving_image shape: {moving_image.shape}, " f"moving_label shape: {moving_label.shape}")
print(f"fixed_image shape: {fixed_image.shape}, " f"fixed_label shape: {fixed_label.shape}")

plt.figure("check", (12, 6))
plt.subplot(1, 4, 1)
plt.title("moving_image")
plt.imshow(moving_image[:, :, 50], cmap="gray")
plt.subplot(1, 4, 2)
plt.title("moving_label")
plt.imshow(moving_label[:, :, 50])
plt.subplot(1, 4, 3)
plt.title("fixed_image")
plt.imshow(fixed_image[:, :, 50], cmap="gray")
plt.subplot(1, 4, 4)
plt.title("fixed_label")
plt.imshow(fixed_label[:, :, 50])

plt.show()
plt.show()

Create training and validation data loaders:

In [None]:
train_ds = Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4)

val_ds = Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

#### Model

A neural network takes as an input concatenation of moving and fixed images on channel dimension and computes a dense displacement field (DDF). Moving image warped with the DDF should correspond to fixed image. We will refer to this neural network as "LocalNet" because it computes local non-rigid deformations.

LocalNet designed as a 3D convolutional neural network with three residual blocks.

Each residual block consists of three convolution groups:
- a `nn.Conv3d` with same padding.
- a `nn.InstanceNorm3d` for normalisation.
- a `nn.LeakyReLU` as non-linearity.

Residual connection is branched out after first convolution group and summed up with output of normalization layer of third convolution group.
Number of channels is changed by first convolution group and remains the same after that.
Kernel size is the same for all convolution groups inside the residual block.
Output of third convolution group downsampled by factor 2 using`nn.MaxPool3d` .

Implement residual block:


In [None]:

class ResidualEncoderBlock(nn.Module):

    # TODO

    def __init__(self, in_channels, out_channels, kernel_size):

        super().__init__()


    def forward(self, x: torch.Tensor) -> torch.Tensor:

        return x

First residual block increases number of channels to 32, second - to 64, and third - to 128.
Kernel size of first residual block is equal to 7, while for second and third residual blocks kernel size are equal to 3.
Output of third residual block is processed by addition convolution group:
- a `nn.Conv3d` with number of output channels equal to 256, kernel size equal to 3 and same padding.
- a `nn.InstanceNorm3d` for normalisation.
- a `nn.LeakyReLU` as non-linearity.

Output of convolution group is processed by another convolution layer (with kernel size 3 and same padding) that reduces number of channels to 3 and computes coarse dense displacement map.
Finally, coarse displacement map is interpolated to match dimensions of moving image and get displacement for its each voxel.

Implement LocalNet:

In [None]:
class LocalNet(nn.Module):

    # TODO

    def __init__(self):
        super().__init__()

    def forward(self, x):

        return x

Instantiate LocalNet and print its summary.
The number of trainable parameters should be equal to 3,013,219.

In [None]:
model = LocalNet()
torchinfo.summary(model, (1, 2, 96, 96, 104))

#### Training loop

During the training process we can use loss function that measures the correspondence based on:
- voxel-level labels - more intuitive way, it can be implemented, for example, using MSE loss between warped moving image and fixed image.
- anatomy-level labels - it can be implemented, for example, using Dice loss between warped label of moving image and label of fixed image.

You will compare both of these loss functions.

Below you can find the template of the training loop method.

Pay attention that we are using again handy methods from MONAI library: warp layer, dice loss function and dice metric. You can read more about the usage in [documentation](https://docs.monai.io/en/stable/) of MONAI library.

Complete missing parts of the training loop. You can use forward method, adam optimizer and dice metric defined below in your implementation:

In [None]:
def train_model(loss_function, anatomy_loss=False, max_epochs = 30):

    device = torch.device("cuda:0")

    model = LocalNet().to(device)
    warp_layer = Warp().to(device)
    def forward(batch_data, model):
        fixed_image = batch_data["fixed_image"].to(device)
        moving_image = batch_data["moving_image"].to(device)
        moving_label = batch_data["moving_label"].to(device)

        # predict DDF through LocalNet
        ddf = model(torch.cat((moving_image, fixed_image), dim=1))

        # warp moving image and label with the predicted ddf
        pred_image = warp_layer(moving_image, ddf)
        pred_label = warp_layer(moving_label, ddf)

        return ddf, pred_image, pred_label


    optimizer = torch.optim.Adam(model.parameters(), 1e-5)
    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
    regularization = BendingEnergyLoss()

    best_metric = 0
    epoch_loss_values = []
    metric_values = []
    for epoch in range(max_epochs):

        model.train()
        for batch_data in train_loader:

            # TODO: implement training for current epoch append loss to loss values

            pred_label =
            pred_image =

            if anatomy_loss:
                loss = loss_function(pred_label, fixed_label)
            else:
                loss = loss_function(pred_image, fixed_image)  + 10 * regularization(ddf)

        model.eval()
        with torch.no_grad():
            for val_data in val_loader:

                # TODO: complete validation and append dice to metric values

    return epoch_loss_values, metric_values

Train model with voxel-level loss:

In [None]:
voxel_loss = MSELoss()
epoch_values_with_voxel_loss, metric_values_with_voxel_loss = train_model(voxel_loss, anatomy_loss=False)

Train model with anatomy-level loss:

In [None]:
anatomy_loss = DiceLoss()
epoch_values_with_anatomy_loss, metric_values_with_anatomy_loss = train_model(anatomy_loss, anatomy_loss=True)

Voxel-level labels for correspondence learning are impossible to reliably obtain from medical image data.
In contrast, loss function computed on anatomy-level labels enforce model to learn high level semantic correspondence between images that is much easier to the model to understand during the training.

You should notice the difference:

In [None]:
plt.figure("train", (12, 6))
plt.title("Val Mean Dice")
x = [(i + 1) for i in range(len(metric_values_with_anatomy_loss))]
y_1 = metric_values_with_anatomy_loss
y_2 = metric_values_with_voxel_loss
plt.xlabel("epoch")
plt.plot(x, y_1, label='dice_for_anatomy_loss')
plt.plot(x, y_2, label='dice_for_voxel_loss')
plt.legend()
plt.show()