In [10]:
#Codes for validating the WMH Challenge training Datasets. The algorithm won the WMH Challenge.
#Codes are written by Mr. Hongwei Li (h.l.li@dundee.ac.uk), Mr. Gongfa Jiang and Miss. Zhaolei Wang from Sun Yat-sen University and University of Dundee.
#

from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import difflib
import SimpleITK as sitk
import scipy.spatial
from keras.models import Model
from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D, Cropping2D, ZeroPadding2D
from keras.optimizers import Adam
# from evaluation import getDSC, getHausdorff, getLesionDetection, getAVD, getImages  #please download evaluation.py from the WMH website
from keras.callbacks import ModelCheckpoint
from keras import backend as K
# from show import imshow
from scipy import ndimage
#from sklearn.utils import class_weight
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# import plotly.plotly as py
# import plotly.figure_factory as ff
# import plotly.graph_objs as go 

### ----define loss function for U-net ------------
smooth = 1.

In [33]:
def Utrecht_preprocessing(FLAIR_image, T1_image):

    channel_num = 2
    print(np.shape(FLAIR_image))
    num_selected_slice = np.shape(FLAIR_image)[0]
    image_rows_Dataset = np.shape(FLAIR_image)[1]
    image_cols_Dataset = np.shape(FLAIR_image)[2]
    T1_image = np.float32(T1_image)

    brain_mask_FLAIR = np.ndarray((np.shape(FLAIR_image)[0],image_rows_Dataset, image_cols_Dataset), dtype=np.float32)
    brain_mask_T1 = np.ndarray((np.shape(FLAIR_image)[0],image_rows_Dataset, image_cols_Dataset), dtype=np.float32)
    imgs_two_channels = np.ndarray((num_selected_slice, rows_standard, cols_standard, channel_num), dtype=np.float32)
    imgs_mask_two_channels = np.ndarray((num_selected_slice, rows_standard, cols_standard,1), dtype=np.float32)

    # FLAIR --------------------------------------------
    brain_mask_FLAIR[FLAIR_image >=thresh_FLAIR] = 1
    brain_mask_FLAIR[FLAIR_image < thresh_FLAIR] = 0
    for iii in range(np.shape(FLAIR_image)[0]):
        brain_mask_FLAIR[iii,:,:] = scipy.ndimage.morphology.binary_fill_holes(brain_mask_FLAIR[iii,:,:])  #fill the holes inside brain
    print(int(image_rows_Dataset/2-rows_standard/2),int(image_rows_Dataset/2+rows_standard/2), int(image_cols_Dataset/2-cols_standard/2),int(image_cols_Dataset/2+cols_standard/2))
    print(FLAIR_image.shape)
    
    FLAIR_image = FLAIR_image[:, int(image_rows_Dataset/2-rows_standard/2):int(image_rows_Dataset/2+rows_standard/2), int(image_cols_Dataset/2-cols_standard/2):int(image_cols_Dataset/2+cols_standard/2)]
    brain_mask_FLAIR = brain_mask_FLAIR[:, int(image_rows_Dataset/2-rows_standard/2):int(image_rows_Dataset/2+rows_standard/2), int(image_cols_Dataset/2-cols_standard/2):int(image_cols_Dataset/2+cols_standard/2)]
    ###------Gaussion Normalization here
    FLAIR_image -=np.mean(FLAIR_image[brain_mask_FLAIR == 1])      #Gaussion Normalization
    FLAIR_image /=np.std(FLAIR_image[brain_mask_FLAIR == 1])
    # T1 -----------------------------------------------
    brain_mask_T1[T1_image >=thresh_T1] = 1
    brain_mask_T1[T1_image < thresh_T1] = 0
    for iii in range(np.shape(T1_image)[0]):
        brain_mask_T1[iii,:,:] = scipy.ndimage.morphology.binary_fill_holes(brain_mask_T1[iii,:,:])  #fill the holes inside brain
    T1_image = T1_image[:, int(image_rows_Dataset/2-rows_standard/2):int(image_rows_Dataset/2+rows_standard/2), int(image_cols_Dataset/2-cols_standard/2):int(image_cols_Dataset/2+cols_standard/2)]
    brain_mask_T1 = brain_mask_T1[:, int(image_rows_Dataset/2-rows_standard/2):int(image_rows_Dataset/2+rows_standard/2), int(image_cols_Dataset/2-cols_standard/2):int(image_cols_Dataset/2+cols_standard/2)]
    #------Gaussion Normalization
    T1_image -=np.mean(T1_image[brain_mask_T1 == 1])      
    T1_image /=np.std(T1_image[brain_mask_T1 == 1])
    #---------------------------------------------------
    FLAIR_image  = FLAIR_image[..., np.newaxis]
    T1_image  = T1_image[..., np.newaxis]
    imgs_two_channels = np.concatenate((FLAIR_image, T1_image), axis = 3)
    print(np.shape(imgs_two_channels))
    return imgs_two_channels

