# Task 6.3: Autograd Deep Dive - Custom Automatic Differentiation

**Module:** 6 - Deep Learning with PyTorch  
**Time:** 2 hours  
**Difficulty:** ‚≠ê‚≠ê‚≠ê (Intermediate-Advanced)

---

## Learning Objectives

By the end of this notebook, you will:
- [ ] Understand how PyTorch's autograd system works
- [ ] Implement custom autograd functions with `torch.autograd.Function`
- [ ] Create novel activation functions (Swish, Mish)
- [ ] Verify gradients using `torch.autograd.gradcheck`
- [ ] Use hooks for model introspection

---

## Prerequisites

- Completed: Tasks 6.1, 6.2
- Knowledge of: Calculus (derivatives), backpropagation

---

## Real-World Context

Sometimes PyTorch's built-in operations aren't enough. You might need custom autograd functions for:

- **Novel activation functions**: Research new architectures
- **Memory-efficient backprop**: Gradient checkpointing
- **Custom loss functions**: Domain-specific objectives
- **Numerical stability**: Special handling for edge cases
- **Hardware optimization**: Custom CUDA kernels

Companies like OpenAI, DeepMind, and Meta regularly implement custom autograd functions for their research.

---

## ELI5: What is Automatic Differentiation?

> **Imagine you're tracking a recipe...** üìù
>
> You make a cake by:
> 1. Mix flour and eggs ‚Üí batter
> 2. Add sugar ‚Üí sweet batter
> 3. Bake ‚Üí cake
>
> Now the cake came out wrong. You want to know: "How much did each ingredient affect the final result?"
>
> Autograd is like keeping a detailed journal of every step. When you taste the cake, you can trace back:
> - "The cake is too sweet ‚Üí sweet batter contributed too much ‚Üí I used too much sugar!"
>
> **In AI terms:** Each operation records how it transforms inputs. During backpropagation, we reverse the journal to compute gradients (how much each parameter affected the loss).
>
> The "computational graph" is your recipe journal!

---

## Part 1: Understanding Autograd Basics

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function, gradcheck
import numpy as np
import matplotlib.pyplot as plt
import time

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

In [None]:
# Basic autograd example
x = torch.tensor([2.0, 3.0], requires_grad=True)
y = x ** 2 + 2 * x + 1  # Polynomial: x¬≤ + 2x + 1
z = y.sum()  # Need scalar for backward

print(f"x = {x}")
print(f"y = x¬≤ + 2x + 1 = {y}")
print(f"z = sum(y) = {z}")

# Compute gradients
z.backward()

# Gradient should be: dy/dx = 2x + 2
# At x = [2, 3]: gradients = [6, 8]
print(f"\nGradient (dy/dx = 2x + 2): {x.grad}")
print(f"Expected: {2 * x.detach() + 2}")

### The Computational Graph

Every tensor with `requires_grad=True` records operations in a DAG (Directed Acyclic Graph). Let's visualize this:

In [None]:
# Exploring the computational graph
a = torch.tensor([1.0], requires_grad=True)
b = torch.tensor([2.0], requires_grad=True)

c = a * b  # Multiplication
d = c + a  # Addition
e = d.relu()  # ReLU activation

print("=== Computational Graph ===")
print(f"a: {a}, requires_grad={a.requires_grad}")
print(f"b: {b}, requires_grad={b.requires_grad}")
print(f"c = a * b = {c}, grad_fn={c.grad_fn}")
print(f"d = c + a = {d}, grad_fn={d.grad_fn}")
print(f"e = relu(d) = {e}, grad_fn={e.grad_fn}")

# Backpropagate
e.backward()

print(f"\nGradients:")
print(f"da/de = {a.grad}")
print(f"db/de = {b.grad}")

### What Just Happened?

The `grad_fn` attribute shows which operation created each tensor:
- `MulBackward0` - multiplication
- `AddBackward0` - addition
- `ReluBackward0` - ReLU

When we call `backward()`, PyTorch traverses this graph in reverse, applying the chain rule!

---

