# Segment 03: Dataset Exemplars

In Segment 02, we used **activation maximization** to generate synthetic images showing what neurons "want" to see. Those are useful, but they're artificial — they don't tell us what real-world inputs actually trigger those neurons.

In this segment, we flip the approach: instead of synthesizing images, we **search through real ImageNet images** to find the ones that most strongly activate each neuron. These are called **dataset exemplars**.

**Why this matters:**
- Activation maximization shows the *ideal* input (a neuron's "dream image")
- Dataset exemplars show what the neuron *actually responds to* in practice
- Together, they give a much richer picture of what a neuron has learned

**What we'll do:**
- Pass ImageNet images through InceptionV1
- Capture activations at the `mixed4a` layer using a forward hook
- For each of the first 10 channels, find the 10 images that produce the highest activation

## Setup

In [None]:
!pip install -q torch-lucent

In [None]:
import torch
import heapq
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from lucent.modelzoo import inceptionv1

In [None]:
# Load pretrained InceptionV1 (the old TF "inception5h" model ported to PyTorch)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = inceptionv1(pretrained=True).to(device).eval()

print(f"Using device: {device}")

## Capturing Activations with a Forward Hook

To see what's happening *inside* the network, we use a PyTorch **forward hook**. This is a callback function that runs every time a specific layer processes data.

We attach it to the `mixed4a` layer. After every forward pass, our hook captures and stores that layer's output (the activations) so we can inspect them.

In [None]:
# Dictionary to store the captured activation
activation = {}

def hook_fn(module, input, output):
    """This function runs automatically every time mixed4a produces output.
    It saves the output tensor so we can analyze it after the forward pass."""
    activation["mixed4a"] = output.detach()

# Register the hook on the mixed4a layer
# model.mixed4a is a CatLayer that concatenates the 4 branches of the Inception module
# Its output has 508 channels: 192 (1x1) + 204 (3x3) + 48 (5x5) + 64 (pool)
hook_handle = model.mixed4a.register_forward_hook(hook_fn)

print("Hook registered on mixed4a")

## Load ImageNet

We use the ImageNet validation set (50,000 images). Set the path below to point to your local copy.

**Expected directory structure:**
```
imagenet_val_path/
├── n01440764/
│   ├── ILSVRC2012_val_00000293.JPEG
│   └── ...
├── n01443537/
│   └── ...
└── ...
```

**Preprocessing:** Lucent's InceptionV1 is the old TensorFlow model converted to PyTorch. It expects inputs scaled to the range [-117, 138] (i.e. pixel values in [0, 255] minus 117). We do this in two steps:
1. Standard resize + crop + ToTensor → gives us [0, 1] range (for display)
2. Scale by `x * 255 - 117` before feeding to the model

In [None]:
# ========================
# SET YOUR IMAGENET PATH HERE
# ========================
IMAGENET_VAL_PATH = "/path/to/imagenet/val"  # <-- Change this!

# Transforms: resize + center crop to 224x224, then convert to [0, 1] tensor.
# We do NOT apply the InceptionV1 scaling here — we'll do that separately
# so we can keep the [0,1] version for displaying images.
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),  # Converts PIL image to tensor in [0, 1] range
])

# Load the dataset
dataset = datasets.ImageFolder(IMAGENET_VAL_PATH, transform=preprocess)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2)

print(f"Loaded {len(dataset)} images")

## Find the Top-10 Activating Images for Each Channel

For each image, we:
1. Apply InceptionV1 preprocessing (`x * 255 - 117`)
2. Run a forward pass (the hook automatically captures `mixed4a` output)
3. For each of the first 10 channels, compute the **mean spatial activation** (average over the H×W spatial grid)
4. Track the top-10 highest-activating images per channel using a min-heap

**Why mean spatial activation?** Each channel's output is a 2D feature map (like a heatmap). Averaging over spatial positions tells us how strongly this image activates the channel *overall*, not just at one specific location.

In [None]:
NUM_CHANNELS = 10  # First 10 channels of mixed4a
TOP_K = 10         # Top 10 images per channel

# For each channel, maintain a min-heap of (activation_value, unique_id, image_tensor).
# A min-heap keeps the smallest element on top, so we can efficiently check
# whether a new image beats the current weakest entry.
top_images = {ch: [] for ch in range(NUM_CHANNELS)}

# Counter to break ties in the heap (avoids comparing tensors directly)
counter = 0

with torch.no_grad():  # No gradients needed — we're only doing inference
    for images, _ in tqdm(dataloader, desc="Processing ImageNet"):
        # Apply InceptionV1 preprocessing: scale [0,1] → [-117, 138]
        model_input = (images * 255 - 117).to(device)

        # Forward pass — the hook captures mixed4a activations automatically
        model(model_input)
        acts = activation["mixed4a"]  # Shape: [batch_size, 508, H, W]

        # For each image in the batch
        for i in range(images.size(0)):
            for ch in range(NUM_CHANNELS):
                # Mean activation across spatial dimensions for this channel
                act_val = acts[i, ch].mean().item()

                # Store the original [0,1] image (not the preprocessed one) for display later
                entry = (act_val, counter, images[i].cpu())

                if len(top_images[ch]) < TOP_K:
                    # Heap not full yet — just add
                    heapq.heappush(top_images[ch], entry)
                elif act_val > top_images[ch][0][0]:
                    # New image beats the weakest current top image — swap them
                    heapq.heapreplace(top_images[ch], entry)

                counter += 1

## Display Results

For each channel (0–9), we show the 10 ImageNet images that produced the highest mean activation. Images are sorted left-to-right from highest to lowest activation.

In [None]:
fig, axes = plt.subplots(NUM_CHANNELS, TOP_K, figsize=(20, 20))

for ch in range(NUM_CHANNELS):
    # Sort by activation (highest first)
    ranked = sorted(top_images[ch], key=lambda x: x[0], reverse=True)

    for rank, (act_val, _, img_tensor) in enumerate(ranked):
        ax = axes[ch][rank]
        # Convert from [C, H, W] tensor to [H, W, C] numpy for matplotlib
        img = img_tensor.permute(1, 2, 0).numpy()
        img = np.clip(img, 0, 1)  # Ensure valid range
        ax.imshow(img)
        ax.axis("off")

        # Label: activation value on top image, channel label on leftmost image
        if rank == 0:
            ax.set_ylabel(f"Ch {ch}", fontsize=12, rotation=0, labelpad=40)
        if ch == 0:
            ax.set_title(f"#{rank+1}", fontsize=10)

plt.suptitle("Top-10 ImageNet Images per Channel (mixed4a, channels 0–9)", fontsize=16, y=1.01)
plt.tight_layout()
plt.show()

## Observations

Look at each row (channel) and ask yourself:

- **Do the top images share a common theme?** If a channel's top images all contain, say, furry textures or circular shapes, the neuron likely detects that pattern.

- **How does this compare to the activation maximization from Segment 02?** The synthetic visualization showed us the neuron's "ideal" input. Do these real images contain similar patterns, colors, or textures?

- **Are some channels more interpretable than others?** Some channels may have clearly coherent top images (all showing the same kind of thing), while others may look more scattered. Scattered results could indicate a **polysemantic** neuron — one that responds to multiple unrelated concepts.

- **What role does spatial structure play?** Some channels might respond to patterns regardless of where they appear; others might prefer certain spatial arrangements.

## Cleanup

Remove the forward hook so it doesn't interfere with future model use.

In [None]:
hook_handle.remove()
print("Hook removed.")