# VoxelMorph Inferencing

This example is based on [Learn2Reg](https://www.kaggle.com/adalca/learn2reg).

## Config

In [None]:
vxm_model_file = "vxm/models/cvpr2018_vm2_l2.h5" # MICCAI18: "vxm/models/miccai2018_10_02_init1.h5"

scans_path = "/mnt/h/ml/ventricles/scans/"
mni_path = "/mnt/h/ml/"

registered_output = "registered.csv"
baseline_output = "unregistered.csv"
difference_output = "diff.csv"
mni_output = "mni.csv"

## Set Up

In [None]:
!pip install numpy scipy sklearn nibabel matplotlib pprint tqdm

# Imports
import os, sys
import csv

# Third party imports
import numpy as np
import nibabel as nib
import keras.layers
import keras.models
from scipy import ndimage as nd

# Local imports
sys.path.append("vxm/ext/pynd-lib/")
sys.path.append("vxm/ext/pytools-lib/")
sys.path.append("vxm/ext/neuron/")
sys.path.append("vxm/voxelmorph/")
sys.path.append("vxm/voxelmorph/voxelmorph/")

import voxelmorph as vxm
import neuron
import neuron.layers as nrn_layers

# Limit GPU memory
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
config = tf.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.7
set_session(tf.Session(config=config))

# Utilities
def padwidth(width):
    width = np.max((0, width)) / 2
    return int(np.ceil(width)), int(np.floor(width))

def cropwidth(width):
    width = np.min((0, width))
    width = np.abs(wid) / 2
    return int(np.ceil(width)), int(np.floor(width))

def padcrop(img, x=256, y=256, z=256):
    """
    Pads or crops a rescaled scan to given target dimensions x,y,z
    """
    new_img = np.zeros((x, y, z))
    target_dim = np.array((x, y, z))
    difs = target_dim - np.array(img.shape)
    cropped_img = None
    if np.any(difs < 0):
        crop_x = cropwidth(difs[0])
        crop_y = cropwidth(difs[1])
        crop_z = cropwidth(difs[2])
        cropped_img = img[crop_x[0]:(img.shape[0] - crop_x[1]), crop_y[0]:(img.shape[1] - crop_y[1]),
                      crop_z[0]:(img.shape[2] - crop_z[1])]
    else:
        cropped_img = img
    if np.any(difs > 0):
        new_img[:, :, :] = np.pad(cropped_img, (padwidth(difs[0]), padwidth(difs[1]), padwidth(difs[2])),
                                  mode="constant")
    else:
        new_img[:, :, :] = cropped_img
    return new_img

def load_nii_unnormalized(path):
    img = nib.load(path)
    pixdim = img.header["pixdim"]
    img = img.get_data()
    img = nd.interpolation.rotate(img, 90, (1, 2))
    
    pixdim = [pixdim[1], pixdim[3], pixdim[2]]
    scale_factor = min([a / (b * c) for (a, b, c) in zip(vol_shape, img.shape, pixdim)])
    img = nd.interpolation.zoom(
        img,
        (scale_factor * pixdim[0], scale_factor * pixdim[1], scale_factor * pixdim[2]),
        mode = "nearest",
        prefilter = False)
    return padcrop(img, 160, 192, 224)

def load_nii(path):
    img = load_nii_unnormalized(path)
    return img.astype("float") / np.max(img)

def dsc(y_true, y_pred):
    top = 2 * np.sum(y_true * y_pred)
    bottom = max(np.sum(y_true) + np.sum(y_pred), 1e-5)
    return top / bottom

## Load Model

In [None]:
# Our data will be of shape 160 x 192 x 224
vol_shape = [160, 192, 224]
ndims = 3

custom_objects = {"SpatialTransformer": nrn_layers.SpatialTransformer,
                 "VecInt": nrn_layers.VecInt,
                 "Sample": vxm.networks.Sample,
                 "Rescale": vxm.networks.RescaleDouble,
                 "Resize": vxm.networks.ResizeDouble,
                 "Negate": vxm.networks.Negate,
                 "recon_loss": vxm.losses.Miccai2018(0.02, 10).recon_loss, # values shouldn't matter
                 "kl_loss": vxm.losses.Miccai2018(0.02, 10).kl_loss        # values shouldn't matter
                 }

vxm_model = keras.models.load_model(vxm_model_file, custom_objects=custom_objects)
warp_model = vxm.networks.nn_trf(vol_shape)

## Test

In [None]:
scan_names = os.listdir(scans_path)
scan_names.sort()

with open(registered_output, "w") as csvfile:
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(["F\M"] + [i / 2 for i in range(0, len(scan_names) - 1, 2)])

    for i in range(0, len(scan_names) - 1, 2):
        fixed_val = load_nii(scans_path + scan_names[i + 1])
        fixed_seg = load_nii(scans_path + scan_names[i])

        print(scan_names[i + 1])
        row = [i / 2]
        for j in range(0, len(scan_names) - 1, 2):
            moving_val = load_nii(scans_path + scan_names[j + 1])
            moving_seg = load_nii(scans_path + scan_names[j])

            val_input = [moving_val[np.newaxis, ..., np.newaxis], fixed_val[np.newaxis, ..., np.newaxis]]
            val_pred = vxm_model.predict(val_input)
            moved_pred = val_pred[0].squeeze()
            pred_warp = val_pred[1]

            warped_seg = warp_model.predict([moving_seg[np.newaxis,...,np.newaxis], pred_warp])
            dice = dsc(fixed_seg, warped_seg.squeeze())
            row.append(dice)
        csvwriter.writerow(row)

In [None]:
scan_names = os.listdir(scans_path)
scan_names.sort()

with open(baseline_output, "w") as csvfile:
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(["F\M"] + [i / 2 for i in range(0, len(scan_names) - 1, 2)])

    for i in range(0, len(scan_names) - 1, 2):
        fixed_seg = load_nii(scans_path + scan_names[i])

        print(scan_names[i + 1])
        row = [i / 2]
        for j in range(0, len(scan_names) - 1, 2):
            moving_seg = load_nii(scans_path + scan_names[j])

            dice = dsc(fixed_seg, moving_seg)
            row.append(dice)
        csvwriter.writerow(row)

In [None]:
with open(baseline_output, "r") as unreg, open(registered_output, "r") as reg, open(difference_output, "w") as diff:
    reader_u = csv.reader(unreg)
    reader_r = csv.reader(reg)
    writer = csv.writer(diff)

    writer.writerow(next(reader_u))
    next(reader_r)
    for row_u in reader_u:
        row_r = next(reader_r)
        row = [float(r) - float(u) for (u, r) in zip(row_u, row_r)]
        row[0] = row_u[0]
        writer.writerow(row)
        

In [None]:
scan_names = os.listdir(scans_path)
scan_names.sort()

moving_val = load_nii(os.path.join(mni_path, "mni_icbm152_t1_tal_nlin_sym_09a.nii")) * \
    load_nii(os.path.join(mni_path, "mni_icbm152_t1_tal_nlin_sym_09a_mask.nii"))

moving_seg = load_nii_unnormalized(os.path.join(mni_path, "mni_icbm152_CerebrA_tal_nlin_sym_09c.nii"))
moving_seg = np.logical_or(moving_seg == 41, moving_seg == 92)

with open(mni_output, "w") as csvfile:
    csvwriter = csv.writer(csvfile)
    csvwriter.writerow(["Fixed Image", "Accuracy [DSC]"])

    for j in range(0, len(scan_names) - 1, 2):
        fixed_val = load_nii(scan_path + scan_names[j + 1])
        fixed_seg = load_nii(scan_path + scan_names[j])

        val_input = [moving_val[np.newaxis, ..., np.newaxis], fixed_val[np.newaxis, ..., np.newaxis]]
        val_pred = vxm_model.predict(val_input)
        moved_pred = val_pred[0].squeeze()
        pred_warp = val_pred[1]

        warped_seg = warp_model.predict([moving_seg[np.newaxis,...,np.newaxis], pred_warp])
        dice = dsc(fixed_seg, warped_seg.squeeze())
        csvwriter.writerow([scan_names[j + 1], dice])

## Visualize

In [None]:
import matplotlib.pyplot as plt

fixed_val = load_nii("/mnt/h/ml/ventricles/scans/ventricle_1_t1.nii.gz")
fixed_seg = load_nii("/mnt/h/ml/ventricles/scans/ventricle_1_seg.nii.gz")

moving_val = load_nii(os.path.join(mni_path, "mni_icbm152_t1_tal_nlin_sym_09a.nii")) * \
    load_nii(os.path.join(mni_path, "mni_icbm152_t1_tal_nlin_sym_09a_mask.nii"))
moving_seg = load_nii_unnormalized(os.path.join(mni_path, "mni_icbm152_CerebrA_tal_nlin_sym_09c.nii"))
moving_seg = np.logical_or(moving_seg == 41, moving_seg == 92)

# Predict
val_input = [moving_val[np.newaxis, ..., np.newaxis], fixed_val[np.newaxis, ..., np.newaxis]]
val_pred = vxm_model.predict(val_input)
moved_pred = val_pred[0].squeeze()
pred_warp = val_pred[1]
warped_seg = warp_model.predict([moving_seg[np.newaxis,...,np.newaxis], pred_warp])

# Save output
# nib.save(nib.Nifti1Image((fixed_val[0,...,0] * 256).astype(int), None), "fixed_val.nii.gz")
# nib.save(nib.Nifti1Image((fixed_seg[0,...,0] * 256).astype(int), None), "fixed_seg.nii.gz")
# nib.save(nib.Nifti1Image((warped_seg[0,...,0] * 256).astype(int), None), "warped_seg.nii.gz")

# Extract slices & plot
mid_slices_fixed = [np.take(fixed_val, vol_shape[d]//2, axis=d) for d in range(ndims)]
mid_slices_fixed[1] = np.rot90(mid_slices_fixed[1], 1)
mid_slices_fixed[2] = np.rot90(mid_slices_fixed[2], -1)

mid_slices_moving = [np.take(moving_val, vol_shape[d]//2, axis=d) for d in range(ndims)]
mid_slices_moving[1] = np.rot90(mid_slices_moving[1], 1)
mid_slices_moving[2] = np.rot90(mid_slices_moving[2], -1)

mid_slices_pred = [np.take(moved_pred, vol_shape[d]//2, axis=d) for d in range(ndims)]
mid_slices_pred[1] = np.rot90(mid_slices_pred[1], 1)
mid_slices_pred[2] = np.rot90(mid_slices_pred[2], -1)
neuron.plot.slices(mid_slices_fixed + mid_slices_pred + mid_slices_moving, cmaps=["gray"], do_colorbars=True, grid=[3,3])

# Visualize DVF
flow = pred_warp[0, :, :, :, :]
flow_sd = np.std(flow)
v_args = dict(cmap = 'RdBu', vmin = -flow_sd, vmax = +flow_sd)
fig, m_axs = plt.subplots(3, 3, figsize = (20, 10))
for i, (ax1, ax2, ax3) in enumerate(m_axs):
    ax1.imshow(np.mean(flow[:, :, :, i], 0), **v_args)
    ax1.set_title('xyz'[i]+' flow')
    ax2.imshow(np.mean(flow[:, :, :, i], 1), **v_args)
    ax3.imshow(np.mean(flow[:, :, :, i], 2), **v_args)
    
def meshgridnd_like(in_img,
                    rng_func=range):
    new_shape = list(in_img.shape)
    all_range = [rng_func(i_len) for i_len in new_shape]
    return tuple([x_arr.swapaxes(0, 1) for x_arr in np.meshgrid(*all_range)])

from mpl_toolkits.mplot3d import axes3d
DS_FACTOR = 16
c_xx, c_yy, c_zz = [x.flatten()
                    for x in 
                    meshgridnd_like(flow[::DS_FACTOR, ::DS_FACTOR, ::DS_FACTOR, 0])]

get_flow = lambda i: flow[::DS_FACTOR, ::DS_FACTOR, ::DS_FACTOR, i].flatten()

fig = plt.figure(figsize = (10, 10))
ax = fig.gca(projection="3d")

ax.quiver(c_xx,
          c_yy,
          c_zz,
          get_flow(0),
          get_flow(1), 
          get_flow(2), 
          length=0.5,
          normalize=True)

In [None]:
from ipywidgets import interact

def plot(x):
    fig, axs = neuron.plot.slices([
        np.rot90(np.take(moving_val, x, axis=1), 1),
        np.rot90(np.take(moved_pred, x, axis=1), 1),
        np.rot90(np.take(fixed_val, x, axis=1), 1),
        np.rot90(np.take(moving_seg, x, axis=1), 1),
        np.rot90(np.take(warped_seg.squeeze(), x, axis=1), 1),
        np.rot90(np.take(fixed_seg, x, axis=1), 1)
    ], cmaps=["gray"], do_colorbars=True, grid=[2,3])
    # fig.savefig("f0_m1_s{}_3up.png".format(x),bbox_inches='tight', dpi=300)

interact(plot, x = 80)