In [1]:
from collections import namedtuple
import numpy as np
from sklearn.metrics import jaccard_similarity_score
from sklearn.metrics import precision_recall_fscore_support
from sklearn.metrics import log_loss, brier_score_loss

import tensorflow as tf
# import keras.backend as K
from keras import backend as K
from keras import objectives
from keras.losses import binary_crossentropy

_EPSILON = K.epsilon()
EPS = np.finfo(np.float32).eps
print(_EPSILON)
print(EPS)

1e-07
1.1920929e-07


Using TensorFlow backend.


$$E = \sum_{\pmb{x}\in\Omega} w(\pmb{x}) log(p_{l(\pmb{x})}(\pmb{x}))$$

$$w(\pmb{x}) = w_c(\pmb{x}) + w_0 \mathrm{exp}\left(-\frac{(d_1(\pmb{x}) + d_2(\pmb{x}))^2)}{2\sigma^2}\right)$$

In [None]:
def get_border_weights(border_mask, w0=10, sigma=5):
    g = np.exp(-)
    return w0 * g


1. Recall score
2. IoU score
3. Jaccard index
4. F$_\beta$ score
5. Dice score
6. Cos Dice Loss

In [2]:
def perf_measure0(y_true, y_pred, axis=None):
    yt = y_true.astype(bool)
    yp = y_pred.astype(bool)
    TP = np.sum(yt & yp, axis=axis)
    TN = np.sum(~yt & ~yp, axis=axis)
    FP = np.sum(~yt & yp, axis=axis)
    FN = np.sum(yt & ~yp, axis=axis)
    return TP, FP, TN, FN

def perf_measure(y_true, y_pred, axis=None):
    TP = np.sum(y_true * y_pred, axis=axis)
    TN = np.sum((1 - y_true) * (1 - y_pred), axis=axis)
    FP = np.sum((1 - y_true) * y_pred, axis=axis)
    FN = np.sum(y_true * (1 - y_pred), axis=axis)
    return TP, FP, TN, FN
    
def confusion_matrix_variants(TP, FP, TN, FN):
    TPR = TP / (TP + FN)  # Sensitivity, hit rate, recall, or true positive rate
    TNR = TN / (TN + FP)  # Specificity or true negative rate
    PPV = TP / (TP + FP)  # Precision or positive predictive value
    NPV = TN / (TN + FN)  # Negative predictive value
    FPR = FP / (FP + TN)  # Fall out or false positive rate
    FNR = FN / (TP + FN)  # False negative rate
    FDR = FP / (TP + FP)  # False discovery rate
    ACC = (TP + TN) / (TP + FP + FN + TN)  # Overall accuracy

    #     print(f'{TP:>5}, {TN:>6}, {FP:>5}, {FN:>5}, {TPR:>.5f}, {TNR:>.5f}, {PPV:>.5f}, {NPV:>.5f}, {FPR:>.5f}, {FNR:>.5f}, {FDR:>.5f}, {ACC:>.5f}')
    return TPR, TNR, PPV, NPV, FPR, FNR, FDR, ACC

def recall(y_true, y_pred, axis=None, smooth=1e-3):
    return (np.sum(y_true * y_pred, axis=axis) + smooth) / (np.sum(y_true, axis=axis) + smooth)

def fbeta(y_true, y_pred, beta=2, axis=None, smooth=1e-3):
    tp, fp, tn, fn = perf_measure(y_true, y_pred, axis=axis)
    return ((beta**2 + 1) * tp + smooth) / ((beta**2 + 1) * tp + beta**2 * fn + fp + smooth)

def iou(y_true, y_pred, axis=None):
    i = np.sum((y_true * y_pred) > 0.5, axis=axis)
    u = np.sum((y_true + y_pred) > 0.5, axis=axis) + EPS  # avoid division by zero
    return i / u

def dice_score(y_true, y_pred, axis=None, smooth=1e-3):
    AB = np.sum(y_true * y_pred, axis=axis)
    A = np.sum(y_true, axis=axis)
    B = np.sum(y_pred, axis=axis)
    return (2 * AB + smooth) / (A + B + smooth)

def brier_loss(y_true, y_pred, **kwargs):
    return np.sum((y_pred - y_true)**2, axis=-1)

def print_headers(metrics):
    res = [m.header for m in metrics]
    fmts = ' '.join([m.hfmt for m in metrics])
    print(fmts.format(*res))
    
def print_metrics(y_true, y_pred, metrics):
    res = [m.func(y_true, y_pred, **m.kwargs) for m in metrics]
    fmts = ' '.join([m.fmt for m in metrics])
    print(fmts.format(*res))

def print_losses(y_true, y_pred, losses):
    res = [np.mean(m.func(y_true, y_pred, **m.kwargs)) for m in losses]
    fmts = ' '.join([m.fmt for m in losses])
    print(fmts.format(*res))

def enlarge(arr1, arr2):
    top0 = np.hstack([arr1, arr2])
    bot0 = np.hstack([arr2, arr2])
    return np.vstack([top0, bot0])

In [2]:
def bce_tensor(y_true, y_pred):
    y_pred = K.clip(y_pred, _EPSILON, 1.0 - _EPSILON)
    out = -(y_true * K.log(y_pred) + (1.0 - y_true) * K.log(1.0 - y_pred))
    return K.mean(out, axis=-1)

def bce_np(y_true, y_pred, **kwargs):
    y_pred = np.clip(y_pred, _EPSILON, 1.0 - _EPSILON)
    out = -(y_true * np.log(y_pred) + (1.0 - y_true) * np.log(1.0 - y_pred))
    return np.mean(out, axis=-1)

