In [1]:
import os
import cv2
import glob
import numpy as np
import skimage.io as io
from tqdm import tqdm
from medpy.io import save
from matplotlib import pyplot as plt
from tta_wrapper import tta_segmentation
from sklearn.metrics import f1_score, jaccard_similarity_score

import tensorflow as tf
from keras import backend as K
from keras.models import *
from keras.preprocessing.image import ImageDataGenerator

from models.unet import *
from datahandler import DataHandler

%matplotlib inline

Using TensorFlow backend.


### 1. Load Model and Images

In [2]:
model = getUnet()
model.load_weights('logs/unet/unet_dice_nobells/unet_dice_nobells_weights.h5')

params = dict(
        h_flip=True,
        v_flip=True,
        rotation=(90, 180, 270),
        merge='mean')

tta_model = tta_segmentation(model, **params)

dh = DataHandler()

def resetSeed():
    np.random.seed(1)

def getGenerator(images, bs=32):
    resetSeed()

    image_datagen = ImageDataGenerator(rescale=1./255)
    image_datagen.fit(images, augment = True)
    image_generator = image_datagen.flow(x = images, batch_size=bs,
            shuffle = False)

    return image_generator

In [28]:
def ensemble_prediction(image):
    """Test time augmentation method using non-maximum supression"""

    masks = []    

    results = {}

    result = model.predict(image)
    print(result.shape)
    
    if result.shape[2] == 0:
        return result
    
    masks.append(result)

    temp_img = np.fliplr(image)
    result = model.predict(temp_img)
    mask = np.fliplr(result)
    masks.append(mask)

    temp_img = np.flipud(image)
    result = model.predict(temp_img)
    mask = np.flipud(result)
    masks.append(mask)
    
    temp_img = np.roll(image, 10, axis=0)
    result = model.predict(temp_img)
    mask = np.roll(result, -10, axis=0)
    masks.append(mask)
    
    temp_img = np.roll(image, 10, axis=1)
    result = model.predict(temp_img)
    mask = np.roll(result, -10, axis=1)
    masks.append(mask)

    angles = [90,180,270]
    for angle in angles:
        temp_img = np.rot90(image, k=angle, axes=(0, 1))
        result = model.predict(temp_img)
        mask = np.rot90(result, k=-angle, axes=(0, 1))
        masks.append(mask)

    masks = np.concatenate(masks, axis=-1)
    print(masks.shape)

    masks = np.sum(masks, axis=-1)
    masks[masks<6] = 0
    masks[masks>=6] = 1

    results['masks'] = mask

    return results

### 2. Model Evaluation

In [33]:
def destiny_directory(data_origin, dice_score):
    pre = './data/eval/unet_'+data_origin+'/'
    if dice_score >= 98:
        return pre + 'dice_98_100/'
    elif dice_score >= 96:
        return pre + 'dice_96_98/'
    elif dice_score >= 94:
        return pre + 'dice_94_96/'
    elif dice_score >= 92:
        return pre + 'dice_92_94/'
    elif dice_score >= 90:
        return pre + 'dice_90_92/'
    elif dice_score >= 88:
        return pre + 'dice_88_90/'
    elif dice_score >= 85:
        return pre + 'dice_85_88'
    elif dice_score >= 80:
        return pre + 'dice_80_85/'
    elif dice_score >= 70:
        return pre + 'dice_70_80/'
    elif dice_score >= 60:
        return pre + 'dice_60_70/'
    else:
        return pre + 'dice_less_60'
    
def getFileName(fname):
    original_name = fname.split('/')[-1]
    original_name = original_name[:original_name.index('.')]
    return original_name

def predictMask(image, tta=None, my_tta=False):
    if my_tta:
        results = []
        for i in range(image.shape[0]):
            img = image[i,:,:,:]
            results = ensemble_prediction(img)
            results[results>=0.5] = 1
            results[results<0.5] = 0
        return results
    
    if tta:
        images = []
        for im in image:
            images.append(im)
        
        images = np.asarray(images)
        test_gen = getGenerator(images, bs=1)
        results = tta.predict_generator(test_gen, len(images), verbose=0)
        results[results>=0.5] = 1
        results[results<0.5] = 0
        return results
    
    return model.predict(image)

