# VoxelMorph Atlas Building Demo

In [None]:
# tf 2.5 gives some weird errors. We'll go back to 2.4.1
!pip uninstall tensorflow -y 
!pip install tensorflow==2.4.1

In [None]:
# tf.compat.v1.experimental.output_all_intermediates(True)

In [None]:
!pip install voxelmorph  # for all things voxelmorph/neurite
!pip install tensorflow_addons  # for tqdm callbacks



In [None]:
# some imports we'll need throughout the demo
import os

# some third party very useful libraries
import tensorflow as tf
import tensorflow_addons as tfa  # for TQDM callback
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
import nibabel as nib

# our libraries
import voxelmorph as vxm
import neurite as ne

# Utilities

In [None]:
# turn off eager for this
# need to do it due to some tf 2.0+ compatibility issues
tf.compat.v1.disable_eager_execution()

In [None]:
# some helpful functions
def plot_hist(hist):
  plt.figure(figsize=(17,5))
  plt.subplot(1, 2, 1)
  plt.plot(hist.epoch, hist.history['loss'], '.-')
  plt.ylabel('loss')
  plt.xlabel('epochs');
  plt.subplot(1, 2, 2)
  nb_epochs = len(hist.epoch) // 2
  plt.plot(hist.epoch[-nb_epochs:], hist.history['loss'][-nb_epochs:], '.-')
  plt.ylabel('loss')
  plt.xlabel('epochs');
  plt.show()

In [None]:
# generally useful callback
# unfortunately show_epoch_progress=True leaves a printout that we can't control (bad implementation in tfa...)
tqdm_cb = tfa.callbacks.TQDMProgressBar(leave_epoch_progress=False, show_epoch_progress=False) 

# Unconditional Template (MNIST)

## Data

In [None]:
# let's load up MNIST
(x_train_all, y_train_all), (x_test_all, y_test_all) = tf.keras.datasets.mnist.load_data(path="mnist.npz")
x_train_all = x_train_all.astype('float')/255
x_test_all = x_test_all.astype('float')/255

x_train_all = np.pad(x_train_all, ((0, 0), (2, 2), (2, 2)), 'constant')[..., np.newaxis]
x_test_all = np.pad(x_test_all, ((0, 0), (2, 2), (2, 2)), 'constant')[..., np.newaxis]

vol_shape = list(x_train_all.shape[1:-1])

In [None]:
# extract all 3s
digit = 3

x_train = x_train_all[y_train_all == digit, ...]
y_train = y_train_all[y_train_all == digit]
x_test = x_test_all[y_test_all == digit, ...].astype('float')/255
y_test = y_test_all[y_test_all == digit]

In [None]:
# prepare a simple generator. 
def template_gen(x, batch_size):
  vol_shape = list(x.shape[1:-1])
  zero = np.zeros([batch_size] + vol_shape + [2])
  mean_atlas = np.repeat(  np.mean(x, 0, keepdims=True), batch_size, 0)

  while True:
    idx = np.random.randint(0, x.shape[0], batch_size)
    img = x[idx, ...]
    inputs = [mean_atlas, img]
    outputs = [img, zero, zero, zero]
    yield inputs, outputs

# let's make sure the sizes make sense
sample = next(template_gen(x_train, 8))
[f.shape for f in sample[0]], [f.shape for f in sample[1]]

([(8, 32, 32, 1), (8, 32, 32, 1)],
 [(8, 32, 32, 1), (8, 32, 32, 2), (8, 32, 32, 2), (8, 32, 32, 2)])

## Model

In [None]:
enc_nf = [16, 32, 32, 32]
dec_nf = [32, 32, 32, 32, 32, 16, 16]

In [None]:
model = vxm.networks.TemplateCreation(vol_shape, nb_unet_features=[enc_nf, dec_nf])

LocalParamWithInput: Consider using neuron.layers.LocalParam()