In [90]:
def check_loss(_shape):
    if _shape == '2d':
        shape = (6, 7)
    elif _shape == '3d':
        shape = (5, 6, 7)
    elif _shape == '4d':
        shape = (8, 5, 6, 7)
    elif _shape == '5d':
        shape = (9, 8, 5, 6, 7)

    y_a = np.random.random(shape)
    y_b = np.random.random(shape)

    out1 = K.eval(binary_crossentropy(K.variable(y_a), K.variable(y_b)))
    out2 = K.eval(bce_tensor(K.variable(y_a), K.variable(y_b)))
    out3 = bce_np(y_a, y_b)
    
    assert out1.shape == out2.shape
    assert out1.shape == out3.shape
    assert out1.shape == shape[:-1]
    print(out1.shape)
    print(np.linalg.norm(out1))
    print(np.linalg.norm(out2))
    print(np.linalg.norm(out3))
    print(np.linalg.norm(out1-out2))

def test_loss():
    shape_list = ['2d', '3d', '4d', '5d']
    for _shape in shape_list:
        check_loss(_shape)
        print('======================')

In [None]:
test_loss()

In [None]:
# We can recover the bce by excluding the weighting factor alpha 
# and taking the mean of the focal_loss along the -1 axis.
# Assuming that the last axis is the label axis as is the case in Tensorflow.
def focal_loss_to_bce_wrapper(gamma=0, alpha=None, axis=-1):
    def focal_loss_to_bce(y_true, y_pred):
        y_pred_c = np.clip(y_pred, EPS, 1. - EPS)
        pt_1 = np.where(np.equal(y_true, 1), y_pred_c, np.ones_like(y_pred))
        pt_0 = np.where(np.equal(y_true, 0), y_pred_c, np.zeros_like(y_pred))
        res1 = np.power(1. - pt_1, gamma) * np.log(     pt_1)
        res0 = np.power(     pt_0, gamma) * np.log(1. - pt_0)
        return -np.mean(res1 + res0, axis=axis)
    return focal_loss_to_bce

def focal_loss_wrapper_k1(gamma=2., alpha=.25, axis=-1):
    def focal_loss_k1(y_true, y_pred):
        y_pred_c = K.clip(y_pred, _EPSILON, 1.0 - _EPSILON)
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred_c, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred_c, tf.zeros_like(y_pred))
        res1 =      alpha  * K.pow(1. - pt_1, gamma) * K.log(     pt_1)
        res0 = (1 - alpha) * K.pow(     pt_0, gamma) * K.log(1. - pt_0)
        return -K.mean(res1 + res0, axis=axis)
    return focal_loss_k1

def focal_loss_wrapper_n1(gamma=2., alpha=.25, axis=-1):
    def focal_loss_n1(y_true, y_pred):
        y_pred_c = np.clip(y_pred, EPS, 1. - EPS)
        pt_1 = np.where(np.equal(y_true, 1), y_pred_c, np.ones_like(y_pred))
        pt_0 = np.where(np.equal(y_true, 0), y_pred_c, np.zeros_like(y_pred))
        res1 =      alpha  * np.power(1. - pt_1, gamma) * np.log(     pt_1)
        res0 = (1 - alpha) * np.power(     pt_0, gamma) * np.log(1. - pt_0)
        return -np.mean(res1 + res0, axis=axis)
    return focal_loss_n1

def focal_loss_wrapper_ns(gamma=2., alpha=.25, axis=-1):
    def focal_loss_ns(y_true, y_pred):
        y_pred_c = np.clip(y_pred, EPS, 1. - EPS)
        pt_1 = np.where(np.equal(y_true, 1), y_pred_c, np.ones_like(y_pred))
        pt_0 = np.where(np.equal(y_true, 0), y_pred_c, np.zeros_like(y_pred))
        res1 =      alpha  * np.power(1. - pt_1, gamma) * np.log(     pt_1)
        res0 = (1 - alpha) * np.power(     pt_0, gamma) * np.log(1. - pt_0)
        loss = -np.mean(res1 + res0, axis=axis)
        return np.sum(loss)
    return focal_loss_ns

def focal_loss_wrapper_k2(gamma=2., alpha=.25, axis=-1):
    def focal_loss_k2(y_true, y_pred):
        pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
        pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
        pt_1 = K.clip(pt_1, _EPSILON, 1.0 - _EPSILON)
        pt_0 = K.clip(pt_0, _EPSILON, 1.0 - _EPSILON)
        return -K.mean(alpha * K.pow(1.-pt_1, gamma) * K.log(pt_1) + (1-alpha) * K.pow(pt_0, gamma) * K.log(1.-pt_0), axis=axis)
    return focal_loss_k2

def focal_loss_wrapper_n2(gamma=2., alpha=.25, axis=-1):
    def focal_loss_n2(y_true, y_pred):
        pt_1 = np.where(np.equal(y_true, 1), y_pred, np.ones_like(y_pred))
        pt_0 = np.where(np.equal(y_true, 0), y_pred, np.zeros_like(y_pred))
        pt_1 = np.clip(pt_1, EPS, 1. - EPS)
        pt_0 = np.clip(pt_0, EPS, 1. - EPS)
        return -np.mean(alpha * np.power(1.-pt_1, gamma) * np.log(pt_1) + (1-alpha) * np.power(pt_0, gamma) * np.log(1.-pt_0), axis=axis)
    return focal_loss_n2
    

