# Task 2 — The Prober (Activation Maximization)

> Followed per instructions.md: load trained model, evaluate train/test, and visualize internal features with activation maximization.

**Note:** This notebook prints many 10×10 grids for insight. It can take time to run end-to-end.

In [4]:
# === Phase 0: Setup ===
import os
import random
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
from PIL import Image

SEED = 1337
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Paths
DATA_ROOT = Path("../task0/outputs/colored-mnist")
WEIGHTS_PATH = Path("../task1/saved_models/cnn_weights_feb1_GODLYPULL.pth")

# Hyperparams for visualization (adjust if needed)
GRID_SIZE = 10  # 10x10 grids
GRID_STEPS = 50
GRID_LR = 0.05
POLY_STEPS = 60
POLY_LR = 0.05
SINGLE_NEURON_STEPS = 60
SINGLE_NEURON_LR = 0.05
BATCH_SIZE = 512

print("Paths:")
print(" - DATA_ROOT:", DATA_ROOT)
print(" - WEIGHTS_PATH:", WEIGHTS_PATH)



Using device: cpu
Paths:
 - DATA_ROOT: ../task0/outputs/colored-mnist
 - WEIGHTS_PATH: ../task1/saved_models/cnn_weights_feb1_GODLYPULL.pth


In [5]:

# === Model (must match ../task1/cnn.ipynb) ===
conv1_features = 8
conv2_features = 16

class ThreeLayerCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, conv1_features, kernel_size=5, padding="same")
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(conv1_features, conv2_features, kernel_size=5, padding="same")
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(conv2_features * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.relu_fc = nn.ReLU()

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.relu_fc(self.fc1(x))
        x = self.relu_fc(self.fc2(x))
        x = self.fc3(x)
        return x

model = ThreeLayerCNN().to(device)

if not WEIGHTS_PATH.exists():
    raise FileNotFoundError(f"Weights not found: {WEIGHTS_PATH}")

state = torch.load(WEIGHTS_PATH, map_location=device)
model.load_state_dict(state)
model.eval()

# freeze weights explicitly
for p in model.parameters():
    p.requires_grad_(False)

print("Loaded weights.")


# === Dataset ===
BASE_TRANSFORM = transforms.ToTensor()



Loaded weights.


In [6]:
def load_meta(split: str) -> pd.DataFrame:
    path = DATA_ROOT / split / "labels.csv"
    if not path.exists():
        raise FileNotFoundError(f"{path} not found. Run generation first.")
    return pd.read_csv(path)

def _load_rgb(split: str, filename: str) -> torch.Tensor:
    path = DATA_ROOT / split / "images" / filename
    return BASE_TRANSFORM(Image.open(path).convert("RGB"))

class ColoredMNISTDataset(Dataset):
    def __init__(self, split: str):
        assert split in {"train", "test"}
        self.split = split
        self.meta = load_meta(split)

    def __len__(self):
        return len(self.meta)

    def __getitem__(self, idx):
        row = self.meta.iloc[idx]
        img = _load_rgb(self.split, row.filename)
        label = int(row.label)
        return img, label

train_dataset = ColoredMNISTDataset("train")
test_dataset = ColoredMNISTDataset("test")

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, pin_memory=True)


In [7]:
def evaluate(loader, name: str) -> float:
    model.eval()
    correct, total = 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        logits = model(images)
        preds = logits.argmax(dim=1)
        total += labels.size(0)
        correct += (preds == labels).sum().item()
    acc = 100.0 * correct / max(total, 1)
    print(f"{name} Accuracy: {acc:.2f}%")
    return acc

print("Evaluating train/test accuracy...")
train_acc = evaluate(train_loader, "Train")
test_acc = evaluate(test_loader, "Test")


# === Helpers for activation maximization ===
def _apply_jitter(img: torch.Tensor, jitter: int) -> torch.Tensor:


_IncompleteInputError: incomplete input (808344353.py, line 20)

