## Image Registration

#### Do the image registration to ensure the ROIs are the same.
- 1)6 healthy volunteers: 1.5T VS 3T
- 2)For 1.5T, matrix=256x256, compare FOV=24 and FOV=18.
- 3.1)Compare different matrices:For 1.5T, FOV=24, compare matrix=256x256, 256x128, 128x128.
- 3.2)Compare different matrices:For 3T, FOV=25.6, compare matrix=320x320, 256x256, 128x128.


Registration tool: simpleelastix 
https://simpleelastix.readthedocs.io/GettingStarted.html

Visualize the registration results:
https://officeguide.cc/python-simpleitk-tutorial-combine-scalar-images-create-color-image/


In [None]:
import SimpleITK as sitk
import pickle
import numpy as np
import ants
import gc

import shutil
import os
import nrrd
import nibabel as nib
from PIL import Image 
import  matplotlib.pyplot as plt

import sys
sys.path.append('../')
from utils.brainImageUtils import *
from utils.myUtils import *
from mySettings import get_preprocess_normalize_path_dict, get_preprocessing_type_list, get_experiment_key, get_path_dict_by_datatype

In [None]:
'''
Register images and save the results.
'''
def register_images(fixed_image_path, moving_image_path, save_registered_dir, transform):
    fixed_img = ants.image_read(fixed_image_path)
    moving_img = ants.image_read(moving_image_path)
    registered_results= ants.registration(fixed=fixed_img, moving=moving_img, type_of_transform = 'Rigid' )
    registered_moving_image =registered_results['warpedmovout']
    transform_from_move_to_fix=registered_results['fwdtransforms']
    transform_from_fix_to_move=registered_results['invtransforms']
    
    
    #resave the fixed image and the registered moving image.
    resave_registered_fixed_image_path=save_registered_dir +'/' + os.path.basename(fixed_image_path)
    resave_registered_moving_image_path=save_registered_dir +'/' + os.path.basename(moving_image_path)
    #print('\n fixed_img=',fixed_img)
    #print('\n registered_moving_image=',registered_moving_image)
    ants.image_write(fixed_img, resave_registered_fixed_image_path)
    ants.image_write(registered_moving_image, resave_registered_moving_image_path)
    
    return  resave_registered_fixed_image_path, resave_registered_moving_image_path

    
    

In [None]:
# '''
# Register images and save the results.
# '''
# def register_images(fixed_image_path, moving_image_path, save_registered_dir, transform):

#     fixedImage = sitk.ReadImage(fixed_image_path)
#     movingImage = sitk.ReadImage(moving_image_path)

#     elastixImageFilter = sitk.ElastixImageFilter()
#     elastixImageFilter.SetFixedImage(fixedImage)
#     elastixImageFilter.SetMovingImage(movingImage)

#     if transformation=='rigid':
#         elastixImageFilter.SetParameterMap(sitk.GetDefaultParameterMap('rigid'))
        
#     elif transformation=='affine':
#         parameterMapVector = sitk.VectorOfParameterMap()
#         parameterMapVector.append(sitk.GetDefaultParameterMap("rigid"))
#         parameterMapVector.append(sitk.GetDefaultParameterMap("affine"))
#         elastixImageFilter.SetParameterMap(parameterMapVector)

#     elif transformation=='non-rigid':
#         parameterMapVector = sitk.VectorOfParameterMap()
#         parameterMapVector.append(sitk.GetDefaultParameterMap("rigid"))
#         parameterMapVector.append(sitk.GetDefaultParameterMap("bspline"))
#         elastixImageFilter.SetParameterMap(parameterMapVector)


#     elastixImageFilter.Execute()
    
    
#     #resave the fixed image and the registered moving image.
#     resave_registered_fixed_image_path=save_registered_dir +'/' + os.path.basename(fixed_image_path)
#     resave_registered_moving_image_path=save_registered_dir +'/' + os.path.basename(moving_image_path)
#     sitk.WriteImage(fixedImage, resave_registered_fixed_image_path)
#     sitk.WriteImage(elastixImageFilter.GetResultImage(), resave_registered_moving_image_path)
    
#     return  resaved_registered_fixed_image_path, resave_registered_moving_image_path

In [None]:
def check_image_hist(image_path, show=False):
    image=sitk.ReadImage(image_path)
    image_array = sitk.GetArrayViewFromImage(image)
    print("Image: ({}, {})".format(image_array.min(), image_array.max()))

    #show image hist
    if show:
        plt.figure(figsize=(5,5))
        plt.hist(image_array.flatten())
        plt.show()