In [12]:
def Utrecht_postprocessing(FLAIR_array, pred):
    start_slice = 6
    num_selected_slice = np.shape(FLAIR_array)[0]
    image_rows_Dataset = np.shape(FLAIR_array)[1]
    image_cols_Dataset = np.shape(FLAIR_array)[2]
    original_pred = np.ndarray(np.shape(FLAIR_array), dtype=np.float32)
    original_pred[:,(image_rows_Dataset-rows_standard)/2:(image_rows_Dataset+rows_standard)/2,(image_cols_Dataset-cols_standard)/2:(image_cols_Dataset+cols_standard)/2] = pred[:,:,:,0]
    
    original_pred[0:start_slice, :, :] = 0
    original_pred[(num_selected_slice-start_slice-1):(num_selected_slice-1), :, :] = 0
    return original_pred

In [52]:
def GE3T_preprocessing(FLAIR_image, T1_image):

  #  start_slice = 10
    channel_num = 2
    start_cut = 46
    print(np.shape(FLAIR_image))
    num_selected_slice = np.shape(FLAIR_image)[0]
    image_rows_Dataset = np.shape(FLAIR_image)[1]
    image_cols_Dataset = np.shape(FLAIR_image)[2]
    FLAIR_image = np.float32(FLAIR_image)
    T1_image = np.float32(T1_image)

    brain_mask_FLAIR = np.ndarray((np.shape(FLAIR_image)[0],image_rows_Dataset, image_cols_Dataset), dtype=np.float32)
    brain_mask_T1 = np.ndarray((np.shape(FLAIR_image)[0],image_rows_Dataset, image_cols_Dataset), dtype=np.float32)
    imgs_two_channels = np.ndarray((num_selected_slice, rows_standard, cols_standard, channel_num), dtype=np.float32)
    imgs_mask_two_channels = np.ndarray((num_selected_slice, rows_standard, cols_standard,1), dtype=np.float32)
    FLAIR_image_suitable = np.ndarray((num_selected_slice, rows_standard, cols_standard), dtype=np.float32)
    T1_image_suitable = np.ndarray((num_selected_slice, rows_standard, cols_standard), dtype=np.float32)

    # FLAIR --------------------------------------------
    brain_mask_FLAIR[FLAIR_image >=thresh_FLAIR] = 1
    brain_mask_FLAIR[FLAIR_image < thresh_FLAIR] = 0
    for iii in range(np.shape(FLAIR_image)[0]):
  
        brain_mask_FLAIR[iii,:,:] = scipy.ndimage.morphology.binary_fill_holes(brain_mask_FLAIR[iii,:,:])  #fill the holes inside brain
        #------Gaussion Normalization
    FLAIR_image -=np.mean(FLAIR_image[brain_mask_FLAIR == 1])      #Gaussion Normalization
    FLAIR_image /=np.std(FLAIR_image[brain_mask_FLAIR == 1])

    FLAIR_image_suitable[...] = np.min(FLAIR_image)
    FLAIR_image_suitable[:, :, int(cols_standard/2-image_cols_Dataset/2):int(cols_standard/2+image_cols_Dataset/2)] = FLAIR_image[:, start_cut:start_cut+rows_standard, :]
   
    # T1 -----------------------------------------------
    brain_mask_T1[T1_image >=thresh_T1] = 1
    brain_mask_T1[T1_image < thresh_T1] = 0
    for iii in range(np.shape(T1_image)[0]):
 
        brain_mask_T1[iii,:,:] = scipy.ndimage.morphology.binary_fill_holes(brain_mask_T1[iii,:,:])  #fill the holes inside brain
        #------Gaussion Normalization
    T1_image -=np.mean(T1_image[brain_mask_T1 == 1])      #Gaussion Normalization
    T1_image /=np.std(T1_image[brain_mask_T1 == 1])

    T1_image_suitable[...] = np.min(T1_image)
    T1_image_suitable[:, :, int((cols_standard-image_cols_Dataset)/2):int((cols_standard+image_cols_Dataset)/2)] = T1_image[:, start_cut:start_cut+rows_standard, :]
    #---------------------------------------------------
    FLAIR_image_suitable  = FLAIR_image_suitable[..., np.newaxis]
    T1_image_suitable  = T1_image_suitable[..., np.newaxis]
    
    imgs_two_channels = np.concatenate((FLAIR_image_suitable, T1_image_suitable), axis = 3)
    print(np.shape(imgs_two_channels))
    return imgs_two_channels