In [None]:
    if jitter <= 0:
        return img
    ox = random.randint(-jitter, jitter)
    oy = random.randint(-jitter, jitter)
    return torch.roll(img, shifts=(ox, oy), dims=(2, 3))

def _apply_blur(img: torch.Tensor, k: int) -> torch.Tensor:
    if k <= 1:
        return img
    if k % 2 == 0:
        k += 1
    return F.avg_pool2d(img, kernel_size=k, stride=1, padding=k // 2)

def maximize_channel(
    model: nn.Module,
    layer: nn.Module,
    channel: int,
    *,
    steps: int = 200,
    lr: float = 0.05,
    color_penalty: float = 0.0,
    jitter: int = 0,
    blur_k: int = 0,
    seed: int | None = None,
    clamp: bool = True,
    ) -> torch.Tensor:
    if seed is not None:
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)

    activations: Dict[str, torch.Tensor] = {}
    def hook(_, __, output):
        activations["feat"] = output

    handle = layer.register_forward_hook(hook)
    img = torch.randn(1, 3, 28, 28, device=device, requires_grad=True)
    opt = torch.optim.Adam([img], lr=lr)

    for _ in range(steps):
        opt.zero_grad()
        inp = _apply_jitter(img, jitter)
        inp = _apply_blur(inp, blur_k)
        model(inp)
        feat = activations["feat"]
        act = feat[:, channel].mean()
        color_var = img.std(dim=(2, 3)).mean()
        loss = -act + (color_penalty * color_var)
        loss.backward()
        opt.step()
        if clamp:
            img.data.clamp_(0, 1)

    handle.remove()
    return img.detach().cpu()

def maximize_fc_neuron(
    model: nn.Module,
    layer: nn.Module,
    neuron: int,
    *,
    steps: int = 200,
    lr: float = 0.05,
    jitter: int = 0,
    blur_k: int = 0,
    seed: int | None = None,
    clamp: bool = True,
    ) -> torch.Tensor:
    if seed is not None:
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)

    activations: Dict[str, torch.Tensor] = {}
    def hook(_, __, output):
        activations["feat"] = output

    handle = layer.register_forward_hook(hook)
    img = torch.randn(1, 3, 28, 28, device=device, requires_grad=True)
    opt = torch.optim.Adam([img], lr=lr)

    for _ in range(steps):
        opt.zero_grad()
        inp = _apply_jitter(img, jitter)
        inp = _apply_blur(inp, blur_k)
        model(inp)
        feat = activations["feat"]
        act = feat[:, neuron].mean()
        loss = -act
        loss.backward()
        opt.step()
        if clamp:
            img.data.clamp_(0, 1)

    handle.remove()
    return img.detach().cpu()

def maximize_spatial_neuron(
    model: nn.Module,
    layer: nn.Module,
    channel: int,
    y: int,
    x: int,
    *,
    steps: int = 200,
    lr: float = 0.05,
    seed: int | None = None,
    clamp: bool = True,
    ) -> torch.Tensor:
    if seed is not None:
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)

    activations: Dict[str, torch.Tensor] = {}
    def hook(_, __, output):
        activations["feat"] = output

    handle = layer.register_forward_hook(hook)
    img = torch.randn(1, 3, 28, 28, device=device, requires_grad=True)
    opt = torch.optim.Adam([img], lr=lr)

    for _ in range(steps):
        opt.zero_grad()
        model(img)
        feat = activations["feat"]
        act = feat[0, channel, y, x]
        loss = -act
        loss.backward()
        opt.step()
        if clamp:
            img.data.clamp_(0, 1)

    handle.remove()
    return img.detach().cpu()

def plot_grid(images: List[torch.Tensor], title: str):
    assert len(images) == GRID_SIZE * GRID_SIZE
    grid = make_grid(torch.cat(images, dim=0), nrow=GRID_SIZE, padding=2)
    np_img = grid.permute(1, 2, 0).numpy()
    plt.figure(figsize=(12, 12))
    plt.imshow(np_img)
    plt.title(title)
    plt.axis("off")
    plt.show()

