# Medical Deep Learning


## Exercise 3: Learning-based multi-modal 3D registration (20 points)

### (Hanna Siebert, Marian Himstedt & Matthias Heinrich)

**Please upload your commented solution (`mdl_exercise3.ipynb`) to Moodle by Tuesday 25.5. 23:59.** Check whether the code can be executed after restarting the kernel and running through all cells sequentially.

The aim of this exercise is to introduce you to deep-learning based image registration and also explore mutual information as a metric to supervise the training of multi-modal feature networks. The method we want to implement comprises three parts: 

1.   A global mutual information loss function
2.   A correlation layer to robustly estimate large rigid transformations
3.   A compact 3D CNN network (with some modality specific and some shared layers) to predict features suitable for multi-modal CT/MR registration

**Provided functions and data loading**

The following cells provide the fundamental for loading and augmenting the data. Also you are given a number of functions for the subsequent registration tasks. You are invited to read and understand the code if you are interested.


In [None]:
# Download train and test data
!wget https://cloud.imi.uni-luebeck.de/s/76KJ7RBqpsdjSbw/download -O mdl3_masks.npz
!wget https://cloud.imi.uni-luebeck.de/s/yAZNkTBRGoeePZa/download -O mdl3_imgs.pth

# Download feature data for testing task 1 & 2
!wget https://cloud.imi.uni-luebeck.de/s/nSixsneJ6fDbfBB/download -O mdl3_exercise_task12.pth

# Download an additional python file providing utility functions

!wget https://cloud.imi.uni-luebeck.de/s/X8A8Dixgj62Qwtf/download -O mdl_exercise3_utils.py

from mdl_exercise3_utils import *


--2021-05-21 20:25:36--  https://cloud.imi.uni-luebeck.de/s/76KJ7RBqpsdjSbw/download
Resolving cloud.imi.uni-luebeck.de (cloud.imi.uni-luebeck.de)... 141.83.20.118
Connecting to cloud.imi.uni-luebeck.de (cloud.imi.uni-luebeck.de)|141.83.20.118|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1638735 (1.6M) [application/octet-stream]
Saving to: ‘mdl3_masks.npz’


2021-05-21 20:25:38 (2.11 MB/s) - ‘mdl3_masks.npz’ saved [1638735/1638735]

--2021-05-21 20:25:39--  https://cloud.imi.uni-luebeck.de/s/yAZNkTBRGoeePZa/download
Resolving cloud.imi.uni-luebeck.de (cloud.imi.uni-luebeck.de)... 141.83.20.118
Connecting to cloud.imi.uni-luebeck.de (cloud.imi.uni-luebeck.de)|141.83.20.118|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 188744833 (180M) [application/octet-stream]
Saving to: ‘mdl3_imgs.pth’


2021-05-21 20:25:50 (17.4 MB/s) - ‘mdl3_imgs.pth’ saved [188744833/188744833]

--2021-05-21 20:25:50--  https://cloud.imi.uni-luebeck.de/s/nSix

In [None]:
# some parameters

grid_step = 12
disp_radius = 4
disp_step = 5
beta = 25

W = D = 192
H = 160

