# Preprocess MR & US data

 1. Read all MR-US image pairs
 1. Register MR rigidly (only translation) to US using the segmentation masks
 1. Preprocess them: intentisty normalization, same size, etc.
 1. Save them as: `ID_MR_img.nrrd`, `ID_MR_msk.nrrd`, `ID_US_img.nrrd`, `ID_US_msk.nrrd`, where ID is the ID of the patient   

### Load required libraries

In [1]:
%load_ext autoreload
%autoreload 2

#Main imports
import os, glob, sys, pickle, numpy as np, pandas as pd,\
       matplotlib.pyplot as plt, SimpleITK as sitk

#Preprocessing functions
from lib.processing.preprocessing import (rescale_intensity, make_dirs, resample_image, 
                                          get_minimum_slices_transform, detect_border)

#Plot lib
from pathlib import Path
sys.path.append(os.path.join(Path.home(), 'plot_lib'))
from plot_lib import plot

### Set up some configuration

In [2]:
#Global configuration
PLOT= True #Plot all images
SIZE= [160]*3 #Output size
SPACING= [0.5]*3 #Output spacing
CLEAN=True #Apply morphological operators to clean the final mask
CONTINUE_IF_EXISTS= False #Skip any pids for which the processed images already exist  

#Main paths
PATH_IN=  'registration_data/images'
PATH_OUT= 'registration_data/preprocessed'
make_dirs(PATH_OUT)

#Simple function for custom-saving images
def save_image(img, name, path=PATH_OUT, normalize=True, overwrite=True):
    #Normalize
    if normalize:
        img_backup= sitk.Image(img)
        img_arr= rescale_intensity(sitk.GetArrayFromImage(img).astype(np.float32), thres=(.9, 99.9))
        img= sitk.GetImageFromArray(img_arr)
        img.CopyInformation(img_backup)
        
    #Save
    full_path= os.path.join(path, pid + '_' + name + '.nrrd')
    if not os.path.exists(full_path) or overwrite: 
        sitk.WriteImage(img, full_path, True)

### Process the images

In [3]:
for pid in ['ID0001', 'ID0002']:
    #Check if processed image already exists
    print(pid)
    if CONTINUE_IF_EXISTS:
        MR_PATH= os.path.join(PATH_OUT, pid + '_MR_img.nrrd')
        if os.path.exists(MR_PATH): 
            print(f"{MR_PATH} found. Continuing")
            continue

    #Load all
    mr_img, mr_msk, mr_msk_HD= sitk.ReadImage(os.path.join(PATH_IN, pid + '_MR_img.nrrd')),\
                               sitk.ReadImage(os.path.join(PATH_IN, pid + '_MR_msk.nrrd')),\
                               sitk.ReadImage(os.path.join(PATH_IN, pid + '_MR_msk_HD.nrrd'))
    us_img, us_msk= sitk.ReadImage(os.path.join(PATH_IN, pid + '_US_img.nrrd')),\
                    sitk.ReadImage(os.path.join(PATH_IN, pid + '_US_msk.nrrd'))
    
    #Take a center crop of images
    mr_img, mr_msk= resample_image(mr_img, SIZE, SPACING, mode='roi', mask=mr_msk_HD)
    us_img, us_msk= resample_image(us_img, SIZE, SPACING, mode='roi', mask=us_msk)
    
    #Check for mask problems
    #Very low DSC might be a problem 
    #(e.g. MR and US images do not correspond to the same patient)
    mr_msk_arr= sitk.GetArrayFromImage(mr_msk).astype(np.uint8)
    us_msk_arr= sitk.GetArrayFromImage(us_msk).astype(np.uint8)
    dsc= 2*(mr_msk_arr * us_msk_arr).sum() / (mr_msk_arr.sum() + us_msk_arr.sum()).astype(np.float32)
    print(' - DSC: %.4f'%(dsc))
    
    #Detect mask border
    #A mask touching the border might also be a problem 
    #(e.g. US prostate might have not been completely captured)
    detections= detect_border(us_msk_arr + mr_msk_arr, threshold=1)
    for border, detection in zip(['z-axis up', 'y-axis up', 'x-axis up', 
                                  'z-axis down', 'y-axis down', 'x-axis down'], detections):
        if detection and 'z' not in border:
            print(' > Warning: Mask very near to the %s border'%border)
    
    #Clean mask?
    if CLEAN:
        clean_params= ([7]*3, sitk.sitkBall)
        mr_msk = sitk.BinaryMorphologicalOpening(mr_msk, *clean_params)
        mr_msk = sitk.BinaryMorphologicalClosing(mr_msk, *clean_params)
        us_msk = sitk.BinaryMorphologicalOpening(us_msk, *clean_params)
        us_msk = sitk.BinaryMorphologicalClosing(us_msk, *clean_params)
    
    #Plot
    if PLOT or dsc < 0.8: #Plot also when DSC is very low
        plot(mr_img, masks=[mr_msk, us_msk], title=pid + ' (MR)')
        plot(us_img, masks=[mr_msk, us_msk], title=pid + ' (US)')
        
    #Save everything
    save_image(mr_img, 'MR_img')
    save_image(mr_msk, 'MR_msk', normalize=False)
    save_image(us_img, 'US_img')
    save_image(us_msk, 'US_msk', normalize=False)

ID0001
 - DSC: 0.8556


interactive(children=(IntSlider(value=80, description='z', max=159, style=SliderStyle(handle_color='lightblue'…

interactive(children=(IntSlider(value=80, description='z', max=159, style=SliderStyle(handle_color='lightblue'…

ID0002
 - DSC: 0.7925


interactive(children=(IntSlider(value=80, description='z', max=159, style=SliderStyle(handle_color='lightblue'…

interactive(children=(IntSlider(value=80, description='z', max=159, style=SliderStyle(handle_color='lightblue'…