def log_template(layer: str, channels: str, objective: str, regularization: str):
    print("Layer:", layer)
    print("Channel/Neuron:", channels)
    print("Objective:", objective)
    print("Regularization:", regularization)
    print("Observation: TODO")
    print("Interpretation: TODO")
    print("-" * 60)


# === Phase 2: Early Layer Exploration (conv1) ===
print("\n[Phase 2] Early layer exploration: conv1")
conv1 = model.conv1
conv2 = model.conv2

conv1_channels = [i % conv1_features for i in range(GRID_SIZE * GRID_SIZE)]
conv1_images = []
for idx, ch in enumerate(conv1_channels):
    img = maximize_channel(
        model, conv1, ch,


In [None]:
        steps=GRID_STEPS, lr=GRID_LR, seed=SEED + idx
    )
    conv1_images.append(img)
plot_grid(conv1_images, title="Conv1 (early) | channel-wise maximization | 10x10 grid")
log_template("conv1", "channels repeated 0-7", "mean activation", "none")


# === Phase 3: Channel Survey (conv1 + conv2) ===
print("\n[Phase 3] Channel survey: conv1 (100 images)")
conv1_survey_channels = [i % conv1_features for i in range(GRID_SIZE * GRID_SIZE)]
conv1_survey_images = []
for idx, ch in enumerate(conv1_survey_channels):
    img = maximize_channel(
        model, conv1, ch,
        steps=GRID_STEPS, lr=GRID_LR, seed=SEED + 100 + idx
    )
    conv1_survey_images.append(img)
plot_grid(conv1_survey_images, title="Conv1 survey | 10x10 grid")
log_template("conv1", "channels repeated 0-7", "mean activation", "none")


In [None]:

print("\n[Phase 3] Channel survey: conv2 (100 images)")
conv2_survey_channels = [i % conv2_features for i in range(GRID_SIZE * GRID_SIZE)]
conv2_survey_images = []
for idx, ch in enumerate(conv2_survey_channels):
    img = maximize_channel(
        model, conv2, ch,
        steps=GRID_STEPS, lr=GRID_LR, seed=SEED + 200 + idx
    )
    conv2_survey_images.append(img)
plot_grid(conv2_survey_images, title="Conv2 survey | 10x10 grid")
log_template("conv2", "channels repeated 0-15", "mean activation", "none")


# === Phase 4: FC neuron probing (sampled neurons) ===
print("\n[Phase 4] FC neuron probing: fc1 (100 images)")
fc1 = model.fc1
fc2 = model.fc2
fc3 = model.fc3

fc1_neurons = [i % 128 for i in range(GRID_SIZE * GRID_SIZE)]
fc1_images = []
for idx, n in enumerate(fc1_neurons):
    img = maximize_fc_neuron(
        model, fc1, n,
        steps=GRID_STEPS, lr=GRID_LR, seed=SEED + 300 + idx
    )
    fc1_images.append(img)
plot_grid(fc1_images, title="FC1 neuron survey | 10x10 grid")
log_template("fc1", "neurons repeated 0-127", "mean activation", "none")



In [None]:
print("\n[Phase 4] FC neuron probing: fc2 (100 images)")
fc2_neurons = [i % 64 for i in range(GRID_SIZE * GRID_SIZE)]
fc2_images = []
for idx, n in enumerate(fc2_neurons):
    img = maximize_fc_neuron(
        model, fc2, n,
        steps=GRID_STEPS, lr=GRID_LR, seed=SEED + 400 + idx
    )
    fc2_images.append(img)
plot_grid(fc2_images, title="FC2 neuron survey | 10x10 grid")
log_template("fc2", "neurons repeated 0-63", "mean activation", "none")

