<a href="https://colab.research.google.com/github/Lorxus/SERI-MATS-Summer-2023/blob/main/colab_mnist_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import ExponentialLR

import numpy as np
import matplotlib.pyplot as plt
import scipy as sp

import jax
#!git clone https://github.com/google/spectral-density.git ./spectral-density/

import sys
sys.path.append("/content/spectral-density/jax")

!pip install --upgrade jax
import hessian_computation


In [None]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

#from torch.nn.modules.activation import Softmax
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)
print(model)

Using cpu device
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)


In [None]:
my_model = model
download_loc = './data'
transform = transforms.ToTensor()
# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=transform,
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=transform,
)

batch_size = 64

# Create data loaders.
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break



my_optim = optim.Adam(params=my_model.parameters(), lr=1e-3)
my_loss = nn.CrossEntropyLoss()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 158018864.05it/s]

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 13957333.08it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 45563434.02it/s]

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1736127.66it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64


In [None]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    test(test_dataloader, my_model, my_loss)
    train(train_dataloader, my_model, my_loss, my_optim)
print("Done!")

test(test_dataloader, my_model, my_loss)

Epoch 1
-------------------------------
Test Error: 
 Accuracy: 12.6%, Avg loss: 2.301974 

loss: 2.295896  [   64/60000]
loss: 0.269054  [ 6464/60000]
loss: 0.201609  [12864/60000]
loss: 0.247943  [19264/60000]
loss: 0.139091  [25664/60000]
loss: 0.345733  [32064/60000]
loss: 0.146595  [38464/60000]
loss: 0.296957  [44864/60000]
loss: 0.365678  [51264/60000]
loss: 0.154878  [57664/60000]
Epoch 2
-------------------------------
Test Error: 
 Accuracy: 95.6%, Avg loss: 0.140561 

loss: 0.066634  [   64/60000]
loss: 0.099852  [ 6464/60000]
loss: 0.097092  [12864/60000]
loss: 0.064814  [19264/60000]
loss: 0.051532  [25664/60000]
loss: 0.130961  [32064/60000]
loss: 0.063286  [38464/60000]
loss: 0.146073  [44864/60000]
loss: 0.133749  [51264/60000]
loss: 0.136715  [57664/60000]
Epoch 3
-------------------------------
Test Error: 
 Accuracy: 96.5%, Avg loss: 0.116395 

loss: 0.041139  [   64/60000]
loss: 0.065425  [ 6464/60000]
loss: 0.051794  [12864/60000]
loss: 0.175680  [19264/60000]
loss

In [None]:
#estimating the (non-naturalized) det[Hessian]
x0 = my_model.parameters()

dim = 0
for x in x0:
  if not x.flatten().shape == x.shape:
    tmp = x.flatten().shape
    dim += x.flatten().shape[0]

print(f"Parameter space dimension: {dim}")



def my_test(dataloader, model, loss_fn):
  size = len(dataloader.dataset)
  num_batches = len(dataloader)
  model.eval()
  test_loss, correct = 0, 0
  with torch.no_grad():
      for X, y in dataloader:
          X, y = X.to(device), y.to(device)
          pred = model(X)
          test_loss += loss_fn(pred, y)
          correct += (pred.argmax(1) == y).type(torch.float).sum().item()
  test_loss /= num_batches
  correct /= size
  #print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
  return test_loss

my_weight_dict = my_model.named_parameters()
#print(f"my_w named params: {my_weight_dict}")

def load_tensors_to_model(local_model, flat_weights):

  szs = []
  weights = []
  count = 0
  for name, param in local_model.named_params():
    shp = param.data.shape
    num_els = torch.sum(shp)
    #szs.append(shp)

    #tmp = flat_weights[count : count + num_els]
    w = torch.reshape(tmp, shp)

    #weights.append()
    count += num_els
    param.data = w
    #szs = [[32,784], [], [], ...]

from copy import deepcopy
def my_func(flat_weights):
  local_model = deepcopy(my_model) #my_model from out of scope
  load_tensors_to_model(local_model, flat_weights)

  tmp = my_test(test_dataloader, local_model, my_loss)
  return tmp

flat_weights = torch.cat([w[1].data.flatten() for w in my_weight_dict])
#print(f"weight dict: {my_weight_dict}")
print(f"weights {flat_weights}")

H_map = torch.autograd.functional.hessian(my_func, flat_weights)

Parameter space dimension: 668672
weights tensor([-0.0128,  0.0341,  0.0091,  ..., -0.0459,  0.1143,  0.0577])


NameError: ignored