def evaluateMask(ground_truth, prediction):
    #convert to boolean values and flatten
    ground_truth = np.asarray(ground_truth, dtype=np.bool).flatten()
    prediction = np.asarray(prediction, dtype=np.bool).flatten()    
    return f1_score(ground_truth, prediction)

def saveAll(data_origin, fname, hdr, image, gt_mask, pred_mask, score):
    fname = getFileName(fname)
    dice_score = int(score * 100)
    
    save_path = destiny_directory(data_origin, dice_score)
        
    save(pred_mask, os.path.join(save_path, fname + '_pred_' 
        + str(dice_score) + '.nii'), hdr)
    save(image, os.path.join(save_path, fname + '_img.nii'), hdr)
    save(gt_mask, os.path.join(save_path, fname + '_mask.nii'), hdr)
    

### 2.1 Refinement

In [5]:
def chooseComponent(image):
    image = image.astype('uint8')
    new_image = np.zeros(image.shape)
    for i in range(image.shape[0]):
        image_slice = image[i,:,:,:]
        nb_components, output, stats, centroids = cv2.connectedComponentsWithStats(
        image_slice, connectivity=4)
        sizes = stats[:, -1]
 
        max_label = 1
        if len(sizes) < 3:
            return image
        
        max_size = sizes[1]
        
        for j in range(2, nb_components):
            if sizes[j] > max_size:
                max_label = j
                max_size = sizes[j]
 
        new_slice = np.zeros(output.shape)
        new_slice[output == max_label] = 1
        new_slice = new_slice[..., np.newaxis]
        
        new_image[i,:,:,:] = new_slice
    
    return new_image

### 3. Predict masks and evaluate

In [23]:
def predictAll(data_origin, dict_prefix, keep_main_only=False, tta=None, my_tta=False):
    print(my_tta)
    images_dir = './data/'+data_origin+'/images/*'
    masks_dir = './data/'+data_origin+'/masks/*'
    
    data_origin+=dict_prefix
    
    glob_images = sorted(glob.glob(images_dir))
    glob_masks = sorted(glob.glob(masks_dir))
    
    inputs = zip(glob_images, glob_masks)
    dice_scores = []

    for image_file, mask_file in tqdm(inputs, total=len(glob_images)):
        image, hdr = dh.getImageData(image_file)
        gt_mask, _ = dh.getImageData(mask_file, is_mask=True)

        assert image.shape == gt_mask.shape
        if image.shape[1] != 256:
            continue

        pred_mask = predictMask(image, tta, my_tta)
        
        if keep_main_only:
            pred_mask = chooseComponent(pred_mask)
            
        score = evaluateMask(gt_mask, pred_mask)
        dice_scores.append(score)
            
        saveAll(data_origin, image_file, hdr, image, 
                gt_mask, pred_mask, score)

    print('Number of images %d'%len(dice_scores))
    print(np.mean(dice_scores))

### Vanilla Run

In [7]:
data_origin = 'test'
dict_prefix = ''
predictAll(data_origin, dict_prefix)

100%|██████████| 43/43 [01:55<00:00,  2.51s/it]

Number of images 43
0.8337489323247731





### Keeping largest component

In [8]:
data_origin = 'test'
dict_prefix = 'CC'
predictAll(data_origin, dict_prefix, keep_main_only=True)

100%|██████████| 43/43 [01:45<00:00,  2.41s/it]

Number of images 43
0.8562939566958778





### TTA

In [9]:
data_origin = 'test'
dict_prefix = 'TTA'
predictAll(data_origin, dict_prefix, tta=tta_model)

100%|██████████| 43/43 [04:09<00:00,  5.69s/it]

Number of images 43
0.923457878437851





In [34]:
data_origin = 'test'
dict_prefix = 'TTA_MINE'
predictAll(data_origin, dict_prefix, my_tta=True)






  0%|          | 0/43 [00:00<?, ?it/s][A[A[A[A[A

True


ValueError: Error when checking input: expected input_1 to have 4 dimensions, but got array with shape (256, 256, 1)

### TTA and CC

In [None]:
data_origin = 'test'
dict_prefix = 'TTACC'
predictAll(data_origin, dict_prefix, keep_main_only=True, tta=tta_model)