In [9]:
import torch as t
from torch import nn
import numpy as np
from tqdm.notebook import tqdm
import torchopt
import functorch
from functools import partial
from torch.utils.data import DataLoader, Dataset, RandomSampler


import os, sys
HOME = os.environ['HOME']  # change if necessary
sys.path.append(f'{HOME}/Finite-groups/src')
from model import MLP3, MLP4, InstancedModule
from utils import *
from group_data import *
from model_utils import *
from train import Parameters
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Quadratic toy example

In [57]:
class QuadraticModel(nn.Module):
    def __init__(self, A):
        super().__init__()
        self.A = A
        self.x = nn.Parameter(t.randn(A.shape[0]))

    def forward(self):
        return (self.x.T @ self.A @ self.x).squeeze()

In [71]:
ADAM_CFG = {
    'lr': 0.01,
}
B = t.randn(5, 5)
model = QuadraticModel(B.T @ B + 0.01 * t.eye(5))
opt = torchopt.FuncOptimizer(torchopt.adam(**ADAM_CFG))
model_f, init_params = functorch.make_functional(model)

def train(init_params, iters):
    params = init_params
    for _ in tqdm(range(iters)):
        loss = model_f(params)
        params = opt.step(loss, params)
    return params

  model_f, init_params = functorch.make_functional(model)


In [73]:
t.func.jacrev(partial(train, iters=10))(init_params)

  0%|          | 0/10 [00:00<?, ?it/s]

((tensor([[ 0.8728, -0.0568, -0.0088,  0.0194,  0.0596],
          [-0.0096,  0.9777, -0.0029,  0.0046,  0.0027],
          [-0.0029, -0.0059,  0.9853,  0.0128,  0.0014],
          [ 0.0054,  0.0075,  0.0104,  0.9740, -0.0010],
          [ 0.2598,  0.0741,  0.0190, -0.0174,  0.8710]],
         grad_fn=<ViewBackward0>),),)

### Groups

In [None]:
PARAMS = Parameters(
    instances=1,
    embed_dim=32,
    hidden_size=32,
    group_string='S(4)',
    model='MLP2',
    unembed_bias=True,
    weight_decay=2e-5,
    train_frac=0.4,
)

t.manual_seed(PARAMS.seed)
np.random.seed(PARAMS.seed)
random.seed(PARAMS.seed)
group_dataset = GroupData(params=PARAMS)
model = MODEL_DICT[PARAMS.model](params=PARAMS).to(device)

batch_size = len(group_dataset)
# sampler = RandomSampler(group_dataset, replacement=PARAMS.replacement)
train_loader = DataLoader(
    dataset=group_dataset,
    batch_size=batch_size,
    # shuffle=True,
    drop_last=True,
    # sampler=sampler,
)

# TODO: bias params should not get weight decay (to match with train.py)
# But probably doesn't matter much
opt = torchopt.FuncOptimizer(
    torchopt.adam(
        weight_decay=PARAMS.weight_decay,
        lr=PARAMS.lr,
        betas=[PARAMS.beta1, PARAMS.beta2],
    )
)
model_f, init_params = functorch.make_functional(model)
param_shapes = [p.shape for p in init_params]

def flatten(params):
    return t.cat([p.flatten() for p in params])

def unflatten(flat_params, shapes):
    params = []
    i = 0
    for shape in shapes:
        size = np.prod(shape)
        params.append(flat_params[i:i+size].reshape(shape))
        i += size
    return params

def train(flat_init_params, epochs):
    params = unflatten(flat_init_params)
    for epoch in tqdm(range(epochs)):
        for x, z in train_loader:
            x = x.to(device)
            z = z.to(device)
            output = model_f(params, x)
            loss = get_cross_entropy(output, z)
            params = opt.step(loss, params)
    return params


Intersection size: 576/576 (1.00)
Added 576 elements from intersection
Added 0 elements from group 0: S(4)
Taking random subset: 230/576 (0.40)
Train set size: 230/576 (0.40)


  model_f, init_params = functorch.make_functional(model)


In [26]:
train_jac = t.func.jacrev(partial(train, epochs=5))(init_params)
train_jac.shape

100%|██████████| 5/5 [00:00<00:00, 109.46it/s]


AttributeError: 'tuple' object has no attribute 'shape'

In [32]:
[t.shape for t in init_params]

[torch.Size([1, 24, 32]),
 torch.Size([1, 24, 32]),
 torch.Size([1, 32, 32]),
 torch.Size([1, 32, 32]),
 torch.Size([1, 32, 24]),
 torch.Size([1, 24])]

In [17]:
params = train(init_params, epochs=1000)
x, z = next(iter(train_loader))
x = x.to(device)
z = z.to(device)
output = model_f(params, x)
get_cross_entropy(output, z)

100%|██████████| 1000/1000 [00:07<00:00, 125.18it/s]


tensor([0.0008], grad_fn=<MeanBackward1>)

tensor([0.0008], grad_fn=<MeanBackward1>)