In [1]:
import numpy as np

from keras.models import *
from keras import backend as K
from keras.preprocessing.image import ImageDataGenerator

from models.unet_se import *
from models.unet import *
from models.resnet_fcn import *
from models.resnet_se_fcn import *
from models.resnet_fcn import *
from models.vgg19_fcn import *
from models.vgg19_se_fcn import *
from models.unet_resnet_se import *
from models.unet_upconv import *
from models.unet_upconv_se import *
from models.unet_resnet_upconv_se import *

from datahandler import DataHandler
from data_loader import *
from params import *
import os
import cv2
import skimage.io as io
from tqdm import tqdm

from medpy.io import save

from math import ceil, floor
from matplotlib import pyplot as plt
from sklearn.metrics import f1_score, jaccard_similarity_score

from scipy.ndimage import _ni_support
from scipy.ndimage.morphology import distance_transform_edt, binary_erosion,\
     generate_binary_structure

from skimage.morphology import cube, binary_closing
from skimage.measure import label

import warnings
warnings.filterwarnings("ignore")

plt.gray()

Using TensorFlow backend.


In [2]:
def destiny_directory(model_name, dice_score):
    pre = './data/eval/'+model_name+'/'
    if dice_score >= 98:
        return pre + 'dice_98_100/'
    elif dice_score >= 96:
        return pre + 'dice_96_98/'
    elif dice_score >= 94:
        return pre + 'dice_94_96/'
    elif dice_score >= 92:
        return pre + 'dice_92_94/'
    elif dice_score >= 90:
        return pre + 'dice_90_92/'
    elif dice_score >= 88:
        return pre + 'dice_88_90/'
    elif dice_score >= 85:
        return pre + 'dice_85_88'
    elif dice_score >= 80:
        return pre + 'dice_80_85/'
    elif dice_score >= 70:
        return pre + 'dice_70_80/'
    elif dice_score >= 60:
        return pre + 'dice_60_70/'
    else:
        return pre + 'dice_less_60'

In [3]:
def getModel(name):
    print('Working with %s'%name)
    if name == 'unet' or name == 'unet_focal':
        model = getUnet()
    elif name == 'unet_se':
        model = getSEUnet()
    elif name == 'unet_upconv':
        model = getUnetUpconv()
    elif name == 'unet_upconv_se':
        model = getSEUnetUpconv()
    elif name == 'resnetFCN':
        model = getResnet50FCN()
    elif name == 'resnetSEFCN':
        model = getResnetSE50FCN()
    elif name == 'vgg19FCN':
        model = getVGG19FCN()
    elif name == 'vgg19SEFCN':
        model = getVGG19SEFCN()
    elif name == 'unet_resnet':
        model = getUnetRes()
    elif name == 'unet_resnet_se':
        model = getUnetRes(se_version = True)
    elif name == 'unet_resnet_upconv':
        model = getUnetResUpconv()
    elif name == 'unet_resnet_upconv_se':
        model = getUnetResUpconv(se_version = True)
    # elif name == 'unetResnet18':
    #     model = getUnetResnet18()
    # elif name == 'unetResnet18SE':
    #     model = getUnetResnet18(se_version = True)
    else:
        print('error')
        return -1

    return model


In [4]:
def getGenerator(images, bs=1):
    image_datagen = ImageDataGenerator(rescale=1./255)
    image_datagen.fit(images, augment = True)
    image_generator = image_datagen.flow(x = images, batch_size=bs,
            shuffle = False)

    return image_generator


In [5]:
def getDiceScore(ground_truth, prediction):
    #convert to boolean values and flatten
    ground_truth = np.asarray(ground_truth, dtype=np.bool).flatten()
    prediction = np.asarray(prediction, dtype=np.bool).flatten()    
    return f1_score(ground_truth, prediction)


In [6]:
 def hd(result, reference, voxelspacing=None, connectivity=1):
    hd1 = __surface_distances(result, reference, voxelspacing, connectivity).max()
    hd2 = __surface_distances(reference, result, voxelspacing, connectivity).max()
    hd = max(hd1, hd2)
    return hd

