In [1]:
import os
import cv2
import glob
import numpy as np
import skimage.io as io
from tqdm import tqdm
from medpy.io import save
from matplotlib import pyplot as plt
from tta_wrapper import tta_segmentation
from sklearn.metrics import f1_score, jaccard_similarity_score

import tensorflow as tf
from keras import backend as K
from keras.models import *
from keras.preprocessing.image import ImageDataGenerator

from models.unet import *
from datahandler import DataHandler

%matplotlib inline

Using TensorFlow backend.


### 1. Load Model and Images

In [2]:
model = getUnet()
model.load_weights('logs/unet/unet_dice_nobells/unet_dice_nobells_weights.h5')

params = dict(
        h_flip=True,
        v_flip=True,
        h_shift=None,
        v_shift=None,
        rotation=(90,180,270),
        contrast=None,
        add=None,
        mul=None,
        merge='mean')

tta_model = tta_segmentation(model, **params)

dh = DataHandler()

def resetSeed():
    np.random.seed(1)


In [3]:
def getGenerator(images, bs=1):
    resetSeed()

    image_datagen = ImageDataGenerator(rescale=1./255)
    image_datagen.fit(images, augment = True)
    image_generator = image_datagen.flow(x = images, batch_size=bs,
            shuffle = False)

    return image_generator

### 2. Model Evaluation

In [4]:
def destiny_directory(data_origin, dice_score):
    pre = './data/eval/unet_'+data_origin+'/'
    if dice_score >= 98:
        return pre + 'dice_98_100/'
    elif dice_score >= 96:
        return pre + 'dice_96_98/'
    elif dice_score >= 94:
        return pre + 'dice_94_96/'
    elif dice_score >= 92:
        return pre + 'dice_92_94/'
    elif dice_score >= 90:
        return pre + 'dice_90_92/'
    elif dice_score >= 88:
        return pre + 'dice_88_90/'
    elif dice_score >= 85:
        return pre + 'dice_85_88'
    elif dice_score >= 80:
        return pre + 'dice_80_85/'
    elif dice_score >= 70:
        return pre + 'dice_70_80/'
    elif dice_score >= 60:
        return pre + 'dice_60_70/'
    else:
        return pre + 'dice_less_60'
    
def getFileName(fname):
    original_name = fname.split('/')[-1]
    original_name = original_name[:original_name.index('.')]
    return original_name

def evaluateMask(ground_truth, prediction):
    #convert to boolean values and flatten
    ground_truth = np.asarray(ground_truth, dtype=np.bool).flatten()
    prediction = np.asarray(prediction, dtype=np.bool).flatten()    
    return f1_score(ground_truth, prediction)

def saveAll(data_origin, fname, hdr, image, gt_mask, pred_mask, score):
    fname = getFileName(fname)
    dice_score = int(score * 100)
    
    save_path = destiny_directory(data_origin, dice_score)
        
    save(pred_mask, os.path.join(save_path, fname + '_pred_' 
        + str(dice_score) + '.nii'), hdr)
    save(image, os.path.join(save_path, fname + '_img.nii'), hdr)
    save(gt_mask, os.path.join(save_path, fname + '_mask.nii'), hdr)
    

### 2.1 Refinement

In [5]:
def chooseComponent(image):
    image = image.astype('uint8')
    new_image = np.zeros(image.shape)
    for i in range(image.shape[0]):
        image_slice = image[i,:,:,:]
        nb_components, output, stats, centroids = cv2.connectedComponentsWithStats(
        image_slice, connectivity=4)
        sizes = stats[:, -1]
 
        max_label = 1
        if len(sizes) < 3:
            return image
        
        max_size = sizes[1]
        
        for j in range(2, nb_components):
            if sizes[j] > max_size:
                max_label = j
                max_size = sizes[j]
 
        new_slice = np.zeros(output.shape)
        new_slice[output == max_label] = 1
        new_slice = new_slice[..., np.newaxis]
        
        new_image[i,:,:,:] = new_slice
    
    return new_image

In [6]:
import itertools

def vflip(image):
    return np.flipud(image)

def hflip(image):
    return np.fliplr(image)

