# Loss Functions Deep Dive

Loss functions turn objectives into gradients. Understanding their behaviour is essential before we tackle attention mechanisms, transformers, and Mixture-of-Experts models.

## Learning Objectives

- Match task types (regression, classification, probabilistic) to appropriate losses.
- Interpret loss landscapes to anticipate optimization behaviour.
- Implement masking and custom losses where built-ins fall short.
- Build intuition for the losses used in attention and transformer architectures.

## Loss Landscape Overview

| Task | Common Loss | Notes |
|------|-------------|-------|
| Regression | `MSELoss`, `L1Loss`, `SmoothL1Loss` | Choose L1 for robustness to outliers, MSE for smooth gradients. |
| Binary classification | `BCEWithLogitsLoss` | Combines sigmoid + BCE with numerical stability. |
| Multi-class classification | `CrossEntropyLoss` | Expects raw logits and class indices. |
| Multi-label | `BCEWithLogitsLoss` | Independent sigmoid per label. |
| Probabilistic | `KLDivLoss`, `NLLLoss` | Operate on log probabilities and distributions. |
| Sequence modeling | `CrossEntropyLoss` + masking | Ignore padding tokens to avoid skewed gradients. |

Whenever you introduce a new objective, fit it into this table first.

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np

preds_reg = torch.tensor([2.5, 0.0, -1.5])
targets_reg = torch.tensor([3.0, -0.5, -1.0])
mse = F.mse_loss(preds_reg, targets_reg)
mae = F.l1_loss(preds_reg, targets_reg)
print(f"MSE: {mse:.4f}")
print(f"MAE: {mae:.4f}")

diffs = np.linspace(-3, 3, 200)
plt.plot(diffs, diffs ** 2, label="MSE contribution")
plt.plot(diffs, np.abs(diffs), label="MAE contribution")
plt.legend(); plt.xlabel("Error"); plt.ylabel("Loss contribution"); plt.grid(True)
plt.show()


### Classification Essentials

`CrossEntropyLoss` expects raw logits and integer class labels. Do not apply softmax manually; the loss combines log-softmax and negative log-likelihood in a single stable operation.

In [None]:
logits = torch.tensor([[2.0, 0.5, -1.0], [0.1, 2.3, 1.0]])
targets = torch.tensor([0, 2])
loss = F.cross_entropy(logits, targets)
print(f"Cross entropy: {loss:.4f}")


### Masking Padding Tokens

Sequence-to-sequence models contain padded positions. Use `ignore_index` (default `-100`) so padding does not skew gradient updates.

In [None]:
logits = torch.tensor([[2.0, 0.1, -1.0], [1.2, 1.5, 0.3]])
targets = torch.tensor([0, -100])
masked_loss = F.cross_entropy(logits, targets, ignore_index=-100)
print(f"Masked loss: {masked_loss:.4f}")


## Mini Task – Binary Cross-Entropy Stability

Compute the binary cross-entropy loss manually (sigmoid + `binary_cross_entropy`) and compare it with `BCEWithLogitsLoss` to verify numerical stability.

Attempt the starter cell before revealing the hidden solution.

In [None]:
logits = torch.tensor([[1.2], [-0.7], [0.4]])
targets = torch.tensor([[1.0], [0.0], [1.0]])

# TODO: compute manual BCE (sigmoid + binary_cross_entropy) and stable BCE with logits
# print both values


In [None]:
logits = torch.tensor([[1.2], [-0.7], [0.4]])
targets = torch.tensor([[1.0], [0.0], [1.0]])

manual = F.binary_cross_entropy(torch.sigmoid(logits), targets)
stable = F.binary_cross_entropy_with_logits(logits, targets)
print(f"Manual BCE: {manual.item():.6f}")
print(f"Stable BCE: {stable.item():.6f}")


### Custom Loss Example (Focal Loss)

Focal loss down-weights easy examples. It shows up in detection models and can help with class imbalance.

In [None]:
def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
    bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
    probs = torch.sigmoid(logits)
    pt = torch.where(targets == 1, probs, 1 - probs)
    return (alpha * (1 - pt) ** gamma * bce).mean()

print(focal_loss(torch.tensor([[1.2], [-0.7], [0.4]]), torch.tensor([[1.0], [0.0], [1.0]])))


## Comprehensive Exercise – Loss Selection Utility

Implement `choose_loss(task_type, **details)` that returns both the initialized loss function and a short justification. Cover regression, multi-class, multi-label, language modeling with padding, and knowledge distillation.

In [None]:
def choose_loss(task_type: str, **details):
    # TODO: map task types to loss functions and explanations
    raise NotImplementedError

examples = [
    ("regression", {}),
    ("multiclass", {"num_classes": 5}),
    ("multilabel", {"num_labels": 3}),
    ("language_modeling", {"ignore_index": -100}),
    ("distillation", {}),
]

for task, kwargs in examples:
    loss_fn, reason = choose_loss(task, **kwargs)
    print(task, type(loss_fn), reason)


In [None]:
def choose_loss(task_type: str, **details):
    task_type = task_type.lower()
    if task_type == "regression":
        return torch.nn.MSELoss(), "Squared error supplies smooth gradients for regression."
    if task_type == "multiclass":
        if "num_classes" not in details:
            raise ValueError("Specify num_classes for multiclass classification.")
        return torch.nn.CrossEntropyLoss(), "CrossEntropy combines log-softmax with NLL."
    if task_type == "multilabel":
        return torch.nn.BCEWithLogitsLoss(), "Independent sigmoid heads support multi-label targets."
    if task_type == "language_modeling":
        ignore_index = details.get("ignore_index", -100)
        return torch.nn.CrossEntropyLoss(ignore_index=ignore_index), "Ignore padding tokens when computing loss."
    if task_type == "distillation":
        return (
            lambda student, teacher: F.kl_div(
                F.log_softmax(student, dim=-1),
                F.softmax(teacher, dim=-1),
                reduction="batchmean",
            ),
            "KL divergence aligns student distributions with teacher outputs.",
        )
    raise ValueError(f"Unknown task type: {task_type}")

examples = [
    ("regression", {}),
    ("multiclass", {"num_classes": 5}),
    ("multilabel", {"num_labels": 3}),
    ("language_modeling", {"ignore_index": -100}),
    ("distillation", {}),
]

for task, kwargs in examples:
    loss_fn, reason = choose_loss(task, **kwargs)
    print(task, type(loss_fn), reason)


## Further Reading

- PyTorch Loss Functions: https://pytorch.org/docs/stable/nn.html#loss-functions
- Lin et al. (2017) – Focal Loss for Dense Object Detection
- Label smoothing strategies in transformer-based language models