def hd95(result, reference, voxelspacing=None, connectivity=1):
    hd1 = __surface_distances(result, reference, voxelspacing, connectivity)
    hd2 = __surface_distances(reference, result, voxelspacing, connectivity)
    hd95 = np.percentile(np.hstack((hd1, hd2)), 95)
    return hd95

def __surface_distances(result, reference, voxelspacing=None, connectivity=1):
    result = np.atleast_1d(result.astype(np.bool))
    reference = np.atleast_1d(reference.astype(np.bool))
    if voxelspacing is not None:
        voxelspacing = _ni_support._normalize_sequence(voxelspacing, result.ndim)
        voxelspacing = np.asarray(voxelspacing, dtype=np.float64)
        if not voxelspacing.flags.contiguous:
            voxelspacing = voxelspacing.copy()

    footprint = generate_binary_structure(result.ndim, connectivity)

    if 0 == np.count_nonzero(result):
        raise RuntimeError('The first supplied array does not contain any binary object.')
    if 0 == np.count_nonzero(reference):
        raise RuntimeError('The second supplied array does not contain any binary object.')

    result_border = result ^ binary_erosion(result, structure=footprint, iterations=1)
    reference_border = reference ^ binary_erosion(reference, structure=footprint, iterations=1)

    dt = distance_transform_edt(~reference_border, sampling=voxelspacing)
    sds = dt[result_border]

    return sds

In [7]:
image_files, mask_files = load_data_files('data/kfold_data/')
print(len(image_files))
print(len(mask_files))
skf = getKFolds(image_files, mask_files, n=10)

kfold_indices = []
for train_index, val_index in skf.split(image_files, mask_files):
    kfold_indices.append({'train': train_index, 'val': val_index})

291
291


In [8]:
def predictMask(model, image):  
    image_gen = getGenerator(image)
    return model.predict_generator(image_gen, steps=len(image))

In [9]:
def prepareForSaving(image):
    image = np.squeeze(image)
    image = np.swapaxes(image, -1, 0)
    
    return image

def predictAll(model, model_name, data, num_data=0):
    dice_scores = []
    names = []
    hd_scores = []
    hd95_scores = []

    for image_file, mask_file in tqdm(data, total=num_data):
        
        fname = image_file[image_file.rindex('/')+1 : image_file.index('.')]
        image, hdr = dh.getImageData(image_file)
        gt_mask, _ = dh.getImageData(mask_file, is_mask=True)

        assert image.shape == gt_mask.shape
        
        if image.shape[1] != 256:
            continue
        
        pred_mask = predictMask(model, image)
        pred_mask[pred_mask>=0.5] = 1
        pred_mask[pred_mask<0.5] = 0
        
        #CHECAR SHAPES DE ESTO Y COMPARAR SIN GUARDAR Y sin PP
        
        #pred_mask = np.squeeze(pred_mask)
        
        #closing and defrag squeze of mask
        '''pred_mask = binary_closing(np.squeeze(pred_mask), cube(2))
        
        try:
            labels = label(pred_mask)
            pred_mask = (labels == np.argmax(np.bincount(labels.flat)[1:])+1).astype(int)
        
        except:
            pred_mask = pred_mask'''
            
        gt_mask = np.squeeze(gt_mask)
            
        dice_score = getDiceScore(gt_mask, pred_mask)
        
        if dice_score == 0:
            dice_scores.append(dice_score)
            hd_scores.append(200)
            hd95_scores.append(200) 
            continue
        
        names.append(fname)
        dice_scores.append(dice_score)
        
        hd_score = hd(gt_mask, pred_mask)
        hd_scores.append(hd_score)
        
        hd95_score = hd95(gt_mask, pred_mask)
        hd95_scores.append(hd95_score)
        
        int_dice_score = floor(dice_score * 100)
        save_path = destiny_directory(model_name, int_dice_score)
        
        pred_mask = prepareForSaving(pred_mask)
        image = prepareForSaving(image)
        gt_mask = prepareForSaving(gt_mask)
            
        '''save(pred_mask, os.path.join(save_path, fname + '_' + model_name + '_' 
            + str(int_dice_score) + '.nii'), hdr)
        save(image, os.path.join(save_path, fname + '_img.nii'), hdr)
        save(gt_mask, os.path.join(save_path, fname + '_mask.nii'), hdr)'''

    return dice_scores, hd_scores, hd95_scores, names

