# Segment 1: Introduction to Vision Interpretability and CNN Basics

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](#)

**What you'll learn:**
1. How images become numbers (tensors) that a network can process
2. What convolutional filters do — local pattern matching
3. How to build and train a simple CNN from scratch
4. How depth creates abstraction: edges → textures → objects
5. How spatial information survives through the network's layers

We'll keep things visual and hands-on. By the end, you'll have trained your own CNN and looked inside it.

In [None]:
import torch

import torch.nn as nn

import torch.nn.functional as F

import torch.optim as optim

from torchvision import transforms, datasets, models

from torch.utils.data import DataLoader

from PIL import Image

import matplotlib.pyplot as plt

import numpy as np

import requests, json, os, tarfile, urllib.request

from io import BytesIO



# Device setup

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")



# Download a sample image

url = "https://raw.githubusercontent.com/pytorch/hub/master/images/dog.jpg"

img = Image.open(BytesIO(requests.get(url).content)).convert("RGB").resize((224, 224))



# Convert to tensor (no normalization — just raw pixel values scaled to [0, 1])

img_tensor = transforms.ToTensor()(img)  # shape: [3, 224, 224]



print(f"Image tensor shape: {list(img_tensor.shape)}  →  [channels, height, width]")

print(f"Value range: [{img_tensor.min():.2f}, {img_tensor.max():.2f}]")



# --- Human view vs. Machine view ---

fig, axes = plt.subplots(1, 4, figsize=(14, 3.5))

axes[0].imshow(img)

axes[0].set_title("What we see")



for c, (name, cmap) in enumerate([("Red", "Reds"), ("Green", "Greens"), ("Blue", "Blues")]):

    axes[c + 1].imshow(img_tensor[c].numpy(), cmap=cmap)

    axes[c + 1].set_title(f"{name} channel")



for ax in axes:

    ax.axis("off")



plt.suptitle("To a CNN, an image is just a 3D grid of numbers — one layer per color channel",

             fontsize=11, y=1.02)

plt.tight_layout()

plt.show()

## What Does a Convolutional Filter Do?

A **filter** (or kernel) is a small grid of numbers, typically 3×3. It slides across the image one patch at a time and computes a weighted sum at each position:

- If the patch **matches** the filter's pattern → **high output** (bright)
- If it **doesn't match** → **low output** (dark)

This is the core mechanism of every CNN: **local pattern matching via dot products.**

Let's make this concrete with a toy experiment — two simple images, three hand-crafted filters.

In [None]:
# --- Two simple test images (10x10 pixels) ---

img_vert = torch.zeros(10, 10)

img_vert[:, 5:] = 1.0                  # vertical edge: left half dark, right half bright



img_diag = torch.zeros(10, 10)

for i in range(10):

    img_diag[i, min(i, 9)] = 1.0       # diagonal line



# --- Three hand-crafted 3x3 filters ---

filters = {

    "Vertical\ndetector": torch.tensor([[-1., 0., 1.],

                                         [-1., 0., 1.],

                                         [-1., 0., 1.]]),

    "Horizontal\ndetector": torch.tensor([[-1., -1., -1.],

                                           [ 0.,  0.,  0.],

                                           [ 1.,  1.,  1.]]),

    "Diagonal\ndetector": torch.tensor([[ 2., -1., -1.],

                                         [-1.,  2., -1.],

                                         [-1., -1.,  2.]]),

}



# --- Apply every filter to every image ---

test_images = {"Vertical edge": img_vert, "Diagonal line": img_diag}



fig, axes = plt.subplots(2, 4, figsize=(14, 6))

for r, (img_name, im) in enumerate(test_images.items()):

    axes[r, 0].imshow(im, cmap="gray")

    axes[r, 0].set_title(img_name, fontsize=11)



    x = im.unsqueeze(0).unsqueeze(0)  # reshape to [1, 1, 10, 10] for conv2d

    for c, (f_name, kernel) in enumerate(filters.items()):

        out = F.conv2d(x, kernel.reshape(1, 1, 3, 3), padding=1)

        axes[r, c + 1].imshow(out.squeeze(), cmap="RdBu_r", vmin=-3, vmax=3)

        axes[r, c + 1].set_title(f_name, fontsize=10)



for ax in axes.flat:

    ax.set_xticks([]); ax.set_yticks([])



plt.suptitle("Filter selectivity: each filter responds strongest to its own pattern",

             fontsize=12, fontweight="bold")

plt.tight_layout()

plt.show()



# Key insight:

print("Each filter is a pattern template — it only fires when it sees its matching pattern.")

print("A CNN learns HUNDREDS of these filters automatically during training.")

## Let's Build and Train a Tiny CNN

Now that we understand what a single filter does, let's build a network with **many filters stacked in layers** and train it to recognize images.

**Dataset:** [ImageNette](https://github.com/fastai/imagenette) — a beginner-friendly 10-class subset of ImageNet with easy-to-recognize categories (dog, church, guitar, fish, etc.)

**Our model — SimpleCNN:**
- `conv1`: 3 → 16 filters (learn basic patterns like edges)
- `conv2`: 16 → 32 filters (combine edges into textures)
- `conv3`: 32 → 64 filters (build higher-level features)
- Global average pool + fully connected layer → 10 class scores

In [None]:
# ===========================================================

# SimpleCNN — a minimal 3-layer convolutional network

# ===========================================================

class SimpleCNN(nn.Module):

    def __init__(self, num_classes=10):

        super().__init__()

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)

        self.bn1   = nn.BatchNorm2d(16)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)

        self.bn2   = nn.BatchNorm2d(32)

        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

        self.bn3   = nn.BatchNorm2d(64)

        self.pool  = nn.MaxPool2d(2, 2)

        self.gap   = nn.AdaptiveAvgPool2d(1)   # global average pool

        self.fc    = nn.Linear(64, num_classes)



    def forward(self, x):

        x = self.pool(F.relu(self.bn1(self.conv1(x))))   # [B,16,H/2,W/2]

        x = self.pool(F.relu(self.bn2(self.conv2(x))))   # [B,32,H/4,W/4]

        x = self.pool(F.relu(self.bn3(self.conv3(x))))   # [B,64,H/8,W/8]

        x = self.gap(x).flatten(1)                        # [B,64]

        return self.fc(x)                                  # [B,10]



