# Training and calculating sharpness

In this notebook we provide two examples of training the model and calculating the sharpness of the minimum obtained. Both examples are on FashionMNIST dataset, with *SimpleBatch* architecture (neural network with 6 convolutional layers, 2 linear layers and batch normalization). The first example uses SGD optimizer, while the second example uses Sharpness Aware Minimization (SAM) with SGD optimizer as a background optimizer. Each of these can be easily changed in the notebook. Here we demonstrate the usage of functions. Huge systematical trainings and sharpness calculations are shown in *Training Systematic* notebook.

## Drive mouting and imports

In [7]:
# libraries import
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data.dataloader as dataloader
import torch.optim as optim
from torch.utils.data import TensorDataset
from torch.autograd import Variable
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.datasets import FashionMNIST

import matplotlib.pyplot as plt
import random
from itertools import product
import sys 
import os
from datetime import datetime
from collections import namedtuple

# setting the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [2]:
# import of models and other helpers
from models import *
from helpers import *

# optimizer imports 
from optimizers.adashift import AdaShift            # code taken from: https://github.com/MichaelKonobeev/adashift
from optimizers.adabound import AdaBound            # code taken from: https://github.com/Luolc/AdaBound
from optimizers.sam import SAM                      # code taken from: https://github.com/davda54/sam

# import functions for calculating sharpness 
from sharpness.Minimum import effective as minimum_shaprness_eff        # code taken from: https://github.com/ibayashi-hikaru/minimum-sharpness

# getting the path to checkpoint folder from helpers file
checkpoint_folder = 'checkpoints_test/'

Here, we specify which architecture, dataset and maximal epoch number we use. In order to do training for different architecture or dataset, only this cell should be changed.

In [3]:
ARCHITECTURE = 'SimpleBatch'                                    # Other possibilities: 'MiddleBatch', 'ComplexBatch'
DATASET = 'FashionMNIST'                                        # Other possibilities: 'CIFAR10'
MAX_EPOCH = 50                                                  # Our trainings were done for 200 epochs, but we put 50 epochs here, so the trainings are faster
TRAIN_BATCH_SIZE = 2**7

## Loading the dataset and creating dataloaders

In [4]:
from collections import namedtuple

VAL_BATCH_SIZE = 1000

if DATASET == 'CIFAR10':
    #loading datasets
    train_data =  CIFAR10('./data', train=True, download=True, transform=transforms.Compose([
        transforms.ToTensor(), # ToTensor does min-max normalization. 
    ]), )

    test_data = CIFAR10('./data', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(), # ToTensor does min-max normalization. 
    ]), )

    #creating dataLoaders
    train_loader = dataloader.DataLoader(train_data, shuffle=True, batch_size=TRAIN_BATCH_SIZE)
    test_loader = dataloader.DataLoader(test_data, shuffle=False, batch_size=VAL_BATCH_SIZE)

if DATASET == 'FashionMNIST':
    #loading datasets
    train_data =  FashionMNIST('./data', train=True, download=True, transform=transforms.Compose([
        transforms.ToTensor(), # ToTensor does min-max normalization. 
    ]), )

    test_data = FashionMNIST('./data', train=False, download=True, transform=transforms.Compose([
        transforms.ToTensor(), # ToTensor does min-max normalization. 
    ]), )

    #creating dataLoaders
    train_loader = dataloader.DataLoader(train_data, shuffle=True, batch_size=TRAIN_BATCH_SIZE)
    test_loader = dataloader.DataLoader(test_data, shuffle=False, batch_size=VAL_BATCH_SIZE)   




Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw



In [8]:
# This cell preproccesses data for calculating the sharpness. If you change the dataset, make sure that this cell is rerun.
print(f'Preporcessing dataset {DATASET} in order to calculate sharpness...')
begin = datetime.now()

x = torch.stack([v[0] for v in train_data])
y = torch.tensor(train_data.targets)

x, y = x.to(device), y.to(device)
data = namedtuple('_','x y n')(x=x, y=y,n=len(y))

print(f'Time needed {datetime.now() - begin}')

