In [1]:
from keras.models import Model
from keras.layers import Input, MaxPooling2D, Dropout, Conv2D, Conv2DTranspose, TimeDistributed, Bidirectional, ConvLSTM2D
from keras import backend as K
import tensorflow as tf
from keras.optimizers import RMSprop, Adam, SGD
from keras.losses import binary_crossentropy
from losses import *
import math

from datahandler import DataHandler
from models import *

from generator import *
from params import *
from callbacks import getCallbacks
from kfold_data_loader import *

from tqdm import tqdm
import os
import skimage.io as io

from keras.models import *
from keras import backend as K

import argparse
import sys
import random


import numpy as np

from keras.models import *
from keras import backend as K
from keras.preprocessing.image import ImageDataGenerator

from models.unet import *
from models.unet_se import *

from datahandler import DataHandler
from kfold_data_loader import *
from params import *
import os
import cv2
import skimage.io as io
from tqdm import tqdm

from medpy.io import save

from math import ceil, floor
from matplotlib import pyplot as plt
from sklearn.metrics import f1_score, jaccard_similarity_score

from scipy.ndimage import _ni_support
from scipy.ndimage.morphology import distance_transform_edt, binary_erosion,\
     generate_binary_structure

import warnings
warnings.filterwarnings("ignore")


lstm_graph = tf.get_default_graph()

Using TensorFlow backend.


In [12]:
def lstmGenerator(images, batch_size, pre_model, pre_graph):
    i=0
    c=0
    bs = batch_size
    while True:
        with pre_graph.as_default():
            batch_features = []

            j = i

            if j == 0:
                res1 =  np.expand_dims(np.zeros(images[j].shape), axis=0)
            else:
                img1 = np.expand_dims(images[j-1], axis=0)
                res1 = pre_model.predict(img1)

            img2 = np.expand_dims(images[j], axis=0)
            res2 = pre_model.predict(img2)

            if j == images.shape[0]-1:
                res3 = np.expand_dims(np.zeros(images[j].shape), axis=0)
            else:
                img3 = np.expand_dims(images[j+1], axis=0)
                res3 = pre_model.predict(img3)

            res = np.concatenate((res1,res2,res3), axis=0)
            res[res>=0.5] = 1
            res[res<0.5] = 0

            batch_features.append(res)
            print(np.array(batch_features).shape)
            i += 1
            yield np.array(batch_features)

def lstmModel():

    with lstm_graph.as_default():

        inputs = Input((3, 256, 256, 1))
        bclstm = Bidirectional(ConvLSTM2D(32, 3, return_sequences = True, padding='same', activation = 'relu'))(inputs)
        pool = TimeDistributed(MaxPooling2D(pool_size=2))(bclstm)
        bclstm = Bidirectional(ConvLSTM2D(64, 3, padding='same', activation = 'relu'))(pool)
        up = Conv2DTranspose(64,3, strides=2, padding='same', activation = 'relu')(bclstm)
        drop = Dropout(0.5)(up)
        outputs = Conv2D(1, (1,1), activation = 'sigmoid')(drop)

        model = Model(input = inputs, output = outputs)

        model.compile(optimizer = Adam(lr = 1e-4),
                loss = binary_crossentropy, metrics = [dice_coef])

        return model

In [13]:
def getDiceScore(ground_truth, prediction):
    #convert to boolean values and flatten
    ground_truth = np.asarray(ground_truth, dtype=np.bool).flatten()
    prediction = np.asarray(prediction, dtype=np.bool).flatten()    
    return f1_score(ground_truth, prediction)


In [14]:
 def hd(result, reference, voxelspacing=None, connectivity=1):
    hd1 = __surface_distances(result, reference, voxelspacing, connectivity).max()
    hd2 = __surface_distances(reference, result, voxelspacing, connectivity).max()
    hd = max(hd1, hd2)
    return hd

def hd95(result, reference, voxelspacing=None, connectivity=1):
    hd1 = __surface_distances(result, reference, voxelspacing, connectivity)
    hd2 = __surface_distances(reference, result, voxelspacing, connectivity)
    hd95 = np.percentile(np.hstack((hd1, hd2)), 95)
    return hd95

def __surface_distances(result, reference, voxelspacing=None, connectivity=1):
    result = np.atleast_1d(result.astype(np.bool))
    reference = np.atleast_1d(reference.astype(np.bool))
    if voxelspacing is not None:
        voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim)
        voxelspacing = np.asarray(voxelspacing, dtype=np.float64)
        if not voxelspacing.flags.contiguous:
            voxelspacing = voxelspacing.copy()

    footprint = generate_binary_structure(result.ndim, connectivity)

    if 0 == np.count_nonzero(result):
        raise RuntimeError('The first supplied array does not contain any binary object.')
    if 0 == np.count_nonzero(reference):
        raise RuntimeError('The second supplied array does not contain any binary object.')

    result_border = result ^ binary_erosion(result, structure=footprint, iterations=1)
    reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1)

    dt = distance_transform_edt(~reference_border, sampling=voxelspacing)
    sds = dt[result_border]

    return sds

