# Task2: Exploring Unsupervised Registration using Voxelmorph

Medical image registration has been a cornerstone in the research fields of medical image computing and computer assisted intervention, responsible for many clinical applications. Whilst machine learning methods have long been important in developing pairwise algorithms, recently proposed deep-learning-based frameworks directly infer displacement fields without iterative optimisation for unseen image pairs, using neural networks trained from large population data. These novel approaches promise to tackle several most challenging aspects previously faced by classical pairwise methods, such as high computational cost, robustness for generalisation and lack of inter-modality similarity measures. 

Note the package makes use of  `Neurite`, a neural networks toolbox that is based on Tensorflow rather than Pytorch. **Due to compatibility issues, you must make the following changes:**

1. Follow https://github.com/tensorflow/tensorflow/issues/70796 so that `import voxelmorph` can be run successfully.
2. Replace all `get_shape()` with `shape` in the `<voxelmorph package>/tf/networks.py` .

## Registration1: MNIST Dataset

### Loading the Dataset

We will learn the preliminearies about (unsupervised) DL image registration through the simple MNIST dataset and the popular *VoxelMorph* framework. VoxelMorph is a fast learning-based framework for deformable, pairwise medical image registration. Instead of optimizing an objective function for each pair of images which can be time-consuming, it formulates registration as a function that maps an input image to a deformation field that aligns these images. Given a new pair of scans, VoxelMorph rapidly computes a deformation field by directly evaluating the function.

> Balakrishnan, Guha, et al. “VoxelMorph: A Learning Framework for Deformable Medical Image Registration.” IEEE Transactions on Medical Imaging, vol. 38, no. 8, Aug. 2019, pp. 1788–800. DOI.org (Crossref), https://doi.org/10.1109/TMI.2019.2897538.


In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
os.environ['TF_USE_LEGACY_KERAS'] = '1'

In [None]:
import neurite as ne
import voxelmorph as vxm

In [None]:
from torchvision import datasets

train_loader = datasets.MNIST('../data', train=True, download=True)
test_loader = datasets.MNIST('../data', train=False, download=True)

In [None]:
x_train = np.array([np.array(image) for image, _ in train_loader])
y_train = np.array([label for _, label in train_loader], dtype=object)
x_test = np.array([np.array(image) for image, _ in test_loader])
y_test = np.array([label for _, label in test_loader], dtype=object)

# Create validation set as well
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# let's get some shapes to understand what we loaded.
print('shape of x_train: ', x_train.shape)
print('shape of y_train: ', y_train.shape)
print('shape of x_val: ', x_val.shape)
print('shape of y_val: ', y_val.shape)
print('shape of x_test: ', x_test.shape)
print('shape of y_test: ', y_test.shape)

In [None]:
# We select first 5 images from the training set and visualize them

fig, ax = plt.subplots(1, 5, figsize=(15, 5))
for i in range(5):
    ax[i].imshow(x_train[i], cmap='gray')
    ax[i].axis('off')
    ax[i].set_title('Label: ' + str(y_train[i]))

As the last step of loading, we pre-process the data to make it suitable for the VoxelMorph framework.  For demonstration, we will create a stand alone dataset that **contains only number 4, 5, 7**.

In [None]:
x_train = x_train/255
x_val = x_val/255
x_test = x_test/255

# Change size from 28*28 to 32*32
pad_amount = ((0, 0), (2,2), (2,2))
x_train = np.pad(x_train, pad_amount, 'constant') # pad with constant values
x_val = np.pad(x_val, pad_amount, 'constant')
x_test = np.pad(x_test, pad_amount, 'constant')

In [None]:
x_train_4 = x_train[y_train == 4]
x_val_4 = x_val[y_val == 4]
x_test_4 = x_test[y_test == 4]

x_train_5 = x_train[y_train == 5]
x_val_5 = x_val[y_val == 5]
x_test_5 = x_test[y_test == 5]

x_train_7 = x_train[y_train == 7]
x_val_7 = x_val[y_val == 7]
x_test_7 = x_test[y_test == 7]

