In [158]:
import random
from typing import Literal, cast

import torch as t
import torch.nn.functional as F
from tqdm import tqdm
from plotly import express as px, graph_objects as go

from src import Dataset, load_latest

In [2]:
model, th = load_latest()
ds = Dataset.load()

In [55]:
def seed(seed_val) -> None:
    random.seed(seed_val)
    t.manual_seed(seed_val)

## Images

In [60]:
def show(inds: list[int] | int, category: Literal["train", "test"] = "train") -> None:
    if isinstance(inds, int):
        inds = [inds]
    assert category in ("train", "test"), f"{category=}"
    for i in inds:
        im = (ds.train_x if category == "train" else ds.test_x)[i]
        pred = model(im.reshape(-1, 1, *im.shape[-2:])).argmax().item()
        label = (ds.train_y if category == "train" else ds.test_y)[i].item()
        title = f"[{i}] {pred = } | {label = } | {category.upper()}"
        fig = px.imshow(im[0], title=title)
        fig.show()

In [62]:
show(0)

Find the ones on which the model mispredicts.

In [63]:
train_logits = model(ds.train_x)
train_preds = train_logits.argmax(-1)
train_mispreds: list[int] = [i for i, (pred, target) in enumerate(zip(train_preds, ds.train_y)) if pred != target]

print(f"Mispredicted {len(train_mispreds)} out of {len(ds.train_x)} training images. ({len(train_mispreds) / len(ds.train_x):.2%})")

Mispredicted 1815 out of 60000 training images. (3.02%)


In [65]:
seed(42)
inds = random.sample(train_mispreds, k=3)
show(inds)

## Max activation

In [71]:
# Disable model gradients
for param in model.parameters():
    model.requires_grad_(False)

In [114]:
# Start with random noise
seed(42)
im = t.randn(1, 1, 28, 28, requires_grad=True)
with t.no_grad():
    px.imshow(im.reshape(28, 28)).show()

In [97]:
PARAM_NAMES = [name for name, _ in model.named_parameters()]
print(f"{PARAM_NAMES = }")

PARAM_NAMES = ['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']


Visualize activations in `conv1`.

In [189]:
seed(42)
im = t.randn(1, 1, 28, 28, requires_grad=True)

lr = 1e-3
N = 1_00_000
act_inds = (0, 0, 20, 10)
loss_fn = t.nn.CrossEntropyLoss()
target = t.tensor([100.])
optimizer = t.optim.AdamW([im], lr)

for i in tqdm((range(N))):
    acts_post_conv1 = model.conv1(im)
    act = acts_post_conv1[act_inds].reshape(1)
    # print(act, target)
    loss = loss_fn(act, target)
    loss.backward()
    optimizer.step()
    # with t.no_grad():
    #     im = im + lr * cast(t.Tensor, im.grad)

with t.no_grad():
    fig = px.imshow(im.squeeze())
    fig.show()

100%|██████████| 100000/100000 [00:24<00:00, 4108.77it/s]


Visualize activations in `conv2`.

In [174]:
seed(42)
im = t.randn(1, 1, 28, 28, requires_grad=True)

lr = 1e-3
N = 1_00_000
act_inds = (0, 0, 5, 0)
for i in tqdm((range(N))):
    im.requires_grad_(True)
    acts_post_conv1 = model.pool1(F.relu(model.conv1(im)))
    acts_post_conv2 = model.conv2(acts_post_conv1)
    act = acts_post_conv2[act_inds]
    act.backward()
    with t.no_grad():
        im = im + lr * cast(t.Tensor, im.grad)
print(act)
with t.no_grad():
    fig = px.imshow(im.squeeze())
    fig.show()

100%|██████████| 100000/100000 [00:17<00:00, 5810.21it/s]


tensor(14.1829, grad_fn=<SelectBackward0>)


Visualize activations in `fc1`.

In [175]:
seed(42)
im = t.randn(1, 1, 28, 28, requires_grad=True)

lr = 1e-4
N = 1_00_000
act_inds = (0, 0)
for i in tqdm((range(N))):
    im.requires_grad_(True)
    acts_post_conv1 = model.pool1(F.relu(model.conv1(im)))
    acts_post_conv2 = model.pool2(F.relu(model.conv2(acts_post_conv1)))
    acts_post_fc1 = F.relu(model.fc1(acts_post_conv2.flatten(1)))
    act = acts_post_fc1[act_inds]
    act.backward()
    with t.no_grad():
        im = im + lr * cast(t.Tensor, im.grad)

with t.no_grad():
    fig = px.imshow(im.squeeze())
    fig.show()

100%|██████████| 100000/100000 [00:22<00:00, 4406.31it/s]


Make the most 8-like image.

In [183]:
seed(42)
im = t.randn(1, 1, 28, 28, requires_grad=True)
target = t.zeros(1, 10)
target[0, 8] = 1
lr = 1e-4
loss_fn = t.nn.CrossEntropyLoss()
optimizer = t.optim.AdamW([im], lr)

N = 1_00_000

for i in tqdm(range(N)):
    # im.requires_grad_(True)
    logits = model(im)
    loss = loss_fn(logits, target)
    loss.backward()
    optimizer.step()

# lr = 1e-3
# N = 1_000_000
# act_inds = (0, 0)
# for i in tqdm((range(N))):
#     im.requires_grad_(True)
#     acts_post_conv1 = model.pool1(F.relu(model.conv1(im)))
#     acts_post_conv2 = model.pool2(F.relu(model.conv2(acts_post_conv1)))
#     acts_post_fc1 = F.relu(model.fc1(acts_post_conv2.flatten(1)))
#     acts_post_fc2 = model.fc2(acts_post_fc1)
#     act = acts_post_fc2[act_inds]
#     act.backward()
#     with t.no_grad():
#         im = im + lr * cast(t.Tensor, im.grad)

with t.no_grad():
    fig = px.imshow(im.squeeze())
    fig.show()

100%|██████████| 100000/100000 [00:36<00:00, 2723.71it/s]
