In [1]:
import os, sys
import toml
import argparse
from munch import Munch, munchify

PROJ_DIR = os.path.expanduser("~/GitWS/Transmisstion-Phase")
DATA_DIR = os.path.join(PROJ_DIR, "data")
SRC_DIR = os.path.join(PROJ_DIR, "src")
LOGS_DIR = os.path.join(PROJ_DIR, "logs", "exp1")
SCRIPTS_DIR = os.path.join(PROJ_DIR, "scripts")
CHECKPOINTS_DIR = os.path.join(PROJ_DIR, "checkpoints")
RESULTS_DIR = os.path.join(PROJ_DIR, "results")

sys.path.append(PROJ_DIR)

In [2]:
import torch
import numpy as np
np.set_printoptions(formatter={'float': lambda x: "{0:0.2f}".format(x)})
from torchvision import datasets, transforms
from functorch import make_functional, vmap, vjp, jvp, jacrev
from models import LeNet, ResNet18, ResNet50
DATASET = 'MNIST'
INDEX = 100
MODEL_NAME = 'LeNet'
PT_FILE = os.path.join(CHECKPOINTS_DIR, DATASET, f'{MODEL_NAME}_best.pt')

In [3]:
dataset = eval('datasets.'+DATASET)(DATA_DIR, train=True, download=True, 
                transform=transforms.ToTensor()
              )

In [4]:
model = eval(MODEL_NAME)().to('cuda')
model.eval()
model.load_state_dict(torch.load(PT_FILE))

<All keys matched successfully>

In [5]:
fnet, params = make_functional(model)
def fnet_single(params, x):
    return fnet(params, x.unsqueeze(0)).squeeze(0)

In [6]:
jac_list = []
for i in range(50):
    x, y = dataset[i]
    x = x.to('cuda')
    jac = vmap(jacrev(fnet_single), (None, 0))(params, x.unsqueeze(0))
    jac = [j.flatten(2) for j in jac]
    jac = [torch.squeeze(j, dim=0) for j in jac]
    jac = [torch.flatten(j).detach() for j in jac]
    jac_list.append(torch.concat(jac).detach())

In [None]:
grad_matrix = torch.stack(jac_list)
del jac_list
del jac

In [9]:
grad_matrix.shape

torch.Size([50, 11998820])

In [10]:
u, s, vh = torch.svd(grad_matrix)

In [11]:
u.shape

torch.Size([50, 50])

In [12]:
s

tensor([103.0746,  68.9476,  64.4444,  62.3768,  55.6081,  52.9760,  48.6101,
         45.1375,  41.4832,  35.7700,  34.3765,  31.5545,  30.4044,  29.7460,
         29.4365,  28.5989,  28.1778,  27.5060,  25.8869,  25.2847,  25.1297,
         24.7309,  24.3326,  23.8508,  23.4553,  22.8916,  22.6801,  22.0632,
         21.7719,  21.4858,  21.0527,  20.5111,  20.2335,  19.9890,  19.6001,
         19.4965,  18.8566,  18.3693,  17.6164,  17.4597,  16.6905,  16.4218,
         15.9653,  15.4291,  14.9738,  14.5213,  14.0964,  14.0354,  12.2378,
          8.3188], device='cuda:0')

In [16]:
vh.shape

torch.Size([11998820, 6])

In [6]:
from hessian_eigenthings import compute_hessian_eigenthings

dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False, num_workers=1)
loss = torch.nn.functional.cross_entropy
num_eigenthings = 20

eigenvals, eigenvecs = compute_hessian_eigenthings(model, dataloader,
                                                   loss, num_eigenthings)



In [8]:
eigenvecs.shape

(20, 1199882)