In [None]:
# prepare losses and compile
image_loss_func = vxm.losses.MSE().loss
neg_loss_func = lambda _, y_pred: image_loss_func(model.references.atlas_tensor, y_pred)
losses = [image_loss_func, neg_loss_func, vxm.losses.MSE().loss, vxm.losses.Grad('l2', loss_mult=2).loss]
loss_weights = [0.5, 0.5, 1, 0.01]

model.compile('adam', loss=losses, loss_weights=loss_weights)

In [None]:
# train model
gen = template_gen(x_train, batch_size=8)
hist = model.fit(gen, epochs=100, steps_per_epoch=25, verbose=0, callbacks=[tqdm_cb])

Training:   0%|           0/100 ETA: ?s,  ?epochs/s

InvalidArgumentError: ignored

In [None]:
# visualize training
plot_hist(hist)

## Visualize Results

In [None]:
# visualize learned atlas
atlas = model.references.atlas_layer.get_weights()[0][..., 0]
plt.imshow(atlas, cmap='gray')
plt.axis('off');
plt.title('atlas')

# Unconditional Template (2D Brain slices)

## Get Data
This is data we released as part of neurite, please read more about it [here](https://github.com/adalca/medical-datasets/blob/master/neurite-oasis.md).

In [None]:
# get the data
!wget wget http://surfer.nmr.mgh.harvard.edu/ftp/data/neurite/data/neurite-oasis.2d.v1.0.tar -O data.tar
!tar -xf data.tar;

In [None]:
# prepare data
files = [f + '/slice_norm.nii.gz' for f in os.listdir('.') if f.startswith('OASIS_OAS1_')]
vols = [nib.load(f).get_fdata() for f in tqdm(files)]
x_vols = np.stack(vols, 0)
vol_shape = x_vols.shape[1:-1]

## Model

In [None]:
# get the model
enc_nf = [16, 32, 32, 32]
dec_nf = [32, 32, 32, 32, 32, 16, 16]
model = vxm.networks.TemplateCreation(vol_shape, nb_unet_features=[enc_nf, dec_nf])

In [None]:
# prepare losses
image_loss_func = vxm.losses.MSE().loss
neg_loss_func = lambda _, y_pred: image_loss_func(model.references.atlas_tensor, y_pred)
losses = [image_loss_func, neg_loss_func, vxm.losses.MSE().loss, vxm.losses.Grad('l2', loss_mult=2).loss]
loss_weights = [0.5, 0.5, 1, 0.01]

model.compile('adam', loss=losses, loss_weights=loss_weights)

In [None]:
# train
gen = template_gen(x_vols, batch_size=2)
hist = model.fit(gen, epochs=100, steps_per_epoch=25, verbose=0, callbacks=[tqdm_cb])

In [None]:
# visualize optimization
plot_hist(hist)

## Visualize Atlas

In [None]:
atlas = model.references.atlas_layer.get_weights()[0][..., 0]
plt.imshow(np.rot90(atlas, -1), cmap='gray')
plt.axis('off');

# Conditional Template (MNIST)

## Data (all MNIST)

In [None]:
# back to MNIST, all digits this time
x_train = x_train_all
y_train = y_train_all
y_train_onehot = tf.keras.utils.to_categorical(y_train_all, 10)
x_test = x_test_all
y_test = y_train_all
vol_shape = list(x_train.shape[1:-1])

In [None]:
# prepare a simple generator. 
def cond_template_gen(x, y, batch_size):
  vol_shape = list(x.shape[1:-1])
  zero = np.zeros([batch_size] + vol_shape + [2])
  atlas = np.repeat(np.mean(x, 0, keepdims=True), batch_size, 0)

  while True:
    idx = np.random.randint(0, x.shape[0], batch_size)
    img = x[idx, ...]
    inputs = [y[idx, ...], atlas, img]

    outputs = [img, zero, zero, zero]
    yield inputs, outputs

sample = next(cond_template_gen(x_train, y_train_onehot, 8))
[f.shape for f in sample[0]], [f.shape for f in sample[1]]

## Model

In [None]:
nf_enc = [16,32,32,32]
nf_dec = [32,32,32,32,16,16,3] 
model = vxm.networks.ConditionalTemplateCreation(vol_shape, pheno_input_shape=[10], nb_unet_features=[enc_nf, dec_nf], conv_nb_features=16,
                                                 conv_image_shape = [4, 4, 8], conv_nb_levels=4)
# model.summary()

In [None]:
# prepare losses
image_loss_func = vxm.losses.MSE().loss
losses = [image_loss_func, vxm.losses.MSE().loss, vxm.losses.Grad('l2', loss_mult=2).loss, vxm.losses.MSE().loss]
loss_weights = [1, 0.01, 0.03, 0]  # changed second-last to 0.01


model.compile('adam', loss=losses, loss_weights=loss_weights)

In [None]:
# fit
gen = cond_template_gen(x_train, y_train_onehot, batch_size=32)
hist = model.fit(gen, epochs=100, steps_per_epoch=25, verbose=0, callbacks=[tqdm_cb])

In [None]:
plot_hist(hist)

## Visualize atlas

In [None]:
atlas_model = tf.keras.models.Model(model.inputs[:2], model.get_layer('atlas').output)

In [None]:
mean_atlas = np.repeat(np.mean(x_train, 0, keepdims=True), 10, 0)
input_samples = [tf.keras.utils.to_categorical(np.arange(10), 10), mean_atlas]

In [None]:
pred = atlas_model.predict(input_samples)
ne.plot.slices([f.squeeze() for f in pred], cmaps=['gray']);

## Video: Visualize conditional atlas in video

In [None]:
!pip install opencv-python
import cv2

In [None]:
output_video_filename = 'age_evolution.mp4'

In [None]:
nb_frames = 100
fps = 5

# create input samples. 
# Since we're dealing with categorical here with MNIST, we'll make a fake continuous space.
# The result won't be sensible but it will give an idea of using videos.
linspace = np.linspace(0, 10 - 1e-7, nb_frames)
pheno = tf.keras.utils.to_categorical(np.floor(linspace), 10) * (linspace - np.floor(linspace))[..., np.newaxis]
mean_atlas = np.repeat(np.mean(x_train, 0, keepdims=True), nb_frames, 0)

# get the atlas predictions
input_samples = [pheno, mean_atlas]
pred = atlas_model.predict(input_samples, batch_size=32)

# write file
out = cv2.VideoWriter(output_video_filename, cv2.VideoWriter_fourcc(*'MP4V'), 
                      fps, tuple(vol_shape), isColor=False) 
for i in range(nb_frames):
  frame = (np.clip(pred[i, ..., 0], 0, 1)*255).astype('uint8')
  out.write(frame)
out.release()

In [None]:
# get file
from google.colab import files
files.download(output_video_filename) 

---

# Multi-Modal atlas

We're going to simulate a couple of variants of multi-modal atlases.  
To simulate 'modalities', we'll use MNIST, and intensity-inverted MNIST. 

## Unpaired data variant (conditional on modality)

We want to test building a multi-modal atlas with *unpaired* multi-modal data.  

To simulate this, we'll take out the images of digit 3, and **split** the training dataset into two groups of images: in the first group keeping the images as they are, while in the second group using intensity-inversed images.

Since the images are unpaired, it's easy to learn a conditional template where the condition is the modality. 

If they were paired (see below), we could take advantage of the pairing by learning a single atlas with two channels (`src_feats=2`).

In [None]:
# extract data
x_train3 = x_train_all[y_train_all == 3, ...]
x_train3_inv = 1 - x_train3
y_train3 = y_train_all[y_train_all == 3, ...] * 0 

In [None]:
# create unpaired data
x_train3_mixed = np.concatenate([x_train3[::2, ...], x_train3_inv[1::2, ...]], 0)
y_train3_mixed = np.concatenate([y_train3[::2], 1 + y_train3[1::2]], 0)
y_train3_mixed_onehot = tf.keras.utils.to_categorical(y_train3_mixed, 2)

### Model

In [None]:
enc_nf = [16,32,32,32]
dec_nf = [32,32,32,32,16,16,3] 
model = vxm.networks.ConditionalTemplateCreation(vol_shape, 
                                                 pheno_input_shape=[2], 
                                                 src_feats=1,
                                                 nb_unet_features=[enc_nf, dec_nf], 
                                                 conv_nb_features=16,
                                                 conv_image_shape=[4, 4, 8], 
                                                 conv_nb_levels=4)
# model.summary()

In [None]:
# prepare losses
image_loss_func = vxm.losses.MSE().loss
losses = [image_loss_func, vxm.losses.MSE().loss, vxm.losses.Grad('l2', loss_mult=2).loss, vxm.losses.MSE().loss]
loss_weights = [1, 0.01, 0.03, 0]  

model.compile('adam', loss=losses, loss_weights=loss_weights)

In [None]:
# fit
gen = cond_template_gen(x_train3_mixed, y_train3_mixed_onehot, batch_size=16)
hist = model.fit(gen, epochs=100, steps_per_epoch=25, verbose=0, callbacks=[tqdm_cb])

In [None]:
plot_hist(hist)

### Visualize atlas

In [None]:
atlas_model = tf.keras.models.Model(model.inputs[:2], model.get_layer('atlas').output)

In [None]:
mean_atlas = np.repeat(np.mean(x_train3_mixed, 0, keepdims=True), 2, 0)
input_samples = [tf.keras.utils.to_categorical(np.arange(2), 2), mean_atlas]

In [None]:
pred = atlas_model.predict(input_samples)
ne.plot.slices([f.squeeze() for f in pred], cmaps=['gray'], width=6);

## Paired data variant (Multi-channel atlas)

Assuming we have paired data:

In [None]:
# 2-channel data.
x_train3_2channel = np.concatenate([x_train3, x_train3_inv], -1)
x_train3_2channel.shape  # making sure

In [None]:
# unfortunately we had a bug in the pypi version of voxelmorph for atlas_feats
# let's get the dev branches
!pip uninstall voxelmorph neurite pystrum -y 
!git clone -b dev --single-branch https://github.com/voxelmorph/voxelmorph
!git clone -b dev --single-branch https://github.com/adalca/neurite
!git clone -b dev --single-branch https://github.com/adalca/pystrum

import sys
sys.path = ['voxelmorph', 'neurite', 'pystrum'] + sys.path

# fully unimport vxm, neurite, pystrum
lst = [f for f in sys.modules if f.startswith('voxelmorph')]
[sys.modules.pop(f) for f in lst]
lst = [f for f in sys.modules if f.startswith('neurite')]
[sys.modules.pop(f) for f in lst]
lst = [f for f in sys.modules if f.startswith('pystrum')]
[sys.modules.pop(f) for f in lst]

# reimport
import voxelmorph as vxm
import neurite as ne

### Model

In [None]:
enc_nf = [16,32,32,32]
dec_nf = [32,32,32,32,16,16,3] 
model = vxm.networks.TemplateCreation(vol_shape, nb_unet_features=[enc_nf, dec_nf], atlas_feats=2, src_feats=2)
# model.summary()

In [None]:
# prepare losses and compile
image_loss_func = vxm.losses.MSE().loss
neg_loss_func = lambda _, y_pred: image_loss_func(model.references.atlas_tensor, y_pred)
losses = [image_loss_func, neg_loss_func, vxm.losses.MSE().loss, vxm.losses.Grad('l2', loss_mult=2).loss]
loss_weights = [0.5, 0.5, 1, 0.01]

model.compile('adam', loss=losses, loss_weights=loss_weights)

In [None]:
# fit
gen = template_gen(x_train3_2channel, batch_size=16)
hist = model.fit(gen, epochs=100, steps_per_epoch=25, verbose=0, callbacks=[tqdm_cb])

In [None]:
plot_hist(hist)

### Visualize

In [None]:
# visualize learned atlas
atlas = model.references.atlas_layer.get_weights()[0]
print(atlas.shape)
ne.plot.slices([atlas[..., 0], atlas[..., 1]], cmaps=['gray'], width=6);

In [None]:
# a bit blurry, but not too bad. Probably need to play with the hyperparameters.