# ===========================================================

# Download and load ImageNette (160px version, ~98 MB)

# ===========================================================

DATA_DIR = "imagenette2-160"

if not os.path.exists(DATA_DIR):

    print("Downloading ImageNette (160px)... this may take a minute.")

    urllib.request.urlretrieve(

        "https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-160.tgz",

        "imagenette2-160.tgz"

    )

    with tarfile.open("imagenette2-160.tgz", "r:gz") as tar:

        tar.extractall()

    print("Done!")



# ImageNet normalization stats

MEAN = [0.485, 0.456, 0.406]

STD  = [0.229, 0.224, 0.225]



train_transform = transforms.Compose([

    transforms.RandomResizedCrop(128),

    transforms.RandomHorizontalFlip(),

    transforms.ToTensor(),

    transforms.Normalize(mean=MEAN, std=STD),

])

val_transform = transforms.Compose([

    transforms.Resize(160),

    transforms.CenterCrop(128),

    transforms.ToTensor(),

    transforms.Normalize(mean=MEAN, std=STD),

])



train_dataset = datasets.ImageFolder(f"{DATA_DIR}/train", transform=train_transform)

val_dataset   = datasets.ImageFolder(f"{DATA_DIR}/val",   transform=val_transform)



train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True,  num_workers=2)

val_loader   = DataLoader(val_dataset,   batch_size=64, shuffle=False, num_workers=2)



CLASS_NAMES = ["tench", "English springer", "cassette player", "chain saw",

               "church", "French horn", "garbage truck", "gas pump",

               "golf ball", "parachute"]



print(f"Training samples:   {len(train_dataset)}")

print(f"Validation samples: {len(val_dataset)}")

print(f"Classes: {CLASS_NAMES}")



# --- Show a sample batch ---

mean_t = torch.tensor(MEAN).view(3, 1, 1)

std_t  = torch.tensor(STD).view(3, 1, 1)



sample_imgs, sample_labels = next(iter(train_loader))

fig, axes = plt.subplots(2, 4, figsize=(14, 7))

