In [1]:
import random
from typing import Literal, cast

import torch as t
import torch.nn.functional as F
from tqdm import tqdm
from plotly import express as px, graph_objects as go

from src import Dataset, load_latest
from src.utils import seed

In [2]:
model, th = load_latest()
ds = Dataset.load()

In [8]:
# Disable model gradients
model.requires_grad_(False)
assert all(not p.requires_grad for p in model.parameters())

## Images

In [9]:
def show(inds: list[int] | int, category: Literal["train", "test"] = "train") -> None:
    if isinstance(inds, int):
        inds = [inds]
    assert category in ("train", "test"), f"{category=}"
    for i in inds:
        im = (ds.train_x if category == "train" else ds.test_x)[i]
        pred = model(im.reshape(-1, 1, *im.shape[-2:])).argmax().item()
        label = (ds.train_y if category == "train" else ds.test_y)[i].item()
        title = f"[{i}] {pred = } | {label = } | {category.upper()}"
        fig = px.imshow(im[0], title=title)
        fig.show()

In [10]:
show(0)

Find the ones on which the model mispredicts.

In [11]:
train_logits = model(ds.train_x)
train_preds = train_logits.argmax(-1)
train_mispreds: list[int] = [i for i, (pred, target) in enumerate(zip(train_preds, ds.train_y)) if pred != target]

print(f"Mispredicted {len(train_mispreds)} out of {len(ds.train_x)} training images. ({len(train_mispreds) / len(ds.train_x):.2%})")

Mispredicted 1815 out of 60000 training images. (3.02%)


In [12]:
seed(42)
inds = random.sample(train_mispreds, k=3)
show(inds)

## Max activation

In [13]:
# Start with random noise
seed(42)
im = t.randn(1, 1, 28, 28, requires_grad=True)
with t.no_grad():
    px.imshow(im.reshape(28, 28)).show()

In [24]:
def make_target_tensor(
    shape: tuple[int, ...], act_inds: tuple[int, ...], factor: float = 1000.0
) -> t.Tensor:
    assert len(shape) == len(act_inds), f"{shape}, {act_inds}"
    target = -factor * t.ones(shape, dtype=t.float32)
    target[act_inds] = factor
    return target

Visualize activations in `conv1`.

In [25]:
seed(42)
im = t.randn(1, 1, 28, 28, requires_grad=True)

lr = 1e-4
N = 1_0_000

act_inds = (0, 0, 20, 10)
target = make_target_tensor(model.conv1(im).shape, act_inds)

loss_fn = t.nn.CrossEntropyLoss()
optimizer = t.optim.AdamW([im], lr)

for i in tqdm((range(N))):
    optimizer.zero_grad()
    acts = model.conv1(im)
    loss = loss_fn(acts, target)
    loss.backward()
    optimizer.step()

with t.no_grad():
    fig = px.imshow(im.squeeze())
    fig.show()

100%|██████████| 10000/10000 [00:02<00:00, 3346.26it/s]


Visualize activations in `conv2`.

In [27]:
def forward(x: t.Tensor) -> t.Tensor:
    return model.conv2(model.pool1(F.relu(model.conv1(x))))

seed(42)
im = t.randn(1, 1, 28, 28, requires_grad=True)

lr = 1e-4
N = 1_0_000

act_inds = (0, 0, 10, 10)
target = make_target_tensor(forward(im).shape, act_inds)

loss_fn = t.nn.CrossEntropyLoss()
optimizer = t.optim.AdamW([im], lr)

for i in tqdm((range(N))):
    optimizer.zero_grad()
    acts = forward(im)
    loss = loss_fn(acts, target)
    loss.backward()
    optimizer.step()

with t.no_grad():
    fig = px.imshow(im.squeeze())
    fig.show()

100%|██████████| 10000/10000 [00:03<00:00, 2639.00it/s]


Visualize activations in `fc1`.

In [47]:
def forward(x: t.Tensor) -> t.Tensor:
    return model.fc1(model.pool2(F.relu(model.conv2(model.pool1(F.relu(model.conv1(x)))))).flatten(1))

seed(42)
im = t.randn(1, 1, 28, 28, requires_grad=True)

lr = 1e-3
N = 1_0_000

act_inds = (0, 0)
target = make_target_tensor(forward(im).shape, act_inds, factor=1e9)

loss_fn = t.nn.CrossEntropyLoss()
optimizer = t.optim.AdamW([im], lr)

for i in tqdm((range(N))):
    optimizer.zero_grad()
    acts = forward(im)
    loss = loss_fn(acts, target)
    loss.backward()
    optimizer.step()

with t.no_grad():
    fig = px.imshow(im.squeeze())
    fig.show()

100%|██████████| 10000/10000 [00:04<00:00, 2228.35it/s]


In [48]:
acts = forward(im)
acts[0, 0], acts.max()

(tensor(86.7340, grad_fn=<SelectBackward0>),
 tensor(220.5017, grad_fn=<MaxBackward1>))

Make the most 8-like image.

In [73]:
def forward(x: t.Tensor) -> t.Tensor:
    logits = model(x)
    probs = logits[0].softmax(0)
    return probs

seed(42)
im = t.randn(1, 1, 28, 28, requires_grad=True)

lr = 1e-4
N = 1_0_000

target_prob = 0.99
target = t.ones(10) * (1-target_prob) / 9
target[8] = target_prob
assert target.sum().isclose(t.tensor(1.0))

loss_fn = t.nn.CrossEntropyLoss()
optimizer = t.optim.AdamW([im], lr)

for i in tqdm((range(N))):
    optimizer.zero_grad()
    acts = forward(im)
    loss = loss_fn(acts, target)
    loss.backward()
    optimizer.step()

with t.no_grad():
    fig = px.imshow(im.squeeze())
    fig.show()

100%|██████████| 10000/10000 [00:04<00:00, 2044.32it/s]


In [64]:
def predict(x: t.Tensor) -> int:
    return model(x).flatten().softmax(0).argmax()

In [65]:
predict(im)

tensor(5)