# Debug Drill: The Exploding Loss

**Scenario:**
A colleague is training a model using gradient descent. They're frustrated.

"My loss went from 2.5 to 50 to infinity!" they say. "The training is broken!"

**Your Task:**
1. Run the training and observe the problem
2. Diagnose why the loss is exploding
3. Fix the hyperparameters
4. Write a 3-bullet postmortem

---

In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

In [None]:
# Simple quadratic loss surface for demonstration
def loss_fn(x, y):
    """Bowl-shaped loss surface. Minimum at (0, 0)."""
    return x**2 + y**2

def gradient_fn(x, y):
    """Gradient of the loss function."""
    return np.array([2*x, 2*y])

def gradient_descent(start, lr, n_steps=50):
    """Run vanilla gradient descent."""
    position = np.array(start, dtype=float)
    path = [position.copy()]
    losses = [loss_fn(position[0], position[1])]
    
    for _ in range(n_steps):
        grad = gradient_fn(position[0], position[1])
        position = position - lr * grad
        path.append(position.copy())
        losses.append(loss_fn(position[0], position[1]))
    
    return np.array(path), np.array(losses)

print("✓ Functions defined")

In [None]:
# ===== COLLEAGUE'S CODE (CONTAINS BUG) =====

# "I want to train faster, so I'll use a big learning rate!"

START_POSITION = [-2.0, 2.0]
LEARNING_RATE = 1.2  # <-- BUG: Way too high!

path_bad, losses_bad = gradient_descent(START_POSITION, LEARNING_RATE, n_steps=20)

print("=== Colleague's Training ===")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Starting position: {START_POSITION}")
print(f"\nLoss over training:")
for i in range(min(10, len(losses_bad))):
    if np.isfinite(losses_bad[i]):
        print(f"  Step {i}: {losses_bad[i]:.2f}")
    else:
        print(f"  Step {i}: EXPLODED!")
        break

In [None]:
# Visualize what's happening
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
Z = loss_fn(X, Y)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Path on contour plot
ax1 = axes[0]
ax1.contour(X, Y, Z, levels=20, cmap='viridis', alpha=0.5)
path_clipped = np.clip(path_bad, -5, 5)
ax1.plot(path_clipped[:, 0], path_clipped[:, 1], 'r-o', markersize=4, linewidth=1, label='GD Path')
ax1.scatter([0], [0], color='green', s=100, marker='*', zorder=5, label='Minimum')
ax1.scatter([START_POSITION[0]], [START_POSITION[1]], color='blue', s=100, marker='s', zorder=5, label='Start')
ax1.set_xlabel('θ₁')
ax1.set_ylabel('θ₂')
ax1.set_title(f'Gradient Descent Path (lr={LEARNING_RATE})')
ax1.legend()
ax1.set_xlim(-5, 5)
ax1.set_ylim(-5, 5)

# Loss curve
ax2 = axes[1]
losses_clipped = np.clip(losses_bad, 0, 100)
ax2.plot(losses_clipped, 'r-', linewidth=2)
ax2.set_xlabel('Step')
ax2.set_ylabel('Loss')
ax2.set_title('Loss Over Time (EXPLODING!)')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n❌ The optimizer is DIVERGING!")
print("   Instead of descending toward the minimum, it's overshooting and bouncing away.")

---

## Your Investigation

The loss is exploding instead of decreasing. Classic sign of a learning rate that's too high.

### Step 1: Understand the math

In [None]:
# Why does this happen?
print("=== Why Learning Rate Matters ===")
print()
print("For a quadratic loss f(x) = x², the gradient is g(x) = 2x")
print()
print("Update rule: x_new = x - lr * gradient")
print("           = x - lr * 2x")
print("           = x * (1 - 2*lr)")
print()
print("For the update to move TOWARD zero:")
print("  |1 - 2*lr| < 1")
print("  → -1 < 1 - 2*lr < 1")
print("  → 0 < lr < 1")
print()
print(f"Colleague's lr = {LEARNING_RATE}")
print(f"1 - 2*lr = {1 - 2*LEARNING_RATE}")
print()
if abs(1 - 2*LEARNING_RATE) > 1:
    print("❌ |1 - 2*lr| > 1, so updates AMPLIFY the error!")
else:
    print("✓ |1 - 2*lr| < 1, so updates reduce the error.")

In [None]:
# Test different learning rates
test_lrs = [0.01, 0.1, 0.3, 0.5, 0.8, 0.95, 1.0, 1.2]
print("Learning Rate Behavior:")
print("-" * 50)

