In [73]:
#Codes for validating the WMH Challenge training Datasets. The algorithm won the WMH Challenge.
#Codes are written by Mr. Hongwei Li (h.l.li@dundee.ac.uk), Mr. Gongfa Jiang and Miss. Zhaolei Wang from Sun Yat-sen University and University of Dundee.
#

from __future__ import print_function
import os
import numpy as np
from random import shuffle
import tensorflow as tf
import difflib
import SimpleITK as sitk
import scipy.spatial
from keras.models import Model
from keras.layers import Input, merge, Convolution2D, MaxPooling2D, UpSampling2D, Cropping2D, ZeroPadding2D
from keras.optimizers import Adam
# from evaluation import getDSC, getHausdorff, getLesionDetection, getAVD, getImages  #please download evaluation.py from the WMH website
from keras.callbacks import ModelCheckpoint
from keras import backend as K
# from show import imshow
from scipy import ndimage
#from sklearn.utils import class_weight
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
# import plotly.plotly as py
# import plotly.figure_factory as ff
# import plotly.graph_objs as go 

### ----define loss function for U-net ------------
smooth = 1.

In [42]:
def Utrecht_preprocessing(FLAIR_image, T1_image,labelArray):

    channel_num = 2
    print(np.shape(FLAIR_image))
    num_selected_slice = np.shape(FLAIR_image)[0]
    image_rows_Dataset = np.shape(FLAIR_image)[1]
    image_cols_Dataset = np.shape(FLAIR_image)[2]
    T1_image = np.float32(T1_image)

    brain_mask_FLAIR = np.ndarray((np.shape(FLAIR_image)[0],image_rows_Dataset, image_cols_Dataset), dtype=np.float32)
    brain_mask_T1 = np.ndarray((np.shape(FLAIR_image)[0],image_rows_Dataset, image_cols_Dataset), dtype=np.float32)
    imgs_two_channels = np.ndarray((num_selected_slice, rows_standard, cols_standard, channel_num), dtype=np.float32)
    brain_label = np.ndarray((np.shape(FLAIR_image)[0],image_rows_Dataset, image_cols_Dataset), dtype=np.float32)
    imgs_mask_two_channels = np.ndarray((num_selected_slice, rows_standard, cols_standard,1), dtype=np.float32)

    # FLAIR --------------------------------------------
    brain_mask_FLAIR[FLAIR_image >=thresh_FLAIR] = 1
    brain_mask_FLAIR[FLAIR_image < thresh_FLAIR] = 0
    for iii in range(np.shape(FLAIR_image)[0]):
        brain_mask_FLAIR[iii,:,:] = scipy.ndimage.morphology.binary_fill_holes(brain_mask_FLAIR[iii,:,:])  #fill the holes inside brain
    print(int(image_rows_Dataset/2-rows_standard/2),int(image_rows_Dataset/2+rows_standard/2), int(image_cols_Dataset/2-cols_standard/2),int(image_cols_Dataset/2+cols_standard/2))
    print(FLAIR_image.shape)
    
    FLAIR_image = FLAIR_image[:, int(image_rows_Dataset/2-rows_standard/2):int(image_rows_Dataset/2+rows_standard/2), int(image_cols_Dataset/2-cols_standard/2):int(image_cols_Dataset/2+cols_standard/2)]
    brain_mask_FLAIR = brain_mask_FLAIR[:, int(image_rows_Dataset/2-rows_standard/2):int(image_rows_Dataset/2+rows_standard/2), int(image_cols_Dataset/2-cols_standard/2):int(image_cols_Dataset/2+cols_standard/2)]
    ###------Gaussion Normalization here
    FLAIR_image -=np.mean(FLAIR_image[brain_mask_FLAIR == 1])      #Gaussion Normalization
    FLAIR_image /=np.std(FLAIR_image[brain_mask_FLAIR == 1])
    # T1 -----------------------------------------------
    brain_mask_T1[T1_image >=thresh_T1] = 1
    brain_mask_T1[T1_image < thresh_T1] = 0
    for iii in range(np.shape(T1_image)[0]):
        brain_mask_T1[iii,:,:] = scipy.ndimage.morphology.binary_fill_holes(brain_mask_T1[iii,:,:])  #fill the holes inside brain
    T1_image = T1_image[:, int(image_rows_Dataset/2-rows_standard/2):int(image_rows_Dataset/2+rows_standard/2), int(image_cols_Dataset/2-cols_standard/2):int(image_cols_Dataset/2+cols_standard/2)]
    brain_mask_T1 = brain_mask_T1[:, int(image_rows_Dataset/2-rows_standard/2):int(image_rows_Dataset/2+rows_standard/2), int(image_cols_Dataset/2-cols_standard/2):int(image_cols_Dataset/2+cols_standard/2)]
    #------Gaussion Normalization
    T1_image -=np.mean(T1_image[brain_mask_T1 == 1])      
    T1_image /=np.std(T1_image[brain_mask_T1 == 1])
    # lable----------------
    brain_label[labelArray == 1] = 1
    brain_label[labelArray != 1] = 0
    imgs_mask_two_channels = brain_label[:, int(image_rows_Dataset/2-rows_standard/2):int(image_rows_Dataset/2+rows_standard/2), int(image_cols_Dataset/2-cols_standard/2):int(image_cols_Dataset/2+cols_standard/2)]
    #---------------------------------------------------
    FLAIR_image  = FLAIR_image[..., np.newaxis]
    T1_image  = T1_image[..., np.newaxis]
    imgs_two_channels = np.concatenate((FLAIR_image, T1_image), axis = 3)
    print(np.shape(imgs_mask_two_channels))
    print(np.shape(imgs_two_channels))
    maskArray = imgs_mask_two_channels > 0

    return imgs_two_channels,imgs_mask_two_channels, maskArray

