# PyTorch Hooks Deep Dive: Forward and Backward Hooks on ResNet18

This notebook demonstrates how to use PyTorch's hook mechanism to intercept
forward and backward passes through a neural network. We'll register hooks
on all convolutional layers in a pre-trained ResNet18 model.

**Learning Objectives:**
1. Understand forward vs backward hooks
2. Navigate nested model architectures
3. Track layer execution order during forward/backward passes
4. Use closures to create reusable hook functions

In [None]:
import torch
from torch import nn
from torchvision import models


def get_device_type(verbose: bool = True) -> torch.device:
    """
    Returns the best available device: CUDA → MPS → CPU
    Explicitly separates 'is_built' and 'is_available' checks for both backends.
    """
    if verbose:
        print("Detecting best available device...")

    # ------------------- CUDA -------------------
    cuda_built = torch.version.cuda is not None
    cuda_available = torch.cuda.is_available()

    if cuda_built:
        if cuda_available:
            device = torch.device("cuda")
            if verbose:
                print(f"CUDA → Built: Yes | Available: Yes → Using {device}")
                print(
                    f"   GPU: {torch.cuda.get_device_name(0)} | Count: {torch.cuda.device_count()}"
                )
            return device
        elif verbose:
            print("CUDA → Built: Yes | Available: No (driver/GPU issue)")
    elif verbose:
        print("CUDA → Built: No (PyTorch compiled without CUDA support)")

    # ------------------- MPS (Apple Silicon) -------------------
    mps_built = torch.backends.mps.is_built()
    mps_available = torch.backends.mps.is_available()

    if mps_built:
        if mps_available:
            device = torch.device("mps")
            if verbose:
                print(
                    f"MPS → Built: Yes | Available: Yes → Using {device} (Apple Silicon GPU)"
                )
            return device
        elif verbose:
            print("MPS → Built: Yes | Available: No (macOS <12.3 or Intel Mac)")
    elif verbose:
        print("MPS → Built: No (PyTorch compiled without MPS support)")

    # ------------------- CPU Fallback -------------------
    device = torch.device("cpu")
    if verbose:
        print("Falling back to CPU")
    return device


device = get_device_type(True)

Detecting best available device...
CUDA → Built: No (PyTorch compiled without CUDA support)
MPS → Built: Yes | Available: Yes → Using mps (Apple Silicon GPU)


In [2]:
# Load pre-trained ResNet18 model
model = models.resnet18(weights='IMAGENET1K_V1')
model = model.to(device)
model.eval()  # Set to evaluation mode (disables dropout, batch norm training mode, etc.)

print("ResNet18 loaded successfully!")
print(f"Model on device: {next(model.parameters()).device}")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /Users/tensor/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100.0%


ResNet18 loaded successfully!
Model on device: mps:0


## Understanding ResNet18's Nested Structure

ResNet18 has a hierarchical architecture:
- `conv1`: Initial 7x7 convolution
- `layer1` → `layer4`: Four residual layer groups
  - Each layer contains 2 BasicBlocks
  - Each BasicBlock has 2 conv layers (conv1, conv2)
  - Some blocks have downsample layers (1x1 conv for dimension matching)

Let's identify all Conv2d layers:

In [3]:
# Discover all Conv2d layers in ResNet18
conv_layers = []

for name, module in model.named_modules():
    if isinstance(module, nn.Conv2d):
        conv_layers.append(name)

print(f"Total Conv2d layers: {len(conv_layers)}\n")
print("Layer names:")
for i, name in enumerate(conv_layers, 1):
    print(f"  {i}. {name}")

Total Conv2d layers: 20

Layer names:
  1. conv1
  2. layer1.0.conv1
  3. layer1.0.conv2
  4. layer1.1.conv1
  5. layer1.1.conv2
  6. layer2.0.conv1
  7. layer2.0.conv2
  8. layer2.0.downsample.0
  9. layer2.1.conv1
  10. layer2.1.conv2
  11. layer3.0.conv1
  12. layer3.0.conv2
  13. layer3.0.downsample.0
  14. layer3.1.conv1
  15. layer3.1.conv2
  16. layer4.0.conv1
  17. layer4.0.conv2
  18. layer4.0.downsample.0
  19. layer4.1.conv1
  20. layer4.1.conv2


## Hook Registration Strategy

We'll create two types of hooks:

1. **Forward Hooks**: Triggered during forward pass
   - Signature: `hook(module, input, output)`
   - We'll print: "Forward: Conv layer X"

2. **Backward Hooks**: Triggered during backward pass
   - Signature: `hook(module, grad_input, grad_output)`
   - We'll print: "Backward: Conv layer X"

**Key Design Decisions:**
- Use closures to capture layer number in hook function
- Store hook handles to prevent garbage collection
- Number layers consistently (1-20) for clarity

**Important Note on Execution Order:**
- Forward hooks execute in the order layers are called (1 → 20)
- Backward hooks execute in reverse order (20 → 1) because gradients flow backward

In [None]:
def create_forward_hook(layer_num: int):
    """
    Create a forward hook that prints the layer number.

    Args:
        layer_num: Integer identifying the convolution layer

    Returns:
        Hook function with captured layer_num in closure
    """
    def hook(module, input, output):
        print(f"Forward: Conv layer {layer_num}")
    return hook


def create_backward_hook(layer_num: int):
    """
    Create a backward hook that prints the layer number.

    Args:
        layer_num: Integer identifying the convolution layer

    Returns:
        Hook function with captured layer_num in closure
    """
    def hook(module, grad_input, grad_output):
        print(f"Backward: Conv layer {layer_num}")
    return hook


