# <font style="color:blue">Soft-Dice Loss</font>

In the semantic-segmentation task, we need to apply a particular group of loss functions. They are all based on one fact: the objects of some classes tend to have several adjoined pixels in the image. Thus, we will introduce an example of one  such special loss function - `soft-Dice loss`. We will briefly describe its key concept and implement.

# <font style="color:green">1. Soft-Dice Loss</font>

As you remember, `Dice Coefficient` or `F1-score` is the double number of true positives, divided by the total number of pixels, in both ground truth and predicted-segmentation masks. We don’t have true negatives - in our case, they are just pixels of some other class. Therefore, the ground-truth annotation can be represented as a sum of true positives and false negatives of every class. The predicted pixels are a sum of true and false positives across all classes.


**Here is its formula:**

$$ DC = \frac{2\cdot{\sum_{i=0}^{N}{p_{i}y_{i}}}}{\sum_{i=0}^{N}{p_{i}} + \sum_{i=0}^{N}{y_{i}}},$$


where,

$p_{i}$ - is prediction for pixel $i$,

$y_{i}$ - is ground truth for pixel $i$;

$N$ is the total number of pixels on the image.


Now, let’s discuss how to turn this metric into a loss function. Dice coefficient compares two discrete masks, and is therefore a discrete function. To make a loss function out of it, we must  come up with a differentiable function. So, instead of thresholded values like `0` and `1`, we make floating-point probabilities, in the range of `[0, 1]`. The function that helps us do this is a negative logarithm. You may recall that even classification cross-entropy loss uses negative logarithm for the same reasons.

Taking all of the above into account, our loss function i.e. `Soft-Dice loss` can be represented as:


$$loss_{soft-dice} = -\log{\frac{2\cdot{\sum_{i=0}^{N}{p_{i}y_{i} + \epsilon}}}{\sum_{i=0}^{N}{p_{i}} + \sum_{i=0}^{N}{y_{i}} + \epsilon}}.$$

Note: We also added $\epsilon$ - an epsilon - both to the numerator and the denominator, to aid  computation.The idea is   to avoid situations like division on zero, or taking the logarithm from zero, which is undefined. As for epsilon, you can take any value that is small enough, for example `1e-5`.

# <font style="color:green">2. Soft-Dice Loss Implementation</font>

**Let's start implementing it.**

In [1]:
import torch
import torch.nn as nn

import numpy as np

from dataclasses import dataclass
import random

**Configuration for reproducible results.**

In [2]:
@dataclass
class SystemConfig:
    seed: int = 42  # seed number to set the state of all random number generators
    cudnn_benchmark_enabled: bool = False  # enable CuDNN benchmark for the sake of performance
    cudnn_deterministic: bool = True  # make cudnn deterministic (reproducible training)

In [3]:
def setup_system(system_config: SystemConfig) -> None:
    torch.manual_seed(system_config.seed)
    np.random.seed(system_config.seed)
    random.seed(system_config.seed)
    torch.set_printoptions(precision=10)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(system_config.seed)
        torch.backends.cudnn_benchmark_enabled = system_config.cudnn_benchmark_enabled
        torch.backends.cudnn.deterministic = system_config.cudnn_deterministic

**Create SoftDiceLoss class inherited from `nn.Module`**

In [4]:
# create SoftDiceLoss class inherited from nn.Module
class SoftDiceLoss(nn.Module):
    """
        Implementation of the Soft-Dice Loss function.

        Arguments:
            num_classes (int): number of classes.
            eps (float): value of the floating point epsilon.
    """
    def __init__(self, num_classes, eps=1e-5):
        super().__init__()
        # init class fields
        self.num_classes = num_classes
        self.eps = eps

    # define the forward pass
    def forward(self, preds, targets):  # pylint: disable=unused-argument
        """
            Compute Soft-Dice Loss.

            Arguments:
                preds (torch.FloatTensor):
                    tensor of predicted labels. The shape of the tensor is (B, num_classes, H, W).
                targets (torch.LongTensor):
                    tensor of ground-truth labels. The shape of the tensor is (B, 1, H, W).
            Returns:
                mean_loss (float32): mean loss by class  value.
        """
        loss = 0
        # iterate over all classes
        for cls in range(self.num_classes):
            # get ground truth for the current class
            target = (targets == cls).float()

            # get prediction for the current class
            pred = preds[:, cls]

            # calculate intersection
            intersection = (pred * target).sum()

            # compute dice coefficient
            dice = (2 * intersection + self.eps) / (pred.sum() + target.sum() + self.eps)

            # compute negative logarithm from the obtained dice coefficient
            loss = loss - dice.log()

        # get mean loss by class value
        loss = loss / self.num_classes

        return loss.item()

**Check the implementation.**

In [5]:
# apply system settings
setup_system(SystemConfig)

# generate input data
ground_truth = torch.zeros(1, 224, 224)
ground_truth[:, :50, :50] = 1
ground_truth[:, 50:100, 50:100] = 2

# generate random predictions and use softmax to get probabilities
prediction = torch.zeros(1, 3, 224, 224).uniform_().softmax(dim=1)

# create an instance of a SoftDiceLoss class
soft_dice_loss = SoftDiceLoss(num_classes=3)

# get the loss value
loss = soft_dice_loss(prediction, ground_truth)

print('Loss: {}'.format(loss))

Loss: 1.8738912343978882