In [25]:
def Utrecht_postprocessing(FLAIR_array, pred):
    start_slice = 6
    num_selected_slice = np.shape(FLAIR_array)[0]
    image_rows_Dataset = np.shape(FLAIR_array)[1]
    image_cols_Dataset = np.shape(FLAIR_array)[2]
    original_pred = np.ndarray(np.shape(FLAIR_array), dtype=np.float32)
    original_pred[:,(image_rows_Dataset-rows_standard)/2:(image_rows_Dataset+rows_standard)/2,(image_cols_Dataset-cols_standard)/2:(image_cols_Dataset+cols_standard)/2] = pred[:,:,:,0]
    
    original_pred[0:start_slice, :, :] = 0
    original_pred[(num_selected_slice-start_slice-1):(num_selected_slice-1), :, :] = 0
    return original_pred

In [44]:
def GE3T_preprocessing(FLAIR_image, T1_image,labelArray):

  #  start_slice = 10
    channel_num = 2
    start_cut = 46
    print(np.shape(FLAIR_image))
    num_selected_slice = np.shape(FLAIR_image)[0]
    image_rows_Dataset = np.shape(FLAIR_image)[1]
    image_cols_Dataset = np.shape(FLAIR_image)[2]
    FLAIR_image = np.float32(FLAIR_image)
    T1_image = np.float32(T1_image)

    brain_mask_FLAIR = np.ndarray((np.shape(FLAIR_image)[0],image_rows_Dataset, image_cols_Dataset), dtype=np.float32)
    brain_mask_T1 = np.ndarray((np.shape(FLAIR_image)[0],image_rows_Dataset, image_cols_Dataset), dtype=np.float32)
    imgs_two_channels = np.ndarray((num_selected_slice, rows_standard, cols_standard, channel_num), dtype=np.float32)
    
    FLAIR_image_suitable = np.ndarray((num_selected_slice, rows_standard, cols_standard), dtype=np.float32)
    T1_image_suitable = np.ndarray((num_selected_slice, rows_standard, cols_standard), dtype=np.float32)
    brain_label = np.ndarray((np.shape(FLAIR_image)[0],image_rows_Dataset, image_cols_Dataset), dtype=np.float32)
    imgs_mask_two_channels = np.ndarray((num_selected_slice, rows_standard, cols_standard), dtype=np.float32)

    # FLAIR --------------------------------------------
    brain_mask_FLAIR[FLAIR_image >=thresh_FLAIR] = 1
    brain_mask_FLAIR[FLAIR_image < thresh_FLAIR] = 0
    for iii in range(np.shape(FLAIR_image)[0]):
  
        brain_mask_FLAIR[iii,:,:] = scipy.ndimage.morphology.binary_fill_holes(brain_mask_FLAIR[iii,:,:])  #fill the holes inside brain
        #------Gaussion Normalization
    FLAIR_image -=np.mean(FLAIR_image[brain_mask_FLAIR == 1])      #Gaussion Normalization
    FLAIR_image /=np.std(FLAIR_image[brain_mask_FLAIR == 1])

    FLAIR_image_suitable[...] = np.min(FLAIR_image)
    FLAIR_image_suitable[:, :, int(cols_standard/2-image_cols_Dataset/2):int(cols_standard/2+image_cols_Dataset/2)] = FLAIR_image[:, start_cut:start_cut+rows_standard, :]
   
    # T1 -----------------------------------------------
    brain_mask_T1[T1_image >=thresh_T1] = 1
    brain_mask_T1[T1_image < thresh_T1] = 0
    for iii in range(np.shape(T1_image)[0]):
 
        brain_mask_T1[iii,:,:] = scipy.ndimage.morphology.binary_fill_holes(brain_mask_T1[iii,:,:])  #fill the holes inside brain
        #------Gaussion Normalization
    T1_image -=np.mean(T1_image[brain_mask_T1 == 1])      #Gaussion Normalization
    T1_image /=np.std(T1_image[brain_mask_T1 == 1])

    T1_image_suitable[...] = np.min(T1_image)
    T1_image_suitable[:, :, int((cols_standard-image_cols_Dataset)/2):int((cols_standard+image_cols_Dataset)/2)] = T1_image[:, start_cut:start_cut+rows_standard, :]
    # lable----------------
    brain_label[labelArray == 1] = 1
    brain_label[labelArray != 1] = 0
    imgs_mask_two_channels[:, :, int((cols_standard-image_cols_Dataset)/2):int((cols_standard+image_cols_Dataset)/2)] = brain_label[:, start_cut:start_cut+rows_standard, :]
    
    #---------------------------------------------------
    FLAIR_image_suitable  = FLAIR_image_suitable[..., np.newaxis]
    T1_image_suitable  = T1_image_suitable[..., np.newaxis]
    
    imgs_two_channels = np.concatenate((FLAIR_image_suitable, T1_image_suitable), axis = 3)
    maskArray = imgs_mask_two_channels > 0
    print(np.shape(imgs_two_channels))
    return imgs_two_channels,imgs_mask_two_channels,maskArray

