In [None]:
import torch
import torch.nn as nn
import numpy as np
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from model import WideModel
from tqdm.auto import tqdm

In [None]:
# load dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0,), (1,))  # Normalize with mean and std for MNIST
])

mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)


In [None]:
# create dataloaders. Batch size must be 1
batch_size = 1
train_dataloader = DataLoader(mnist_trainset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(mnist_testset, batch_size=batch_size, shuffle=True)

In [None]:
# create device
dev = "cuda" if torch.cuda.is_available() else "cpu"

# create model
model = WideModel(hidden_dim_scale = 20).to(dev)

# create optimizer
lr = 1e-3
optimizer = torch.optim.SGD([p for p in model.parameters()], lr=lr)

In [None]:
# get linearized models:
# we reduce f(x,w) to Aw+B, where there is a different A,B per x

As = []
Bs = []
ys = []

for x,y in tqdm(train_dataloader):
    x = x.to(dev)
    
    # A = gradient matrix of logits
    A = model.flatten_gradient(x)
    
    # B = f(x,w) - A w
    B = model.forward(x) - A @ model.flatten_parameters()
    
    As.append(A.detach().cpu().numpy())
    Bs.append(B.detach().cpu().numpy())
    ys.append(y.detach().cpu().numpy())
    