for i, ax in enumerate(axes.flat):

    display = (sample_imgs[i] * std_t + mean_t).permute(1, 2, 0).clip(0, 1)

    ax.imshow(display)

    ax.set_title(CLASS_NAMES[sample_labels[i]], fontsize=10)

    ax.axis("off")

plt.suptitle("Sample training images from ImageNette", fontsize=12, fontweight="bold")

plt.tight_layout()

plt.show()



# --- Create model ---

model_simple = SimpleCNN(num_classes=10).to(device)

total_params = sum(p.numel() for p in model_simple.parameters())

print(f"\nSimpleCNN created on {device}: {total_params:,} trainable parameters")

In [None]:
# ===========================================================

# Training loop

# ===========================================================

optimizer = optim.Adam(model_simple.parameters(), lr=1e-3)

criterion = nn.CrossEntropyLoss()

num_epochs = 8



history = {"train_loss": [], "val_loss": [], "val_acc": []}



for epoch in range(num_epochs):

    # --- Training pass ---

    model_simple.train()

    running_loss = 0.0

    for images, labels in train_loader:

        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        loss = criterion(model_simple(images), labels)

        loss.backward()

        optimizer.step()

        running_loss += loss.item() * images.size(0)



    train_loss = running_loss / len(train_dataset)



    # --- Validation pass ---

    model_simple.eval()

    val_loss, correct = 0.0, 0

    with torch.no_grad():

        for images, labels in val_loader:

            images, labels = images.to(device), labels.to(device)

            outputs = model_simple(images)

            val_loss += criterion(outputs, labels).item() * images.size(0)

            correct += (outputs.argmax(1) == labels).sum().item()



    val_loss /= len(val_dataset)

    val_acc = correct / len(val_dataset) * 100



    history["train_loss"].append(train_loss)

    history["val_loss"].append(val_loss)

    history["val_acc"].append(val_acc)



    print(f"Epoch {epoch+1}/{num_epochs}  |  "

          f"Train loss: {train_loss:.3f}  |  "

          f"Val loss: {val_loss:.3f}  |  Val acc: {val_acc:.1f}%")



# --- Plot training curves ---

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(history["train_loss"], label="Train", marker="o")

ax1.plot(history["val_loss"], label="Validation", marker="o")

ax1.set_xlabel("Epoch"); ax1.set_ylabel("Loss"); ax1.legend()

ax1.set_title("Loss")



ax2.plot(history["val_acc"], color="green", marker="o")

ax2.set_xlabel("Epoch"); ax2.set_ylabel("Accuracy (%)")

ax2.set_title("Validation Accuracy")



plt.tight_layout()

plt.show()

## What Did Our CNN Learn?

Our SimpleCNN now classifies 10 categories of images. But what patterns did it actually discover?

The first layer (`conv1`) has **16 filters**, each 3×3×3 (height × width × RGB channels). Before training these were random noise. After training, they should have organized into meaningful pattern detectors.

Let's compare before vs. after.

In [None]:
# Trained filters from our model

trained_filters = model_simple.conv1.weight.data.cpu()   # [16, 3, 3, 3]



# Random filters from a fresh (untrained) model for comparison

random_filters = SimpleCNN().conv1.weight.data.cpu()



def show_filter_row(filters, axes_row):

    """Display filters as tiny RGB images in a row of axes."""

    for idx, ax in enumerate(axes_row):

        if idx < filters.shape[0]:

            f = filters[idx].permute(1, 2, 0).numpy()        # [3,3,RGB]

            f = (f - f.min()) / (f.max() - f.min() + 1e-8)   # normalize to [0,1]

            ax.imshow(f, interpolation="nearest")

        ax.axis("off")



fig, axes = plt.subplots(2, 16, figsize=(16, 2.5))

show_filter_row(random_filters, axes[0])

show_filter_row(trained_filters, axes[1])



axes[0][0].set_ylabel("Random\n(before)", fontsize=10, rotation=0, labelpad=50, va="center")

axes[1][0].set_ylabel("Learned\n(after)",  fontsize=10, rotation=0, labelpad=50, va="center")



plt.suptitle("conv1 filters: random noise → learned edge and color detectors",

             fontsize=12, fontweight="bold")