In [140]:
def check_loss(_shape):
    if _shape == '2d':
        shape = (6, 7)
    elif _shape == '3d':
        shape = (5, 6, 7)
    elif _shape == '4d':
        shape = (8, 5, 6, 7)
    elif _shape == '5d':
        shape = (9, 8, 5, 6, 7)

    y_true = np.random.randint(0, 2, (shape)) * 1.0
    y_pred = np.random.random(shape)

    focal_loss_fc = focal_loss_to_bce_wrapper(gamma=0, alpha=0)
    focal_loss_k1 = focal_loss_wrapper_k1(gamma=0, alpha=0.5)
    focal_loss_k2 = focal_loss_wrapper_k2(gamma=0, alpha=0.5)
    focal_loss_n1 = focal_loss_wrapper_n1(gamma=0, alpha=0.5)
    focal_loss_n2 = focal_loss_wrapper_n2(gamma=0, alpha=0.5)

    out_ce = K.eval(binary_crossentropy(K.variable(y_true), K.variable(y_pred)))
    out_fc = focal_loss_fc(y_true, y_pred)
    out_k1 = K.eval(focal_loss_k1(K.variable(y_true), K.variable(y_pred)))
    out_k2 = K.eval(focal_loss_k2(K.variable(y_true), K.variable(y_pred)))
    out_n1 = focal_loss_n1(y_true, y_pred)
    out_n2 = focal_loss_n2(y_true, y_pred)
    
    print('ce', np.sum(out_ce), out_ce.shape)
    print('fc', np.sum(out_fc), out_fc.shape)
    print('k1', np.linalg.norm(out_k1), out_k1.shape)
    print('k2', np.linalg.norm(out_k2), out_k2.shape)
    print('n1', np.linalg.norm(out_n1), out_n1.shape)
    print('n2', np.linalg.norm(out_n2), out_n2.shape)
    
    assert out_k1.shape == out_ce.shape
    assert out_k1.shape == out_fc.shape
    assert out_k1.shape == out_k2.shape
    assert out_k1.shape == out_n1.shape
    assert out_k1.shape == out_n2.shape
    assert out_k1.shape == shape[:-1]
    print(np.linalg.norm(out_k2-out_n2))

def test_loss():
    shape_list = ['2d', '3d', '4d', '5d']
    for _shape in shape_list:
        check_loss(_shape)
        print('======================')

In [103]:
test_loss()

ce 6.094289 (6,)
fc 6.094289466447662 (6,)
k1 1.3171521 (6,)
k2 1.3171524 (6,)
n1 1.3171522854057975 (6,)
n2 1.3171524232972118 (6,)
1.2360156827758254e-07
ce 28.42365 (5, 6)
fc 28.423650374142237 (5, 6)
k1 2.8164084 (5, 6)
k2 2.8164086 (5, 6)
n1 2.816408187125766 (5, 6)
n2 2.816408487895681 (5, 6)
2.167798827917469e-07
ce 242.01654 (8, 5, 6)
fc 242.01654257703984 (8, 5, 6)
k1 8.441095 (8, 5, 6)
k2 8.441096 (8, 5, 6)
n1 8.441095835796608 (8, 5, 6)
n2 8.441096690265704 (8, 5, 6)
1.4951486450247199e-06
ce 2150.2935 (9, 8, 5, 6)
fc 2150.2934860677224 (9, 8, 5, 6)
k1 24.600975 (9, 8, 5, 6)
k2 24.600979 (9, 8, 5, 6)
n1 24.60097522692255 (9, 8, 5, 6)
n2 24.600977831849523 (9, 8, 5, 6)
1.63639511444481e-05


In [134]:
Metric = namedtuple('Metric', 'header hfmt func fmt kwargs')

metrics = [
    Metric('recall', "{:>8}", recall, "{:8.5f}", {}), 
    Metric('f2', "{:>8}", fbeta, "{:8.5f}", {}), 
    Metric('iou', "{:>8}", iou, "{:8.5f}", {}), 
    Metric('dice', "{:>8}", dice_score, "{:8.5f}", {})]