In [53]:
int(3/2)


1

In [54]:
def GE3T_postprocessing(FLAIR_array, pred):
    start_slice = 11
    start_cut = 46
    num_selected_slice = np.shape(FLAIR_array)[0]
    image_rows_Dataset = np.shape(FLAIR_array)[1]
    image_cols_Dataset = np.shape(FLAIR_array)[2]
    original_pred = np.ndarray(np.shape(FLAIR_array), dtype=np.float32)
    original_pred[:, start_cut:start_cut+rows_standard,:] = pred[:,:, (rows_standard-image_cols_Dataset)/2:(rows_standard+image_cols_Dataset)/2,0]

    original_pred[0:start_slice, :, :] = 0
    original_pred[(num_selected_slice-start_slice-1):(num_selected_slice-1), :, :] = 0
    return original_pred

In [55]:
###---Here comes the main funtion--------------------------------------------
###---Leave one patient out validation--------------------------------------------

patient_num =45
patient_count = 0
rows_standard = 200
cols_standard = 200
thresh_FLAIR = 70      #to mask the brain
thresh_T1 = 30
para_array = [[0.958, 0.958, 3], [1.00, 1.00, 3], [1.20, 0.977, 3]]    # parameters of the scanner
para_array = np.array(para_array, dtype=np.float32)


#read the dirs of test data 
input_dir_1 = '../data/train/Utrecht'
input_dir_2 = '../data/train/Singapore'
input_dir_3 = '../data/train/Amsterdam'
###---dir to save results---------
outputDir = 'evaluation_result_LOOV'
#-------------------------------------------
dirs = os.listdir(input_dir_1) + os.listdir(input_dir_2) + os.listdir(input_dir_3)
# #All the slices and the corresponding patients id
# imgs_three_datasets_two_channels = np.load('imgs_three_datasets_two_channels.npy')
# imgs_mask_three_datasets_two_channels = np.load('imgs_mask_three_datasets_two_channels.npy')
# slices_patient_id_label = np.load('slices_patient_id_label.npy')


for dir_name in dirs:
    print('dir_name is:')
    print(dir_name)
    
    if patient_count < 15:
        inputDir = input_dir_1
    elif patient_count > 14 and patient_count < 30:
        inputDir = input_dir_2
    elif patient_count >= 30:
        inputDir = input_dir_3
    FLAIR_image = sitk.ReadImage(os.path.join(inputDir, dir_name, 'pre', 'FLAIR.nii.gz'))
    T1_image = sitk.ReadImage(os.path.join(inputDir, dir_name, 'pre', 'T1.nii.gz'))
    
    FLAIR_array = sitk.GetArrayFromImage(FLAIR_image)
    T1_array = sitk.GetArrayFromImage(T1_image)
    #Proccess testing data-----
    para_FLAIR = np.ndarray((1,3), dtype=np.float32)
    para_FLAIR_ = FLAIR_image.GetSpacing()
    para_FLAIR[0,0] = round(para_FLAIR_[0],3)   # get spacing parameters of the data
    para_FLAIR[0,1] = round(para_FLAIR_[1],3)  
    para_FLAIR[0,2] = round(para_FLAIR_[2],3) 
    if np.array_equal(para_FLAIR[0], para_array[0]) :
        print('From Utrecht!')
        imgs_test = Utrecht_preprocessing(FLAIR_array, T1_array)
    elif np.array_equal(para_FLAIR[0], para_array[1]):
        print('From Singapore!')
        imgs_test = Utrecht_preprocessing(FLAIR_array, T1_array)
    elif np.array_equal(para_FLAIR[0], para_array[2]):
        print('From Amsterdam!')
        imgs_test = GE3T_preprocessing(FLAIR_array, T1_array)
    print(imgs_test.shape)
    patient_count+=1
