# Tutorial 4
## May 24, 2023
In the previous tutorials, you have familiarized yourself with PyTorch, MONAI, and Weights & Biases. In last week's lectures 4 and 5, you have heard about image reconstruction and image registration with (convolutional) neural networks. This week, you again get the chance to put what you have learned into practice. The tutorial consists of two parts. First, you will develop, train, and evaluate a CNN for denoising of (synthetic) CT images. Second, you will develop, train, and evaluate a CNN that learns to perform deformable image registration in the chest X-ray images that we have also used in the second tutorial. Along the way, there will be questions (❓) and <b style='background-color:rgba(80,255,80,0.4); padding:2px'>
exercises.</b>

First, let's take care of the necessities:
- If you're using Google Colab, make sure to select a GPU Runtime.
- Connect to Weights & Biases using the code below.
- Install a few libraries that we will use in this tutorial.

In [None]:
import os
import wandb

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
wandb.login()

In [None]:
!pip install dival
!pip install kornia
!pip install monai

## Part 1: Reconstruction
In the first part of this tutorial, you will reconstruct CT images. To not use too much disk storage, we will synthetise images on the fly using the Deep Inversion Validation Library [(dival)](https://github.com/jleuschn/dival). These are 2D images with $128\times 128$ pixels that contain a random number of ellipses with random sizes and random intensities. 

First, make a dataset of ellipses. This will make an object that we can call for images using a generator. Next, we take a look at what this dataset contains. We will use the <code>generator</code> to ask for a sample. Each sample contains a sinogram and a ground truth (original) synthetic image that we can visualize. You may recall from the lecture that the sinogram is made up of integrals along projections. The horizontal axis in the sinogram corresponds to the location $s$ along the detector, the vertical axis to the projection angle $\theta$.

<img src="https://upload.wikimedia.org/wikipedia/commons/0/0c/Tomographic_fig1.png" width="400px"></img>

In [None]:
import dival

dataset = dival.get_standard_dataset('ellipses', impl='skimage')
dat_gen = dataset.generator(part='train')

Run the cell below to show a sinogram and image in the dataset.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Get a sample from the generator
sinogram, ground_truth = next(dat_gen)
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

# Show the sinogram
axs[0].imshow(sinogram, cmap='gray', extent=[0, 183, -90, 90])
axs[0].set_title('Sinogram')
axs[0].set_xlabel('$s$')
axs[0].set_ylabel('$\Theta$')

# Show the ground truth image
axs[1].imshow(ground_truth, cmap='gray')
axs[1].set_title('Ground truth')
axs[1].set_xlabel('$x$')
axs[1].set_ylabel('$y$')
plt.show()   

> ❓ What kind of CT reconstruction problem is this? Limited-view or sparse-angle CT? Why?

Answer:

Not only does the sinogram contain few angles, it also contains added white noise. If we simply backproject the sinogram to the image domain we end up with a low-quality image. Let's give it a try using the standard [Filtered Backprojection](https://en.wikipedia.org/wiki/Radon_transform#Reconstruction_approaches) (FBP) algorithm for CT and its implementation in [scikit-image](https://scikit-image.org/).

In [None]:
import skimage.transform as sktr

# Get a sample from the generator
sinogram, ground_truth = next(dat_gen)
sinogram = np.asarray(sinogram).transpose()

# This defines the projectiona angles
theta = np.linspace(-90., 90., sinogram.shape[1], endpoint=True)

# Perform FBP
fbp_recon = sktr.iradon(sinogram, theta=theta, filter_name='ramp')[28:-27, 28:-27]
fig, axs = plt.subplots(1, 3, figsize=(12, 4))
axs[0].imshow(sinogram.transpose(), cmap='gray', extent=[0, 183, -90, 90])
axs[0].set_title('Sinogram')
axs[0].set_xlabel('$s$')
axs[0].set_ylabel('$\Theta$')
axs[1].imshow(ground_truth, cmap='gray', clim=[0, 1])
axs[1].set_title('Ground truth')
axs[1].set_xlabel('$x$')
axs[1].set_ylabel('$y$')
axs[2].imshow(fbp_recon, cmap='gray', clim=[0, 1])
axs[2].set_title('FBP')
axs[2].set_xlabel('$x$')
axs[2].set_ylabel('$y$')
plt.show()

<div style='background-color:rgba(80,255,80,0.4); padding:20px'>
   ⌨ <b>Exercise</b>: What do you think of the quality of the reconstructed FBP algorithm? Use the cell below to quantify the similarity between the images using the structural similarity index (SSIM). Does this reflect your intuition? Also compute the PSNR using the <a href="https://scikit-image.org/docs/stable/api/skimage.metrics.html#skimage.metrics.peak_signal_noise_ratio"><code>peak_signal_noise_ratio</code></a> method in scikit-image.
</div>

In [None]:
import skimage.metrics as skme

print('SSIM = {:.2f}'.format(skme.structural_similarity(np.asarray(ground_truth), fbp_recon, data_range=np.max(ground_truth)-np.min(ground_truth))))
# ⌨ FILL IN

### Datasets and dataloaders

Our (or your) goal now is to obtain high(er) quality reconstructed images based on the sinogram measurements. As you have seen in the lecture, this can be done in four ways:
1. Train a reconstruction method that directly maps from the measurement (sinogram) domain to the image domain.
2. **Preprocessing** Clean up the sinogram using a neural network, then backproject to the image domain.
3. **Postprocessing** First backproject to the image domain, then improve the reconstruction using a neural network.
4. Iterative methods that integrate data consistency.

Here, we will follow the third approach, postprocessing. We create reconstructions from the generated sinograms using filtered backprojection and use a neural network to learn corrections on this FBP image and improve the reconstruction, as shown in the image below. The data that we need for training this network is the reconstructions from FBP, and the ground-truth reconstructions from the dival dataset. 
<img src='https://imgur.com/df4RYzE.png%27></img>'></img>

We will make a training dataset of 512 samples from the ellipses dival dataset that we store in a MONAI <code>DataSet</code>. The code below does this in four steps:
1. Create a dival generator, that creates sinograms and ground-truth reconstructions.
2. Make a dictionary (like we did in the previous tutorial) that contains the ground-truth reconstructions and the reconstructions constructed by FBP as separate keys.
3. Define the transforms for the data (also like the previous tutorial). In this case we require an additional 'channels' dimension, as that is what the neural network expects. We will not make use of extra data augmentation.
4. Construct the dataset using the dictionary and the defined transform.

In [None]:
import tqdm
import monai

theta = np.linspace(-90., 90., sinogram.shape[1], endpoint=True)

# Make a generator for the training part of the dataset
train_gen = dataset.generator(part='train')
train_samples = []

# Make a list of (in this case) 512 random training samples. We store the filtered backprojection (FBP) and ground truth image
# in a dictionary for each sample, and add these to a list.
for ns in tqdm.tqdm(range(512)):
    sinogram, ground_truth = next(train_gen)
    sinogram = np.asarray(sinogram).transpose()
    fbp_recon = sktr.iradon(sinogram, theta=theta, filter_name='ramp')[28:-27, 28:-27]
    train_samples.append({'fbp': fbp_recon, 'ground_truth': np.asarray(ground_truth)})

# You can add or remove transforms here
train_transform = monai.transforms.Compose([
    monai.transforms.AddChanneld(keys=['fbp', 'ground_truth'])
])    

# Use the list of dictionaries and the transform to initialize a MONAI CacheDataset
train_dataset = monai.data.CacheDataset(train_samples, transform=train_transform)    

<div style='background-color:rgba(80,255,80,0.4); padding:20px'>
    ⌨ <b>Exercise</b>: Also make a validation dataset and call it <code>val_dataset</code>. This dataset can be smaller, e.g., 64 or 128 samples.

</div>

In [None]:
# Your code goes here

<div style='background-color:rgba(80,255,80,0.4); padding:20px'>
    ⌨ <b>Exercise</b>: Now, make a dataloader for both the validation and training data, called <code>train_loader</code> and <code>validation_loader</code>, that we can use for sampling batches during training of the network. Give them a reasonable batch size, e.g., 16.
</div>

In [None]:
# ⌨️ FILL IN
train_loader = ...
validation_loader = ...

### Model
Now that we have datasets and dataloaders, the next step is to define a model, optimizer and criterion. Because we want to improve the FBP-reconstructed image, we are dealing with an image-to-image task. A standard U-Net as implemented in MONAI is therefore a good starting point. First, make sure that you are using the GPU (CUDA), otherwise training will be extremely slow.

In [None]:
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = "cpu"
print(f'The used device is {device}')

<div style='background-color:rgba(80,255,80,0.4); padding:20px'>
    ⌨ <b>Exercise</b>: Initialize a U-Net with the correct settings, e.g. channels and dimensions, and call it <code>model</code>. Here, it's convenient to use the <a href="https://docs.monai.io/en/stable/networks.html#monai.networks.nets.BasicUNet"><code>BasicUNet</code></a> as implemented in MONAI.

</div>

### Loss function
An important aspect is the loss function that you will use to optimize the model. The problem that we are trying to solve using a neural network is a *regression* problem, which differs from the *classification* approach we covered in the segmentation tutorial. Instead of classifying each pixel as a certain class, we alter their intensities to obtain a better overall reconstruction of the image. 

Because this task is substantially different, we need to change our loss function. In the previous tutorial we used the Dice loss, which measures the overlap for each of the classes to segment. In this case, an L2 (mean squared error) or L1 (mean average error) loss suits our objective. Alternatively, we can use a loss that aims to maximize the structural similarity (SSIM). For this, we use the [kornia](https://kornia.readthedocs.io/en/latest/) library.

In [None]:
import kornia 

# Three loss functions, turn them on or off by commenting

loss_function = torch.nn.MSELoss()
# loss_function = torch.nn.L1Loss()
# loss_function = kornia.losses.SSIMLoss(window_size=3)

As in previous tutorials, we use an adaptive SGD (Adam) optimizer to train our network. This tutorial, we add a [learning rate scheduler](https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html). This scheduler lowers the learning rate every *step_size* steps, meaning that the optimizer will take smaller steps in the direction of the gradient after a set amount of epochs. Therefore, the optimizer can potentially find a better local minimum for the weights of the neural network.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

<div style='background-color:rgba(80,255,80,0.4); padding:20px'>
    ⌨ <b>Exercise</b>: Complete the code below and train the U-Net.
    
❓ What does the model learn? Look carefully at how we determine the output of the model. Can you describe what happens in the following line: <code>outputs = model(batch_data['fbp'].float().to(device)) + batch_data["fbp"].float().to(device)</code>?

</div>

In [None]:
from tqdm.notebook import tqdm
import wandb
from skimage.metrics import structural_similarity as ssim


run = wandb.init(
    project='tutorial3_reconstruction',
    name='test',
    config={
        'loss function': str(loss_function), 
        'lr': optimizer.param_groups[0]["lr"],
        'batch_size': train_loader.batch_size,
    }
)
# Do not hesitate to enrich this list of settings to be able to correctly keep track of your experiments!
# For example you should include information on your model architecture

run_id = run.id # We remember here the run ID to be able to write the evaluation metrics

def log_to_wandb(epoch, train_loss, val_loss, batch_data, outputs):
    """ Function that logs ongoing training variables to W&B """

    # Create list of images that have segmentation masks for model output and ground truth
    # log_imgs = [wandb.Image(PIL.Image.fromarray(img.detach().cpu().numpy())) for img in outputs]
    val_ssim = []
    for im_id in range(batch_data['ground_truth'].shape[0]):
        val_ssim.append(ssim(batch_data['ground_truth'].detach().cpu().numpy()[im_id, 0, :, :].squeeze(), 
                             outputs.detach().cpu().numpy()[im_id, 0, :, :].squeeze() ))
    val_ssim = np.mean(np.asarray(val_ssim))
    # Send epoch, losses and images to W&B
    wandb.log({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss, 'val_ssim': val_ssim}) 
    
for epoch in tqdm(range(75)):
    model.train()    
    epoch_loss = 0
    step = 0
    for batch_data in train_loader: 
        step += 1
        optimizer.zero_grad()
        outputs = model(batch_data["fbp"].float().to(device)) + batch_data["fbp"].float().to(device)
        # FILL IN
    # validation part
    step = 0
    val_loss = 0
    for batch_data in validation_loader:
        step += 1
        model.eval()
        outputs = model(batch_data['fbp'].float().to(device)) + batch_data["fbp"].float().to(device)
        # FILL IN
    log_to_wandb(epoch, train_loss, val_loss, batch_data, outputs)
    # Scheduler also needs to make a step during training
    scheduler.step()

# Store the network parameters        
torch.save(model.state_dict(), r'trainedUNet.pt')
run.finish()

<div style='background-color:rgba(80,255,80,0.4); padding:20px'>
    ⌨ <b>Exercise</b>: Now make a <code>DataSet</code> and <code>DataLoader</code> for the test set. Just a handful of images should be enough.

</div>

In [None]:
import tqdm

test_gen = dataset.generator(part='test')
....
test_dataset = ....

test_loader = monai.data.DataLoader(test_dataset, batch_size=1)

> ❓ Visualize a number of reconstructions from the neural network and compare them to the fbp reconstructed images, using the code below. The performance of the network is evaluated using the structural similarity [function](https://scikit-image.org/docs/stable/api/skimage.metrics.html#skimage.metrics.structural_similarity) in scikit-image. Does the neural network improve this metric a lot compared to the filtered back projection?

In [None]:
model.eval()

for test_sample in test_loader:
    output = model(test_sample['fbp'].to(device)) + test_sample['fbp'].to(device)
    output = output.detach().cpu().numpy()[0, 0, :, :].squeeze()
    ground_truth = test_sample['ground_truth'][0, 0, :, :].squeeze()
    fbp_recon = test_sample['fbp'][0, 0, :, :].squeeze()
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(fbp_recon, cmap='gray', clim=[0, 1])
    axs[0].set_title('FBP SSIM={:.2f}'.format(ssim(ground_truth.cpu().numpy(), fbp_recon.cpu().numpy())))
    axs[0].set_xlabel('$x$')
    axs[0].set_ylabel('$y$')
    axs[1].imshow(ground_truth, cmap='gray', clim=[0, 1])
    axs[1].set_title('Ground truth')
    axs[1].set_xlabel('$x$')
    axs[1].set_ylabel('$y$')
    axs[2].imshow(output, cmap='gray', clim=[0, 1])
    axs[2].set_title('CNN SSIM={:.2f}'.format(ssim(ground_truth.cpu().numpy(), output)))
    axs[2].set_xlabel('$x$')
    axs[2].set_ylabel('$y$')
    plt.show()   

<div style='background-color:rgba(80,255,80,0.4); padding:20px'>
⌨ <b>Exercise</b>: 
Instead of a U-Net, try a different model, e.g., a <a href="https://docs.monai.io/en/stable/networks.html#segresnet">SegResNet</a> in Monai.
Evaluate how the different loss functions affect the performance of the network. Notes that the SSIM on the validation set is also written to Weights & Biases during training. Which loss leads to the best SSIM scores? Which loss results in the worst SSIM scores?
    </div>

Answer:

## Part 2: Registration

In [None]:
import monai
import numpy as np
import matplotlib.pyplot as plt
import torch
import wandb

In the second part of the tutorial, we will register chest X-ray images. We will reuse the data of Tutorial 3. As always, we first set the paths. This should be the path ending in 'ribs'.

In [None]:
# ONLY IF YOU USE JUPYTER: ADD PATH ⌨️
data_path = r'/Users/jmwolterink/Downloads/ribs'# WHEREDIDYOUPUTTHEDATA?

In [None]:
# ONLY IF YOU USE COLAB: ADD PATH ⌨️
from google.colab import drive

drive.mount('/content/drive')
data_path = r'/content/drive/My Drive/Tutorial3'

In [None]:
# check if data_path exists:
import os

if not os.path.exists(data_path):
    print("Please update your data path to an existing folder.")
elif not set(["train", "val", "test"]).issubset(set(os.listdir(data_path))):
    print("Please update your data path to the correct folder (should contain train, val and test folders).")
else:
    print("Congrats! You selected the correct folder :)")

### Data management

In this part we prepare all the tools needed to load and visualize our samples. One thing we *could* do is perform **inter**-patient registration, i.e., register two chest X-ray images of different patients. However, this is a very challenging problem. Instead, to make our life a bit easier, we will perform **intra**-patient registration: register two images of the same patient. For each patient, we make a synthetic moving image by applying some random elastic deformations. To build this data set, we we used the [Rand2DElasticd](https://docs.monai.io/en/stable/transforms.html#rand2delastic) transform on both the image and the mask. We will use a neural network to learn the deformation field between the fixed image and the moving image.
<img src='https://i.imgur.com/OmoOZ5w.png'></img>

Similarly as in Tutorial 3, make a dictionary of the image file names.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import glob
import monai
from PIL import Image
import torch

def build_dict_ribs(data_path, mode='train'):
    """
    This function returns a list of dictionaries, each dictionary containing the keys 'img' and 'mask' 
    that returns the path to the corresponding image.
    
    Args:
        data_path (str): path to the root folder of the data set.
        mode (str): subset used. Must correspond to 'train', 'val' or 'test'.
        
    Returns:
        (List[Dict[str, str]]) list of the dictionnaries containing the paths of X-ray images and masks.
    """
    # test if mode is correct
    if mode not in ["train", "val", "test"]:
        raise ValueError(f"Please choose a mode in ['train', 'val', 'test']. Current mode is {mode}.")
    
    # define empty dictionary
    dicts = []
    # list all .png files in directory, including the path
    paths_xray = glob.glob(os.path.join(data_path, mode, 'img', '*.png'))
    # make a corresponding list for all the mask files
    for xray_path in paths_xray:
        if mode == 'test':
            suffix = 'val'
        else:
            suffix = mode
        # find the binary mask that belongs to the original image, based on indexing in the filename
        image_index = os.path.split(xray_path)[1].split('_')[-1].split('.')[0]
        # define path to mask file based on this index and add to list of mask paths
        mask_path = os.path.join(data_path, mode, 'mask', f'VinDr_RibCXR_{suffix}_{image_index}.png')
        if os.path.exists(mask_path):
            dicts.append({'fixed': xray_path, 'moving': xray_path, 'fixed_mask': mask_path, 'moving_mask': mask_path})
    return dicts

class LoadRibData(monai.transforms.Transform):
    """
    This custom Monai transform loads the data from the rib segmentation dataset.
    Defining a custom transform is simple; just overwrite the __init__ function and __call__ function.
    """
    def __init__(self, keys=None):
        pass

    def __call__(self, sample):
        fixed = Image.open(sample['fixed']).convert('L') # import as grayscale image
        fixed = np.array(fixed, dtype=np.uint8)
        moving = Image.open(sample['moving']).convert('L') # import as grayscale image
        moving = np.array(moving, dtype=np.uint8)        
        fixed_mask = Image.open(sample['fixed_mask']).convert('L') # import as grayscale image
        fixed_mask = np.array(fixed_mask, dtype=np.uint8)
        moving_mask = Image.open(sample['moving_mask']).convert('L') # import as grayscale image
        moving_mask = np.array(moving_mask, dtype=np.uint8)        
        # mask has value 255 on rib pixels. Convert to binary array
        fixed_mask[np.where(fixed_mask==255)] = 1
        moving_mask[np.where(moving_mask==255)] = 1        
        return {'fixed': fixed, 'moving': moving, 'fixed_mask': fixed_mask, 'moving_mask': moving_mask, 'img_meta_dict': {'affine': np.eye(2)}, 
                'mask_meta_dict': {'affine': np.eye(2)}}

Then we make a training dataset like before. The <code>Rand2DElasticd</code> transform here determines how much deformation is in the 'moving' image. 

In [None]:
train_dict_list = build_dict_ribs(data_path, mode='train')

# constructDataset from list of paths + transform
transform = monai.transforms.Compose(
[
    LoadRibData(),
    monai.transforms.AddChanneld(keys=['fixed', 'moving', 'fixed_mask', 'moving_mask']),
    monai.transforms.Resized(keys=['fixed', 'moving', 'fixed_mask', 'moving_mask'], spatial_size=(256, 256),  mode=['bilinear', 'bilinear', 'nearest', 'nearest']),
    monai.transforms.HistogramNormalized(keys=['fixed', 'moving']),
    monai.transforms.ScaleIntensityd(keys=['fixed', 'moving'], minv=0.0, maxv=1.0),
    monai.transforms.Rand2DElasticd(keys=['moving', 'moving_mask'], spacing=(64, 64), 
                                    magnitude_range=(-8, 8), prob=1, mode=['bilinear', 'nearest']),    
])
train_dataset = monai.data.Dataset(train_dict_list, transform=transform)

<div style='background-color:rgba(80,255,80,0.4); padding:20px'>
  ⌨ <b>Exercise</b>: Visualize fixed and moving training images associated to their comparison image with <code>visualize_fmc_sample</code>.
    <p>Try different methods to create the comparison image. How well do these different methods allow you to qualitatively assess the quality of the registration?</p>
     <p>More information on this method is available in  <a href="https://scikit-image.org/docs/stable/api/skimage.util.html#skimage.util.compare_images">scikit-image documentation</a></p> 

</div>

In [None]:
def visualize_fmc_sample(sample, method="checkerboard"):
    """
    Plot three images: fixed, moving and comparison.
    
    Args:
        sample (dict): sample of dataset created with `build_dataset`.
        method (str): method used by `skimage.util.compare_image`.
    """
    import skimage.util as skut 
    
    skut_methods = ["diff", "blend", "checkerboard"]
    if method not in skut_methods:
        raise ValueError(f"Method must be chosen in {skut_methods}.\n"
                         f"Current value is {method}.")
    
    
    fixed = np.squeeze(sample['fixed'])
    moving = np.squeeze(sample['moving'])
    comp_checker = skut.compare_images(fixed, moving, method=method)
    axs = plt.figure(constrained_layout=True, figsize=(15, 5)).subplot_mosaic("FMC")
    axs['F'].imshow(fixed, cmap='gray')
    axs['F'].set_title('Fixed')
    axs['M'].imshow(moving, cmap='gray')
    axs['M'].set_title('Moving')
    axs['C'].imshow(comp_checker, cmap='gray')
    axs['C'].set_title('Comparison')
    plt.show()

In [None]:
sample = train_dataset[0]
for method in ["diff", "blend", "checkerboard"]:
    print(f"Method {method}")
    visualize_fmc_sample(sample, method=method)

Now we apply a little trick. Because applying the random deformation in each training iteration will be very costly, we only apply the deformation once and we make a new dataset based on the deformed images. Running the cell below make take a few minutes.

In [None]:
import tqdm

train_loader = monai.data.DataLoader(train_dataset, batch_size=1, shuffle=False)

samples = []
for train_batch in tqdm.tqdm(train_loader):
    samples.append(train_batch)

# Make a new dataset and dataloader using the transformed images
train_dataset = monai.data.Dataset(samples, transform=monai.transforms.SqueezeDimd(keys=['fixed', 'moving', 'fixed_mask', 'moving_mask']))
train_loader = monai.data.DataLoader(train_dataset, batch_size=16, shuffle=False)

<div style='background-color:rgba(80,255,80,0.4); padding:20px'>
    ⌨ <b>Exercise</b>: Create <code>val_dataset</code> and <code>val_loader</code>, corresponding to the DataSet and DataLoader for your validation set. The transforms can be the same as in the training set.
</div>

In [None]:
# Your code goes here

### Model

We use again the U-Net architecture as in the previous tutorial. However, this time our input / output structure is quite different:
- the network takes as input two images: the *moving* and *fixed* images.
- it outputs one tensor representing the *deformation field*.

<img src='https://i.imgur.com/rvZfwr2.png' width=600></img>


This *deformation field* can be applied to the *moving* image with the `monai.networks.blocks.Warp` block of Monai.

<img src='https://i.imgur.com/gj7JnOy.png' width=500></img>


This deformed moving image is then compared to the *fixed* image: if they are similar, the deformation field is correctly registering the moving image on the fixed image. Keep in mind that this is done on **training** data, and we want the U-Net to learn to predict a proper deformation field given two new and unseen images. So we're not optimizing for a pair of images as would be done in conventional iterative registration, but training a model that can generalize.

<img src='https://i.imgur.com/aM7OrR4.png' width=300></img>


Before starting, let's check that you can work on a GPU by runnning the following cell:
- if the device is "cuda" you are working on a GPU,
- if the device is "cpu" call a teacher.

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    os.environ["PYTORCH_ENABLE_MPS_FALLBACK"]="1"
else:
    device = "cpu"
print(f'The used device is {device}')

<div style='background-color:rgba(80,255,80,0.4); padding:20px'>
    ⌨ <b>Exercise</b>: Construct a U-Net with suitable settings and name it <code>model</code>.
    <p>Check that you can correctly apply its output to the input moving image with the <code>warp_layer</code>!</p>
</div>

In [None]:
model = # FILL IN

warp_layer = monai.networks.blocks.Warp().to(device)

### Objective function

We evaluate the similarity between the fixed image and the deformed moving image with the `MSELoss()`. The L1 or SSIM losses seen in the previous section could also be used. Furthermore, the deformation field is regularized with `BendingEnergyLoss`. This is a penalty that takes the smoothness of the deformation field into account: if it's not smooth enough, the bending energy is high. Thus, our model will favor smooth deformation fields.

Finally, we pick an optimizer, in this case again an Adam optimizer.

In [None]:
image_loss = torch.nn.MSELoss()
regularization = monai.losses.BendingEnergyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-3)

<div style='background-color:rgba(80,255,80,0.4); padding:20px'>
    ⌨ <b>Exercise</b>: Add a learning rate scheduler that lowers the learning rate by a factor ten every 100 epochs.</p>
</div>

In [None]:
# Your code goes here

To warp the moving image using the predicted deformation field and *then* compute the loss between the deformed image and the fixed image, we define a forward function which does all this. The output of this function is `pred_image`. 

In [None]:
def forward(batch_data, model):
    """
    Applies the model to a batch of data.
    
    Args:
        batch_data (dict): a batch of samples computed by a DataLoader.
        model (Module): a model computing the deformation field.
    
    Returns:
        ddf (Tensor): batch of deformation fields.
        pred_image (Tensor): batch of deformed moving images.
    
    """
    fixed_image = batch_data["fixed"].to(device).float()
    moving_image = batch_data["moving"].to(device).float()
    
    # predict DDF
    ddf = model(torch.cat((moving_image, fixed_image), dim=1))

    # warp moving image and label with the predicted ddf
    pred_image = warp_layer(moving_image, ddf)

    return ddf, pred_image

You can supervise the training process in W&B, in which at each epoch a batch of validation images are used to compute the comparison images of your choice, based on the parameter `method`.

In [None]:
def log_to_wandb(epoch, train_loss, val_loss, pred_batch, fixed_batch, method="checkerboard"):
    """ Function that logs ongoing training variables to W&B """
    import skimage.util as skut
    
    log_imgs = []
    for fixed_pt, pred_pt in zip(pred_batch, fixed_batch):
        fixed_np = np.squeeze(fixed_pt.cpu().detach())
        pred_np = np.squeeze(pred_pt.cpu().detach())
        comp_checker = skut.compare_images(fixed_np, pred_np, method=method)
        log_imgs.append(wandb.Image(comp_checker))

    # Send epoch, losses and images to W&B
    wandb.log({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss, 'results': log_imgs})

### Training time

Use the following cells to train your network. You may choose different parameters to improve the performance!

In [None]:
# Choose your parameters

max_epochs = 200
reg_weight = 0 # By default 0, but you can investigate what it does

In [None]:
from tqdm import tqdm

run = wandb.init(
    project='tutorial4_registration',
    config={
        'lr': optimizer.param_groups[0]["lr"],
        'batch_size': train_loader.batch_size,
        'regularization': reg_weight,
        'loss_function': str(image_loss)
    }
)
# Do not hesitate to enrich this list of settings to be able to correctly keep track of your experiments!
# For example you should add information on your model...

run_id = run.id # We remember here the run ID to be able to write the evaluation metrics

for epoch in tqdm(range(max_epochs)):    
    model.train()
    epoch_loss = 0
    for batch_data in train_loader:
        optimizer.zero_grad()

        ddf, pred_image = forward(batch_data, model)

        fixed_image = batch_data["fixed"].to(device).float()
        reg = regularization(ddf)
        loss = image_loss(pred_image, fixed_image) + reg_weight * reg
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    epoch_loss /= len(train_loader)

    model.eval()
    val_epoch_loss = 0
    for batch_data in val_loader:
        ddf, pred_image = forward(batch_data, model)
        fixed_image = batch_data["fixed"].to(device).float()
        reg = regularization(ddf)
        loss = image_loss(pred_image, fixed_image) + reg_weight * reg
        val_epoch_loss += loss.item()
    val_epoch_loss /= len(val_loader)

    log_to_wandb(epoch, epoch_loss, val_epoch_loss, pred_image, fixed_image)
    
run.finish()    

### Evaluation of the trained model

Now that the model has been trained, it's time to evaluate its performance. Use the code below to visualize samples and deformation fields. 
> ❓ Are you satisfied with these registration results? Do they seem anatomically plausible? Try out different regularization factors (<code>reg_weight</code>) and see what they do to the registration.

Answer: 

In [None]:
def visualize_prediction(sample, model, method="checkerboard"):
    """
    Plot three images: fixed, moving and comparison.
    
    Args:
        sample (dict): sample of dataset created with `build_dataset`.
        model (Module): a model computing the deformation field.
        method (str): method used by `skimage.util.compare_image`.
    """
    import skimage.util as skut 
    
    skut_methods = ["diff", "blend", "checkerboard"]
    if method not in skut_methods:
        raise ValueError(f"Method must be chosen in {skut_methods}.\n"
                         f"Current value is {method}.")
        
    model.eval()
    
    # Compute deformation field + deformed image
    batch_data = {
        "fixed": sample["fixed"].unsqueeze(0),
        "moving": sample["moving"].unsqueeze(0),
    }
    ddf, pred_image = forward(batch_data, model)
    ddf = ddf.detach().cpu().numpy().squeeze()
    ddf = np.linalg.norm(ddf, axis=0).squeeze()
    
    # Squeeze images
    fixed = np.squeeze(sample["fixed"])
    moving = np.squeeze(sample["moving"])    
    deformed = np.squeeze(pred_image.detach().cpu())
    
    # Generate comparison image
    comp_checker = skut.compare_images(fixed, deformed, method=method, n_tiles=(4, 4))
    
    # Plot everything
    fig, axs = plt.subplots(1, 5, figsize=(18, 5))    
    axs[0].imshow(fixed, cmap='gray')
    axs[0].set_title('Fixed')
    axs[1].imshow(moving, cmap='gray')
    axs[1].set_title('Moving')
    axs[2].imshow(deformed, cmap='gray')
    axs[2].set_title('Deformed')
    axs[3].imshow(comp_checker, cmap='gray')
    axs[3].set_title('Comparison')    
    dpl = axs[4].imshow(ddf, clim=(0, 10))
    fig.colorbar(dpl, ax=axs[4])
    plt.show()   
    plt.show()
for sample in val_dataset:
    visualize_prediction(sample, model)

<div style='background-color:rgba(80,255,80,0.4); padding:20px'>
    ⌨ <b>Bonus exercise</b>: Compute the Jacobian determinant at each image voxel. How many of these are negative? Can you improve upon this?</p>
</div>