In [7]:
import os
import numpy as np

from keras.models import load_model
from keras import backend as K

from network import extractor, generator, model, utils

In [8]:
class_weights = [0.9939077556150256, 0.006092244384974352]

def dice_coef(y_true, y_pred, smooth=1):
    intersection = K.sum(K.abs(y_true * y_pred), axis=-1)
    return (2. * intersection + smooth) / (K.sum(K.square(y_true), -1) + K.sum(K.square(y_pred), -1) + smooth)


def weighted_dice_coefficient(y_true, y_pred):
    result = 0.0
    for i in range(len(class_weights)):
        wght = 1 - class_weights[i]
        y_true_f = K.flatten(y_true[:, :, :, i])
        y_pred_f = K.flatten(y_pred[:, :, :, i])
        intersection = K.sum(y_true_f * y_pred_f)
        result += ((2 * intersection + 1) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1)) * wght
        # result += (2.0 * K.sum(wght * intersection) + 1) / (K.sum(wght * y_true_f) + K.sum(wght * y_pred_f) + 1) 

    dice_error = result
    return dice_error


def weighted_dice_coefficient_loss(y_true, y_pred):
    return -weighted_dice_coefficient(y_true, y_pred)

In [15]:
model = load_model('../models/wmh.h5', custom_objects={
    'dice_coef': dice_coef,
    'weighted_dice_coefficient': weighted_dice_coefficient,
    'weighted_dice_coefficient_loss': weighted_dice_coefficient_loss
})

In [34]:
_, _, _, _, x_test, y_test = utils.create_paths(
  'whi_mat_hyp', '2d', 'small', 'fl'
)

test_generator = generator.DataSequence(
  x_test, y_test, '2d', 64, shuffle=False
)

x, y = utils.extract_generator(test_generator)

In [42]:
y_prediction = model.predict(np.array(x))

In [49]:
def decode_prediction(prediction):
    '''Change prediction to mask img'''
    img = prediction.argmax(axis=-1)
    img[img == 1] = 1.
    img[img == 0] = 0.
    return img.astype(np.float32)

y_pred = np.array([decode_prediction(y) for y in y_prediction])

x, y, y_pred = utils.squeeze_all(x, y, y_pred)

tp, fn, fp, tn = utils.calc_conf_matrix(y, y_pred)

prec, rec, f1 = utils.calc_metrics(y, y_pred)

print('prec: ', round(prec, 4), ', rec: ', round(rec, 4), ', f1: ', round(f1, 4))

prec:  0.757 , rec:  0.3372 , f1:  0.4665
