### Import useful libraries

In [7]:
# Import useful libraries for computation
import numpy as np

# Import torch and libraries to deal with NN
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# pip install torch_optimizer
import torch_optimizer as optim
import copy
# Import usefil library to visualize results
import matplotlib.pyplot as plt

# Importing the LeNet5 architecture we are going to use for our study and comparisons
from cnn_architectures import *

# Importing parameters to use with different optimizers before comparing them
import params

# Importing useful functions
from helpers import *

# Ignoring warnings to make the code more readable
import warnings
warnings.filterwarnings("ignore")

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Setting the parameters and additional variables

In [8]:
# Defininig neural network's parameters and seed for reproducibility purposes
RANDOM_SEED = 42
IMG_SIZE = 32
N_CLASSES = 10
# Checking device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

### Loading, reshaping and plotting  data (ADAHessian)

In [9]:
# Loading data
transforms = transforms.Compose([transforms.Resize(IMG_SIZE),
                                 transforms.ToTensor()])

# Load the MNIST dataset
raw_mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms)
raw_mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms)

# Passing train data to the dataloader
train_loader = DataLoader(dataset=raw_mnist_trainset, 
                          batch_size=params.AH_BATCH_SIZE, 
                          shuffle=True)

# Passing test data to the dataloader
test_loader = DataLoader(dataset=raw_mnist_testset, 
                          batch_size=params.AH_BATCH_SIZE, 
                          shuffle=False)

In [10]:
# Reshaping train data (from 28*28 to 32*32) for visualization purposes
train_data, train_target = reshape_train_data(raw_mnist_trainset, DEVICE)
# Reshaping test data (from 28*28 to 32*32) for visualization purposes
test_data, test_target = reshape_test_data(raw_mnist_trainset, DEVICE)

## Model training and Model Evaluation using ADAHessian

First, we train our model using LeNet5. The model was trained using batches of size 128 and 10 epochs. 

In [11]:
# Initializing the model we are going to use in our study
model = LeNet5(num_classes=N_CLASSES)
# Defining the criterion (loss function) to be used during the training procedure
criterion = nn.CrossEntropyLoss()
# Defining and initializing the optimizer (ADAM in this notebook)
optimizer = optim.Adahessian(model.parameters(),
    lr = params.AH_LEARNING_RATE,
    betas= params.AH_BETAS,
    eps= params.AH_EPS,
    weight_decay= params.AH_WD,
    hessian_power=params.AH_power)

Let's train and test our first model

In [12]:
model, optimizer, losses, grad_norms = training_loop(model, criterion, optimizer, train_loader, test_loader, params.ADAM_N_EPOCHS,
                                    DEVICE, second_order_method = True)

16:17:47 --- Epoch: 0	Train loss: 191.2571	Valid loss: 345.9440	Train accuracy: 11.24	Valid accuracy: 11.35
16:18:22 --- Epoch: 1	Train loss: 204.9368	Valid loss: 136.9343	Train accuracy: 10.22	Valid accuracy: 10.10
16:18:56 --- Epoch: 2	Train loss: 46.5819	Valid loss: 37.0907	Train accuracy: 9.75	Valid accuracy: 9.74
16:19:31 --- Epoch: 3	Train loss: 25.8781	Valid loss: 18.8322	Train accuracy: 9.75	Valid accuracy: 9.74
16:20:05 --- Epoch: 4	Train loss: 18.8863	Valid loss: 17.4562	Train accuracy: 9.91	Valid accuracy: 10.09
16:20:40 --- Epoch: 5	Train loss: 11.6708	Valid loss: 9.8202	Train accuracy: 9.87	Valid accuracy: 9.80
16:21:14 --- Epoch: 6	Train loss: 8.2005	Valid loss: 4.7378	Train accuracy: 11.24	Valid accuracy: 11.35
16:21:49 --- Epoch: 7	Train loss: 2.9584	Valid loss: 2.6984	Train accuracy: 10.22	Valid accuracy: 10.10
16:22:24 --- Epoch: 8	Train loss: 2.5337	Valid loss: 2.4626	Train accuracy: 9.91	Valid accuracy: 10.09
16:22:58 --- Epoch: 9	Train loss: 2.5589	Valid loss: 2.58

In [None]:
plot_gradient_norm(grad_norms[-30:], method = 'ADAHessian')

In [None]:
compute_confusion_matrix(test_loader, model, N_CLASSES)

In [None]:
# We now divide the training dataset into batches to compute the hessian of the loss evaluated in the solution
indices = np.random.choice(len(train_data),1000)
hessian_input, hessian_label = train_data[indices].to(DEVICE), train_target[indices].to(DEVICE)


# We now compute the hessian matrix, to later retrieve the spectral norm and the eigenvalues
device_flag = True if torch.cuda.is_available() else False
model_to_plot = copy.deepcopy(model)
criterion = torch.nn.CrossEntropyLoss()
hessian_comp = hessian(model_to_plot, criterion, data=(hessian_input, hessian_label), cuda=device_flag)

# Now let's compute the top eigenvalue. This only takes a few seconds.
top_eigenvalues, top_eigenvector = hessian_comp.eigenvalues(top_n=1)

# Now let's compute the top 2 eigenavlues and eigenvectors of the Hessian
print("The top eigenvalue of this model is: %.4f "% (top_eigenvalues[0]))

In [None]:
# lambda is a small scalar that we use to perturb the model parameters along the eigenvectors 
lams = np.linspace(-0.5, 0.5, 21).astype(np.float32)

loss_list = []

# At first, we initialized the perturb model to be the model obtained at the end of the training procedure
model_perb = copy.deepcopy(model)

# We now perturb the function in the direction given by the top eigenvector to visualize the quality of the minimum
for lam in lams:
    model_perb = get_params(model, model_perb, top_eigenvector[0], lam)
    loss_list.append(criterion(model_perb(hessian_input), hessian_label).item())

plt.plot(lams, loss_list)
plt.ylabel('Loss')
plt.xlabel('Perturbation')
plt.title('Loss landscape perturbed based on top Hessian eigenvector')

In [None]:
from pyhessian.utils import normalization


# used to perturb your model 
lams = np.linspace(-0.5, 0.5, 21).astype(np.float32)

loss_list = []

# create a copy of the model
model_perb = copy.deepcopy(model)

# generate gradient vector to do the loss plot
loss = criterion(model_perb(hessian_input), hessian_label)
loss.backward()

v = [p.grad.data for p in model_perb.parameters()]
v = normalization(v)
model_perb.zero_grad()


for lam in lams: 
    model_perb = get_params(model, model_perb, v, lam)
    loss_list.append(criterion(model_perb(hessian_input), hessian_label).item())

plt.plot(lams, loss_list)
plt.ylabel('Loss')
plt.xlabel('Perturbation')
plt.title('Loss landscape perturbed based on gradient direction')