In [27]:
int(3/2)


1

In [28]:
def GE3T_postprocessing(FLAIR_array, pred):
    start_slice = 11
    start_cut = 46
    num_selected_slice = np.shape(FLAIR_array)[0]
    image_rows_Dataset = np.shape(FLAIR_array)[1]
    image_cols_Dataset = np.shape(FLAIR_array)[2]
    original_pred = np.ndarray(np.shape(FLAIR_array), dtype=np.float32)
    original_pred[:, start_cut:start_cut+rows_standard,:] = pred[:,:, (rows_standard-image_cols_Dataset)/2:(rows_standard+image_cols_Dataset)/2,0]

    original_pred[0:start_slice, :, :] = 0
    original_pred[(num_selected_slice-start_slice-1):(num_selected_slice-1), :, :] = 0
    return original_pred

In [69]:
###---Here comes the main funtion--------------------------------------------
###---Leave one patient out validation--------------------------------------------

patient_num =45
patient_count = 0
rows_standard = 200
cols_standard = 200
thresh_FLAIR = 70      #to mask the brain
thresh_T1 = 30
para_array = [[0.958, 0.958, 3], [1.00, 1.00, 3], [1.20, 0.977, 3]]    # parameters of the scanner
para_array = np.array(para_array, dtype=np.float32)

images = None # shape: (numImages, z, y, x, channels=1)
labels = None
masks  = None

validationimages = []
validationlables = []

#read the dirs of test data 
input_dir_1 = '../data/validation/Utrecht'
input_dir_2 = '../data/validation/Singapore'
input_dir_3 = '../data/validation/Amsterdam'
###---dir to save results---------
outputDir = 'evaluation_result_LOOV'
  
