# Segment 03: Dataset Exemplars

## What Are Dataset Exemplars?

In **Segment 02**, we used *activation maximization* to generate synthetic images that maximally activate specific neurons. Those visualizations show us a neuron's "dream image" — the ideal input pattern it's looking for.

But synthetic images don't tell us what **real-world inputs** actually trigger those neurons.

**Dataset exemplars** flip the approach:
- Instead of *generating* images, we *search* through a large dataset of real images
- For each neuron, we find the images that produce the highest activation
- These show us what the neuron *actually responds to* in practice

## Why This Matters

| Approach | What It Shows | Limitation |
|----------|--------------|------------|
| Activation Maximization | The "ideal" input pattern | Synthetic, may not exist in real data |
| Dataset Exemplars | What the neuron responds to in practice | Limited to images in the dataset |

Together, they give a much richer picture of what each neuron has learned to detect.

## What We'll Do

1. Stream the full **ImageNet training set** (1.28 million images) from HuggingFace
2. Pass each image through **InceptionV1** and capture activations at the `mixed4a` layer
3. For each of the **first 10 channels**, track the **top 10 images** with highest activation
4. **Checkpoint progress** to HuggingFace so we can pause/resume at any time
5. Visualize and compare with our Segment 02 results

---
## 1. Setup & Dependencies

In [1]:
# Install required packages
# - torch-lucent: For loading InceptionV1 (the old TensorFlow model ported to PyTorch)
# - datasets: HuggingFace library for streaming ImageNet without downloading 150GB
# - huggingface_hub: For saving/loading checkpoints to HuggingFace

