### Setup

In [1]:
from src.optimizers.SAMTrainer import SAMTrainer
from src.optimizers.volumes import VolumeFunction, LogVolume
%load_ext autoreload
%autoreload 2
import sys
import os

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if project_root not in sys.path:
    sys.path.append(project_root)

In [76]:
import torch
from src.optimizers.LagrangianTrainer import LagrangianTrainer
from src.optimizers.SimpleTrainer import SimpleTrainer
from src.utils import dataset
from src.optimizers.HypercubeTrainer import HypercubeTrainer
from src.utils.evaluation import evaluate_accuracy
from src.cert import Safebox

from src.utils.dataset import reduce_dataset

from src.optimizers.volumes import LogVolume


In [3]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Tests on FashionMNIST

In [4]:
train_dataset, val_dataset = dataset.get_fashion_mnist_dataset()
train_dataset = reduce_dataset(train_dataset, num_samples=300)

In [5]:
def get_model(output_dim=10):
    """Returns a simple CNN model."""
    model = torch.nn.Sequential(
        torch.nn.Conv2d(1, 8, kernel_size=5, stride=1, padding=1),
        torch.nn.ReLU(),
        torch.nn.Conv2d(8, 1, kernel_size=5, stride=1, padding=1),
        torch.nn.ReLU(),
        torch.nn.Flatten(),
        torch.nn.Linear(576, output_dim),
    ).to(DEVICE)
    return model

# Tests on Ciphar 100



In [34]:
train_dataset, val_dataset = dataset.get_cifar100_dataset()

In [38]:
def get_model(output_dim=100):
    model = torch.nn.Sequential(
        torch.nn.Conv2d(3, 12, kernel_size=5, stride=1, padding=1),
        torch.nn.ReLU(),
        torch.nn.Conv2d(12, 48, kernel_size=5, stride=1, padding=1),
        torch.nn.ReLU(),
        torch.nn.Conv2d(48, 12, kernel_size=3, stride=1, padding=1),
        torch.nn.ReLU(),
        torch.nn.Conv2d(12, 1, kernel_size=3, stride=1, padding=1),
        torch.nn.Flatten(),
        torch.nn.Linear(784,784),
        torch.nn.ReLU(),
        torch.nn.Linear(784, output_dim),
    ).to(DEVICE)
    return model

##### Baseline with Adam

In [39]:
base_model = get_model()

base_trainer = SimpleTrainer(base_model, device=DEVICE)
base_model = base_trainer.train(train_dataset, val_dataset, loss_obj=0.000000000000001, max_iters=10000, batch_size=64, lr=1e-3)

100%|██████████| 10000/10000 [03:32<00:00, 46.99it/s, loss=0.406, val_acc=0.203]


----------  Training completed with loss  0 ----------


In [40]:
print("Training base model accuracy", evaluate_accuracy(train_dataset, base_model, num_samples=len(val_dataset), device=DEVICE))
print("Validation base model accuracy", evaluate_accuracy(val_dataset, base_model, num_samples=len(val_dataset), device=DEVICE))

Training base model accuracy 0.9325999617576599
Validation base model accuracy 0.19429999589920044


##### Sam
-------------------------

rho = 0.05 (default)

In [41]:
sam_model = get_model()

sam_trainer = SAMTrainer(sam_model, device=DEVICE)
sam_model = sam_trainer.train(train_dataset, val_dataset, loss_obj=0.000000000000001, max_iters=10000, batch_size=64, lr=1e-3, rho=0.05)

100%|██████████| 10000/10000 [03:55<00:00, 42.45it/s, loss=0.987, val_acc=0.0781]

----------  Training completed with loss  1 ----------





In [43]:
print("Training SAM model accuracy", evaluate_accuracy(train_dataset, sam_model, num_samples=len(val_dataset), device=DEVICE))
print("Validation SAM model accuracy", evaluate_accuracy(val_dataset, sam_model, num_samples=len(val_dataset), device=DEVICE))

Training SAM model accuracy 0.9426999688148499
Validation SAM model accuracy 0.18490000069141388


rho = 0.01 (smaller)

In [69]:
sam_model_1 = get_model()

sam_trainer_1 = SAMTrainer(sam_model_1, device=DEVICE)
sam_model_1 = sam_trainer_1.train(train_dataset, val_dataset, loss_obj=0.000000000000001, max_iters=10000, batch_size=64, lr=1e-3, rho=0.01)

100%|██████████| 10000/10000 [03:56<00:00, 42.26it/s, loss=0.512, val_acc=0.188]

----------  Training completed with loss  1 ----------





In [70]:
print("Training SAM model accuracy (rho=0.01)", evaluate_accuracy(train_dataset, sam_model_1, num_samples=len(val_dataset), device=DEVICE))
print("Validation SAM model accuracy (rho=0.01)", evaluate_accuracy(val_dataset, sam_model_1, num_samples=len(val_dataset), device=DEVICE))

