# 04: Automatic Differentiation

**Module 1.1: Calculus & Optimization**

## Learning Objectives

By the end of this notebook, you will:
1. Understand why numerical differentiation fails at scale
2. Learn how automatic differentiation (autodiff) works
3. Trace through computation graphs manually
4. Understand forward vs backward mode autodiff
5. See how PyTorch's autograd implements backpropagation

## Resources
- Solomon, *Numerical Algorithms*, §14.3.5
- Ananthaswamy, *Why Machines Learn*, Chapter 5

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)
plt.rcParams['figure.figsize'] = (10, 6)

---
## 1. Three Ways to Compute Derivatives

| Method | Accuracy | Speed | How |
|--------|----------|-------|-----|
| **Numerical** | Limited | $O(n)$ evals | Finite differences |
| **Symbolic** | Exact | Expression explosion | Algebraic rules |
| **Autodiff** | Machine precision | $O(1)$ passes | Chain rule on graph |

---
## 2. Why Numerical Differentiation Fails

Numerical differentiation: $f'(x) \approx \frac{f(x+\epsilon) - f(x)}{\epsilon}$

### Two Problems:
1. **Precision:** Too small $\epsilon$ → floating point errors
2. **Speed:** Need $O(n)$ function evaluations for $n$ parameters

In [None]:
# Problem 1: Precision issues
def f(x):
    return x**2

x = 1.0
true_derivative = 2.0  # f'(x) = 2x

print("Numerical differentiation precision:")
print(f"{'epsilon':<15} {'approx deriv':<15} {'error':<15}")
print("-" * 45)

for exp in range(1, 17):
    eps = 10**(-exp)
    approx = (f(x + eps) - f(x)) / eps
    error = abs(approx - true_derivative)
    print(f"1e-{exp:<13} {approx:<15.10f} {error:<15.2e}")

print("\n⚠️ Error increases for very small epsilon (floating point issues)!")

In [None]:
# Problem 2: Speed - O(n) evaluations
def numerical_gradient(f, x, eps=1e-7):
    """Compute gradient numerically - requires n+1 function evals."""
    n = len(x)
    grad = np.zeros(n)
    f_x = f(x)
    
    for i in range(n):
        x_plus = x.copy()
        x_plus[i] += eps
        grad[i] = (f(x_plus) - f_x) / eps
    
    return grad

# For n=1,000,000 parameters, need 1,000,001 function evaluations!
print("Function evaluations needed for gradient:")
for n in [10, 100, 1000, 1000000]:
    print(f"  n = {n:>10,} params → {n+1:>10,} evaluations")

---
## 3. Automatic Differentiation: The Key Idea

Autodiff builds a **computation graph** during the forward pass, then uses the **chain rule** to compute gradients in the backward pass.

### Example: $y = x_0^2 + x_0 \cdot x_1 + x_1^3$

```
x[0]=2    x[1]=3
   │         │
   ▼         ▼
  (²)       (³)
   │         │
   ▼         ▼
   4        27
   │         │
   │    ┌────┴────┐
   │    │         │
   │   (×)────────┤
   │    │         │
   │    6         │
   │    │         │
   └────┼─────────┘
        │
       (+)
        │
        ▼
       y=37
```

In [None]:
# Build computation manually
x0, x1 = 2.0, 3.0

# Forward pass - compute and store intermediate values
a = x0**2       # a = 4
b = x1**3       # b = 27
c = x0 * x1     # c = 6
y = a + c + b   # y = 37

print("Forward pass:")
print(f"  a = x0² = {a}")
print(f"  b = x1³ = {b}")
print(f"  c = x0·x1 = {c}")
print(f"  y = a + c + b = {y}")

---
## 4. Backward Pass: Chain Rule

Start with $\frac{\partial y}{\partial y} = 1$ and propagate backward.

At each node, multiply **upstream gradient** by **local derivative**.

In [None]:
# Backward pass - manual
x0, x1 = 2.0, 3.0

# Forward pass (storing intermediates)
a = x0**2
b = x1**3
c = x0 * x1
y = a + c + b

# Backward pass
dy_dy = 1.0  # Start here

