# drive mouting and imports

In [1]:
#mounting drive and setting path
from google.colab import drive
drive.mount('/content/drive')
import sys 
import os
sys.path.append('/content/drive/MyDrive/OptML_project')
checkpoint_folder = 'checkpoints/'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
from sharpness.MinimumSimple import effective as minimum_shaprness_eff

In [3]:
#libraries import
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data.dataloader as dataloader
import torch.optim as optim
from torch.utils.data import TensorDataset
from torch.autograd import Variable
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10, FashionMNIST
import matplotlib.pyplot as plt
from collections import namedtuple
from datetime import datetime

#files imports
from adashift import AdaShift
from adabound import AdaBound
#from sam import SAM
from models import *
from main import *

#setting the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [11]:
DATASET = 'FashionMNIST'

In [12]:
def plot_model(test_data_loader, model):
    test_statistics = test(model, test_data_loader, device=device)
    accuracy = test_statistics['accuracy']
    loss = test_statistics['loss']
    print(f"Performance on validation data:\naccuracy : {accuracy:.2f}% | loss = {loss:.6f}")

def get_model(architecture, dataset):
    input_channels = 3 if dataset == 'CIFAR10' else 1
    size = 32 if dataset == 'CIFAR10' else 28
    if architecture == 'SimpleBatch':
        return SimpleBatch(input_channels=input_channels, size=size)
    if architecture == 'MiddleBatch':
        return MiddleBatch(input_channels=input_channels, size=size)
    if architecture == 'ComplexBatch':
        return ComplexBatch(input_channels=input_channels, size=size)


## Loading dataset and preporcessing data

In [13]:
TRAIN_BATCH_SIZE = 2**7
VAL_BATCH_SIZE = 1000

if DATASET == 'CIFAR10':
    #loading datasets
    train_data =  CIFAR10('./data', train=True, download=True, transform=transforms.Compose([
        transforms.ToTensor(), # ToTensor does min-max normalization. 
    ]), )

    test_data = CIFAR10('./data', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(), # ToTensor does min-max normalization. 
    ]), )

    #creating dataLoaders
    train_loader = dataloader.DataLoader(train_data, shuffle=True, batch_size=TRAIN_BATCH_SIZE)
    test_loader = dataloader.DataLoader(test_data, shuffle=False, batch_size=VAL_BATCH_SIZE)

if DATASET == 'FashionMNIST':
    #loading datasets
    train_data =  FashionMNIST('./data', train=True, download=True, transform=transforms.Compose([
        transforms.ToTensor(), # ToTensor does min-max normalization. 
    ]), )

    test_data = FashionMNIST('./data', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(), # ToTensor does min-max normalization. 
    ]), )

    #creating dataLoaders
    train_loader = dataloader.DataLoader(train_data, shuffle=True, batch_size=TRAIN_BATCH_SIZE)
    test_loader = dataloader.DataLoader(test_data, shuffle=False, batch_size=VAL_BATCH_SIZE)

In [14]:
print(f'Preporcessing dataset {DATASET}...')
begin = datetime.now()

x = torch.stack([v[0] for v in train_data])
y = torch.tensor(train_data.targets)

#x, y = x.to(device), y.to(device)
x = x.cuda()
y = y.cuda()

data = namedtuple('_','x y n')(x=x, y=y,n=len(y))

print(f'Time needed {datetime.now() - begin}')

