In [2]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

from src.model import LeNet5
from src.exp import run_experiment, run_experiment_all_optimizers
from torch.optim import LBFGS, Adam

from optim_adahessian import Adahessian
from apollo import Apollo

import src.utils as utils

import matplotlib.pyplot as plt
import logging

import pickle
import copy
import numpy as np

logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(format="%(message)s")

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
def perturabate_model(model, direction, alpha):
    
    model_per = copy.deepcopy(model)
    
    for param, d in zip(model_per.parameters(),direction):
        with torch.no_grad():
            param = param + d * alpha
    
    return model_per

In [4]:
def compute_minimum_shape(results_dic, dataset_name, optimizer_name, max_amp_pert = 0.5, num_per = 25):
    
    train_loader, _ = utils.load_data('/home/app/datasets', dataset = 'MNIST')
    
    model = results_dic[dataset_name][optimizer_name]['model']
    hessian = results_dic[dataset_name][optimizer_name]['hessian']
    
    criterion = torch.nn.CrossEntropyLoss()
    
    #density_eigen, density_weight = hessian.density()
    top_eigenvalues , top_eigenvectors = hessian.eigenvalues(top_n = 1)
    
    alphas = np.linspace(-max_amp_pert,max_amp_pert,num_per)
    
    losses = []
        
    for alpha in alphas:
        model_per = perturabate_model(model,top_eigenvectors[-1],alpha)
        
        acc_loss = 0
        for inputs,targets in train_loader:
            acc_loss += criterion(model_per(inputs),targets)
        acc_loss = acc_loss/len(train_loader.dataset)
        
        losses.append(acc_loss)
    
    return alphas, losses

In [5]:
results = pickle.load(open('MNIST_results.dic','rb'))

In [None]:
out = compute_minimum_shape(results, 'MNIST','Adam')

In [None]:
plt.plot(out[0],out[1])

In [None]:
results

In [None]:
with open("test.dic","wb") as file:
    pickle.dump(results,file)

In [None]:
train_loader,test_loader = utils.load_data('/home/app/datasets', dataset = 'MNIST')

In [6]:
model = results['MNIST']['Adam']['model']

In [7]:
hesmx = results['MNIST']['Adam']['hessian']

In [None]:
lam, vec = hesmx.eigenvalues(top_n = 1)

In [None]:
perturabate_model(model,vec[-1],0.1)

In [None]:
model = results['Adam']['model']



In [None]:
results['Adam']['model']

In [None]:
results

In [None]:
# This is a simple function, that will allow us to perturb the model paramters and get the result
def get_params(model, direction, alpha):
    
    for m_orig, m_perb, d in zip(model_orig.parameters(), model_perb.parameters(), direction):
        m_perb.data = m_orig.data + alpha * d
    return model_perb