#-------------------------------------------   
dirs = os.listdir(input_dir_1) + os.listdir(input_dir_2) + os.listdir(input_dir_3)
# #All the slices and the corresponding patients id
# imgs_three_datasets_two_channels = np.load('imgs_three_datasets_two_channels.npy')
# imgs_mask_three_datasets_two_channels = np.load('imgs_mask_three_datasets_two_channels.npy')
# slices_patient_id_label = np.load('slices_patient_id_label.npy')
dirs = [f for f in dirs if not f.startswith('.')]
print(dirs)
for dir_name in dirs:
    print('dir_name is:')
    print(dir_name)
    
    if patient_count < 4:
        inputDir = input_dir_1
    elif patient_count > 3 and patient_count < 9:
        inputDir = input_dir_2
    elif patient_count >= 9:
        inputDir = input_dir_3
    FLAIR_image = sitk.ReadImage(os.path.join(inputDir, dir_name, 'pre', 'FLAIR.nii.gz'))
    T1_image = sitk.ReadImage(os.path.join(inputDir, dir_name, 'pre', 'T1.nii.gz'))
    label_image= sitk.ReadImage(os.path.join(inputDir, dir_name, "wmh.nii.gz"))
    FLAIR_array = sitk.GetArrayFromImage(FLAIR_image)
    T1_array = sitk.GetArrayFromImage(T1_image)
    labelArray = sitk.GetArrayFromImage(label_image)
    
    #Proccess testing data-----
    para_FLAIR = np.ndarray((1,3), dtype=np.float32)
    para_FLAIR_ = FLAIR_image.GetSpacing()
    para_FLAIR[0,0] = round(para_FLAIR_[0],3)   # get spacing parameters of the data
    para_FLAIR[0,1] = round(para_FLAIR_[1],3)  
    para_FLAIR[0,2] = round(para_FLAIR_[2],3) 
    if np.array_equal(para_FLAIR[0], para_array[0]) :
        print('From Utrecht!')
        imgs_test,label,maskArray = Utrecht_preprocessing(FLAIR_array, T1_array, labelArray)
    elif np.array_equal(para_FLAIR[0], para_array[1]):
        print('From Singapore!')
        imgs_test,label,maskArray  = Utrecht_preprocessing(FLAIR_array, T1_array, labelArray)
    elif np.array_equal(para_FLAIR[0], para_array[2]):
        print('From Amsterdam!')
        imgs_test,label,maskArray  = GE3T_preprocessing(FLAIR_array, T1_array, labelArray)
    print(imgs_test.shape,label.shape,maskArray.shape)
    patient_count+=1
    
           
    # Add to the images/labels array
    images = imgs_test.reshape([1] + list(imgs_test.shape) )
    labels = label.reshape([1] + list(label.shape) + [1])
    masks  = maskArray.reshape([1] + list(maskArray.shape) + [1])
    print("hehe",images.shape, labels.shape, masks.shape,images.max())
    for i in range(masks.shape[0]):
        for j in range(masks.shape[1]):
            if not np.all(masks[i,j,:,:,0]== False):
                validationlables.append(labels[i,j,:,:,:])
#                 print("max:",images[i,j,:,:,:].max())
                validationimages.append(images[i,j,:,:,:])

      
    print("t:", np.asarray(validationlables).shape,np.asarray(validationimages).shape,np.asarray(validationimages).max() )
print(images.shape, labels.shape, masks.shape)   
print(np.asarray(validationlables).shape,np.asarray(validationimages).shape)
#     ###---train u-net models-------------------------------------------------------------------------------
#     training_index = slices_patient_id_label != patient_count
#     test_index = slices_patient_id_label == patient_count
#     dim_training = sum(training_index)
#     dim_test = sum(test_index)
#     print('the dim of training set:')
#     print(dim_training[0])
#     imgs_train = np.ndarray((dim_training[0], rows_standard, cols_standard, 2), dtype=np.float32)
#     imgs_test_selected = np.ndarray((dim_test[0], rows_standard, cols_standard, 2), dtype=np.float32)

#     imgs_mask_train = np.ndarray((dim_training[0], rows_standard, cols_standard, 1), dtype=np.float32)
#     imgs_mask_test_selected = np.ndarray((dim_test[0], rows_standard, cols_standard, 1), dtype=np.float32)
#     count_index_train = 0
#     count_index_test = 0
#     for iii in range(training_index.shape[0]):
#         if training_index[iii] == 1:
#             imgs_train[count_index_train, ...] = imgs_three_datasets_two_channels[iii, ...]
#             imgs_mask_train[count_index_train, ...] = imgs_mask_three_datasets_two_channels[iii, ...]
#             count_index_train = count_index_train + 1
#         if training_index[iii] == 0:
#             imgs_test_selected[count_index_test, ...] = imgs_three_datasets_two_channels[iii, ...]
#             imgs_mask_test_selected[count_index_test, ...] = imgs_mask_three_datasets_two_channels[iii, ...]
#             count_index_test = count_index_test + 1