!pip install -q torch-lucent datasets huggingface_hub

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.1/46.1 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m225.1/225.1 kB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m254.1/254.1 kB[0m [31m30.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m58.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for docopt (setup.py) ... [?25l[?25hdone


In [2]:
import torch
import heapq
import json
import io
import time
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.auto import tqdm
from torchvision import transforms
from datasets import load_dataset
from huggingface_hub import HfApi, hf_hub_download, upload_file, create_repo, list_repo_files
from lucent.modelzoo import inceptionv1

print("All imports successful!")

All imports successful!


### HuggingFace Authentication

You need to be logged in to HuggingFace because:
1. **ImageNet is gated** — you must accept the terms at https://huggingface.co/datasets/ILSVRC/imagenet-1k
2. **We save checkpoints** to your private HF repo

Run `huggingface-cli login` in your terminal, or use the cell below:

In [3]:
# Option 1: Interactive login (will prompt for token)
# from huggingface_hub import login
# login()

# Option 2: If already logged in via CLI, this will confirm
from huggingface_hub import whoami
try:
    user_info = whoami()
    print(f"Logged in as: {user_info['name']}")
except Exception as e:
    print("Not logged in! Run: huggingface-cli login")
    print(f"Error: {e}")

Logged in as: ayesha-imr02


---
## 2. Configuration

All important settings in one place. **Edit the `HF_REPO_ID` to your own repository name.**

In [4]:
# =============================================================================
# CONFIGURATION - Edit these values as needed
# =============================================================================

# HuggingFace repository for checkpoints (EDIT THIS!)
# Format: "your-username/your-repo-name"
HF_REPO_ID = "ayesha-imr02/inceptionv1-imagenet-mixed4a-top10"
# Layer and channels to analyze
LAYER_NAME = "mixed4a"   # The layer we're studying (middle layer of InceptionV1)
NUM_CHANNELS = 10        # First 10 channels (matching Segment 02)
TOP_K = 10               # Keep top 10 images per channel

# Checkpointing
CHECKPOINT_EVERY = 5000   # Save progress every N images
CHECKPOINT_FILE = "checkpoint.json"  # Small metadata file (~2KB)

# Processing
TOTAL_IMAGES = 1_281_167  # ImageNet training set size

print("Configuration:")
print(f"  - HF Repository: {HF_REPO_ID}")
print(f"  - Layer: {LAYER_NAME}")
print(f"  - Channels: 0-{NUM_CHANNELS-1}")
print(f"  - Top K images per channel: {TOP_K}")
print(f"  - Checkpoint every: {CHECKPOINT_EVERY:,} images")
print(f"  - Total images to process: {TOTAL_IMAGES:,}")

Configuration:
  - HF Repository: ayesha-imr02/inceptionv1-imagenet-mixed4a-top10
  - Layer: mixed4a
  - Channels: 0-9
  - Top K images per channel: 10
  - Checkpoint every: 5,000 images
  - Total images to process: 1,281,167


---
## 3. Model Setup

We load **InceptionV1** (also known as GoogLeNet or "inception5h") — the same model used in Segment 02.

This is the original TensorFlow model from 2015, converted to PyTorch. It's commonly used in interpretability research because:
- It has clear, well-studied features
- The Distill article "Feature Visualization" provides reference visualizations
- It's small enough to run quickly but deep enough to be interesting

In [5]:
# Detect available device (GPU is much faster)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load pretrained InceptionV1
# - pretrained=True downloads weights trained on ImageNet
# - .eval() puts the model in inference mode (disables dropout, etc.)
model = inceptionv1(pretrained=True).to(device).eval()

print("InceptionV1 loaded successfully!")

Using device: cuda
Downloading: "https://github.com/ProGamerGov/pytorch-old-tensorflow-models/raw/master/inception5h.pth" to /root/.cache/torch/hub/checkpoints/inception5h.pth


100%|██████████| 27.0M/27.0M [00:00<00:00, 196MB/s]


InceptionV1 loaded successfully!


### Forward Hook: Capturing Internal Activations

To see what's happening *inside* the network, we use a PyTorch **forward hook**.

**How hooks work:**
1. We register a callback function on a specific layer (`mixed4a`)
2. Every time data flows through that layer, our callback runs
3. The callback saves the layer's output (the "activations") for us to analyze

**About `mixed4a`:**
- It's an "Inception module" — a block that applies multiple filter sizes in parallel
- Output shape: `[batch, 508, H, W]` where 508 = 192 + 204 + 48 + 64 (from 4 branches)
- H and W depend on input image size (for 224x224 input: H=W=14)

In [6]:
# Dictionary to store captured activations
# We use a dict so the hook function can modify it (closures can't reassign outer variables)
activation_storage = {}

def activation_hook(module, input_tensor, output_tensor):
    """
    Forward hook callback.

    This function is called automatically every time the mixed4a layer
    produces output during a forward pass.

    Args:
        module: The layer this hook is attached to (mixed4a)
        input_tensor: The input to this layer (we don't need it)
        output_tensor: The layer's output — this is what we want!
    """
    # Detach from computation graph (we don't need gradients)
    # This saves memory and prevents gradient accumulation
    activation_storage[LAYER_NAME] = output_tensor.detach()

# Register the hook on the mixed4a layer
# model.mixed4a is the Inception module we want to study
hook_handle = model.mixed4a.register_forward_hook(activation_hook)

print(f"Hook registered on '{LAYER_NAME}'")
print("Now every forward pass will capture this layer's activations.")

Hook registered on 'mixed4a'
Now every forward pass will capture this layer's activations.


---
## 4. Checkpoint Utilities (Robust Version)

Processing 1.28 million images takes several hours. We need robust checkpointing that won't fail.

**New approach — separate images from metadata:**

```
HuggingFace Repo:
├── checkpoint.json          (~2KB) - just numbers, uploads fast
│   {
│     "images_processed": 500000,
│     "top_images": {
│       "0": [{"activation": 5.2, "filename": "images/ch0_rank0.png"}, ...]
│     }
│   }
└── images/                  (separate image files)
    ├── ch0_rank0.png        (uploaded once, only changes if beaten)
    ├── ch0_rank1.png
    └── ...
```

**Why this is more robust:**
- Checkpoint file is tiny (~2KB) → fast, reliable uploads
- Images uploaded separately, only when they change
- If checkpoint upload fails, images already uploaded are safe
- Can browse results visually on HuggingFace anytime

In [7]:
def ensure_repo_exists(repo_id):
    """
    Create the HuggingFace repository if it doesn't exist.

    Args:
        repo_id: Repository ID in format "username/repo-name"
    """
    api = HfApi()
    try:
        api.repo_info(repo_id=repo_id, repo_type="dataset")
        print(f"Repository '{repo_id}' already exists.")
    except Exception:
        print(f"Creating new private repository: {repo_id}")
        create_repo(repo_id=repo_id, repo_type="dataset", private=True)
        print(f"Repository created!")


def upload_with_retry(api, content_bytes, path_in_repo, repo_id, max_retries=3):
    """
    Upload a file to HuggingFace with retry logic.

    Args:
        api: HfApi instance
        content_bytes: File content as bytes
        path_in_repo: Destination path in the repo
        repo_id: Repository ID
        max_retries: Number of retry attempts

    Returns:
        True if successful, False otherwise
    """
    for attempt in range(max_retries):
        try:
            api.upload_file(
                path_or_fileobj=content_bytes,
                path_in_repo=path_in_repo,
                repo_id=repo_id,
                repo_type="dataset",
            )
            return True
        except Exception as e:
            if attempt < max_retries - 1:
                wait_time = 2 ** attempt  # Exponential backoff: 1s, 2s, 4s
                print(f"    Upload failed, retrying in {wait_time}s... ({e})")
                time.sleep(wait_time)
            else:
                print(f"    Upload failed after {max_retries} attempts: {e}")
                return False
    return False

In [8]:
def get_image_filename(channel, rank):
    """
    Generate a consistent filename for an image.

    Args:
        channel: Channel number (0-9)
        rank: Rank in the top-K (0-9, where 0 is highest activation)

    Returns:
        Filename like "images/ch0_rank0.png"
    """
    return f"images/ch{channel}_rank{rank}.png"


def upload_image(api, pil_image, channel, rank, repo_id):
    """
    Upload a single image to HuggingFace.

    Args:
        api: HfApi instance
        pil_image: PIL Image to upload
        channel: Channel number
        rank: Rank in top-K
        repo_id: Repository ID

    Returns:
        Filename if successful, None otherwise
    """
    filename = get_image_filename(channel, rank)

    # Convert PIL image to PNG bytes
    buffer = io.BytesIO()
    pil_image.save(buffer, format="PNG")
    image_bytes = buffer.getvalue()

    if upload_with_retry(api, image_bytes, filename, repo_id):
        return filename
    return None


def save_checkpoint(images_processed, top_images_heaps, repo_id, uploaded_images):
    """
    Save current progress to HuggingFace.

    This saves:
    1. Any NEW top images that haven't been uploaded yet (as separate files)
    2. A small JSON checkpoint with metadata only

    Args:
        images_processed: Number of images we've processed so far
        top_images_heaps: Dict mapping channel_id -> heap of (activation, counter, PIL_image)
        repo_id: HuggingFace repository ID
        uploaded_images: Dict tracking which images have been uploaded {(channel, counter): filename}

    Returns:
        Updated uploaded_images dict
    """
    api = HfApi()

    # Build checkpoint data and upload any new images
    checkpoint_data = {
        "images_processed": images_processed,
        "top_images": {}
    }

    images_uploaded_this_round = 0

    for channel_id, heap in top_images_heaps.items():
        # Sort by activation (highest first) to assign ranks
        sorted_entries = sorted(heap, key=lambda x: x[0], reverse=True)

        channel_data = []
        for rank, (activation, counter, pil_image) in enumerate(sorted_entries):
            # Check if this specific image has been uploaded
            image_key = (channel_id, counter)

            if image_key not in uploaded_images:
                # New image — upload it
                filename = upload_image(api, pil_image, channel_id, rank, repo_id)
                if filename:
                    uploaded_images[image_key] = filename
                    images_uploaded_this_round += 1
            else:
                # Image already uploaded, but might need to update filename if rank changed
                expected_filename = get_image_filename(channel_id, rank)
                if uploaded_images[image_key] != expected_filename:
                    # Rank changed, re-upload to new filename
                    filename = upload_image(api, pil_image, channel_id, rank, repo_id)
                    if filename:
                        uploaded_images[image_key] = filename
                        images_uploaded_this_round += 1

            # Add to checkpoint data
            channel_data.append({
                "activation": activation,
                "counter": counter,
                "filename": uploaded_images.get(image_key, f"images/ch{channel_id}_rank{rank}.png")
            })

        checkpoint_data["top_images"][str(channel_id)] = channel_data

    # Save the small JSON checkpoint
    checkpoint_json = json.dumps(checkpoint_data, indent=2)

    if upload_with_retry(api, checkpoint_json.encode("utf-8"), CHECKPOINT_FILE, repo_id):
        print(f"  [Checkpoint saved: {images_processed:,} images | {images_uploaded_this_round} new images uploaded]")
    else:
        print(f"  [WARNING: Checkpoint JSON upload failed at {images_processed:,} images]")

    return uploaded_images

In [9]:
def load_checkpoint(repo_id):
    """
    Load previous progress from HuggingFace (if exists).

    Args:
        repo_id: HuggingFace repository ID

    Returns:
        Tuple of (images_processed, top_images_heaps, uploaded_images)
        If no checkpoint exists, returns (0, empty_heaps, empty_dict)
    """
    try:
        # Try to download the checkpoint file
        checkpoint_path = hf_hub_download(
            repo_id=repo_id,
            filename=CHECKPOINT_FILE,
            repo_type="dataset"
        )

        with open(checkpoint_path, "r") as f:
            data = json.load(f)

        images_processed = data["images_processed"]

        # Reconstruct the heaps and uploaded_images tracking
        top_images_heaps = {}
        uploaded_images = {}

        for channel_str, entries in data["top_images"].items():
            channel_id = int(channel_str)
            heap = []

            for entry in entries:
                # Download the image from HuggingFace
                try:
                    image_path = hf_hub_download(
                        repo_id=repo_id,
                        filename=entry["filename"],
                        repo_type="dataset"
                    )
                    pil_image = Image.open(image_path)
                    # Convert to RGB if needed and make a copy to avoid file handle issues
                    if pil_image.mode != "RGB":
                        pil_image = pil_image.convert("RGB")
                    pil_image = pil_image.copy()
                except Exception as e:
                    print(f"    Warning: Could not load {entry['filename']}: {e}")
                    continue

                heap_entry = (
                    entry["activation"],
                    entry["counter"],
                    pil_image
                )
                heap.append(heap_entry)

                # Track that this image is already uploaded
                uploaded_images[(channel_id, entry["counter"])] = entry["filename"]

            # Heapify to restore heap property
            heapq.heapify(heap)
            top_images_heaps[channel_id] = heap

        print(f"Checkpoint loaded! Resuming from image {images_processed:,}")
        print(f"  Loaded {sum(len(h) for h in top_images_heaps.values())} existing top images")
        return images_processed, top_images_heaps, uploaded_images

    except Exception as e:
        # No checkpoint found — start fresh
        print(f"No checkpoint found. Starting from scratch.")
        print(f"  (Reason: {e})")
        return 0, {ch: [] for ch in range(NUM_CHANNELS)}, {}

---
## 5. ImageNet Streaming Setup

The ImageNet training set is ~150GB. Instead of downloading it all, we **stream** images one at a time using HuggingFace's `datasets` library.

**How streaming works:**
- Images are downloaded on-demand as we iterate
- Only one image is in memory at a time
- Much faster to start (no waiting for full download)

**Preprocessing for InceptionV1:**
- Resize smallest edge to 256px, then center crop to 224×224
- Scale pixel values: `pixel * 255 - 117` (the original TensorFlow model expects this range)

In [10]:
# Image preprocessing pipeline
# This transforms PIL images into the format InceptionV1 expects

preprocess_for_model = transforms.Compose([
    transforms.Resize(256),           # Resize so smallest edge is 256px
    transforms.CenterCrop(224),       # Crop center 224x224 region
    transforms.ToTensor(),            # Convert to tensor, scales to [0, 1]
])

def preprocess_image(pil_image):
    """
    Prepare an image for InceptionV1.

    Args:
        pil_image: PIL Image (any size, RGB or other mode)

    Returns:
        Tensor of shape [1, 3, 224, 224] ready for the model
    """
    # Ensure RGB (some ImageNet images are grayscale)
    if pil_image.mode != "RGB":
        pil_image = pil_image.convert("RGB")

    # Apply transforms: resize, crop, convert to [0, 1] tensor
    tensor = preprocess_for_model(pil_image)  # Shape: [3, 224, 224]

    # Scale for InceptionV1: [0, 1] -> [-117, 138]
    # The original TF model was trained with this scaling
    tensor = tensor * 255 - 117

    # Add batch dimension: [3, 224, 224] -> [1, 3, 224, 224]
    return tensor.unsqueeze(0)

In [11]:
# Load ImageNet training set in streaming mode
# This requires you to have accepted the dataset terms on HuggingFace

print("Loading ImageNet dataset (streaming mode)...")
print("Note: You must have accepted the terms at:")
print("  https://huggingface.co/datasets/ILSVRC/imagenet-1k")

imagenet_stream = load_dataset(
    "ILSVRC/imagenet-1k",
    split="train",
    streaming=True,  # Don't download everything — stream on demand
)

print(f"ImageNet stream ready! Will process {TOTAL_IMAGES:,} images.")

Loading ImageNet dataset (streaming mode)...
Note: You must have accepted the terms at:
  https://huggingface.co/datasets/ILSVRC/imagenet-1k


README.md:   0%|          | 0.00/87.6k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/294 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/294 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/28 [00:00<?, ?it/s]

ImageNet stream ready! Will process 1,281,167 images.


---
## 6. Main Processing Loop

This is where the work happens. For each image:

1. **Preprocess** — resize/crop/scale for InceptionV1
2. **Forward pass** — run through the model (hook captures activations)
3. **Compute channel activations** — mean activation per channel
4. **Update heaps** — if this image is in the top-K for any channel, add it
5. **Checkpoint** — save progress every 5000 images

**About the heap data structure:**
- We use a min-heap (smallest element on top) for efficiency
- When full, we only keep an image if it beats the current minimum
- This is O(log K) per update, much faster than sorting

**You can safely interrupt this cell at any time!** Progress is saved every 5000 images. Just re-run the cell to resume.

In [12]:
# Ensure our HuggingFace repo exists (creates if needed)
ensure_repo_exists(HF_REPO_ID)

# Load checkpoint (or start fresh)
images_processed, top_images, uploaded_images = load_checkpoint(HF_REPO_ID)

# Counter for heap tie-breaking
# (When two images have equal activation, we use this to decide order)
# Start from where we left off to maintain consistency
counter = images_processed * NUM_CHANNELS

print(f"\nReady to process. Counter starting at: {counter:,}")

Repository 'ayesha-imr02/inceptionv1-imagenet-mixed4a-top10' already exists.


checkpoint.json:   0%|          | 0.00/13.0k [00:00<?, ?B/s]

images/ch0_rank0.png:   0%|          | 0.00/356k [00:00<?, ?B/s]

images/ch0_rank1.png:   0%|          | 0.00/389k [00:00<?, ?B/s]

images/ch0_rank2.png:   0%|          | 0.00/380k [00:00<?, ?B/s]

images/ch0_rank3.png:   0%|          | 0.00/418k [00:00<?, ?B/s]

images/ch0_rank4.png:   0%|          | 0.00/619k [00:00<?, ?B/s]

images/ch0_rank5.png:   0%|          | 0.00/461k [00:00<?, ?B/s]

images/ch0_rank6.png:   0%|          | 0.00/408k [00:00<?, ?B/s]

images/ch0_rank7.png:   0%|          | 0.00/269k [00:00<?, ?B/s]

images/ch0_rank8.png:   0%|          | 0.00/469k [00:00<?, ?B/s]

images/ch0_rank9.png:   0%|          | 0.00/356k [00:00<?, ?B/s]

images/ch1_rank1.png:   0%|          | 0.00/634k [00:00<?, ?B/s]

images/ch1_rank2.png:   0%|          | 0.00/331k [00:00<?, ?B/s]

images/ch1_rank3.png:   0%|          | 0.00/303k [00:00<?, ?B/s]

images/ch1_rank4.png:   0%|          | 0.00/198k [00:00<?, ?B/s]

images/ch1_rank5.png:   0%|          | 0.00/243k [00:00<?, ?B/s]

images/ch1_rank6.png:   0%|          | 0.00/400k [00:00<?, ?B/s]

images/ch1_rank7.png:   0%|          | 0.00/940k [00:00<?, ?B/s]

images/ch1_rank8.png:   0%|          | 0.00/348k [00:00<?, ?B/s]

images/ch1_rank9.png:   0%|          | 0.00/284k [00:00<?, ?B/s]

images/ch2_rank0.png:   0%|          | 0.00/340k [00:00<?, ?B/s]

images/ch2_rank1.png:   0%|          | 0.00/411k [00:00<?, ?B/s]

images/ch2_rank2.png:   0%|          | 0.00/450k [00:00<?, ?B/s]

images/ch2_rank3.png:   0%|          | 0.00/259k [00:00<?, ?B/s]

images/ch2_rank4.png:   0%|          | 0.00/548k [00:00<?, ?B/s]

images/ch2_rank5.png:   0%|          | 0.00/495k [00:00<?, ?B/s]

images/ch2_rank6.png:   0%|          | 0.00/431k [00:00<?, ?B/s]

images/ch2_rank7.png:   0%|          | 0.00/342k [00:00<?, ?B/s]

images/ch2_rank8.png:   0%|          | 0.00/7.26M [00:00<?, ?B/s]

images/ch2_rank9.png:   0%|          | 0.00/148k [00:00<?, ?B/s]

images/ch3_rank0.png:   0%|          | 0.00/401k [00:00<?, ?B/s]

images/ch3_rank1.png:   0%|          | 0.00/369k [00:00<?, ?B/s]

images/ch3_rank2.png:   0%|          | 0.00/441k [00:00<?, ?B/s]

images/ch3_rank3.png:   0%|          | 0.00/352k [00:00<?, ?B/s]

images/ch3_rank5.png:   0%|          | 0.00/429k [00:00<?, ?B/s]

images/ch3_rank6.png:   0%|          | 0.00/415k [00:00<?, ?B/s]

images/ch3_rank7.png:   0%|          | 0.00/414k [00:00<?, ?B/s]

images/ch3_rank8.png:   0%|          | 0.00/392k [00:00<?, ?B/s]

images/ch3_rank9.png:   0%|          | 0.00/616k [00:00<?, ?B/s]

images/ch4_rank0.png:   0%|          | 0.00/225k [00:00<?, ?B/s]

images/ch4_rank1.png:   0%|          | 0.00/381k [00:00<?, ?B/s]

images/ch4_rank2.png:   0%|          | 0.00/418k [00:00<?, ?B/s]

images/ch4_rank3.png:   0%|          | 0.00/253k [00:00<?, ?B/s]

images/ch4_rank4.png:   0%|          | 0.00/488k [00:00<?, ?B/s]

images/ch4_rank5.png:   0%|          | 0.00/437k [00:00<?, ?B/s]

images/ch4_rank6.png:   0%|          | 0.00/274k [00:00<?, ?B/s]

images/ch4_rank7.png:   0%|          | 0.00/377k [00:00<?, ?B/s]

images/ch4_rank8.png:   0%|          | 0.00/242k [00:00<?, ?B/s]

images/ch5_rank0.png:   0%|          | 0.00/371k [00:00<?, ?B/s]

images/ch5_rank1.png:   0%|          | 0.00/541k [00:00<?, ?B/s]

images/ch5_rank3.png:   0%|          | 0.00/788k [00:00<?, ?B/s]

images/ch5_rank4.png:   0%|          | 0.00/393k [00:00<?, ?B/s]

images/ch5_rank5.png:   0%|          | 0.00/496k [00:00<?, ?B/s]

images/ch5_rank6.png:   0%|          | 0.00/472k [00:00<?, ?B/s]

images/ch5_rank7.png:   0%|          | 0.00/383k [00:00<?, ?B/s]

images/ch5_rank8.png:   0%|          | 0.00/423k [00:00<?, ?B/s]

images/ch5_rank9.png:   0%|          | 0.00/443k [00:00<?, ?B/s]

images/ch6_rank0.png:   0%|          | 0.00/356k [00:00<?, ?B/s]

images/ch6_rank2.png:   0%|          | 0.00/3.56M [00:00<?, ?B/s]

images/ch6_rank3.png:   0%|          | 0.00/424k [00:00<?, ?B/s]

images/ch6_rank4.png:   0%|          | 0.00/691k [00:00<?, ?B/s]

images/ch6_rank5.png:   0%|          | 0.00/636k [00:00<?, ?B/s]

images/ch6_rank6.png:   0%|          | 0.00/366k [00:00<?, ?B/s]

images/ch6_rank7.png:   0%|          | 0.00/445k [00:00<?, ?B/s]

images/ch6_rank8.png:   0%|          | 0.00/1.94M [00:00<?, ?B/s]

images/ch6_rank9.png:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

images/ch7_rank0.png:   0%|          | 0.00/439k [00:00<?, ?B/s]

images/ch7_rank1.png:   0%|          | 0.00/387k [00:00<?, ?B/s]

images/ch7_rank2.png:   0%|          | 0.00/919k [00:00<?, ?B/s]

images/ch7_rank3.png:   0%|          | 0.00/477k [00:00<?, ?B/s]

images/ch7_rank4.png:   0%|          | 0.00/349k [00:00<?, ?B/s]

images/ch7_rank5.png:   0%|          | 0.00/501k [00:00<?, ?B/s]

images/ch7_rank6.png:   0%|          | 0.00/370k [00:00<?, ?B/s]

images/ch7_rank7.png:   0%|          | 0.00/497k [00:00<?, ?B/s]

images/ch7_rank8.png:   0%|          | 0.00/535k [00:00<?, ?B/s]

images/ch7_rank9.png:   0%|          | 0.00/450k [00:00<?, ?B/s]

images/ch8_rank0.png:   0%|          | 0.00/249k [00:00<?, ?B/s]

images/ch8_rank1.png:   0%|          | 0.00/250k [00:00<?, ?B/s]

images/ch8_rank2.png:   0%|          | 0.00/342k [00:00<?, ?B/s]

images/ch8_rank3.png:   0%|          | 0.00/292k [00:00<?, ?B/s]

images/ch8_rank4.png:   0%|          | 0.00/107k [00:00<?, ?B/s]

images/ch8_rank5.png:   0%|          | 0.00/286k [00:00<?, ?B/s]

images/ch8_rank6.png:   0%|          | 0.00/111k [00:00<?, ?B/s]

images/ch8_rank7.png:   0%|          | 0.00/241k [00:00<?, ?B/s]

images/ch8_rank8.png:   0%|          | 0.00/276k [00:00<?, ?B/s]

images/ch8_rank9.png:   0%|          | 0.00/747k [00:00<?, ?B/s]

images/ch9_rank0.png:   0%|          | 0.00/116k [00:00<?, ?B/s]

images/ch9_rank1.png:   0%|          | 0.00/321k [00:00<?, ?B/s]

images/ch9_rank2.png:   0%|          | 0.00/148k [00:00<?, ?B/s]

images/ch9_rank3.png:   0%|          | 0.00/319k [00:00<?, ?B/s]

images/ch9_rank4.png:   0%|          | 0.00/117k [00:00<?, ?B/s]

images/ch9_rank5.png:   0%|          | 0.00/219k [00:00<?, ?B/s]

images/ch9_rank6.png:   0%|          | 0.00/150k [00:00<?, ?B/s]

images/ch9_rank7.png:   0%|          | 0.00/346k [00:00<?, ?B/s]

images/ch9_rank8.png:   0%|          | 0.00/288k [00:00<?, ?B/s]

images/ch9_rank9.png:   0%|          | 0.00/122k [00:00<?, ?B/s]

Checkpoint loaded! Resuming from image 330,000
  Loaded 100 existing top images

Ready to process. Counter starting at: 3,300,000


In [13]:
# Main processing loop
print(f"\nStarting processing from image {images_processed:,}...")
print(f"Checkpoints will be saved every {CHECKPOINT_EVERY:,} images.")

# Progress bar
pbar = tqdm(
    enumerate(imagenet_stream),
    total=TOTAL_IMAGES,
    initial=images_processed,
    desc="Processing ImageNet"
)

try:
    with torch.no_grad():  # Disable gradient computation (we're only doing inference)
        for idx, sample in pbar:
            # Skip images we've already processed (when resuming)
            if idx < images_processed:
                continue

            # Get the image (HuggingFace returns a dict with 'image' and 'label' keys)
            pil_image = sample["image"]

            # Keep a copy of the original image (for storing in results)
            # We store the original, not the preprocessed version
            original_image = pil_image.copy()

            # Ensure RGB
            if original_image.mode != "RGB":
                original_image = original_image.convert("RGB")

            # Preprocess for InceptionV1
            model_input = preprocess_image(pil_image).to(device)

            # Forward pass — the hook automatically captures mixed4a activations
            model(model_input)

            # Get the captured activations
            # Shape: [1, 508, H, W] where H=W=14 for 224x224 input
            acts = activation_storage[LAYER_NAME]

            # For each channel we're tracking
            for ch in range(NUM_CHANNELS):
                # Compute mean spatial activation for this channel
                # This tells us how strongly the whole image activates this channel
                activation_value = acts[0, ch].mean().item()

                # Create heap entry: (activation, counter, image)
                # The counter breaks ties when activations are equal
                entry = (activation_value, counter, original_image)
                counter += 1

                # Update the heap
                if len(top_images[ch]) < TOP_K:
                    # Heap not full yet — just add the image
                    heapq.heappush(top_images[ch], entry)
                elif activation_value > top_images[ch][0][0]:
                    # Heap is full, but this image beats the current minimum
                    # Replace the minimum with this new image
                    heapq.heapreplace(top_images[ch], entry)

            # Update progress bar with current best activation for channel 0
            if (idx + 1) % 500 == 0:
                best_act = max(top_images[0])[0] if top_images[0] else 0
                pbar.set_postfix({"ch0_best": f"{best_act:.2f}"})

            # Save checkpoint periodically
            if (idx + 1) % CHECKPOINT_EVERY == 0:
                uploaded_images = save_checkpoint(idx + 1, top_images, HF_REPO_ID, uploaded_images)

    # Final checkpoint after processing all images
    uploaded_images = save_checkpoint(TOTAL_IMAGES, top_images, HF_REPO_ID, uploaded_images)
    print("\nProcessing complete!")

except KeyboardInterrupt:
    print("\n\nInterrupted! Saving checkpoint...")
    uploaded_images = save_checkpoint(idx, top_images, HF_REPO_ID, uploaded_images)
    print(f"Checkpoint saved at image {idx:,}. You can resume by re-running this cell.")


Starting processing from image 330,000...
Checkpoints will be saved every 5,000 images.


Processing ImageNet:  26%|##5       | 330000/1281167 [00:00<?, ?it/s]



RemoteProtocolError: peer closed connection without sending complete message body (received 59906375 bytes, expected 100574983)

---
## 7. Results Visualization

Now let's see what we found! For each channel, we display the 10 ImageNet images that produced the highest activation.

**How to interpret these results:**
- Look for **common themes** across the top images for each channel
- Compare with the **activation maximization** images from Segment 02
- Some channels will be clearly interpretable; others may be more mysterious

In [None]:
# Create visualization grid
# Rows = channels (0-9), Columns = top images ranked by activation

fig, axes = plt.subplots(NUM_CHANNELS, TOP_K, figsize=(20, 22))

for ch in range(NUM_CHANNELS):
    # Sort heap by activation (highest first)
    # Heap entries are (activation, counter, pil_image)
    ranked = sorted(top_images[ch], key=lambda x: x[0], reverse=True)

    for rank, (activation_value, _, pil_image) in enumerate(ranked):
        ax = axes[ch][rank]

        # Display the image
        ax.imshow(pil_image)
        ax.axis("off")

        # Add labels
        if rank == 0:
            # Channel label on the left
            ax.set_ylabel(f"Ch {ch}", fontsize=14, rotation=0, labelpad=50, va="center")
            # Activation value below the image
            ax.set_xlabel(f"act={activation_value:.2f}", fontsize=9)
        else:
            ax.set_xlabel(f"{activation_value:.2f}", fontsize=9)

        if ch == 0:
            # Rank label on top
            ax.set_title(f"#{rank+1}", fontsize=11)

plt.suptitle(
    f"Top-{TOP_K} ImageNet Images per Channel ({LAYER_NAME}, channels 0-{NUM_CHANNELS-1})",
    fontsize=16,
    y=1.02
)
plt.tight_layout()
plt.show()

---
## 8. Observations & Analysis

Look at each row (channel) and ask yourself:

### Questions to Consider

1. **Do the top images share a common theme?**
   - If all top images for a channel contain similar content (e.g., furry textures, circular shapes, text), the neuron likely detects that pattern.

2. **How does this compare to Segment 02?**
   - The activation maximization showed the neuron's "ideal" input.
   - Do these real images contain similar patterns, colors, or textures?
   - Are there surprising differences?

3. **Are some channels more interpretable?**
   - Coherent top images → clear, monosemantic neuron
   - Scattered, unrelated images → possibly **polysemantic** (responds to multiple concepts)

4. **What specific features might each neuron detect?**
   - Textures? (fur, scales, fabric)
   - Shapes? (circles, curves, lines)
   - Colors? (specific hues or contrasts)
   - Objects? (eyes, wheels, faces)

### Recording Your Observations

Use the cell below to note what you see for each channel:

**Your observations:**

- **Channel 0**: _[What patterns do you see?]_
- **Channel 1**: _[...]_
- **Channel 2**: _[...]_
- **Channel 3**: _[...]_
- **Channel 4**: _[...]_
- **Channel 5**: _[...]_
- **Channel 6**: _[...]_
- **Channel 7**: _[...]_
- **Channel 8**: _[...]_
- **Channel 9**: _[...]_

---
## 9. Cleanup

In [None]:
# Remove the forward hook to clean up
hook_handle.remove()
print("Hook removed.")

# Clear activation storage
activation_storage.clear()
print("Activation storage cleared.")

print(f"\nResults are saved in your HuggingFace repo: {HF_REPO_ID}")
print(f"You can view the images directly at: https://huggingface.co/datasets/{HF_REPO_ID}")