Preporcessing dataset FashionMNIST in order to calculate sharpness...
Time needed 0:00:06.364751


  y = torch.tensor(train_data.targets)


## Train model using desired optimizer


In [9]:
# Getting the model based on the architecture and the dataset
model = get_model(ARCHITECTURE, DATASET).to(device)

# Specifying the optimizer. This can be changed to any optimizer that is supported by PyTorch.
# In order to use AdaBound:   optimizer = AdaBound(model.parameters())
# In order to use AdaShift:   optimizer = AdaShift(model.parameters())
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
path_sgd = 'sgd'

model = train(model, optimizer, train_loader=train_loader, device=device, epoch_num=MAX_EPOCH, max_nbr_epochs=MAX_EPOCH, path=path_sgd, val_dataloader=test_loader, sam=False)

new version 2
	 Test loss: 0.33952043056488035 	 Test accuracy: 87.4000015258789

	 Test loss: 0.2549797177314758 	 Test accuracy: 90.5999984741211

	 Test loss: 0.2669827178120613 	 Test accuracy: 90.61000061035156


KeyboardInterrupt: 

## Computing sharpness for model

In [None]:
lr = 0.1 if DATASET == 'FashionMNIST' else 1
num_epochs = 100000
batch_size = 128

computed = False

path = os.path.join(checkpoint_folder, path_sgd, '.pt')
checkpoint = torch.load(path, map_location=torch.device('cpu'))
model = checkpoint['state_dict']
model = get_model(ARCHITECTURE, DATASET).to(device)
model.load_state_dict(checkpoint['state_dict'])


while not computed:
    try:
        # Calculating the sharpness. Returns an error if the learning rate is too big
        sharpnesses, losses = minimum_shaprness_eff(data, model, batch_size, lr, num_epochs=num_epochs, optimizer_file=path)

        # storing the sharpness
        sharpness_path = os.path.join(checkpoint_folder, path_sgd, '_sharpness.pt')
        checkpoint = {'sharpnesses':sharpnesses, 'sharpness':sharpnesses[-1], 'losses': losses}
        torch.save(checkpoint, sharpness_path)


        computed = True
        print(f'Sharpness: {sharpnesses[-1]}')
    except:
        # Error is returned if the learning rate is too big, so in that case learning rate is set to be twice smaller and number of epochs are set to be twice as bigger
        computed = False
        lr /= 2.0
        num_epochs *= 2
        print(f'Use smaller stepsize than {lr}')

## Training the model and SAM

In [None]:
model = get_model(ARCHITECTURE, DATASET).to(device)

# Specifying the background-optimizer of SAM. This can be changed to any optimizer that is supported by PyTorch. Examples:
# In order to use PHB:        optimizer = SAM(model.parameters(), torch.optim.SGD, lr=0.1, momentum=0.8)
# In order to use Adam:       optimizer = SAM(model.parameters(), torch.optim.Adam)
# In order to use Adagrad:   optimizer = SAM(model.parameters(), torch.optim.Adagrad)
# In order to use AdaBound:   optimizer = SAM(model.parameters(), AdaBound)
# In order to use AdaShift:   optimizer = SAM(model.parameters(), AdaShift, lr=0.01)
optimizer = SAM(model.parameters(), torch.optim.SGD, lr=0.1)

path_sam = 'sam'
model = train(model, optimizer, train_loader=train_loader, device=device, epoch_num=MAX_EPOCH, max_nbr_epochs=MAX_EPOCH, path=path_sam, val_dataloader=test_loader, sam=True)

