# Topic 4: Loss Functions - A Comprehensive Guide

## Learning Objectives

By the end of this notebook, you will:
- Understand WHY loss functions are the "North Star" of neural network training
- Know WHEN to use each major loss function (with real-world scenarios)
- Understand the mathematics behind common loss functions
- Implement custom loss functions
- Avoid common loss function mistakes
- Combine multiple loss functions

---

## 1. The Big Picture: Why Loss Functions?

### The Training Problem

You've built a neural network, but how do you make it better? You need:
1. A way to **measure** how wrong your predictions are
2. A **single number** to optimize (can't optimize multiple conflicting goals)
3. A function that's **differentiable** (so we can compute gradients)

**This is the loss function!**

### Loss Function = Objective Function = Cost Function

Different names, same concept:
- **Loss**: How much error on one example
- **Cost**: Average loss over dataset
- **Objective**: What we're trying to minimize

### The Training Loop (Preview)

```python
for epoch in range(num_epochs):
    # 1. Forward pass: make predictions
    predictions = model(inputs)
    
    # 2. Compute loss: how wrong are we?
    loss = loss_function(predictions, targets)
    
    # 3. Backward pass: compute gradients
    loss.backward()
    
    # 4. Update weights
    optimizer.step()
```

**Loss is the bridge between predictions and learning!**

### Key Principle: Choose Loss Based on Problem Type

| Problem Type | Loss Function | Why |
|-------------|---------------|-----|
| Binary Classification | BCELoss, BCEWithLogitsLoss | Probability of two classes |
| Multi-class Classification | CrossEntropyLoss | Probability distribution over classes |
| Regression | MSELoss, L1Loss, SmoothL1Loss | Minimize distance to target |
| Ranking/Similarity | ContrastiveLoss, TripletLoss | Learn embeddings |

We'll explore each in detail!

In [None]:
# Setup
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

print(f"PyTorch version: {torch.__version__}")

# Set random seed
torch.manual_seed(42)
np.random.seed(42)

---

## 2. Regression Loss Functions

### When to Use Regression?

Predicting **continuous values**:
- House prices
- Temperature
- Stock prices
- Age estimation
- Object coordinates (bounding boxes)

### 2.1 Mean Squared Error (MSE) Loss

**Formula**: $\text{MSE} = \frac{1}{n}\sum_{i=1}^{n}(y_i - \hat{y}_i)^2$

**When to use**:
- Standard regression problems
- When large errors should be penalized heavily (due to squaring)
- When your target values are continuous and unbounded

**Pros**:
- Smooth, differentiable everywhere
- Heavily penalizes large errors
- Well-studied, standard choice

**Cons**:
- Sensitive to outliers (squared term)
- Can explode with large errors

**Real-world scenarios**:
- Predicting house prices (no extreme outliers)
- Image reconstruction (pixel values 0-255)
- Temperature forecasting

In [None]:
# MSE Loss example
predictions = torch.tensor([2.5, 0.0, 2.0, 8.0])
targets = torch.tensor([3.0, -0.5, 2.0, 7.0])

# Method 1: Using nn.MSELoss
mse_loss = nn.MSELoss()
loss = mse_loss(predictions, targets)
print(f"MSE Loss: {loss.item():.4f}")

# Method 2: Manual calculation
manual_loss = ((predictions - targets) ** 2).mean()
print(f"Manual MSE: {manual_loss.item():.4f}")

# Show individual squared errors
squared_errors = (predictions - targets) ** 2
print(f"\nIndividual squared errors: {squared_errors}")
print(f"Mean: {squared_errors.mean().item():.4f}")

### 2.2 Mean Absolute Error (MAE) / L1 Loss

**Formula**: $\text{MAE} = \frac{1}{n}\sum_{i=1}^{n}|y_i - \hat{y}_i|$

**When to use**:
- When your data has outliers
- When all errors should be weighted equally (no squaring)
- When you want more robust predictions

**Pros**:
- Robust to outliers
- Same units as your target variable (interpretable)
- Uniform penalty for errors

**Cons**:
- Not differentiable at zero (can slow convergence)
- Doesn't penalize large errors as much as MSE

**Real-world scenarios**:
- Predicting house prices with outliers (mansions)
- Financial forecasting (outlier events common)
- Sensor data with occasional bad readings

In [None]:
# L1 Loss example
predictions = torch.tensor([2.5, 0.0, 2.0, 8.0])
targets = torch.tensor([3.0, -0.5, 2.0, 7.0])

# Using nn.L1Loss
l1_loss = nn.L1Loss()
loss = l1_loss(predictions, targets)
print(f"L1 Loss: {loss.item():.4f}")

# Manual calculation
manual_loss = (predictions - targets).abs().mean()
print(f"Manual L1: {manual_loss.item():.4f}")

# Compare with MSE
mse = nn.MSELoss()(predictions, targets)
print(f"\nFor comparison, MSE: {mse.item():.4f}")
print(f"L1 is smaller because it doesn't square the errors")

### 2.3 Smooth L1 Loss (Huber Loss)

**Formula**: 
$$\text{SmoothL1}(x) = \begin{cases}
0.5x^2 & \text{if } |x| < 1 \\
|x| - 0.5 & \text{otherwise}
\end{cases}$$

**When to use**:
- Best of both worlds: MSE for small errors, L1 for large errors
- Object detection (bounding box regression)
- When you want robustness to outliers but smooth gradients

**Pros**:
- Combines advantages of MSE and L1
- Smooth gradients everywhere
- Less sensitive to outliers than MSE

**Cons**:
- Slightly more complex
- Threshold parameter (beta) needs tuning

**Real-world scenarios**:
- Object detection (Faster R-CNN uses this)
- Reinforcement learning (Q-learning)
- Any regression with potential outliers

In [None]:
# Smooth L1 Loss example
predictions = torch.tensor([2.5, 0.0, 2.0, 8.0])
targets = torch.tensor([3.0, -0.5, 2.0, 7.0])

smooth_l1 = nn.SmoothL1Loss()
loss = smooth_l1(predictions, targets)
print(f"Smooth L1 Loss: {loss.item():.4f}")

# Compare all three
mse = nn.MSELoss()(predictions, targets)
l1 = nn.L1Loss()(predictions, targets)

print(f"\nComparison:")
print(f"MSE Loss:      {mse.item():.4f}")
print(f"L1 Loss:       {l1.item():.4f}")
print(f"Smooth L1:     {loss.item():.4f}")
print(f"\nSmooth L1 is between MSE and L1!")

In [None]:
# Visualize the three regression losses
errors = torch.linspace(-3, 3, 100)

# Compute losses for different error values
mse_values = errors ** 2
l1_values = errors.abs()
smooth_l1_values = torch.where(
    errors.abs() < 1,
    0.5 * errors ** 2,
    errors.abs() - 0.5
)

plt.figure(figsize=(12, 5))

# Plot losses
plt.subplot(1, 2, 1)
plt.plot(errors.numpy(), mse_values.numpy(), 'b-', linewidth=2, label='MSE (x²)')
plt.plot(errors.numpy(), l1_values.numpy(), 'r-', linewidth=2, label='L1 (|x|)')
plt.plot(errors.numpy(), smooth_l1_values.numpy(), 'g-', linewidth=2, label='Smooth L1')
plt.xlabel('Error (prediction - target)', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Regression Loss Functions', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
plt.axvline(x=0, color='k', linestyle='--', alpha=0.3)

# Plot gradients
plt.subplot(1, 2, 2)
# MSE gradient: 2x
mse_grad = 2 * errors
# L1 gradient: sign(x)
l1_grad = torch.sign(errors)
# Smooth L1 gradient
smooth_l1_grad = torch.where(
    errors.abs() < 1,
    errors,
    torch.sign(errors)
)

plt.plot(errors.numpy(), mse_grad.numpy(), 'b-', linewidth=2, label='MSE gradient')
plt.plot(errors.numpy(), l1_grad.numpy(), 'r-', linewidth=2, label='L1 gradient')
plt.plot(errors.numpy(), smooth_l1_grad.numpy(), 'g-', linewidth=2, label='Smooth L1 gradient')
plt.xlabel('Error', fontsize=12)
plt.ylabel('Gradient', fontsize=12)
plt.title('Loss Gradients', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
plt.axvline(x=0, color='k', linestyle='--', alpha=0.3)
plt.ylim(-3, 3)

plt.tight_layout()
plt.show()

print("Key observations:")
print("1. MSE grows quadratically (large errors dominate)")
print("2. L1 grows linearly (treats all errors equally)")
print("3. Smooth L1 combines both (quadratic near 0, linear far away)")
print("4. MSE gradient grows linearly (can explode!)")
print("5. L1 gradient is constant (can be slow near optimum)")
print("6. Smooth L1 gradient is best of both worlds")

---

## 3. Binary Classification Loss Functions

### When to Use Binary Classification?

Predicting **one of two classes**:
- Spam detection (spam/not spam)
- Medical diagnosis (disease/healthy)
- Sentiment analysis (positive/negative)
- Fraud detection (fraud/legitimate)
- Image contains cat (yes/no)

### 3.1 Binary Cross Entropy (BCE) Loss

**Formula**: $\text{BCE} = -\frac{1}{n}\sum_{i=1}^{n}[y_i \log(\hat{y}_i) + (1-y_i)\log(1-\hat{y}_i)]$

**When to use**:
- Binary classification where predictions are probabilities (0 to 1)
- After sigmoid activation
- Multi-label classification (each label independently binary)

**Pros**:
- Probabilistic interpretation
- Works well with sigmoid output
- Standard for binary problems

**Cons**:
- Requires predictions in [0, 1] range
- Can be numerically unstable (log of small numbers)

**Real-world scenarios**:
- Email spam classifier
- Medical test (disease present?)
- Ad click prediction (will user click?)

In [None]:
# BCE Loss example
# Predictions MUST be probabilities (0-1)
predictions = torch.tensor([0.9, 0.2, 0.8, 0.1])  # After sigmoid
targets = torch.tensor([1.0, 0.0, 1.0, 0.0])     # Ground truth labels

bce_loss = nn.BCELoss()
loss = bce_loss(predictions, targets)
print(f"BCE Loss: {loss.item():.4f}")

# Manual calculation
manual_loss = -(targets * torch.log(predictions) + 
                (1 - targets) * torch.log(1 - predictions)).mean()
print(f"Manual BCE: {manual_loss.item():.4f}")

# Show why this makes sense
print(f"\nBreakdown:")
for i in range(len(predictions)):
    pred, target = predictions[i].item(), targets[i].item()
    if target == 1:
        loss_i = -np.log(pred)
        print(f"Sample {i}: target=1, pred={pred:.2f} → loss={loss_i:.4f}")
    else:
        loss_i = -np.log(1 - pred)
        print(f"Sample {i}: target=0, pred={pred:.2f} → loss={loss_i:.4f}")

print(f"\nKey insight: High confidence correct predictions have low loss!")

### 3.2 BCE With Logits Loss (RECOMMENDED)

**Formula**: Combines sigmoid + BCE in one function

**When to use**:
- **Always prefer this over BCE Loss!**
- Binary classification with raw network outputs (logits)
- More numerically stable than sigmoid → BCE

**Pros**:
- Numerically stable (uses log-sum-exp trick)
- No need to apply sigmoid manually
- Better gradients

**Cons**:
- None! Always use this over BCE

**Real-world scenarios**:
- Same as BCE, but this is the preferred implementation

**Important**: Your model should output logits (raw values), not probabilities!

In [None]:
# BCE With Logits example
# Predictions are RAW outputs (logits), not probabilities
logits = torch.tensor([2.5, -1.0, 1.5, -2.0])  # Raw network output
targets = torch.tensor([1.0, 0.0, 1.0, 0.0])

bce_with_logits = nn.BCEWithLogitsLoss()
loss = bce_with_logits(logits, targets)
print(f"BCE With Logits Loss: {loss.item():.4f}")

# Compare with manual sigmoid + BCE
predictions = torch.sigmoid(logits)
bce = nn.BCELoss()(predictions, targets)
print(f"Sigmoid + BCE Loss:   {bce.item():.4f}")
print(f"\nThey're the same! But BCEWithLogitsLoss is more stable.")

# Show the probabilities for interpretation
print(f"\nLogits → Probabilities:")
for i, (logit, prob, target) in enumerate(zip(logits, predictions, targets)):
    print(f"Sample {i}: logit={logit:.2f} → prob={prob:.4f}, target={target:.0f}")

In [None]:
# Visualize BCE loss
predictions = torch.linspace(0.01, 0.99, 100)  # Avoid 0 and 1 (log undefined)

# Loss when target = 1
loss_target_1 = -torch.log(predictions)

# Loss when target = 0
loss_target_0 = -torch.log(1 - predictions)

plt.figure(figsize=(12, 5))

# Plot losses
plt.subplot(1, 2, 1)
plt.plot(predictions.numpy(), loss_target_1.numpy(), 'b-', linewidth=2, label='Target = 1')
plt.plot(predictions.numpy(), loss_target_0.numpy(), 'r-', linewidth=2, label='Target = 0')
plt.xlabel('Prediction (probability)', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Binary Cross Entropy Loss', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.ylim(0, 5)

# Add annotations
plt.annotate('Wrong prediction\nhigh loss', xy=(0.1, 2.3), fontsize=10,
            bbox=dict(boxstyle='round', facecolor='yellow', alpha=0.5))
plt.annotate('Correct prediction\nlow loss', xy=(0.85, 0.2), fontsize=10,
            bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.5))

# Plot decision boundary
plt.subplot(1, 2, 2)
plt.axvline(x=0.5, color='k', linestyle='--', linewidth=2, label='Decision boundary')
plt.fill_between([0, 0.5], 0, 10, alpha=0.3, color='red', label='Predict 0')
plt.fill_between([0.5, 1], 0, 10, alpha=0.3, color='blue', label='Predict 1')
plt.xlabel('Prediction (probability)', fontsize=12)
plt.ylabel('Loss', fontsize=12)
plt.title('Decision Regions', fontsize=14)
plt.legend(fontsize=12)
plt.xlim(0, 1)
plt.ylim(0, 5)

plt.tight_layout()
plt.show()

print("Key observations:")
print("1. Loss is 0 only when prediction perfectly matches target")
print("2. Loss → ∞ as prediction → wrong answer")
print("3. Asymmetric: worse to be confident and wrong!")
print("4. Encourages high confidence correct predictions")

---

## 4. Multi-Class Classification Loss

### When to Use Multi-Class Classification?

Predicting **one class from many**:
- Image classification (cat/dog/bird/...)
- Digit recognition (0-9)
- Language identification (English/Spanish/French/...)
- Medical diagnosis (disease A/B/C/healthy)

**Key difference from binary**: Exactly one class is correct.

### 4.1 Cross Entropy Loss (MOST IMPORTANT)

**Formula**: $\text{CE} = -\sum_{i=1}^{C} y_i \log(\hat{y}_i)$

Where $y_i$ is 1 for correct class, 0 otherwise.

**When to use**:
- **Default choice for multi-class classification**
- Image classification
- Text classification
- Any problem with mutually exclusive classes

**Pros**:
- Combines LogSoftmax + NLLLoss (numerically stable)
- Directly optimizes probability distribution
- Standard in all deep learning frameworks

**Important**: 
- Expects **logits** (raw outputs), not probabilities
- Applies softmax internally
- Targets are class indices (0, 1, 2, ..., C-1)

**Real-world scenarios**:
- MNIST digit classification (10 classes)
- ImageNet (1000 classes)
- Sentiment classification (positive/neutral/negative)

In [None]:
# Cross Entropy Loss example
# Model outputs: raw logits for 3 classes
logits = torch.tensor([
    [2.0, 1.0, 0.1],  # Sample 1
    [0.5, 2.5, 1.0],  # Sample 2
    [0.1, 0.2, 3.0]   # Sample 3
])

# Targets: class indices (NOT one-hot!)
targets = torch.tensor([0, 1, 2])  # Correct classes

ce_loss = nn.CrossEntropyLoss()
loss = ce_loss(logits, targets)
print(f"Cross Entropy Loss: {loss.item():.4f}")

# Show what's happening
print(f"\nLogits:")
print(logits)
print(f"\nTargets (class indices): {targets}")

# Convert logits to probabilities (for interpretation)
probs = F.softmax(logits, dim=1)
print(f"\nProbabilities (after softmax):")
print(probs)

# Show prediction vs target
predicted_classes = logits.argmax(dim=1)
print(f"\nPredicted classes: {predicted_classes}")
print(f"Target classes:    {targets}")
print(f"Correct: {(predicted_classes == targets).sum().item()}/{len(targets)}")

In [None]:
# Understanding CrossEntropyLoss internals
# It's actually: LogSoftmax + NLLLoss

logits = torch.tensor([
    [2.0, 1.0, 0.1],
    [0.5, 2.5, 1.0],
    [0.1, 0.2, 3.0]
])
targets = torch.tensor([0, 1, 2])

# Method 1: CrossEntropyLoss (recommended)
ce_loss = nn.CrossEntropyLoss()(logits, targets)

# Method 2: Manual (LogSoftmax + NLLLoss)
log_probs = F.log_softmax(logits, dim=1)
nll_loss = F.nll_loss(log_probs, targets)

print(f"CrossEntropyLoss:  {ce_loss.item():.6f}")
print(f"LogSoftmax + NLL:  {nll_loss.item():.6f}")
print(f"\nThey're identical!")

# Manual calculation (fully explicit)
probs = F.softmax(logits, dim=1)
log_probs_manual = torch.log(probs)
manual_loss = -log_probs_manual[range(len(targets)), targets].mean()

print(f"Fully manual:      {manual_loss.item():.6f}")
print(f"\nConclusion: CrossEntropyLoss is just a convenient wrapper!")

### 4.2 Common Mistake: Softmax + CrossEntropy

**DO NOT** apply softmax before CrossEntropyLoss!

In [None]:
# WRONG: Applying softmax before CrossEntropyLoss
logits = torch.tensor([[2.0, 1.0, 0.1]])
targets = torch.tensor([0])

# Correct way
loss_correct = nn.CrossEntropyLoss()(logits, targets)
print(f"Correct (logits → CE): {loss_correct.item():.4f}")

# WRONG way (double softmax!)
probs = F.softmax(logits, dim=1)  # DON'T DO THIS!
loss_wrong = nn.CrossEntropyLoss()(probs, targets)
print(f"Wrong (softmax → CE): {loss_wrong.item():.4f}")

print(f"\nThe losses are different! CrossEntropyLoss applies softmax internally.")
print(f"Always pass raw logits to CrossEntropyLoss!")

---

## 5. Multi-Label Classification Loss

### Multi-Class vs Multi-Label

- **Multi-class**: Exactly ONE class (cat XOR dog XOR bird)
- **Multi-label**: MULTIPLE classes possible (cat AND dog AND bird)

**When to use multi-label**:
- Image tagging (photo contains: person, dog, outdoor)
- Text categorization (article is: politics, economy, international)
- Medical diagnosis (patient has: diabetes, hypertension, obesity)

**Key difference**: Each label is independent binary classification!

### 5.1 BCE Loss for Multi-Label

**When to use**: Multi-label classification (each label independent)

**Approach**: Apply BCE to each label independently

In [None]:
# Multi-label classification example
# Each sample can have multiple labels

# Raw outputs (logits) for 4 labels
logits = torch.tensor([
    [2.0, -1.0, 0.5, 1.5],   # Sample 1: has labels 0, 2, 3
    [-0.5, 2.0, -1.0, 0.8],  # Sample 2: has labels 1, 3
    [1.0, 1.0, 1.0, -2.0]    # Sample 3: has labels 0, 1, 2
])

# Targets: binary for each label (can have multiple 1s)
targets = torch.tensor([
    [1.0, 0.0, 1.0, 1.0],  # Sample 1
    [0.0, 1.0, 0.0, 1.0],  # Sample 2
    [1.0, 1.0, 1.0, 0.0]   # Sample 3
])

# Use BCEWithLogitsLoss (treats each label independently)
loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(logits, targets)
print(f"Multi-label BCE Loss: {loss.item():.4f}")

# Show predictions
probs = torch.sigmoid(logits)
predictions = (probs > 0.5).float()  # Threshold at 0.5

print(f"\nProbabilities:")
print(probs)
print(f"\nPredictions (>0.5):")
print(predictions)
print(f"\nTargets:")
print(targets)

# Compute accuracy per label
correct = (predictions == targets).sum(dim=0)
total = len(targets)
print(f"\nAccuracy per label: {correct / total}")

---

## 6. Decision Guide: Which Loss Function?

### Flowchart

```
What's your task?
│
├─ Regression (continuous values)
│  ├─ Standard case → MSELoss
│  ├─ Has outliers → L1Loss or SmoothL1Loss
│  └─ Object detection → SmoothL1Loss
│
├─ Binary Classification (2 classes)
│  └─ ALWAYS use BCEWithLogitsLoss (most stable)
│
├─ Multi-Class Classification (1 of N classes)
│  └─ ALWAYS use CrossEntropyLoss
│
└─ Multi-Label Classification (multiple classes)
   └─ Use BCEWithLogitsLoss
```

### Quick Reference Table

| Task | Loss Function | Model Output | Target Format | Example |
|------|---------------|--------------|---------------|----------|
| Regression | `MSELoss` | Raw values | Float values | House price prediction |
| Regression (outliers) | `L1Loss` or `SmoothL1Loss` | Raw values | Float values | Stock price with outliers |
| Binary Classification | `BCEWithLogitsLoss` | Logits (1 value) | 0 or 1 | Spam detection |
| Multi-Class | `CrossEntropyLoss` | Logits (C values) | Class index (0 to C-1) | MNIST digits |
| Multi-Label | `BCEWithLogitsLoss` | Logits (L values) | Binary per label | Image tagging |

---

## 7. Advanced: Custom Loss Functions

Sometimes you need to create your own loss function!

In [None]:
# Example 1: Weighted MSE (some samples more important)
class WeightedMSELoss(nn.Module):
    """MSE where each sample has a weight"""
    
    def __init__(self):
        super().__init__()
    
    def forward(self, predictions, targets, weights):
        """
        Args:
            predictions: (N,) tensor
            targets: (N,) tensor
            weights: (N,) tensor of importance weights
        """
        squared_error = (predictions - targets) ** 2
        weighted_error = squared_error * weights
        return weighted_error.mean()

# Test it
predictions = torch.tensor([1.0, 2.0, 3.0])
targets = torch.tensor([1.2, 2.5, 2.8])
weights = torch.tensor([1.0, 2.0, 0.5])  # Middle sample is more important

loss_fn = WeightedMSELoss()
loss = loss_fn(predictions, targets, weights)
print(f"Weighted MSE Loss: {loss.item():.4f}")

# Compare with regular MSE
regular_mse = nn.MSELoss()(predictions, targets)
print(f"Regular MSE Loss:  {regular_mse.item():.4f}")
print(f"\nWeighted version emphasizes important samples!")

In [None]:
# Example 2: Combined loss (multiple objectives)
class CombinedLoss(nn.Module):
    """Combine classification and regression losses"""
    
    def __init__(self, alpha=0.5):
        super().__init__()
        self.alpha = alpha  # Balance between tasks
        self.ce_loss = nn.CrossEntropyLoss()
        self.mse_loss = nn.MSELoss()
    
    def forward(self, class_logits, class_targets, reg_pred, reg_targets):
        """
        Args:
            class_logits: (N, C) classification logits
            class_targets: (N,) class indices
            reg_pred: (N,) regression predictions
            reg_targets: (N,) regression targets
        """
        ce = self.ce_loss(class_logits, class_targets)
        mse = self.mse_loss(reg_pred, reg_targets)
        
        # Weighted combination
        total = self.alpha * ce + (1 - self.alpha) * mse
        return total, ce, mse

# Test it
class_logits = torch.randn(5, 3)  # 5 samples, 3 classes
class_targets = torch.randint(0, 3, (5,))
reg_pred = torch.randn(5)
reg_targets = torch.randn(5)

loss_fn = CombinedLoss(alpha=0.7)  # 70% classification, 30% regression
total_loss, ce_loss, mse_loss = loss_fn(class_logits, class_targets, reg_pred, reg_targets)

print(f"Classification Loss: {ce_loss.item():.4f}")
print(f"Regression Loss:     {mse_loss.item():.4f}")
print(f"Total Loss:          {total_loss.item():.4f}")
print(f"\nUse case: Multi-task learning (e.g., detect object + estimate size)")

---

## Mini Exercises

### Exercise 1: Choose the Right Loss

For each scenario, choose the appropriate loss function:

1. Predicting tomorrow's temperature (Celsius)
2. Classifying emails as spam or not spam
3. Recognizing handwritten digits (0-9)
4. Detecting all objects in an image (person, car, tree, etc.)
5. Predicting house prices where some houses are mansions (outliers)

In [None]:
# Your answers here


In [None]:
# Solutions
print("1. Temperature prediction:")
print("   → MSELoss (standard regression)")
print()

print("2. Spam classification:")
print("   → BCEWithLogitsLoss (binary classification)")
print()

print("3. Digit recognition (0-9):")
print("   → CrossEntropyLoss (multi-class, one correct digit)")
print()

print("4. Object detection (multiple objects):")
print("   → BCEWithLogitsLoss (multi-label, multiple objects present)")
print("   → PLUS SmoothL1Loss for bounding box coordinates")
print()

print("5. House prices with outliers:")
print("   → L1Loss or SmoothL1Loss (robust to outliers)")

### Exercise 2: Compute Loss Manually

Given predictions and targets, compute CrossEntropyLoss manually (without using PyTorch's function).

Predictions (logits): `[[1.0, 2.0, 0.5], [0.5, 0.5, 2.0]]`  
Targets: `[1, 2]`

In [None]:
# Your code here


In [None]:
# Solution
logits = torch.tensor([[1.0, 2.0, 0.5], [0.5, 0.5, 2.0]])
targets = torch.tensor([1, 2])

# Step 1: Compute softmax probabilities
probs = F.softmax(logits, dim=1)
print(f"Probabilities:\n{probs}")
print()

# Step 2: Take log of probabilities
log_probs = torch.log(probs)
print(f"Log probabilities:\n{log_probs}")
print()

# Step 3: Select log prob of correct class for each sample
# Sample 0: class 1 → log_probs[0, 1]
# Sample 1: class 2 → log_probs[1, 2]
correct_log_probs = log_probs[range(len(targets)), targets]
print(f"Log probs of correct classes: {correct_log_probs}")
print()

# Step 4: Negative mean
manual_loss = -correct_log_probs.mean()
print(f"Manual CrossEntropy: {manual_loss.item():.4f}")

# Verify with PyTorch
pytorch_loss = nn.CrossEntropyLoss()(logits, targets)
print(f"PyTorch CrossEntropy: {pytorch_loss.item():.4f}")
print(f"\nMatch: {torch.isclose(manual_loss, pytorch_loss)}")

### Exercise 3: Implement Focal Loss

Focal Loss is used for handling class imbalance (e.g., 99% negative, 1% positive).

Formula: $FL(p_t) = -(1-p_t)^\gamma \log(p_t)$

Where $p_t$ is the probability of the correct class, and $\gamma$ (typically 2) focuses on hard examples.

Implement FocalLoss as a custom nn.Module.

In [None]:
# Your code here


In [None]:
# Solution
class FocalLoss(nn.Module):
    """Focal Loss for handling class imbalance"""
    
    def __init__(self, gamma=2.0, alpha=0.25):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
    
    def forward(self, logits, targets):
        """
        Args:
            logits: (N, C) raw predictions
            targets: (N,) class indices
        """
        # Compute probabilities
        probs = F.softmax(logits, dim=1)
        
        # Get probability of correct class
        correct_probs = probs[range(len(targets)), targets]
        
        # Compute focal loss
        focal_weight = (1 - correct_probs) ** self.gamma
        ce_loss = -torch.log(correct_probs)
        focal_loss = focal_weight * ce_loss
        
        return focal_loss.mean()

# Test it
logits = torch.tensor([
    [3.0, 0.1, 0.1],  # Easy example (high confidence)
    [0.6, 0.5, 0.4],  # Hard example (low confidence)
])
targets = torch.tensor([0, 0])

focal = FocalLoss(gamma=2.0)
focal_loss = focal(logits, targets)

ce = nn.CrossEntropyLoss()(logits, targets)

print(f"Focal Loss: {focal_loss.item():.4f}")
print(f"CE Loss:    {ce.item():.4f}")
print(f"\nFocal loss focuses more on hard examples!")

# Show individual losses
probs = F.softmax(logits, dim=1)[:, 0]
print(f"\nSample 0: p={probs[0]:.4f} (easy)")
print(f"Sample 1: p={probs[1]:.4f} (hard)")
print(f"Focal loss weights hard examples more heavily!")

---

## Comprehensive Exercise: Loss Function Comparison

Generate synthetic regression data with outliers and compare:
1. MSELoss
2. L1Loss  
3. SmoothL1Loss

Fit a simple linear model to the data and visualize the results.

**Steps**:
1. Generate data: `y = 2x + 1 + noise`, with 10% outliers
2. Train 3 models (one per loss function)
3. Plot data and fitted lines
4. Compare which loss handles outliers best

In [None]:
# Your code here


In [None]:
# Solution
torch.manual_seed(42)

# 1. Generate data with outliers
n_samples = 100
x = torch.randn(n_samples, 1)
y = 2 * x + 1 + torch.randn(n_samples, 1) * 0.5  # y = 2x + 1 + noise

# Add outliers (10% of data)
n_outliers = n_samples // 10
outlier_indices = torch.randperm(n_samples)[:n_outliers]
y[outlier_indices] += torch.randn(n_outliers, 1) * 5  # Large noise

print(f"Generated {n_samples} samples with {n_outliers} outliers")

# 2. Train models with different losses
def train_model(loss_fn, n_iter=1000, lr=0.01):
    """Train simple linear model"""
    model = nn.Linear(1, 1)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    
    for _ in range(n_iter):
        optimizer.zero_grad()
        pred = model(x)
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()
    
    return model

# Train with different losses
print("\nTraining models...")
model_mse = train_model(nn.MSELoss())
model_l1 = train_model(nn.L1Loss())
model_smooth = train_model(nn.SmoothL1Loss())

# Extract parameters
w_mse, b_mse = model_mse.weight.item(), model_mse.bias.item()
w_l1, b_l1 = model_l1.weight.item(), model_l1.bias.item()
w_smooth, b_smooth = model_smooth.weight.item(), model_smooth.bias.item()

print(f"\nTrue parameters: w=2.0, b=1.0")
print(f"MSE:      w={w_mse:.3f}, b={b_mse:.3f}")
print(f"L1:       w={w_l1:.3f}, b={b_l1:.3f}")
print(f"Smooth L1: w={w_smooth:.3f}, b={b_smooth:.3f}")

In [None]:
# 3. Visualize results
x_line = torch.linspace(x.min(), x.max(), 100).reshape(-1, 1)

with torch.no_grad():
    y_mse = model_mse(x_line)
    y_l1 = model_l1(x_line)
    y_smooth = model_smooth(x_line)

plt.figure(figsize=(12, 5))

# Plot data and fits
plt.subplot(1, 2, 1)
# Separate normal and outlier points
mask = torch.ones(n_samples, dtype=torch.bool)
mask[outlier_indices] = False

plt.scatter(x[mask].numpy(), y[mask].numpy(), alpha=0.6, label='Normal data', s=30)
plt.scatter(x[~mask].numpy(), y[~mask].numpy(), color='red', alpha=0.6, 
           label='Outliers', s=50, marker='x')

# Plot fitted lines
plt.plot(x_line.numpy(), y_mse.numpy(), 'b-', linewidth=2, label=f'MSE: y={w_mse:.2f}x+{b_mse:.2f}')
plt.plot(x_line.numpy(), y_l1.numpy(), 'g-', linewidth=2, label=f'L1: y={w_l1:.2f}x+{b_l1:.2f}')
plt.plot(x_line.numpy(), y_smooth.numpy(), 'm-', linewidth=2, label=f'Smooth L1: y={w_smooth:.2f}x+{b_smooth:.2f}')

# True line
y_true = 2 * x_line + 1
plt.plot(x_line.numpy(), y_true.numpy(), 'k--', linewidth=2, alpha=0.5, label='True: y=2x+1')

plt.xlabel('x', fontsize=12)
plt.ylabel('y', fontsize=12)
plt.title('Regression with Outliers', fontsize=14)
plt.legend(fontsize=9)
plt.grid(True, alpha=0.3)

# Plot loss comparison
plt.subplot(1, 2, 2)
with torch.no_grad():
    pred_mse = model_mse(x)
    pred_l1 = model_l1(x)
    pred_smooth = model_smooth(x)
    
    loss_mse = nn.MSELoss()(pred_mse, y).item()
    loss_l1 = nn.L1Loss()(pred_l1, y).item()
    loss_smooth = nn.SmoothL1Loss()(pred_smooth, y).item()

losses = [loss_mse, loss_l1, loss_smooth]
labels = ['MSE', 'L1', 'Smooth L1']
colors = ['blue', 'green', 'magenta']

plt.bar(labels, losses, color=colors, alpha=0.7)
plt.ylabel('Final Loss', fontsize=12)
plt.title('Loss Comparison', fontsize=14)
plt.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\nConclusions:")
print("1. MSE is pulled by outliers (fitted line deviates from true line)")
print("2. L1 and Smooth L1 are more robust to outliers")
print("3. Smooth L1 often provides best balance")
print("4. Choose loss based on your data characteristics!")

---

## Key Takeaways

1. **Loss function = learning objective**: Tells the model what "good" means
2. **Choose based on task type**:
   - Regression → MSE, L1, or Smooth L1
   - Binary classification → BCEWithLogitsLoss
   - Multi-class → CrossEntropyLoss
   - Multi-label → BCEWithLogitsLoss
3. **Model output matters**:
   - CrossEntropyLoss expects logits (raw outputs)
   - BCELoss expects probabilities (after sigmoid)
   - BCEWithLogitsLoss expects logits (more stable)
4. **Outliers matter**: Use L1 or Smooth L1 when data has outliers
5. **Custom losses**: Easy to implement for special requirements
6. **Combined losses**: Can optimize multiple objectives simultaneously

### Common Mistakes to Avoid

1. **Applying softmax before CrossEntropyLoss** (it does it internally!)
2. **Using BCELoss with logits** (use BCEWithLogitsLoss instead)
3. **Wrong target format** (class indices vs one-hot vs binary)
4. **Ignoring data characteristics** (outliers, class imbalance)
5. **Mixing regression and classification losses** incorrectly

---

## Next Steps

You now understand how to measure model performance! Next, we'll put everything together and build complete **training loops with optimization**.

Continue to: [Topic 5: Training Loop & Optimization](05_training_optimization.ipynb)

---

## Further Reading

- [PyTorch Loss Functions Documentation](https://pytorch.org/docs/stable/nn.html#loss-functions)
- [Cross Entropy Loss Explained](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html)
- [Focal Loss Paper](https://arxiv.org/abs/1708.02002)
- [Loss Functions for Classification](https://machinelearningmastery.com/loss-functions-for-classification/)