<a href="https://colab.research.google.com/github/Ayesha-Imr/vision-mech-interp/blob/main/ayesha-imr__segment_1_intro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Segment 1: Building Intuition for CNNs

## Why This Matters

Imagine you're trying to understand how a human brain recognizes faces. You wouldn't just look at the final "yes, that's a face" decision. You'd want to understand:
- How the eyes detect edges and lines
- How the brain combines those edges into shapes
- How those shapes become "nose," "eyes," "mouth"
- How all of that becomes "face"

**CNNs work the same way.** Before we can "interpret" what a CNN is doing, we need to understand its basic mechanics.

## The Mental Model We're Building

By the end of this notebook, you'll understand:

1. **Images aren't images to a CNN** ‚Äî they're arrays of numbers (tensors)
2. **Convolutions are pattern matchers** ‚Äî each filter looks for one specific pattern
3. **Depth creates abstraction** ‚Äî stacking simple operations creates complex understanding
4. **Space is preserved** ‚Äî even deep in the network, there's spatial structure

Without this foundation, interpretability techniques will feel like magic. With it, they'll make perfect sense.

---
## Setup: Loading Our Tools

### What We Need

Think of this section as gathering your equipment before a lab experiment:
- **PyTorch** ‚Äî the framework that lets us build and run neural networks
- **VGG16** ‚Äî a pretrained CNN (someone already trained it on millions of images)
- **An image** ‚Äî we'll use a cat photo from Wikipedia
- **Visualization tools** ‚Äî matplotlib to see what's happening

### Why VGG16?

VGG16 is like a 1990s Toyota Camry:
- Not the fanciest (ResNet, Vision Transformers are newer)
- But simple, reliable, and easy to understand
- Perfect for learning the fundamentals

It has 16 layers organized into 5 "blocks" ‚Äî we'll explore these blocks.

In [None]:
# ============================================
# IMPORTS: The Libraries We Need
# ============================================

import torch                              # PyTorch: the neural network framework
import torch.nn as nn                     # Neural network building blocks (layers, activations, etc.)
import torchvision.models as models       # Pre-built models like VGG16, ResNet, etc.
import torchvision.transforms as transforms  # Image preprocessing tools (resize, normalize, etc.)
from PIL import Image                     # Python Imaging Library: load and manipulate images
import matplotlib.pyplot as plt           # Plotting library: visualize images and graphs
import numpy as np                        # Numerical computing: array operations
import requests                           # Download images from URLs
from io import BytesIO                    # Handle image data in memory (not saving to disk)

print("‚úÖ All libraries imported successfully!")

In [None]:
# ============================================
# STEP 1: Load the Pretrained Model
# ============================================

# What does "pretrained" mean?
# - Someone (the PyTorch team) already trained this network on ImageNet
# - ImageNet = 1.2 million images, 1000 categories (dogs, cats, cars, etc.)
# - The network learned to recognize patterns through weeks of training on GPUs
# - We're downloading those learned "weights" (the pattern detectors)

model = models.vgg16(pretrained=True)  # Download and load VGG16 with pretrained weights

# What does .eval() mean?
# - Neural networks behave differently during training vs testing
# - During training: they use dropout (randomly turn off neurons) and batch norm (normalize data)
# - During evaluation: we turn these off for consistent, reproducible results
# - .eval() puts the model in "evaluation mode"

model.eval()

print("‚úÖ VGG16 loaded and set to evaluation mode")
print(f"\nModel structure preview:")
print(f"  - Total layers: {len(list(model.features))} convolutional layers")
print(f"  - Input: 224√ó224√ó3 RGB image")
print(f"  - Output: 1000 class probabilities (dog, cat, car, etc.)")

In [None]:
# ============================================
# STEP 2: Define Image Preprocessing
# ============================================

# Why do we need preprocessing?
# - Raw images come in different sizes (640√ó480, 1920√ó1080, etc.)
# - Pixel values are 0-255 (integers)
# - VGG16 expects: exactly 224√ó224 pixels, normalized float values

# Think of preprocessing as "translating" the image into the language the CNN speaks