losses1 = [
    Metric('brier2', "{:>8}", brier_score_loss, "{:8.5f}", {}),
]
losses3 = [
#     Metric('bce', "{:>8}", bce_np, "{:8.4f}", {}),
#     Metric('brier1', "{:>8}", brier_loss, "{:8.4f}", {}),
#     Metric('fl_00_0', "{:>8}", focal_loss_wrapper_n1(gamma=0, alpha=0), "{:8.4f}", {}),
    Metric('fl_00_1', "{:>8}", focal_loss_wrapper_ns(gamma=0, alpha=0.25), "{:8.4f}", {}),
    Metric('fl_00_2', "{:>8}", focal_loss_wrapper_ns(gamma=0, alpha=0.5), "{:8.4f}", {}),
    Metric('fl_00_3', "{:>8}", focal_loss_wrapper_ns(gamma=0, alpha=0.75), "{:8.4f}", {}),
#     Metric('fl_00_4', "{:>8}", focal_loss_wrapper_n1(gamma=0, alpha=1), "{:8.4f}", {}),
#     Metric('fl_05_0', "{:>8}", focal_loss_wrapper_n1(gamma=0.5, alpha=0), "{:8.4f}", {}),
    Metric('fl_05_1', "{:>8}", focal_loss_wrapper_ns(gamma=0.5, alpha=0.25), "{:8.4f}", {}),
    Metric('fl_05_2', "{:>8}", focal_loss_wrapper_ns(gamma=0.5, alpha=0.5), "{:8.4f}", {}),
    Metric('fl_05_3', "{:>8}", focal_loss_wrapper_ns(gamma=0.5, alpha=0.75), "{:8.4f}", {}),
#     Metric('fl_05_4', "{:>8}", focal_loss_wrapper_n1(gamma=0.5, alpha=1), "{:8.4f}", {}),
#     Metric('fl_10_0', "{:>8}", focal_loss_wrapper_n1(gamma=1, alpha=0), "{:8.4f}", {}),
    Metric('fl_10_1', "{:>8}", focal_loss_wrapper_ns(gamma=1, alpha=0.25), "{:8.4f}", {}),
    Metric('fl_10_2', "{:>8}", focal_loss_wrapper_ns(gamma=1, alpha=0.5), "{:8.4f}", {}),
    Metric('fl_10_3', "{:>8}", focal_loss_wrapper_ns(gamma=1, alpha=0.75), "{:8.4f}", {}),
#     Metric('fl_10_4', "{:>8}", focal_loss_wrapper_n1(gamma=1, alpha=1), "{:8.4f}", {}),
#     Metric('fl_20_0', "{:>8}", focal_loss_wrapper_ns(gamma=2, alpha=0), "{:8.4f}", {}),
    Metric('fl_20_1', "{:>8}", focal_loss_wrapper_ns(gamma=2, alpha=0.25), "{:8.4f}", {}),
    Metric('fl_20_2', "{:>8}", focal_loss_wrapper_ns(gamma=2, alpha=0.5), "{:8.4f}", {}),
    Metric('fl_20_3', "{:>8}", focal_loss_wrapper_ns(gamma=2, alpha=0.75), "{:8.4f}", {}),
#     Metric('fl_20_4', "{:>8}", focal_loss_wrapper_ns(gamma=2, alpha=1), "{:8.4f}", {}),
]


In [128]:
y_none = np.array([
    [0,0,0,0,0,0],
    [0,0,0,0,0,0],
    [0,0,0,0,0,0],
    [0,0,0,0,0,0],
    [0,0,0,0,0,0],
    [0,0,0,0,0,0]]
)

y_sm1 = np.array([
    [0,0,0,0,0,0],
    [0,1,1,0,0,0],
    [0,1,1,0,0,0],
    [0,0,0,0,0,0],
    [0,0,0,0,0,0],
    [0,0,0,0,0,0]]
)

y_sm2 = np.array([
    [0,0,0,0,0,0],
    [0,0,0,0,0,0],
    [0,0,0,0,0,0],
    [0,0,0,1,1,0],
    [0,0,0,1,1,0],
    [0,0,0,0,0,0]]
)

y_big = np.array([
    [0,0,0,0,0,0],
    [0,1,1,1,1,0],
    [0,1,1,1,1,0],
    [0,1,1,1,1,0],
    [0,1,1,1,1,0],
    [0,0,0,0,0,0]]
)

In [129]:
y_none_2 = enlarge(y_none, y_none)
y_sm1_2 = enlarge(y_sm1, y_none)
y_sm2_2 = enlarge(y_sm2, y_none)
y_big_2 = enlarge(y_big, y_none)

In [50]:
scale_factor = 0.99 
print_headers(metrics)
print_metrics(y_none, y_none*scale_factor, metrics)
print_metrics(y_none, y_big*scale_factor, metrics)
print_metrics(y_big, y_big*scale_factor, metrics)
print_metrics(y_sm1, y_sm2*scale_factor, metrics)
print_metrics(y_sm1, y_big*scale_factor, metrics)
print_metrics(y_big, y_sm2*scale_factor, metrics)
print_metrics(y_big, y_none*scale_factor, metrics)
print('')
print_headers(metrics)
print_metrics(y_none_2, y_none_2*scale_factor, metrics)
print_metrics(y_none_2, y_big_2*scale_factor, metrics)
print_metrics(y_big_2, y_big_2*scale_factor, metrics)
print_metrics(y_sm1_2, y_sm2_2*scale_factor, metrics)
print_metrics(y_sm1_2, y_big_2*scale_factor, metrics)
print_metrics(y_big_2, y_sm2_2*scale_factor, metrics)
print_metrics(y_big_2, y_none_2*scale_factor, metrics)

  recall       f2      iou     dice
 1.00000  1.00000  0.00000  1.00000
 1.00000  0.00006  0.00000  0.00006
 0.99000  0.99198  1.00000  0.99498
 0.00025  0.00005  0.00000  0.00013
 0.99000  0.62187  0.25000  0.39922
 0.24755  0.29136  0.25000  0.39682
 0.00006  0.00002  0.00000  0.00006

  recall       f2      iou     dice
 1.00000  1.00000  0.00000  1.00000
 1.00000  0.00006  0.00000  0.00006
 0.99000  0.99198  1.00000  0.99498
 0.00025  0.00005  0.00000  0.00013
 0.99000  0.62187  0.25000  0.39922
 0.24755  0.29136  0.25000  0.39682
 0.00006  0.00002  0.00000  0.00006


