In [1]:
import torch as t
from torch import nn
import numpy as np
from tqdm.notebook import tqdm
# import torchopt
import functorch
from functools import partial
from torch.utils.data import DataLoader, Dataset, RandomSampler
from jaxtyping import Float, Int
from typing import Tuple, List, Dict, Any


import os, sys
# HOME = os.environ['HOME']  # change if necessary
HOME = '/workspace/wilson'
sys.path.append(f'{HOME}/Finite-groups/src')
from model import MLP3, MLP4, InstancedModule
from utils import *
from group_data import *
from model_utils import *
from train import Parameters
%load_ext autoreload
%autoreload 2

In [2]:
def flatten(
    params: Dict[str, t.Tensor]
) -> Tuple[t.Tensor, Dict[str, t.Size]]:
    return t.cat([p.flatten() for p in params.values()]), {k: p.size() for k, p in params.items()}

def unflatten(
    flat_params: t.Tensor, 
    shapes: Dict[str, t.Size],
) -> Dict[str, t.Tensor]:
    params = dict()
    for name, shape in shapes.items():
        size = np.prod(shape)
        params[name] = flat_params[:size].reshape(shape)
        flat_params = flat_params[size:]
    return params

In [3]:
def train_adam(flat_params: t.Tensor, loss_func: Callable, epochs: int, lr=0.01, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) -> t.Tensor:
    m = 0
    v = 0
    for i in tqdm(range(epochs)):
        g = t.func.grad(loss_func)(flat_params)
        g += weight_decay * flat_params
        m = betas[0] * m + (1 - betas[0]) * g
        v = betas[1] * v + (1 - betas[1]) * g**2
        m_hat = m / (1 - betas[0]**(i+1))
        v_hat = v / (1 - betas[1]**(i+1))
        flat_params = flat_params - lr * m_hat / (t.sqrt(v_hat) + eps)
    return flat_params

### Quadratic toy example

In [4]:
class QuadraticModel(nn.Module):
    def __init__(self, A):
        super().__init__()
        self.A = A
        self.x = nn.Parameter(t.randn(A.shape[0]))

    def forward(self):
        return (self.x.T @ self.A @ self.x).squeeze()


# B = t.randn(5, 5)
# model = QuadraticModel(B.T @ B + 0.01 * t.eye(5))
model = QuadraticModel(t.eye(1000).to(device)).to(device)
# opt = torchopt.FuncOptimizer(torchopt.adam(**ADAM_CFG))
# model_f, init_params = functorch.make_functional(model)
flat_init_params, shapes = flatten(dict(model.named_parameters()))
loss_func = lambda flat_params, model=model, shapes=shapes: t.func.functional_call(model, unflatten(flat_params, shapes), ())

In [7]:
t.func.grad(loss_func)(flat_init_params)

  return (self.x.T @ self.A @ self.x).squeeze()


tensor([-0.2592, -4.1862,  4.2671,  1.9261, -1.8756],
       grad_fn=<SliceBackwardBackward0>)

In [8]:
t.func.jacfwd(loss_func)(flat_init_params)

tensor([-0.2592, -4.1862,  4.2671,  1.9261, -1.8756], grad_fn=<ViewBackward0>)

In [30]:
train_adam(flat_init_params, loss_func, 1000)

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [00:00<00:00, 1586.72it/s]


tensor([-1.1915e-15,  3.9145e-08, -3.4294e-05,  9.1840e-17,  1.8637e-24],
       grad_fn=<SubBackward0>)

In [38]:
t.func.jacfwd(partial(train_adam, loss_func=loss_func, epochs=10))(flat_init_params)

100%|██████████| 10/10 [00:00<00:00, 173.87it/s]


tensor([[0.9998, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.9999, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.9999, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.9998, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.9553]], grad_fn=<ViewBackward0>)

In [5]:
t.cuda.memory._record_memory_history(max_entries=100000)
train_jac = t.func.jacfwd(partial(train_adam, loss_func=loss_func, epochs=1000))(flat_init_params)
t.cuda.memory._dump_snapshot('/workspace/memory_snapshot_toy.pkl')
t.cuda.memory._record_memory_history(enabled=None)

  return (self.x.T @ self.A @ self.x).squeeze()