In [None]:
'''
Visualize the registration results.
'''
def show_registered_images(fixed_image_path, moving_image_path, registered_image_path, save_visualize_basic_path):

    fixed_image=nib.load(fixed_image_path).get_fdata()
    moving_image=nib.load(moving_image_path).get_fdata()
    registered_image=nib.load(registered_image_path).get_fdata()
    
    
    # read images
    fixedImage = sitk.ReadImage(fixed_image_path)
    movingImage = sitk.ReadImage(moving_image_path)
    registeredImage = sitk.ReadImage(registered_image_path)
    
    #convert the pixel value to uint8
    fixedNDA = sitk.GetArrayViewFromImage(fixedImage)
    movingNDA = sitk.GetArrayViewFromImage(movingImage)
    registeredNDA = sitk.GetArrayViewFromImage(registeredImage)

    fixedImageUint8 = sitk.Cast(sitk.IntensityWindowing(fixedImage, windowMinimum = float(fixedNDA.min()), windowMaximum = float(fixedNDA.max()),
                                                 outputMinimum = 0.0, outputMaximum = 255.0), sitk.sitkUInt8)
    movingImageUint8 = sitk.Cast(sitk.IntensityWindowing(movingImage, windowMinimum = float(movingNDA.min()), windowMaximum = float(movingNDA.max()),
                                                 outputMinimum = 0.0, outputMaximum = 255.0), sitk.sitkUInt8)
    registeredImageUint8 = sitk.Cast(sitk.IntensityWindowing(registeredImage, windowMinimum = float(registeredNDA.min()), windowMaximum = float(registeredNDA.max()),
                                                 outputMinimum = 0.0, outputMaximum = 255.0), sitk.sitkUInt8)

    
    #show the images
    number_of_slices=min(fixed_image.shape[2], moving_image.shape[2])
    for index in range(0, number_of_slices):
            
        rgbImage_afterReg = sitk.Cast(sitk.Compose(fixedImageUint8[:,:,index], 
                                                   fixedImageUint8[:,:,index] * 0.5 + registeredImageUint8[:,:,index] * 0.5, 
                                                   registeredImageUint8[:,:,index]), sitk.sitkVectorUInt8)
            
        
        plt.figure(figsize=(20,5))
        sub_fig=plt.subplot(1,4,1)
        sub_fig.set_title("Fixed Image", fontsize=20)
        plt.axis('off')
        plt.imshow(fixed_image[:,:,index].T)

        sub_fig=plt.subplot(1,4,2)
        sub_fig.set_title("Moving Image", fontsize=20)
        plt.axis('off')
        plt.imshow(moving_image[:,:,index].T)

        sub_fig=plt.subplot(1,4,3)
        sub_fig.set_title("Registered Image", fontsize=20)
        plt.axis('off')
        plt.imshow(registered_image[:,:,index].T)
        
        sub_fig=plt.subplot(1,4,4)
        sub_fig.set_title("Overlap after registration", fontsize=20)
        plt.axis('off')
        plt.imshow(sitk.GetArrayViewFromImage(rgbImage_afterReg))
        

        plt.subplots_adjust(left=0.03, bottom=0.03, right=0.97, top=0.9,  wspace=0.03, hspace=0)
        
        basename=os.path.basename(moving_image_path)
        save_path= save_visualize_basic_path+'/'+ basename[:-5]+'_'+str(index)+'.png'
        plt.savefig(save_path)
        
        #plt.show()
        
        # Clear the current axes.
        plt.cla() 
        # Clear the current figure.
        plt.clf() 
        # Closes all the figure windows.
        plt.close('all')
        gc.collect()

In [None]:
'''
Main code for performing registration and visualize the results.
'''
def perform_registration(image_dir, save_registered_dir, registration_info):
    registration_info_dict=parse_registration_info_txt(registration_info)
    num_patients=len(registration_info_dict)
    i=0
    for patient, registration_info in registration_info_dict.items():
        i=i+1
        print("\n=====registration: {}/{}=====".format(i,num_patients))

        #transform
        transform=registration_info["transform"]
        fixed_image_path=os.path.join(image_dir, registration_info["fixed_image"]+ '.nii.gz')
        moving_image_name_list=registration_info["moving_image_list"]

        for moving_image_name in moving_image_name_list:
            moving_image_path=os.path.join(image_dir, moving_image_name+ '.nii.gz')
            print('{}:{} \n Fixed image:{} \n Moving image:{} .'.format(patient, transform, fixed_image_path, moving_image_path))

            #register images
            resave_registered_fixed_image_path, resave_registered_moving_image_path=register_images(fixed_image_path, moving_image_path, save_registered_dir, transform)

            #check image hist
            check_image_hist(resave_registered_fixed_image_path)
            check_image_hist(moving_image_path)
            check_image_hist(resave_registered_moving_image_path)

            #visualize registration results
            save_visualize_basic_path=os.path.join(save_registered_dir, os.path.basename(moving_image_name))
            mkdir(save_visualize_basic_path)
            show_registered_images(resave_registered_fixed_image_path, moving_image_path, resave_registered_moving_image_path, save_visualize_basic_path)

### main

In [None]:
def main_registration():
    #basic settings
    preprocess_type_list=get_preprocessing_type_list()
    experiment_key_list=get_experiment_key()
    registration_info_dict=get_path_dict_by_datatype(data_type='registration_info')
    
    #preprocess images according to defined preprocess type.
    for preprocess_type in preprocess_type_list:   
        print("====================================preprocess_type={}================================================".format(preprocess_type))
        image_path_dict=get_preprocess_normalize_path_dict(path_type='preprocessed_data', preprocess_type=preprocess_type)
        registered_image_path_dict=get_preprocess_normalize_path_dict(path_type='registered_images', preprocess_type=preprocess_type)
        for experiment_key in experiment_key_list:
            print("============={}=============".format(experiment_key))
            image_dir=image_path_dict[experiment_key]
            save_registered_dir=registered_image_path_dict[experiment_key]
            registration_info=registration_info_dict[experiment_key]
            perform_registration(image_dir, save_registered_dir, registration_info)
            

In [None]:
main_registration()