# Storage for hook handles (prevents garbage collection)
forward_hook_handles = []
backward_hook_handles = []

# Register hooks on all Conv2d layers
layer_num = 0
for _name, module in model.named_modules():
    if isinstance(module, nn.Conv2d):
        layer_num += 1

        # Register both forward and backward hooks
        fwd_handle = module.register_forward_hook(create_forward_hook(layer_num))
        bwd_handle = module.register_full_backward_hook(create_backward_hook(layer_num))

        forward_hook_handles.append(fwd_handle)
        backward_hook_handles.append(bwd_handle)

print(f"Registered hooks on {layer_num} Conv2d layers")
print(f"  - {len(forward_hook_handles)} forward hooks")
print(f"  - {len(backward_hook_handles)} backward hooks")

Registered hooks on 20 Conv2d layers
  - 20 forward hooks
  - 20 backward hooks


In [5]:
# Create a dummy input tensor (batch_size=1, channels=3, height=224, width=224)
# ResNet18 expects 224x224 RGB images
dummy_input = torch.randn(1, 3, 224, 224, device=device, requires_grad=True)

print(f"Dummy input shape: {dummy_input.shape}")
print(f"Requires gradient: {dummy_input.requires_grad}")

Dummy input shape: torch.Size([1, 3, 224, 224])
Requires gradient: True


## Run Forward Pass

This will trigger all forward hooks in execution order:

In [6]:
print("=" * 60)
print("FORWARD PASS")
print("=" * 60)

output = model(dummy_input)

print("=" * 60)
print(f"Output shape: {output.shape}")  # Should be [1, 1000] for ImageNet classes

FORWARD PASS
Forward: Conv layer 1
Forward: Conv layer 2
Forward: Conv layer 3
Forward: Conv layer 4
Forward: Conv layer 5
Forward: Conv layer 6
Forward: Conv layer 7
Forward: Conv layer 8
Forward: Conv layer 9
Forward: Conv layer 10
Forward: Conv layer 11
Forward: Conv layer 12
Forward: Conv layer 13
Forward: Conv layer 14
Forward: Conv layer 15
Forward: Conv layer 16
Forward: Conv layer 17
Forward: Conv layer 18
Forward: Conv layer 19
Forward: Conv layer 20
Output shape: torch.Size([1, 1000])


## Compute Dummy Loss and Run Backward Pass

We'll create a simple loss by summing all outputs, then call `backward()`.
This will trigger all backward hooks in reverse order:

In [7]:
# Create a dummy loss (sum of all output logits)
# In real training, you'd use a proper loss function like CrossEntropyLoss
dummy_loss = output.sum()

print("=" * 60)
print("BACKWARD PASS")
print("=" * 60)

dummy_loss.backward()

print("=" * 60)
print(f"Dummy loss value: {dummy_loss.item():.4f}")

BACKWARD PASS
Backward: Conv layer 20
Backward: Conv layer 19
Backward: Conv layer 18
Backward: Conv layer 17
Backward: Conv layer 16
Backward: Conv layer 15
Backward: Conv layer 14
Backward: Conv layer 13
Backward: Conv layer 12
Backward: Conv layer 11
Backward: Conv layer 10
Backward: Conv layer 9
Backward: Conv layer 8
Backward: Conv layer 7
Backward: Conv layer 6
Backward: Conv layer 5
Backward: Conv layer 4
Backward: Conv layer 3
Backward: Conv layer 2
Backward: Conv layer 1
Dummy loss value: 0.0103


## Clean Up: Remove Hooks

Good practice: Remove hooks when done to free resources.

In [8]:
# Remove all hooks
for handle in forward_hook_handles:
    handle.remove()

for handle in backward_hook_handles:
    handle.remove()

print(f"Removed {len(forward_hook_handles)} forward hooks")
print(f"Removed {len(backward_hook_handles)} backward hooks")

# Clear the lists
forward_hook_handles.clear()
backward_hook_handles.clear()

print("Hook cleanup complete!")

Removed 20 forward hooks
Removed 20 backward hooks
Hook cleanup complete!


## Extension Exercises

Now that you understand the basics, try these extensions:

1. **Capture Feature Maps**: Modify forward hooks to store `output.clone()` in a dictionary
   ```python
   feature_maps = {}
   def create_forward_hook(layer_num):
       def hook(module, input, output):
           feature_maps[f'layer_{layer_num}'] = output.clone()
       return hook
   ```

2. **Measure Gradient Magnitudes**: Modify backward hooks to compute `grad_output[0].norm()`
   ```python
   def create_backward_hook(layer_num):
       def hook(module, grad_input, grad_output):
           grad_norm = grad_output[0].norm().item()
           print(f"Backward: Conv layer {layer_num} | Gradient norm: {grad_norm:.4f}")
       return hook
   ```

3. **Filter by Layer Type**: Register hooks only on 3x3 convolutions (exclude 1x1 and 7x7)
   ```python
   for name, module in model.named_modules():
       if isinstance(module, nn.Conv2d) and module.kernel_size == (3, 3):
           # Register hook only on 3x3 convs
   ```

4. **Timing Analysis**: Use `time.time()` in hooks to measure per-layer execution time

5. **Selective Hooks**: Only hook layers in `layer3` and `layer4`
   ```python
   for name, module in model.named_modules():
       if isinstance(module, nn.Conv2d) and ('layer3' in name or 'layer4' in name):
           # Register hook
   ```

**Advanced Challenge**: Create a visualization showing which layers have the largest gradient flow during backward pass. This can help identify vanishing/exploding gradient problems in deep networks.