In [None]:
def least_trimmed_rigid(fixed_pts, moving_pts, iter=5):
    idx = torch.arange(fixed_pts.shape[0]).to(fixed_pts.device)
    for i in range(iter):
        x = find_rigid_3d(fixed_pts[idx,:], moving_pts[idx,:]).t()
        residual = torch.sqrt(torch.sum(torch.pow(moving_pts - torch.mm(fixed_pts, x), 2), 1))
        _, idx = torch.topk(residual, fixed_pts.shape[0]//2, largest=False)
    return x.t().to(fixed_pts.dtype)

def find_rigid_3d(x, y):
    x_mean = x[:, :3].mean(0)
    y_mean = y[:, :3].mean(0)
    u, s, v = torch.svd(torch.matmul((x[:, :3]-x_mean).t(), (y[:, :3]-y_mean)).float())
    m = torch.eye(v.shape[0], v.shape[0]).to(x.device)
    m[-1,-1] = torch.det(torch.matmul(v, u.t()).float())
    rotation = torch.matmul(torch.matmul(v, m), u.t())
    translation = y_mean - torch.matmul(rotation, x_mean)
    T = torch.eye(4).to(x.device)
    T[:3,:3] = rotation
    T[:3, 3] = translation
    return T

def generate_random_rigid_3d(strength=.3):
    x = torch.randn(12,3).to(device)
    y = x + strength*torch.randn(12,3).to(device)
    return find_rigid_3d(x, y)

disp = torch.stack(torch.meshgrid(torch.arange(- disp_step * disp_radius, disp_step * disp_radius + 1, disp_step),\
                                      torch.arange(- disp_step * disp_radius , disp_step * disp_radius  + 1, disp_step),\
                                      torch.arange(- disp_step * disp_radius , disp_step * disp_radius  + 1, disp_step))).permute(1, 2, 3, 0).contiguous().view(1, 1, -1, 1, 3).float()

disp = (disp.flip(-1) * 2 / (torch.tensor([W, H, D]) - 1))#.to(dtype).to(device)
    
disp_width = disp_radius * 2 + 1

#finding 50% best matches and computing soft-correspondences is provided as "robust_rigid_fit"
#in initial notebook, it receives the ssd_cost tensor N x 729 from Task 2 and returns a 4x4 matrix R

def robust_rigid_fit(ssd_cost,kpts_fixed,feat_fix):
    ssd_cost = ssd_cost.view(1,-1,(disp_radius*2+1)**3)
    kpts_fixed = kpts_fixed.view(1,-1,3)
    #mask_fix, feat_fix, feat_mov):#, grid_step, disp_radius, disp_step, beta=15):
    #use predefined set of displacements
    disp1 = disp.to(ssd_cost.device).to(ssd_cost.dtype)
    
    #remove 50% least reliable control points based on the minimum cost of their respective values
    ssd_val, ssd_idx = torch.min(ssd_cost.squeeze(), 1)
    idx_best = torch.sort(ssd_val, dim=0, descending=False)[1][:kpts_fixed.shape[1]//2]
    #compute a weighted soft correspondence (displacement)
    #this step is crucial to keep the loss differentiable!
    disp_best = torch.sum(torch.softmax(-beta*ssd_cost.squeeze(0).unsqueeze(2),1) * disp1.view(1, -1, 3), 1)
    disp_best = disp_best[idx_best,:]
    
    #compute absolute coordinates for coresspondences and run least trimmed squares fitting
    fixed_pts = torch.cat((kpts_fixed[0,idx_best,:], torch.ones(idx_best.size(0),1).to(feat_fix.device).to(feat_fix.dtype)),1)
    moving_pts = torch.cat((kpts_fixed[0,idx_best,:] + disp_best, torch.ones(idx_best.size(0),1).to(feat_fix.device).to(feat_fix.dtype)),1)
    with torch.cuda.amp.autocast(enabled=False): #SVD is not available/stable with FP16
        R = least_trimmed_rigid(fixed_pts.float(), moving_pts.float())
    return R[:3,:4].unsqueeze(0)


In [None]:
# Run all required imports

import os
import matplotlib.pyplot as plt
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import math

In [None]:
#load data and created eight augmented versions each (64 CT/MR pairs in total)

mdl3_imgs = torch.load('mdl3_imgs.pth')
mdl3_masks = np.load('mdl3_masks.npz')

def load_case(case):
    img_fix = mdl3_imgs['mdl3_img_fix'][case].float().cpu()
    img_mov = mdl3_imgs['mdl3_img_mov'][case].float().cpu()
    mask_fix = torch.from_numpy(mdl3_masks['mdl3_mask_fix'][case]).cpu()
    mask_mov = torch.from_numpy(mdl3_masks['mdl3_mask_mov'][case]).cpu()
    seg_fix = torch.from_numpy(mdl3_masks['mdl3_seg_fix'][case]).cpu().long()
    seg_mov = torch.from_numpy(mdl3_masks['mdl3_seg_mov'][case]).cpu().long()
    return img_fix, img_mov, seg_fix, seg_mov, mask_fix, mask_mov


TRAIN_CASES = torch.arange(8) 

imgs_fix_train = torch.zeros(len(TRAIN_CASES), 8, D, H, W).float().pin_memory()
imgs_mov_train = torch.zeros(len(TRAIN_CASES), 1, D, H, W).float().pin_memory()
segs_fix_train = torch.zeros(len(TRAIN_CASES), 8, D, H, W).int().pin_memory()
segs_mov_train = torch.zeros(len(TRAIN_CASES), 1, D, H, W).int().pin_memory()
masks_fix_train = torch.zeros(len(TRAIN_CASES), 8, D, H, W).bool().pin_memory()
masks_mov_train = torch.zeros(len(TRAIN_CASES), 1, D, H, W).bool().pin_memory()
for i, case in enumerate(TRAIN_CASES):
    print('process case', i)
    img_fix, img_mov, seg_fix, seg_mov, mask_fix, mask_mov = load_case(case)
    device = img_fix.device
    img_fix = img_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    img_mov = img_mov.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    seg_fix = seg_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    seg_mov = seg_mov.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    mask_fix = mask_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    mask_mov = mask_mov.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
    
    imgs_mov_train[i:i+1] = img_mov
    segs_mov_train[i:i+1] = seg_mov
    masks_mov_train[i:i+1] = mask_mov
    for j in range(8):
        with torch.no_grad():
            R = generate_random_rigid_3d()
            grid = F.affine_grid(R[:3,:4].unsqueeze(0).cuda(), (1,1,D,H,W))
            img_fix_ = F.grid_sample(img_fix.cuda(), grid, padding_mode='border')
            seg_fix_ = F.grid_sample(F.one_hot(seg_fix[0, 0]).permute(3, 0, 1, 2).unsqueeze(0).float().cuda(), grid).argmax(1, keepdim=True).int()
            mask_fix_ = F.grid_sample(mask_fix.float().cuda(), grid)>0.5

            imgs_fix_train[i:i+1, j:j+1] = img_fix_.cpu()
            segs_fix_train[i:i+1, j:j+1] = seg_fix_.cpu()
            masks_fix_train[i:i+1, j:j+1] = mask_fix_.cpu()

process case 0


  "Default grid_sample and affine_grid behavior has changed "
  "Default grid_sample and affine_grid behavior has changed "


process case 1
process case 2
process case 3
process case 4
process case 5
process case 6
process case 7


### Task 1 (9 points): Computation of (joint) histograms and mutual information


 The values should be sampled of the (fixed) mask to exclude background locations

✔ Define a range for the histogram bins with linspace with 64 steps and the minimum and maximum value of the respective image.

✔ Compute a Parzen window weighting with σ=0.015 as exp( - $(value - bin)^2$ / (2 ⋅ $σ^2$)), after this step you should have two 2D tensors of size 64 x N (where N is the number of pixels) 

✔ Calculate the marginal (individual) and the joint histogram by summing/averaging over the pixels and dividing the resulting vector by its sum.  For the joint histogram the pairwise sums are implicitly obtained using a matrix multiplication of fixed and transposed moving histograms.

✔ Use E = - ∑ p ⋅ log2(p + ε) to compute entropy and -($E_{fix}$ + $E_{mov}$ - $E_{joint}$) as MI loss, ε = $10^{-6}$

In [None]:
def mutual_inf(mask_fix,img_fix,img_mov):
    sigma = 0.015

    # TODO: draw samples from mask_fix
    img_fix_flattened = img_fix.view(-1)[mask_fix.view(-1) > 0]
    img_mov_flattened = img_mov.view(-1)[mask_fix.view(-1) > 0]
    random_indices = np.random.randint(torch.numel(img_fix_flattened), size=int(torch.numel(img_fix_flattened)*0.2))

    img_fix_sampled = img_fix_flattened[random_indices]
    img_mov_sampled = img_mov_flattened[random_indices]


    # TODO: define bins
    minimum_fix = torch.min(img_fix_sampled)
    maximum_fix = torch.max(img_fix_sampled)
    bins_fix = torch.linspace(minimum_fix, maximum_fix, 64).cuda()

    minimum_mov = torch.min(img_mov_sampled)
    maximum_mov = torch.max(img_mov_sampled)
    bins_mov = torch.linspace(minimum_mov, maximum_mov, 64).cuda()

    # TODO: estimate histograms
    rep_fix = img_fix_sampled.unsqueeze(1).repeat(1, bins_fix.size()[0]).cuda()
    rep_mov = img_mov_sampled.unsqueeze(1).repeat(1, bins_mov.size()[0]).cuda()

    hist_fix = torch.exp(-(rep_fix - bins_fix)**2 / (2*sigma**2)).t()
    hist_mov = torch.exp(-(rep_mov - bins_mov)**2 / (2*sigma**2)).t()

    hist_fix_sum = hist_fix.sum(dim=1).cuda()
    hist_mov_sum = hist_mov.sum(dim=1).cuda()

    hist_fix_marg = hist_fix_sum/hist_fix_sum.sum().cuda()
    hist_mov_marg = hist_mov_sum/hist_mov_sum.sum().cuda()

    joint_hist = torch.matmul(hist_fix, hist_mov.t()).cuda()
    joint_hist = joint_hist/joint_hist.sum().cuda()



    # TODO: estimate entropies
    E_fix = torch.matmul(-hist_fix_marg,torch.log2(hist_fix_marg+0.000001)).cuda()# ...
    E_mov = torch.matmul(-hist_mov_marg,torch.log2(hist_mov_marg+0.000001)).cuda()# ...
    j_h_flattened = joint_hist.view(-1)
    E_joint = torch.matmul(-j_h_flattened,torch.log2(j_h_flattened+0.000001)).cuda()
    return -(E_fix+E_mov-E_joint)                         

### Task 2 (5 points): Correlation layer  

✔ Define a grid of control points using affine_grid with a spacing of 12 voxels and range of .925 (to exclude points near the image boundaries). Again sample the values of the (fixed) mask.  

✔ Create an empty tensor (with same type and device as the fixed features) of size N x 729 (N = number of control points, 729 = 93 displacements), use 32 chunks for unrolling and compute the (dis)similarity as follows: sample the fixed features at (the current subset of) grid points and the moving feature tensor at the combined (added) coordinates of absolute grid points and relative displacements yielding a C x N/32 x 729 tensor. Square and sum over the channel dimension (C).    

✔ The resulting SSD tensor should be fed to the provided function robust_rigid_fit, which searches for the most probable correspondences, by filtering out potentially erroneous ones based on their similarity score and the residual of a globally rigid least-square fit.

In [None]:
# Task2 correlation layer
def correlation(mask_fix, feat_fixed, feat_moving):
    unroll_factor=32 
    voxel_spacing = 12

    # TODO

    affine_matrix = (torch.eye(3, 4).unsqueeze(0)).cuda()
    kpts_fixed = F.affine_grid((0.925)*(affine_matrix), torch.Size((1, 1, 192//12 , 160//12, 192//12))).float().cuda()

    sample_mask = F.grid_sample(mask_fix, kpts_fixed)

    N = (192//12)*(160//12)*(192//12)
    empty_tensor = torch.empty((N,729)).cuda()

    sample_fixed = F.grid_sample(feat_fixed.float(), kpts_fixed).reshape(64,-1)

    #N_ = int(N/unroll_factor)

    #outp = torch.empty((1, 1, 729)).cuda()

    #for idx in range(unroll_factor):
    #  empt = torch.empty((1, N_, 1, 1, 3)).cuda()
    #  rel_coords = (disp.cuda() + empt).cuda()

    #  sample_mov = F.grid_sample(feat_moving.float(), rel_coords)#.reshape(64,N_,-1)
    #  outp = torch.cat((outp, sample_mov[:,0,:,:,:].squeeze(3)),1)

    #print(outp.size(), sample_fixed.size())


    

    #We understand what we need to do:
    #Add the offsets from "disp" to each voxel of "feat_mov" and do grid_sample, similar to line 19.
    #Also chunk the tensor and process one chunk at a time, such that there is no memory issues.


    new_tensor = torch.empty((32, 16, 13, 16, 3)).cuda()

    for idx in range(math.floor(729/32)):
      for a in range(32):
        for b in range(16):
          for c in range(13):
            for d in range(16):
              new_tensor[a, b,c,d,:] = disp[:,:,a*(idx+1),:,:].squeeze(1).squeeze(1).squeeze(1)
      

      new_featMov = torch.empty((32, 64, 48, 40, 48)).cuda()
      for a in range(32):
        new_featMov[a, :,:,:,:] = feat_moving

      sample_mov = F.grid_sample(new_featMov.float(), new_tensor).reshape(64,N,-1)

    print(feat_mov.size(), kpts_fixed.size(), sample_fixed.size(), empty_tensor.size())

    print(sample_mov - sample_fixed.unsqueeze(2))


    ssd = torch.sum(torch.pow(sample_mov - sample_fixed.unsqueeze(2),2),0)
    
    return ssd,kpts_fixed


**Optional**: The following code enables to verify your solutions for tasks 1 & 2:

In [None]:
#test for task1 and task2

task12 = torch.load('mdl3_exercise_task12.pth')

print(task12.keys())
with torch.no_grad():
    mask_fix = task12['mask_fix'].float().cuda()
    img_mov = task12['img_mov'].float().cuda()
    #call mutual information before transform
    loss0 = mutual_inf(mask_fix,task12['img_fix'].float().cuda()[0,0],img_mov[0,0])
    print('mi before',loss0)

    with torch.cuda.amp.autocast(enabled=True):
            #call your own implementation of correlation layer
            feat_fix = task12['feat_fix'].half().cuda()
            feat_mov = task12['feat_mov'].half().cuda()
            cost,kpts_fixed = correlation(mask_fix,feat_fix,feat_mov)
            
            #provided function for robust fitting
            R = robust_rigid_fit(cost.cuda(),kpts_fixed.cuda(),feat_fix)
            
    #mutual information requires 32bit precision
    grid = F.affine_grid(R, (1,1,D,H,W))
    img_warped = F.grid_sample(img_mov,grid.float(),mode='bilinear')
    loss1 = mutual_inf(mask_fix,task12['img_fix'].float().cuda()[0,0],img_mov)
    print('mi after',loss1)  
    seg_fix = task12['seg_fix'].float().cuda()
    seg_mov = task12['seg_mov'].float().cuda()

    seg_mov_warped = F.grid_sample(seg_mov.float(), grid, mode='nearest')

    d0 = dice_coeff(seg_fix.cpu(),seg_mov.cpu(),5) 
    print(d0)
    print('mean dice before: ',d0.mean().item())
    d1 = dice_coeff(seg_fix.cpu(),seg_mov_warped.cpu(),5)
    print(d1)
    print('mean dice after: ', d1.mean().item())
    print()


dict_keys(['feat_fix', 'feat_mov', 'seg_mov', 'mask_fix', 'R', 'seg_fix', 'img_fix', 'img_mov'])
mi before tensor(-0.1097, device='cuda:0')


  "Default grid_sample and affine_grid behavior has changed "
  "Default grid_sample and affine_grid behavior has changed "


torch.Size([1, 64, 48, 40, 48]) torch.Size([1, 16, 13, 16, 3]) torch.Size([64, 3328]) torch.Size([3328, 729])
tensor([[[ 0.0055,  0.0055,  0.0055,  ...,  0.0055,  0.0055,  0.0055],
         [ 0.0054,  0.0054,  0.0054,  ...,  0.0054,  0.0054,  0.0054],
         [ 0.0054,  0.0054,  0.0054,  ...,  0.0054,  0.0054,  0.0054],
         ...,
         [ 0.0043,  0.0043,  0.0043,  ...,  0.0043,  0.0043,  0.0043],
         [-0.0172, -0.0172, -0.0172,  ..., -0.0172, -0.0172, -0.0172],
         [-0.0225, -0.0225, -0.0225,  ..., -0.0225, -0.0225, -0.0225]],

        [[ 0.2365,  0.2365,  0.2365,  ...,  0.2365,  0.2365,  0.2365],
         [ 0.2365,  0.2365,  0.2365,  ...,  0.2365,  0.2365,  0.2365],
         [ 0.2365,  0.2365,  0.2365,  ...,  0.2365,  0.2365,  0.2365],
         ...,
         [ 0.1517,  0.1517,  0.1517,  ...,  0.1517,  0.1517,  0.1517],
         [ 0.1295,  0.1295,  0.1295,  ...,  0.1295,  0.1295,  0.1295],
         [ 0.0978,  0.0978,  0.0978,  ...,  0.0978,  0.0978,  0.0978]],

      

RuntimeError: ignored

### Task 3 (6 points): Define network and train CNNs with mutual information

✔ The modality specific modules should both receive a single-channel input and start with 16 feature maps that are doubled to 32 in the second block. All convolutions should be 3D with 3x3x3 kernels and padding=1, followed by InstanceNorm3d and a LeakyReLU. The second block should have a stride=2. 

✔ The shared sub-network should receive the 32-channel feature maps and comprise three blocks of the same pattern as above: doubling of channels and stride=2 in second block. This yields 64-channel feature that will be squeezed into a range of 0 to 1 using a sigmoid. These feature tensors (for fixed = MRI and moving = CT) will be the input of your correlation layer from Task 2. 
   
✔ The resulting SSD tensor is again fed into robust_rigid_fit, which return a 1x3x4 matrix to obtain a deformation grid that will be applied to the moving scan. Afterwards the mutual information can be computed and minimised as loss. The training loop is pre-defined, we use a batch-size of 1 (hence InstanceNorm) and train for 30 epochs (240 iterations). 

In [None]:
class ModalityNet(nn.Module):
    def __init__(self, base):
        super(ModalityNet, self).__init__()
        
        base = 16
        
        # TODO 
        self.conv1 = nn.Sequential(
            nn.Conv3d(1,base,kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm3d(base),
            nn.LeakyReLU()
        )
        self.conv2 = conv2 = nn.Sequential(
            nn.Conv3d(base,32,kernel_size=3,stride=2, padding=1),
            nn.InstanceNorm3d(32),
            nn.LeakyReLU()
        )

    def forward(self, x):
        # TODO
        return self.conv2(self.conv1(x))
    
class SharedNet(nn.Module):
    def __init__(self, base, out_channels):
        super(SharedNet, self).__init__()
        
        # TODO 
        self.conv1 = nn.Sequential(
            nn.Conv3d(32,32, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm3d(32),
            nn.LeakyReLU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv3d(32,64, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm3d(64),
            nn.LeakyReLU()
        )
        self.conv3 = nn.Sequential(
            nn.Conv3d(64,64, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm3d(64),
            nn.LeakyReLU()
        ) 
        self.feature = nn.Sequential(
            nn.Conv3d(64,out_channels, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        # TODO
        return self.feature(self.conv3(self.conv2(self.conv1(x))))

# This architecture is given.
class FeatureNet(nn.Module):
    def __init__(self):
        super(FeatureNet, self).__init__()
        
        base = 16
        out_channels = 64
        
        self.modality1_net = ModalityNet(base)
        self.modality2_net = ModalityNet(base)
        self.shared_net = SharedNet(base, out_channels)
        
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, y):
        x = self.modality1_net(x)
        y = self.modality2_net(y)
        x = self.shared_net(x)
        y = self.shared_net(y)
        return self.sigmoid(x), self.sigmoid(y)
    


**Optional**: Plot the initial dice values

In [None]:
get_dice_all(TRAIN_CASES,segs_fix_train,segs_mov_train)


In [None]:
# training loop 

num_epochs = 30
init_lr = 0.001
device = 'cuda'

net = FeatureNet().to(device)
parameter_count(net)
optimizer = optim.Adam(net.parameters(), lr=init_lr)

scaler = torch.cuda.amp.GradScaler()
losses = torch.zeros(num_epochs)

for epoch in range(num_epochs):
    net.train()
    torch.cuda.synchronize()
    t0 = time.time()
    running_loss = 0
    rand_idx = torch.randperm(len(TRAIN_CASES))
    for idx in rand_idx:
        optimizer.zero_grad()
        
        rand_idx1 = torch.randint(8, (1,))[0]
        img_fix = imgs_fix_train[idx:idx+1, rand_idx1:rand_idx1+1].to(device,non_blocking=True)# + 1
        img_mov = imgs_mov_train[idx:idx+1].to(device,non_blocking=True)
        seg_fix = segs_fix_train[idx:idx+1, rand_idx1:rand_idx1+1].long().to(device,non_blocking=True)
        seg_mov = segs_mov_train[idx:idx+1].to(device,non_blocking=True).long()
        mask_fix = masks_fix_train[idx:idx+1, rand_idx1:rand_idx1+1].to(device,non_blocking=True)
        mask_mov = masks_mov_train[idx:idx+1].to(device,non_blocking=True)
        
        with torch.cuda.amp.autocast(enabled=True):
            #call your own implementation of network architecture and correlation layer
            feat_fix, feat_mov = net(img_fix, img_mov)
            cost,kpts_fixed = correlation(mask_fix,feat_fix,feat_mov)
            
            #provided function for robust fitting
            R = robust_rigid_fit(cost,kpts_fixed,feat_fix)
            
        #mutual information requires 32bit precision
        with torch.cuda.amp.autocast(enabled=False): 
            grid = F.affine_grid(R, (1,1,D,H,W))
            img_warped = F.grid_sample(img_mov,grid.float(),mode='bilinear')
            
            #call your own implementation for mutual information loss
            loss = mutual_inf(mask_fix.float(),img_fix[0,0],img_warped[0,0])

        seg_mov_warped = F.grid_sample(F.one_hot(seg_mov, 5).view(1, D, H, W, -1).permute(0, 4, 1, 2, 3).float(), grid.float(), mode='bilinear').argmax(1)
        if(rand_idx1==4):
            plt.figure()
            q100 = float(torch.topk(img_fix[0,0,:,60,:].reshape(-1),100)[0].cpu().data[-1:])
            gray1 = torch.clamp(img_fix[0,0,:,60,:].data.cpu().t().flip([0,1]),0,q100)/q100
            rgb = overlaySegment(gray1,seg_mov_warped[0,:,60,:].long().data.cpu().t().flip([0,1]))
            plt.imshow(rgb)
            plt.show()
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
    
    running_loss /= len(TRAIN_CASES)
    losses[epoch] = running_loss
    torch.cuda.synchronize()
    t1 = time.time()

    print('epoch (train): {:02d} -- loss: {:.3f} -- time(s): {:.1f}'.format(epoch, running_loss, t1-t0))
    gpu_usage()

Plot the loss and store network weights:

In [None]:
plt.plot(losses)
FOLD = 3
torch.save(net.cpu().state_dict(), 'net_mi_fold{}.pth'.format(FOLD))

**Evaluation**: Run the following code for a final evaluation. You dice should increase from 43% to above 60%.

In [None]:
#quantitaive evaluation of the trained network, should return around 60% Dice after registration
net = FeatureNet().to(device)
net.load_state_dict(torch.load('net_mi_fold{}.pth'.format(FOLD)))
net.eval()
parameter_count(net)

torch.manual_seed(30)
TEST_CASES = TRAIN_CASES
with torch.no_grad():
    dice_init_all = torch.zeros(4,len(TEST_CASES),4)
    dice_all = torch.zeros(4,len(TEST_CASES),4)
    for i in range(4):
        for j, case in enumerate(TEST_CASES):
            img_fix, img_mov, seg_fix, seg_mov, mask_fix, mask_mov = load_case(case)

            img_fix = img_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)# + 1
            img_mov = img_mov.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
            seg_fix = seg_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
            seg_mov = seg_mov.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
            mask_fix = mask_fix.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)
            mask_mov = mask_mov.to(device,non_blocking=True).unsqueeze(0).unsqueeze(0)

            R = generate_random_rigid_3d()
            grid = F.affine_grid(R[:3,:4].unsqueeze(0), (1,1,D,H,W))
            img_fix_ = F.grid_sample(img_fix, grid)
            seg_fix_ = F.grid_sample(F.one_hot(seg_fix[0, 0]).permute(3, 0, 1, 2).unsqueeze(0).float(), grid).argmax(1, keepdim=True)
            mask_fix_ = (F.grid_sample(mask_fix.float(), grid)>0.5).float()
            with torch.cuda.amp.autocast():
                feat_fix, feat_mov = net(img_fix_.contiguous(), img_mov.contiguous())
                ssd_cost,kpts_fix = correlation(mask_fix_,feat_fix,feat_mov)
            R = robust_rigid_fit(ssd_cost,kpts_fix)
            
            grid = F.affine_grid(R, (1,1,D,H,W))
            seg_mov_warped = F.grid_sample(seg_mov.float(), grid, mode='nearest')

            d = dice_coeff(seg_fix.cpu(),seg_mov.cpu(),5); print(d,d.mean())
            d0 = dice_coeff(seg_fix_.cpu(),seg_mov.cpu(),5); print(d0,d0.mean())
            d1 = dice_coeff(seg_fix_.cpu(),seg_mov_warped.cpu(),5); print(d1,d1.mean())
            print()

            dice_init_all[i, j] = d0
            dice_all[i, j] = d1

In [None]:
# Print dice 

print('Initial dice: ', (dice_init_all.sum()/(dice_init_all>0).sum()).item())
print('Dice after reg:', (dice_all.sum()/(dice_init_all>0).sum()).item())