# Example of automatic prostate mpMRI registration

This Notebook can be used to automatically (or semi-automatically) register different imaging modalities. Here, it will be used to automatically register the prostate MR ADC map to the MR T2 sequence.

To get some images to test this Notebook on, please run ``ProstateX_processing.ipynb`` first, with ``apply_registration`` set to ``False``.

We will need ``plot_lib`` to visualize the images as they are processed:
 - `plot_lib`: https://github.com/OscarPellicer/plot_lib

In [1]:
#Import plot_lib
from pathlib import Path
import sys, os
sys.path.append(os.path.join(Path.home(), 'plot_lib'))
from plot_lib import plot_alpha, plot_multi_mask, plot, plot4

#Some CSS to allow images to display side by side by default
br= lambda: print(' '*100) #Insert a line that breaks flexbox wrapping
from IPython.display import display, HTML
CSS = """.output { flex-direction: row; flex-wrap: wrap; }
         .widget-hslider { width: auto; } """
HTML('<style>{}</style>'.format(CSS))

Load required libraries

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

#Basic libraries
import SimpleITK as sitk
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pydicom
import os
import glob
import pickle
from functools import partial

#Data processing lib
from preprocessing_lib import (ImageDataset, info as info_sitk, grow_regions_sitk,
                              join_sitk_images, join_masks, rescale_intensity, 
                              center_image, get_blank_image,
                              ProgressBar, EasyTimer)
from reg_lib import (register_rigid, register_spline, get_gradient_features, evaluate_registration, 
                     save_transform_auto, save_transform)

#Show interactive buttons
import ipywidgets as widgets
from IPython.display import display

## Configuration

### How it works