In [138]:
scale_factor = 0.99 
losses = losses3
print_headers(losses)
print_losses(y_none, y_none*scale_factor, losses)
print_losses(y_big, y_big*scale_factor, losses)
print_losses(y_sm1, y_big*scale_factor, losses)
print_losses(y_none, y_big*scale_factor, losses)
print_losses(y_sm1, y_sm2*scale_factor, losses)
print_losses(y_big, y_sm2*scale_factor, losses)
print_losses(y_big, y_none*scale_factor, losses)
print('')
print_headers(losses)
print_losses(y_none_2, y_none_2*scale_factor, losses)
print_losses(y_big_2, y_big_2*scale_factor, losses)
print_losses(y_sm1_2, y_big_2*scale_factor, losses)
print_losses(y_none_2, y_big_2*scale_factor, losses)
print_losses(y_sm1_2, y_sm2_2*scale_factor, losses)
print_losses(y_big_2, y_sm2_2*scale_factor, losses)
print_losses(y_big_2, y_none_2*scale_factor, losses)

 fl_00_1  fl_00_2  fl_00_3  fl_05_1  fl_05_2  fl_05_3  fl_10_1  fl_10_2  fl_10_3  fl_20_1  fl_20_2  fl_20_3
  0.0000   0.0000   0.0000   0.0000   0.0000   0.0000   0.0000   0.0000   0.0000   0.0000   0.0000   0.0000
  0.0067   0.0134   0.0201   0.0007   0.0013   0.0020   0.0001   0.0001   0.0002   0.0000   0.0000   0.0000
  6.9094   4.6085   2.3076   6.8733   4.5824   2.2915   6.8387   4.5592   2.2796   6.7703   4.5135   2.2568
  9.2103   6.1402   3.0701   9.1642   6.1094   3.0547   9.1182   6.0788   3.0394   9.0271   6.0180   3.0090
  4.9596   6.8492   8.7387   4.9481   6.8415   8.7349   4.9366   6.8338   8.7310   4.9138   6.8186   8.7234
  7.9729  15.9457  23.9186   7.9714  15.9427  23.9141   7.9712  15.9424  23.9136   7.9712  15.9424  23.9136
 10.6283  21.2565  31.8848  10.6283  21.2565  31.8848  10.6283  21.2565  31.8848  10.6283  21.2565  31.8848

 fl_00_1  fl_00_2  fl_00_3  fl_05_1  fl_05_2  fl_05_3  fl_10_1  fl_10_2  fl_10_3  fl_20_1  fl_20_2  fl_20_3
  0.0000   0.0000   0.0000 

In [139]:
scale_factor = 0.99 
losses = losses3
print_headers(losses)
print_losses(y_none[..., None], y_none[..., None]*scale_factor, losses)
print_losses(y_big[..., None], y_big[..., None]*scale_factor, losses)
print_losses(y_sm1[..., None], y_big[..., None]*scale_factor, losses)
print_losses(y_none[..., None], y_big[..., None]*scale_factor, losses)
print_losses(y_sm1[..., None], y_sm2[..., None]*scale_factor, losses)
print_losses(y_big[..., None], y_sm2[..., None]*scale_factor, losses)
print_losses(y_big[..., None], y_none[..., None]*scale_factor, losses)
print('')
print_headers(losses)
print_losses(y_none_2[..., None], y_none_2[..., None]*scale_factor, losses)
print_losses(y_big_2[..., None], y_big_2[..., None]*scale_factor, losses)
print_losses(y_sm1_2[..., None], y_big_2[..., None]*scale_factor, losses)
print_losses(y_none_2[..., None], y_big_2[..., None]*scale_factor, losses)
print_losses(y_sm1_2[..., None], y_sm2_2[..., None]*scale_factor, losses)
print_losses(y_big_2[..., None], y_sm2_2[..., None]*scale_factor, losses)
print_losses(y_big_2[..., None], y_none_2[..., None]*scale_factor, losses)

 fl_00_1  fl_00_2  fl_00_3  fl_05_1  fl_05_2  fl_05_3  fl_10_1  fl_10_2  fl_10_3  fl_20_1  fl_20_2  fl_20_3
  0.0000   0.0000   0.0000   0.0000   0.0000   0.0000   0.0000   0.0000   0.0000   0.0000   0.0000   0.0000
  0.0402   0.0804   0.1206   0.0040   0.0080   0.0121   0.0004   0.0008   0.0012   0.0000   0.0000   0.0000
 41.4566  27.6511  13.8457  41.2398  27.4945  13.7493  41.0322  27.3549  13.6777  40.6217  27.0812  13.5406
 55.2620  36.8414  18.4207  54.9850  36.6567  18.3283  54.7094  36.4729  18.2365  54.1623  36.1082  18.0541
 29.7579  41.0951  52.4323  29.6886  41.0489  52.4092  29.6197  41.0030  52.3863  29.4830  40.9118  52.3407
 47.8372  95.6744 143.5116  47.8282  95.6563 143.4845  47.8273  95.6545 143.4818  47.8271  95.6543 143.4814
 63.7695 127.5391 191.3086  63.7695 127.5391 191.3086  63.7695 127.5391 191.3086  63.7695 127.5391 191.3086

 fl_00_1  fl_00_2  fl_00_3  fl_05_1  fl_05_2  fl_05_3  fl_10_1  fl_10_2  fl_10_3  fl_20_1  fl_20_2  fl_20_3
  0.0000   0.0000   0.0000 