Training SAM model accuracy (rho=0.01) 0.9380999803543091
Validation SAM model accuracy (rho=0.01) 0.19029998779296875


rho = 0.001 (way smaller)

In [71]:
sam_model_2 = get_model()

sam_trainer_2 = SAMTrainer(sam_model_2, device=DEVICE)
sam_model_2 = sam_trainer_2.train(train_dataset, val_dataset, loss_obj=0.000000000000001, max_iters=10000, batch_size=64, lr=1e-3, rho=0.001)

100%|██████████| 10000/10000 [03:57<00:00, 42.18it/s, loss=0.373, val_acc=0.234]

----------  Training completed with loss  0 ----------





In [72]:
print("Training SAM model accuracy (rho=0.001)", evaluate_accuracy(train_dataset, sam_model_2, num_samples=len(val_dataset), device=DEVICE))
print("Validation SAM model accuracy (rho=0.001)", evaluate_accuracy(val_dataset, sam_model_2, num_samples=len(val_dataset), device=DEVICE))

Training SAM model accuracy (rho=0.001) 0.9375999569892883
Validation SAM model accuracy (rho=0.001) 0.18809999525547028


rho = 0.1 (bigger)

In [73]:
sam_model_3 = get_model()

sam_trainer_3 = SAMTrainer(sam_model_3, device=DEVICE)
sam_model_3 = sam_trainer_3.train(train_dataset, val_dataset, loss_obj=0.000000000000001, max_iters=10000, batch_size=64, lr=1e-3, rho=0.1)

100%|██████████| 10000/10000 [03:56<00:00, 42.22it/s, loss=2.16, val_acc=0.172]


----------  Training completed with loss  2 ----------


In [75]:
print("Training SAM model accuracy (rho=0.1)", evaluate_accuracy(train_dataset, sam_model_3, num_samples=len(val_dataset), device=DEVICE))
print("Validation SAM model accuracy (rho=0.1)", evaluate_accuracy(val_dataset, sam_model_3, num_samples=len(val_dataset), device=DEVICE))

Training SAM model accuracy (rho=0.1) 0.7524999976158142
Validation SAM model accuracy (rho=0.1) 0.18159998953342438


### Vanilla SGD vs SAM + SGD
-----------------------------

Not modularized yet (changed manually adam -> sgd in SamTrainer and SimpleTrainer :/ )

In [77]:
sam_sgd_model = get_model()

sam_sgd_trainer = SAMTrainer(sam_sgd_model, device=DEVICE)
sam_sgd_model = sam_sgd_trainer.train(train_dataset, val_dataset, loss_obj=0.000000000000001, max_iters=10000, batch_size=64, lr=1e-3, rho=0.01)

100%|██████████| 10000/10000 [03:54<00:00, 42.60it/s, loss=4.6, val_acc=0.0312]

----------  Training completed with loss  5 ----------





In [78]:
print("Training SGD SAM model accuracy (rho=0.01)", evaluate_accuracy(train_dataset, sam_sgd_model, num_samples=len(val_dataset), device=DEVICE))
print("Validation SGD SAM model accuracy (rho=0.01)", evaluate_accuracy(val_dataset, sam_sgd_model, num_samples=len(val_dataset), device=DEVICE))

Training SGD SAM model accuracy (rho=0.01) 0.012600000016391277
Validation SGD SAM model accuracy (rho=0.01) 0.013799999840557575


In [79]:
base_model = get_model()

base_trainer = SimpleTrainer(base_model, device=DEVICE)
base_model = base_trainer.train(train_dataset, val_dataset, loss_obj=0.000000000000001, max_iters=10000, batch_size=64, lr=1e-3)

 11%|█         | 1058/10000 [00:21<03:04, 48.43it/s, loss=4.61, val_acc=0.0156]


KeyboardInterrupt: 

##### Lagrangian
---------------------

In [57]:
lagrangian_model = get_model()
lagrangian_trainer = LagrangianTrainer(lagrangian_model, LogVolume(epsilon=1e-12), device=DEVICE)
lagrangian_trainer.set_volume_constrain(1e-4) # start with a small volume at first
print(lagrangian_trainer._volume_function(lagrangian_trainer._interval_model))
lagrangian_trainer.train(
    train_dataset, val_dataset, loss_obj=-0.000000000000001, max_iters=3000, batch_size=64, lr=1e-4
)

tensor(-8.5172, device='cuda:0', grad_fn=<DivBackward0>)


 35%|███▌      | 1060/3000 [00:54<01:40, 19.30it/s, loss=4.6, min_val_acc=0, current_volume=-6.9]     


KeyboardInterrupt: 

In [51]:
print("Training base model accuracy",
      evaluate_accuracy(train_dataset, lagrangian_model, num_samples=len(val_dataset), device=DEVICE))
print("Validation base model accuracy",
      evaluate_accuracy(val_dataset, lagrangian_model, num_samples=len(val_dataset), device=DEVICE))

Training base model accuracy 0.010300000198185444
Validation base model accuracy 0.009999999776482582
