In [None]:
from __future__ import print_function
from __future__ import absolute_import

import os
import numpy as np
import SimpleITK as sitk

from skimage.filters import gaussian
from skimage.exposure import equalize_adapthist

from matplotlib import pyplot as plt
from IPython import display

%matplotlib inline

# Applying the preprocessing proposed in [1] to all data using the CLAHE method in [2]
# [1] Stollenga et al. (2015). Parallel Multi-Dimensional LSTM, With Application to 
#     Fast Biomedical Volumetric Image Segmentation. NIPS 2015.
# [2] Pizer et al. (1987) Adaptive Histogram Equalization and Its Variations. 
#     Comput. Vision Graph. Image Process.

# adapt the base data path accordingly
data_path = '../../data/MRBrainS13DataNii/'
out_path = '../../data/MRBrainS13DataNii/prep/'

In [None]:
def normalise(img):
    """ Normalisation of axial slices according to [1] """
    out = np.zeros_like(img)
    
    for i in range(img.shape[0]):
        ax_slice = img[i,:,:]
        ax_slice = (ax_slice - np.mean(ax_slice)) / np.var(ax_slice)
        ax_slice[ax_slice < -1] = -1
        ax_slice[ax_slice > 1] = 1
                        
        out[i,:,:,] = ax_slice

    return out
        

def pre_process(img):
    """ Preprocessing of axial slices according to [1] """
    out = np.zeros_like(img)

    for i in range(img.shape[0]):
        ax_slice = img[i,:,:]
       
        # subtract a Gaussian smoothed image
        ax_slice -= gaussian(ax_slice, sigma=5.0)
                            
        # apply CLAHE
        ax_slice = equalize_adapthist(ax_slice, kernel_size=16, clip_limit=2.0)
           
        out[i,:,:] = ax_slice
        
    return out

In [None]:
# Visualise all images and preprocessing
def visualise(T1, T1_IR, T2_FLAIR, T1pp, T1_IRpp, T2_FLAIRpp):
    plt.close()
    f, axarr = plt.subplots(2, 3, figsize=(16,8))
    axarr[0,0].imshow(np.squeeze(T1[24,:,:]), cmap='gray')
    axarr[0,0].set_title('T1')
    axarr[0,0].axis('off')

    axarr[0,1].imshow(np.squeeze(T1_IR[24,:,:]), cmap='gray')
    axarr[0,1].set_title('T1 IR')
    axarr[0,1].axis('off')

    axarr[0,2].imshow(np.squeeze(T2_FLAIR[24,:,:]), cmap='gray')
    axarr[0,2].set_title('T2 FLAIR')
    axarr[0,2].axis('off')

    axarr[1,0].imshow(np.squeeze(T1pp[24,:,:]), cmap='gray')
    axarr[1,0].set_title('T1 CLAHE')
    axarr[1,0].axis('off')

    axarr[1,1].imshow(np.squeeze(T1_IRpp[24,:,:]), cmap='gray')
    axarr[1,1].set_title('T1 IR CLAHE')
    axarr[1,1].axis('off')

    axarr[1,2].imshow(np.squeeze(T2_FLAIRpp[24,:,:]), cmap='gray')
    axarr[1,2].set_title('T2 FLAIR CLAHE')
    axarr[1,2].axis('off')

    display.clear_output(wait=True)
    display.display(plt.gcf())

In [None]:
# iterate through all training and test data
for i in ['TrainingData', 'TestData']:
    print(i)
    ids = os.listdir(os.path.join(data_path, i))
    for j in ids:
        # read the original files
        T1_in = sitk.ReadImage(os.path.join(data_path, i, j, 'T1.nii'))
        T1_IR_in = sitk.ReadImage(os.path.join(data_path, i, j, 'T1_IR.nii'))
        T2_FLAIR_in = sitk.ReadImage(os.path.join(data_path, i, j, 'T2_FLAIR.nii'))
        
        # https://github.com/scikit-image/scikit-image/issues/2383
        T1 = sitk.GetArrayFromImage(T1_in).astype(np.float64)
        T1_IR = sitk.GetArrayFromImage(T1_IR_in).astype(np.float64)
        T2_FLAIR = sitk.GetArrayFromImage(T2_FLAIR_in).astype(np.float64)
        
        # only normalise 
        T1 = normalise(T1)
        T1_IR = normalise(T1_IR)
        T2_FLAIR = normalise(T2_FLAIR)
        
        # fully preprocess
        T1pp = pre_process(T1)
        T1_IRpp = pre_process(T1_IR)
        T2_FLAIRpp = pre_process(T2_FLAIR)
        
        # visualise
        if True:
            visualise(T1, T1_IR, T2_FLAIR, T1pp, T1_IRpp, T2_FLAIRpp)
            
        # write to file
        if True:
            # create sitk images and copy the header information from the original data
            T1_out = sitk.GetImageFromArray(T1)
            T1_IR_out = sitk.GetImageFromArray(T1_IR)
            T2_FLAIR_out = sitk.GetImageFromArray(T2_FLAIR)
            
            T1pp_out = sitk.GetImageFromArray(T1pp)
            T1_IRpp_out = sitk.GetImageFromArray(T1_IRpp)
            T2_FLAIRpp_out = sitk.GetImageFromArray(T2_FLAIRpp)
            
            # create an output folder and write to file
            write_path = os.path.join(out_path, i, j)
            os.system("rm -rf %s" % write_path)
            os.system("mkdir -p %s" % write_path)
            
            sitk.WriteImage(T1_out, os.path.join(write_path, 'T1.nii'))
            sitk.WriteImage(T1_IR_out, os.path.join(write_path, 'T1_IR.nii'))
            sitk.WriteImage(T2_FLAIR_out, os.path.join(write_path, 'T2_FLAIR.nii'))
            
            sitk.WriteImage(T1pp_out, os.path.join(write_path, 'T1pp.nii'))
            sitk.WriteImage(T1_IRpp_out, os.path.join(write_path, 'T1_IRpp.nii'))
            sitk.WriteImage(T2_FLAIRpp_out, os.path.join(write_path, 'T2_FLAIRpp.nii'))
            