In [None]:
# visualize again
fig, ax = plt.subplots(1, 5, figsize=(15, 5))
for i in range(5):
    ax[i].imshow(x_train_7[i], cmap='gray')
    ax[i].axis('off')

### Defining Neural Networks used to Estimate Deformation

In learning-based methods, we use a neural network that takes in two images $m$ and $f$ (e.g. MNIST digits of size 32x32) and outputs a dense deformation $\phi$ (e.g. size 32x32x2, because at each pixel we want a vector telling us where to go). In this case, we abstract the UNet model from the VoxelMorph.

Note that we only consider non-rigid deformation here as there are lots of handy tools to conduct affine registrations.

In [None]:
N_DIM = 2
INPUT_FEATURES = 2

input_shape = (*x_train.shape[1:], INPUT_FEATURES)
print('input shape: ', input_shape)

n_encoder_features = [32, 32, 32, 32]
n_decoder_features = [32, 32, 32, 32, 32, 16]
n_features = [n_encoder_features, n_decoder_features]

# inshape: Input tensor shape
# nb_features: UNet convolutional features, specified via a list of lists of the form [[encoder feats], [decoder feats]]
UNet = vxm.networks.Unet(inshape=input_shape, nb_features=n_features)

print(f"Shape of input: {UNet.input[0].shape}")
print(f"Shape of output: {UNet.output[0].shape}")

We can then use the UNet to replicate the structure presented in VoxelMorph.