new version 2


	sub_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	sub_(Tensor other, *, Number alpha) (Triggered internally at  ../torch/csrc/utils/python_arg_parser.cpp:1055.)
  (exp_avg.sub_(first_grad_weight, offset_grad).mul_(beta1)


	 Test loss: 0.3021207422018051 	 Test accuracy: 89.3499984741211

	 Test loss: 0.264253543317318 	 Test accuracy: 90.1199951171875

	 Test loss: 0.22525830268859864 	 Test accuracy: 91.66999816894531

	 Test loss: 0.21805322766304017 	 Test accuracy: 92.00999450683594

	 Test loss: 0.20426045954227448 	 Test accuracy: 92.64999389648438

	 Test loss: 0.20500749349594116 	 Test accuracy: 92.40999603271484

	 Test loss: 0.1929409235715866 	 Test accuracy: 93.1199951171875

	 Test loss: 0.18878853023052217 	 Test accuracy: 93.18999481201172

	 Test loss: 0.18565435111522674 	 Test accuracy: 93.30999755859375

	 Test loss: 0.19338732063770295 	 Test accuracy: 93.25999450683594

	 Test loss: 0.18967241495847703 	 Test accuracy: 93.66999816894531

	 Test loss: 0.1971086546778679 	 Test accuracy: 93.54999542236328

	 Test loss: 0.20515214800834655 	 Test accuracy: 93.43000030517578

	 Test loss: 0.20754780620336533 	 Test accuracy: 93.69999694824219

	 Test loss: 0.21484714448451997 	 Test ac

## Compute sharpness for the optimizer + SAM run

In [None]:
lr = 0.1 if DATASET == 'FashionMNIST' else 1
num_epochs = 100000
computed = False
batch_size = 128

path = os.path.join(checkpoint_folder, path_sam, '.pt')
checkpoint = torch.load(path, map_location=torch.device('cpu'))
model = checkpoint['state_dict']
model = get_model(ARCHITECTURE, DATASET).to(device)
model.load_state_dict(checkpoint['state_dict'])


while not computed:
    try:
        # Calculating the sharpness. Returns an error if the learning rate is too big
        sharpnesses, losses = minimum_shaprness_eff(data, model, batch_size, lr, num_epochs=num_epochs, optimizer_file=path)

        # storing the sharpness
        sharpness_path = os.path.join(checkpoint_folder, path_sam, '_sharpness.pt')
        checkpoint = {'sharpnesses':sharpnesses, 'sharpness':sharpnesses[-1], 'losses': losses}
        torch.save(checkpoint, sharpness_path)

        # Plotting sharpness as an objective function being minimized during time
        plt.ylabel('Sharpness')
        plt.xlabel('Epoch [x100]')
        plt.plot(sharpnesses)
        plt.show()


        computed = True
        print(f'Sharpness: {sharpnesses[-1]}')
    except:
        # Error is returned if the learning rate is too big, so in that case learning rate is set to be twice smaller and number of epochs are set to be twice as bigger
        computed = False
        lr /= 2.0
        num_epochs *= 2
        print(f'Use smaller stepsize than {lr}')

## Loading the data

Here we demonstrate how to load the data. We compare here SGD and SAM with SGD. Plots used in the report are shown in the *Data analysis* notebook.

In [13]:
# Loading information about SGD training
checkpoint_sgd = torch.load(checkpoint_folder+'sgd.pt')
losses_sgd = checkpoint_sgd['training_losses']
acc_sgd = checkpoint_sgd['validation_accuracies']
sharpness_sgd = torch.load(checkpoint_folder+'sgd_sharpness.pt')['sharpness']

# Loading information about SAM SGD training
checkpoint_sam = torch.load(checkpoint_folder+'sam.pt')
losses_sam = checkpoint_sam['training_losses']
acc_sam = checkpoint_sam['validation_accuracies']
sharpness_sam = torch.load(checkpoint_folder+'sam_sharpness.pt')['sharpness']

# Plotting both
fig, ax = plt.subplots(2,1) 

ax[0].loglog(losses_sgd, label='SGD')
ax[0].loglog(losses_sam, label='SAM')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('Training loss')
ax[0].legend()

ax[1].loglog(acc_sgd, label='SGD')
ax[1].loglog(acc_sam, label='SAM')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('Test accuracy')
ax[1].legend()

# Writing sharpness
print(f'Minimum sharpness of SGD: {sharpness_sgd}')
print(f'Minimum sharpness of SAM: {sharpness_sam}')

KeyError: 'losses'