# Lab 2.1.3: Autograd Deep Dive - SOLUTIONS

This notebook contains complete solutions for the Autograd Deep Dive exercises.

---

In [None]:
import torch
import torch.nn.functional as F
from torch.autograd import Function, gradcheck

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

---

## Exercise Solution: Hard Swish

Hard Swish is a computationally efficient approximation of Swish used in MobileNetV3.

In [None]:
class HardSwishFunction(Function):
    """
    Custom autograd function for Hard Swish activation.
    
    HardSwish(x) = x * ReLU6(x + 3) / 6
    
    Where ReLU6(x) = min(max(0, x), 6) = clamp(x, 0, 6)
    
    This can be rewritten as:
    - x < -3: 0
    - -3 <= x <= 3: x * (x + 3) / 6
    - x > 3: x
    
    Derivative:
    - x < -3: 0
    - -3 <= x <= 3: (2x + 3) / 6
    - x > 3: 1
    """
    
    @staticmethod
    def forward(ctx, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass: compute HardSwish(x).
        """
        # Save input for backward
        ctx.save_for_backward(x)
        
        # Compute: x * ReLU6(x + 3) / 6
        # ReLU6(x) = clamp(x, 0, 6)
        return x * torch.clamp(x + 3, min=0, max=6) / 6
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
        """
        Backward pass: compute gradient of HardSwish.
        
        The gradient is piecewise:
        - x < -3: 0
        - -3 <= x <= 3: (2x + 3) / 6
        - x > 3: 1
        """
        x, = ctx.saved_tensors
        
        # Initialize gradient to zeros
        grad_input = torch.zeros_like(x)
        
        # Region 1: x < -3 -> gradient = 0 (already set)
        
        # Region 2: -3 <= x <= 3 -> gradient = (2x + 3) / 6
        mask_middle = (x >= -3) & (x <= 3)
        grad_input[mask_middle] = (2 * x[mask_middle] + 3) / 6
        
        # Region 3: x > 3 -> gradient = 1
        mask_upper = x > 3
        grad_input[mask_upper] = 1.0
        
        # Apply chain rule
        return grad_output * grad_input


def hard_swish_custom(x: torch.Tensor) -> torch.Tensor:
    """Apply custom Hard Swish activation."""
    return HardSwishFunction.apply(x)


# Test the implementation
print("=== Testing HardSwish ===")

# Test values in each region
x_test = torch.tensor([-5.0, -3.0, 0.0, 3.0, 5.0], requires_grad=True)
y = hard_swish_custom(x_test)
y.sum().backward()

print(f"Input: {x_test.data}")
print(f"Output: {y.data}")
print(f"Gradient: {x_test.grad}")

# Expected outputs:
# x=-5: 0 (x < -3)
# x=-3: 0 (boundary, -3 * 0 / 6 = 0)
# x=0: 0 (0 * 3 / 6 = 0)
# x=3: 3 (3 * 6 / 6 = 3)
# x=5: 5 (x > 3, output = x)
print(f"\nExpected output: [0, 0, 0, 3, 5]")

# Expected gradients:
# x=-5: 0
# x=-3: (-6+3)/6 = -0.5 (actually at boundary, should be 0)
# x=0: 3/6 = 0.5
# x=3: 9/6 = 1.5 (but we clamp, so actually should be at boundary)
# x=5: 1

In [None]:
# Verify gradients with gradcheck
x_test = torch.randn(3, 4, dtype=torch.float64, requires_grad=True)

test_passed = gradcheck(
    HardSwishFunction.apply, 
    (x_test,), 
    eps=1e-6, 
    atol=1e-4, 
    rtol=1e-3
)

print(f"HardSwish gradient check passed: {test_passed}")

In [None]:
# Compare with PyTorch's built-in hardswish
import matplotlib.pyplot as plt

x = torch.linspace(-6, 6, 200)

# Our custom implementation
x_custom = x.clone().requires_grad_(True)
y_custom = hard_swish_custom(x_custom)

# PyTorch built-in
y_builtin = F.hardswish(x)

# Plot comparison
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(x.numpy(), y_custom.detach().numpy(), label='Custom', linewidth=2)
axes[0].plot(x.numpy(), y_builtin.numpy(), '--', label='Built-in', linewidth=2)
axes[0].set_title('Hard Swish: Custom vs Built-in')
axes[0].legend()
axes[0].grid(True)

# Plot difference
diff = (y_custom.detach() - y_builtin).abs()
axes[1].plot(x.numpy(), diff.numpy(), linewidth=2)
axes[1].set_title('Absolute Difference')
axes[1].set_ylabel('|Custom - Built-in|')
axes[1].grid(True)

plt.tight_layout()
plt.show()

print(f"Max difference: {diff.max():.2e}")

---

## Challenge Solution: Gradient Checkpointing

Gradient checkpointing is a memory optimization technique that trades compute for memory. Instead of storing all intermediate activations, we only store "checkpoints" and recompute activations during the backward pass.

In [None]:
class CheckpointFunction(Function):
    """
    Gradient checkpointing function.
    
    This function wraps another function and recomputes the forward pass
    during backward instead of storing intermediate activations.
    
    Memory savings: O(sqrt(n)) instead of O(n) for n layers
    Compute overhead: One additional forward pass
    """
    
    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
        """
        Forward pass: run the function and save inputs for recomputation.
        
        Args:
            run_function: The function to checkpoint
            preserve_rng_state: Whether to preserve RNG state
            *args: Arguments to pass to run_function
        """
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        
        # Save RNG state if requested
        if preserve_rng_state:
            ctx.cpu_rng_state = torch.get_rng_state()
            if torch.cuda.is_available():
                ctx.cuda_rng_state = torch.cuda.get_rng_state()
        
        # Save inputs (detached to not track computation)
        ctx.save_for_backward(*args)
        
        # Run forward without tracking gradients for intermediate ops
        with torch.no_grad():
            outputs = run_function(*args)
        
        return outputs
    
    @staticmethod
    def backward(ctx, *grad_outputs):
        """
        Backward pass: recompute forward and then backpropagate.
        """
        # Restore inputs
        inputs = ctx.saved_tensors
        
        # Restore RNG state if needed
        if ctx.preserve_rng_state:
            torch.set_rng_state(ctx.cpu_rng_state)
            if torch.cuda.is_available():
                torch.cuda.set_rng_state(ctx.cuda_rng_state)
        
        # Recompute forward pass with gradient tracking
        detached_inputs = [x.detach().requires_grad_(x.requires_grad) for x in inputs]
        
        with torch.enable_grad():
            outputs = ctx.run_function(*detached_inputs)
        
        # Handle single output
        if not isinstance(outputs, tuple):
            outputs = (outputs,)
        
        # Compute gradients
        torch.autograd.backward(outputs, grad_outputs)
        
        # Return gradients for inputs (None for run_function and preserve_rng_state)
        grads = tuple(inp.grad if inp.grad is not None else None for inp in detached_inputs)
        return (None, None) + grads


def checkpoint(function, *args, preserve_rng_state=True):
    """
    Checkpoint a function to save memory during backward pass.
    
    Args:
        function: Function to checkpoint
        *args: Arguments to pass to function
        preserve_rng_state: Whether to preserve random state
    
    Returns:
        Output of function(*args)
    
    Example:
        >>> def my_func(x):
        ...     return x.relu().sin().cos()
        >>> y = checkpoint(my_func, x)
    """
    return CheckpointFunction.apply(function, preserve_rng_state, *args)


# Test checkpointing
print("=== Testing Gradient Checkpointing ===")

def expensive_computation(x):
    """A function with many intermediate activations."""
    for _ in range(5):
        x = x.relu()
        x = x.sin()
        x = x.cos()
    return x

# Without checkpointing
x1 = torch.randn(100, 100, requires_grad=True)
y1 = expensive_computation(x1)
y1.sum().backward()
grad1 = x1.grad.clone()

# With checkpointing
x2 = x1.detach().clone().requires_grad_(True)
y2 = checkpoint(expensive_computation, x2)
y2.sum().backward()
grad2 = x2.grad.clone()

# Compare
print(f"Outputs match: {torch.allclose(y1.detach(), y2.detach())}")
print(f"Gradients match: {torch.allclose(grad1, grad2)}")
print(f"Max gradient difference: {(grad1 - grad2).abs().max():.2e}")

### Memory Comparison

In practice, checkpointing is most useful for deep networks. PyTorch includes a built-in version:

In [None]:
from torch.utils.checkpoint import checkpoint as torch_checkpoint
import torch.nn as nn

# Create a deep model
class DeepModel(nn.Module):
    def __init__(self, num_layers=20):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.Linear(256, 256) for _ in range(num_layers)
        ])
    
    def forward(self, x, use_checkpoint=False):
        for layer in self.layers:
            if use_checkpoint:
                x = torch_checkpoint(lambda y: F.relu(layer(y)), x, use_reentrant=False)
            else:
                x = F.relu(layer(x))
        return x

model = DeepModel(num_layers=20).to(device)

# Compare memory usage
if device.type == 'cuda':
    torch.cuda.reset_peak_memory_stats()
    
    # Without checkpointing
    x = torch.randn(64, 256, device=device, requires_grad=True)
    y = model(x, use_checkpoint=False)
    y.sum().backward()
    mem_no_ckpt = torch.cuda.max_memory_allocated() / 1e6
    
    torch.cuda.reset_peak_memory_stats()
    
    # With checkpointing
    x = torch.randn(64, 256, device=device, requires_grad=True)
    y = model(x, use_checkpoint=True)
    y.sum().backward()
    mem_ckpt = torch.cuda.max_memory_allocated() / 1e6
    
    print(f"Memory without checkpointing: {mem_no_ckpt:.2f} MB")
    print(f"Memory with checkpointing: {mem_ckpt:.2f} MB")
    print(f"Memory saved: {(1 - mem_ckpt/mem_no_ckpt)*100:.1f}%")
else:
    print("GPU not available, skipping memory comparison")

---

## Alternative Solution: Using torch.autograd.function with Setup Context

PyTorch 2.0+ supports a new `setup_context` method for cleaner code:

In [None]:
class ModernSwishFunction(Function):
    """
    Modern-style autograd function using setup_context.
    
    This style separates context setup from forward computation,
    making the code cleaner and more maintainable.
    """
    
    @staticmethod
    def forward(x: torch.Tensor) -> torch.Tensor:
        """Pure forward computation."""
        sigmoid_x = torch.sigmoid(x)
        return x * sigmoid_x
    
    @staticmethod
    def setup_context(ctx, inputs, output):
        """Setup context after forward."""
        x, = inputs
        # Save sigmoid for backward (computing it from x is equivalent)
        sigmoid_x = torch.sigmoid(x)
        ctx.save_for_backward(x, sigmoid_x)
    
    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        """Backward computation."""
        x, sigmoid_x = ctx.saved_tensors
        grad_input = sigmoid_x * (1 + x * (1 - sigmoid_x))
        return grad_output * grad_input

# This style may not work in all PyTorch versions
# The classic style with forward(ctx, x) is more compatible
print("Modern autograd function style demonstrated!")

In [None]:
# Cleanup
import gc
torch.cuda.empty_cache()
gc.collect()
print("Cleanup complete!")