In [1]:
import os,time
import glob
import SimpleITK as sitk
import matplotlib.pyplot as plt
import numpy as np
inputDir = '../data/train'
testDir = '../data/validation'
outputDir = 'output'
from random import shuffle


In [2]:
def myshow(img, title=None, margin=0.05, dpi=100):
    nda = sitk.GetArrayViewFromImage(img)
    spacing = img.GetSpacing()
        
    if nda.ndim == 3:
        # fastest dim, either component or x
        c = nda.shape[-1]
        
        # the the number of components is 3 or 4 consider it an RGB image
        if not c in (3,4):
            nda = nda[nda.shape[0]//2,:,:]
    
    elif nda.ndim == 4:
        c = nda.shape[-1]
        
        if not c in (3,4):
            raise Runtime("Unable to show 3D-vector Image")
            
        # take a z-slice
        nda = nda[nda.shape[0]//2,:,:,:]
            
    ysize = nda.shape[0]
    xsize = nda.shape[1]
      
    # Make a figure big enough to accommodate an axis of xpixels by ypixels
    # as well as the ticklabels, etc...
    figsize = (4 + margin) * ysize / dpi, (4 + margin) * xsize / dpi

    fig = plt.figure(figsize=figsize, dpi=dpi)
    # Make the axis the right size...
    ax = fig.add_axes([margin, margin, 1 - 2*margin, 1 - 2*margin])
   
    extent = (0, xsize*spacing[1], ysize*spacing[0], 0)
    
    t = ax.imshow(nda,extent=extent,interpolation=None)
    
    if nda.ndim == 2:
        t.set_cmap("gray")
    
    if(title):
        plt.title(title)

In [3]:
img_rows =240
img_cols =240

smooth = 1.

batchSize = 20
batchShape = (batchSize, 240,240, 2)

In [4]:
def preprocess(imgs):
#   print("pree",imgs.shape, imgs.shape[:-1])
    imgs_p = np.ndarray((imgs.shape[0],img_rows, img_cols,imgs.shape[-1]), dtype=np.uint8)
#     print("imgs.shape[0]",imgs.shape[0])
    for i in range(imgs.shape[0]):
        imgs_p[i] = resize(imgs[i],(img_cols, img_rows,imgs.shape[-1]), preserve_range=True)
  
#   print("imgs_p",imgs_p.shape)
    return imgs_p

In [5]:
def loadData( inputDir,city, padding=0 ):
    imageflairFilenames = glob.glob(os.path.join(inputDir, city, "*","pre","FLAIR.nii.gz"))
    imageT1Filenames = glob.glob(os.path.join(inputDir, city, "*","pre","T1.nii.gz"))
    labelFilenames= glob.glob(os.path.join(inputDir, city, "*","wmh.nii.gz"))
    images = None # shape: (numImages, z, y, x, channels=1)
    labels = None
    masks  = None
    print(len(imageflairFilenames))
    for imageflairFilenames, imageT1Filenames, labelFilenames in zip(imageflairFilenames, imageT1Filenames, labelFilenames):
        # Load the images
        flairImage = sitk.ReadImage(imageflairFilenames)
        T1Image = sitk.ReadImage(imageT1Filenames)
        labelImage = sitk.ReadImage(labelFilenames)
        # Convert to arrays
        flairArray = np.pad(sitk.GetArrayFromImage(flairImage), [(0,0),(padding,padding),(padding,padding)], 'constant')
#         flairArray = preprocess(flairArray)
        T1Array = np.pad(sitk.GetArrayFromImage(T1Image), [(0,0),(padding,padding),(padding,padding)], 'constant')
#         T1Array = preprocess(T1Array)
        labelArray = np.pad(sitk.GetArrayFromImage(labelImage), [(0,0),(padding,padding),(padding,padding)], 'constant')
#         labelArray = preprocess(labelArray)
        maskArray = labelArray > 0
        print("flairArray",flairArray.shape)
        print("T1Array",T1Array.shape)
        print("labelArray",labelArray.shape)
        print("maskArray",maskArray.shape)
        # Add to the images/labels array
        if images is None:
            images = flairArray.reshape([1] + list(flairArray.shape) + [1])
            images = np.concatenate([images, T1Array.reshape([1] + list(T1Array.shape) + [1])], axis=4)
            labels = labelArray.reshape([1] + list(labelArray.shape) + [1])
            masks  = maskArray.reshape([1] + list(maskArray.shape) + [1])
        else:
            tempArray = np.concatenate([flairArray.reshape([1] + list(flairArray.shape) + [1]), T1Array.reshape([1] + list(T1Array.shape) + [1])], axis=4)
            print("tempArray:",tempArray.shape,"images:",images.shape)
            images = np.concatenate([images, tempArray])
            
            labels = np.concatenate([labels, labelArray.reshape([1] + list(labelArray.shape) + [1])])
            masks  = np.concatenate([masks, maskArray.reshape([1] + list(maskArray.shape) + [1])])
                    
    return images, labels, masks

In [None]:
city = "Amsterdam"
testImages, testLabels, testMasks = loadData(testDir,city)
testNonZeroIdx = np.nonzero(testMasks)
trainImages, trainLabels, trainMasks = loadData(inputDir,city)
trainNonZeroIdx = np.nonzero(testMasks)

In [9]:
cities = ["Utrecht","Amsterdam"]
for city in cities:
    testImages, testLabels, testMasks = loadData(testDir,city)
    testNonZeroIdx = np.nonzero(testMasks)
    trainImages, trainLabels, trainMasks = loadData(inputDir,city)
    trainNonZeroIdx = np.nonzero(testMasks)
    
    trainimages = []
    trainlables = []
    for i in range(trainMasks.shape[0]):
        for j in range(trainMasks.shape[1]):
            if not np.all(trainMasks[i,j,:,:,0]== False):
                trainlables.append(trainLabels[i,j,:,:,:])
                trainimages.append(trainImages[i,j,:,:,:])

    testimages = []
    testlables = []
    for i in range(testMasks.shape[0]):
        for j in range(testMasks.shape[1]):
            if not np.all(testMasks[i,j,:,:,0]== False):
                testlables.append(testLabels[i,j,:,:,:])
                testimages.append(testImages[i,j,:,:,:])
    if city == "Utrecht":
        X = np.asarray(trainimages)
        y = np.asarray(trainlables).astype(int)
        y[y >1] = 0
        Xtest = np.asarray(testimages)
        ytest =  np.asarray(testlables).astype(int)
        ytest[ytest >1] = 0
    else:
        X = np.concatenate([X,np.asarray(trainimages)])
        y = np.concatenate([y,np.asarray(trainlables).astype(int)])
        y[y >1] = 0
        print(X.shape, X.min(), X.max()) # (240, 240, 4) -0.380588 2.62761
        print(y.shape, y.min(), y.max()) # (240, 240, 1) 0 1 

        Xtest = np.concatenate([Xtest,np.asarray(testimages)])
        ytest =  np.concatenate([ytest,np.asarray(testlables).astype(int)])
        ytest[ytest >1] = 0
        print(Xtest.shape, Xtest.min(), Xtest.max()) # (240, 240, 4) -0.380588 2.62761
        print(ytest.shape, ytest.min(), ytest.max()) # (240, 240, 1) 0 1 


5
flairArray (48, 240, 240)
T1Array (48, 240, 240)
labelArray (48, 240, 240)
maskArray (48, 240, 240)
flairArray (48, 240, 240)
T1Array (48, 240, 240)
labelArray (48, 240, 240)
maskArray (48, 240, 240)
tempArray: (1, 48, 240, 240, 2) images: (1, 48, 240, 240, 2)
flairArray (48, 240, 240)
T1Array (48, 240, 240)
labelArray (48, 240, 240)
maskArray (48, 240, 240)
tempArray: (1, 48, 240, 240, 2) images: (2, 48, 240, 240, 2)
flairArray (48, 240, 240)
T1Array (48, 240, 240)
labelArray (48, 240, 240)
maskArray (48, 240, 240)
tempArray: (1, 48, 240, 240, 2) images: (3, 48, 240, 240, 2)
flairArray (48, 240, 240)
T1Array (48, 240, 240)
labelArray (48, 240, 240)
maskArray (48, 240, 240)
tempArray: (1, 48, 240, 240, 2) images: (4, 48, 240, 240, 2)
15
flairArray (48, 240, 240)
T1Array (48, 240, 240)
labelArray (48, 240, 240)
maskArray (48, 240, 240)
flairArray (48, 240, 240)
T1Array (48, 240, 240)
labelArray (48, 240, 240)
maskArray (48, 240, 240)
tempArray: (1, 48, 240, 240, 2) images: (1, 48, 240

ValueError: all the input array dimensions except for the concatenation axis must match exactly

In [10]:
X.shape

(402, 240, 240, 2)

In [None]:
trainimages = []
trainlables = []
for i in range(trainMasks.shape[0]):
    for j in range(trainMasks.shape[1]):
        if not np.all(trainMasks[i,j,:,:,0]== False):
            trainlables.append(trainLabels[i,j,:,:,:])
            trainimages.append(trainImages[i,j,:,:,:])

testimages = []
testlables = []
for i in range(testMasks.shape[0]):
    for j in range(testMasks.shape[1]):
        if not np.all(testMasks[i,j,:,:,0]== False):
            testlables.append(testLabels[i,j,:,:,:])
            testimages.append(testImages[i,j,:,:,:])

In [None]:
X = np.asarray(trainimages)
y = np.asarray(trainlables).astype(int)
y[y >1] = 0
print(X.shape, X.min(), X.max()) # (240, 240, 4) -0.380588 2.62761
print(y.shape, y.min(), y.max()) # (240, 240, 1) 0 1 

Xtest = np.asarray(testimages)
ytest = np.asarray(testlables).astype(int)
ytest[ytest >1] = 0
print(Xtest.shape, Xtest.min(), Xtest.max()) # (240, 240, 4) -0.380588 2.62761
print(ytest.shape, ytest.min(), ytest.max()) # (240, 240, 1) 0 1 

In [None]:

def shuffle_list(*ls):
    l =list(zip(*ls))
    shuffle(l)
    return zip(*l)

Xs,ys = shuffle_list(X,y)
Xs= np.array(Xs)
ys= np.array(ys)
print(Xs.shape)
print(ys.shape)

In [None]:
from __future__ import print_function

import os
from skimage.transform import resize
from skimage.io import imsave
import numpy as np
from keras.models import Model
from keras.layers import Input, concatenate, Conv2D, MaxPooling2D, Conv2DTranspose,ZeroPadding2D, Dropout,UpSampling2D,Activation, Cropping2D
from keras.optimizers import Adam
from keras.callbacks import ModelCheckpoint
from keras import backend as K

In [None]:
def dice_coef_for_training(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def dice_coef_loss(y_true, y_pred):
    return 1.-dice_coef_for_training(y_true, y_pred)

In [None]:
def conv_bn_relu(nd, k=3, inputs=None):
    conv = Conv2D(nd, k, padding='same')(inputs) #, kernel_initializer='he_normal'
    #bn = BatchNormalization()(conv)
    relu = Activation('relu')(conv)
    return relu

In [None]:
def get_crop_shape(target, refer):
        # width, the 3rd dimension
        cw = (target.get_shape()[2] - refer.get_shape()[2]).value
        assert (cw >= 0)
        if cw % 2 != 0:
            cw1, cw2 = int(cw/2), int(cw/2) + 1
        else:
            cw1, cw2 = int(cw/2), int(cw/2)
        # height, the 2nd dimension
        ch = (target.get_shape()[1] - refer.get_shape()[1]).value
        assert (ch >= 0)
        if ch % 2 != 0:
            ch1, ch2 = int(ch/2), int(ch/2) + 1
        else:
            ch1, ch2 = int(ch/2), int(ch/2)

        return (ch1, ch2), (cw1, cw2)

In [None]:
def get_unet():
#     inputs = Input(batchShape[1:])
#     conv1 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(inputs)
#     conv1 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv1)
#     pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
#     conv2 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool1)
#     conv2 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv2)
#     pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)
#     conv3 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool2)
#     conv3 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv3)
#     pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)
#     conv4 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool3)
#     conv4 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv4)
#     drop4 = Dropout(0.5)(conv4)
#     pool4 = MaxPooling2D(pool_size=(2, 2))(drop4)

#     conv5 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(pool4)
#     conv5 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv5)
#     drop5 = Dropout(0.5)(conv5)

#     up6 = Conv2D(32, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(drop5))
#     merge6 = concatenate([drop4,up6], axis = -1)
#     conv6 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge6)
#     conv6 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv6)

#     up7 = Conv2D(32, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv6))
#     merge7 = concatenate([conv3,up7], axis = -1)
#     conv7 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge7)
#     conv7 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv7)

#     up8 = Conv2D(32, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv7))
#     merge8 = concatenate([conv2,up8], axis = -1)
#     conv8 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge8)
#     conv8 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv8)

#     up9 = Conv2D(32, 2, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(UpSampling2D(size = (2,2))(conv8))
#     merge9 = concatenate([conv1,up9], axis = -1)
#     conv9 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(merge9)
#     conv9 = Conv2D(32, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
#     conv9 = Conv2D(2, 3, activation = 'relu', padding = 'same', kernel_initializer = 'he_normal')(conv9)
#     conv10 = Conv2D(1, 1, activation = 'sigmoid')(conv9)


#     model = Model(inputs = inputs, outputs = conv10)

#     model.compile(optimizer = Adam(lr = 1e-4), loss=dice_coef_loss)

#     model.summary()
    concat_axis = -1
    filters = 3
    inputs = Input(batchShape[1:])    
    conv1 = conv_bn_relu(64, filters, inputs)
    conv1 = conv_bn_relu(64, filters, conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = conv_bn_relu(96, 3, pool1)
    conv2 = conv_bn_relu(96, 3, conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = conv_bn_relu(128, 3, pool2)
    conv3 = conv_bn_relu(128, 3, conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = conv_bn_relu(256, 3, pool3)
    conv4 = conv_bn_relu(256, 4, conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = conv_bn_relu(512, 3, pool4)
    conv5 = conv_bn_relu(512, 3, conv5)

    up_conv5 = UpSampling2D(size=(2, 2))(conv5)
    ch, cw = get_crop_shape(conv4, up_conv5)
    crop_conv4 = Cropping2D(cropping=(ch,cw))(conv4)
    up6 = concatenate([up_conv5, crop_conv4], axis=concat_axis)
    conv6 = conv_bn_relu(256, 3, up6)
    conv6 = conv_bn_relu(256, 3, conv6)

    up_conv6 = UpSampling2D(size=(2, 2))(conv6)
    ch, cw = get_crop_shape(conv3, up_conv6)
    crop_conv3 = Cropping2D(cropping=(ch,cw))(conv3)
    up7 = concatenate([up_conv6, crop_conv3], axis=concat_axis)
    conv7 = conv_bn_relu(128, 3, up7)
    conv7 = conv_bn_relu(128, 3, conv7)

    up_conv7 = UpSampling2D(size=(2, 2))(conv7)
    ch, cw = get_crop_shape(conv2, up_conv7)
    crop_conv2 = Cropping2D(cropping=(ch,cw))(conv2)
    up8 = concatenate([up_conv7, crop_conv2], axis=concat_axis)
    conv8 = conv_bn_relu(96, 3, up8)
    conv8 = conv_bn_relu(96, 3, conv8)

    up_conv8 = UpSampling2D(size=(2, 2))(conv8)
    ch, cw = get_crop_shape(conv1, up_conv8)
    crop_conv1 = Cropping2D(cropping=(ch,cw))(conv1)
    up9 = concatenate([up_conv8, crop_conv1], axis=concat_axis)
    conv9 = conv_bn_relu(64, 3, up9)
    conv9 = conv_bn_relu(64, 3, conv9)

    ch, cw = get_crop_shape(inputs, conv9)
    conv9 = ZeroPadding2D(padding=(ch, cw))(conv9)
    conv10 = Conv2D(1, 1, activation='sigmoid', padding='same')(conv9) #, kernel_initializer='he_normal'
    model = Model(inputs=inputs, outputs=conv10)
    model.compile(optimizer=Adam(lr=(2e-4)), loss=dice_coef_loss)
    model.summary()
    return model

In [None]:
def train(X,y):
    print('-'*30)
    print('Loading and preprocessing train data...')
    print('-'*30)
    imgs_train, imgs_mask_train = np.array(X), np.array(y)#load_train_data()
#     myshow(sitk.GetImageFromArray(imgs_train[:,:,0]))
#     myshow(sitk.GetImageFromArray(imgs_train[:,:,1]))
#     myshow(sitk.GetImageFromArray(imgs_mask_train))
#     print("shape before",imgs_train.shape,imgs_mask_train.shape)
#     myshow(sitk.GetImageFromArray(imgs_train))

    imgs_train = preprocess(imgs_train)
    imgs_mask_train = preprocess(imgs_mask_train)
#     print("shapeagfter",imgs_train.shape,imgs_mask_train.shape)
#     myshow(sitk.GetImageFromArray(imgs_train))

    imgs_train = imgs_train.astype('float32')
    mean = np.mean(imgs_train)  # mean for data centering
    std = np.std(imgs_train)  # std for data normalization

    imgs_train -= mean
    imgs_train /= std

    imgs_mask_train = imgs_mask_train.astype('float32')
    imgs_mask_train /= 255.  # scale masks to [0, 1]

    print('-'*30)
    print('Creating and compiling model...')
    print('-'*30)
    model = get_unet()
    model_checkpoint = ModelCheckpoint('weights.h5', monitor='val_loss', save_best_only=True)

    print('-'*30)
    print('Fitting model...')
    print('-'*30)
# #     myshow(sitk.GetImageFromArray(imgs_train))
#     myshow(sitk.GetImageFromArray(imgs_mask_train))
#     print("before model",imgs_train.shape,imgs_mask_train.shape)
    model.fit(imgs_train, imgs_mask_train, batch_size=20, nb_epoch=1, verbose=1, shuffle=True,
              validation_split=0.2,
              callbacks=[model_checkpoint])
    
    return model
   
def predict(model):    
    print('-'*30)
    print('Loading and preprocessing test data...')
    print('-'*30)
    testrange = range(len(ytest))
    imgs_test, imgs_id_test = Xtest, ytest
#     print("before test pre",imgs_test.shape,imgs_id_test)
    imgs_test = preprocess(imgs_test)
#     print("after test pre",imgs_test.shape,imgs_id_test)
    imgs_test = imgs_test.astype('float32')
    mean = np.mean(imgs_test)  # mean for data centering
    std = np.std(imgs_test)  # std for data normalization
    imgs_test -= mean
    imgs_test /= std

    print('-'*30)
    print('Loading saved weights...')
    print('-'*30)
    model.load_weights('weights.h5')

    print('-'*30)
    print('Predicting masks on test data...')
    print('-'*30)
#     print("test model",imgs_test.shape)
    imgs_mask_test = model.predict(imgs_test, verbose=1)
    print("test model finished",imgs_mask_test.shape)
    np.save('imgs_mask_test.npy', imgs_mask_test)
    myshow(sitk.GetImageFromArray(imgs_test[50,:,:,1]))
    myshow(sitk.GetImageFromArray(imgs_id_test[50,:,:,0]))
    myshow(sitk.GetImageFromArray(imgs_mask_test[50,:,:,0]))
    print('-' * 30)
    print('Saving predicted masks to files...')
    print('-' * 30)
    pred_dir = 'preds'
    if not os.path.exists(pred_dir):
        os.mkdir(pred_dir)
    for image, image_id in zip(imgs_mask_test, testrange):
#         print(image_id)
        nn = 0
#         print(image.shape)
        image = (image[:, :, 0] * 255.).astype(np.uint8)
        imsave(os.path.join(pred_dir,str(image_id) + '_pred.png'), image)
        nn+=1
    print(imgs_id_test.shape,imgs_mask_test.shape)
    return imgs_id_test,imgs_mask_test

In [None]:
Xchunks = [Xs[x:x+100] for x in range(0, len(Xs), 100)]
ychunks = [ys[x:x+100] for x in range(0, len(ys), 100)]


In [None]:
if __name__ == '__main__':
#     for i,j in zip(Xchunks[-5:],ychunks[-5:]):
#         model = train(i,j)
    model = train(Xchunks[2],ychunks[2])
    testFilename, resultFilename = predict(model)