# Loss Functions

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
import math

## How to define & extend a loss in PyTorch

#### (A) Write a callable or an nn.Module

In [None]:
def my_loss(pred, target):
    # pred, target are tensors with gradients enabled for pred
    loss = (pred - target).abs().mean()
    return loss

class MyLoss(nn.Module):
    def forward(self, pred, target):
        return (pred - target).abs().mean()

Because every primitive operation in the body is differentiable and tracks the


computation graph, you do not implement .backward() yourself.

#### (B) Use it exactly like a built-in loss

In [None]:
criterion = MyLoss()
loss = criterion(logits, masks)      # ← returns scalar
loss.backward()

#### (C) Augment a built-in loss
You can keep the original value and add any penalty:

In [None]:
bce = nn.BCEWithLogitsLoss()
loss = bce(logits, masks)

l2_penalty = 0.0
for p in model.parameters():
    l2_penalty += p.pow(2).sum()

loss = loss + 1e-4 * l2_penalty      # combined objective

## Loss Functions

### Segmentation (UNet)
Four widely-used segmentation losses:
1. Dice Loss
2. Focal Loss
3. Tversky Loss
4. BCE

In [None]:
# ------------------------------------------------------------------ #
class DiceLoss(nn.Module):
    """Soft Dice loss for logits; works with BCEWithLogits-style output."""
    def __init__(self, eps: float = 1e-6):
        super().__init__()
        self.eps = eps

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        num   = 2 * (probs * targets).sum(dim=(1,2,3)) + self.eps
        den   = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) + self.eps
        dice  = num / den
        return 1 - dice.mean()                      # minimise (1-dice)
    
    
# ------------------------------------------------------------------ #
class FocalLoss(nn.Module):
    """
    γ: focusing parameter; α: optional class-balancing weight (scalar or tensor)
    """
    def __init__(self, gamma: float = 2.0, alpha: float | None = 0.25,
                 reduction: str = "mean"):
        super().__init__()
        self.gamma, self.alpha, self.reduction = gamma, alpha, reduction

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        pt    = torch.where(targets == 1, probs, 1 - probs)   # p_t
        log_pt = torch.log(pt.clamp(min=1e-6))
        loss  = -(1 - pt) ** self.gamma * log_pt
        if self.alpha is not None:
            at = torch.where(targets == 1, self.alpha, 1 - self.alpha)
            loss = at * loss
        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        return loss                                           # 'none'
    
    

# ------------------------------------------------------------------ #
class TverskyLoss(nn.Module):
    """α/β balance FP vs FN. Dice = α = β = 0.5 ; IoU ≈ α = β = 1."""
    def __init__(self, alpha: float = 0.5, beta: float = 0.5,
                 eps: float = 1e-6):
        super().__init__()
        self.alpha, self.beta, self.eps = alpha, beta, eps

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        tp = (probs * targets).sum(dim=(1,2,3))
        fp = (probs * (1 - targets)).sum(dim=(1,2,3))
        fn = ((1 - probs) * targets).sum(dim=(1,2,3))
        tversky = (tp + self.eps) / (tp + self.alpha*fp + self.beta*fn + self.eps)
        return 1 - tversky.mean()
    

# ------------------------------------------------------------------ #
class CustomBCEWithLogitsLoss(nn.Module):
    """
    Numerically stable BCE loss that takes raw logits.
    Args
    ----
    reduction: 'mean' | 'sum' | 'none'
    """
    def __init__(self, reduction: str = "mean"):
        super().__init__()
        if reduction not in {"mean", "sum", "none"}:
            raise ValueError("reduction must be 'mean', 'sum', or 'none'")
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, targets: torch.Tensor):
        # targets should be float tensor of 0/1 with the same shape as logits
        max_val = torch.clamp_min(logits, 0.0)
        loss = max_val - logits * targets + torch.log1p(torch.exp(-torch.abs(logits)))
        if   self.reduction == "mean": return loss.mean()
        elif self.reduction == "sum" : return loss.sum()
        else:                          return loss    # 'none'

Can be used interchangeably:

In [None]:
criterion = DiceLoss()          # or FocalLoss(), TverskyLoss()
loss = criterion(outputs, masks)