In [10]:
#Get data and generators

model_types = ['unet_resnet_upconv', 'unet_resnet_upconv_se']
for model_type in model_types:
    dh = DataHandler()
    all_dice = []
    all_hd = []
    all_hd95 = []
    #all_names = []

    for i in range(len(kfold_indices)):
        exp_name = 'kfold_%s_dice_DA_K%d'%(model_type, i)

        #get parameters
        params = getParams(exp_name, unet_type=model_type)

        val_img_files = np.take(image_files, kfold_indices[i]['val'])
        val_mask_files = np.take(mask_files, kfold_indices[i]['val'])


        model = getModel(model_type)

        print('loading weights from %s'%params['checkpoint']['name'])
        model.load_weights(params['checkpoint']['name'])

        data = zip(val_img_files, val_mask_files)

        dice_score, hd_score, hd95_score, names = predictAll(model, model_type, data, num_data=len(val_mask_files))

        print('Finished K%d'%i)

        all_dice += dice_score
        all_hd += hd_score
        all_hd95 += hd95_score
        #all_names.extend(names)

    print('dice')
    for i in range(len(all_dice)):
        print(all_dice[i])
    print()

    print('hd')
    for i in range(len(all_hd)):
        print(all_hd[i])
    print()

    print('hd95')
    for i in range(len(all_hd95)):
        print(all_hd95[i])
    print()

    '''print('names')
    for i in range(len(all_names)):
        print(all_names[i])
    print()'''

    print('Final results for %s'%model_type)
    print('dice %f'%np.mean(all_dice))
    print('hd %f'%np.mean(all_hd))
    print('hd95 %f'%np.mean(all_hd95))


Working with unet_resnet_upconv
loading weights from ./logs/unet_resnet_upconv/kfold_unet_resnet_upconv/kfold_unet_resnet_upconv_dice_DA_K0/kfold_unet_resnet_upconv_dice_DA_K0_weights.h5


100%|██████████| 30/30 [01:44<00:00,  3.16s/it]


Finished K0
Working with unet_resnet_upconv
loading weights from ./logs/unet_resnet_upconv/kfold_unet_resnet_upconv/kfold_unet_resnet_upconv_dice_DA_K1/kfold_unet_resnet_upconv_dice_DA_K1_weights.h5


100%|██████████| 29/29 [01:49<00:00,  3.64s/it]


Finished K1
Working with unet_resnet_upconv
loading weights from ./logs/unet_resnet_upconv/kfold_unet_resnet_upconv/kfold_unet_resnet_upconv_dice_DA_K2/kfold_unet_resnet_upconv_dice_DA_K2_weights.h5


100%|██████████| 29/29 [01:40<00:00,  3.10s/it]


Finished K2
Working with unet_resnet_upconv
loading weights from ./logs/unet_resnet_upconv/kfold_unet_resnet_upconv/kfold_unet_resnet_upconv_dice_DA_K3/kfold_unet_resnet_upconv_dice_DA_K3_weights.h5


100%|██████████| 29/29 [01:41<00:00,  3.24s/it]


Finished K3
Working with unet_resnet_upconv
loading weights from ./logs/unet_resnet_upconv/kfold_unet_resnet_upconv/kfold_unet_resnet_upconv_dice_DA_K4/kfold_unet_resnet_upconv_dice_DA_K4_weights.h5


100%|██████████| 29/29 [01:45<00:00,  3.40s/it]