plt.tight_layout()

plt.show()



print("The network discovered edge and color detectors on its own — nobody programmed them.")

In [None]:
# --- Hook activations at each convolutional layer ---

activations = {}



def get_hook(name):

    def hook_fn(module, input, output):

        activations[name] = output.detach().cpu()

    return hook_fn



hooks = [

    model_simple.bn1.register_forward_hook(get_hook("conv1 (16ch)")),

    model_simple.bn2.register_forward_hook(get_hook("conv2 (32ch)")),

    model_simple.bn3.register_forward_hook(get_hook("conv3 (64ch)")),

]



# Run our sample image through the trained SimpleCNN

sample_input = val_transform(img).unsqueeze(0).to(device)

model_simple.eval()

with torch.no_grad():

    _ = model_simple(sample_input)



for h in hooks:

    h.remove()



# --- Mean activation at each depth ---

fig, axes = plt.subplots(1, 4, figsize=(16, 4))

axes[0].imshow(img)

axes[0].set_title("Input image", fontsize=11)



for i, (name, act) in enumerate(activations.items()):

    # Apply ReLU (hooks captured pre-ReLU output from BN)

    mean_act = F.relu(act[0]).mean(dim=0)

    C, H, W = act.shape[1], act.shape[2], act.shape[3]

    axes[i + 1].imshow(mean_act, cmap="inferno")

    axes[i + 1].set_title(f"{name}\n{C} × {H}×{W}", fontsize=10)



for ax in axes:

    ax.axis("off")

plt.suptitle("Deeper layers → more abstract features, lower spatial resolution",

             fontsize=12, fontweight="bold")

plt.tight_layout()

plt.show()



# --- Individual feature maps from conv1 ---

conv1_act = F.relu(activations["conv1 (16ch)"][0])   # [16, H, W]

fig, axes = plt.subplots(2, 8, figsize=(16, 4))

for i, ax in enumerate(axes.flat):

    fmap = conv1_act[i]

    fmap = (fmap - fmap.min()) / (fmap.max() - fmap.min() + 1e-5)

    ax.imshow(fmap, cmap="viridis")

    ax.set_title(f"#{i}", fontsize=8)

    ax.axis("off")

plt.suptitle("All 16 feature maps from conv1 — each filter highlights a different pattern",

             fontsize=11, fontweight="bold")

plt.tight_layout()

plt.show()

## Does the Network Know *Where* Things Are?

Deeper layers have smaller spatial dimensions — the feature maps physically shrink. Does that mean the network loses track of where things are in the image?

**Experiment:** black out a region of the input and see which activations change. If spatial information is preserved, only the *corresponding region* of the feature maps should be disrupted.

In [None]:
# Create a perturbed copy with a blacked-out region

perturbed_input = sample_input.clone()

perturbed_input[:, :, 30:90, 40:110] = 0   # zero out a rectangular patch



def collect_acts(x):

    """Run x through the model and return early + deep activations."""

    store = {}

    def hook_early(m, inp, out): store["early"] = out.detach().cpu()

    def hook_deep(m, inp, out):  store["deep"]  = out.detach().cpu()



    h1 = model_simple.bn1.register_forward_hook(hook_early)

    h2 = model_simple.bn3.register_forward_hook(hook_deep)

    model_simple.eval()

    with torch.no_grad():

        model_simple(x)

    h1.remove(); h2.remove()

    return store



acts_orig = collect_acts(sample_input)

acts_pert = collect_acts(perturbed_input)



# Absolute difference in activations (averaged across channels)

early_diff = (acts_orig["early"][0] - acts_pert["early"][0]).abs().mean(0)

deep_diff  = (acts_orig["deep"][0]  - acts_pert["deep"][0]).abs().mean(0)



# --- Visualize ---

mean_t = torch.tensor(MEAN).view(3, 1, 1)

std_t  = torch.tensor(STD).view(3, 1, 1)



fig, axes = plt.subplots(1, 4, figsize=(16, 4))



axes[0].imshow(img)

axes[0].set_title("Original")



pert_display = (perturbed_input[0].cpu() * std_t + mean_t).permute(1, 2, 0).clip(0, 1)