Preporcessing dataset FashionMNIST...
Time needed 0:00:04.697079


  """


In [15]:
epoch = 200
architecture = 'MiddleBatch'

PATH = f'{checkpoint_folder}{DATASET}/{architecture}/epoch200/SGD_0.1.pt'
print('Loading model...')

checkpoint = torch.load(PATH, map_location=torch.device('cpu'))
model = checkpoint['state_dict']
model = get_model(architecture, DATASET).to(device)
model.load_state_dict(checkpoint['state_dict'])


print(f'Dataset: {DATASET} \t Architecture: {architecture} \t Optimizer: SGD lr = 0.1')
print(f'Training statistics: \t accuracy: {checkpoint["training_accuracy"][-1]} \t loss: {checkpoint["training_loss"][-1]}')
print(f'Test statistics: \t accuracy: {checkpoint["validation_accuracy"][-1]} \t loss: {checkpoint["validation_loss"][-1]}')


Loading model...
Dataset: FashionMNIST 	 Architecture: MiddleBatch 	 Optimizer: SGD lr = 0.1
Training statistics: 	 accuracy: 100.0 	 loss: 6.022807643707049e-06
Test statistics: 	 accuracy: 93.79999542236328 	 loss: 0.509210342168808


In [17]:
import matplotlib.pyplot as plt

epoch = 200
for architecture in ['SimpleBatch', 'MiddleBatch', 'ComplexBatch']:

    directory = f'{checkpoint_folder}{DATASET}/{architecture}/converged'
    files = os.listdir(directory)

    print('-'*100)
    print(architecture)
    print('-'*100)

    for filename in files:
        if filename.endswith('_sharpness.pt') or filename.endswith('_hessian.pt'):
            continue
        

        print(f'Current file:{filename}')
        sharpness_filename = filename.replace('.pt', '_sharpness.pt')
        if sharpness_filename in files:
            print(f'For {filename.replace(".pt", "")} shrapness is already computed\n')
            continue

        if 'adashift' in filename.lower():
            print(f'There is sth weird with AdaShift')
            continue

        
        path = os.path.join(directory, filename)
        checkpoint = torch.load(path, map_location=torch.device('cpu'))
        model = checkpoint['state_dict']
        model = get_model(architecture, DATASET).to(device)
        model.load_state_dict(checkpoint['state_dict'])

        
        lr = 0.1
        num_epochs = 100000
        computed = False
        
        while not computed:
            try:
                sharpnesses, losses = minimum_shaprness_eff(data, model, 128, lr, num_epochs=num_epochs, optimizer_file=path)
                sharpness_path = os.path.join(directory, sharpness_filename)
                checkpoint = {'sharpnesses':sharpnesses, 'sharpness':sharpnesses[-1], 'losses': losses}
                torch.save(checkpoint, sharpness_path)

        
                computed = True
                print(f'\t{filename} done \t sharpness: {sharpnesses[-1]}')
            except:
                print(f'Use smaller stepsize than {lr}')
                computed = False
                lr /= 2.0
                num_epochs *= 2


        print()
        
    print('-'*100)



----------------------------------------------------------------------------------------------------
SimpleBatch
----------------------------------------------------------------------------------------------------
Current file:PHB_0.1_0.8.pt
	 Calculating Hessian




	Finished the diag calculation. Time needed: 0:01:10.100262. Computing sharpness...
 		 epoch:100000	 processed 100.0%	 loss:0.010402049198079463 	 minimum sharpness: 95233.87126268476 	 Time needed 0:02:01.232482
--------------------------------------------------
	PHB_0.1_0.8.pt done 	 sharpness: 95233.87126268476

Current file:SGD_0.1.pt
	 Calculating Hessian
	Finished the diag calculation. Time needed: 0:01:09.092269. Computing sharpness...
 		 epoch:100000	 processed 100.0%	 loss:0.012309919085253055 	 minimum sharpness: 54812.88849387634 	 Time needed 0:02:02.645844
--------------------------------------------------
	SGD_0.1.pt done 	 sharpness: 54812.88849387634

Current file:Adam.pt
	 Calculating Hessian
	Finished the diag calculation. Time needed: 0:01:09.083995. Computing sharpness...
 		 epoch:100000	 processed 100.0%	 loss:0.009570424038998286 	 minimum sharpness: 78887.24651684402 	 Time needed 0:02:00.312480
--------------------------------------------------
	Adam.pt done 

In [None]:
! ls drive/MyDrive/OptML_project/checkpoints/CIFAR10/SimpleBatch/epoch200/

AdaBound.pt  AdaShift.pt      SAM_Adagrad.pt		SAM_SGD_0.1.pt
Adagrad.pt   PHB_0.1_0.8.pt   SAM_Adagrad_sharpness.pt	SGD_0.1.pt
Adam.pt      SAM_AdaBound.pt  SAM_Adam.pt		SGD_0.1_sharpness.pt


In [None]:
#42853.32931748071

checkpoint = torch.load('drive/MyDrive/OptML_project/checkpoints/CIFAR10/SimpleBatch/epoch200/SAM_Adagrad_sharpness.pt')
checkpoint['sharpness']

15788.708123962846

In [None]:
plt.plot(sharpnesses_SAM, label='SAM')
plt.plot(sharpnesses_SGD, label='SGD')
plt.legend()
plt.show()

In [None]:
from copy import deepcopy
losses_SGD = deepcopy(sharpnesses)
sharpnesses_SGD = deepcopy(losses)