# for U-Net / binary segmentation
criterion = CustomBCEWithLogitsLoss()
loss = criterion(pred_logits, target_masks)

### ResNet
Three widely-used classification losses:
1. Focal Loss
2. Label-Smoothing Cross-Entropy
3. Multi-Class Hinge (SVM) Loss
4. Cross-Entropy

In [None]:
# ------------------------------------------------------------------ #
class FocalLossMC(nn.Module):
    """
    Multiclass focal loss working on raw logits.
    gamma: focusing parameter; alpha: weight per class or scalar
    """
    def __init__(self, gamma: float = 2.0,
                 alpha: torch.Tensor | float | None = None,
                 reduction: str = "mean"):
        super().__init__()
        if isinstance(alpha, float):
            alpha = torch.tensor([alpha])
        self.gamma, self.register_buffer('alpha', alpha if alpha is not None else None)
        self.reduction = reduction

    def forward(self, logits, targets):
        log_probs = F.log_softmax(logits, dim=1)              # (N,C)
        probs = log_probs.exp()
        idx   = targets.unsqueeze(1)                         # (N,1)
        pt    = probs.gather(1, idx).squeeze(1)              # (N,)
        log_pt = log_probs.gather(1, idx).squeeze(1)
        loss = - (1 - pt) ** self.gamma * log_pt
        if self.alpha is not None:
            at = self.alpha.gather(0, targets)
            loss = at * loss
        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        return loss

# ------------------------------------------------------------------ #
class LabelSmoothingCrossEntropy(nn.Module):
    """
    ε: smoothing factor; reduction: 'mean'|'sum'|'none'
    """
    def __init__(self, eps: float = 0.1, reduction: str = "mean"):
        super().__init__()
        self.eps, self.reduction = eps, reduction

    def forward(self, logits, targets):
        n_classes = logits.size(1)
        log_probs = F.log_softmax(logits, dim=1)

        # Negative-log-likelihood of smoothed target distribution
        loss = -log_probs.sum(dim=1) * (self.eps / n_classes)
        nll  = -log_probs.gather(1, targets.unsqueeze(1)).squeeze(1)
        loss = loss + (1 - self.eps) * nll

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        return loss

# ------------------------------------------------------------------ #
class MultiClassHingeLoss(nn.Module):
    """
    Implements the Crammer-Singer multiclass hinge:
    L = mean( max_{j≠y}(0, 1 + s_j - s_y) )
    """
    def __init__(self, reduction: str = "mean"):
        super().__init__()
        self.reduction = reduction

    def forward(self, logits, targets):
        N, C = logits.shape
        # Gather correct class score s_y
        s_y = logits.gather(1, targets.view(-1, 1))          # (N,1)
        # Compute margin to every other class
        margin = logits + 1.0 - s_y                          # broadcast
        margin.scatter_(1, targets.view(-1, 1), 0.0)         # ignore y
        loss = torch.clamp(margin, min=0.0).max(dim=1).values  # max_{j≠y}
        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        return loss
    

# ------------------------------------------------------------------ #
class CustomCrossEntropyLoss(nn.Module):
    """
    Vanilla multi-class cross-entropy working on raw logits.
    Args
    ----
    reduction: 'mean' | 'sum' | 'none'
    """
    def __init__(self, reduction: str = "mean"):
        super().__init__()
        if reduction not in {"mean", "sum", "none"}:
            raise ValueError("reduction must be 'mean', 'sum', or 'none'")
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, targets: torch.Tensor):
        # logits: (N, C, ...)   targets: (N, ...) of int64 class indices
        log_probs = F.log_softmax(logits, dim=1)          # same dim as PyTorch CE
        # gather the log-probability of the correct class for every sample
        loss = -log_probs.gather(1, targets.unsqueeze(1)).squeeze(1)
        if self.reduction == "mean": return loss.mean()
        elif self.reduction == "sum": return loss.sum()
        else:                          return loss        # 'none'

In [None]:
criterion = LabelSmoothingCrossEntropy(eps=0.05)
loss = criterion(logits, labels)
loss.backward()

# for ResNet / multi-class classification
criterion = CustomCrossEntropyLoss()
loss = criterion(pred_logits, target_labels)