# y = a + c + b
dy_da = dy_dy * 1  # ∂(a+c+b)/∂a = 1
dy_dc = dy_dy * 1  # ∂(a+c+b)/∂c = 1
dy_db = dy_dy * 1  # ∂(a+c+b)/∂b = 1

# a = x0², so ∂a/∂x0 = 2*x0
da_dx0 = 2 * x0  # = 4

# b = x1³, so ∂b/∂x1 = 3*x1²
db_dx1 = 3 * x1**2  # = 27

# c = x0*x1, so ∂c/∂x0 = x1, ∂c/∂x1 = x0
dc_dx0 = x1  # = 3
dc_dx1 = x0  # = 2

# Combine paths to x0 (from a and c)
dy_dx0 = dy_da * da_dx0 + dy_dc * dc_dx0
# = 1 * 4 + 1 * 3 = 7

# Combine paths to x1 (from b and c)
dy_dx1 = dy_db * db_dx1 + dy_dc * dc_dx1
# = 1 * 27 + 1 * 2 = 29

print("Backward pass (manual):")
print(f"  ∂y/∂x0 = {dy_dx0}")
print(f"  ∂y/∂x1 = {dy_dx1}")

# Verify with PyTorch
x = torch.tensor([2.0, 3.0], requires_grad=True)
y = x[0]**2 + x[0]*x[1] + x[1]**3
y.backward()

print(f"\nPyTorch autograd:")
print(f"  ∂y/∂x0 = {x.grad[0].item()}")
print(f"  ∂y/∂x1 = {x.grad[1].item()}")
print("\n✓ Match!")

---
## 5. Forward vs Backward Mode

| Mode | Direction | Cost | Best For |
|------|-----------|------|----------|
| **Forward** | Input → Output | $O(n)$ passes | Few inputs, many outputs |
| **Backward** | Output → Input | $O(m)$ passes | Many inputs, few outputs |

**Neural networks:** Millions of inputs (parameters), ONE output (loss)

→ **Backward mode wins!** This is why it's called **back**propagation.

In [None]:
# Comparison: Forward mode would need n passes
# Backward mode needs 1 pass

n_params = [10, 100, 1000, 10000, 100000, 1000000]

print("Passes needed for full gradient:")
print(f"{'# params':<12} {'Forward mode':<15} {'Backward mode':<15}")
print("-" * 42)
for n in n_params:
    print(f"{n:<12,} {n:<15,} {1:<15}")

print("\n→ Backward mode is O(1) regardless of parameter count!")

---
## 6. Forward Mode: Dual Numbers