preprocess = transforms.Compose([  # Compose = chain multiple transformations together
    
    # Step 1: Resize the shorter side to 224 pixels (keeps aspect ratio)
    # Example: 1200√ó800 image ‚Üí 224√ó149 image
    transforms.Resize(224),
    
    # Step 2: Crop the center 224√ó224 square
    # Example: 224√ó149 ‚Üí take center 224√ó224 (crops left/right edges)
    transforms.CenterCrop(224),
    
    # Step 3: Convert PIL Image to PyTorch tensor
    # - PIL stores as [height, width, channels] with values 0-255
    # - Tensor stores as [channels, height, width] with values 0.0-1.0
    # Example: [224, 224, 3] with 0-255 ‚Üí [3, 224, 224] with 0.0-1.0
    transforms.ToTensor(),
    
    # Step 4: Normalize using ImageNet statistics
    # Why? VGG16 was trained on normalized images, so we must normalize too
    # mean=[0.485, 0.456, 0.406] = average RGB values across all ImageNet images
    # std=[0.229, 0.224, 0.225] = standard deviation of RGB values
    # Formula: pixel_normalized = (pixel - mean) / std
    # This centers the data around 0 and scales to similar ranges
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

print("‚úÖ Preprocessing pipeline defined")
print("\nWhat happens to an image:")
print("  1. Any size (e.g., 1200√ó800) ‚Üí 224√ó224 (resized and cropped)")
print("  2. [H, W, 3] uint8 0-255 ‚Üí [3, H, W] float32 0.0-1.0")
print("  3. Each channel normalized: (pixel - mean) / std")

In [None]:
# ============================================
# STEP 3: Load an Image from the Internet
# ============================================

def load_image(url):
    """
    Download an image from a URL and return it as a PIL Image.
    
    Why we need the User-Agent header:
    - Some websites block automated requests (bots)
    - By pretending to be a web browser, we avoid getting blocked
    """
    headers = {
        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
    }
    
    # Send HTTP GET request to download the image
    response = requests.get(url, headers=headers)
    
    # Check if download succeeded (status code 200 = success)
    response.raise_for_status()  # Raises error if status is 4xx or 5xx
    
    # Convert the downloaded bytes into a PIL Image object
    # BytesIO creates a file-like object in memory (no disk I/O)
    # .convert('RGB') ensures we have 3 color channels (some images are grayscale)
    img = Image.open(BytesIO(response.content)).convert('RGB')
    
    return img

# Download a cat image from Wikimedia Commons
img_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg"
img = load_image(img_url)

print(f"‚úÖ Image downloaded successfully!")
print(f"   Original size: {img.size}  # (width, height) in pixels")

In [None]:
# ============================================
# STEP 4: Preprocess the Image
# ============================================

# Apply all the transformations we defined earlier
input_tensor = preprocess(img)  # Now: [3, 224, 224] normalized tensor

# Add a "batch dimension" at the front
# Why? PyTorch processes images in batches for efficiency
# - During training, you might process 32 images at once
# - Here we only have 1 image, but we still need the batch dimension
# - .unsqueeze(0) adds a dimension at position 0
# - [3, 224, 224] ‚Üí [1, 3, 224, 224]
#    ^new dimension (batch size = 1)

input_tensor = input_tensor.unsqueeze(0)

print(f"‚úÖ Image preprocessed and ready for the network")
print(f"\nOriginal image size: {img.size}  # (width, height) = ({img.size[0]}√ó{img.size[1]})")
print(f"Preprocessed tensor shape: {input_tensor.shape}")
print(f"\nBreaking down the shape [1, 3, 224, 224]:")
print(f"  [0] Batch size = 1       # We're processing 1 image")
print(f"  [1] Channels = 3         # Red, Green, Blue")
print(f"  [2] Height = 224         # Rows of pixels")
print(f"  [3] Width = 224          # Columns of pixels")

In [None]:
# Let's visualize the original image to see what we're working with
plt.figure(figsize=(6, 6))
plt.imshow(img)
plt.title(f"Original Image: {img.size[0]}√ó{img.size[1]} pixels", fontsize=14)
plt.axis('off')
plt.show()

print("\nüìå This is what the image looks like to us humans.")
print("   To the CNN, it's just a [1, 3, 224, 224] array of numbers!")

---
# Code Block 1: First Contact ‚Äî Image ‚Üí Tensors ‚Üí Activations

## The Big Question

**What happens when an image enters a CNN?**

## The Answer (Simplified)

1. **Input:** `[1, 3, 224, 224]` ‚Äî one image with 3 color channels (RGB)
2. **After first convolutional layer:** `[1, 64, 224, 224]` ‚Äî 64 "feature maps"
3. **Each feature map:** Shows where a specific pattern was detected

## The Analogy

Imagine you have 64 different colored markers (red, blue, green, etc.).

You look at the image and:
- With marker #1, you highlight all vertical edges
- With marker #2, you highlight all horizontal edges  
- With marker #3, you highlight all diagonal edges
- With marker #4, you highlight all circular shapes
- ... and so on for all 64 markers

Each marker gives you a **different highlighted version** of the same image. That's what the 64 feature maps are!

## The Truth

**The CNN doesn't "see" a cat.** It sees 64 different filtered versions, each responding to different local patterns (edges, textures, colors).

## What We'll Do

1. Extract the first convolutional layer from VGG16
2. Pass our cat image through just that one layer
3. Visualize the 64 resulting feature maps
4. See how the "cat" is decomposed into pattern responses

In [None]:
# ============================================
# Extract the First Convolutional Layer
# ============================================

# VGG16 is organized as:
# model.features = all convolutional and pooling layers (the "feature extractor")
# model.classifier = fully connected layers at the end (the "decision maker")

# model.features[0] = the very first layer
# It's a Conv2d layer: Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
#   - Input: 3 channels (RGB)
#   - Output: 64 channels (64 different pattern detectors)
#   - Kernel size: 3√ó3 (each filter is a 3√ó3 grid)
#   - Padding: 1 (adds 1-pixel border so output size = input size)

first_conv = model.features[0]

print("First layer details:")
print(first_conv)
print(f"\nWhat this layer does:")
print(f"  - Takes: [batch, 3, 224, 224] (3-channel image)")
print(f"  - Returns: [batch, 64, 224, 224] (64 feature maps)")
print(f"  - How: Applies 64 different 3√ó3 filters across the image")

In [None]:
# ============================================
# Pass the Image Through the First Layer
# ============================================

# torch.no_grad() context:
# - During training, PyTorch tracks all operations to compute gradients (for backprop)
# - We're not training, just visualizing, so we don't need gradients
# - no_grad() turns off gradient tracking ‚Üí saves memory, runs faster

with torch.no_grad():
    # Pass input through the first conv layer
    # input_tensor: [1, 3, 224, 224]
    # first_activations: [1, 64, 224, 224]
    first_activations = first_conv(input_tensor)

print(f"‚úÖ Forward pass complete!")
print(f"\nInput shape:  {input_tensor.shape}  ‚Üê 3 color channels (R, G, B)")
print(f"Output shape: {first_activations.shape}  ‚Üê 64 feature maps (pattern responses)")
print(f"\nWhat just happened:")
print(f"  - 64 different 3√ó3 filters scanned the entire 224√ó224 image")
print(f"  - Each filter detected a specific pattern (edges, textures, etc.)")
print(f"  - Result: 64 activation maps showing WHERE each pattern was found")

In [None]:
# ============================================
# Visualize: Original Image vs Feature Maps
# ============================================

# Create a 2√ó5 grid of subplots
# - 1st subplot: original image
# - Next 9 subplots: first 9 feature maps (out of 64 total)
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

# ----------------------------
# Top-left: Original Image
# ----------------------------
axes[0, 0].imshow(img)  # Display the RGB image
axes[0, 0].set_title("Original Image\n(what we see)", fontsize=10, fontweight='bold')
axes[0, 0].axis('off')  # Hide x,y axis numbers

# ----------------------------
# Remaining 9 subplots: Feature Maps
# ----------------------------
for i in range(9):  # Show first 9 feature maps (out of 64)
    # Calculate which subplot this is
    # i=0: position 1 ‚Üí row=0, col=1
    # i=1: position 2 ‚Üí row=0, col=2
    # i=4: position 5 ‚Üí row=1, col=0  (wraps to next row)
    row = (i + 1) // 5  # Integer division: which row?
    col = (i + 1) % 5   # Modulo: which column?
    
    # Extract the i-th feature map
    # first_activations shape: [1, 64, 224, 224]
    #   [0] = batch index (we only have 1 image)
    #   [i] = feature map index (0-63)
    # Result: [224, 224] array of activation values
    feature_map = first_activations[0, i].numpy()  # Convert to numpy for matplotlib
    
    # Display the feature map
    # cmap='viridis': colormap (yellow=high activation, purple=low activation)
    axes[row, col].imshow(feature_map, cmap='viridis')
    axes[row, col].set_title(f"Feature Map {i}\n(filter {i} response)", fontsize=9)
    axes[row, col].axis('off')

plt.suptitle("After First Conv Layer: One Image Becomes 64 Filtered Versions", fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("üí° WHAT YOU SHOULD SEE:")
print("="*60)
print("Each feature map looks different because each filter detects different patterns:")
print("  ‚Ä¢ Some highlight edges (bright lines where edges exist)")
print("  ‚Ä¢ Some highlight textures (fur, whiskers)")
print("  ‚Ä¢ Some respond to specific colors")
print("  ‚Ä¢ Some look almost random (that filter didn't find its pattern)")
print("\n‚ö†Ô∏è The CNN doesn't 'see' a cat ‚Äî it sees 64 different pattern-responses!")

## üéØ Key Takeaway from Block 1

**CNNs transform images immediately.**

- **Input:** A recognizable image (a cat)
- **After 1 layer:** 64 abstract "maps" showing where patterns exist
- **Each map:** Answers "where does pattern X appear?"

This is **not** semantic understanding. It's **pattern matching**. The "cat" concept emerges much later, from combining these simple responses.

---
# Code Block 2: Convolution as Local Pattern Detection

## The Big Question

**What ARE these "filters" and how do they work?**

## The Intuition

Imagine a stencil (like for painting letters on a wall):
1. You place the stencil over a spot on the image
2. You check: "Does the pattern under the stencil match my stencil pattern?"
3. If yes ‚Üí high activation (bright spot on the feature map)
4. If no ‚Üí low activation (dark spot)
5. Move the stencil 1 pixel over and repeat

That's exactly what a **convolutional filter** does!

## The Math (Simplified)

A 3√ó3 filter is just a 3√ó3 grid of numbers ("weights"):

```
Filter for detecting vertical edges:
[-1,  0,  1]
[-1,  0,  1]  
[-1,  0,  1]
```

This filter:
- Negative on the left (-1)
- Zero in the middle (0)
- Positive on the right (+1)

When you slide it over an image:
- **Vertical edge (dark‚Üíbright):** High positive response ‚úÖ
- **Flat region (all same color):** Zero response ‚ùå
- **Horizontal edge:** Zero response ‚ùå

## Weight Sharing

**Key insight:** The SAME filter is used everywhere in the image.

- Top-left corner: filter checks for vertical edges
- Bottom-right corner: SAME filter checks for vertical edges
- This is why it's called "weight sharing" ‚Äî one set of weights, used everywhere

## What We'll Do

1. Look at the learned filter weights (the 3√ó3 patterns)
2. See their corresponding activation maps
3. Overlay activations on the original image to see WHERE patterns were detected

In [None]:
# ============================================
# Extract the Learned Filters
# ============================================

# Every Conv2d layer stores its filters in .weight
# For first_conv (Conv2d(3, 64, kernel_size=3)):
#   .weight.shape = [64, 3, 3, 3]
#     [0] = 64 output channels (64 different filters)
#     [1] = 3 input channels (R, G, B)
#     [2] = 3 rows (3√ó3 kernel)
#     [3] = 3 columns (3√ó3 kernel)

filters = first_conv.weight.data.clone()  # .data = raw tensor, .clone() = make a copy

print(f"Filter tensor shape: {filters.shape}")
print(f"\nBreaking it down:")
print(f"  [0] 64 filters        # 64 different pattern detectors")
print(f"  [1] 3 input channels  # Each filter looks at R, G, B")
print(f"  [2] 3 rows            # Kernel height")
print(f"  [3] 3 columns         # Kernel width")
print(f"\nEach filter is a 3√ó3√ó3 cube of numbers.")
print(f"Total parameters in this layer: {filters.numel()} = 64 √ó 3 √ó 3 √ó 3")

In [None]:
# ============================================
# Visualize Filters and Their Activations
# ============================================

# We'll show 4 filters (out of 64) for clarity
# For each filter, we show:
#   Column 0: The filter itself (3√ó3 RGB pattern)
#   Column 1: Arrow ‚Üí
#   Column 2: Where that filter activates (the feature map)
#   Column 3: Activation overlaid on original image

fig, axes = plt.subplots(4, 6, figsize=(16, 10))

for i in range(4):  # Show first 4 filters
    
    # =============================
    # Column 0: The Filter Pattern
    # =============================
    
    # Get filter i: shape [3, 3, 3] (RGB channels, 3√ó3 spatial)
    filt = filters[i].permute(1, 2, 0).numpy()  # Permute to [3, 3, 3] for imshow
    
    # Normalize to [0, 1] for display
    # Raw filter values can be negative or very large
    # Normalization: (x - min) / (max - min) ‚Üí maps to [0, 1]
    filt = (filt - filt.min()) / (filt.max() - filt.min() + 1e-8)  # +epsilon to avoid division by 0
    
    axes[i, 0].imshow(filt)
    axes[i, 0].set_title(f"Filter {i}\n(3√ó3 RGB pattern)", fontsize=9, fontweight='bold')
    axes[i, 0].axis('off')
    
    # =============================
    # Column 1: Arrow (just visual)
    # =============================
    
    axes[i, 1].text(0.5, 0.5, "‚Üí", fontsize=30, ha='center', va='center')
    axes[i, 1].axis('off')
    
    # =============================
    # Column 2: Activation Map
    # =============================
    
    # Get the corresponding activation map
    # first_activations[0, i]: the i-th feature map, shape [224, 224]
    activation = first_activations[0, i].numpy()
    
    # Display with 'hot' colormap (black=low, red=medium, yellow/white=high)
    axes[i, 2].imshow(activation, cmap='hot')
    axes[i, 2].set_title(f"Where filter {i}\nactivates strongly", fontsize=9)
    axes[i, 2].axis('off')
    
    # =============================
    # Column 3: Overlay on Original
    # =============================
    
    # Show original image
    axes[i, 3].imshow(img)
    
    # Resize activation map to match original image size (224√ó224 ‚Üí 1200√ó1198)
    # Steps:
    # 1. Scale activation to 0-255 uint8
    # 2. Convert to PIL Image
    # 3. Resize to original image dimensions
    # 4. Convert back to numpy array
    activation_resized = np.array(
        Image.fromarray((activation * 255).astype(np.uint8)).resize(img.size)
    )
    
    # Overlay as semi-transparent heatmap (alpha=0.5 = 50% transparent)
    axes[i, 3].imshow(activation_resized, cmap='hot', alpha=0.5)
    axes[i, 3].set_title(f"Overlay:\nBright = high activation", fontsize=9)
    axes[i, 3].axis('off')
    
    # =============================
    # Columns 4-5: Empty (spacing)
    # =============================
    
    axes[i, 4].axis('off')
    axes[i, 5].axis('off')

plt.suptitle("Filters as Pattern Detectors: Each 3√ó3 filter looks for a specific pattern", 
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("üí° WHAT YOU SHOULD SEE:")
print("="*60)
print("Each row shows:")
print("  1. The filter pattern (3√ó3 RGB) ‚Äî this is what it's 'looking for'")
print("  2. The activation map ‚Äî WHERE that pattern was found")
print("  3. Overlay ‚Äî bright areas = strong match with the filter pattern")
print("\nNotice:")
print("  ‚Ä¢ SAME filter detects patterns in MULTIPLE locations (weight sharing!)")
print("  ‚Ä¢ Different filters detect different patterns (edges, textures, colors)")
print("  ‚Ä¢ The 3√ó3 pattern is all that's needed to detect complex structures")

## üéØ Key Takeaway from Block 2

**Convolutions are sliding pattern matchers, not magic.**

- Each filter = one 3√ó3 pattern to look for
- Filter slides across the entire image (weight sharing)
- High activation = "I found my pattern here!"
- Low activation = "My pattern is not here"

**This is local, not global.** Each filter only "sees" a 3√ó3 neighborhood at a time. Global understanding comes from stacking many layers.

---
# Code Block 3: Depth = Abstraction

## The Big Question

**If filters only see 3√ó3 patches, how does a CNN understand whole objects?**

## The Answer

**Layers are hierarchical.** Each layer builds on the previous:

```
Layer 1 (early):   Edges, colors, simple textures
         ‚Üì
Layer 5 (middle):  Corners, curves, complex textures (combinations of edges)
         ‚Üì  
Layer 10 (deep):   Parts (eyes, wheels, fur patterns) ‚Äî combinations of curves/textures
         ‚Üì
Layer 16 (deeper): Whole objects (faces, cars) ‚Äî combinations of parts
```

## The Analogy

Building a house:
- **Layer 1:** Bricks (simple, local)
- **Layer 5:** Walls (made from bricks)
- **Layer 10:** Rooms (made from walls)
- **Layer 16:** House (made from rooms)

Each layer sees **more context** (larger receptive field) and learns **more abstract concepts**.

## What Changes with Depth

| Property | Early Layers | Middle Layers | Deep Layers |
|----------|-------------|---------------|-------------|
| **Spatial resolution** | High (224√ó224) | Medium (56√ó56) | Low (14√ó14) |
| **Number of channels** | Few (64) | Medium (256) | Many (512) |
| **Features detected** | Edges, colors | Textures, patterns | Object parts, concepts |
| **Receptive field** | 3√ó3 pixels | ~40√ó40 pixels | ~200√ó200 pixels |
| **Activation sparsity** | Dense (many active) | Medium | Sparse (few active) |

## What We'll Do

1. Extract activations from three depths: early, middle, deep
2. Compare their shapes and visual appearance
3. See how abstraction increases with depth

In [None]:
# ============================================
# Set Up Hooks to Capture Activations
# ============================================

# Problem: We want to see activations from MIDDLE of the network, not just output
# Solution: "Hooks" ‚Äî callbacks that run during the forward pass

# Dictionary to store captured activations
activations = {}

# Function factory: creates a hook function that saves to a specific name
def get_activation(name):
    """Returns a hook function that captures activations."""
    def hook(model, input, output):
        # This function runs DURING the forward pass
        # model: the layer being executed
        # input: what went into the layer
        # output: what came out of the layer
        
        # .detach() = disconnect from computation graph (we're not training)
        activations[name] = output.detach()
    return hook

# VGG16 has 5 blocks of conv layers:
# Block 1: features[0-4]   (64 channels)
# Block 2: features[5-9]   (128 channels)
# Block 3: features[10-16] (256 channels)
# Block 4: features[17-23] (512 channels)
# Block 5: features[24-30] (512 channels)

# Register hooks at three different depths
model.features[2].register_forward_hook(get_activation('early'))    # End of block 1
model.features[16].register_forward_hook(get_activation('middle'))  # End of block 3
model.features[28].register_forward_hook(get_activation('deep'))    # End of block 5

print("‚úÖ Hooks registered at three depths:")
print(f"   ‚Ä¢ Layer 2 (early):  After first conv block")
print(f"   ‚Ä¢ Layer 16 (middle): After third conv block")
print(f"   ‚Ä¢ Layer 28 (deep):   Near the end of conv layers")

In [None]:
# ============================================
# Run Forward Pass to Capture Activations
# ============================================

with torch.no_grad():  # No gradients needed
    _ = model(input_tensor)  # Full forward pass through VGG16
    # The hooks captured activations automatically!

# Check what we captured
print("\n" + "="*60)
print("ACTIVATION SHAPES AT DIFFERENT DEPTHS")
print("="*60)

for name, act in activations.items():
    print(f"\n{name.upper()} layer:")
    print(f"  Shape: {act.shape}")
    print(f"  Breakdown: [batch={act.shape[0]}, channels={act.shape[1]}, height={act.shape[2]}, width={act.shape[3]}]")

print("\n" + "="*60)
print("NOTICE THE PATTERN:")
print("="*60)
print("As we go deeper:")
print("  ‚úì Spatial dimensions DECREASE (224 ‚Üí 112 ‚Üí 28 ‚Üí 14)")
print("    Why? Pooling layers downsample")
print("\n  ‚úì Number of channels INCREASES (64 ‚Üí 256 ‚Üí 512)")
print("    Why? More features to detect more complex patterns")
print("\nThis is the trade-off: spatial resolution ‚Üî semantic abstraction")

In [None]:
# ============================================
# Visualize Feature Maps from Each Depth
# ============================================

# We'll show 4 feature maps from each depth
# Strategy: Pick the most "interesting" channels (highest variance = most informative)

fig, axes = plt.subplots(3, 5, figsize=(16, 10))

# Iterate through depths (row 0 = early, row 1 = middle, row 2 = deep)
for row, (name, act) in enumerate(activations.items()):
    
    # =============================
    # Column 0: Label
    # =============================
    
    axes[row, 0].text(0.5, 0.5, f"{name.upper()}\nLayer", 
                     fontsize=14, ha='center', va='center', fontweight='bold')
    axes[row, 0].axis('off')
    
    # =============================
    # Find Most Interesting Channels
    # =============================
    
    # Variance measures how much a feature map varies
    # High variance = informative (has structure)
    # Low variance = boring (mostly uniform)
    
    # act[0] = first (only) image in batch, shape [channels, H, W]
    # .var(dim=[1, 2]) = variance across spatial dimensions (H, W)
    # Result: [channels] ‚Äî one variance value per channel
    variances = act[0].var(dim=[1, 2])
    
    # Get indices of top 4 channels with highest variance
    top_channels = torch.argsort(variances, descending=True)[:4]
    
    # =============================
    # Columns 1-4: Feature Maps
    # =============================
    
    for col, ch in enumerate(top_channels):
        # Extract feature map for channel ch
        feature_map = act[0, ch].numpy()
        
        # Display with viridis colormap
        axes[row, col + 1].imshow(feature_map, cmap='viridis')
        axes[row, col + 1].set_title(f"Channel {ch.item()}\n({feature_map.shape[0]}√ó{feature_map.shape[1]})", 
                                     fontsize=9)
        axes[row, col + 1].axis('off')

# Add annotations explaining what we see
axes[0, 4].text(1.2, 0.5, 
                "‚Üê EARLY\n   ‚Ä¢ High resolution\n   ‚Ä¢ Edge-like\n   ‚Ä¢ Dense activations", 
                fontsize=10, transform=axes[0, 4].transAxes, va='center')

axes[1, 4].text(1.2, 0.5, 
                "‚Üê MIDDLE\n   ‚Ä¢ Medium resolution\n   ‚Ä¢ Texture patterns\n   ‚Ä¢ More structured", 
                fontsize=10, transform=axes[1, 4].transAxes, va='center')

axes[2, 4].text(1.2, 0.5, 
                "‚Üê DEEP\n   ‚Ä¢ Low resolution\n   ‚Ä¢ Abstract, sparse\n   ‚Ä¢ Object-level", 
                fontsize=10, transform=axes[2, 4].transAxes, va='center')

plt.suptitle("Depth = Abstraction: Features become more abstract deeper in the network", 
             fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("üí° WHAT YOU SHOULD SEE:")
print("="*60)
print("EARLY (row 1):")
print("  ‚Ä¢ High resolution (224√ó224 or similar)")
print("  ‚Ä¢ Looks like filtered versions of the image")
print("  ‚Ä¢ You can still recognize the cat's shape")
print("  ‚Ä¢ Detects: edges, simple textures")
print("\nMIDDLE (row 2):")
print("  ‚Ä¢ Medium resolution (~56√ó56)")
print("  ‚Ä¢ More abstract patterns")
print("  ‚Ä¢ Harder to see the original image")
print("  ‚Ä¢ Detects: complex textures, motifs")
print("\nDEEP (row 3):")
print("  ‚Ä¢ Very low resolution (~14√ó14)")
print("  ‚Ä¢ Mostly uniform or very sparse")
print("  ‚Ä¢ Only a few 'hot spots' of activation")
print("  ‚Ä¢ Detects: high-level concepts, object parts")
print("\n‚ö†Ô∏è This is WHY interpretability must be LAYER-AWARE!")
print("   Different layers show completely different information.")

## üéØ Key Takeaway from Block 3

**Abstraction emerges from depth, not from any single smart layer.**

- **Early layers:** See 3√ó3 patches ‚Üí detect edges
- **Middle layers:** Combine edges ‚Üí detect textures and patterns  
- **Deep layers:** Combine patterns ‚Üí detect object parts and concepts

**The magic is in the composition.** Simple operations (convolution, pooling) stacked many times create complex understanding.

**Implication for interpretability:** If you want to understand "what the network sees," you must specify WHICH LAYER. Early layers show low-level features, deep layers show high-level concepts.

---
# Code Block 4: Spatial Locality & Receptive Fields

## The Big Question

**If deep layers have such low resolution (14√ó14), how do they still correspond to locations in the original image?**

## The Answer: Receptive Fields

Each neuron in a deep layer doesn't "see" one pixel. It sees a **receptive field** ‚Äî a region of the original image.

```
Layer 1 neuron:  sees 3√ó3 pixels
Layer 5 neuron:  sees ~40√ó40 pixels (accumulated through layers)
Layer 15 neuron: sees ~200√ó200 pixels (almost the whole image!)
```

But even though receptive fields grow, **spatial correspondence is maintained**:
- Top-left neuron in layer 15 ‚Üí corresponds to top-left region of image
- Bottom-right neuron in layer 15 ‚Üí corresponds to bottom-right region

## The Experiment

To prove this, we'll:
1. Black out a region of the image (the cat's face)
2. Compare activations: original vs perturbed
3. See where the differences appear

**Hypothesis:** Changes should appear in corresponding spatial locations across all layers.

## Why This Matters

This is why techniques like **Grad-CAM** and **saliency maps** work!
- They rely on the fact that deep layers still have spatial structure
- We can trace activations back to specific image regions

In [None]:
# ============================================
# Create a Perturbed Version of the Image
# ============================================

# Make a copy of the original tensor (don't modify the original!)
perturbed_tensor = input_tensor.clone()

# Black out a rectangular region
# Tensor shape: [1, 3, 224, 224]
#   [:, :, 50:150, 60:160] means:
#     : = all batches (just 1)
#     : = all channels (R, G, B)
#     50:150 = rows 50-149 (height)
#     60:160 = columns 60-159 (width)
# Setting to 0 = black (after normalization, 0 is not quite black, but close)

perturbed_tensor[:, :, 50:150, 60:160] = 0

print("‚úÖ Created perturbed image with blacked-out region")
print(f"   Region: rows 50-150, columns 60-160")
print(f"   Size of blackout: 100√ó100 pixels")

In [None]:
# ============================================
# Capture Activations for Both Images
# ============================================

# We need fresh storage dictionaries
activations_original = {}
activations_perturbed = {}

# Hook factory (same as before, but stores to different dict)
def get_activation_dict(storage, name):
    """Creates a hook that stores activations in a specific dictionary."""
    def hook(model, input, output):
        storage[name] = output.detach()
    return hook

# Create a fresh model to avoid hook accumulation
# (registering hooks multiple times can cause issues)
model2 = models.vgg16(pretrained=True).eval()

# Register hooks for ORIGINAL image
handles = []  # Keep track of hooks so we can remove them
handles.append(model2.features[2].register_forward_hook(get_activation_dict(activations_original, 'early')))
handles.append(model2.features[28].register_forward_hook(get_activation_dict(activations_original, 'deep')))

# Forward pass with ORIGINAL image
with torch.no_grad():
    _ = model2(input_tensor)

print("‚úÖ Captured activations for ORIGINAL image")

# Remove hooks (clean up)
for h in handles:
    h.remove()

# Register hooks for PERTURBED image
handles = []
handles.append(model2.features[2].register_forward_hook(get_activation_dict(activations_perturbed, 'early')))
handles.append(model2.features[28].register_forward_hook(get_activation_dict(activations_perturbed, 'deep')))

# Forward pass with PERTURBED image
with torch.no_grad():
    _ = model2(perturbed_tensor)

print("‚úÖ Captured activations for PERTURBED image")

# Remove hooks (clean up)
for h in handles:
    h.remove()

In [None]:
# ============================================
# Compute Difference Maps
# ============================================

# For each layer, compute: |original - perturbed|
# This shows WHERE activations changed

# Early layer difference
# activations_original['early'] shape: [1, 64, 224, 224]
# Step 1: Subtract ‚Üí [1, 64, 224, 224]
# Step 2: .abs() ‚Üí absolute value (we care about magnitude of change, not direction)
# Step 3: .mean(dim=1) ‚Üí average across all 64 channels ‚Üí [1, 224, 224]
# Step 4: [0] ‚Üí get first (only) batch element ‚Üí [224, 224]

early_diff = (activations_original['early'] - activations_perturbed['early']).abs().mean(dim=1)[0]
deep_diff = (activations_original['deep'] - activations_perturbed['deep']).abs().mean(dim=1)[0]

print("‚úÖ Computed difference maps")
print(f"   Early layer diff shape: {early_diff.shape}")
print(f"   Deep layer diff shape:  {deep_diff.shape}")
print(f"\n   Bright pixels in diff map = large activation change")
print(f"   Dark pixels = little/no change")

In [None]:
# ============================================
# Visualize Original vs Perturbed vs Differences
# ============================================

fig, axes = plt.subplots(2, 3, figsize=(14, 9))

# =============================
# Row 1: The Images
# =============================

# Column 0: Original image
axes[0, 0].imshow(img)
axes[0, 0].set_title("Original Image", fontsize=12, fontweight='bold')
axes[0, 0].axis('off')

# Column 1: Perturbed image
# Need to denormalize for display (reverse the normalization we did)
# perturbed_tensor: [1, 3, 224, 224] normalized
# Step 1: [0] ‚Üí [3, 224, 224] (remove batch)
# Step 2: .permute(1, 2, 0) ‚Üí [224, 224, 3] (channels last for imshow)
# Step 3: Denormalize: pixel * std + mean

perturbed_display = perturbed_tensor[0].permute(1, 2, 0).numpy()
perturbed_display = perturbed_display * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
perturbed_display = np.clip(perturbed_display, 0, 1)  # Clip to valid range [0, 1]

axes[0, 1].imshow(perturbed_display)
axes[0, 1].set_title("Perturbed Image\n(face region blacked out)", fontsize=12, fontweight='bold')
axes[0, 1].axis('off')

# Column 2: Question
axes[0, 2].text(0.5, 0.5, "Which\nactivations\nchanged?", 
                fontsize=16, ha='center', va='center', fontweight='bold')
axes[0, 2].axis('off')

# =============================
# Row 2: Difference Maps
# =============================

# Column 0: Early layer difference
axes[1, 0].imshow(early_diff.numpy(), cmap='hot')
axes[1, 0].set_title(f"EARLY layer difference\n(shape: {tuple(early_diff.shape)})", 
                     fontsize=11, fontweight='bold')
axes[1, 0].axis('off')

# Column 1: Deep layer difference
axes[1, 1].imshow(deep_diff.numpy(), cmap='hot')
axes[1, 1].set_title(f"DEEP layer difference\n(shape: {tuple(deep_diff.shape)})", 
                     fontsize=11, fontweight='bold')
axes[1, 1].axis('off')

# Column 2: Observations
axes[1, 2].text(0.05, 0.9, "üí° Observations:", fontsize=12, 
                transform=axes[1, 2].transAxes, fontweight='bold')
axes[1, 2].text(0.05, 0.7, "‚Ä¢ EARLY: Localized\n  change at blackout", fontsize=11, 
                transform=axes[1, 2].transAxes)
axes[1, 2].text(0.05, 0.45, "‚Ä¢ DEEP: Broader but\n  still spatially\n  structured", fontsize=11, 
                transform=axes[1, 2].transAxes)
axes[1, 2].text(0.05, 0.15, "‚Üí Receptive fields\n   grow, but spatial\n   info persists!", fontsize=11, 
                transform=axes[1, 2].transAxes, style='italic')
axes[1, 2].axis('off')

plt.suptitle("Spatial Locality: Changes in the image affect corresponding spatial locations in activations", 
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("üí° WHAT YOU SHOULD SEE:")
print("="*60)
print("EARLY layer difference (bottom-left):")
print("  ‚Ä¢ A bright rectangular region")
print("  ‚Ä¢ Corresponds EXACTLY to where we blacked out pixels")
print("  ‚Ä¢ Size: ~100√ó100 (same as blackout)")
print("  ‚Ä¢ This makes sense: early layers have small receptive fields")
print("\nDEEP layer difference (bottom-middle):")
print("  ‚Ä¢ Lower resolution (14√ó14 instead of 224√ó224)")
print("  ‚Ä¢ Bright region is BROADER (blurrier)")
print("  ‚Ä¢ But STILL in the corresponding location!")
print("  ‚Ä¢ This proves spatial structure is preserved")
print("\n‚ö†Ô∏è KEY INSIGHT:")
print("   Even though deep layers have huge receptive fields,")
print("   they still maintain a spatial map of the image.")
print("   This is WHY Grad-CAM and saliency maps work!")

## üéØ Key Takeaway from Block 4

**CNNs trade spatial precision for semantic meaning gradually.**

- **Receptive fields grow** with depth (3√ó3 ‚Üí 40√ó40 ‚Üí 200√ó200)
- **But spatial correspondence is maintained** (top-left stays top-left)
- **This is not obvious!** It's a design property of conv layers

**Why this matters:** Interpretability methods like Grad-CAM rely on this property to produce spatial heatmaps showing "which part of the image mattered."

---
# Summary: Your New Mental Model of CNNs

## What We Learned

### 1. Images ‚Üí Tensors ‚Üí Activations
- **CNNs don't "see" images** ‚Äî they see arrays of numbers
- **First layer transforms 3 channels ‚Üí 64 channels** ‚Äî 64 different pattern responses
- **Feature maps ‚â† semantic understanding** ‚Äî they're local pattern detectors

### 2. Convolutions = Pattern Matchers
- **Each filter is a 3√ó3 pattern** (learned from data)
- **Filter slides across the image** (weight sharing = efficiency)
- **High activation = pattern found** in that location
- **Convolution is local, not global** ‚Äî each position sees only 3√ó3 neighbors

### 3. Depth Creates Abstraction
- **Early layers:** Edges, colors (3√ó3 receptive field)
- **Middle layers:** Textures, patterns (40√ó40 receptive field)
- **Deep layers:** Object parts, concepts (200√ó200 receptive field)
- **Abstraction emerges from composition**, not from any single layer being "smart"

### 4. Spatial Structure Persists
- **Receptive fields grow** but **spatial correspondence remains**
- **Top-left neuron ‚Üí top-left region** of image, even in deep layers
- **This enables spatial interpretability** (Grad-CAM, saliency, etc.)

---

## Why This Matters for Interpretability

Now when you use interpretability techniques, you'll understand:

| Technique | What It Does | Why It Works (Based on This Notebook) |
|-----------|-------------|---------------------------------------|
| **Feature Visualization** | Generate images that activate a neuron | Uses the fact that neurons detect specific patterns |
| **Grad-CAM** | Heatmap showing important regions | Uses spatial correspondence in deep layers |
| **Saliency Maps** | Which pixels matter most | Uses gradient flow through the network |
| **Layer-wise Analysis** | Compare features at different depths | Uses the abstraction hierarchy (edge ‚Üí texture ‚Üí object) |

---

## The Foundation is Set

You now have a **concrete, grounded mental model** of CNNs:
- ‚úÖ Not "AI magic" ‚Äî pattern matching + composition
- ‚úÖ Not "semantic understanding" ‚Äî local detectors stacked hierarchically  
- ‚úÖ Not "global reasoning" ‚Äî receptive fields grow but start local
- ‚úÖ Not "black box" ‚Äî we can visualize and understand each layer

**Next steps:** Now we can explore interpretability techniques with this solid foundation! üöÄ