100%|██████████| 1000/1000 [00:16<00:00, 61.98it/s]


### Groups

In [4]:
PARAMS = Parameters(
    instances=1,
    embed_dim=32,
    hidden_size=32,
    group_string='S(4)',
    model='MLP2',
    unembed_bias=True,
    weight_decay=2e-5,
    train_frac=0.4,
)

t.manual_seed(PARAMS.seed)
np.random.seed(PARAMS.seed)
random.seed(PARAMS.seed)
group_dataset = GroupData(params=PARAMS)
model = MODEL_DICT[PARAMS.model](params=PARAMS).to(device)
dataset = t.tensor(group_dataset.train_data, device=device)

flat_init_params, shapes = flatten(dict(model.named_parameters()))
def loss_func(flat_params, model=model, shapes=shapes, dataset=dataset):
    x, z = dataset[:, :-1], dataset[:, -1]
    output = t.func.functional_call(model, unflatten(flat_params, shapes), x)
    return get_cross_entropy(output, z).squeeze()

Intersection size: 576/576 (1.00)
Added 576 elements from intersection
Added 0 elements from group 0: S(4)
Taking random subset: 230/576 (0.40)
Train set size: 230/576 (0.40)


In [None]:
flat_params = train_adam(flat_init_params, loss_func, 10)
loss_func(flat_params)

  0%|          | 0/1000 [00:00<?, ?it/s]

 35%|███▌      | 354/1000 [00:05<00:08, 78.84it/s]

In [5]:
# t.cuda.memory._record_memory_history(max_entries=100000)
train_jac = t.func.jacfwd(t.compile(partial(train_adam, loss_func=loss_func, epochs=1000)))(flat_init_params)
# t.cuda.memory._dump_snapshot('/workspace/memory_snapshot.pkl')
# t.cuda.memory._record_memory_history(enabled=None)

  4%|▎         | 36/1000 [00:02<00:54, 17.85it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 94.00 MiB. GPU 0 has a total capacity of 47.50 GiB of which 80.31 MiB is free. Process 1374752 has 47.41 GiB memory in use. Of the allocated memory 41.26 GiB is allocated by PyTorch, and 5.66 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [12]:
print(t.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  34271 MiB |  34662 MiB | 218501 MiB | 184230 MiB |
|       from large pool |  34259 MiB |  34649 MiB | 218436 MiB | 184176 MiB |
|       from small pool |     12 MiB |     12 MiB |     65 MiB |     53 MiB |
|---------------------------------------------------------------------------|
| Active memory         |  34271 MiB |  34662 MiB | 218501 MiB | 184230 MiB |
|       from large pool |  34259 MiB |  34649 MiB | 218436 MiB | 184176 MiB |
|       from small pool |     12 MiB |     12 MiB |     65 MiB |     53 MiB |
|---------------------------------------------------------------

In [11]:
import gc
for obj in gc.get_objects():
    try:
        if t.is_tensor(obj) or (hasattr(obj, 'data') and t.is_tensor(obj.data)):
            print(type(obj), obj.size())
    except:
        pass

<class 'torch.Tensor'> torch.Size([230, 3])
<class 'torch.Tensor'> torch.Size([4376])
<class 'torch.nn.parameter.Parameter'> torch.Size([1, 24, 32])
<class 'torch.nn.parameter.Parameter'> torch.Size([1, 24, 32])
<class 'torch.nn.parameter.Parameter'> torch.Size([1, 32, 32])
<class 'torch.nn.parameter.Parameter'> torch.Size([1, 32, 32])
<class 'torch.nn.parameter.Parameter'> torch.Size([1, 32, 24])
<class 'torch.nn.parameter.Parameter'> torch.Size([1, 24])
<class 'torch.Tensor'> torch.Size([24, 24])
<class 'torch.Tensor'> torch.Size([4376, 4376])
<class 'torch.Tensor'> torch.Size([4376, 4376])
<class 'torch.Tensor'> torch.Size([4376, 4376])
