Experimenting and re-implementing different loss functions. 

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

#### Binary Cross-Entropy / Log Loss

Binary cross-entropy (log loss) is a loss function used in **binary classification problems**.  It quantifies the difference between the actual class labels (0 or 1) and the predicted probabilities output by the model. The lower the binary cross-entropy value, the better the model’s predictions align with the true labels.

Binary Cross-Entropy measures the distance between the true labels and the predicted probabilities. When the predicted probability is close to the actual label, the BCE value is low, indicating a good prediction. Conversely, when the predicted probability deviates significantly from the actual label, the BCE value is high, indicating a poor prediction. The logarithmic component of the BCE function penalizes wrong predictions more heavily than correct ones.

In [3]:
y_true = torch.tensor([0, 1, 1, 1], dtype=torch.float32)
y_pred = torch.tensor([0.1, 0.9, 0.8, 0.3], dtype=torch.float32) # model’s output is a probability between 0 and 1

y_true_np = y_true.numpy()
y_pred_np = y_pred.numpy()

In [4]:
def bce(y_true, y_pred):   
    eps = 1e-9
    y_pred = np.clip(y_pred, eps, 1-eps)
    return -np.mean(y_true*np.log(y_pred) + (1 - y_true)*np.log(1 - y_pred))

loss = F.binary_cross_entropy(y_pred, y_true)
loss_bce = bce(y_true_np, y_pred_np)

assert np.allclose(loss, loss_bce), "Incorrect implementation"

#### Mean Squared Error (MSE) / L2 Loss

Quantifies the magnitude of the error between an algorithm prediction and an actual output by taking the average of the squared difference between the predictions and the target values. It is useful for **regression tasks**, particularly when we want to penalize larger errors more heavily.

In [5]:
def mse(y_true, y_pred): 
    return np.mean((y_true - y_pred) ** 2)

loss = F.mse_loss(y_pred, y_true)
loss_mse = mse(y_true_np, y_pred_np)

assert np.allclose(loss, loss_mse), "Incorrect implementation"

#### Mean Absolute Error (MAE) / L1 Loss

Used in **regression tasks** that calculates the average absolute differences between predicted values from a machine learning model and the actual target values. Unlike Mean Squared Error (MSE), MAE does not square the differences, treating all errors with equal weight regardless of their magnitude. Compared to MSE, MAE does not square the differences, which makes it less sensitive to outliers because it assigns an equal weight to all errors, regardless of their magnitude.

In [6]:
def mae(y_true, y_pred): 
    return np.mean(np.abs(y_true - y_pred))

loss = F.l1_loss(y_pred, y_true)
loss_mae = mae(y_true_np, y_pred_np)

assert np.allclose(loss, loss_mae), "Incorrect implementation"

#### Dice Loss

Used in **image segmentation tasks**, widely used to segment both 2D and 3D medical images.

In [81]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, preds, targets):
        # Flatten predictions and targets per sample
        num = preds.size(0)  # batch size

        # Flattens into [num, N] shape where N = C × D × H × W or C × H × W.
        preds = preds.contiguous().view(num, -1)
        targets = targets.contiguous().view(num, -1)
        
        intersection = (preds * targets).sum(dim=1)
        dice = (2. * intersection + self.smooth) / \
               (preds.sum(dim=1) + targets.sum(dim=1) + self.smooth)

        loss = 1 - dice  # Dice loss
        return loss.mean()

dice_loss = DiceLoss()

In [83]:
# [2, 1, 4, 4], Batch of 2, single channel, 4x4
pred_2d = torch.tensor([[[[0, 1, 0, 1],
                            [1, 1, 0, 0],
                            [0, 0, 1, 1],
                            [1, 0, 1, 0]]],
                          
                          [[[0, 1, 1, 0],
                            [1, 0, 0, 1],
                            [0, 1, 1, 0],
                            [1, 1, 0, 0]]]], dtype=torch.float32)

target_2d = torch.tensor([[[[0, 1, 0, 1],
                              [1, 1, 0, 0],
                              [0, 0, 1, 1],
                              [1, 0, 1, 0]]],
                            
                            [[[1, 1, 1, 0],
                              [1, 0, 0, 1],
                              [0, 1, 0, 0],
                              [1, 1, 0, 0]]]], dtype=torch.float32)

print("pred_2d shape: ", pred_2d.shape)
print("target_2d shape: ", target_2d.shape)

# [1, 1, 2, 3, 3] Batch of 1, single channel, depth=2, 3x3
# N, C, D, H, W > batch_size=1, channels=1, depth=2 (number of slices), height=3, width=3
pred_3d = torch.tensor([[[[[0, 1, 1],
                          [1, 0, 0],
                          [0, 1, 0]],
                         
                         [[1, 1, 0],
                          [0, 1, 0],
                          [1, 0, 1]]]]], dtype=torch.float32)

target_3d = torch.tensor([[[[[0, 1, 1],
                            [1, 0, 0],
                            [0, 1, 0]],
                           
                           [[1, 1, 0],
                            [1, 1, 0],
                            [1, 0, 1]]]]], dtype=torch.float32)

print("pred_3d shape: ", pred_3d.shape)
print("target_3d shape: ", target_3d.shape)

loss_value_2d = dice_loss(pred_2d, target_2d)
loss_value_3d = dice_loss(pred_3d, target_3d)

print(loss_value_2d, loss_value_3d)

pred_2d shape:  torch.Size([2, 1, 4, 4])
target_2d shape:  torch.Size([2, 1, 4, 4])
pred_3d shape:  torch.Size([1, 1, 2, 3, 3])
target_3d shape:  torch.Size([1, 1, 2, 3, 3])
tensor(0.0625) tensor(0.0526)