#     print('training dataset dimension:')
#     print(imgs_train.shape[0])
#     img_shape=(rows_standard, cols_standard, 2)


#     print('-'*30)
#     print('Fitting model...')
#     print('-'*30)

['21', '23', '31', '35', '50', '55', '56', '60', '62', '100', '105', '106', '110', '113']
dir_name is:
21
From Utrecht!
(48, 240, 240)
20 220 20 220
(48, 240, 240)
(48, 200, 200)
(48, 200, 200, 2)
(48, 200, 200, 2) (48, 200, 200) (48, 200, 200)
hehe (1, 48, 200, 200, 2) (1, 48, 200, 200, 1) (1, 48, 200, 200, 1) 5.7240167
max: 4.362728
max: 5.1427183
max: 4.776386
max: 4.3169985
max: 4.740326
max: 4.793553
max: 5.0530953
max: 4.8989935
max: 4.3028836
max: 4.4358244
max: 4.3551927
max: 4.3547955
max: 4.3620734
max: 4.6707845
max: 4.6877866
max: 4.1151795
max: 4.7104096
max: 4.3052025
max: 4.801093
max: 4.496967
max: 4.9930186
max: 4.4287314
max: 4.4444366
max: 4.2412643
max: 4.12495
max: 3.9727085
max: 3.7560663
max: 3.6715212
max: 3.9394782
max: 3.9315321
max: 3.4742377
max: 3.4050465
max: 4.097337
max: 3.7792242
t: (34, 200, 200, 1) (34, 200, 200, 2) 5.1427183
dir_name is:
23
From Utrecht!
(48, 240, 240)
20 220 20 220
(48, 240, 240)
(48, 200, 200)
(48, 200, 200, 2)
(48, 200, 200, 2) (4

t: (324, 200, 200, 1) (324, 200, 200, 2) 7.7884636
(1, 83, 200, 200, 2) (1, 83, 200, 200, 1) (1, 83, 200, 200, 1)
(324, 200, 200, 1) (324, 200, 200, 2)


In [57]:
np.asarray(trainlables).shape

(1033, 200, 200, 1)

In [58]:
np.save('trainlables.npy', np.asarray(trainlables))    
np.save('trainimages.npy', np.asarray(trainimages)) 


In [64]:
np.asarray(validationimages).max()

nan

In [70]:
np.save('validationlables.npy', np.asarray(validationlables))    
np.save('validationimages.npy', np.asarray(validationimages)) 

### U net


In [91]:
def myshow(img, title=None, margin=0.05, dpi=100):
    nda = sitk.GetArrayViewFromImage(img)
    spacing = img.GetSpacing()
        
    if nda.ndim == 3:
        # fastest dim, either component or x
        c = nda.shape[-1]
        
        # the the number of components is 3 or 4 consider it an RGB image
        if not c in (3,4):
            nda = nda[nda.shape[0]//2,:,:]
    
    elif nda.ndim == 4:
        c = nda.shape[-1]
        
        if not c in (3,4):
            raise Runtime("Unable to show 3D-vector Image")
            
        # take a z-slice
        nda = nda[nda.shape[0]//2,:,:,:]
            
    ysize = nda.shape[0]
    xsize = nda.shape[1]
      
    # Make a figure big enough to accommodate an axis of xpixels by ypixels
    # as well as the ticklabels, etc...
    figsize = (4 + margin) * ysize / dpi, (4 + margin) * xsize / dpi

    fig = plt.figure(figsize=figsize, dpi=dpi)
    # Make the axis the right size...
    ax = fig.add_axes([margin, margin, 1 - 2*margin, 1 - 2*margin])
   
    extent = (0, xsize*spacing[1], ysize*spacing[0], 0)
    
    t = ax.imshow(nda,extent=extent,interpolation=None)
    
    if nda.ndim == 2:
        t.set_cmap("gray")
    
    if(title):
        plt.title(title)

In [71]:
X = np.load('trainimages.npy')
y = np.load('trainlables.npy')
y[y >1] = 0
print(X.shape, X.min(), X.max()) # (240, 240, 4) -0.380588 2.62761
print(y.shape, y.min(), y.max()) # (240, 240, 1) 0 1 

Xtest = np.load('validationimages.npy')
ytest = np.load('validationlables.npy')
ytest[ytest >1] = 0
print(Xtest.shape, Xtest.min(), Xtest.max()) # (240, 240, 4) -0.380588 2.62761
print(ytest.shape, ytest.min(), ytest.max()) # (240, 240, 1) 0 1 

(1033, 200, 200, 2) -2.6653488 16.262873
(1033, 200, 200, 1) 0.0 1.0
(324, 200, 200, 2) -1.9722795 7.7884636
(324, 200, 200, 1) 0.0 1.0


In [74]:
def shuffle_list(*ls):
    l =list(zip(*ls))
    shuffle(l)
    return zip(*l)

Xs,ys = shuffle_list(X,y)
Xs= np.array(Xs)
ys= np.array(ys)
print(Xs.shape)
print(ys.shape)

(1033, 200, 200, 2)
(1033, 200, 200, 1)


In [75]:
from __future__ import print_function

import os
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose,ZeroPadding2D, Dropout,UpSampling2D,Activation, Cropping2D
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from keras import backend as K

In [76]:
def dice_coef_for_training(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1.-dice_coef_for_training(y_true, y_pred)

In [77]:
def conv_bn_relu(nd, k=3, inputs=None):
    conv = Conv2D(nd, k, padding='same')(inputs) #, kernel_initializer='he_normal'
    #bn = BatchNormalization()(conv)
    relu = Activation('relu')(conv)
    return relu

In [78]:
def get_crop_shape(target, refer):
        # width, the 3rd dimension
        cw = (target.get_shape()[2] - refer.get_shape()[2]).value
        assert (cw >= 0)
        if cw % 2 != 0:
            cw1, cw2 = int(cw/2), int(cw/2) + 1
        else:
            cw1, cw2 = int(cw/2), int(cw/2)
        # height, the 2nd dimension
        ch = (target.get_shape()[1] - refer.get_shape()[1]).value
        assert (ch >= 0)
        if ch % 2 != 0:
            ch1, ch2 = int(ch/2), int(ch/2) + 1
        else:
            ch1, ch2 = int(ch/2), int(ch/2)

        return (ch1, ch2), (cw1, cw2)

In [79]:
def get_unet():
    concat_axis = -1
    filters = 3
    inputs = Input(batchShape[1:])    
    conv1 = conv_bn_relu(64, filters, inputs)
    conv1 = conv_bn_relu(64, filters, conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = conv_bn_relu(96, 3, pool1)
    conv2 = conv_bn_relu(96, 3, conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = conv_bn_relu(128, 3, pool2)
    conv3 = conv_bn_relu(128, 3, conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = conv_bn_relu(256, 3, pool3)
    conv4 = conv_bn_relu(256, 4, conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = conv_bn_relu(512, 3, pool4)
    conv5 = conv_bn_relu(512, 3, conv5)

    up_conv5 = UpSampling2D(size=(2, 2))(conv5)
    ch, cw = get_crop_shape(conv4, up_conv5)
    crop_conv4 = Cropping2D(cropping=(ch,cw))(conv4)
    up6 = concatenate([up_conv5, crop_conv4], axis=concat_axis)
    conv6 = conv_bn_relu(256, 3, up6)
    conv6 = conv_bn_relu(256, 3, conv6)

    up_conv6 = UpSampling2D(size=(2, 2))(conv6)
    ch, cw = get_crop_shape(conv3, up_conv6)
    crop_conv3 = Cropping2D(cropping=(ch,cw))(conv3)
    up7 = concatenate([up_conv6, crop_conv3], axis=concat_axis)
    conv7 = conv_bn_relu(128, 3, up7)
    conv7 = conv_bn_relu(128, 3, conv7)

    up_conv7 = UpSampling2D(size=(2, 2))(conv7)
    ch, cw = get_crop_shape(conv2, up_conv7)
    crop_conv2 = Cropping2D(cropping=(ch,cw))(conv2)
    up8 = concatenate([up_conv7, crop_conv2], axis=concat_axis)
    conv8 = conv_bn_relu(96, 3, up8)
    conv8 = conv_bn_relu(96, 3, conv8)

    up_conv8 = UpSampling2D(size=(2, 2))(conv8)
    ch, cw = get_crop_shape(conv1, up_conv8)
    crop_conv1 = Cropping2D(cropping=(ch,cw))(conv1)
    up9 = concatenate([up_conv8, crop_conv1], axis=concat_axis)
    conv9 = conv_bn_relu(64, 3, up9)
    conv9 = conv_bn_relu(64, 3, conv9)

    ch, cw = get_crop_shape(inputs, conv9)
    conv9 = ZeroPadding2D(padding=(ch, cw))(conv9)
    conv10 = Conv2D(1, 1, activation='sigmoid', padding='same')(conv9) #, kernel_initializer='he_normal'
    model = Model(inputs=inputs, outputs=conv10)
    model.compile(optimizer=Adam(lr=(2e-4)), loss=dice_coef_loss)
    model.summary()
    return model

In [90]:
img_rows =240
img_cols =240

smooth = 1.

batchSize = 20
batchShape = (batchSize, 240,240, 2)

In [86]:
def preprocess(imgs):
#   print("pree",imgs.shape, imgs.shape[:-1])
    imgs_p = np.ndarray((imgs.shape[0],img_rows, img_cols,imgs.shape[-1]), dtype=np.uint8)
#     print("imgs.shape[0]",imgs.shape[0])
    for i in range(imgs.shape[0]):
        imgs_p[i] = resize(imgs[i],(img_cols, img_rows,imgs.shape[-1]), preserve_range=True)
  
#   print("imgs_p",imgs_p.shape)
    return imgs_p

In [99]:
def train(X,y):
    print('-'*30)
    print('Loading and preprocessing train data...')
    print('-'*30)
    imgs_train, imgs_mask_train = np.array(X), np.array(y)#load_train_data()
#     myshow(sitk.GetImageFromArray(imgs_train[:,:,0]))
#     myshow(sitk.GetImageFromArray(imgs_train[:,:,1]))
#     myshow(sitk.GetImageFromArray(imgs_mask_train))
#     print("shape before",imgs_train.shape,imgs_mask_train.shape)
#     myshow(sitk.GetImageFromArray(imgs_train))

    imgs_train = preprocess(imgs_train)
    imgs_mask_train = preprocess(imgs_mask_train)
#     print("shapeagfter",imgs_train.shape,imgs_mask_train.shape)
#     myshow(sitk.GetImageFromArray(imgs_train))

    imgs_train = imgs_train.astype('float32')
    mean = np.mean(imgs_train)  # mean for data centering
    std = np.std(imgs_train)  # std for data normalization

    imgs_train -= mean
    imgs_train /= std

    imgs_mask_train = imgs_mask_train.astype('float32')

    print('-'*30)
    print('Creating and compiling model...')
    print('-'*30)
    model = get_unet()
    model_checkpoint = ModelCheckpoint('weights.h5', monitor='val_loss', save_best_only=True)

    print('-'*30)
    print('Fitting model...')
    print('-'*30)
# #     myshow(sitk.GetImageFromArray(imgs_train))
#     myshow(sitk.GetImageFromArray(imgs_mask_train))
#     print("before model",imgs_train.shape,imgs_mask_train.shape)
    model.fit(imgs_train, imgs_mask_train, batch_size=20, nb_epoch=2, verbose=1, shuffle=True,
              validation_split=0.2,
              callbacks=[model_checkpoint])
    
    return model
   
def predict(model):    
    print('-'*30)
    print('Loading and preprocessing test data...')
    print('-'*30)
    testrange = range(len(ytest))
    imgs_test, imgs_id_test = Xtest[:50,:,:,:], ytest[:50,:,:,:]
#     print("before test pre",imgs_test.shape,imgs_id_test)
    imgs_test = preprocess(imgs_test)
#     print("after test pre",imgs_test.shape,imgs_id_test)
    imgs_test = imgs_test.astype('float32')
    mean = np.mean(imgs_test)  # mean for data centering
    std = np.std(imgs_test)  # std for data normalization
    imgs_test -= mean
    imgs_test /= std

    print('-'*30)
    print('Loading saved weights...')
    print('-'*30)
    model.load_weights('weights.h5')

    print('-'*30)
    print('Predicting masks on test data...')
    print('-'*30)
#     print("test model",imgs_test.shape)
    imgs_mask_test = model.predict(imgs_test, verbose=1)
    print("test model finished",imgs_mask_test.shape)
    np.save('imgs_mask_test.npy', imgs_mask_test)
    myshow(sitk.GetImageFromArray(imgs_test[20,:,:,1]))
    myshow(sitk.GetImageFromArray(imgs_id_test[20,:,:,0]))
    myshow(sitk.GetImageFromArray(imgs_mask_test[20,:,:,0]))
    print('-' * 30)
    print('Saving predicted masks to files...')
    print('-' * 30)
    pred_dir = 'preds'
    if not os.path.exists(pred_dir):
        os.mkdir(pred_dir)
    for image, image_id in zip(imgs_mask_test, testrange):
#         print(image_id)
        nn = 0
#         print(image.shape)
        image = (image[:, :, 0] * 255.).astype(np.uint8)
        imsave(os.path.join(pred_dir,str(image_id) + '_pred.png'), image)
        nn+=1
    print(imgs_id_test.shape,imgs_mask_test.shape)
    return imgs_id_test,imgs_mask_test

In [96]:
nn= 30
Xchunks = [Xs[x:x+nn] for x in range(0, len(Xs), nn)]
ychunks = [ys[x:x+nn] for x in range(0, len(ys), nn)]

In [97]:
models = []
if __name__ == '__main__':
    for i,j in zip(Xchunks[:2],ychunks[:2]):
        model = train(i,j)
        models.append(model)
#     model = train(Xchunks[2],ychunks[2])
    for model in models:
        testFilename, resultFilename = predict(model)

------------------------------
Loading and preprocessing train data...
------------------------------


  warn("The default mode, 'constant', will be changed to 'reflect' in "
  warn("Anti-aliasing will be enabled by default in skimage 0.15 to "


------------------------------
Creating and compiling model...
------------------------------
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            (None, 240, 240, 2)  0                                            
__________________________________________________________________________________________________
conv2d_39 (Conv2D)              (None, 240, 240, 64) 1216        input_3[0][0]                    
__________________________________________________________________________________________________
activation_37 (Activation)      (None, 240, 240, 64) 0           conv2d_39[0][0]                  
__________________________________________________________________________________________________
conv2d_40 (Conv2D)              (None, 240, 240, 64) 36928       activation_37[0][0]              
_______________



Train on 24 samples, validate on 6 samples
Epoch 1/2
Epoch 2/2
------------------------------
Loading and preprocessing train data...
------------------------------
------------------------------
Creating and compiling model...
------------------------------
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 240, 240, 2)  0                                            
__________________________________________________________________________________________________
conv2d_58 (Conv2D)              (None, 240, 240, 64) 1216        input_4[0][0]                    
__________________________________________________________________________________________________
activation_55 (Activation)      (None, 240, 240, 64) 0           conv2d_58[0][0]                  
________________________________________________

Train on 24 samples, validate on 6 samples
Epoch 1/2
Epoch 2/2
------------------------------
Loading and preprocessing test data...
------------------------------
------------------------------
Loading saved weights...
------------------------------
------------------------------
Predicting masks on test data...
------------------------------

KeyboardInterrupt: 

In [101]:
len(models)

2

In [102]:
for model in models:
        testFilename, resultFilename = predict(model)

------------------------------
Loading and preprocessing test data...
------------------------------


  warn("The default mode, 'constant', will be changed to 'reflect' in "
  warn("Anti-aliasing will be enabled by default in skimage 0.15 to "


------------------------------
Loading saved weights...
------------------------------
------------------------------
Predicting masks on test data...
------------------------------
test model finished (50, 240, 240, 1)
------------------------------
Saving predicted masks to files...
------------------------------
(50, 200, 200, 1) (50, 240, 240, 1)
------------------------------
Loading and preprocessing test data...
------------------------------
------------------------------
Loading saved weights...
------------------------------
------------------------------
Predicting masks on test data...
------------------------------
test model finished (50, 240, 240, 1)
------------------------------
Saving predicted masks to files...
------------------------------
(50, 200, 200, 1) (50, 240, 240, 1)