In [15]:
image_files, mask_files = load_data_files('data/kfold_data/')

skf = getKFolds(image_files, mask_files, n=10)

kfold_indices = []
for train_index, val_index in skf.split(image_files, mask_files):
    kfold_indices.append({'train': train_index, 'val': val_index})

In [16]:
def predictMask(model, cur_graph, pre_model, pre_graph, image): 
    
    prediction = image
    
    
    
    image_gen = lstmGenerator(image, 1, pre_model, pre_graph)
    return model.predict_generator(image_gen, steps=len(image))

In [17]:
def predictAll(model, cur_graph, data, num_data=0):
    dice_scores = []
    hd_scores = []
    hd95_scores = []
    
    pre_graph = tf.get_default_graph()
    with pre_graph.as_default():
        pre_model = getUnet()
        print('loading pre weights %d'%i)
        pre_model.load_weights('logs/unet/kfold_unet/kfold_unet_dice_DA_K%d/kfold_unet_dice_DA_K%d_weights.h5'%(i,i))


    for image_file, mask_file in tqdm(data, total=num_data):
        
        fname = image_file[image_file.rindex('/')+1 : image_file.index('.')]
        
        image, hdr = dh.getImageData(image_file)
        gt_mask, _ = dh.getImageData(mask_file, is_mask=True)

        assert image.shape == gt_mask.shape
        
        if image.shape[1] != 256:
            continue
        
        pred_mask = predictMask(model, cur_graph, pre_model, pre_graph, image)
        pred_mask[pred_mask>=0.5] = 1
        pred_mask[pred_mask<0.5] = 0
            
        dice_score = getDiceScore(gt_mask, pred_mask)
        
        if dice_score == 0:
            continue
            
        dice_scores.append(dice_score)
        
        hd_score = hd(gt_mask, pred_mask)
        hd_scores.append(hd_score)
        
        hd95_score = hd95(gt_mask, pred_mask)
        hd95_scores.append(hd95_score)

    return dice_scores, hd_scores, hd95_scores

In [18]:
#Get data and generators

unet_type = 'unet'
dh = DataHandler()
all_dice = []
all_hd = []
all_hd95 = []

for i in range(7,9):
    exp_name = 'kfold_%s_BiCLSTM_K%d'%(unet_type, i)

    #get parameters
    params = getParams(exp_name, unet_type=unet_type, is_lstm = True)
    
    val_img_files = np.take(image_files, kfold_indices[i]['val'])
    val_mask_files = np.take(mask_files, kfold_indices[i]['val'])
    
    
    
    with lstm_graph.as_default():
        model = lstmModel()
        print('loading weights from %s'%params['checkpoint']['name'])
        model.load_weights(params['checkpoint']['name'])
        
    data = zip(val_img_files, val_mask_files)
    dice_score, hd_score, hd95_score = predictAll(model, lstm_graph, data, num_data=len(val_mask_files))
    
    print('Finished K%d'%i)
    
    all_dice += dice_score
    all_hd += hd_score
    all_hd95 += hd95_score

print('dice')
for i in range(len(all_dice)):
    print(all_dice[i])
print()

print('hd')
for i in range(len(all_hd)):
    print(all_hd[i])
print()

print('hd95')
for i in range(len(all_hd95)):
    print(all_hd95[i])
print()
    
print('Final results for %s'%unet_type)
print('dice %f'%np.mean(all_dice))
print('hd %f'%np.mean(all_hd))
print('hd95 %f'%np.mean(all_hd95))


loading weights from ./logs/unet_LSTM/kfold_unet_LSTM/kfold_unet_BiCLSTM_K7/kfold_unet_BiCLSTM_K7_weights.h5
loading pre weights 7


  0%|          | 0/29 [00:00<?, ?it/s]

(57, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1

  3%|▎         | 1/29 [00:08<03:52,  8.30s/it]

(53, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1

  7%|▋         | 2/29 [00:15<03:37,  8.05s/it]

(41, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)


 10%|█         | 3/29 [00:21<03:11,  7.36s/it]

(40, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)
(1, 3, 256, 256, 1)


 14%|█▍        | 4/29 [00:27<02:50,  6.83s/it]

KeyboardInterrupt: 