The last cell of this Notebook will go over a loop of ``patient_ids``, calling ``get_data(patient_id)`` for every patient, which returns all the information needed for performing registration on that patient (at least, a ``fixed_image`` and a ``moving_image``.

Then, the registration algorithm will attempt to find the transform that makes ``moving_image`` match ``fixed_image`` as best as possible in terms of ``Mattes Mutual Information`` metric. There are two registration methods available: ``register_rigid`` and ``register_spline``. 

The algorithm attempts this registration ``RUNS`` (e.g. 50) times (with a different random initialization each time), and then the best ``PLOT_BEST_N`` (e.g. 2) registrations (in terms of ``Mattes Mutual Information`` and a custom metric defined in function ``evaluate_registration``) are plotted. 

By default, the first best registration is kept, but, after the Notebook has finished, there will be buttons that will allow to manually keep other registrations that might be perceived to be better (Button: ``Save this one``), as well as a button allowing to keep none (Button: ``Original was better``); this last button will instead save an identity transformation. Please note that these buttons will not work while the registration loop is still running!

In [3]:
#Base path
BASE_PATH= ''

#Path were images are located
IMAGES_PATH= os.path.join(BASE_PATH, 'out_unregistered')

#Create directory to save transforms
transform_dir= os.path.join(BASE_PATH, 'transforms2')
os.makedirs(transform_dir, exist_ok=True)

#Create directory to save registration samples to quicly visualize results if SAVE_IMGS==True
image_save_dir= os.path.join(BASE_PATH, 'registration_pngs')
os.makedirs(image_save_dir, exist_ok=True)

#Main configuration parameters
REGISTRATION_CHANNEL= 3  #Values: 1: B500, 2:B800+, 3:ADC | Default: 3
METHOD= register_spline  #Values: register_rigid, register_spline | Default: register_spline
RUNS= 50                 #Number of runs for the registration algorithm
SAVE_IMGS= False         #Save images with the registration results
PLOT_BEST_N= 2           #How many results to show

#Create an easy timer for timing registration times
ET= EasyTimer()

Configure the data loading routines. We must define at least the following functions:
 - `get_data(patient_id)` returns:
    1. Fixed image wrt which registration is to be performed 
    1. Moving image to which the transformation is to be applied
    1. List of masks to evaluate the custom metric on 
    1. List of factors for the custom metric
 
Also, we must set the following list:
 - `patient_ids`: IDs over which to iterate and pass to `get_data`

In [4]:
#Define get_data funtion
def get_data(pid):
    #Load image and masks
    #Please, work with medical imaging formats (dicom, nrrd, etc.) when possible
    #to avoid problems with physical position, orientation, and spacing
    spacing= (0.5, 0.5, 3)
    img= sitk.GetImageFromArray(np.load(os.path.join(IMAGES_PATH, pid + '_img.npy')))
    img.SetSpacing(spacing)
    
    #The fixed image is the T2
    fixed_image= sitk.VectorIndexSelectionCast(img, 0)
    
    #The image to be registered is at the REGISTRATION_CHANNEL
    moving_image= sitk.VectorIndexSelectionCast(img, REGISTRATION_CHANNEL)
    
    #Obtain the masks to apply the custom metric using them
    cz_mask= sitk.VectorIndexSelectionCast(img, 6)
    pz_mask= sitk.VectorIndexSelectionCast(img, 7)
    
    return fixed_image, moving_image, [pz_mask, cz_mask], [1.5,0.75]
    
#Obtain the patient_ids from the names of the images in IMAGES_PATH
patient_ids= [id[:-8] for id in os.listdir(IMAGES_PATH) if id[-8:] == '_img.npy']

#Show some info
print('Patient IDs to register (N=%d):\n'%len(patient_ids), patient_ids)

Patient IDs to register (N=21):
 ['ProstateX-0000', 'ProstateX-0001', 'ProstateX-0002', 'ProstateX-0003', 'ProstateX-0004', 'ProstateX-0005', 'ProstateX-0006', 'ProstateX-0007', 'ProstateX-0008', 'ProstateX-0009', 'ProstateX-0010', 'ProstateX-0011', 'ProstateX-0012', 'ProstateX-0013', 'ProstateX-0014', 'ProstateX-0015', 'ProstateX-0016', 'ProstateX-0017', 'ProstateX-0018', 'ProstateX-0019', 'ProstateX-0020']


## Registration

Perform the registration

In [None]:
for pid in patient_ids:
    #Check if transform already exists and ignore patient
    if os.path.exists(os.path.join(transform_dir, pid + '.tfm')):
        continue
    else:
        #Load fixed, moving and mask
        print(pid, '-'*20); br()
        try:
            fixed_image, moving_image, mask_list, factors= get_data(pid)
        except Exception as e:
            print('An exception occured loading the data:', e)
            save_transform_auto(pid, sitk.Euler3DTransform(), transform_dir)
            continue

        #Get lists of results + utilities
        results=[]
        ET.reset()
        PB= ProgressBar(RUNS)

        #Loop
        for i in range(RUNS):
            PB.go(i)
            try:
                #Gradient computation
                fixed_image_reg= get_gradient_features(fixed_image) 
                moving_image_reg= -get_gradient_features(moving_image)

                #Perform registration
                transform, metric= METHOD(fixed_image_reg, moving_image_reg, show_progress=False, verbose=False)
                registered_image= sitk.Resample(moving_image, fixed_image, transform, 
                         sitk.sitkLinear, 0.0, moving_image.GetPixelID())

                #Get custom metric
                registered_image_reg= -get_gradient_features(registered_image)
                custom_metric, values= evaluate_registration(fixed_image_reg, moving_image_reg, 
                                                             registered_image_reg, mask_list, factors)

                #Save in list
                results.append((custom_metric, transform, values, metric/2))

            except Exception as e:
                print('Exception:', e)
                raise e

        print(' ', end=''); ET.time()

        #Sort and save best
        results= sorted(results, key= lambda i: i[0] + i[3]) #0: custom metric, 3: metric
        save_transform_auto(pid, results[0][1], transform_dir) #First (0), transform item (1)
        br()

        #Plot T2
        plot(fixed_image, masks=mask_list, title='Reference', 
             save_as='%s/%s_T2'%(image_save_dir, pid) if SAVE_IMGS else None)

        #Plot best and worst
        for i, (custom_metric, transform, values, metric) in enumerate(results[:PLOT_BEST_N]):

            print('Metrics:', metric, custom_metric); br()
            registered_image= sitk.Resample(moving_image, fixed_image, transform, 
                         sitk.sitkLinear, 0.0, moving_image.GetPixelID())

            #Plot images to compare
            plot(moving_image, masks=mask_list, title='Before', 
                 save_as='%s/%s_before'%(image_save_dir, pid) if i==0 and SAVE_IMGS else None)
            plot(registered_image, masks=mask_list , title='After', 
                 save_as='%s/%s_after'%(image_save_dir, pid) if i==0 and SAVE_IMGS else None)

            plot_alpha(fixed_image, moving_image, masks=mask_list, title='Before', color='r', alpha=0.5)
            plot_alpha(fixed_image, registered_image, masks=mask_list, title='After', color='r', alpha=0.5)

            #Add button to save transform
            button = widgets.Button(description='Save this one')
            button.on_click(partial(save_transform, pid=pid, transform=sitk.Transform(transform), 
                                    transform_dir=transform_dir))
            display(button)

        button = widgets.Button(description='Original was better')
        button.on_click(partial(save_transform, pid=pid, transform=sitk.Euler3DTransform(), 
                                transform_dir=transform_dir))
        display(button)

        #break