In [108]:
scale_factor = 0.99 
losses = losses1
print_headers(losses)
print_metrics(y_none.flatten(), y_none.flatten()*scale_factor, losses)
print_metrics(y_none.flatten(), y_big.flatten()*scale_factor, losses)
print_metrics(y_big.flatten(), y_big.flatten()*scale_factor, losses)
print_metrics(y_sm1.flatten(), y_sm2.flatten()*scale_factor, losses)
print_metrics(y_sm1.flatten(), y_big.flatten()*scale_factor, losses)
print_metrics(y_big.flatten(), y_sm2.flatten()*scale_factor, losses)
print_metrics(y_big.flatten(), y_none.flatten()*scale_factor, losses)
print('')
print_headers(losses)
print_metrics(y_none_2.flatten(), y_none_2.flatten()*scale_factor, losses)
print_metrics(y_none_2.flatten(), y_big_2.flatten()*scale_factor, losses)
print_metrics(y_big_2.flatten(), y_big_2.flatten()*scale_factor, losses)
print_metrics(y_sm1_2.flatten(), y_sm2_2.flatten()*scale_factor, losses)
print_metrics(y_sm1_2.flatten(), y_big_2.flatten()*scale_factor, losses)
print_metrics(y_big_2.flatten(), y_sm2_2.flatten()*scale_factor, losses)
print_metrics(y_big_2.flatten(), y_none_2.flatten()*scale_factor, losses)

  brier2
 0.00000
 0.43560
 0.00004
 0.22001
 0.32671
 0.33334
 0.44444

  brier2
 0.00000
 0.10890
 0.00001
 0.05500
 0.08168
 0.08334
 0.11111


## Test soft dice loss

In [3]:
def soft_dice_coef_n0(y_true, y_pred, axis=-1, smooth=1e-3):
    AB = np.sum(y_true * y_pred, axis=axis)
    A = np.sum(y_true + y_pred, axis=axis)
#     B = np.sum(y_pred, axis=axis)
    return (2 * AB + smooth) / (A + smooth)

def soft_dice_coef_k0(y_true, y_pred, axis=-1, smooth=1e-3):
    AB = K.sum(y_true * y_pred, axis=axis)
    A = K.sum(y_true, axis=axis)
    B = K.sum(y_pred, axis=axis)
    return (2. * AB + smooth) / (A + B + smooth)

def soft_dice_loss_n0(y_true, y_pred, axis=-1, smooth=1e-3):
    return 1 - soft_dice_coef_n0(y_true, y_pred, axis=axis, smooth=smooth)

def soft_dice_loss_k0(y_true, y_pred, axis=-1, smooth=1e-3):
    return 1 - soft_dice_coef_k0(y_true, y_pred, axis=axis, smooth=smooth)

# https://www.jeremyjordan.me/semantic-segmentation/
def soft_dice_loss_n2(y_true, y_pred, epsilon=1e-6):
    """
    Soft dice loss calculation for arbitrary batch size, number of classes, and number of spatial dimensions.
    Assumes the `channels_last` format.

    # Arguments
        y_true: b x X x Y( x Z...) x c One hot encoding of ground truth
        y_pred: b x X x Y( x Z...) x c Network output, must sum to 1 over c channel (such as after softmax)
        epsilon: Used for numerical stability to avoid divide by zero errors

    # References
        V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
        https://arxiv.org/abs/1606.04797
        More details on Dice loss formulation
        https://mediatum.ub.tum.de/doc/1395260/1395260.pdf (page 72)

        Adapted from https://github.com/Lasagne/Recipes/issues/99#issuecomment-347775022
    """

    # skip the batch and class axis for calculating Dice score
    axes = tuple(range(1, len(y_pred.shape) - 1))
    numerator = 2. * np.sum(y_pred * y_true, axes)
    denominator = np.sum(np.square(y_pred) + np.square(y_true), axes)
    return 1 - numerator / (denominator + epsilon)  # average over classes and batch

def soft_dice_loss_k2(y_true, y_pred, epsilon=1e-6):
    axes = tuple(range(1, len(y_pred.shape) - 1))
    numerator = 2. * K.sum(y_pred * y_true, axes)
    denominator = K.sum(K.square(y_pred) + K.square(y_true), axes)
    return 1 - numerator / (denominator + epsilon)

In [4]:
def check_loss(_shape):
    if _shape == '2d':
        shape = (6, 7)
    elif _shape == '3d':
        shape = (5, 6, 7)
    elif _shape == '4d':
        shape = (8, 5, 6, 7)
    elif _shape == '5d':
        shape = (9, 8, 5, 6, 7)

    y_true = np.random.randint(0, 2, (shape)) * 1.0
    y_pred = np.random.random(shape)

    out_n0 = soft_dice_loss_n0(y_true, y_pred)
    out_k0 = K.eval(soft_dice_loss_k0(K.variable(y_true), K.variable(y_pred)))
    
    print('n0', out_n0.shape, np.mean(out_n0))
    print('k0', out_k0.shape, K.eval(K.mean(K.variable(out_k0))))
    
    assert out_k0.shape == out_n0.shape
    assert out_k0.shape == shape[:-1]
    print(np.linalg.norm(out_k0-out_n0))

def test_loss():
    shape_list = ['2d', '3d', '4d', '5d']
    for _shape in shape_list:
        check_loss(_shape)
        print('======================')

In [5]:
test_loss()

n0 (6,) 0.5589039955482429
k0 (6,) 0.558904
8.952415159745165e-08
n0 (5, 6) 0.447953447155229
k0 (5, 6) 0.44795343
2.1955494995942978e-07
n0 (8, 5, 6) 0.5294329270967677
k0 (8, 5, 6) 0.52943295
5.972190077369066e-07
n0 (9, 8, 5, 6) 0.525747611982823
k0 (9, 8, 5, 6) 0.5257476
1.6635018929073494e-06


## Test mixing losses

In [11]:
def bce_soft_dice_loss_n0(y_true, y_pred, frac=0.5):
    return bce_np(y_true, y_pred) * frac + soft_dice_loss_n0(y_true, y_pred) * (1 - frac)