print("\n[Phase 4] FC neuron probing: fc3/logits (100 images)")
fc3_neurons = [i % 10 for i in range(GRID_SIZE * GRID_SIZE)]
fc3_images = []
for idx, n in enumerate(fc3_neurons):
    img = maximize_fc_neuron(
        model, fc3, n,
        steps=GRID_STEPS, lr=GRID_LR, seed=SEED + 500 + idx
    )
    fc3_images.append(img)
plot_grid(fc3_images, title="FC3 (logits) neuron survey | 10x10 grid")
log_template("fc3", "neurons repeated 0-9", "mean activation", "none")


# === Phase 5: Polysemanticity experiments (critical) ===
print("\n[Phase 5] Polysemanticity experiments on conv2 channel 0")
channel_target = 0

poly_plain = []
poly_color = []
poly_jitter_blur = []
for idx in range(GRID_SIZE * GRID_SIZE):
    poly_plain.append(maximize_channel(
        model, conv2, channel_target,
        steps=POLY_STEPS, lr=POLY_LR, seed=SEED + 600 + idx
    ))
    poly_color.append(maximize_channel(
        model, conv2, channel_target,
        steps=POLY_STEPS, lr=POLY_LR, seed=SEED + 700 + idx, color_penalty=1.0
    ))
    poly_jitter_blur.append(maximize_channel(
        model, conv2, channel_target,


In [None]:
        steps=POLY_STEPS, lr=POLY_LR, seed=SEED + 800 + idx, jitter=2, blur_k=3
    ))

plot_grid(poly_plain, title="Polysemanticity | conv2 ch0 | plain objective | 10x10")
log_template("conv2", "channel 0", "mean activation", "none")

plot_grid(poly_color, title="Polysemanticity | conv2 ch0 | + color penalty | 10x10")
log_template("conv2", "channel 0", "mean activation", "color variance penalty (lambda=1.0)")

plot_grid(poly_jitter_blur, title="Polysemanticity | conv2 ch0 | jitter+blur | 10x10")
log_template("conv2", "channel 0", "mean activation", "jitter=2, blur_k=3")


# === Phase 8: Single-neuron probing (spatial specificity) ===
print("\n[Phase 8] Single-neuron probing on conv2 channel 0 across spatial positions")
spatial_images = []
positions: List[Tuple[int, int]] = []
for i in range(GRID_SIZE * GRID_SIZE):
    y = i % 7
    x = (i // 7) % 7
    positions.append((y, x))
for idx, (y, x) in enumerate(positions):
    img = maximize_spatial_neuron(
        model, conv2, channel=0, y=y, x=x,
        steps=SINGLE_NEURON_STEPS, lr=SINGLE_NEURON_LR, seed=SEED + 900 + idx
    )
    spatial_images.append(img)


In [None]:
plot_grid(spatial_images, title="Single-neuron probe | conv2 ch0 | varying (y,x) | 10x10")
log_template("conv2", "channel 0, spatial positions", "single neuron activation", "none")


# === Phase 6: Cross-layer comparison (same objective, different layers) ===
print("\n[Phase 6] Cross-layer comparison for channel 0")
cross_conv1 = []
cross_conv2 = []
for idx in range(GRID_SIZE * GRID_SIZE):
    cross_conv1.append(maximize_channel(
        model, conv1, 0, steps=GRID_STEPS, lr=GRID_LR, seed=SEED + 1000 + idx
    ))
    cross_conv2.append(maximize_channel(
        model, conv2, 0, steps=GRID_STEPS, lr=GRID_LR, seed=SEED + 1100 + idx
    ))
plot_grid(cross_conv1, title="Cross-layer | conv1 ch0 | 10x10")
log_template("conv1", "channel 0", "mean activation", "none")


In [None]:

plot_grid(cross_conv2, title="Cross-layer | conv2 ch0 | 10x10")
log_template("conv2", "channel 0", "mean activation", "none")

print("\nDone. Please fill the Observation/Interpretation lines above for each grid.")