Finished K4
Working with unet_resnet_upconv
loading weights from ./logs/unet_resnet_upconv/kfold_unet_resnet_upconv/kfold_unet_resnet_upconv_dice_DA_K5/kfold_unet_resnet_upconv_dice_DA_K5_weights.h5


100%|██████████| 29/29 [01:35<00:00,  3.45s/it]


Finished K5
Working with unet_resnet_upconv
loading weights from ./logs/unet_resnet_upconv/kfold_unet_resnet_upconv/kfold_unet_resnet_upconv_dice_DA_K6/kfold_unet_resnet_upconv_dice_DA_K6_weights.h5


100%|██████████| 29/29 [01:35<00:00,  3.24s/it]


Finished K6
Working with unet_resnet_upconv
loading weights from ./logs/unet_resnet_upconv/kfold_unet_resnet_upconv/kfold_unet_resnet_upconv_dice_DA_K7/kfold_unet_resnet_upconv_dice_DA_K7_weights.h5


100%|██████████| 29/29 [01:47<00:00,  3.55s/it]


Finished K7
Working with unet_resnet_upconv
loading weights from ./logs/unet_resnet_upconv/kfold_unet_resnet_upconv/kfold_unet_resnet_upconv_dice_DA_K8/kfold_unet_resnet_upconv_dice_DA_K8_weights.h5


100%|██████████| 29/29 [01:54<00:00,  4.59s/it]


Finished K8
Working with unet_resnet_upconv
loading weights from ./logs/unet_resnet_upconv/kfold_unet_resnet_upconv/kfold_unet_resnet_upconv_dice_DA_K9/kfold_unet_resnet_upconv_dice_DA_K9_weights.h5


100%|██████████| 29/29 [01:41<00:00,  3.10s/it]


Finished K9
dice
0.9478326551497283
0.945354116763791
0.9067519365119567
0.8874737779105741
0.9415164850516904
0.9741883420193758
0.9652293761706371
0.9335139781879261
0.964328265496278
0.9499000559686576
0.9661800221249279
0.9499613767058621
0.9492382155885649
0.9707188191218535
0.9622806627211915
0.9436770831455
0.9256578976170775
0.9437265527303043
0.955413298967628
0.9601442094434536
0.5868848644665917
0.9719744757766356
0.971868340908883
0.9662115260718664
0.9427340294883898
0.8690021252606688
0.9481557277201189
0.9325912956478238
0.9586915945044521
0.8889945626669836
0.9463348838131888
0.9494010323669771
0.9313404558916987
0.9627200438716753
0.9533234157413655
0.9641454215377874
0.9308375856794217
0.9640489316896576
0.374798669650888
0.9476035018266779
0.9560212609266434
0.9588184075590958
0.9418285802053641
0.9425644932596873
0.9678841641148678
0.9141542052987426
0.9549898350714321
0.9667414559231475
0.958178321378273
0.9587022378579108
0.9664908870695863
0.9550192023986573
0.94

loading weights from ./logs/unet_resnet_upconv_se/kfold_unet_resnet_upconv_se/kfold_unet_resnet_upconv_se_dice_DA_K0/kfold_unet_resnet_upconv_se_dice_DA_K0_weights.h5


100%|██████████| 30/30 [01:45<00:00,  3.29s/it]


Finished K0
Working with unet_resnet_upconv_se
loading weights from ./logs/unet_resnet_upconv_se/kfold_unet_resnet_upconv_se/kfold_unet_resnet_upconv_se_dice_DA_K1/kfold_unet_resnet_upconv_se_dice_DA_K1_weights.h5


100%|██████████| 29/29 [01:50<00:00,  3.56s/it]


Finished K1
Working with unet_resnet_upconv_se
loading weights from ./logs/unet_resnet_upconv_se/kfold_unet_resnet_upconv_se/kfold_unet_resnet_upconv_se_dice_DA_K2/kfold_unet_resnet_upconv_se_dice_DA_K2_weights.h5