def bce_soft_dice_loss_wrapper_k0(frac=0.5):
    def bce_soft_dice_loss_k0(y_true, y_pred):
        return binary_crossentropy(y_true, y_pred) * frac + soft_dice_loss_k0(y_true, y_pred) * (1 - frac)
    return bce_soft_dice_loss_k0

In [17]:
def check_loss(_shape):
    if _shape == '2d':
        shape = (6, 7)
    elif _shape == '3d':
        shape = (5, 6, 7)
    elif _shape == '4d':
        shape = (8, 5, 6, 7)
    elif _shape == '5d':
        shape = (9, 8, 5, 6, 7)

    y_true = np.random.randint(0, 2, (shape)) * 1.0
    y_pred = np.random.random(shape)

    frac = 0.3
    
    bce_soft_dice_loss_k0 = bce_soft_dice_loss_wrapper_k0(frac=frac)
    
    bce_n1 = bce_np(y_true, y_pred)
    dice_n1 = soft_dice_loss_n0(y_true, y_pred)
    out_n1 = np.mean(bce_n1) * frac + np.mean(dice_n1) * (1 - frac)
    print('bce__n1', bce_n1.shape, np.mean(bce_n1))
    print('dice_n1', dice_n1.shape, np.mean(dice_n1))
    print('out__n1', out_n1.shape, out_n1)

    out_n0 = bce_soft_dice_loss_n0(y_true, y_pred, frac=frac)
    print('out__n0', out_n0.shape, np.mean(out_n0))

    bce_k1 = K.eval(binary_crossentropy(K.variable(y_true), K.variable(y_pred)))
    dice_k1 = K.eval(soft_dice_loss_k0(K.variable(y_true), K.variable(y_pred)))
    out_k1 = K.eval(K.mean(K.variable(bce_k1))) * frac + K.eval(K.mean(K.variable(dice_k1))) * (1 - frac)
    print('bce__k1', bce_k1.shape, K.eval(K.mean(K.variable(bce_n1))))
    print('dice_k1', dice_k1.shape, K.eval(K.mean(K.variable(dice_n1))))
    print('out__k1', out_k1.shape, out_k1)

    out_k0 = K.eval(bce_soft_dice_loss_k0(K.variable(y_true), K.variable(y_pred)))
    print('out__k0', out_k0.shape, K.eval(K.mean(K.variable(out_k0))))
    
#     assert out_k1.shape == out_k0.shape
#     assert out_k1.shape == out_k2.shape
#     assert out_k1.shape == out_n0.shape
#     assert out_k1.shape == out_n1.shape
#     assert out_k1.shape == out_n2.shape
    assert out_k0.shape == shape[:-1]
    print(np.linalg.norm(out_k1-out_n1))

def test_loss():
    shape_list = ['2d', '3d', '4d', '5d']
    for _shape in shape_list:
        check_loss(_shape)
        print('======================')


In [18]:
test_loss()

bce__n1 (6,) 1.0835870316421043
dice_n1 (6,) 0.4906300013296098
out__n1 () 0.6685171104233582
out__n0 (6,) 0.6685171104233582
bce__k1 (6,) 1.083587
dice_k1 (6,) 0.49063
out__k1 () 0.6685171157121659
out__k0 (6,) 0.6685171
5.288807725101208e-09
bce__n1 (5, 6) 1.1214160428333548
dice_n1 (5, 6) 0.5585305986973864
out__n1 () 0.7273962319381768
out__n0 (5, 6) 0.727396231938177
bce__k1 (5, 6) 1.1214161
dice_k1 (5, 6) 0.55853057
out__k1 () 0.7273962676525116
out__k0 (5, 6) 0.72739625
3.571433482285613e-08
bce__n1 (8, 5, 6) 1.0148600433635961
dice_n1 (8, 5, 6) 0.5185417804015048
out__n1 () 0.6674372592901321
out__n0 (8, 5, 6) 0.6674372592901321
bce__k1 (8, 5, 6) 1.01486
dice_k1 (8, 5, 6) 0.5185418
out__k1 () 0.6674319803714752
out__k0 (8, 5, 6) 0.66743207
5.278918656870246e-06
bce__n1 (9, 8, 5, 6) 1.002251914616841
dice_n1 (9, 8, 5, 6) 0.5241681726755963
out__n1 () 0.6675932952579697
out__n0 (9, 8, 5, 6) 0.6675932952579698
bce__k1 (9, 8, 5, 6) 1.002252
dice_k1 (9, 8, 5, 6) 0.5241682
out__k1 ()

In [7]:
def focal_soft_dice_loss_wrapper(gamma=2., alpha=.25, frac=0.5):
    focal_loss = focal_loss_wrapper(gamma=gamma, alpha=alpha)
    def focal_soft_dice_loss(y_true, y_pred, frac=frac):
        return focal_loss(y_true, y_pred) * frac + soft_dice_loss(y_true, y_pred) * (1 - frac)
    return focal_soft_dice_loss

In [71]:
rand_batch_true = np.random.choice(2, (5, 100, 100, 1))
rand_batch_pred = np.random.choice(2, (5, 100, 100, 1))
average_sample_loss(rand_batch_true, rand_batch_pred, fbeta, axis=(1, 2, 3)) / rand_batch_true.shape[0]

array([0.10176404, 0.1008238 , 0.10009978, 0.10044594, 0.10178322])