axes[1].imshow(pert_display)

axes[1].set_title("Perturbed (blacked out)")



axes[2].imshow(early_diff, cmap="hot")

axes[2].set_title(f"conv1 difference\n{list(early_diff.shape)}")



axes[3].imshow(deep_diff, cmap="hot")

axes[3].set_title(f"conv3 difference\n{list(deep_diff.shape)}")



for ax in axes:

    ax.axis("off")

plt.suptitle("Spatial information survives: the disruption appears in the right location, even in deep layers",

             fontsize=11, fontweight="bold")

plt.tight_layout()

plt.show()



print("Even at low resolution, deep layers preserve WHERE things are in the image.")

## Scaling Up: A Quick Look at Pretrained Models

Our SimpleCNN has 3 layers and learned from just 10 classes. State-of-the-art models have **dozens of layers** and are trained on **millions of images** from the full ImageNet (1000+ classes).

The same principles apply at scale — early layers detect edges, deeper layers detect objects. The features are just much richer.

Popular architectures include **ResNet**, **VGG**, and **GoogLeNet (InceptionV1)** — which we'll use in Segment 2.

Let's quickly load a pretrained ResNet18 and compare.

In [None]:
# --- Load pretrained ResNet18 ---

resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT).to(device).eval()

resnet_preprocess = models.ResNet18_Weights.DEFAULT.transforms()



# --- Classify our sample image ---

resnet_input = resnet_preprocess(img).unsqueeze(0).to(device)

with torch.no_grad():

    logits = resnet(resnet_input)



# Load human-readable ImageNet labels

labels_url = "https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json"

imagenet_labels = json.loads(requests.get(labels_url).text)



probs = torch.softmax(logits, dim=1)[0]

top5 = probs.topk(5)



print("ResNet18 top-5 predictions:")

for i in range(5):

    idx = top5.indices[i].item()

    print(f"  {imagenet_labels[idx]:25s}  {top5.values[i]:.1%}")



# --- Feature maps at two depths ---

resnet_acts = {}



def resnet_hook(name):

    def fn(m, inp, out):

        resnet_acts[name] = out.detach().cpu()

    return fn



hooks = [

    resnet.layer1.register_forward_hook(resnet_hook("layer1 (early)")),

    resnet.layer4.register_forward_hook(resnet_hook("layer4 (deep)")),

]

with torch.no_grad():

    _ = resnet(resnet_input)

for h in hooks:

    h.remove()



fig, axes = plt.subplots(1, 3, figsize=(14, 4))

axes[0].imshow(img)

axes[0].set_title("Input image")



for i, (name, act) in enumerate(resnet_acts.items()):

    mean_act = act[0].mean(dim=0)

    C, H, W = act.shape[1], act.shape[2], act.shape[3]

    axes[i + 1].imshow(mean_act, cmap="inferno")

    axes[i + 1].set_title(f"ResNet18 {name}\n{C}ch × {H}×{W}")



for ax in axes:

    ax.axis("off")

plt.suptitle("Same principles at scale: early layers see edges, deep layers see objects",

             fontsize=12, fontweight="bold")

plt.tight_layout()

plt.show()



print("A pretrained model has much richer features — but the same hierarchy holds.")

## Summary

| Concept | Key Takeaway |
|---------|-------------|
| **Images as tensors** | A CNN sees a [3, H, W] grid of numbers, not a photograph |
| **Filter selectivity** | Each filter is a pattern template — it only fires on matching patterns |
| **Learned features** | Training discovers meaningful filters (edges, colors) automatically |
| **Feature hierarchy** | Early layers = edges, middle = textures, deep = whole objects |
| **Spatial information** | Position is preserved even in deep, low-resolution layers |

---

### Next Up: Segment 2 — Activation Maximization

So far we've asked: *"What does each neuron do when it sees our image?"*

In Segment 2, we'll flip the question: **"What input image would make a given neuron fire the hardest?"**

That's the idea behind **activation maximization** — and we'll use the [Lucent](https://github.com/greentfrapp/lucent) library with **InceptionV1 (GoogLeNet)** to generate these visualizations for the first 10 neurons of the **Mixed4a** layer.

Stay tuned!