Forward autodiff uses "dual numbers": $[u, u']$ where $u' = \frac{du}{dt}$

### Rules:
- $[u, u'] + [v, v'] = [u+v, u'+v']$
- $[u, u'] \cdot [v, v'] = [uv, uv' + u'v]$
- $\exp([u, u']) = [e^u, u'e^u]$

In [None]:
class DualNumber:
    """Simple dual number for forward-mode autodiff."""
    def __init__(self, value, deriv=0.0):
        self.value = value
        self.deriv = deriv
    
    def __add__(self, other):
        if isinstance(other, (int, float)):
            return DualNumber(self.value + other, self.deriv)
        return DualNumber(self.value + other.value, self.deriv + other.deriv)
    
    def __radd__(self, other):
        return self.__add__(other)
    
    def __mul__(self, other):
        if isinstance(other, (int, float)):
            return DualNumber(self.value * other, self.deriv * other)
        return DualNumber(self.value * other.value, 
                         self.value * other.deriv + self.deriv * other.value)
    
    def __rmul__(self, other):
        return self.__mul__(other)
    
    def __pow__(self, n):
        return DualNumber(self.value**n, n * self.value**(n-1) * self.deriv)
    
    def __repr__(self):
        return f"Dual({self.value}, {self.deriv})"

# Example: f(x) = x³ + 2x, find f'(3)
x = DualNumber(3, 1)  # x = 3, dx/dx = 1
y = x**3 + 2*x

print(f"f(x) = x³ + 2x at x = 3")
print(f"Result: {y}")
print(f"f(3) = {y.value}, f'(3) = {y.deriv}")
print(f"\nVerify: f'(x) = 3x² + 2, f'(3) = 3(9) + 2 = 29 ✓")

---
## 7. PyTorch Autograd Internals

When you set `requires_grad=True`, PyTorch:
1. Builds a computation graph during forward pass
2. Stores `grad_fn` on each tensor (the backward function)
3. Calls `.backward()` to traverse graph in reverse

In [None]:
# Peek inside PyTorch's computation graph
x = torch.tensor([2.0, 3.0], requires_grad=True)

# Build graph
a = x[0]**2
b = x[1]**3
c = x[0] * x[1]
y = a + b + c

print("Computation graph:")
print(f"  x.grad_fn = {x.grad_fn}")
print(f"  a.grad_fn = {a.grad_fn}")
print(f"  b.grad_fn = {b.grad_fn}")
print(f"  c.grad_fn = {c.grad_fn}")
print(f"  y.grad_fn = {y.grad_fn}")

print("\nBackward graph (y's inputs):")
print(f"  {y.grad_fn.next_functions}")

In [None]:
# Hook to see gradients flowing through
def print_grad(name):
    def hook(grad):
        print(f"  Gradient at {name}: {grad}")
    return hook

x = torch.tensor([2.0, 3.0], requires_grad=True)
a = x[0]**2
b = x[1]**3  
c = x[0] * x[1]

# Register hooks
a.register_hook(print_grad('a'))
b.register_hook(print_grad('b'))
c.register_hook(print_grad('c'))

y = a + b + c

print("Backward pass gradients:")
y.backward()
print(f"\nFinal gradient: x.grad = {x.grad}")

---
## 8. Genomics Application: Gradient of VAE Loss

In scVI, the loss combines reconstruction + KL divergence:

$$L = -\log p(x|z) + \text{KL}(q(z|x) \| p(z))$$

Autodiff handles this complex loss automatically!

In [None]:
# Simplified VAE loss
def vae_loss(x, x_recon, mu, logvar):
    """VAE loss = Reconstruction + KL divergence."""
    # Reconstruction loss (MSE for simplicity)
    recon_loss = torch.mean((x - x_recon)**2)
    
    # KL divergence: -0.5 * sum(1 + log(σ²) - μ² - σ²)
    kl_loss = -0.5 * torch.mean(1 + logvar - mu**2 - torch.exp(logvar))
    
    return recon_loss + kl_loss

# Simulated data
x = torch.randn(100)  # Gene expression
x_recon = torch.randn(100, requires_grad=True)
mu = torch.randn(10, requires_grad=True)
logvar = torch.randn(10, requires_grad=True)

loss = vae_loss(x, x_recon, mu, logvar)
loss.backward()

print(f"VAE Loss: {loss.item():.4f}")
print(f"\nGradients computed automatically:")
print(f"  ∂L/∂x_recon shape: {x_recon.grad.shape}")
print(f"  ∂L/∂mu shape: {mu.grad.shape}")
print(f"  ∂L/∂logvar shape: {logvar.grad.shape}")
print("\n✓ Complex loss, simple gradient computation!")

---
## Exercises

### Exercise 1: Manual Backprop
For $y = (x_1 + x_2) \cdot x_3$ at $(1, 2, 3)$, trace through the backward pass manually.

### Exercise 2: Dual Numbers
Extend the DualNumber class to support `sin` and `cos`.

### Exercise 3: Graph Inspection
Build a 3-layer network and print the `grad_fn` chain.

### Exercise 4: Gradient Checkpointing
Research: What is gradient checkpointing and why is it useful for large models?

In [None]:
# Your solutions here


---
## Summary

| Concept | Key Point |
|---------|----------|
| **Numerical diff** | $O(n)$ evals, precision issues |
| **Autodiff** | Exact, $O(1)$ backward passes |
| **Computation graph** | Stores values + local derivatives |
| **Backward pass** | upstream × local, sum at branches |
| **Forward mode** | Dual numbers $[u, u']$ |
| **Backward mode** | Better for many params → 1 loss |

## Next: 05_genomics_applications.ipynb