# Segment 03: Dataset Exemplars

## What Are Dataset Exemplars?

In **Segment 02**, we used *activation maximization* to generate synthetic images that maximally activate specific neurons. Those visualizations show us a neuron's "dream image" — the ideal input pattern it's looking for.

But synthetic images don't tell us what **real-world inputs** actually trigger those neurons.

**Dataset exemplars** flip the approach:
- Instead of *generating* images, we *search* through a large dataset of real images
- For each neuron, we find the images that produce the highest activation
- These show us what the neuron *actually responds to* in practice

## Why This Matters

| Approach | What It Shows | Limitation |
|----------|--------------|------------|
| Activation Maximization | The "ideal" input pattern | Synthetic, may not exist in real data |
| Dataset Exemplars | What the neuron responds to in practice | Limited to images in the dataset |

Together, they give a much richer picture of what each neuron has learned to detect.

## What We'll Do

1. Stream the full **ImageNet training set** (1.28 million images) from HuggingFace
2. Pass each image through **InceptionV1** and capture activations at the `mixed4a` layer
3. For each of the **first 10 channels**, track the **top 10 images** with highest activation
4. **Checkpoint progress** to HuggingFace so we can pause/resume at any time
5. Visualize and compare with our Segment 02 results

---
## 1. Setup & Dependencies

In [1]:
# Install required packages
# - torch-lucent: For loading InceptionV1 (the old TensorFlow model ported to PyTorch)
# - datasets: HuggingFace library for streaming ImageNet without downloading 150GB
# - huggingface_hub: For saving/loading checkpoints to HuggingFace