100%|██████████| 29/29 [01:41<00:00,  3.08s/it]


Finished K2
Working with unet_resnet_upconv_se
loading weights from ./logs/unet_resnet_upconv_se/kfold_unet_resnet_upconv_se/kfold_unet_resnet_upconv_se_dice_DA_K3/kfold_unet_resnet_upconv_se_dice_DA_K3_weights.h5


100%|██████████| 29/29 [01:44<00:00,  3.26s/it]


Finished K3
Working with unet_resnet_upconv_se
loading weights from ./logs/unet_resnet_upconv_se/kfold_unet_resnet_upconv_se/kfold_unet_resnet_upconv_se_dice_DA_K4/kfold_unet_resnet_upconv_se_dice_DA_K4_weights.h5


100%|██████████| 29/29 [01:49<00:00,  3.51s/it]


Finished K4
Working with unet_resnet_upconv_se
loading weights from ./logs/unet_resnet_upconv_se/kfold_unet_resnet_upconv_se/kfold_unet_resnet_upconv_se_dice_DA_K5/kfold_unet_resnet_upconv_se_dice_DA_K5_weights.h5


100%|██████████| 29/29 [01:41<00:00,  3.63s/it]


Finished K5
Working with unet_resnet_upconv_se
loading weights from ./logs/unet_resnet_upconv_se/kfold_unet_resnet_upconv_se/kfold_unet_resnet_upconv_se_dice_DA_K6/kfold_unet_resnet_upconv_se_dice_DA_K6_weights.h5


100%|██████████| 29/29 [01:40<00:00,  3.39s/it]


Finished K6
Working with unet_resnet_upconv_se
loading weights from ./logs/unet_resnet_upconv_se/kfold_unet_resnet_upconv_se/kfold_unet_resnet_upconv_se_dice_DA_K7/kfold_unet_resnet_upconv_se_dice_DA_K7_weights.h5


100%|██████████| 29/29 [01:54<00:00,  3.68s/it]


Finished K7
Working with unet_resnet_upconv_se
loading weights from ./logs/unet_resnet_upconv_se/kfold_unet_resnet_upconv_se/kfold_unet_resnet_upconv_se_dice_DA_K8/kfold_unet_resnet_upconv_se_dice_DA_K8_weights.h5


100%|██████████| 29/29 [01:59<00:00,  4.83s/it]


Finished K8
Working with unet_resnet_upconv_se
loading weights from ./logs/unet_resnet_upconv_se/kfold_unet_resnet_upconv_se/kfold_unet_resnet_upconv_se_dice_DA_K9/kfold_unet_resnet_upconv_se_dice_DA_K9_weights.h5


100%|██████████| 29/29 [01:45<00:00,  3.23s/it]

Finished K9
dice
0.9492059437135437
0.9486701221371308
0.9105819770755493
0.9087238483327537
0.9470540272263188
0.9664292899135586
0.9672575173265089
0.9382936333303996
0.9674374560894401
0.9626367031565635
0.9585846674363613
0.9420977807763667
0.9482957025514861
0.9693861554654782
0.9658771735655082
0.9435571581365333
0.9456179505742351
0.9294204874315372
0.658466018330129
0.9625963732662202
0.6166037735849057
0.9708138519651581
0.973203619930252
0.9668629545887477
0.947348133342851
0.860817958582716
0.9549015301266565
0.9551892755178175
0.9358310518589058
0.9155095858696967
0.9073666241662365
0.9566176422028647
0.9222553540891103
0.9551963649574086
0.9401930423026741
0.9671035940803381
0.9430113833594675
0.952261171802259
0.23779434494044915
0.937324617738254
0.9564038669145043
0.949413982460454
0.9410949536485623
0.9231928990062007
0.964115161847079
0.9058328920207075
0.9424755087356789
0.9491512104928642
0.9589301306698939
0.9552595243188581
0.9584993984827033
0.9533971217795668
0.