#     ###---train u-net models-------------------------------------------------------------------------------
#     training_index = slices_patient_id_label != patient_count
#     test_index = slices_patient_id_label == patient_count
#     dim_training = sum(training_index)
#     dim_test = sum(test_index)
#     print('the dim of training set:')
#     print(dim_training[0])
#     imgs_train = np.ndarray((dim_training[0], rows_standard, cols_standard, 2), dtype=np.float32)
#     imgs_test_selected = np.ndarray((dim_test[0], rows_standard, cols_standard, 2), dtype=np.float32)

#     imgs_mask_train = np.ndarray((dim_training[0], rows_standard, cols_standard, 1), dtype=np.float32)
#     imgs_mask_test_selected = np.ndarray((dim_test[0], rows_standard, cols_standard, 1), dtype=np.float32)
#     count_index_train = 0
#     count_index_test = 0
#     for iii in range(training_index.shape[0]):
#         if training_index[iii] == 1:
#             imgs_train[count_index_train, ...] = imgs_three_datasets_two_channels[iii, ...]
#             imgs_mask_train[count_index_train, ...] = imgs_mask_three_datasets_two_channels[iii, ...]
#             count_index_train = count_index_train + 1
#         if training_index[iii] == 0:
#             imgs_test_selected[count_index_test, ...] = imgs_three_datasets_two_channels[iii, ...]
#             imgs_mask_test_selected[count_index_test, ...] = imgs_mask_three_datasets_two_channels[iii, ...]
#             count_index_test = count_index_test + 1

#     print('training dataset dimension:')
#     print(imgs_train.shape[0])
#     img_shape=(rows_standard, cols_standard, 2)


#     print('-'*30)
#     print('Fitting model...')
#     print('-'*30)

dir_name is:
11
From Utrecht!
(48, 240, 240)
20 220 20 220
(48, 240, 240)
(48, 200, 200, 2)
(48, 200, 200, 2)
dir_name is:
17
From Utrecht!
(48, 240, 240)
20 220 20 220
(48, 240, 240)
(48, 200, 200, 2)
(48, 200, 200, 2)
dir_name is:
19
From Utrecht!
(48, 240, 240)
20 220 20 220
(48, 240, 240)
(48, 200, 200, 2)
(48, 200, 200, 2)
dir_name is:
2
From Utrecht!
(48, 240, 240)
20 220 20 220
(48, 240, 240)
(48, 200, 200, 2)
(48, 200, 200, 2)
dir_name is:
25
From Utrecht!
(48, 240, 240)
20 220 20 220
(48, 240, 240)
(48, 200, 200, 2)
(48, 200, 200, 2)
dir_name is:
27
From Utrecht!
(48, 240, 240)
20 220 20 220
(48, 240, 240)
(48, 200, 200, 2)
(48, 200, 200, 2)
dir_name is:
29
From Utrecht!
(48, 240, 240)
20 220 20 220
(48, 240, 240)
(48, 200, 200, 2)
(48, 200, 200, 2)
dir_name is:
33
From Utrecht!
(48, 240, 240)
20 220 20 220
(48, 240, 240)
(48, 200, 200, 2)
(48, 200, 200, 2)
dir_name is:
37
From Utrecht!
(48, 240, 240)
20 220 20 220
(48, 240, 240)
(48, 200, 200, 2)
(48, 200, 200, 2)
dir_name is