In [None]:
# Deformation Model: UNet + Conv2D
class DeformationModel(nn.Module):
    def __init__(self, UNet, N_DIM):
        super(DeformationModel, self).__init__()
        self.unet = UNet
        self.conv = nn.Conv2d(in_channels=UNet.output, out_channels=N_DIM, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.UNet(x)
        deformation_tensor = self.conv(x)
        return deformation_tensor

In [None]:
# Full Model: UNet + Conv2D + Spatial Transformer
class VoxelMorphModel(nn.Module):
    def __init__(self, UNet, N_DIM):
        super(VoxelMorphModel, self).__init__()
        self.deformation_model = DeformationModel(UNet, N_DIM)
        self.spatial_transformer = vxm.layers.SpatialTransformer(name='transformer')

    def forward(self, x):
        deformation_tensor = self.deformation_model(x)
        moved_image_tensor = self.spatial_transformer([x, deformation_tensor])
        out_tensor = [moved_image_tensor, deformation_tensor]
        return out_tensor

The model we just created represents the standard, dense VoxelMorph archetecture, with **a UNet component, displacement field, and final spatial transformer layer**.

<img src="../img/VoxelMorph.png" width="800" height="400">

Actually, it is not necessary for us to repeat such code every time. Voxelmorph has a built-in `VxmDense` class to define such models.

In [None]:
# We no longer need to supply 2 in input_shape, as the model will automatically configured with two input tensorw (moving and fixed inputs) instead of one
vxm_model = vxm.networks.VxmDense(inshape=input_shape[:2], nb_unet_features=n_features, int_steps=0) # Here int_steps=0 option disables diffeomorphism
print(f"Shape of input: {[str(t.shape) for t in vxm_model.inputs]}")
print(f"Shape of output: {[str(t.shape) for t in vxm_model.outputs]}")

### Defining Loss Function

In unsupervised registration where no gold standard is provided, how could we know whether the deformation is good or not?

1. First make sure $m\circ \phi$, that is, the moving image warped by the deformation $\phi$ to be close to the fixed image $f$.
2. Then regularize $\phi$, ensuring that the deformation is smooth enough without eccentric behavior.

Let us define two losses: a simple MSE and a spatial gradient of the displacement.

In [None]:
# voxelmorph has a variety of custom loss classes
losses = [vxm.losses.MSE().loss, vxm.losses.Grad('l2').loss]
lambda_param = 0.05
loss_weights = [1, lambda_param]  # balance the two losses through hyperparameter

vxm_model.compile(optimizer='Adam', loss=losses, loss_weights=loss_weights)

### Train the Model

After making an instance of model, we need to make sure the data is in the right format through a python generator that gives batches of data.

In [None]:
def vxm_data_generator(x, batch_size=32):
    """
    Generator that takes in data of size [N, H, W], and yields data for
    our custom vxm model. Note that we need to provide numpy data for each
    input, and each output.

    inputs:  moving [bs, H, W, 1], fixed image [bs, H, W, 1]
    outputs: moved image [bs, H, W, 1], zero-gradient [bs, H, W, 2]
    """

    # preliminary sizing
    volume_shape = x.shape[1:] # extract data shape
    n_dim = len(volume_shape)
    
    # prepare a zero array the size of the deformation, explained later
    zero_phi = np.zeros([batch_size, *volume_shape, n_dim])
    
    while True:
        # randomly select pairs of images for training
        idx1 = np.random.randint(0, x.shape[0], size=batch_size)
        moving_images = x[idx1, ..., np.newaxis]
        idx2 = np.random.randint(0, x.shape[0], size=batch_size)
        fixed_images = x[idx2, ..., np.newaxis]

        inputs = [moving_images, fixed_images]
        outputs = [fixed_images, zero_phi]
        
        yield (inputs, outputs)

In [None]:
# Make use of the generator
train_generator = vxm_data_generator(x_train_7)
in_sample, out_sample = next(train_generator)

print(in_sample[0].shape)
images = [img[0, :, :, 0] for img in in_sample + out_sample] 
titles = ['Moving', 'Fixed', 'Moved ground-truth (Fixed)', 'Zero Deformation Field']
fig, ax = plt.subplots(1, 4, figsize=(15, 5))
for i in range(4):
    ax[i].imshow(images[i], cmap='gray')
    ax[i].axis('off')
    ax[i].set_title(titles[i])

In [None]:
epochs = 10
steps_per_epoch = 100
history = vxm_model.fit_generator(train_generator, epochs=epochs, steps_per_epoch=steps_per_epoch, verbose=2);

In [None]:
# Visualize the loss
plt.figure()
plt.plot(history.epoch, history.history['loss'], '.-')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

### Start Registration

In [None]:
# test the trained model using validation set
val_generator = vxm_data_generator(x_val_7, batch_size = 1)
val_input, _ = next(val_generator)
val_pred = vxm_model.predict(val_input)

In [None]:
# visualize the prediction
images = [img[0, :, :, 0] for img in val_input + val_pred] 
titles = ['Moving', 'Fixed', 'Moved', 'Flow']
fig, ax = plt.subplots(1, 4, figsize=(15, 5))
for i in range(4):
    ax[i].imshow(images[i], cmap='gray')
    ax[i].axis('off')
    ax[i].set_title(titles[i])

The visualization of flow field might be a bit tricky. The `neurite` package provides a handy function `plot.flow` to visualize the displacement field.

In [None]:
ne.plot.flow([val_pred[1].squeeze()], width=4)

Then you might wonder, can we use the learned model for other numbers such as 4 or 5? Let us make some experiments.

In [None]:
val_generator = vxm_data_generator(x_val_4, batch_size = 1)
val_input, _ = next(val_generator)
val_pred = vxm_model.predict(val_input)

# visualize the prediction
images = [img[0, :, :, 0] for img in val_input + val_pred]
titles = ['Moving', 'Fixed', 'Moved', 'Flow']
fig, ax = plt.subplots(1, 4, figsize=(15, 5))
for i in range(4):
    ax[i].imshow(images[i], cmap='gray')
    ax[i].axis('off')
    ax[i].set_title(titles[i])

In [None]:
val_generator = vxm_data_generator(x_val_5, batch_size = 1)
val_input, _ = next(val_generator)
val_pred = vxm_model.predict(val_input)

# visualize the prediction
images = [img[0, :, :, 0] for img in val_input + val_pred]
titles = ['Moving', 'Fixed', 'Moved', 'Flow']
fig, ax = plt.subplots(1, 4, figsize=(15, 5))
for i in range(4):
    ax[i].imshow(images[i], cmap='gray')
    ax[i].axis('off')
    ax[i].set_title(titles[i])

Interesting - it still works with a bit worse performance, meaning that it generalizes beyond what we expected. Why? Locally, parts of the 7s look similar to the 4s as well as 5s, so the registration algorithm still tries to match local neighborhoods.

## Registration2: Brain MRI

Cleary MNIST is just a toy dataset. Let us try a more realistic example - brain MRI images. You can find them in `data/Brain`.

These brain images have gone through intensity-normalized affinely alignment, and are skull-stripped with FreeSurfer, so that we can concentrate on conducting deformable registration.

In [None]:
brain_data = np.load('../data/Brain/tutorial_data.npz')
x_train = brain_data['train']
x_val = brain_data['validate']

# There are 208 volumes, each of size 160x192
volume_shape = x_train.shape[1:]
print('Shape of x_train:', x_train.shape)

In [None]:
# Make some visualizations for first 5 volumes
fig, ax = plt.subplots(1, 5, figsize=(20, 5))
for i in range(5):
    ax[i].imshow(x_train[i, ...], cmap = 'gray')
    ax[i].axis('off')
    cbar = plt.colorbar(ax[i].imshow(x_train[i, ...], cmap='gray'), ax=ax[i])
    cbar.ax.tick_params(labelsize=8)


We will repeat the same setting as above for the MNIST dataset, with the tunning parameter adjusted to 0.01.

In [None]:
# Prepare the VoxelMorph model
# todo
vxm_model = vxm.networks.VxmDense(volume_shape, n_features, int_steps=0)

losses = [vxm.losses.MSE().loss, vxm.losses.Grad('l2').loss]
lambda_param = 0.01
loss_weights = [1, lambda_param]

vxm_model.compile(optimizer='Adam', loss=losses, loss_weights=loss_weights)

In [None]:
# Prepare the data generator
# todo
train_generator = None
in_sample, out_sample = None

# visualize
# todo
# Hint: Follow the same procedure as in the MNIST example
images = None
titles = None
fig, ax = plt.subplots(1, 4, figsize=(15, 5))
for i in range(4):
    ax[i].imshow(images[i], cmap='gray')
    ax[i].axis('off')
    ax[i].set_title(titles[i])

As the brains are much more complex than the MNIST digits, the `epochs` should be increased to 200. We adjust `steps_per_epoch` to $\lceil 208/8\rceil=26$.

In [None]:
epochs = 200
steps_per_epoch = 26
# todo
history = None

In [None]:
# Visualize the loss
plt.figure()
plt.plot(history.epoch, history.history['loss'], '.-')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()

In [None]:
# todo
val_generator = None
val_input, _ = None

val_pred = None

In [None]:
# Visualize both the images and the flow as before
# todo
images = None
titles = None
fig, ax = plt.subplots(1, 4, figsize=(15, 5))
for i in range(4):
    ax[i].imshow(images[i], cmap='gray')
    ax[i].axis('off')
    ax[i].set_title(titles[i])


flow = val_pred[1].squeeze()[::3,::3]
ne.plot.flow([flow], width=5);

At first look, the MSE-only model matches the fixed image better. However, we can observe that the obtained deformation field is very noisy and unlikely to be anatomically meaningful. We might need to make use of anatomical segmentations for evaluation.

## Concluding Remarks

Here are some tutorials if you want to dive deep into how to use deep learning for medical image registration:

+ Visualize warp: https://colab.research.google.com/drive/1F8f1imh5WfyBv-crllfeJBFY16-KHl9c?usp=sharing
+ Template(Atlas) construction: https://colab.research.google.com/drive/1SkQbrWTQHpQFrG4J2WoBgGZC9yAzUas2?usp=sharing#scrollTo=ADanmU8xde-N

Evaluating registration is tricky. For more details, here is a useful reference:

> Song, Joo Hyun. Methods for Evaluating Image Registration. 2017. University of Iowa, Doctor of Philosophy. DOI.org (Crossref), https://doi.org/10.17077/etd.v0vailob.