In [180]:
def get_normalizer(old_norm_weights, index_array, new_batch_scores, alpha=None):
    if alpha is None:
        baseline = old_norm_weights * np.sum(new_batch_scores - old_norm_weights[index_array])
    else:
        baseline = old_norm_weights * alpha
    
    other_examples = np.ones((len(old_norm_weights),), dtype=int)
    other_examples[index_array] = 0
    index_other = np.where(other_examples)[0]
    
    baseline[index_array] = new_batch_scores
    print(baseline, np.sum(baseline))
    return baseline

In [196]:
n_samples = 6
n_subs = 3
index_array = np.array([0, 1, 2])
# index_array = np.random.choice(n_samples, n_subs, replace=False)
print(index_array)
other_examples = np.ones((n_samples,), dtype=int)
other_examples[index_array] = 0
index_other = np.where(other_examples)[0]


[0 1 2]


In [208]:
print(old_weights)

[0.96342001 0.54880108 0.7793285  0.10600293 0.99062654 0.61411048]


In [204]:
old_weights = np.random.random((n_samples,))
print(old_weights)
old_norm_weights = old_weights / np.sum(old_weights)
print(old_norm_weights)
old_norm_sample = old_norm_weights[index_array]
print(old_norm_sample)

[0.24071722 0.13712178 0.19472067 0.02648557 0.24751496 0.15343979]
[0.24071722 0.13712178 0.19472067]


In [240]:
batch_scores = np.random.random((12,))
print(batch_scores)

[0.24724735 0.37993896 0.31944128 0.09919289 0.86815457 0.61496554
 0.72502535 0.7692111  0.93714747 0.59123661 0.082467   0.05772703]


In [None]:
batch_scores *= 0.1

In [241]:
scaled_scores = batch_scores / np.sum(batch_scores)
print(scaled_scores, np.sum(scaled_scores))

[0.04343956 0.06675251 0.05612351 0.01742747 0.15252845 0.10804497
 0.12738168 0.1351448  0.16465    0.10387597 0.01448885 0.01014222] 1.0


In [242]:
scaled_scores = (batch_scores - np.min(batch_scores)) / (np.max(batch_scores) - np.min(batch_scores))
print(scaled_scores, np.sum(scaled_scores))

[0.21550592 0.36639121 0.29759854 0.04715134 0.9215473  0.63364288
 0.75879328 0.80903745 1.         0.60666042 0.02813213 0.        ] 5.684460484113896


In [243]:
norm_scores = batch_scores / np.linalg.norm(batch_scores)
print(norm_scores, np.sum(norm_scores))

[0.1267597  0.19478854 0.16377236 0.05085458 0.44508876 0.31528285
 0.37170874 0.39436205 0.4804603  0.30311742 0.04227949 0.02959571] 2.918070515389947


In [236]:
batch_scores *= 0.1

In [237]:
scaled_scores = batch_scores / np.sum(batch_scores)
print(scaled_scores, np.sum(scaled_scores))

[0.13613562 0.20192301 0.13856829 0.14336397 0.15893467 0.22107444] 1.0000000000000002


In [238]:
scaled_scores = (batch_scores - np.min(batch_scores)) / (np.max(batch_scores) - np.min(batch_scores))
print(scaled_scores, np.sum(scaled_scores))

[0.         0.7745268  0.0286402  0.08510061 0.26841729 1.        ] 2.156684898648847


In [239]:
norm_scores = batch_scores / np.linalg.norm(batch_scores)
print(norm_scores, np.sum(norm_scores))

[0.32711837 0.48519796 0.33296378 0.34448726 0.38190187 0.53121665] 2.4028858817096452


In [235]:
new_batch_scores = (scaled_scores + old_norm_sample) / 2.0
print(new_batch_scores, np.sum(new_batch_scores))

ValueError: operands could not be broadcast together with shapes (6,) (3,) 

In [201]:
# print(new_batch_scores - old_norm_sample)
print(new_batch_scores - old_norm_sample, np.sum(new_batch_scores - old_norm_sample))

[ 0.42688128 -0.05222697  0.32290874] 0.6975630530832811


In [202]:
s1 = 1 - np.sum(new_batch_scores)
s0 = np.sum(old_norm_weights[index_other])
alpha = s1 / s0
s1, s0, alpha

(-0.2312239964377052, 0.46633905664557596, -0.4958280743219811)

In [203]:
get_normalizer(old_norm_weights, index_array, new_batch_scores, 0.36)
get_normalizer(old_norm_weights, index_array, new_batch_scores, 0.35)
get_normalizer(old_norm_weights, index_array, new_batch_scores, alpha)
get_normalizer(old_norm_weights, index_array, new_batch_scores, 0.34)
get_normalizer(old_norm_weights, index_array, new_batch_scores, 0.33)


[0.57311872 0.05222697 0.60587831 0.03260416 0.06483034 0.07044757] 1.3991060568301126
[0.57311872 0.05222697 0.60587831 0.03169849 0.06302949 0.06849069] 1.3944426662636566
[ 0.57311872  0.05222697  0.60587831 -0.04490571 -0.08929083 -0.09702745] 0.9999999999999999
[0.57311872 0.05222697 0.60587831 0.03079282 0.06122865 0.06653381] 1.389779275697201
[0.57311872 0.05222697 0.60587831 0.02988714 0.05942781 0.06457694] 1.3851158851307452


array([0.57311872, 0.05222697, 0.60587831, 0.02988714, 0.05942781,
       0.06457694])

In [135]:
np.sum(old_norm_weights[3:] - new_norm_weights[3:])

0.2612719302317831

In [137]:
1.1825/3

0.3941666666666667