In [31]:
import torch
import torch.nn.functional as F
import tensorflow as tf
import numpy as np

In [6]:
def compute_dice_torch(logits, labels, epsilon=1e-10):
    '''
    Computes the dice score between logits and labels
    :param logits: Network output before softmax
    :param labels: ground truth label masks
    :param epsilon: A small constant to avoid division by 0
    :return: dice (per label, per image in the batch)
    '''

    #TODO check dimension
    prediction = F.softmax(logits, dim=1)
    intersection = torch.mul(prediction, labels)
    # labels = [8,2,144,112,48]

    reduction_axes = [2,3,4]        
    # compute area of intersection, area of GT, area of prediction (per image per label)
    tp = torch.sum(intersection, dim=reduction_axes)
    tp_plus_fp = torch.sum(prediction, dim=reduction_axes)
    tp_plus_fn = torch.sum(labels, dim=reduction_axes)

    # compute dice (per image per label)
    dice = 2 * tp / (tp_plus_fp + tp_plus_fn + epsilon)

    # =============================
    # if a certain label is missing in the GT of a certain image and also in the prediction,
    # dice[this_image,this_label] will be incorrectly computed as zero whereas it should be 1.
    # =============================

    # mean over all images in the batch and over all labels.
    mean_dice = torch.mean(dice)

    # mean over all images in the batch and over all foreground labels.
    mean_fg_dice = torch.mean(dice[:, 1:])

    return dice, mean_dice, mean_fg_dice


In [7]:
## ======================================================================
def compute_dice(logits, labels, epsilon=1e-10):
    '''
    Computes the dice score between logits and labels
    :param logits: Network output before softmax
    :param labels: ground truth label masks
    :param epsilon: A small constant to avoid division by 0
    :return: dice (per label, per image in the batch)
    '''

    with tf.name_scope('dice'):

        prediction = tf.nn.softmax(logits)
        intersection = tf.multiply(prediction, labels)
        
        reduction_axes = [1,2,3]        
        # compute area of intersection, area of GT, area of prediction (per image per label)
        tp = tf.reduce_sum(intersection, axis=reduction_axes) 
        tp_plus_fp = tf.reduce_sum(prediction, axis=reduction_axes) 
        tp_plus_fn = tf.reduce_sum(labels, axis=reduction_axes)

        # compute dice (per image per label)
        dice = 2 * tp / (tp_plus_fp + tp_plus_fn + epsilon)
        
        # =============================
        # if a certain label is missing in the GT of a certain image and also in the prediction,
        # dice[this_image,this_label] will be incorrectly computed as zero whereas it should be 1.
        # =============================
        
        # mean over all images in the batch and over all labels.
        mean_dice = tf.reduce_mean(dice)
        
        # mean over all images in the batch and over all foreground labels.
        mean_fg_dice = tf.reduce_mean(dice[:,1:])
        
    return dice, mean_dice, mean_fg_dice

## ======================================================================

In [32]:
np_input = np.random.rand(8,2,144,112,48)
np_labels = np.random.rand(8,2,144,112,48)

In [34]:
np_tf_input = np.transpose(np_input, (0,2,3,4,1))
np_tf_labels = np.transpose(np_labels, (0,2,3,4,1))

In [40]:
logits = tf.convert_to_tensor(np_tf_input, dtype=tf.float64)
labels = tf.convert_to_tensor(np_tf_labels, dtype=tf.float64)

In [41]:
compute_dice(logits, labels)

(<tf.Tensor: shape=(8, 2), dtype=float64, numpy=
 array([[0.49994906, 0.49997942],
        [0.49998991, 0.50003429],
        [0.49985629, 0.50030285],
        [0.49991537, 0.5000407 ],
        [0.50032763, 0.50015519],
        [0.49999868, 0.50010628],
        [0.49980817, 0.49987502],
        [0.49971379, 0.49980308]])>,
 <tf.Tensor: shape=(), dtype=float64, numpy=0.49999098155824306>,
 <tf.Tensor: shape=(), dtype=float64, numpy=0.5000371016073819>)

In [42]:
logits_torch = torch.from_numpy(np_input)
labels_torch = torch.from_numpy(np_labels)

In [43]:
compute_dice_torch(logits_torch, labels_torch)

(tensor([[0.4999, 0.5000],
         [0.5000, 0.5000],
         [0.4999, 0.5003],
         [0.4999, 0.5000],
         [0.5003, 0.5002],
         [0.5000, 0.5001],
         [0.4998, 0.4999],
         [0.4997, 0.4998]], dtype=torch.float64),
 tensor(0.5000, dtype=torch.float64),
 tensor(0.5000, dtype=torch.float64))