for lr in test_lrs:
    path, losses = gradient_descent(START_POSITION, lr, n_steps=50)
    final_loss = losses[-1] if np.isfinite(losses[-1]) else float('inf')
    
    if final_loss < 0.01:
        status = "✓ Converged"
    elif final_loss < losses[0]:
        status = "~ Slow progress"
    elif np.isfinite(final_loss):
        status = "⚠ Oscillating"
    else:
        status = "❌ DIVERGED"
    
    print(f"lr={lr:<4} → {status} (final loss: {final_loss:.4f})")

### Step 2: Find a good learning rate

In [None]:
# TODO: Find a learning rate that converges reliably

# Uncomment and complete:

# LEARNING_RATE_FIXED = ???  # Pick from the experiment above (0.1 is usually safe)
# 
# path_fixed, losses_fixed = gradient_descent(START_POSITION, LEARNING_RATE_FIXED, n_steps=50)
# 
# print(f"=== Fixed Training (lr={LEARNING_RATE_FIXED}) ===")
# print(f"Starting loss: {losses_fixed[0]:.4f}")
# print(f"Final loss: {losses_fixed[-1]:.6f}")
# print(f"Final position: ({path_fixed[-1, 0]:.4f}, {path_fixed[-1, 1]:.4f})")
# print(f"\nConverged: {losses_fixed[-1] < 0.01}")

In [None]:
# TODO: Visualize the comparison

# Uncomment:

# fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# 
# # Paths
# ax1 = axes[0]
# ax1.contour(X, Y, Z, levels=20, cmap='viridis', alpha=0.5)
# ax1.plot(np.clip(path_bad[:10], -5, 5)[:, 0], np.clip(path_bad[:10], -5, 5)[:, 1], 
#          'r-o', markersize=4, linewidth=1, label=f'Divergent (lr={LEARNING_RATE})', alpha=0.7)
# ax1.plot(path_fixed[:, 0], path_fixed[:, 1], 
#          'b-o', markersize=4, linewidth=1, label=f'Converged (lr={LEARNING_RATE_FIXED})')
# ax1.scatter([0], [0], color='green', s=100, marker='*', zorder=5)
# ax1.set_xlabel('θ₁')
# ax1.set_ylabel('θ₂')
# ax1.set_title('Gradient Descent Paths')
# ax1.legend()
# ax1.set_xlim(-5, 5)
# ax1.set_ylim(-5, 5)
# 
# # Loss curves
# ax2 = axes[1]
# ax2.plot(np.clip(losses_bad, 0, 50), 'r-', linewidth=2, label=f'lr={LEARNING_RATE} (diverges)')
# ax2.plot(losses_fixed, 'b-', linewidth=2, label=f'lr={LEARNING_RATE_FIXED} (converges)')
# ax2.set_xlabel('Step')
# ax2.set_ylabel('Loss')
# ax2.set_title('Loss Over Time')
# ax2.legend()
# ax2.grid(True, alpha=0.3)
# 
# plt.tight_layout()
# plt.show()

In [None]:
# ============================================
# SELF-CHECK: Did you fix the divergence?
# ============================================

# Uncomment:

# assert losses_fixed[-1] < losses_fixed[0], "Loss should decrease, not increase"
# assert losses_fixed[-1] < 0.1, "Should converge close to minimum"
# assert LEARNING_RATE_FIXED < LEARNING_RATE, "Fixed lr should be smaller"
# 
# print("✓ Divergence fixed!")
# print(f"✓ Changed lr from {LEARNING_RATE} to {LEARNING_RATE_FIXED}")
# print(f"✓ Loss now decreases from {losses_fixed[0]:.2f} to {losses_fixed[-1]:.6f}")

### Step 3: Write your postmortem

In [None]:
postmortem = """
## Postmortem: The Exploding Loss

### What happened:
- (Your answer: What symptom indicated divergence?)

### Root cause:
- (Your answer: Why was the learning rate too high?)

### How to prevent:
- (Your answer: What's a safe starting point for learning rate? How would you tune it?)

"""

print(postmortem)

---

## ✅ Drill Complete!

**Key lessons:**

1. **Too high learning rate → divergence.** The optimizer overshoots the minimum and bounces away.

2. **The symptom:** Loss increases instead of decreases, or goes to infinity/NaN.

3. **The fix:** Reduce learning rate. Start with 0.001-0.1, then tune.

4. **Rule of thumb:** If loss explodes, cut learning rate by 10x.

---

## Learning Rate Troubleshooting

| Symptom | Likely Cause | Fix |
|---------|-------------|-----|
| Loss explodes/NaN | LR too high | Reduce by 10x |
| Loss oscillates wildly | LR slightly too high | Reduce by 2-5x |
| Loss decreases very slowly | LR too low | Increase by 2-10x |
| Loss plateaus early | LR too low (stuck) | Increase, or use scheduler |
| Loss bounces at end | LR too high for fine-tuning | Use LR decay/scheduler |