!pip install -q torch-lucent datasets huggingface_hub

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.1/46.1 kB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m225.1/225.1 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m254.1/254.1 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for docopt (setup.py) ... [?25l[?25hdone


In [2]:
import torch
import heapq
import json
import base64
import io
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
from torchvision import transforms
from datasets import load_dataset
from huggingface_hub import HfApi, hf_hub_download, upload_file, create_repo
from lucent.modelzoo import inceptionv1

### HuggingFace Authentication

You need to be logged in to HuggingFace because:
1. **ImageNet is gated** — you must accept the terms at https://huggingface.co/datasets/ILSVRC/imagenet-1k
2. **We save checkpoints** to your private HF repo

Run `huggingface-cli login` in your terminal, or use the cell below:

In [3]:
# Option 1: Interactive login (will prompt for token)
# from huggingface_hub import login
# login()

# Option 2: If already logged in via CLI, this will confirm
from huggingface_hub import whoami
try:
    user_info = whoami()
    print(f"Logged in as: {user_info['name']}")
except Exception as e:
    print("Not logged in! Run: huggingface-cli login")
    print(f"Error: {e}")

Logged in as: ayesha-imr02


---
## 2. Configuration

In [4]:
# =============================================================================
# CONFIGURATION - Edit these values as needed
# =============================================================================

# HuggingFace repository for checkpoints (EDIT THIS!)
# Format: "your-username/your-repo-name"
HF_REPO_ID = "ayesha-imr02/inceptionv1-imagenet-mixed4a-top10"  # <-- CHANGE THIS!

# Layer and channels to analyze
LAYER_NAME = "mixed4a"   # The layer we're studying (middle layer of InceptionV1)
NUM_CHANNELS = 10        # First 10 channels (matching Segment 02)
TOP_K = 10               # Keep top 10 images per channel

# Checkpointing
CHECKPOINT_EVERY = 1000  # Save progress every N images (allows safe interruption)
CHECKPOINT_FILE = "checkpoint.json"  # Filename in HF repo

# Processing
BATCH_SIZE = 1           # Process one image at a time (streaming mode)
TOTAL_IMAGES = 1_281_167 # ImageNet training set size

print("Configuration:")
print(f"  - HF Repository: {HF_REPO_ID}")
print(f"  - Layer: {LAYER_NAME}")
print(f"  - Channels: 0-{NUM_CHANNELS-1}")
print(f"  - Top K images per channel: {TOP_K}")
print(f"  - Checkpoint every: {CHECKPOINT_EVERY} images")
print(f"  - Total images to process: {TOTAL_IMAGES:,}")

Configuration:
  - HF Repository: ayesha-imr02/inceptionv1-imagenet-mixed4a-top10
  - Layer: mixed4a
  - Channels: 0-9
  - Top K images per channel: 10
  - Checkpoint every: 1000 images
  - Total images to process: 1,281,167


---
## 3. Model Setup

We load **InceptionV1** (also known as GoogLeNet or "inception5h") — the same model used in Segment 02.

This is the original TensorFlow model from 2015, converted to PyTorch. It's commonly used in interpretability research because:
- It has clear, well-studied features
- The Distill article "Feature Visualization" provides reference visualizations
- It's small enough to run quickly but deep enough to be interesting

In [5]:
# Detect available device (GPU is much faster)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load pretrained InceptionV1
# - pretrained=True downloads weights trained on ImageNet
# - .eval() puts the model in inference mode (disables dropout, etc.)
model = inceptionv1(pretrained=True).to(device).eval()

print("InceptionV1 loaded successfully!")

Using device: cpu
Downloading: "https://github.com/ProGamerGov/pytorch-old-tensorflow-models/raw/master/inception5h.pth" to /root/.cache/torch/hub/checkpoints/inception5h.pth


100%|██████████| 27.0M/27.0M [00:00<00:00, 179MB/s]


InceptionV1 loaded successfully!


### Forward Hook: Capturing Internal Activations

To see what's happening *inside* the network, we use a PyTorch **forward hook**.

**How hooks work:**
1. We register a callback function on a specific layer (`mixed4a`)
2. Every time data flows through that layer, our callback runs
3. The callback saves the layer's output (the "activations") for us to analyze

**About `mixed4a`:**
- It's an "Inception module" — a block that applies multiple filter sizes in parallel
- Output shape: `[batch, 508, H, W]` where 508 = 192 + 204 + 48 + 64 (from 4 branches)
- H and W depend on input image size (for 224x224 input: H=W=14)

In [6]:
# Dictionary to store captured activations
# We use a dict so the hook function can modify it (closures can't reassign outer variables)
activation_storage = {}

def activation_hook(module, input_tensor, output_tensor):
    """
    Forward hook callback.

    This function is called automatically every time the mixed4a layer
    produces output during a forward pass.

    Args:
        module: The layer this hook is attached to (mixed4a)
        input_tensor: The input to this layer (we don't need it)
        output_tensor: The layer's output - this is what we want
    """
    # Detach from computation graph (we don't need gradients)
    # This saves memory and prevents gradient accumulation
    activation_storage[LAYER_NAME] = output_tensor.detach()

# Register the hook on the mixed4a layer
# model.mixed4a is the Inception module we want to study
hook_handle = model.mixed4a.register_forward_hook(activation_hook)

print(f"Hook registered on '{LAYER_NAME}'")
print("Now every forward pass will capture this layer's activations.")

Hook registered on 'mixed4a'
Now every forward pass will capture this layer's activations.


---
## 4. Checkpoint Utilities

Processing 1.28 million images takes several hours. If the runtime crashes or you need to stop, you don't want to start over.

**Our checkpointing strategy:**
1. Every 1000 images, save current progress to HuggingFace
2. The checkpoint contains:
   - `images_processed`: How many images we've seen
   - `top_images`: The current top-K images for each channel (with their actual pixel data)
3. On startup, we check if a checkpoint exists and resume from there

In [7]:
def image_to_base64(pil_image):
    """
    Convert a PIL Image to a base64-encoded string.

    We store images as base64 in JSON because:
    - JSON is easy to save/load from HuggingFace
    - Base64 preserves exact pixel values
    - It's self-contained (no separate image files to track)

    Args:
        pil_image: A PIL Image object

    Returns:
        Base64-encoded string of the PNG image
    """
    buffer = io.BytesIO()
    pil_image.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")


def base64_to_image(b64_string):
    """
    Convert a base64-encoded string back to a PIL Image.

    Args:
        b64_string: Base64-encoded PNG image data

    Returns:
        PIL Image object
    """
    image_data = base64.b64decode(b64_string)
    return Image.open(io.BytesIO(image_data))


In [8]:
def ensure_repo_exists(repo_id):
    """
    Create the HuggingFace repository if it doesn't exist.

    Args:
        repo_id: Repository ID in format "username/repo-name"
    """
    api = HfApi()
    try:
        # Try to get repo info — if it exists, this succeeds
        api.repo_info(repo_id=repo_id, repo_type="dataset")
        print(f"Repository '{repo_id}' already exists.")
    except Exception:
        # Repository doesn't exist — create it
        print(f"Creating new private repository: {repo_id}")
        create_repo(repo_id=repo_id, repo_type="dataset", private=True)
        print(f"Repository created!")


def save_checkpoint(images_processed, top_images_heaps, repo_id):
    """
    Save current progress to HuggingFace.

    This converts the in-memory heaps to a JSON-serializable format
    and uploads to the HuggingFace repository.

    Args:
        images_processed: Number of images we've processed so far
        top_images_heaps: Dict mapping channel_id -> heap of (activation, counter, PIL_image)
        repo_id: HuggingFace repository ID
    """
    # Convert heaps to serializable format
    # Each heap entry is (activation_value, tie_breaker_counter, pil_image)
    serializable_data = {
        "images_processed": images_processed,
        "top_images": {}
    }

    for channel_id, heap in top_images_heaps.items():
        serializable_data["top_images"][str(channel_id)] = [
            {
                "activation": entry[0],      # The activation value
                "counter": entry[1],          # Tie-breaker counter
                "image_b64": image_to_base64(entry[2])  # PIL image as base64
            }
            for entry in heap
        ]

    # Write to a temporary file, then upload
    checkpoint_json = json.dumps(serializable_data)

    # Upload to HuggingFace
    api = HfApi()
    api.upload_file(
        path_or_fileobj=checkpoint_json.encode("utf-8"),
        path_in_repo=CHECKPOINT_FILE,
        repo_id=repo_id,
        repo_type="dataset",
    )

    print(f"  [Checkpoint saved: {images_processed:,} images processed]")


def load_checkpoint(repo_id):
    """
    Load previous progress from HuggingFace (if exists).

    Args:
        repo_id: HuggingFace repository ID

    Returns:
        Tuple of (images_processed, top_images_heaps)
        If no checkpoint exists, returns (0, empty_heaps)
    """
    try:
        # Try to download the checkpoint file
        checkpoint_path = hf_hub_download(
            repo_id=repo_id,
            filename=CHECKPOINT_FILE,
            repo_type="dataset"
        )

        with open(checkpoint_path, "r") as f:
            data = json.load(f)

        # Reconstruct the heaps
        top_images_heaps = {}
        for channel_str, entries in data["top_images"].items():
            channel_id = int(channel_str)
            heap = []
            for entry in entries:
                heap_entry = (
                    entry["activation"],
                    entry["counter"],
                    base64_to_image(entry["image_b64"])
                )
                heap.append(heap_entry)
            # Heapify to restore heap property
            heapq.heapify(heap)
            top_images_heaps[channel_id] = heap

        images_processed = data["images_processed"]
        print(f"Checkpoint loaded! Resuming from image {images_processed:,}")
        return images_processed, top_images_heaps

    except Exception as e:
        # No checkpoint found — start fresh
        print(f"No checkpoint found. Starting from scratch.")
        print(f"  (Reason: {e})")
        return 0, {ch: [] for ch in range(NUM_CHANNELS)}


---
## 5. ImageNet Streaming Setup

The ImageNet training set is ~150GB. Instead of downloading it all, we **stream** images one at a time using HuggingFace's `datasets` library.

**How streaming works:**
- Images are downloaded on-demand as we iterate
- Only one image is in memory at a time
- Much faster to start (no waiting for full download)

**Preprocessing for InceptionV1:**
- Resize smallest edge to 256px, then center crop to 224×224
- Scale pixel values: `pixel * 255 - 117` (the original TensorFlow model expects this range)

In [9]:
# Image preprocessing pipeline
# This transforms PIL images into the format InceptionV1 expects

preprocess_for_model = transforms.Compose([
    transforms.Resize(256),           # Resize so smallest edge is 256px
    transforms.CenterCrop(224),       # Crop center 224x224 region
    transforms.ToTensor(),            # Convert to tensor, scales to [0, 1]
])

def preprocess_image(pil_image):
    """
    Prepare an image for InceptionV1.

    Args:
        pil_image: PIL Image (any size, RGB or other mode)

    Returns:
        Tensor of shape [1, 3, 224, 224] ready for the model
    """
    # Ensure RGB (some ImageNet images are grayscale)
    if pil_image.mode != "RGB":
        pil_image = pil_image.convert("RGB")

    # Apply transforms: resize, crop, convert to [0, 1] tensor
    tensor = preprocess_for_model(pil_image)  # Shape: [3, 224, 224]

    # Scale for InceptionV1: [0, 1] -> [-117, 138]
    # The original TF model was trained with this scaling
    tensor = tensor * 255 - 117

    # Add batch dimension: [3, 224, 224] -> [1, 3, 224, 224]
    return tensor.unsqueeze(0)


In [10]:
# Load ImageNet training set in streaming mode
# This requires you to have accepted the dataset terms on HuggingFace

imagenet_stream = load_dataset(
    "ILSVRC/imagenet-1k",
    split="train",
    streaming=True,  # Don't download everything — stream on demand
    trust_remote_code=True
)

print(f"ImageNet stream ready. Will process {TOTAL_IMAGES:,} images.")

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'ILSVRC/imagenet-1k' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.
ERROR:datasets.load:`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'ILSVRC/imagenet-1k' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


README.md:   0%|          | 0.00/87.6k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/294 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/294 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

ImageNet stream ready. Will process 1,281,167 images.


---
## 6. Main Processing Loop

This is where the work happens. For each image:

1. **Preprocess** — resize/crop/scale for InceptionV1
2. **Forward pass** — run through the model (hook captures activations)
3. **Compute channel activations** — mean activation per channel
4. **Update heaps** — if this image is in the top-K for any channel, add it
5. **Checkpoint** — save progress every 1000 images

**About the heap data structure:**
- We use a min-heap (smallest element on top) for efficiency
- When full, we only keep an image if it beats the current minimum
- This is O(log K) per update, much faster than sorting

**You can safely interrupt this cell at any time!** Progress is saved every 1000 images. Just re-run the cell to resume.

In [11]:
# Ensure our HuggingFace repo exists (creates if needed)
ensure_repo_exists(HF_REPO_ID)

# Load checkpoint (or start fresh)
images_processed, top_images = load_checkpoint(HF_REPO_ID)

# Counter for heap tie-breaking
# (When two images have equal activation, we use this to decide order)
# Start from where we left off to maintain consistency
counter = images_processed * NUM_CHANNELS

Repository 'ayesha-imr02/inceptionv1-imagenet-mixed4a-top10' already exists.


checkpoint.json:   0%|          | 0.00/69.8M [00:00<?, ?B/s]

Checkpoint loaded! Resuming from image 229,000


In [None]:
# Main processing loop
print(f"\nStarting processing from image {images_processed:,}...")
print(f"Checkpoints will be saved every {CHECKPOINT_EVERY:,} images.")

# Progress bar
pbar = tqdm(
    enumerate(imagenet_stream),
    total=TOTAL_IMAGES,
    initial=images_processed,
    desc="Processing ImageNet"
)

with torch.no_grad():  # Disable gradient computation (we're only doing inference)
    for idx, sample in pbar:
        # Skip images we've already processed (when resuming)
        if idx < images_processed:
            continue

        # Get the image (HuggingFace returns a dict with 'image' and 'label' keys)
        pil_image = sample["image"]

        # Keep a copy of the original image (for storing in results)
        # We store the original, not the preprocessed version
        original_image = pil_image.copy()

        # Preprocess for InceptionV1
        model_input = preprocess_image(pil_image).to(device)

        # Forward pass — the hook automatically captures mixed4a activations
        model(model_input)

        # Get the captured activations
        # Shape: [1, 508, H, W] where H=W=14 for 224x224 input
        acts = activation_storage[LAYER_NAME]

        # For each channel we're tracking
        for ch in range(NUM_CHANNELS):
            # Compute mean spatial activation for this channel
            # This tells us how strongly the whole image activates this channel
            activation_value = acts[0, ch].mean().item()

            # Create heap entry: (activation, counter, image)
            # The counter breaks ties when activations are equal
            entry = (activation_value, counter, original_image)
            counter += 1

            # Update the heap
            if len(top_images[ch]) < TOP_K:
                # Heap not full yet — just add the image
                heapq.heappush(top_images[ch], entry)
            elif activation_value > top_images[ch][0][0]:
                # Heap is full, but this image beats the current minimum
                # Replace the minimum with this new image
                heapq.heapreplace(top_images[ch], entry)

        # Update progress bar with current best activation
        if (idx + 1) % 100 == 0:
            best_act = max(top_images[0])[0] if top_images[0] else 0
            pbar.set_postfix({"ch0_best": f"{best_act:.2f}"})

        # Save checkpoint periodically
        if (idx + 1) % CHECKPOINT_EVERY == 0:
            save_checkpoint(idx + 1, top_images, HF_REPO_ID)

# Final checkpoint after processing all images
save_checkpoint(TOTAL_IMAGES, top_images, HF_REPO_ID)
print("\nProcessing complete!")


Starting processing from image 229,000...
Checkpoints will be saved every 1,000 images.


Processing ImageNet:  18%|#7        | 229000/1281167 [00:00<?, ?it/s]

'The read operation timed out' thrown while requesting GET https://huggingface.co/datasets/ILSVRC/imagenet-1k/resolve/49e2ee26f3810fb5a7536bbf732a7b07389a47b5/data/train-00023-of-00294.parquet
Retrying in 1s [Retry 1/5].


---
## 7. Results Visualization

Now let's see what we found! For each channel, we display the 10 ImageNet images that produced the highest activation.

**How to interpret these results:**
- Look for **common themes** across the top images for each channel
- Compare with the **activation maximization** images from Segment 02
- Some channels will be clearly interpretable; others may be more mysterious

In [None]:
# Create visualization grid
# Rows = channels (0-9), Columns = top images ranked by activation

fig, axes = plt.subplots(NUM_CHANNELS, TOP_K, figsize=(20, 22))

for ch in range(NUM_CHANNELS):
    # Sort heap by activation (highest first)
    # Heap entries are (activation, counter, pil_image)
    ranked = sorted(top_images[ch], key=lambda x: x[0], reverse=True)

    for rank, (activation_value, _, pil_image) in enumerate(ranked):
        ax = axes[ch][rank]

        # Display the image
        ax.imshow(pil_image)
        ax.axis("off")

        # Add labels
        if rank == 0:
            # Channel label on the left
            ax.set_ylabel(f"Ch {ch}", fontsize=14, rotation=0, labelpad=50, va="center")
            # Activation value below the image
            ax.set_xlabel(f"act={activation_value:.2f}", fontsize=9)
        else:
            ax.set_xlabel(f"{activation_value:.2f}", fontsize=9)

        if ch == 0:
            # Rank label on top
            ax.set_title(f"#{rank+1}", fontsize=11)

plt.suptitle(
    f"Top-{TOP_K} ImageNet Images per Channel ({LAYER_NAME}, channels 0–{NUM_CHANNELS-1})",
    fontsize=16,
    y=1.02
)
plt.tight_layout()
plt.show()

---
## 8. Cleanup

In [None]:
# Remove the forward hook to clean up
hook_handle.remove()
print("Hook removed.")

# Clear activation storage
activation_storage.clear()
print("Activation storage cleared.")

print(f"\nResults are saved in HuggingFace repo: {HF_REPO_ID}")