In [None]:
import torch
from torch.nn.modules.loss import _Loss 
import torch.nn as nn

class SoftDiceLoss(_Loss):
    '''
    Soft_Dice = 2*|dot(A, B)| / (|dot(A, A)| + |dot(B, B)| + eps)
    eps is a small constant to avoid zero division,
    '''
    def __init__(self, new_loss):
        super(SoftDiceLoss, self).__init__()
        self.new_loss = new_loss

    def forward(self, y_pred, y_true, eps=1e-8):   # put 1,2,4 together   (2, 1, 4) 1+4: TC; 4:ET; 1+2+4: WT

        if self.new_loss:
            y_pred[:,0,:,:,:] = torch.sum(y_pred, dim=1)
            y_pred[:,1,:,:,:] = torch.sum(y_pred[:, 1:, :, :, :], dim=1)
            y_true[:,0,:,:,:] = torch.sum(y_true, dim=1)
            y_true[:,1,:,:,:] = torch.sum(y_true[:, 1:, :, :, :], dim=1)

        intersection = torch.sum(torch.mul(y_pred, y_true), dim=[-3, -2, -1])
        union = torch.sum(torch.mul(y_pred, y_pred),
                          dim=[-3, -2, -1]) + torch.sum(torch.mul(y_true, y_true), dim=[-3, -2, -1]) + eps

        dice = 2 * intersection / union   # (bs, 3)
        dice_loss = 1 - torch.mean(dice)  # loss small, better
        # means = torch.mean(dice, dim=2)
        # dice_loss = 1 - 0.5*means[0] - 0.25*means[1] - 0.25*means[2]  # loss small, better

        return dice_loss