def rotate(image, k):
    return np.rot90(image, k, axes=(1,2))


def ensemble_prediction(image, image_gen):
    """Test time augmentation method using non-maximum supression"""
    prediction = np.zeros(image.shape)
    slice_count = len(image_gen)
    for i in range(slice_count):
        img_slice = image_gen[i]
        masks = []
    
        result = model.predict(img_slice)
        masks.append(np.squeeze(result, axis=0))
    
        flip_v = [True, False]
        flip_h = [True, False]
        rotations = [0,1,2,3]
        
        transformations = list(itertools.product(flip_v, flip_h, rotations))
        
        for fv, fh, r in transformations:
            result = img_slice
            result = vflip(result) if fv else result
            result = hflip(result) if fh else result
            result = rotate(result, r)
            result = model.predict(result)
            result = rotate(result, -r)
            result = hflip(result) if fh else result
            result = vflip(result) if fv else result
            masks.append(np.squeeze(result, axis=0))

        masks = np.concatenate(masks, axis=-1)
        masks = np.mean(masks, axis=-1, keepdims=True)
        
        prediction[i] = masks
    return prediction

### 3. Predict masks and evaluate

In [7]:
def predictMask(image, tta=None, my_tta=False):
    image_gen = getGenerator(image)
    
    if my_tta:
        results = ensemble_prediction(image, image_gen)
        return results
    
    if tta:
        return tta.predict_generator(image_gen)    
    
    return model.predict_generator(image_gen, steps=len(image))

def predictAll(data_origin, dict_prefix, keep_main_only=False, tta=None, my_tta=False):
    images_dir = './data/'+data_origin+'/images/*'
    masks_dir = './data/'+data_origin+'/masks/*'
    
    data_origin+=dict_prefix
    
    glob_images = sorted(glob.glob(images_dir))
    glob_masks = sorted(glob.glob(masks_dir))
    
    inputs = zip(glob_images, glob_masks)
    dice_scores = []

    for image_file, mask_file in tqdm(inputs, total=len(glob_images)):
        image, hdr = dh.getImageData(image_file)
        gt_mask, _ = dh.getImageData(mask_file, is_mask=True)

        assert image.shape == gt_mask.shape
        if image.shape[1] != 256:
            continue

        pred_mask = predictMask(image, tta, my_tta)
        pred_mask[pred_mask>=0.7] = 1
        pred_mask[pred_mask<0.7] = 0
        
        if keep_main_only:
            pred_mask = chooseComponent(pred_mask)
            
        score = evaluateMask(gt_mask, pred_mask)
        dice_scores.append(score)
            
        saveAll(data_origin, image_file, hdr, image, 
                gt_mask, pred_mask, score)

    print('Number of images %d'%len(dice_scores))
    print(np.mean(dice_scores))

### Vanilla Run

In [8]:
data_origin = 'test'
dict_prefix = ''
predictAll(data_origin, dict_prefix)

100%|██████████| 43/43 [01:57<00:00,  2.69s/it]

Number of images 43
0.9277734986264862





### Keeping largest component

In [9]:
data_origin = 'test'
dict_prefix = 'CC'
predictAll(data_origin, dict_prefix, keep_main_only=True)

100%|██████████| 43/43 [01:55<00:00,  2.66s/it]

Number of images 43
0.9277734986264862





### TTA

In [10]:
data_origin = 'test'
dict_prefix = 'TTA'
predictAll(data_origin, dict_prefix, tta=tta_model)

100%|██████████| 43/43 [04:11<00:00,  5.72s/it]

Number of images 43
0.9148326878501317





In [11]:
data_origin = 'test'
dict_prefix = 'TTA_MINE'
predictAll(data_origin, dict_prefix, my_tta=True)

100%|██████████| 43/43 [05:32<00:00,  7.52s/it]

Number of images 43
0.9150217444187126





### TTA and CC

In [12]:
data_origin = 'test'
dict_prefix = 'TTACC'
predictAll(data_origin, dict_prefix, keep_main_only=True, tta=tta_model)

100%|██████████| 43/43 [04:09<00:00,  5.69s/it]

Number of images 43
0.9148326878501317