## Part 2: Custom Autograd Functions

To create a custom operation with gradients, we extend `torch.autograd.Function`:

```python
class MyFunction(Function):
    @staticmethod
    def forward(ctx, input):
        # Compute output
        # Save tensors for backward using ctx.save_for_backward()
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        # Retrieve saved tensors
        # Compute gradient with respect to input
        return grad_input
```

Let's implement some custom activation functions!

### Custom Activation 1: Swish

**Swish** (also called SiLU) was discovered by Google Brain through neural architecture search.

$$\text{Swish}(x) = x \cdot \sigma(x) = \frac{x}{1 + e^{-x}}$$

Derivative:
$$\frac{d}{dx}\text{Swish}(x) = \sigma(x) + x \cdot \sigma(x) \cdot (1 - \sigma(x))$$

In [None]:
class SwishFunction(Function):
    """
    Custom autograd function for Swish activation.
    
    Swish(x) = x * sigmoid(x)
    
    This is also known as SiLU (Sigmoid Linear Unit) and is used
    in modern architectures like EfficientNet and Transformers.
    """
    
    @staticmethod
    def forward(ctx, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: compute Swish(x) = x * sigmoid(x)
        
        Args:
            ctx: Context object to save tensors for backward
            x: Input tensor
        
        Returns:
            Swish activation applied to x
        """
        sigmoid_x = torch.sigmoid(x)
        result = x * sigmoid_x
        
        # Save sigmoid for backward (saves recomputing it)
        ctx.save_for_backward(x, sigmoid_x)
        
        return result
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
        """
        Backward pass: compute gradient of Swish.
        
        d/dx Swish(x) = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
                      = sigmoid(x) * (1 + x * (1 - sigmoid(x)))
        
        Args:
            ctx: Context with saved tensors
            grad_output: Gradient from downstream
        
        Returns:
            Gradient with respect to input
        """
        x, sigmoid_x = ctx.saved_tensors
        
        # Compute gradient
        # d/dx [x * sigmoid(x)] = sigmoid(x) + x * sigmoid(x) * (1 - sigmoid(x))
        grad_input = sigmoid_x * (1 + x * (1 - sigmoid_x))
        
        # Chain rule: multiply by upstream gradient
        return grad_output * grad_input


# Create a convenient wrapper
def swish_custom(x: torch.Tensor) -> torch.Tensor:
    """Apply custom Swish activation."""
    return SwishFunction.apply(x)


# Test our implementation
x = torch.randn(5, requires_grad=True)
y = swish_custom(x)
y.sum().backward()

print("=== Custom Swish Test ===")
print(f"Input: {x.data}")
print(f"Output: {y.data}")
print(f"Gradient: {x.grad}")

### Verifying Gradients with gradcheck

`gradcheck` computes numerical gradients (finite differences) and compares them to our analytical gradients.

In [None]:
# Verify gradients using numerical differentiation
x_test = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)

# gradcheck uses float64 for numerical precision
test_passed = gradcheck(SwishFunction.apply, (x_test,), eps=1e-6, atol=1e-4, rtol=1e-3)

print(f"Gradient check passed: {test_passed}")

if test_passed:
    print("Our custom gradient is correct!")
else:
    print("WARNING: Gradient mismatch detected!")

### Custom Activation 2: Mish

**Mish** is a self-regularized non-monotonic activation function.

$$\text{Mish}(x) = x \cdot \tanh(\text{softplus}(x)) = x \cdot \tanh(\ln(1 + e^x))$$

The derivative is more complex:
$$\frac{d}{dx}\text{Mish}(x) = \frac{e^x \cdot \omega}{\delta^2}$$

where:
- $\omega = 4(x+1) + 4e^{2x} + e^{3x} + e^x(4x+6)$
- $\delta = 2e^x + e^{2x} + 2$

In [None]:
class MishFunction(Function):
    """
    Custom autograd function for Mish activation.
    
    Mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^x))
    
    Paper: "Mish: A Self Regularized Non-Monotonic Activation Function"
    https://arxiv.org/abs/1908.08681
    """
    
    @staticmethod
    def forward(ctx, x: torch.Tensor) -> torch.Tensor:
        """Forward pass: compute Mish(x)."""
        # softplus(x) = ln(1 + e^x)
        softplus_x = F.softplus(x)
        tanh_softplus = torch.tanh(softplus_x)
        result = x * tanh_softplus
        
        # Save for backward
        ctx.save_for_backward(x, tanh_softplus)
        
        return result
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
        """
        Backward pass: compute gradient of Mish.
        
        Using the formula:
        d/dx Mish(x) = tanh(sp) + x * sech¬≤(sp) * œÉ(x)
        where sp = softplus(x), œÉ = sigmoid
        """
        x, tanh_sp = ctx.saved_tensors
        
        # Recompute softplus and sigmoid
        sp = F.softplus(x)
        sigmoid_x = torch.sigmoid(x)
        
        # sech¬≤(x) = 1 - tanh¬≤(x)
        sech2_sp = 1 - tanh_sp ** 2
        
        # Gradient: tanh(softplus(x)) + x * sech¬≤(softplus(x)) * sigmoid(x)
        grad_input = tanh_sp + x * sech2_sp * sigmoid_x
        
        return grad_output * grad_input


def mish_custom(x: torch.Tensor) -> torch.Tensor:
    """Apply custom Mish activation."""
    return MishFunction.apply(x)


# Verify Mish gradients
x_test = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
test_passed = gradcheck(MishFunction.apply, (x_test,), eps=1e-6, atol=1e-4, rtol=1e-3)
print(f"Mish gradient check passed: {test_passed}")

### Visualizing Activation Functions

In [None]:
# Compare activation functions
x = torch.linspace(-5, 5, 200)

activations = {
    'ReLU': F.relu(x),
    'Swish (Custom)': swish_custom(x.clone().requires_grad_(True)).detach(),
    'Mish (Custom)': mish_custom(x.clone().requires_grad_(True)).detach(),
    'GELU': F.gelu(x),
    'Tanh': torch.tanh(x),
}

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot activations
ax1 = axes[0]
for name, y in activations.items():
    ax1.plot(x.numpy(), y.numpy(), label=name, linewidth=2)
ax1.axhline(y=0, color='k', linestyle='-', linewidth=0.5)
ax1.axvline(x=0, color='k', linestyle='-', linewidth=0.5)
ax1.set_xlabel('x')
ax1.set_ylabel('f(x)')
ax1.set_title('Activation Functions')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot derivatives
ax2 = axes[1]

# Compute derivatives numerically for visualization
for name, _ in activations.items():
    x_grad = x.clone().requires_grad_(True)
    if name == 'ReLU':
        y = F.relu(x_grad)
    elif 'Swish' in name:
        y = swish_custom(x_grad)
    elif 'Mish' in name:
        y = mish_custom(x_grad)
    elif name == 'GELU':
        y = F.gelu(x_grad)
    else:  # Tanh
        y = torch.tanh(x_grad)
    
    y.sum().backward()
    ax2.plot(x.numpy(), x_grad.grad.numpy(), label=name, linewidth=2)

ax2.axhline(y=0, color='k', linestyle='-', linewidth=0.5)
ax2.axhline(y=1, color='gray', linestyle='--', linewidth=0.5)
ax2.axvline(x=0, color='k', linestyle='-', linewidth=0.5)
ax2.set_xlabel('x')
ax2.set_ylabel("f'(x)")
ax2.set_title('Activation Derivatives')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

### Key Observations:

1. **ReLU**: Derivative is 0 for x < 0 ("dead neurons" problem)
2. **Swish/Mish**: Non-monotonic, smooth, with small negative values
3. **Tanh**: Derivatives saturate to 0 for large |x| (vanishing gradients)

Swish and Mish avoid dead neurons while maintaining smooth gradients!

---

## Part 3: Using Hooks for Introspection

Hooks let you inspect or modify:
- **Forward hooks**: Activations during forward pass
- **Backward hooks**: Gradients during backward pass

This is incredibly useful for debugging and understanding model behavior!

In [None]:
class SimpleNet(nn.Module):
    """Simple network for hook demonstration."""
    
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 10)
        self.fc3 = nn.Linear(10, 2)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = SimpleNet()
print(model)

In [None]:
# Storage for activations and gradients
activations = {}
gradients = {}

def get_activation_hook(name):
    """Create a forward hook that saves activations."""
    def hook(module, input, output):
        activations[name] = output.detach()
    return hook

def get_gradient_hook(name):
    """Create a backward hook that saves gradients."""
    def hook(module, grad_input, grad_output):
        gradients[name] = grad_output[0].detach()
    return hook

# Register hooks
handles = []
for name, layer in model.named_modules():
    if isinstance(layer, nn.Linear):
        handles.append(layer.register_forward_hook(get_activation_hook(name)))
        handles.append(layer.register_full_backward_hook(get_gradient_hook(name)))

# Run forward and backward
x = torch.randn(1, 10)
y = model(x)
loss = y.sum()
loss.backward()

print("=== Activations ===")
for name, act in activations.items():
    print(f"{name}: shape={act.shape}, mean={act.mean():.4f}, std={act.std():.4f}")

print("\n=== Gradients ===")
for name, grad in gradients.items():
    print(f"{name}: shape={grad.shape}, mean={grad.mean():.4f}, std={grad.std():.4f}")

# Clean up hooks
for h in handles:
    h.remove()

### Practical Application: Gradient Clipping Hook

In [None]:
def gradient_clip_hook(max_norm: float):
    """
    Create a hook that clips gradients during backward pass.
    
    This can help prevent exploding gradients!
    """
    def hook(module, grad_input, grad_output):
        # Clip each gradient tensor
        clipped_grads = []
        for grad in grad_input:
            if grad is not None:
                norm = grad.norm()
                if norm > max_norm:
                    grad = grad * max_norm / norm
            clipped_grads.append(grad)
        return tuple(clipped_grads)
    return hook

# Apply to a layer
model2 = SimpleNet()
handle = model2.fc1.register_full_backward_hook(gradient_clip_hook(1.0))

# Test
x = torch.randn(1, 10)
y = model2(x)
(y * 1000).sum().backward()  # Large gradient

print(f"fc1 input grad norm: {model2.fc1.weight.grad.norm():.4f}")

handle.remove()

---

## Part 4: Benchmarking Custom vs Built-in

Let's compare our custom implementations against PyTorch's built-in versions.

In [None]:
def benchmark_activation(activation_fn, name, input_size=(1000, 1000), num_runs=100):
    """
    Benchmark an activation function.
    
    Returns:
        Tuple of (forward_time_ms, backward_time_ms)
    """
    x = torch.randn(*input_size, device=device, requires_grad=True)
    
    # Warmup
    for _ in range(10):
        y = activation_fn(x)
        y.sum().backward()
        x.grad = None
    
    if device.type == 'cuda':
        torch.cuda.synchronize()
    
    # Benchmark forward
    start = time.time()
    for _ in range(num_runs):
        y = activation_fn(x)
    if device.type == 'cuda':
        torch.cuda.synchronize()
    forward_time = (time.time() - start) / num_runs * 1000
    
    # Benchmark backward
    start = time.time()
    for _ in range(num_runs):
        y = activation_fn(x)
        y.sum().backward()
        x.grad = None
    if device.type == 'cuda':
        torch.cuda.synchronize()
    backward_time = (time.time() - start) / num_runs * 1000 - forward_time
    
    return forward_time, backward_time


# Benchmark different activations
activations_to_benchmark = [
    ('ReLU (built-in)', F.relu),
    ('Swish (custom)', swish_custom),
    ('Swish (built-in F.silu)', F.silu),
    ('Mish (custom)', mish_custom),
    ('Mish (built-in)', F.mish),
    ('GELU (built-in)', F.gelu),
]

print(f"Benchmarking on {device}...")
print("="*60)

for name, fn in activations_to_benchmark:
    try:
        fwd, bwd = benchmark_activation(fn, name)
        print(f"{name:25s} | Forward: {fwd:.3f}ms | Backward: {bwd:.3f}ms")
    except Exception as e:
        print(f"{name:25s} | Error: {e}")

print("="*60)

### Observations

- Built-in functions are typically faster due to optimized C++/CUDA kernels
- Our custom implementations are mathematically correct but slightly slower
- For production, prefer built-in functions when available
- Custom functions are great for prototyping new ideas!

---

## ‚úã Try It Yourself: Exercise

Implement a custom autograd function for **Hard Swish**:

$$\text{HardSwish}(x) = x \cdot \frac{\text{ReLU6}(x + 3)}{6}$$

Where ReLU6(x) = min(max(0, x), 6)

This is a computationally cheaper approximation of Swish!

<details>
<summary>üí° Hint</summary>

The derivative is piecewise:
- For x < -3: 0
- For -3 <= x <= 3: (2x + 3) / 6
- For x > 3: 1

</details>

In [None]:
# YOUR CODE HERE: Implement HardSwishFunction
class HardSwishFunction(Function):
    """
    Custom autograd function for Hard Swish activation.
    
    HardSwish(x) = x * ReLU6(x + 3) / 6
    """
    
    @staticmethod
    def forward(ctx, x: torch.Tensor) -> torch.Tensor:
        # TODO: Implement forward pass
        pass
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
        # TODO: Implement backward pass
        pass

# Test your implementation
# x_test = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)
# test_passed = gradcheck(HardSwishFunction.apply, (x_test,), eps=1e-6)
# print(f"HardSwish gradient check passed: {test_passed}")

---

## Common Mistakes

### Mistake 1: Not using `ctx.save_for_backward`

```python
# ‚ùå Wrong - storing tensors as attributes
class BadFunction(Function):
    @staticmethod
    def forward(ctx, x):
        ctx.x = x  # This can cause memory issues!
        return x * 2

# ‚úÖ Right - use save_for_backward
class GoodFunction(Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)  # Proper memory management
        return x * 2
```

### Mistake 2: Modifying tensors in-place

```python
# ‚ùå Wrong - in-place modification breaks autograd
@staticmethod
def forward(ctx, x):
    x.mul_(2)  # In-place!
    return x

# ‚úÖ Right - create new tensor
@staticmethod
def forward(ctx, x):
    return x * 2  # New tensor
```

### Mistake 3: Forgetting to handle multiple inputs

```python
# ‚ùå Wrong - not returning gradient for each input
@staticmethod
def backward(ctx, grad_output):
    return grad_x  # Only one gradient, but forward had 2 inputs!

# ‚úÖ Right - return gradient for each input
@staticmethod
def backward(ctx, grad_output):
    return grad_x, grad_y  # One per input
```

---

## Checkpoint

You've learned:
- ‚úÖ How PyTorch's autograd builds computational graphs
- ‚úÖ Creating custom `torch.autograd.Function` classes
- ‚úÖ Implementing Swish and Mish activations from scratch
- ‚úÖ Verifying gradients with `gradcheck`
- ‚úÖ Using hooks for model introspection

---

## Challenge (Optional)

Implement **Gradient Checkpointing** as a custom autograd function!

Gradient checkpointing trades compute for memory by not saving intermediate activations. During backward, it recomputes the forward pass to get the activations.

This is how large models like GPT-3 fit in GPU memory!

---

## Further Reading

- [PyTorch Autograd Mechanics](https://pytorch.org/docs/stable/notes/autograd.html)
- [Extending PyTorch](https://pytorch.org/docs/stable/notes/extending.html)
- [Swish Paper](https://arxiv.org/abs/1710.05941)
- [Mish Paper](https://arxiv.org/abs/1908.08681)

In [None]:
# Cleanup
import gc

torch.cuda.empty_cache()
gc.collect()

print(f"GPU Memory: {torch.cuda.memory_allocated()/1e9:.2f} GB allocated")
print("Cleanup complete!")