Evaluate the models

# Stochastic Adversarial Training (StochAT)

### SoTA

vanila SGD: 
MNIST - 99%+ (most cnns), CIFAR10 - 93%+ (resnet18), 96%+ (wideresnet) 

MNIST:

adversarial attacks: 
l-inf @ eps = 80/255 @20 steps: TRADES - 96.07% - (4 layer cnn), MART 96.4%, MMA 95.5%, PGD - 96.01% - (4 layer cnn)

adversarial attacks:
l-2 @ eps = 32/255 (check): TRADES, MMA, PGD

CIFAR10:

adversarial attacks: 
l-inf @ eps = 8/255 @20 steps: 
TRADES 53-56% - (WRN-34-10), MART 57-58% (WRN-34-10), MMA 47%, PGD 48% - (WRN-32-10)// 49% - (WRN-34-10), Std - 0.03%
https://openreview.net/pdf?id=rklOg6EFwS (Table 4)

adversarial attacks: 
l-inf @ eps = 8/255 @20 steps: 
[ResNet10] TRADES 45.4%, MART 46.6%, MMA 37.26%, PGD 42.27%, Std 0.14%

Benign accuracies: TRADES 84.92%, MART 83.62%, MMA 84.36, PGD 87.14%, Std 95.8% [wideresnet]
https://openreview.net/pdf?id=Ms9zjhVB5R (Table 1)

adversarial attacks:
l-2 @ eps = 32/255 (check): TRADES, MART, MMA, PGD

TBD: CWinf attacks

## Pretrained models for comparison

download pretrained models and place in ../trainedmodels/MNIST or ../trainedmodels/CIFAR10 respectively

### TRADES :
https://github.com/yaodongyu/TRADES (MNIST: small cnn, CIFAR10: WideResNet34)
### MMA : 
https://github.com/BorealisAI/mma_training (MNIST: lenet5, CIFAR10: WideResNet28)
### MART :
 https://github.com/YisenWang/MART (CIFAR10: ResNet18 and WideResNet34)

## IMPORT LIBRARIES

In [1]:
import numpy as np
import pandas as pd
from torch import nn, optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms, datasets
from multiprocessing import cpu_count
from collections import OrderedDict
import matplotlib.pyplot as plt
import torch
import olympic
import sys
from typing import Union, Callable, Tuple
sys.path.append('../adversarial/')
sys.path.append('../architectures/')
from functional import boundary, iterated_fgsm, local_search, pgd, entropySmoothing
from ESGD_utils import *
import pickle
import time
import torch.backends.cudnn as cudnn
import argparse, math, random
import ESGD_optim
from trades import trades_loss

In [2]:
if torch.cuda.is_available():
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'

In [3]:
DEVICE

'cuda'

# LOAD DATA

In [4]:
#place data folders outside working directory

In [5]:
dataset = 'CIFAR10' # [MNIST, CIFAR10]
transform = transforms.Compose([
    transforms.ToTensor(),
])
bsz = 128
if dataset == 'MNIST':
    train = datasets.MNIST('../../data/MNIST', train=True, transform=transform, download=True)
    val = datasets.MNIST('../../data/MNIST', train=False, transform=transform, download=True)
elif dataset == 'CIFAR10':
    train = datasets.CIFAR10('../../data/CIFAR10', train=True, transform=transform, download=True)
    val = datasets.CIFAR10('../../data/CIFAR10', train=False, transform=transform, download=True)
    
train_loader = DataLoader(train, batch_size=128, num_workers=cpu_count(),drop_last=True)
val_loader = DataLoader(val, batch_size=128, num_workers=cpu_count(),drop_last=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ../../data/CIFAR10/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ../../data/CIFAR10/cifar-10-python.tar.gz to ../../data/CIFAR10
Files already downloaded and verified


# INITIALIZE NETWORK

In [6]:
if dataset=='MNIST':
    from net_mnist import Net, NetSoft, model_cnn
    Net = model_cnn
    NetName = 'model_cnn'

In [12]:
if dataset=='CIFAR10':
    #[ResNet18,ResNet34,ResNet50,WideResNet]
    from resnet import ResNet18,ResNet34,ResNet50
    from wideresnet import WideResNet
    #Net = WideResNet
    #NetName = 'WideResNet'
    Net = ResNet18
    NetName = 'ResNet18'

In [13]:
Net

<function resnet.ResNet18()>

# RANDOM SEED 

In [14]:
seed = 42
torch.set_num_threads(2)
if DEVICE=='cuda':
    torch.cuda.set_device(-1)
    torch.cuda.manual_seed(seed)
    cudnn.benchmark = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f823a215d50>

# LOAD PRETRAINED OR TRAIN NEW MODELS:

In [15]:
TrainSGD = True
TrainESGD = True
TrainL2 = True
TrainLInf = True
TrainSAT2 = True
TrainSATInf = True
TrainTRADES = True
TrainTRADESInf = True
# tbd: add training modules for MART(inf only), MMA and MMAInf - always keep False
TrainMART = False 
TrainMMA = False
TrainMMAInf = False

# TRAIN NAIVE MODEL USING SGD

In [18]:
if TrainSGD:
    ## initialize model
    model_SGD = Net().to(DEVICE)
    ## training params
    lr = 0.001
    optimiser = optim.SGD(model_SGD.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    epochs = 10
    ## train model
    history_natural = olympic.fit(
        model_SGD,
        optimiser,
        loss_fn,
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5)
        ]
    )
    ## verify validation accuracy
    print('final validation accuracy:')
    valscore = olympic.evaluate(model_SGD, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    print(valscore)
    ## save model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_SGD_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_SGD,modelname)


Epoch 1:   0%|          | 0/390 [00:00<?, ?it/s][A

Begin training...



Epoch 1:   0%|          | 1/390 [00:00<04:28,  1.45it/s][A
Epoch 1:   0%|          | 1/390 [00:00<04:28,  1.45it/s, loss=2.33, accuracy=0.0938][A
Epoch 1:   1%|          | 2/390 [00:00<04:28,  1.45it/s, loss=2.31, accuracy=0.117] [A
Epoch 1:   1%|          | 3/390 [00:00<03:14,  1.99it/s, loss=2.31, accuracy=0.117][A
Epoch 1:   1%|          | 3/390 [00:00<03:14,  1.99it/s, loss=2.35, accuracy=0.0703][A
Epoch 1:   1%|          | 4/390 [00:00<03:13,  1.99it/s, loss=2.33, accuracy=0.0859][A
Epoch 1:   1%|▏         | 5/390 [00:00<02:22,  2.70it/s, loss=2.33, accuracy=0.0859][A
Epoch 1:   1%|▏         | 5/390 [00:00<02:22,  2.70it/s, loss=2.31, accuracy=0.0703][A
Epoch 1:   2%|▏         | 6/390 [00:00<02:22,  2.70it/s, loss=2.31, accuracy=0.164] [A
Epoch 1:   2%|▏         | 7/390 [00:01<01:45,  3.62it/s, loss=2.31, accuracy=0.164][A
Epoch 1:   2%|▏         | 7/390 [00:01<01:45,  3.62it/s, loss=2.33, accuracy=0.117][A
Epoch 1:   2%|▏         | 8/390 [00:01<01:45,  3.62it/s, loss=

Epoch 1:  32%|███▏      | 124/390 [00:07<00:14, 18.10it/s, loss=2.03, accuracy=0.227][A
Epoch 1:  32%|███▏      | 125/390 [00:07<00:14, 18.10it/s, loss=2.03, accuracy=0.227][A
Epoch 1:  32%|███▏      | 125/390 [00:07<00:14, 18.10it/s, loss=2.06, accuracy=0.258][A
Epoch 1:  32%|███▏      | 126/390 [00:07<00:14, 18.10it/s, loss=2, accuracy=0.273]   [A
Epoch 1:  33%|███▎      | 127/390 [00:07<00:14, 18.10it/s, loss=2, accuracy=0.273][A
Epoch 1:  33%|███▎      | 127/390 [00:07<00:14, 18.10it/s, loss=2.07, accuracy=0.297][A
Epoch 1:  33%|███▎      | 128/390 [00:07<00:14, 18.10it/s, loss=2.07, accuracy=0.234][A
Epoch 1:  33%|███▎      | 129/390 [00:07<00:14, 18.16it/s, loss=2.07, accuracy=0.234][A
Epoch 1:  33%|███▎      | 129/390 [00:07<00:14, 18.16it/s, loss=2.04, accuracy=0.281][A
Epoch 1:  33%|███▎      | 130/390 [00:07<00:14, 18.16it/s, loss=2.13, accuracy=0.18] [A
Epoch 1:  34%|███▎      | 131/390 [00:07<00:14, 18.12it/s, loss=2.13, accuracy=0.18][A
Epoch 1:  34%|███▎      |

Epoch 1:  63%|██████▎   | 247/390 [00:14<00:07, 18.11it/s, loss=1.93, accuracy=0.281][A
Epoch 1:  63%|██████▎   | 247/390 [00:14<00:07, 18.11it/s, loss=1.88, accuracy=0.312][A
Epoch 1:  64%|██████▎   | 248/390 [00:14<00:07, 18.11it/s, loss=1.89, accuracy=0.344][A
Epoch 1:  64%|██████▍   | 249/390 [00:14<00:07, 18.07it/s, loss=1.89, accuracy=0.344][A
Epoch 1:  64%|██████▍   | 249/390 [00:14<00:07, 18.07it/s, loss=1.91, accuracy=0.336][A
Epoch 1:  64%|██████▍   | 250/390 [00:14<00:07, 18.07it/s, loss=1.91, accuracy=0.383][A
Epoch 1:  64%|██████▍   | 251/390 [00:14<00:07, 18.12it/s, loss=1.91, accuracy=0.383][A
Epoch 1:  64%|██████▍   | 251/390 [00:14<00:07, 18.12it/s, loss=1.89, accuracy=0.312][A
Epoch 1:  65%|██████▍   | 252/390 [00:14<00:07, 18.12it/s, loss=1.9, accuracy=0.328] [A
Epoch 1:  65%|██████▍   | 253/390 [00:14<00:07, 18.14it/s, loss=1.9, accuracy=0.328][A
Epoch 1:  65%|██████▍   | 253/390 [00:14<00:07, 18.14it/s, loss=1.86, accuracy=0.32][A
Epoch 1:  65%|██████▌  

Epoch 1:  95%|█████████▍| 369/390 [00:21<00:01, 18.07it/s, loss=1.68, accuracy=0.398][A
Epoch 1:  95%|█████████▍| 370/390 [00:21<00:01, 18.07it/s, loss=1.71, accuracy=0.344][A
Epoch 1:  95%|█████████▌| 371/390 [00:21<00:01, 18.04it/s, loss=1.71, accuracy=0.344][A
Epoch 1:  95%|█████████▌| 371/390 [00:21<00:01, 18.04it/s, loss=1.73, accuracy=0.367][A
Epoch 1:  95%|█████████▌| 372/390 [00:21<00:00, 18.04it/s, loss=1.79, accuracy=0.367][A
Epoch 1:  96%|█████████▌| 373/390 [00:21<00:00, 18.07it/s, loss=1.79, accuracy=0.367][A
Epoch 1:  96%|█████████▌| 373/390 [00:21<00:00, 18.07it/s, loss=1.71, accuracy=0.391][A
Epoch 1:  96%|█████████▌| 374/390 [00:21<00:00, 18.07it/s, loss=1.76, accuracy=0.367][A
Epoch 1:  96%|█████████▌| 375/390 [00:21<00:00, 18.05it/s, loss=1.76, accuracy=0.367][A
Epoch 1:  96%|█████████▌| 375/390 [00:21<00:00, 18.05it/s, loss=1.82, accuracy=0.383][A
Epoch 1:  96%|█████████▋| 376/390 [00:21<00:00, 18.05it/s, loss=1.96, accuracy=0.273][A
Epoch 1:  97%|███████

Epoch 2:  26%|██▋       | 103/390 [00:06<00:16, 17.80it/s, loss=1.74, accuracy=0.383][A
Epoch 2:  26%|██▋       | 103/390 [00:06<00:16, 17.80it/s, loss=1.73, accuracy=0.344][A
Epoch 2:  27%|██▋       | 104/390 [00:06<00:16, 17.80it/s, loss=1.67, accuracy=0.344][A
Epoch 2:  27%|██▋       | 105/390 [00:06<00:15, 17.89it/s, loss=1.67, accuracy=0.344][A
Epoch 2:  27%|██▋       | 105/390 [00:06<00:15, 17.89it/s, loss=1.74, accuracy=0.391][A
Epoch 2:  27%|██▋       | 106/390 [00:06<00:15, 17.89it/s, loss=1.67, accuracy=0.461][A
Epoch 2:  27%|██▋       | 107/390 [00:06<00:15, 17.98it/s, loss=1.67, accuracy=0.461][A
Epoch 2:  27%|██▋       | 107/390 [00:06<00:15, 17.98it/s, loss=1.73, accuracy=0.414][A
Epoch 2:  28%|██▊       | 108/390 [00:06<00:15, 17.98it/s, loss=1.67, accuracy=0.359][A
Epoch 2:  28%|██▊       | 109/390 [00:06<00:15, 17.99it/s, loss=1.67, accuracy=0.359][A
Epoch 2:  28%|██▊       | 109/390 [00:06<00:15, 17.99it/s, loss=1.63, accuracy=0.438][A
Epoch 2:  28%|██▊    

Epoch 2:  58%|█████▊    | 225/390 [00:13<00:09, 17.92it/s, loss=1.66, accuracy=0.438][A
Epoch 2:  58%|█████▊    | 226/390 [00:13<00:09, 17.92it/s, loss=1.48, accuracy=0.516][A
Epoch 2:  58%|█████▊    | 227/390 [00:13<00:09, 17.92it/s, loss=1.48, accuracy=0.516][A
Epoch 2:  58%|█████▊    | 227/390 [00:13<00:09, 17.92it/s, loss=1.62, accuracy=0.445][A
Epoch 2:  58%|█████▊    | 228/390 [00:13<00:09, 17.92it/s, loss=1.63, accuracy=0.414][A
Epoch 2:  59%|█████▊    | 229/390 [00:13<00:08, 17.91it/s, loss=1.63, accuracy=0.414][A
Epoch 2:  59%|█████▊    | 229/390 [00:13<00:08, 17.91it/s, loss=1.63, accuracy=0.453][A
Epoch 2:  59%|█████▉    | 230/390 [00:13<00:08, 17.91it/s, loss=1.56, accuracy=0.43] [A
Epoch 2:  59%|█████▉    | 231/390 [00:13<00:08, 17.93it/s, loss=1.56, accuracy=0.43][A
Epoch 2:  59%|█████▉    | 231/390 [00:13<00:08, 17.93it/s, loss=1.6, accuracy=0.398][A
Epoch 2:  59%|█████▉    | 232/390 [00:13<00:08, 17.93it/s, loss=1.59, accuracy=0.406][A
Epoch 2:  60%|█████▉   

Epoch 2:  89%|████████▉ | 348/390 [00:20<00:02, 17.88it/s, loss=1.66, accuracy=0.383][A
Epoch 2:  89%|████████▉ | 349/390 [00:20<00:02, 17.90it/s, loss=1.66, accuracy=0.383][A
Epoch 2:  89%|████████▉ | 349/390 [00:20<00:02, 17.90it/s, loss=1.53, accuracy=0.469][A
Epoch 2:  90%|████████▉ | 350/390 [00:20<00:02, 17.90it/s, loss=1.59, accuracy=0.469][A
Epoch 2:  90%|█████████ | 351/390 [00:20<00:02, 17.88it/s, loss=1.59, accuracy=0.469][A
Epoch 2:  90%|█████████ | 351/390 [00:20<00:02, 17.88it/s, loss=1.55, accuracy=0.445][A
Epoch 2:  90%|█████████ | 352/390 [00:20<00:02, 17.88it/s, loss=1.55, accuracy=0.406][A
Epoch 2:  91%|█████████ | 353/390 [00:20<00:02, 17.87it/s, loss=1.55, accuracy=0.406][A
Epoch 2:  91%|█████████ | 353/390 [00:20<00:02, 17.87it/s, loss=1.67, accuracy=0.43] [A
Epoch 2:  91%|█████████ | 354/390 [00:20<00:02, 17.87it/s, loss=1.63, accuracy=0.383][A
Epoch 2:  91%|█████████ | 355/390 [00:20<00:01, 17.90it/s, loss=1.63, accuracy=0.383][A
Epoch 2:  91%|███████

Epoch 3:  21%|██        | 81/390 [00:05<00:17, 17.91it/s, loss=1.7, accuracy=0.43][A
Epoch 3:  21%|██        | 81/390 [00:05<00:17, 17.91it/s, loss=1.62, accuracy=0.406][A
Epoch 3:  21%|██        | 82/390 [00:05<00:17, 17.91it/s, loss=1.55, accuracy=0.383][A
Epoch 3:  21%|██▏       | 83/390 [00:05<00:17, 17.84it/s, loss=1.55, accuracy=0.383][A
Epoch 3:  21%|██▏       | 83/390 [00:05<00:17, 17.84it/s, loss=1.63, accuracy=0.414][A
Epoch 3:  22%|██▏       | 84/390 [00:05<00:17, 17.84it/s, loss=1.47, accuracy=0.453][A
Epoch 3:  22%|██▏       | 85/390 [00:05<00:17, 17.79it/s, loss=1.47, accuracy=0.453][A
Epoch 3:  22%|██▏       | 85/390 [00:05<00:17, 17.79it/s, loss=1.44, accuracy=0.609][A
Epoch 3:  22%|██▏       | 86/390 [00:05<00:17, 17.79it/s, loss=1.52, accuracy=0.469][A
Epoch 3:  22%|██▏       | 87/390 [00:05<00:17, 17.82it/s, loss=1.52, accuracy=0.469][A
Epoch 3:  22%|██▏       | 87/390 [00:05<00:17, 17.82it/s, loss=1.73, accuracy=0.258][A
Epoch 3:  23%|██▎       | 88/390 [

Epoch 3:  52%|█████▏    | 203/390 [00:12<00:10, 17.37it/s, loss=1.54, accuracy=0.461][A
Epoch 3:  52%|█████▏    | 204/390 [00:12<00:10, 17.37it/s, loss=1.37, accuracy=0.422][A
Epoch 3:  53%|█████▎    | 205/390 [00:12<00:10, 17.37it/s, loss=1.37, accuracy=0.422][A
Epoch 3:  53%|█████▎    | 205/390 [00:12<00:10, 17.37it/s, loss=1.41, accuracy=0.477][A
Epoch 3:  53%|█████▎    | 206/390 [00:12<00:10, 17.37it/s, loss=1.55, accuracy=0.469][A
Epoch 3:  53%|█████▎    | 207/390 [00:12<00:10, 17.47it/s, loss=1.55, accuracy=0.469][A
Epoch 3:  53%|█████▎    | 207/390 [00:12<00:10, 17.47it/s, loss=1.53, accuracy=0.422][A
Epoch 3:  53%|█████▎    | 208/390 [00:12<00:10, 17.47it/s, loss=1.45, accuracy=0.445][A
Epoch 3:  54%|█████▎    | 209/390 [00:12<00:10, 17.43it/s, loss=1.45, accuracy=0.445][A
Epoch 3:  54%|█████▎    | 209/390 [00:12<00:10, 17.43it/s, loss=1.53, accuracy=0.438][A
Epoch 3:  54%|█████▍    | 210/390 [00:12<00:10, 17.43it/s, loss=1.54, accuracy=0.469][A
Epoch 3:  54%|█████▍ 

Epoch 3:  84%|████████▎ | 326/390 [00:19<00:03, 17.86it/s, loss=1.44, accuracy=0.5] [A
Epoch 3:  84%|████████▍ | 327/390 [00:19<00:03, 17.87it/s, loss=1.44, accuracy=0.5][A
Epoch 3:  84%|████████▍ | 327/390 [00:19<00:03, 17.87it/s, loss=1.37, accuracy=0.508][A
Epoch 3:  84%|████████▍ | 328/390 [00:19<00:03, 17.87it/s, loss=1.49, accuracy=0.445][A
Epoch 3:  84%|████████▍ | 329/390 [00:19<00:03, 17.86it/s, loss=1.49, accuracy=0.445][A
Epoch 3:  84%|████████▍ | 329/390 [00:19<00:03, 17.86it/s, loss=1.53, accuracy=0.445][A
Epoch 3:  85%|████████▍ | 330/390 [00:19<00:03, 17.86it/s, loss=1.48, accuracy=0.453][A
Epoch 3:  85%|████████▍ | 331/390 [00:19<00:03, 17.86it/s, loss=1.48, accuracy=0.453][A
Epoch 3:  85%|████████▍ | 331/390 [00:19<00:03, 17.86it/s, loss=1.35, accuracy=0.609][A
Epoch 3:  85%|████████▌ | 332/390 [00:19<00:03, 17.86it/s, loss=1.35, accuracy=0.539][A
Epoch 3:  85%|████████▌ | 333/390 [00:19<00:03, 17.77it/s, loss=1.35, accuracy=0.539][A
Epoch 3:  85%|████████▌ 

Epoch 4:  15%|█▌        | 59/390 [00:04<00:19, 16.73it/s, loss=1.45, accuracy=0.469][A
Epoch 4:  15%|█▌        | 59/390 [00:04<00:19, 16.73it/s, loss=1.41, accuracy=0.516][A
Epoch 4:  15%|█▌        | 60/390 [00:04<00:19, 16.73it/s, loss=1.5, accuracy=0.414] [A
Epoch 4:  16%|█▌        | 61/390 [00:04<00:19, 16.73it/s, loss=1.5, accuracy=0.414][A
Epoch 4:  16%|█▌        | 61/390 [00:04<00:19, 16.73it/s, loss=1.42, accuracy=0.469][A
Epoch 4:  16%|█▌        | 62/390 [00:04<00:19, 16.73it/s, loss=1.45, accuracy=0.484][A
Epoch 4:  16%|█▌        | 63/390 [00:04<00:19, 16.76it/s, loss=1.45, accuracy=0.484][A
Epoch 4:  16%|█▌        | 63/390 [00:04<00:19, 16.76it/s, loss=1.45, accuracy=0.453][A
Epoch 4:  16%|█▋        | 64/390 [00:04<00:19, 16.76it/s, loss=1.46, accuracy=0.438][A
Epoch 4:  17%|█▋        | 65/390 [00:04<00:19, 16.76it/s, loss=1.46, accuracy=0.438][A
Epoch 4:  17%|█▋        | 65/390 [00:04<00:19, 16.76it/s, loss=1.36, accuracy=0.516][A
Epoch 4:  17%|█▋        | 66/390 

Epoch 4:  46%|████▋     | 181/390 [00:11<00:11, 17.80it/s, loss=1.29, accuracy=0.516][A
Epoch 4:  47%|████▋     | 182/390 [00:11<00:11, 17.80it/s, loss=1.44, accuracy=0.461][A
Epoch 4:  47%|████▋     | 183/390 [00:11<00:11, 17.78it/s, loss=1.44, accuracy=0.461][A
Epoch 4:  47%|████▋     | 183/390 [00:11<00:11, 17.78it/s, loss=1.22, accuracy=0.602][A
Epoch 4:  47%|████▋     | 184/390 [00:11<00:11, 17.78it/s, loss=1.36, accuracy=0.5]  [A
Epoch 4:  47%|████▋     | 185/390 [00:11<00:11, 17.77it/s, loss=1.36, accuracy=0.5][A
Epoch 4:  47%|████▋     | 185/390 [00:11<00:11, 17.77it/s, loss=1.3, accuracy=0.508][A
Epoch 4:  48%|████▊     | 186/390 [00:11<00:11, 17.77it/s, loss=1.33, accuracy=0.547][A
Epoch 4:  48%|████▊     | 187/390 [00:11<00:11, 17.82it/s, loss=1.33, accuracy=0.547][A
Epoch 4:  48%|████▊     | 187/390 [00:11<00:11, 17.82it/s, loss=1.29, accuracy=0.547][A
Epoch 4:  48%|████▊     | 188/390 [00:11<00:11, 17.82it/s, loss=1.33, accuracy=0.453][A
Epoch 4:  48%|████▊     

Epoch 4:  78%|███████▊  | 304/390 [00:18<00:04, 17.93it/s, loss=1.19, accuracy=0.562][A
Epoch 4:  78%|███████▊  | 305/390 [00:18<00:04, 17.87it/s, loss=1.19, accuracy=0.562][A
Epoch 4:  78%|███████▊  | 305/390 [00:18<00:04, 17.87it/s, loss=1.36, accuracy=0.5]  [A
Epoch 4:  78%|███████▊  | 306/390 [00:18<00:04, 17.87it/s, loss=1.38, accuracy=0.516][A
Epoch 4:  79%|███████▊  | 307/390 [00:18<00:04, 17.78it/s, loss=1.38, accuracy=0.516][A
Epoch 4:  79%|███████▊  | 307/390 [00:18<00:04, 17.78it/s, loss=1.31, accuracy=0.539][A
Epoch 4:  79%|███████▉  | 308/390 [00:18<00:04, 17.78it/s, loss=1.32, accuracy=0.57] [A
Epoch 4:  79%|███████▉  | 309/390 [00:18<00:04, 17.78it/s, loss=1.32, accuracy=0.57][A
Epoch 4:  79%|███████▉  | 309/390 [00:18<00:04, 17.78it/s, loss=1.45, accuracy=0.484][A
Epoch 4:  79%|███████▉  | 310/390 [00:18<00:04, 17.78it/s, loss=1.42, accuracy=0.461][A
Epoch 4:  80%|███████▉  | 311/390 [00:18<00:04, 17.79it/s, loss=1.42, accuracy=0.461][A
Epoch 4:  80%|███████▉

Epoch 5:   9%|▉         | 37/390 [00:02<00:21, 16.22it/s, loss=1.34, accuracy=0.578][A
Epoch 5:   9%|▉         | 37/390 [00:02<00:21, 16.22it/s, loss=1.18, accuracy=0.578][A
Epoch 5:  10%|▉         | 38/390 [00:02<00:21, 16.22it/s, loss=1.41, accuracy=0.461][A
Epoch 5:  10%|█         | 39/390 [00:03<00:21, 16.25it/s, loss=1.41, accuracy=0.461][A
Epoch 5:  10%|█         | 39/390 [00:03<00:21, 16.25it/s, loss=1.48, accuracy=0.484][A
Epoch 5:  10%|█         | 40/390 [00:03<00:21, 16.25it/s, loss=1.46, accuracy=0.5]  [A
Epoch 5:  11%|█         | 41/390 [00:03<00:21, 16.42it/s, loss=1.46, accuracy=0.5][A
Epoch 5:  11%|█         | 41/390 [00:03<00:21, 16.42it/s, loss=1.42, accuracy=0.523][A
Epoch 5:  11%|█         | 42/390 [00:03<00:21, 16.42it/s, loss=1.35, accuracy=0.492][A
Epoch 5:  11%|█         | 43/390 [00:03<00:21, 16.50it/s, loss=1.35, accuracy=0.492][A
Epoch 5:  11%|█         | 43/390 [00:03<00:21, 16.50it/s, loss=1.38, accuracy=0.508][A
Epoch 5:  11%|█▏        | 44/390 [

Epoch 5:  41%|████      | 160/390 [00:09<00:13, 17.51it/s, loss=1.36, accuracy=0.484][A
Epoch 5:  41%|████▏     | 161/390 [00:09<00:13, 17.53it/s, loss=1.36, accuracy=0.484][A
Epoch 5:  41%|████▏     | 161/390 [00:09<00:13, 17.53it/s, loss=1.37, accuracy=0.484][A
Epoch 5:  42%|████▏     | 162/390 [00:09<00:13, 17.53it/s, loss=1.28, accuracy=0.508][A
Epoch 5:  42%|████▏     | 163/390 [00:10<00:12, 17.55it/s, loss=1.28, accuracy=0.508][A
Epoch 5:  42%|████▏     | 163/390 [00:10<00:12, 17.55it/s, loss=1.28, accuracy=0.547][A
Epoch 5:  42%|████▏     | 164/390 [00:10<00:12, 17.55it/s, loss=1.24, accuracy=0.555][A
Epoch 5:  42%|████▏     | 165/390 [00:10<00:12, 17.61it/s, loss=1.24, accuracy=0.555][A
Epoch 5:  42%|████▏     | 165/390 [00:10<00:12, 17.61it/s, loss=1.26, accuracy=0.484][A
Epoch 5:  43%|████▎     | 166/390 [00:10<00:12, 17.61it/s, loss=1.26, accuracy=0.547][A
Epoch 5:  43%|████▎     | 167/390 [00:10<00:12, 17.66it/s, loss=1.26, accuracy=0.547][A
Epoch 5:  43%|████▎  

Epoch 5:  73%|███████▎  | 283/390 [00:16<00:06, 17.73it/s, loss=1.27, accuracy=0.547][A
Epoch 5:  73%|███████▎  | 283/390 [00:16<00:06, 17.73it/s, loss=1.37, accuracy=0.523][A
Epoch 5:  73%|███████▎  | 284/390 [00:16<00:05, 17.73it/s, loss=1.33, accuracy=0.508][A
Epoch 5:  73%|███████▎  | 285/390 [00:16<00:05, 17.76it/s, loss=1.33, accuracy=0.508][A
Epoch 5:  73%|███████▎  | 285/390 [00:16<00:05, 17.76it/s, loss=1.21, accuracy=0.555][A
Epoch 5:  73%|███████▎  | 286/390 [00:16<00:05, 17.76it/s, loss=1.17, accuracy=0.586][A
Epoch 5:  74%|███████▎  | 287/390 [00:17<00:05, 17.78it/s, loss=1.17, accuracy=0.586][A
Epoch 5:  74%|███████▎  | 287/390 [00:17<00:05, 17.78it/s, loss=1.27, accuracy=0.562][A
Epoch 5:  74%|███████▍  | 288/390 [00:17<00:05, 17.78it/s, loss=1.23, accuracy=0.547][A
Epoch 5:  74%|███████▍  | 289/390 [00:17<00:05, 17.77it/s, loss=1.23, accuracy=0.547][A
Epoch 5:  74%|███████▍  | 289/390 [00:17<00:05, 17.77it/s, loss=1.42, accuracy=0.539][A
Epoch 5:  74%|███████

Epoch 6:   4%|▍         | 15/390 [00:01<00:41,  9.06it/s, loss=1.26, accuracy=0.523][A
Epoch 6:   4%|▍         | 15/390 [00:01<00:41,  9.06it/s, loss=1.19, accuracy=0.594][A
Epoch 6:   4%|▍         | 16/390 [00:01<00:41,  9.06it/s, loss=1.43, accuracy=0.453][A
Epoch 6:   4%|▍         | 17/390 [00:01<00:35, 10.58it/s, loss=1.43, accuracy=0.453][A
Epoch 6:   4%|▍         | 17/390 [00:01<00:35, 10.58it/s, loss=1.13, accuracy=0.602][A
Epoch 6:   5%|▍         | 18/390 [00:01<00:35, 10.58it/s, loss=1.27, accuracy=0.547][A
Epoch 6:   5%|▍         | 19/390 [00:01<00:30, 12.18it/s, loss=1.27, accuracy=0.547][A
Epoch 6:   5%|▍         | 19/390 [00:01<00:30, 12.18it/s, loss=1.23, accuracy=0.539][A
Epoch 6:   5%|▌         | 20/390 [00:01<00:30, 12.18it/s, loss=1.26, accuracy=0.523][A
Epoch 6:   5%|▌         | 21/390 [00:01<00:27, 13.64it/s, loss=1.26, accuracy=0.523][A
Epoch 6:   5%|▌         | 21/390 [00:01<00:27, 13.64it/s, loss=1.17, accuracy=0.594][A
Epoch 6:   6%|▌         | 22/390

Epoch 6:  35%|███▌      | 138/390 [00:08<00:14, 17.62it/s, loss=1.19, accuracy=0.586][A
Epoch 6:  36%|███▌      | 139/390 [00:08<00:14, 17.62it/s, loss=1.19, accuracy=0.586][A
Epoch 6:  36%|███▌      | 139/390 [00:08<00:14, 17.62it/s, loss=1.45, accuracy=0.508][A
Epoch 6:  36%|███▌      | 140/390 [00:08<00:14, 17.62it/s, loss=1.12, accuracy=0.641][A
Epoch 6:  36%|███▌      | 141/390 [00:08<00:14, 17.63it/s, loss=1.12, accuracy=0.641][A
Epoch 6:  36%|███▌      | 141/390 [00:08<00:14, 17.63it/s, loss=1.25, accuracy=0.547][A
Epoch 6:  36%|███▋      | 142/390 [00:08<00:14, 17.63it/s, loss=1.2, accuracy=0.57]  [A
Epoch 6:  37%|███▋      | 143/390 [00:08<00:13, 17.68it/s, loss=1.2, accuracy=0.57][A
Epoch 6:  37%|███▋      | 143/390 [00:08<00:13, 17.68it/s, loss=1.1, accuracy=0.617][A
Epoch 6:  37%|███▋      | 144/390 [00:08<00:13, 17.68it/s, loss=1.23, accuracy=0.594][A
Epoch 6:  37%|███▋      | 145/390 [00:08<00:13, 17.57it/s, loss=1.23, accuracy=0.594][A
Epoch 6:  37%|███▋      

Epoch 6:  67%|██████▋   | 261/390 [00:15<00:07, 17.79it/s, loss=1.13, accuracy=0.633][A
Epoch 6:  67%|██████▋   | 261/390 [00:15<00:07, 17.79it/s, loss=1.37, accuracy=0.508][A
Epoch 6:  67%|██████▋   | 262/390 [00:15<00:07, 17.79it/s, loss=1.31, accuracy=0.516][A
Epoch 6:  67%|██████▋   | 263/390 [00:15<00:07, 17.78it/s, loss=1.31, accuracy=0.516][A
Epoch 6:  67%|██████▋   | 263/390 [00:15<00:07, 17.78it/s, loss=1.29, accuracy=0.516][A
Epoch 6:  68%|██████▊   | 264/390 [00:15<00:07, 17.78it/s, loss=1.24, accuracy=0.484][A
Epoch 6:  68%|██████▊   | 265/390 [00:15<00:07, 17.73it/s, loss=1.24, accuracy=0.484][A
Epoch 6:  68%|██████▊   | 265/390 [00:15<00:07, 17.73it/s, loss=1.25, accuracy=0.586][A
Epoch 6:  68%|██████▊   | 266/390 [00:15<00:06, 17.73it/s, loss=1.27, accuracy=0.562][A
Epoch 6:  68%|██████▊   | 267/390 [00:15<00:06, 17.75it/s, loss=1.27, accuracy=0.562][A
Epoch 6:  68%|██████▊   | 267/390 [00:15<00:06, 17.75it/s, loss=1.08, accuracy=0.664][A
Epoch 6:  69%|██████▊

Epoch 6:  98%|█████████▊| 383/390 [00:21<00:00, 17.97it/s, loss=1.22, accuracy=0.562][A
Epoch 6:  98%|█████████▊| 384/390 [00:22<00:00, 17.97it/s, loss=1.21, accuracy=0.578][A
Epoch 6:  99%|█████████▊| 385/390 [00:22<00:00, 17.95it/s, loss=1.21, accuracy=0.578][A
Epoch 6:  99%|█████████▊| 385/390 [00:22<00:00, 17.95it/s, loss=1.04, accuracy=0.609][A
Epoch 6:  99%|█████████▉| 386/390 [00:22<00:00, 17.95it/s, loss=1.2, accuracy=0.539] [A
Epoch 6:  99%|█████████▉| 387/390 [00:22<00:00, 17.93it/s, loss=1.2, accuracy=0.539][A
Epoch 6:  99%|█████████▉| 387/390 [00:22<00:00, 17.93it/s, loss=1.14, accuracy=0.633][A
Epoch 6:  99%|█████████▉| 388/390 [00:22<00:00, 17.93it/s, loss=1.19, accuracy=0.578][A
Epoch 6: 100%|█████████▉| 389/390 [00:22<00:00, 17.90it/s, loss=1.19, accuracy=0.578][A
Epoch 6: 100%|█████████▉| 389/390 [00:22<00:00, 17.90it/s, loss=1.24, accuracy=0.523][A
Epoch 6: 100%|██████████| 390/390 [00:24<00:00, 15.91it/s, loss=1.22, accuracy=0.564, val_loss=1.25, val_accura

Epoch 7:  30%|██▉       | 116/390 [00:07<00:15, 17.90it/s, loss=1.23, accuracy=0.508][A
Epoch 7:  30%|███       | 117/390 [00:07<00:15, 17.91it/s, loss=1.23, accuracy=0.508][A
Epoch 7:  30%|███       | 117/390 [00:07<00:15, 17.91it/s, loss=1.24, accuracy=0.555][A
Epoch 7:  30%|███       | 118/390 [00:07<00:15, 17.91it/s, loss=1.24, accuracy=0.57] [A
Epoch 7:  31%|███       | 119/390 [00:07<00:15, 17.93it/s, loss=1.24, accuracy=0.57][A
Epoch 7:  31%|███       | 119/390 [00:07<00:15, 17.93it/s, loss=1.26, accuracy=0.508][A
Epoch 7:  31%|███       | 120/390 [00:07<00:15, 17.93it/s, loss=1.15, accuracy=0.586][A
Epoch 7:  31%|███       | 121/390 [00:07<00:15, 17.93it/s, loss=1.15, accuracy=0.586][A
Epoch 7:  31%|███       | 121/390 [00:07<00:15, 17.93it/s, loss=1.15, accuracy=0.617][A
Epoch 7:  31%|███▏      | 122/390 [00:07<00:14, 17.93it/s, loss=1.3, accuracy=0.594] [A
Epoch 7:  32%|███▏      | 123/390 [00:07<00:14, 18.03it/s, loss=1.3, accuracy=0.594][A
Epoch 7:  32%|███▏     

Epoch 7:  61%|██████▏   | 239/390 [00:13<00:08, 17.87it/s, loss=1.11, accuracy=0.633][A
Epoch 7:  61%|██████▏   | 239/390 [00:13<00:08, 17.87it/s, loss=1.1, accuracy=0.594] [A
Epoch 7:  62%|██████▏   | 240/390 [00:14<00:08, 17.87it/s, loss=1.32, accuracy=0.5] [A
Epoch 7:  62%|██████▏   | 241/390 [00:14<00:08, 17.90it/s, loss=1.32, accuracy=0.5][A
Epoch 7:  62%|██████▏   | 241/390 [00:14<00:08, 17.90it/s, loss=1.31, accuracy=0.562][A
Epoch 7:  62%|██████▏   | 242/390 [00:14<00:08, 17.90it/s, loss=1.12, accuracy=0.641][A
Epoch 7:  62%|██████▏   | 243/390 [00:14<00:08, 17.97it/s, loss=1.12, accuracy=0.641][A
Epoch 7:  62%|██████▏   | 243/390 [00:14<00:08, 17.97it/s, loss=1.03, accuracy=0.625][A
Epoch 7:  63%|██████▎   | 244/390 [00:14<00:08, 17.97it/s, loss=1.08, accuracy=0.641][A
Epoch 7:  63%|██████▎   | 245/390 [00:14<00:08, 17.97it/s, loss=1.08, accuracy=0.641][A
Epoch 7:  63%|██████▎   | 245/390 [00:14<00:08, 17.97it/s, loss=1.06, accuracy=0.648][A
Epoch 7:  63%|██████▎   

Epoch 7:  93%|█████████▎| 361/390 [00:20<00:01, 17.82it/s, loss=1.03, accuracy=0.656][A
Epoch 7:  93%|█████████▎| 362/390 [00:20<00:01, 17.82it/s, loss=1.17, accuracy=0.57] [A
Epoch 7:  93%|█████████▎| 363/390 [00:20<00:01, 17.85it/s, loss=1.17, accuracy=0.57][A
Epoch 7:  93%|█████████▎| 363/390 [00:20<00:01, 17.85it/s, loss=1.08, accuracy=0.609][A
Epoch 7:  93%|█████████▎| 364/390 [00:20<00:01, 17.85it/s, loss=1.46, accuracy=0.453][A
Epoch 7:  94%|█████████▎| 365/390 [00:21<00:01, 17.89it/s, loss=1.46, accuracy=0.453][A
Epoch 7:  94%|█████████▎| 365/390 [00:21<00:01, 17.89it/s, loss=1.1, accuracy=0.555] [A
Epoch 7:  94%|█████████▍| 366/390 [00:21<00:01, 17.89it/s, loss=1.06, accuracy=0.594][A
Epoch 7:  94%|█████████▍| 367/390 [00:21<00:01, 17.93it/s, loss=1.06, accuracy=0.594][A
Epoch 7:  94%|█████████▍| 367/390 [00:21<00:01, 17.93it/s, loss=1.14, accuracy=0.617][A
Epoch 7:  94%|█████████▍| 368/390 [00:21<00:01, 17.93it/s, loss=1.15, accuracy=0.555][A
Epoch 7:  95%|████████

Epoch 8:  24%|██▍       | 94/390 [00:05<00:16, 17.89it/s, loss=1.1, accuracy=0.594] [A
Epoch 8:  24%|██▍       | 95/390 [00:05<00:16, 17.91it/s, loss=1.1, accuracy=0.594][A
Epoch 8:  24%|██▍       | 95/390 [00:05<00:16, 17.91it/s, loss=1.13, accuracy=0.648][A
Epoch 8:  25%|██▍       | 96/390 [00:06<00:16, 17.91it/s, loss=1.06, accuracy=0.617][A
Epoch 8:  25%|██▍       | 97/390 [00:06<00:16, 17.90it/s, loss=1.06, accuracy=0.617][A
Epoch 8:  25%|██▍       | 97/390 [00:06<00:16, 17.90it/s, loss=1.15, accuracy=0.609][A
Epoch 8:  25%|██▌       | 98/390 [00:06<00:16, 17.90it/s, loss=1.13, accuracy=0.617][A
Epoch 8:  25%|██▌       | 99/390 [00:06<00:16, 17.91it/s, loss=1.13, accuracy=0.617][A
Epoch 8:  25%|██▌       | 99/390 [00:06<00:16, 17.91it/s, loss=1.08, accuracy=0.586][A
Epoch 8:  26%|██▌       | 100/390 [00:06<00:16, 17.91it/s, loss=1.16, accuracy=0.578][A
Epoch 8:  26%|██▌       | 101/390 [00:06<00:16, 17.91it/s, loss=1.16, accuracy=0.578][A
Epoch 8:  26%|██▌       | 101/3

Epoch 8:  55%|█████▌    | 216/390 [00:12<00:09, 17.62it/s, loss=1.03, accuracy=0.648][A
Epoch 8:  56%|█████▌    | 217/390 [00:12<00:09, 17.65it/s, loss=1.03, accuracy=0.648][A
Epoch 8:  56%|█████▌    | 217/390 [00:12<00:09, 17.65it/s, loss=1.09, accuracy=0.625][A
Epoch 8:  56%|█████▌    | 218/390 [00:12<00:09, 17.65it/s, loss=1.12, accuracy=0.594][A
Epoch 8:  56%|█████▌    | 219/390 [00:12<00:09, 17.62it/s, loss=1.12, accuracy=0.594][A
Epoch 8:  56%|█████▌    | 219/390 [00:12<00:09, 17.62it/s, loss=0.979, accuracy=0.664][A
Epoch 8:  56%|█████▋    | 220/390 [00:13<00:09, 17.62it/s, loss=1.08, accuracy=0.625] [A
Epoch 8:  57%|█████▋    | 221/390 [00:13<00:09, 17.71it/s, loss=1.08, accuracy=0.625][A
Epoch 8:  57%|█████▋    | 221/390 [00:13<00:09, 17.71it/s, loss=1.08, accuracy=0.594][A
Epoch 8:  57%|█████▋    | 222/390 [00:13<00:09, 17.71it/s, loss=1.04, accuracy=0.68] [A
Epoch 8:  57%|█████▋    | 223/390 [00:13<00:09, 17.73it/s, loss=1.04, accuracy=0.68][A
Epoch 8:  57%|█████▋

Epoch 8:  87%|████████▋ | 338/390 [00:19<00:02, 17.80it/s, loss=1.15, accuracy=0.547] [A
Epoch 8:  87%|████████▋ | 339/390 [00:19<00:02, 17.77it/s, loss=1.15, accuracy=0.547][A
Epoch 8:  87%|████████▋ | 339/390 [00:19<00:02, 17.77it/s, loss=1.09, accuracy=0.625][A
Epoch 8:  87%|████████▋ | 340/390 [00:19<00:02, 17.77it/s, loss=1.14, accuracy=0.602][A
Epoch 8:  87%|████████▋ | 341/390 [00:19<00:02, 17.82it/s, loss=1.14, accuracy=0.602][A
Epoch 8:  87%|████████▋ | 341/390 [00:19<00:02, 17.82it/s, loss=1.08, accuracy=0.602][A
Epoch 8:  88%|████████▊ | 342/390 [00:19<00:02, 17.82it/s, loss=1.1, accuracy=0.633] [A
Epoch 8:  88%|████████▊ | 343/390 [00:19<00:02, 17.85it/s, loss=1.1, accuracy=0.633][A
Epoch 8:  88%|████████▊ | 343/390 [00:19<00:02, 17.85it/s, loss=1.11, accuracy=0.648][A
Epoch 8:  88%|████████▊ | 344/390 [00:19<00:02, 17.85it/s, loss=1.11, accuracy=0.586][A
Epoch 8:  88%|████████▊ | 345/390 [00:20<00:02, 17.87it/s, loss=1.11, accuracy=0.586][A
Epoch 8:  88%|███████

Epoch 9:  18%|█▊        | 70/390 [00:04<00:17, 17.86it/s, loss=1.08, accuracy=0.648][A
Epoch 9:  18%|█▊        | 71/390 [00:04<00:17, 17.86it/s, loss=1.08, accuracy=0.648][A
Epoch 9:  18%|█▊        | 71/390 [00:04<00:17, 17.86it/s, loss=1.16, accuracy=0.609][A
Epoch 9:  18%|█▊        | 72/390 [00:04<00:17, 17.86it/s, loss=1.03, accuracy=0.641][A
Epoch 9:  19%|█▊        | 73/390 [00:04<00:17, 17.87it/s, loss=1.03, accuracy=0.641][A
Epoch 9:  19%|█▊        | 73/390 [00:04<00:17, 17.87it/s, loss=0.955, accuracy=0.711][A
Epoch 9:  19%|█▉        | 74/390 [00:04<00:17, 17.87it/s, loss=1.09, accuracy=0.602] [A
Epoch 9:  19%|█▉        | 75/390 [00:04<00:17, 17.89it/s, loss=1.09, accuracy=0.602][A
Epoch 9:  19%|█▉        | 75/390 [00:04<00:17, 17.89it/s, loss=0.914, accuracy=0.688][A
Epoch 9:  19%|█▉        | 76/390 [00:04<00:17, 17.89it/s, loss=1.04, accuracy=0.602] [A
Epoch 9:  20%|█▉        | 77/390 [00:05<00:17, 17.93it/s, loss=1.04, accuracy=0.602][A
Epoch 9:  20%|█▉        | 77

Epoch 9:  49%|████▉     | 192/390 [00:11<00:11, 17.89it/s, loss=0.952, accuracy=0.648][A
Epoch 9:  49%|████▉     | 193/390 [00:11<00:11, 17.90it/s, loss=0.952, accuracy=0.648][A
Epoch 9:  49%|████▉     | 193/390 [00:11<00:11, 17.90it/s, loss=0.885, accuracy=0.68] [A
Epoch 9:  50%|████▉     | 194/390 [00:11<00:10, 17.90it/s, loss=1.13, accuracy=0.57] [A
Epoch 9:  50%|█████     | 195/390 [00:11<00:10, 17.93it/s, loss=1.13, accuracy=0.57][A
Epoch 9:  50%|█████     | 195/390 [00:11<00:10, 17.93it/s, loss=1.05, accuracy=0.633][A
Epoch 9:  50%|█████     | 196/390 [00:11<00:10, 17.93it/s, loss=0.847, accuracy=0.719][A
Epoch 9:  51%|█████     | 197/390 [00:11<00:10, 18.00it/s, loss=0.847, accuracy=0.719][A
Epoch 9:  51%|█████     | 197/390 [00:11<00:10, 18.00it/s, loss=1.03, accuracy=0.648] [A
Epoch 9:  51%|█████     | 198/390 [00:11<00:10, 18.00it/s, loss=1.06, accuracy=0.664][A
Epoch 9:  51%|█████     | 199/390 [00:11<00:10, 17.95it/s, loss=1.06, accuracy=0.664][A
Epoch 9:  51%|██

Epoch 9:  80%|████████  | 313/390 [00:18<00:04, 17.69it/s, loss=0.904, accuracy=0.703][A
Epoch 9:  81%|████████  | 314/390 [00:18<00:04, 17.69it/s, loss=1.01, accuracy=0.648] [A
Epoch 9:  81%|████████  | 315/390 [00:18<00:04, 17.69it/s, loss=1.01, accuracy=0.648][A
Epoch 9:  81%|████████  | 315/390 [00:18<00:04, 17.69it/s, loss=0.971, accuracy=0.656][A
Epoch 9:  81%|████████  | 316/390 [00:18<00:04, 17.69it/s, loss=1.03, accuracy=0.648] [A
Epoch 9:  81%|████████▏ | 317/390 [00:18<00:04, 17.70it/s, loss=1.03, accuracy=0.648][A
Epoch 9:  81%|████████▏ | 317/390 [00:18<00:04, 17.70it/s, loss=0.935, accuracy=0.727][A
Epoch 9:  82%|████████▏ | 318/390 [00:18<00:04, 17.70it/s, loss=0.875, accuracy=0.695][A
Epoch 9:  82%|████████▏ | 319/390 [00:18<00:03, 17.80it/s, loss=0.875, accuracy=0.695][A
Epoch 9:  82%|████████▏ | 319/390 [00:18<00:03, 17.80it/s, loss=1.08, accuracy=0.617] [A
Epoch 9:  82%|████████▏ | 320/390 [00:18<00:03, 17.80it/s, loss=1.11, accuracy=0.609][A
Epoch 9:  82%

Epoch 10:  11%|█▏        | 44/390 [00:03<00:19, 17.63it/s, loss=0.926, accuracy=0.672][A
Epoch 10:  12%|█▏        | 45/390 [00:03<00:19, 17.77it/s, loss=0.926, accuracy=0.672][A
Epoch 10:  12%|█▏        | 45/390 [00:03<00:19, 17.77it/s, loss=0.933, accuracy=0.672][A
Epoch 10:  12%|█▏        | 46/390 [00:03<00:19, 17.77it/s, loss=0.984, accuracy=0.617][A
Epoch 10:  12%|█▏        | 47/390 [00:03<00:19, 17.80it/s, loss=0.984, accuracy=0.617][A
Epoch 10:  12%|█▏        | 47/390 [00:03<00:19, 17.80it/s, loss=0.861, accuracy=0.703][A
Epoch 10:  12%|█▏        | 48/390 [00:03<00:19, 17.80it/s, loss=0.986, accuracy=0.656][A
Epoch 10:  13%|█▎        | 49/390 [00:03<00:19, 17.83it/s, loss=0.986, accuracy=0.656][A
Epoch 10:  13%|█▎        | 49/390 [00:03<00:19, 17.83it/s, loss=0.899, accuracy=0.68] [A
Epoch 10:  13%|█▎        | 50/390 [00:03<00:19, 17.83it/s, loss=0.924, accuracy=0.711][A
Epoch 10:  13%|█▎        | 51/390 [00:03<00:19, 17.78it/s, loss=0.924, accuracy=0.711][A
Epoch 10: 

Epoch 10:  42%|████▏     | 165/390 [00:10<00:12, 18.04it/s, loss=0.94, accuracy=0.648][A
Epoch 10:  42%|████▏     | 165/390 [00:10<00:12, 18.04it/s, loss=0.899, accuracy=0.68][A
Epoch 10:  43%|████▎     | 166/390 [00:10<00:12, 18.04it/s, loss=0.907, accuracy=0.664][A
Epoch 10:  43%|████▎     | 167/390 [00:10<00:12, 18.06it/s, loss=0.907, accuracy=0.664][A
Epoch 10:  43%|████▎     | 167/390 [00:10<00:12, 18.06it/s, loss=1.01, accuracy=0.688] [A
Epoch 10:  43%|████▎     | 168/390 [00:10<00:12, 18.06it/s, loss=1, accuracy=0.648]   [A
Epoch 10:  43%|████▎     | 169/390 [00:10<00:12, 18.09it/s, loss=1, accuracy=0.648][A
Epoch 10:  43%|████▎     | 169/390 [00:10<00:12, 18.09it/s, loss=0.83, accuracy=0.672][A
Epoch 10:  44%|████▎     | 170/390 [00:10<00:12, 18.09it/s, loss=0.911, accuracy=0.688][A
Epoch 10:  44%|████▍     | 171/390 [00:10<00:12, 18.11it/s, loss=0.911, accuracy=0.688][A
Epoch 10:  44%|████▍     | 171/390 [00:10<00:12, 18.11it/s, loss=0.924, accuracy=0.68] [A
Epoch 1

Epoch 10:  73%|███████▎  | 285/390 [00:16<00:05, 17.63it/s, loss=0.94, accuracy=0.672][A
Epoch 10:  73%|███████▎  | 285/390 [00:16<00:05, 17.63it/s, loss=0.929, accuracy=0.688][A
Epoch 10:  73%|███████▎  | 286/390 [00:16<00:05, 17.63it/s, loss=0.848, accuracy=0.734][A
Epoch 10:  74%|███████▎  | 287/390 [00:16<00:05, 17.62it/s, loss=0.848, accuracy=0.734][A
Epoch 10:  74%|███████▎  | 287/390 [00:16<00:05, 17.62it/s, loss=0.892, accuracy=0.695][A
Epoch 10:  74%|███████▍  | 288/390 [00:16<00:05, 17.62it/s, loss=0.891, accuracy=0.688][A
Epoch 10:  74%|███████▍  | 289/390 [00:16<00:05, 17.65it/s, loss=0.891, accuracy=0.688][A
Epoch 10:  74%|███████▍  | 289/390 [00:16<00:05, 17.65it/s, loss=1.12, accuracy=0.586] [A
Epoch 10:  74%|███████▍  | 290/390 [00:17<00:05, 17.65it/s, loss=0.906, accuracy=0.727][A
Epoch 10:  75%|███████▍  | 291/390 [00:17<00:05, 17.68it/s, loss=0.906, accuracy=0.727][A
Epoch 10:  75%|███████▍  | 291/390 [00:17<00:05, 17.68it/s, loss=1.02, accuracy=0.617] [A


Finished.
final validation accuracy:





{'val_accuracy': 0.5990584935897436}


# TRAIN MODEL USING ENTROPY SGD (ESGD)

In [19]:
def entropy_training(model, optimiser, loss_fn, x, y, epoch):
    
    model.train()
    y_pred = model(x)
    
    def helper():
        def feval():
            #x, y = Variable(x), Variable(y.squeeze())
            bsz = x.size(0)
            optimiser.zero_grad()
            yh = model(x)
            f = loss_fn.forward(yh, y)
            f.backward()

            yp = yh.argmax(axis=1)
            prec1 = 100*torch.sum(yp == y)//bsz
            err = 100.-prec1.item()

            return (f.data.item(), err)
        return feval

    loss, err = optimiser.step(helper(), model, loss_fn)
    loss = torch.tensor(loss)
    return loss, y_pred

In [20]:
if TrainESGD:
    ## initialize model
    model_ESGD = Net().to(DEVICE)
    ## training parameters
    lr = 0.01 
    l2 = 0.0 #l2 regularization
    L = 0    #langevin iterations
    gamma = 1e-4 
    scoping = 1e-3
    noise = 1e-4
    loss_fn = nn.CrossEntropyLoss()
    epochs = 5
    optimiser = ESGD_optim.EntropySGD(model_ESGD.parameters(),
            config = dict(lr=lr, momentum=0.9, nesterov=True, weight_decay=l2,
            L=L, eps=noise, g0=gamma, g1=scoping))
    ## train model
    history_natural = olympic.fit(
        model_ESGD,
        optimiser,
        loss_fn,
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=entropy_training,
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5)
        ]
    )
    ## verify validation accuracy
    print('final validation accuracy:')
    valacc = olympic.evaluate(model_ESGD, val_loader, metrics=['accuracy'],
                         prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    print(valacc['val_accuracy'])
    ## save trained model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_ESGD_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_ESGD,modelname)


Epoch 1:   0%|          | 0/390 [00:00<?, ?it/s][A

Begin training...


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha)

Epoch 1:   0%|          | 1/390 [00:00<04:35,  1.41it/s][A
Epoch 1:   0%|          | 1/390 [00:00<04:35,  1.41it/s, loss=2.52, accuracy=0.0938][A
Epoch 1:   1%|          | 2/390 [00:00<04:34,  1.41it/s, loss=2.28, accuracy=0.164] [A
Epoch 1:   1%|          | 3/390 [00:00<03:21,  1.92it/s, loss=2.28, accuracy=0.164][A
Epoch 1:   1%|          | 3/390 [00:00<03:21,  1.92it/s, loss=2.28, accuracy=0.195][A
Epoch 1:   1%|          | 4/390 [00:00<03:20,  1.92it/s, loss=2.28, accuracy=0.172][A
Epoch 1:   1%|▏         | 5/390 [00:01<02:28,  2.59it/s, loss=2.28, accuracy=0.172][A
Epoch 1:   1%|▏         | 5/390 [00:01<02:28,  2.59it/s, loss=2.12, accuracy=0.203][A
Epoch 1:   2%|▏         | 6/390 [00:01<02:28,  2.59it/s, loss=2.16, accuracy=0.148][A
Epoch 1:   2%|▏         | 7/390 [00:01<01:51,  3.43it/s, loss=2.16, accuracy=0.148][A
Epoch 1:   2%|▏         | 7

Epoch 1:  32%|███▏      | 123/390 [00:09<00:19, 13.70it/s, loss=1.34, accuracy=0.531][A
Epoch 1:  32%|███▏      | 123/390 [00:09<00:19, 13.70it/s, loss=1.21, accuracy=0.609][A
Epoch 1:  32%|███▏      | 124/390 [00:09<00:19, 13.70it/s, loss=1.25, accuracy=0.5]  [A
Epoch 1:  32%|███▏      | 125/390 [00:09<00:19, 13.70it/s, loss=1.25, accuracy=0.5][A
Epoch 1:  32%|███▏      | 125/390 [00:09<00:19, 13.70it/s, loss=1.4, accuracy=0.5] [A
Epoch 1:  32%|███▏      | 126/390 [00:09<00:19, 13.70it/s, loss=1.24, accuracy=0.539][A
Epoch 1:  33%|███▎      | 127/390 [00:09<00:19, 13.72it/s, loss=1.24, accuracy=0.539][A
Epoch 1:  33%|███▎      | 127/390 [00:09<00:19, 13.72it/s, loss=1.24, accuracy=0.547][A
Epoch 1:  33%|███▎      | 128/390 [00:09<00:19, 13.72it/s, loss=1.24, accuracy=0.555][A
Epoch 1:  33%|███▎      | 129/390 [00:10<00:19, 13.74it/s, loss=1.24, accuracy=0.555][A
Epoch 1:  33%|███▎      | 129/390 [00:10<00:19, 13.74it/s, loss=1.29, accuracy=0.523][A
Epoch 1:  33%|███▎      |

Epoch 1:  63%|██████▎   | 245/390 [00:18<00:10, 13.60it/s, loss=1.04, accuracy=0.633][A
Epoch 1:  63%|██████▎   | 245/390 [00:18<00:10, 13.60it/s, loss=0.797, accuracy=0.727][A
Epoch 1:  63%|██████▎   | 246/390 [00:18<00:10, 13.60it/s, loss=1.06, accuracy=0.625] [A
Epoch 1:  63%|██████▎   | 247/390 [00:18<00:10, 13.63it/s, loss=1.06, accuracy=0.625][A
Epoch 1:  63%|██████▎   | 247/390 [00:18<00:10, 13.63it/s, loss=0.957, accuracy=0.625][A
Epoch 1:  64%|██████▎   | 248/390 [00:18<00:10, 13.63it/s, loss=0.963, accuracy=0.641][A
Epoch 1:  64%|██████▍   | 249/390 [00:18<00:10, 13.55it/s, loss=0.963, accuracy=0.641][A
Epoch 1:  64%|██████▍   | 249/390 [00:18<00:10, 13.55it/s, loss=0.973, accuracy=0.625][A
Epoch 1:  64%|██████▍   | 250/390 [00:18<00:10, 13.55it/s, loss=0.986, accuracy=0.641][A
Epoch 1:  64%|██████▍   | 251/390 [00:18<00:10, 13.53it/s, loss=0.986, accuracy=0.641][A
Epoch 1:  64%|██████▍   | 251/390 [00:18<00:10, 13.53it/s, loss=0.994, accuracy=0.648][A
Epoch 1:  65

Epoch 1:  94%|█████████▍| 366/390 [00:27<00:01, 13.76it/s, loss=0.699, accuracy=0.727][A
Epoch 1:  94%|█████████▍| 367/390 [00:27<00:01, 13.76it/s, loss=0.699, accuracy=0.727][A
Epoch 1:  94%|█████████▍| 367/390 [00:27<00:01, 13.76it/s, loss=0.838, accuracy=0.727][A
Epoch 1:  94%|█████████▍| 368/390 [00:27<00:01, 13.76it/s, loss=0.938, accuracy=0.633][A
Epoch 1:  95%|█████████▍| 369/390 [00:27<00:01, 13.76it/s, loss=0.938, accuracy=0.633][A
Epoch 1:  95%|█████████▍| 369/390 [00:27<00:01, 13.76it/s, loss=0.745, accuracy=0.758][A
Epoch 1:  95%|█████████▍| 370/390 [00:27<00:01, 13.76it/s, loss=0.744, accuracy=0.734][A
Epoch 1:  95%|█████████▌| 371/390 [00:27<00:01, 13.75it/s, loss=0.744, accuracy=0.734][A
Epoch 1:  95%|█████████▌| 371/390 [00:27<00:01, 13.75it/s, loss=0.854, accuracy=0.719][A
Epoch 1:  95%|█████████▌| 372/390 [00:27<00:01, 13.75it/s, loss=0.768, accuracy=0.688][A
Epoch 1:  96%|█████████▌| 373/390 [00:27<00:01, 13.75it/s, loss=0.768, accuracy=0.688][A
Epoch 1:  

Epoch 2:  25%|██▌       | 98/390 [00:07<00:21, 13.59it/s, loss=0.882, accuracy=0.688][A
Epoch 2:  25%|██▌       | 99/390 [00:07<00:21, 13.58it/s, loss=0.882, accuracy=0.688][A
Epoch 2:  25%|██▌       | 99/390 [00:07<00:21, 13.58it/s, loss=0.656, accuracy=0.797][A
Epoch 2:  26%|██▌       | 100/390 [00:07<00:21, 13.58it/s, loss=0.714, accuracy=0.734][A
Epoch 2:  26%|██▌       | 101/390 [00:07<00:21, 13.58it/s, loss=0.714, accuracy=0.734][A
Epoch 2:  26%|██▌       | 101/390 [00:07<00:21, 13.58it/s, loss=0.686, accuracy=0.758][A
Epoch 2:  26%|██▌       | 102/390 [00:08<00:21, 13.58it/s, loss=0.784, accuracy=0.758][A
Epoch 2:  26%|██▋       | 103/390 [00:08<00:21, 13.57it/s, loss=0.784, accuracy=0.758][A
Epoch 2:  26%|██▋       | 103/390 [00:08<00:21, 13.57it/s, loss=0.617, accuracy=0.766][A
Epoch 2:  27%|██▋       | 104/390 [00:08<00:21, 13.57it/s, loss=0.813, accuracy=0.773][A
Epoch 2:  27%|██▋       | 105/390 [00:08<00:21, 13.56it/s, loss=0.813, accuracy=0.773][A
Epoch 2:  27%

Epoch 2:  56%|█████▌    | 219/390 [00:16<00:12, 13.50it/s, loss=0.463, accuracy=0.852][A
Epoch 2:  56%|█████▋    | 220/390 [00:16<00:12, 13.50it/s, loss=0.628, accuracy=0.789][A
Epoch 2:  57%|█████▋    | 221/390 [00:16<00:12, 13.43it/s, loss=0.628, accuracy=0.789][A
Epoch 2:  57%|█████▋    | 221/390 [00:16<00:12, 13.43it/s, loss=0.593, accuracy=0.75] [A
Epoch 2:  57%|█████▋    | 222/390 [00:16<00:12, 13.43it/s, loss=0.681, accuracy=0.805][A
Epoch 2:  57%|█████▋    | 223/390 [00:17<00:12, 13.44it/s, loss=0.681, accuracy=0.805][A
Epoch 2:  57%|█████▋    | 223/390 [00:17<00:12, 13.44it/s, loss=0.706, accuracy=0.758][A
Epoch 2:  57%|█████▋    | 224/390 [00:17<00:12, 13.44it/s, loss=0.839, accuracy=0.703][A
Epoch 2:  58%|█████▊    | 225/390 [00:17<00:12, 13.42it/s, loss=0.839, accuracy=0.703][A
Epoch 2:  58%|█████▊    | 225/390 [00:17<00:12, 13.42it/s, loss=0.639, accuracy=0.734][A
Epoch 2:  58%|█████▊    | 226/390 [00:17<00:12, 13.42it/s, loss=0.551, accuracy=0.805][A
Epoch 2:  

Epoch 2:  87%|████████▋ | 341/390 [00:25<00:03, 13.53it/s, loss=0.546, accuracy=0.812][A
Epoch 2:  87%|████████▋ | 341/390 [00:25<00:03, 13.53it/s, loss=0.447, accuracy=0.867][A
Epoch 2:  88%|████████▊ | 342/390 [00:25<00:03, 13.53it/s, loss=0.555, accuracy=0.828][A
Epoch 2:  88%|████████▊ | 343/390 [00:25<00:03, 13.49it/s, loss=0.555, accuracy=0.828][A
Epoch 2:  88%|████████▊ | 343/390 [00:25<00:03, 13.49it/s, loss=0.626, accuracy=0.773][A
Epoch 2:  88%|████████▊ | 344/390 [00:25<00:03, 13.49it/s, loss=0.476, accuracy=0.797][A
Epoch 2:  88%|████████▊ | 345/390 [00:26<00:03, 13.47it/s, loss=0.476, accuracy=0.797][A
Epoch 2:  88%|████████▊ | 345/390 [00:26<00:03, 13.47it/s, loss=0.567, accuracy=0.805][A
Epoch 2:  89%|████████▊ | 346/390 [00:26<00:03, 13.47it/s, loss=0.63, accuracy=0.836] [A
Epoch 2:  89%|████████▉ | 347/390 [00:26<00:03, 13.43it/s, loss=0.63, accuracy=0.836][A
Epoch 2:  89%|████████▉ | 347/390 [00:26<00:03, 13.43it/s, loss=0.558, accuracy=0.773][A
Epoch 2:  8

Epoch 3:  18%|█▊        | 72/390 [00:05<00:23, 13.52it/s, loss=0.49, accuracy=0.828] [A
Epoch 3:  19%|█▊        | 73/390 [00:05<00:23, 13.54it/s, loss=0.49, accuracy=0.828][A
Epoch 3:  19%|█▊        | 73/390 [00:05<00:23, 13.54it/s, loss=0.516, accuracy=0.836][A
Epoch 3:  19%|█▉        | 74/390 [00:06<00:23, 13.54it/s, loss=0.512, accuracy=0.812][A
Epoch 3:  19%|█▉        | 75/390 [00:06<00:23, 13.55it/s, loss=0.512, accuracy=0.812][A
Epoch 3:  19%|█▉        | 75/390 [00:06<00:23, 13.55it/s, loss=0.481, accuracy=0.82] [A
Epoch 3:  19%|█▉        | 76/390 [00:06<00:23, 13.55it/s, loss=0.404, accuracy=0.836][A
Epoch 3:  20%|█▉        | 77/390 [00:06<00:23, 13.54it/s, loss=0.404, accuracy=0.836][A
Epoch 3:  20%|█▉        | 77/390 [00:06<00:23, 13.54it/s, loss=0.668, accuracy=0.805][A
Epoch 3:  20%|██        | 78/390 [00:06<00:23, 13.54it/s, loss=0.46, accuracy=0.789] [A
Epoch 3:  20%|██        | 79/390 [00:06<00:22, 13.56it/s, loss=0.46, accuracy=0.789][A
Epoch 3:  20%|██       

Epoch 3:  49%|████▉     | 193/390 [00:14<00:14, 13.49it/s, loss=0.355, accuracy=0.883][A
Epoch 3:  50%|████▉     | 194/390 [00:14<00:14, 13.49it/s, loss=0.516, accuracy=0.812][A
Epoch 3:  50%|█████     | 195/390 [00:14<00:14, 13.47it/s, loss=0.516, accuracy=0.812][A
Epoch 3:  50%|█████     | 195/390 [00:14<00:14, 13.47it/s, loss=0.45, accuracy=0.875] [A
Epoch 3:  50%|█████     | 196/390 [00:15<00:14, 13.47it/s, loss=0.28, accuracy=0.914][A
Epoch 3:  51%|█████     | 197/390 [00:15<00:14, 13.45it/s, loss=0.28, accuracy=0.914][A
Epoch 3:  51%|█████     | 197/390 [00:15<00:14, 13.45it/s, loss=0.42, accuracy=0.844][A
Epoch 3:  51%|█████     | 198/390 [00:15<00:14, 13.45it/s, loss=0.49, accuracy=0.82] [A
Epoch 3:  51%|█████     | 199/390 [00:15<00:14, 13.45it/s, loss=0.49, accuracy=0.82][A
Epoch 3:  51%|█████     | 199/390 [00:15<00:14, 13.45it/s, loss=0.51, accuracy=0.836][A
Epoch 3:  51%|█████▏    | 200/390 [00:15<00:14, 13.45it/s, loss=0.422, accuracy=0.844][A
Epoch 3:  52%|███

Epoch 3:  81%|████████  | 315/390 [00:23<00:05, 13.20it/s, loss=0.417, accuracy=0.852][A
Epoch 3:  81%|████████  | 315/390 [00:23<00:05, 13.20it/s, loss=0.217, accuracy=0.906][A
Epoch 3:  81%|████████  | 316/390 [00:24<00:05, 13.20it/s, loss=0.419, accuracy=0.859][A
Epoch 3:  81%|████████▏ | 317/390 [00:24<00:05, 13.24it/s, loss=0.419, accuracy=0.859][A
Epoch 3:  81%|████████▏ | 317/390 [00:24<00:05, 13.24it/s, loss=0.351, accuracy=0.867][A
Epoch 3:  82%|████████▏ | 318/390 [00:24<00:05, 13.24it/s, loss=0.359, accuracy=0.875][A
Epoch 3:  82%|████████▏ | 319/390 [00:24<00:05, 13.27it/s, loss=0.359, accuracy=0.875][A
Epoch 3:  82%|████████▏ | 319/390 [00:24<00:05, 13.27it/s, loss=0.402, accuracy=0.883][A
Epoch 3:  82%|████████▏ | 320/390 [00:24<00:05, 13.27it/s, loss=0.387, accuracy=0.891][A
Epoch 3:  82%|████████▏ | 321/390 [00:24<00:05, 13.28it/s, loss=0.387, accuracy=0.891][A
Epoch 3:  82%|████████▏ | 321/390 [00:24<00:05, 13.28it/s, loss=0.339, accuracy=0.875][A
Epoch 3:  

Epoch 4:  12%|█▏        | 46/390 [00:04<00:25, 13.70it/s, loss=0.267, accuracy=0.93] [A
Epoch 4:  12%|█▏        | 47/390 [00:04<00:25, 13.71it/s, loss=0.267, accuracy=0.93][A
Epoch 4:  12%|█▏        | 47/390 [00:04<00:25, 13.71it/s, loss=0.228, accuracy=0.922][A
Epoch 4:  12%|█▏        | 48/390 [00:04<00:24, 13.71it/s, loss=0.315, accuracy=0.891][A
Epoch 4:  13%|█▎        | 49/390 [00:04<00:25, 13.63it/s, loss=0.315, accuracy=0.891][A
Epoch 4:  13%|█▎        | 49/390 [00:04<00:25, 13.63it/s, loss=0.316, accuracy=0.875][A
Epoch 4:  13%|█▎        | 50/390 [00:04<00:24, 13.63it/s, loss=0.293, accuracy=0.891][A
Epoch 4:  13%|█▎        | 51/390 [00:04<00:25, 13.53it/s, loss=0.293, accuracy=0.891][A
Epoch 4:  13%|█▎        | 51/390 [00:04<00:25, 13.53it/s, loss=0.298, accuracy=0.906][A
Epoch 4:  13%|█▎        | 52/390 [00:04<00:24, 13.53it/s, loss=0.203, accuracy=0.945][A
Epoch 4:  14%|█▎        | 53/390 [00:04<00:24, 13.50it/s, loss=0.203, accuracy=0.945][A
Epoch 4:  14%|█▎      

Epoch 4:  43%|████▎     | 167/390 [00:13<00:16, 13.17it/s, loss=0.404, accuracy=0.828][A
Epoch 4:  43%|████▎     | 168/390 [00:13<00:16, 13.17it/s, loss=0.297, accuracy=0.891][A
Epoch 4:  43%|████▎     | 169/390 [00:13<00:16, 13.20it/s, loss=0.297, accuracy=0.891][A
Epoch 4:  43%|████▎     | 169/390 [00:13<00:16, 13.20it/s, loss=0.367, accuracy=0.883][A
Epoch 4:  44%|████▎     | 170/390 [00:13<00:16, 13.20it/s, loss=0.441, accuracy=0.812][A
Epoch 4:  44%|████▍     | 171/390 [00:13<00:16, 13.21it/s, loss=0.441, accuracy=0.812][A
Epoch 4:  44%|████▍     | 171/390 [00:13<00:16, 13.21it/s, loss=0.31, accuracy=0.875] [A
Epoch 4:  44%|████▍     | 172/390 [00:13<00:16, 13.21it/s, loss=0.39, accuracy=0.891][A
Epoch 4:  44%|████▍     | 173/390 [00:13<00:16, 13.16it/s, loss=0.39, accuracy=0.891][A
Epoch 4:  44%|████▍     | 173/390 [00:13<00:16, 13.16it/s, loss=0.311, accuracy=0.898][A
Epoch 4:  45%|████▍     | 174/390 [00:13<00:16, 13.16it/s, loss=0.261, accuracy=0.914][A
Epoch 4:  45

Epoch 4:  74%|███████▍  | 289/390 [00:22<00:07, 13.18it/s, loss=0.344, accuracy=0.922][A
Epoch 4:  74%|███████▍  | 289/390 [00:22<00:07, 13.18it/s, loss=0.316, accuracy=0.867][A
Epoch 4:  74%|███████▍  | 290/390 [00:22<00:07, 13.18it/s, loss=0.309, accuracy=0.914][A
Epoch 4:  75%|███████▍  | 291/390 [00:22<00:07, 13.22it/s, loss=0.309, accuracy=0.914][A
Epoch 4:  75%|███████▍  | 291/390 [00:22<00:07, 13.22it/s, loss=0.159, accuracy=0.945][A
Epoch 4:  75%|███████▍  | 292/390 [00:22<00:07, 13.22it/s, loss=0.266, accuracy=0.914][A
Epoch 4:  75%|███████▌  | 293/390 [00:22<00:07, 13.30it/s, loss=0.266, accuracy=0.914][A
Epoch 4:  75%|███████▌  | 293/390 [00:22<00:07, 13.30it/s, loss=0.219, accuracy=0.914][A
Epoch 4:  75%|███████▌  | 294/390 [00:22<00:07, 13.30it/s, loss=0.206, accuracy=0.922][A
Epoch 4:  76%|███████▌  | 295/390 [00:22<00:07, 13.35it/s, loss=0.206, accuracy=0.922][A
Epoch 4:  76%|███████▌  | 295/390 [00:22<00:07, 13.35it/s, loss=0.209, accuracy=0.945][A
Epoch 4:  

Epoch 5:   5%|▍         | 19/390 [00:02<00:36, 10.10it/s, loss=0.19, accuracy=0.93]  [A
Epoch 5:   5%|▌         | 20/390 [00:02<00:36, 10.10it/s, loss=0.216, accuracy=0.922][A
Epoch 5:   5%|▌         | 21/390 [00:02<00:33, 10.91it/s, loss=0.216, accuracy=0.922][A
Epoch 5:   5%|▌         | 21/390 [00:02<00:33, 10.91it/s, loss=0.157, accuracy=0.938][A
Epoch 5:   6%|▌         | 22/390 [00:02<00:33, 10.91it/s, loss=0.195, accuracy=0.922][A
Epoch 5:   6%|▌         | 23/390 [00:02<00:31, 11.55it/s, loss=0.195, accuracy=0.922][A
Epoch 5:   6%|▌         | 23/390 [00:02<00:31, 11.55it/s, loss=0.192, accuracy=0.945][A
Epoch 5:   6%|▌         | 24/390 [00:02<00:31, 11.55it/s, loss=0.37, accuracy=0.859] [A
Epoch 5:   6%|▋         | 25/390 [00:02<00:30, 12.05it/s, loss=0.37, accuracy=0.859][A
Epoch 5:   6%|▋         | 25/390 [00:02<00:30, 12.05it/s, loss=0.209, accuracy=0.914][A
Epoch 5:   7%|▋         | 26/390 [00:02<00:30, 12.05it/s, loss=0.293, accuracy=0.906][A
Epoch 5:   7%|▋       

Epoch 5:  36%|███▌      | 141/390 [00:11<00:18, 13.45it/s, loss=0.179, accuracy=0.93] [A
Epoch 5:  36%|███▋      | 142/390 [00:11<00:18, 13.45it/s, loss=0.192, accuracy=0.93][A
Epoch 5:  37%|███▋      | 143/390 [00:11<00:18, 13.41it/s, loss=0.192, accuracy=0.93][A
Epoch 5:  37%|███▋      | 143/390 [00:11<00:18, 13.41it/s, loss=0.245, accuracy=0.93][A
Epoch 5:  37%|███▋      | 144/390 [00:11<00:18, 13.41it/s, loss=0.252, accuracy=0.906][A
Epoch 5:  37%|███▋      | 145/390 [00:11<00:18, 13.44it/s, loss=0.252, accuracy=0.906][A
Epoch 5:  37%|███▋      | 145/390 [00:11<00:18, 13.44it/s, loss=0.306, accuracy=0.898][A
Epoch 5:  37%|███▋      | 146/390 [00:11<00:18, 13.44it/s, loss=0.257, accuracy=0.922][A
Epoch 5:  38%|███▊      | 147/390 [00:11<00:18, 13.45it/s, loss=0.257, accuracy=0.922][A
Epoch 5:  38%|███▊      | 147/390 [00:11<00:18, 13.45it/s, loss=0.204, accuracy=0.906][A
Epoch 5:  38%|███▊      | 148/390 [00:11<00:17, 13.45it/s, loss=0.254, accuracy=0.922][A
Epoch 5:  38%

Epoch 5:  67%|██████▋   | 263/390 [00:20<00:09, 13.45it/s, loss=0.263, accuracy=0.898][A
Epoch 5:  67%|██████▋   | 263/390 [00:20<00:09, 13.45it/s, loss=0.247, accuracy=0.906][A
Epoch 5:  68%|██████▊   | 264/390 [00:20<00:09, 13.45it/s, loss=0.149, accuracy=0.945][A
Epoch 5:  68%|██████▊   | 265/390 [00:20<00:09, 13.39it/s, loss=0.149, accuracy=0.945][A
Epoch 5:  68%|██████▊   | 265/390 [00:20<00:09, 13.39it/s, loss=0.236, accuracy=0.891][A
Epoch 5:  68%|██████▊   | 266/390 [00:20<00:09, 13.39it/s, loss=0.208, accuracy=0.938][A
Epoch 5:  68%|██████▊   | 267/390 [00:20<00:09, 13.28it/s, loss=0.208, accuracy=0.938][A
Epoch 5:  68%|██████▊   | 267/390 [00:20<00:09, 13.28it/s, loss=0.151, accuracy=0.945][A
Epoch 5:  69%|██████▊   | 268/390 [00:20<00:09, 13.28it/s, loss=0.149, accuracy=0.953][A
Epoch 5:  69%|██████▉   | 269/390 [00:20<00:09, 13.21it/s, loss=0.149, accuracy=0.953][A
Epoch 5:  69%|██████▉   | 269/390 [00:20<00:09, 13.21it/s, loss=0.23, accuracy=0.93]  [A
Epoch 5:  

Epoch 5:  98%|█████████▊| 384/390 [00:29<00:00, 13.42it/s, loss=0.197, accuracy=0.914][A
Epoch 5:  99%|█████████▊| 385/390 [00:29<00:00, 13.45it/s, loss=0.197, accuracy=0.914][A
Epoch 5:  99%|█████████▊| 385/390 [00:29<00:00, 13.45it/s, loss=0.133, accuracy=0.953][A
Epoch 5:  99%|█████████▉| 386/390 [00:29<00:00, 13.45it/s, loss=0.256, accuracy=0.914][A
Epoch 5:  99%|█████████▉| 387/390 [00:29<00:00, 13.45it/s, loss=0.256, accuracy=0.914][A
Epoch 5:  99%|█████████▉| 387/390 [00:29<00:00, 13.45it/s, loss=0.393, accuracy=0.844][A
Epoch 5:  99%|█████████▉| 388/390 [00:29<00:00, 13.45it/s, loss=0.232, accuracy=0.922][A
Epoch 5: 100%|█████████▉| 389/390 [00:29<00:00, 13.47it/s, loss=0.232, accuracy=0.922][A
Epoch 5: 100%|█████████▉| 389/390 [00:29<00:00, 13.47it/s, loss=0.252, accuracy=0.914][A
Epoch 5: 100%|██████████| 390/390 [00:31<00:00, 12.27it/s, loss=0.214, accuracy=0.925, val_loss=0.95, val_accuracy=0.753][A

Finished.
final validation accuracy:





0.7530048076923077


# TRAIN MODEL USING PGD SGD 

In [35]:
def infnorm(x):
    infn = torch.max(torch.abs(x.detach().cpu()))
    return infn

In [36]:
def adversarial_training(model, optimiser, loss_fn, x, y, epoch, adversary, k, step, eps, norm, random):
    """Performs a single update against a specified adversary"""
    model.train()

    # Adversial perturbation
    x_adv = adversary(model, x, y, loss_fn, k=k, step=step, eps=eps, norm=norm, random=True)
    #print('l2:',torch.norm(x_adv.detach().cpu()-x.detach().cpu())/np.sqrt(x.detach().cpu().size(0)))    
    #print('linf:',infnorm(x_adv.detach().cpu()-x.detach().cpu())/infnorm(x))    

    optimiser.zero_grad()
    y_pred = model(x_adv)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimiser.step()

    return loss, y_pred

## l2 ball

In [37]:
if TrainL2:
    ## initialize model
    adv_model_l2 = Net().to(DEVICE)
    lr = 0.1
    optimiser = optim.SGD(adv_model_l2.parameters(), lr=lr)
    epochs = 5
    ## train model
    training_history_l2 = olympic.fit(
        adv_model_l2,
        optimiser,
        nn.CrossEntropyLoss(),
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=adversarial_training,
        update_fn_kwargs={'adversary': pgd, 'k': 10, 'step': 0.01, 'eps': 3.0, 'norm': 2, 'random':True},
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5, factor=0.5, min_delta=0.005, monitor='val_accuracy')
        ]
    )
    ## verify validation accuracy
    print('final validation accuracy:')
    valacc = olympic.evaluate(adv_model_l2, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    print(valacc['val_accuracy'])
    ## save trained model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_AT2_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(adv_model_l2,modelname)


Epoch 1:   0%|          | 0/390 [00:00<?, ?it/s][A

Begin training...



Epoch 1:   0%|          | 1/390 [00:01<08:06,  1.25s/it][A
Epoch 1:   0%|          | 1/390 [00:01<08:06,  1.25s/it, loss=2.92, accuracy=0.0156][A
Epoch 1:   1%|          | 2/390 [00:01<06:43,  1.04s/it, loss=2.92, accuracy=0.0156][A
Epoch 1:   1%|          | 2/390 [00:01<06:43,  1.04s/it, loss=3.5, accuracy=0.117]  [A
Epoch 1:   1%|          | 3/390 [00:02<05:45,  1.12it/s, loss=3.5, accuracy=0.117][A
Epoch 1:   1%|          | 3/390 [00:02<05:45,  1.12it/s, loss=5.01, accuracy=0.102][A
Epoch 1:   1%|          | 4/390 [00:02<05:04,  1.27it/s, loss=5.01, accuracy=0.102][A
Epoch 1:   1%|          | 4/390 [00:02<05:04,  1.27it/s, loss=3.31, accuracy=0.133][A
Epoch 1:   1%|▏         | 5/390 [00:03<04:36,  1.39it/s, loss=3.31, accuracy=0.133][A
Epoch 1:   1%|▏         | 5/390 [00:03<04:36,  1.39it/s, loss=4.55, accuracy=0.211][A
Epoch 1:   2%|▏         | 6/390 [00:03<04:16,  1.50it/s, loss=4.55, accuracy=0.211][A
Epoch 1:   2%|▏         | 6/390 [00:04<04:16,  1.50it/s, loss=3.06,

Epoch 1:  24%|██▍       | 94/390 [00:52<02:46,  1.78it/s, loss=1.84, accuracy=0.328][A
Epoch 1:  24%|██▍       | 94/390 [00:52<02:46,  1.78it/s, loss=1.77, accuracy=0.336][A
Epoch 1:  24%|██▍       | 95/390 [00:53<02:45,  1.78it/s, loss=1.77, accuracy=0.336][A
Epoch 1:  24%|██▍       | 95/390 [00:53<02:45,  1.78it/s, loss=1.66, accuracy=0.312][A
Epoch 1:  25%|██▍       | 96/390 [00:53<02:44,  1.78it/s, loss=1.66, accuracy=0.312][A
Epoch 1:  25%|██▍       | 96/390 [00:53<02:44,  1.78it/s, loss=1.62, accuracy=0.406][A
Epoch 1:  25%|██▍       | 97/390 [00:54<02:44,  1.78it/s, loss=1.62, accuracy=0.406][A
Epoch 1:  25%|██▍       | 97/390 [00:54<02:44,  1.78it/s, loss=1.75, accuracy=0.414][A
Epoch 1:  25%|██▌       | 98/390 [00:55<02:43,  1.78it/s, loss=1.75, accuracy=0.414][A
Epoch 1:  25%|██▌       | 98/390 [00:55<02:43,  1.78it/s, loss=1.86, accuracy=0.352][A
Epoch 1:  25%|██▌       | 99/390 [00:55<02:43,  1.78it/s, loss=1.86, accuracy=0.352][A
Epoch 1:  25%|██▌       | 99/390

KeyboardInterrupt: 

## linf ball

In [24]:
if TrainLInf:
    ## initialize model
    adv_model_linf = Net().to(DEVICE)
    ## train params
    lr = 0.1
    optimiser = optim.SGD(adv_model_linf.parameters(), lr=lr)
    epochs = 5
    ## train model
    training_history_linf = olympic.fit(
        adv_model_linf,
        optimiser,
        nn.CrossEntropyLoss(),
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=adversarial_training,
        update_fn_kwargs={'adversary': iterated_fgsm,'k': 10, 'step': 0.07, 'eps': 0.032, 'norm': 'inf', 'random':True},
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5, factor=0.5, min_delta=0.005, monitor='val_accuracy')
        ]
    )
    ## verify validation
    print('final validation accuracy:')
    valacc = olympic.evaluate(adv_model_linf, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    
    ## save trained model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_ATInf_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(adv_model_linf,modelname)


Epoch 1:   0%|          | 0/390 [00:00<?, ?it/s][A

Begin training...



Epoch 1:   0%|          | 1/390 [00:01<07:40,  1.18s/it][A
Epoch 1:   0%|          | 1/390 [00:01<07:40,  1.18s/it, loss=3.02, accuracy=0][A
Epoch 1:   1%|          | 2/390 [00:01<06:25,  1.01it/s, loss=3.02, accuracy=0][A
Epoch 1:   1%|          | 2/390 [00:01<06:25,  1.01it/s, loss=4.53, accuracy=0.117][A
Epoch 1:   1%|          | 3/390 [00:02<05:31,  1.17it/s, loss=4.53, accuracy=0.117][A
Epoch 1:   1%|          | 3/390 [00:02<05:31,  1.17it/s, loss=5.95, accuracy=0.0625][A
Epoch 1:   1%|          | 4/390 [00:02<04:54,  1.31it/s, loss=5.95, accuracy=0.0625][A
Epoch 1:   1%|          | 4/390 [00:02<04:54,  1.31it/s, loss=4.05, accuracy=0.0703][A
Epoch 1:   1%|▏         | 5/390 [00:03<04:28,  1.43it/s, loss=4.05, accuracy=0.0703][A
Epoch 1:   1%|▏         | 5/390 [00:03<04:28,  1.43it/s, loss=2.92, accuracy=0.0547][A
Epoch 1:   2%|▏         | 6/390 [00:03<04:10,  1.53it/s, loss=2.92, accuracy=0.0547][A
Epoch 1:   2%|▏         | 6/390 [00:03<04:10,  1.53it/s, loss=2.52, acc

Epoch 1:  24%|██▍       | 94/390 [00:52<02:44,  1.80it/s, loss=2.27, accuracy=0.156][A
Epoch 1:  24%|██▍       | 94/390 [00:52<02:44,  1.80it/s, loss=2.14, accuracy=0.227][A
Epoch 1:  24%|██▍       | 95/390 [00:52<02:43,  1.80it/s, loss=2.14, accuracy=0.227][A
Epoch 1:  24%|██▍       | 95/390 [00:52<02:43,  1.80it/s, loss=2.07, accuracy=0.18] [A
Epoch 1:  25%|██▍       | 96/390 [00:53<02:43,  1.80it/s, loss=2.07, accuracy=0.18][A
Epoch 1:  25%|██▍       | 96/390 [00:53<02:43,  1.80it/s, loss=2.07, accuracy=0.227][A
Epoch 1:  25%|██▍       | 97/390 [00:54<02:43,  1.80it/s, loss=2.07, accuracy=0.227][A
Epoch 1:  25%|██▍       | 97/390 [00:54<02:43,  1.80it/s, loss=2.16, accuracy=0.219][A
Epoch 1:  25%|██▌       | 98/390 [00:54<02:42,  1.80it/s, loss=2.16, accuracy=0.219][A
Epoch 1:  25%|██▌       | 98/390 [00:54<02:42,  1.80it/s, loss=2.13, accuracy=0.203][A
Epoch 1:  25%|██▌       | 99/390 [00:55<02:41,  1.80it/s, loss=2.13, accuracy=0.203][A
Epoch 1:  25%|██▌       | 99/390 

Epoch 1:  48%|████▊     | 186/390 [01:46<02:04,  1.64it/s, loss=2, accuracy=0.281][A
Epoch 1:  48%|████▊     | 186/390 [01:46<02:04,  1.64it/s, loss=1.99, accuracy=0.258][A
Epoch 1:  48%|████▊     | 187/390 [01:47<02:03,  1.64it/s, loss=1.99, accuracy=0.258][A
Epoch 1:  48%|████▊     | 187/390 [01:47<02:03,  1.64it/s, loss=2.02, accuracy=0.305][A
Epoch 1:  48%|████▊     | 188/390 [01:47<02:03,  1.64it/s, loss=2.02, accuracy=0.305][A
Epoch 1:  48%|████▊     | 188/390 [01:47<02:03,  1.64it/s, loss=2.05, accuracy=0.219][A
Epoch 1:  48%|████▊     | 189/390 [01:48<02:02,  1.64it/s, loss=2.05, accuracy=0.219][A
Epoch 1:  48%|████▊     | 189/390 [01:48<02:02,  1.64it/s, loss=1.99, accuracy=0.188][A
Epoch 1:  49%|████▊     | 190/390 [01:48<02:01,  1.64it/s, loss=1.99, accuracy=0.188][A
Epoch 1:  49%|████▊     | 190/390 [01:48<02:01,  1.64it/s, loss=2.02, accuracy=0.156][A
Epoch 1:  49%|████▉     | 191/390 [01:49<02:01,  1.64it/s, loss=2.02, accuracy=0.156][A
Epoch 1:  49%|████▉     

Epoch 1:  71%|███████▏  | 278/390 [02:42<01:08,  1.64it/s, loss=1.97, accuracy=0.297][A
Epoch 1:  71%|███████▏  | 278/390 [02:42<01:08,  1.64it/s, loss=1.9, accuracy=0.242] [A
Epoch 1:  72%|███████▏  | 279/390 [02:43<01:07,  1.65it/s, loss=1.9, accuracy=0.242][A
Epoch 1:  72%|███████▏  | 279/390 [02:43<01:07,  1.65it/s, loss=2.06, accuracy=0.242][A
Epoch 1:  72%|███████▏  | 280/390 [02:43<01:06,  1.64it/s, loss=2.06, accuracy=0.242][A
Epoch 1:  72%|███████▏  | 280/390 [02:43<01:06,  1.64it/s, loss=2.01, accuracy=0.266][A
Epoch 1:  72%|███████▏  | 281/390 [02:44<01:06,  1.64it/s, loss=2.01, accuracy=0.266][A
Epoch 1:  72%|███████▏  | 281/390 [02:44<01:06,  1.64it/s, loss=1.89, accuracy=0.367][A
Epoch 1:  72%|███████▏  | 282/390 [02:44<01:05,  1.64it/s, loss=1.89, accuracy=0.367][A
Epoch 1:  72%|███████▏  | 282/390 [02:44<01:05,  1.64it/s, loss=2.07, accuracy=0.25] [A
Epoch 1:  73%|███████▎  | 283/390 [02:45<01:05,  1.64it/s, loss=2.07, accuracy=0.25][A
Epoch 1:  73%|███████▎ 

Epoch 1:  95%|█████████▍| 370/390 [03:38<00:12,  1.64it/s, loss=1.82, accuracy=0.297][A
Epoch 1:  95%|█████████▍| 370/390 [03:38<00:12,  1.64it/s, loss=1.86, accuracy=0.328][A
Epoch 1:  95%|█████████▌| 371/390 [03:39<00:11,  1.64it/s, loss=1.86, accuracy=0.328][A
Epoch 1:  95%|█████████▌| 371/390 [03:39<00:11,  1.64it/s, loss=1.93, accuracy=0.273][A
Epoch 1:  95%|█████████▌| 372/390 [03:39<00:10,  1.64it/s, loss=1.93, accuracy=0.273][A
Epoch 1:  95%|█████████▌| 372/390 [03:39<00:10,  1.64it/s, loss=1.93, accuracy=0.273][A
Epoch 1:  96%|█████████▌| 373/390 [03:40<00:10,  1.65it/s, loss=1.93, accuracy=0.273][A
Epoch 1:  96%|█████████▌| 373/390 [03:40<00:10,  1.65it/s, loss=1.89, accuracy=0.266][A
Epoch 1:  96%|█████████▌| 374/390 [03:40<00:09,  1.65it/s, loss=1.89, accuracy=0.266][A
Epoch 1:  96%|█████████▌| 374/390 [03:40<00:09,  1.65it/s, loss=1.99, accuracy=0.273][A
Epoch 1:  96%|█████████▌| 375/390 [03:41<00:09,  1.65it/s, loss=1.99, accuracy=0.273][A
Epoch 1:  96%|███████

Epoch 2:  18%|█▊        | 72/390 [00:47<03:46,  1.40it/s, loss=1.89, accuracy=0.258][A
Epoch 2:  19%|█▊        | 73/390 [00:47<03:44,  1.41it/s, loss=1.89, accuracy=0.258][A
Epoch 2:  19%|█▊        | 73/390 [00:47<03:44,  1.41it/s, loss=1.85, accuracy=0.273][A
Epoch 2:  19%|█▉        | 74/390 [00:48<03:42,  1.42it/s, loss=1.85, accuracy=0.273][A
Epoch 2:  19%|█▉        | 74/390 [00:48<03:42,  1.42it/s, loss=1.99, accuracy=0.242][A
Epoch 2:  19%|█▉        | 75/390 [00:49<03:40,  1.43it/s, loss=1.99, accuracy=0.242][A
Epoch 2:  19%|█▉        | 75/390 [00:49<03:40,  1.43it/s, loss=1.87, accuracy=0.266][A
Epoch 2:  19%|█▉        | 76/390 [00:50<03:41,  1.42it/s, loss=1.87, accuracy=0.266][A
Epoch 2:  19%|█▉        | 76/390 [00:50<03:41,  1.42it/s, loss=1.99, accuracy=0.234][A
Epoch 2:  20%|█▉        | 77/390 [00:50<03:40,  1.42it/s, loss=1.99, accuracy=0.234][A
Epoch 2:  20%|█▉        | 77/390 [00:50<03:40,  1.42it/s, loss=2.02, accuracy=0.234][A
Epoch 2:  20%|██        | 78/390

Epoch 2:  42%|████▏     | 164/390 [01:51<02:40,  1.41it/s, loss=1.89, accuracy=0.305][A
Epoch 2:  42%|████▏     | 165/390 [01:52<02:38,  1.42it/s, loss=1.89, accuracy=0.305][A
Epoch 2:  42%|████▏     | 165/390 [01:52<02:38,  1.42it/s, loss=1.92, accuracy=0.266][A
Epoch 2:  43%|████▎     | 166/390 [01:53<02:38,  1.42it/s, loss=1.92, accuracy=0.266][A
Epoch 2:  43%|████▎     | 166/390 [01:53<02:38,  1.42it/s, loss=1.95, accuracy=0.258][A
Epoch 2:  43%|████▎     | 167/390 [01:53<02:36,  1.42it/s, loss=1.95, accuracy=0.258][A
Epoch 2:  43%|████▎     | 167/390 [01:53<02:36,  1.42it/s, loss=1.99, accuracy=0.289][A
Epoch 2:  43%|████▎     | 168/390 [01:54<02:34,  1.44it/s, loss=1.99, accuracy=0.289][A
Epoch 2:  43%|████▎     | 168/390 [01:54<02:34,  1.44it/s, loss=1.97, accuracy=0.297][A
Epoch 2:  43%|████▎     | 169/390 [01:55<02:33,  1.44it/s, loss=1.97, accuracy=0.297][A
Epoch 2:  43%|████▎     | 169/390 [01:55<02:33,  1.44it/s, loss=1.85, accuracy=0.289][A
Epoch 2:  44%|████▎  

Epoch 2:  66%|██████▌   | 256/390 [02:56<01:35,  1.41it/s, loss=1.91, accuracy=0.234][A
Epoch 2:  66%|██████▌   | 257/390 [02:56<01:34,  1.41it/s, loss=1.91, accuracy=0.234][A
Epoch 2:  66%|██████▌   | 257/390 [02:56<01:34,  1.41it/s, loss=2.01, accuracy=0.188][A
Epoch 2:  66%|██████▌   | 258/390 [02:57<01:33,  1.41it/s, loss=2.01, accuracy=0.188][A
Epoch 2:  66%|██████▌   | 258/390 [02:57<01:33,  1.41it/s, loss=1.93, accuracy=0.273][A
Epoch 2:  66%|██████▋   | 259/390 [02:58<01:33,  1.39it/s, loss=1.93, accuracy=0.273][A
Epoch 2:  66%|██████▋   | 259/390 [02:58<01:33,  1.39it/s, loss=2, accuracy=0.258]   [A
Epoch 2:  67%|██████▋   | 260/390 [02:58<01:32,  1.41it/s, loss=2, accuracy=0.258][A
Epoch 2:  67%|██████▋   | 260/390 [02:58<01:32,  1.41it/s, loss=1.82, accuracy=0.367][A
Epoch 2:  67%|██████▋   | 261/390 [02:59<01:31,  1.41it/s, loss=1.82, accuracy=0.367][A
Epoch 2:  67%|██████▋   | 261/390 [02:59<01:31,  1.41it/s, loss=2, accuracy=0.297]   [A
Epoch 2:  67%|██████▋   

Epoch 2:  89%|████████▉ | 348/390 [04:01<00:30,  1.39it/s, loss=1.89, accuracy=0.336][A
Epoch 2:  89%|████████▉ | 349/390 [04:01<00:29,  1.39it/s, loss=1.89, accuracy=0.336][A
Epoch 2:  89%|████████▉ | 349/390 [04:01<00:29,  1.39it/s, loss=1.93, accuracy=0.297][A
Epoch 2:  90%|████████▉ | 350/390 [04:02<00:28,  1.40it/s, loss=1.93, accuracy=0.297][A
Epoch 2:  90%|████████▉ | 350/390 [04:02<00:28,  1.40it/s, loss=1.82, accuracy=0.297][A
Epoch 2:  90%|█████████ | 351/390 [04:03<00:27,  1.40it/s, loss=1.82, accuracy=0.297][A
Epoch 2:  90%|█████████ | 351/390 [04:03<00:27,  1.40it/s, loss=1.88, accuracy=0.312][A
Epoch 2:  90%|█████████ | 352/390 [04:04<00:27,  1.40it/s, loss=1.88, accuracy=0.312][A
Epoch 2:  90%|█████████ | 352/390 [04:04<00:27,  1.40it/s, loss=1.7, accuracy=0.414] [A
Epoch 2:  91%|█████████ | 353/390 [04:04<00:26,  1.40it/s, loss=1.7, accuracy=0.414][A
Epoch 2:  91%|█████████ | 353/390 [04:04<00:26,  1.40it/s, loss=1.92, accuracy=0.344][A
Epoch 2:  91%|████████

Epoch 3:  13%|█▎        | 50/390 [00:33<03:27,  1.64it/s, loss=1.76, accuracy=0.398][A
Epoch 3:  13%|█▎        | 51/390 [00:34<03:27,  1.63it/s, loss=1.76, accuracy=0.398][A
Epoch 3:  13%|█▎        | 51/390 [00:34<03:27,  1.63it/s, loss=1.84, accuracy=0.273][A
Epoch 3:  13%|█▎        | 52/390 [00:34<03:25,  1.65it/s, loss=1.84, accuracy=0.273][A
Epoch 3:  13%|█▎        | 52/390 [00:34<03:25,  1.65it/s, loss=1.86, accuracy=0.297][A
Epoch 3:  14%|█▎        | 53/390 [00:35<03:24,  1.65it/s, loss=1.86, accuracy=0.297][A
Epoch 3:  14%|█▎        | 53/390 [00:35<03:24,  1.65it/s, loss=1.76, accuracy=0.344][A
Epoch 3:  14%|█▍        | 54/390 [00:36<03:24,  1.65it/s, loss=1.76, accuracy=0.344][A
Epoch 3:  14%|█▍        | 54/390 [00:36<03:24,  1.65it/s, loss=1.97, accuracy=0.281][A
Epoch 3:  14%|█▍        | 55/390 [00:36<03:24,  1.64it/s, loss=1.97, accuracy=0.281][A
Epoch 3:  14%|█▍        | 55/390 [00:36<03:24,  1.64it/s, loss=1.76, accuracy=0.305][A
Epoch 3:  14%|█▍        | 56/390

Epoch 3:  37%|███▋      | 143/390 [01:38<02:58,  1.38it/s, loss=1.86, accuracy=0.305][A
Epoch 3:  37%|███▋      | 143/390 [01:38<02:58,  1.38it/s, loss=1.77, accuracy=0.336][A
Epoch 3:  37%|███▋      | 144/390 [01:38<02:58,  1.37it/s, loss=1.77, accuracy=0.336][A
Epoch 3:  37%|███▋      | 144/390 [01:38<02:58,  1.37it/s, loss=1.92, accuracy=0.25] [A
Epoch 3:  37%|███▋      | 145/390 [01:39<02:57,  1.38it/s, loss=1.92, accuracy=0.25][A
Epoch 3:  37%|███▋      | 145/390 [01:39<02:57,  1.38it/s, loss=1.89, accuracy=0.312][A
Epoch 3:  37%|███▋      | 146/390 [01:40<02:55,  1.39it/s, loss=1.89, accuracy=0.312][A
Epoch 3:  37%|███▋      | 146/390 [01:40<02:55,  1.39it/s, loss=1.82, accuracy=0.289][A
Epoch 3:  38%|███▊      | 147/390 [01:41<02:54,  1.39it/s, loss=1.82, accuracy=0.289][A
Epoch 3:  38%|███▊      | 147/390 [01:41<02:54,  1.39it/s, loss=1.86, accuracy=0.336][A
Epoch 3:  38%|███▊      | 148/390 [01:41<02:54,  1.39it/s, loss=1.86, accuracy=0.336][A
Epoch 3:  38%|███▊    

Epoch 3:  60%|██████    | 235/390 [02:43<01:51,  1.39it/s, loss=1.8, accuracy=0.312][A
Epoch 3:  60%|██████    | 235/390 [02:43<01:51,  1.39it/s, loss=1.85, accuracy=0.398][A
Epoch 3:  61%|██████    | 236/390 [02:44<01:49,  1.40it/s, loss=1.85, accuracy=0.398][A
Epoch 3:  61%|██████    | 236/390 [02:44<01:49,  1.40it/s, loss=1.86, accuracy=0.273][A
Epoch 3:  61%|██████    | 237/390 [02:44<01:48,  1.41it/s, loss=1.86, accuracy=0.273][A
Epoch 3:  61%|██████    | 237/390 [02:44<01:48,  1.41it/s, loss=1.83, accuracy=0.312][A
Epoch 3:  61%|██████    | 238/390 [02:45<01:47,  1.41it/s, loss=1.83, accuracy=0.312][A
Epoch 3:  61%|██████    | 238/390 [02:45<01:47,  1.41it/s, loss=1.8, accuracy=0.273] [A
Epoch 3:  61%|██████▏   | 239/390 [02:46<01:47,  1.40it/s, loss=1.8, accuracy=0.273][A
Epoch 3:  61%|██████▏   | 239/390 [02:46<01:47,  1.40it/s, loss=1.88, accuracy=0.305][A
Epoch 3:  62%|██████▏   | 240/390 [02:46<01:47,  1.40it/s, loss=1.88, accuracy=0.305][A
Epoch 3:  62%|██████▏  

Epoch 3:  84%|████████▍ | 327/390 [03:48<00:42,  1.47it/s, loss=1.76, accuracy=0.305][A
Epoch 3:  84%|████████▍ | 327/390 [03:48<00:42,  1.47it/s, loss=1.71, accuracy=0.359][A
Epoch 3:  84%|████████▍ | 328/390 [03:48<00:42,  1.46it/s, loss=1.71, accuracy=0.359][A
Epoch 3:  84%|████████▍ | 328/390 [03:48<00:42,  1.46it/s, loss=1.86, accuracy=0.289][A
Epoch 3:  84%|████████▍ | 329/390 [03:49<00:41,  1.45it/s, loss=1.86, accuracy=0.289][A
Epoch 3:  84%|████████▍ | 329/390 [03:49<00:41,  1.45it/s, loss=1.87, accuracy=0.305][A
Epoch 3:  85%|████████▍ | 330/390 [03:50<00:41,  1.46it/s, loss=1.87, accuracy=0.305][A
Epoch 3:  85%|████████▍ | 330/390 [03:50<00:41,  1.46it/s, loss=1.84, accuracy=0.344][A
Epoch 3:  85%|████████▍ | 331/390 [03:50<00:40,  1.46it/s, loss=1.84, accuracy=0.344][A
Epoch 3:  85%|████████▍ | 331/390 [03:50<00:40,  1.46it/s, loss=1.83, accuracy=0.297][A
Epoch 3:  85%|████████▌ | 332/390 [03:51<00:39,  1.46it/s, loss=1.83, accuracy=0.297][A
Epoch 3:  85%|███████

Epoch 4:   7%|▋         | 29/390 [00:21<04:18,  1.39it/s, loss=1.7, accuracy=0.352][A
Epoch 4:   7%|▋         | 29/390 [00:21<04:18,  1.39it/s, loss=1.77, accuracy=0.32][A
Epoch 4:   8%|▊         | 30/390 [00:21<04:15,  1.41it/s, loss=1.77, accuracy=0.32][A
Epoch 4:   8%|▊         | 30/390 [00:21<04:15,  1.41it/s, loss=1.69, accuracy=0.305][A
Epoch 4:   8%|▊         | 31/390 [00:22<04:12,  1.42it/s, loss=1.69, accuracy=0.305][A
Epoch 4:   8%|▊         | 31/390 [00:22<04:12,  1.42it/s, loss=1.85, accuracy=0.281][A
Epoch 4:   8%|▊         | 32/390 [00:23<04:10,  1.43it/s, loss=1.85, accuracy=0.281][A
Epoch 4:   8%|▊         | 32/390 [00:23<04:10,  1.43it/s, loss=1.76, accuracy=0.32] [A
Epoch 4:   8%|▊         | 33/390 [00:23<04:11,  1.42it/s, loss=1.76, accuracy=0.32][A
Epoch 4:   8%|▊         | 33/390 [00:23<04:11,  1.42it/s, loss=1.86, accuracy=0.258][A
Epoch 4:   9%|▊         | 34/390 [00:24<04:11,  1.41it/s, loss=1.86, accuracy=0.258][A
Epoch 4:   9%|▊         | 34/390 [00

Epoch 4:  31%|███       | 121/390 [01:25<03:13,  1.39it/s, loss=1.79, accuracy=0.312][A
Epoch 4:  31%|███▏      | 122/390 [01:25<03:13,  1.39it/s, loss=1.79, accuracy=0.312][A
Epoch 4:  31%|███▏      | 122/390 [01:25<03:13,  1.39it/s, loss=1.9, accuracy=0.328] [A
Epoch 4:  32%|███▏      | 123/390 [01:26<03:10,  1.40it/s, loss=1.9, accuracy=0.328][A
Epoch 4:  32%|███▏      | 123/390 [01:26<03:10,  1.40it/s, loss=1.79, accuracy=0.32][A
Epoch 4:  32%|███▏      | 124/390 [01:27<03:09,  1.40it/s, loss=1.79, accuracy=0.32][A
Epoch 4:  32%|███▏      | 124/390 [01:27<03:09,  1.40it/s, loss=1.77, accuracy=0.32][A
Epoch 4:  32%|███▏      | 125/390 [01:28<03:10,  1.39it/s, loss=1.77, accuracy=0.32][A
Epoch 4:  32%|███▏      | 125/390 [01:28<03:10,  1.39it/s, loss=1.92, accuracy=0.266][A
Epoch 4:  32%|███▏      | 126/390 [01:28<03:08,  1.40it/s, loss=1.92, accuracy=0.266][A
Epoch 4:  32%|███▏      | 126/390 [01:28<03:08,  1.40it/s, loss=1.76, accuracy=0.32] [A
Epoch 4:  33%|███▎      | 

Epoch 4:  55%|█████▍    | 213/390 [02:30<02:03,  1.43it/s, loss=1.8, accuracy=0.281] [A
Epoch 4:  55%|█████▍    | 214/390 [02:31<02:03,  1.42it/s, loss=1.8, accuracy=0.281][A
Epoch 4:  55%|█████▍    | 214/390 [02:31<02:03,  1.42it/s, loss=1.74, accuracy=0.312][A
Epoch 4:  55%|█████▌    | 215/390 [02:31<02:04,  1.41it/s, loss=1.74, accuracy=0.312][A
Epoch 4:  55%|█████▌    | 215/390 [02:31<02:04,  1.41it/s, loss=1.85, accuracy=0.289][A
Epoch 4:  55%|█████▌    | 216/390 [02:32<02:04,  1.39it/s, loss=1.85, accuracy=0.289][A
Epoch 4:  55%|█████▌    | 216/390 [02:32<02:04,  1.39it/s, loss=1.8, accuracy=0.32]  [A
Epoch 4:  56%|█████▌    | 217/390 [02:33<02:04,  1.39it/s, loss=1.8, accuracy=0.32][A
Epoch 4:  56%|█████▌    | 217/390 [02:33<02:04,  1.39it/s, loss=1.69, accuracy=0.32][A
Epoch 4:  56%|█████▌    | 218/390 [02:33<02:03,  1.39it/s, loss=1.69, accuracy=0.32][A
Epoch 4:  56%|█████▌    | 218/390 [02:33<02:03,  1.39it/s, loss=1.85, accuracy=0.383][A
Epoch 4:  56%|█████▌    | 

Epoch 4:  78%|███████▊  | 305/390 [03:35<01:00,  1.40it/s, loss=1.69, accuracy=0.352][A
Epoch 4:  78%|███████▊  | 306/390 [03:36<01:00,  1.40it/s, loss=1.69, accuracy=0.352][A
Epoch 4:  78%|███████▊  | 306/390 [03:36<01:00,  1.40it/s, loss=1.65, accuracy=0.43] [A
Epoch 4:  79%|███████▊  | 307/390 [03:37<00:59,  1.39it/s, loss=1.65, accuracy=0.43][A
Epoch 4:  79%|███████▊  | 307/390 [03:37<00:59,  1.39it/s, loss=1.7, accuracy=0.336][A
Epoch 4:  79%|███████▉  | 308/390 [03:38<00:58,  1.39it/s, loss=1.7, accuracy=0.336][A
Epoch 4:  79%|███████▉  | 308/390 [03:38<00:58,  1.39it/s, loss=1.75, accuracy=0.344][A
Epoch 4:  79%|███████▉  | 309/390 [03:38<00:57,  1.40it/s, loss=1.75, accuracy=0.344][A
Epoch 4:  79%|███████▉  | 309/390 [03:38<00:57,  1.40it/s, loss=1.82, accuracy=0.289][A
Epoch 4:  79%|███████▉  | 310/390 [03:39<00:57,  1.40it/s, loss=1.82, accuracy=0.289][A
Epoch 4:  79%|███████▉  | 310/390 [03:39<00:57,  1.40it/s, loss=1.86, accuracy=0.281][A
Epoch 4:  80%|███████▉  

Epoch 5:   2%|▏         | 7/390 [00:05<05:07,  1.25it/s, loss=1.77, accuracy=0.289][A
Epoch 5:   2%|▏         | 7/390 [00:05<05:07,  1.25it/s, loss=1.79, accuracy=0.336][A
Epoch 5:   2%|▏         | 8/390 [00:06<04:55,  1.29it/s, loss=1.79, accuracy=0.336][A
Epoch 5:   2%|▏         | 8/390 [00:06<04:55,  1.29it/s, loss=1.7, accuracy=0.328] [A
Epoch 5:   2%|▏         | 9/390 [00:07<04:46,  1.33it/s, loss=1.7, accuracy=0.328][A
Epoch 5:   2%|▏         | 9/390 [00:07<04:46,  1.33it/s, loss=1.59, accuracy=0.352][A
Epoch 5:   3%|▎         | 10/390 [00:07<04:39,  1.36it/s, loss=1.59, accuracy=0.352][A
Epoch 5:   3%|▎         | 10/390 [00:07<04:39,  1.36it/s, loss=1.74, accuracy=0.352][A
Epoch 5:   3%|▎         | 11/390 [00:08<04:33,  1.38it/s, loss=1.74, accuracy=0.352][A
Epoch 5:   3%|▎         | 11/390 [00:08<04:33,  1.38it/s, loss=1.77, accuracy=0.297][A
Epoch 5:   3%|▎         | 12/390 [00:09<04:32,  1.39it/s, loss=1.77, accuracy=0.297][A
Epoch 5:   3%|▎         | 12/390 [00:09

Epoch 5:  26%|██▌       | 100/390 [01:06<02:57,  1.64it/s, loss=1.81, accuracy=0.297][A
Epoch 5:  26%|██▌       | 100/390 [01:06<02:57,  1.64it/s, loss=1.89, accuracy=0.273][A
Epoch 5:  26%|██▌       | 101/390 [01:07<02:56,  1.64it/s, loss=1.89, accuracy=0.273][A
Epoch 5:  26%|██▌       | 101/390 [01:07<02:56,  1.64it/s, loss=1.66, accuracy=0.359][A
Epoch 5:  26%|██▌       | 102/390 [01:08<02:55,  1.64it/s, loss=1.66, accuracy=0.359][A
Epoch 5:  26%|██▌       | 102/390 [01:08<02:55,  1.64it/s, loss=1.81, accuracy=0.32] [A
Epoch 5:  26%|██▋       | 103/390 [01:08<02:55,  1.64it/s, loss=1.81, accuracy=0.32][A
Epoch 5:  26%|██▋       | 103/390 [01:08<02:55,  1.64it/s, loss=1.72, accuracy=0.297][A
Epoch 5:  27%|██▋       | 104/390 [01:09<02:54,  1.64it/s, loss=1.72, accuracy=0.297][A
Epoch 5:  27%|██▋       | 104/390 [01:09<02:54,  1.64it/s, loss=1.72, accuracy=0.32] [A
Epoch 5:  27%|██▋       | 105/390 [01:09<02:53,  1.64it/s, loss=1.72, accuracy=0.32][A
Epoch 5:  27%|██▋      

Epoch 5:  49%|████▉     | 192/390 [02:02<02:00,  1.64it/s, loss=1.64, accuracy=0.383][A
Epoch 5:  49%|████▉     | 192/390 [02:02<02:00,  1.64it/s, loss=1.57, accuracy=0.375][A
Epoch 5:  49%|████▉     | 193/390 [02:03<02:00,  1.64it/s, loss=1.57, accuracy=0.375][A
Epoch 5:  49%|████▉     | 193/390 [02:03<02:00,  1.64it/s, loss=1.56, accuracy=0.391][A
Epoch 5:  50%|████▉     | 194/390 [02:03<01:59,  1.64it/s, loss=1.56, accuracy=0.391][A
Epoch 5:  50%|████▉     | 194/390 [02:03<01:59,  1.64it/s, loss=1.8, accuracy=0.289] [A
Epoch 5:  50%|█████     | 195/390 [02:04<01:59,  1.64it/s, loss=1.8, accuracy=0.289][A
Epoch 5:  50%|█████     | 195/390 [02:04<01:59,  1.64it/s, loss=1.73, accuracy=0.32][A
Epoch 5:  50%|█████     | 196/390 [02:05<01:58,  1.64it/s, loss=1.73, accuracy=0.32][A
Epoch 5:  50%|█████     | 196/390 [02:05<01:58,  1.64it/s, loss=1.58, accuracy=0.391][A
Epoch 5:  51%|█████     | 197/390 [02:05<01:57,  1.64it/s, loss=1.58, accuracy=0.391][A
Epoch 5:  51%|█████     

Epoch 5:  73%|███████▎  | 284/390 [02:58<01:04,  1.65it/s, loss=1.69, accuracy=0.367][A
Epoch 5:  73%|███████▎  | 284/390 [02:58<01:04,  1.65it/s, loss=1.68, accuracy=0.32] [A
Epoch 5:  73%|███████▎  | 285/390 [02:58<01:03,  1.65it/s, loss=1.68, accuracy=0.32][A
Epoch 5:  73%|███████▎  | 285/390 [02:58<01:03,  1.65it/s, loss=1.59, accuracy=0.406][A
Epoch 5:  73%|███████▎  | 286/390 [02:59<01:02,  1.66it/s, loss=1.59, accuracy=0.406][A
Epoch 5:  73%|███████▎  | 286/390 [02:59<01:02,  1.66it/s, loss=1.64, accuracy=0.367][A
Epoch 5:  74%|███████▎  | 287/390 [03:00<01:02,  1.65it/s, loss=1.64, accuracy=0.367][A
Epoch 5:  74%|███████▎  | 287/390 [03:00<01:02,  1.65it/s, loss=1.65, accuracy=0.383][A
Epoch 5:  74%|███████▍  | 288/390 [03:00<01:01,  1.65it/s, loss=1.65, accuracy=0.383][A
Epoch 5:  74%|███████▍  | 288/390 [03:00<01:01,  1.65it/s, loss=1.61, accuracy=0.398][A
Epoch 5:  74%|███████▍  | 289/390 [03:01<01:01,  1.65it/s, loss=1.61, accuracy=0.398][A
Epoch 5:  74%|███████▍

Epoch 5:  96%|█████████▋| 376/390 [03:53<00:08,  1.66it/s, loss=1.86, accuracy=0.328][A
Epoch 5:  96%|█████████▋| 376/390 [03:53<00:08,  1.66it/s, loss=1.83, accuracy=0.344][A
Epoch 5:  97%|█████████▋| 377/390 [03:54<00:07,  1.67it/s, loss=1.83, accuracy=0.344][A
Epoch 5:  97%|█████████▋| 377/390 [03:54<00:07,  1.67it/s, loss=1.65, accuracy=0.305][A
Epoch 5:  97%|█████████▋| 378/390 [03:55<00:07,  1.66it/s, loss=1.65, accuracy=0.305][A
Epoch 5:  97%|█████████▋| 378/390 [03:55<00:07,  1.66it/s, loss=1.61, accuracy=0.344][A
Epoch 5:  97%|█████████▋| 379/390 [03:55<00:06,  1.66it/s, loss=1.61, accuracy=0.344][A
Epoch 5:  97%|█████████▋| 379/390 [03:55<00:06,  1.66it/s, loss=1.57, accuracy=0.383][A
Epoch 5:  97%|█████████▋| 380/390 [03:56<00:06,  1.65it/s, loss=1.57, accuracy=0.383][A
Epoch 5:  97%|█████████▋| 380/390 [03:56<00:06,  1.65it/s, loss=1.66, accuracy=0.445][A
Epoch 5:  98%|█████████▊| 381/390 [03:56<00:05,  1.65it/s, loss=1.66, accuracy=0.445][A
Epoch 5:  98%|███████

Finished.
final validation accuracy:





# TRAIN MODEL USING SAT

In [61]:
def adversarial_training_entropy(model, optimiser, loss_fn, x, y, epoch, adversary, k, step, eps, norm, gamma):
    """Performs a single update against a specified adversary"""
    model.train()
    
    # Adversial perturbation
    #alpha = 0.8
    N = 1
    loss = 0
    for l in range(N):
        x_adv = adversary(model, x, y, loss_fn, k=k, step=step, eps=eps, norm=norm, random=True, gamma=gamma)
        
        optimiser.zero_grad()
        y_pred = model(x_adv)
        loss = loss + loss_fn(y_pred,y)
        #loss = (1-alpha)*loss + alpha*loss_fn(y_pred, y)
    loss = loss/N
    loss.backward()
    optimiser.step()
    
    return loss, y_pred

In [27]:
if TrainSAT2:
    ## initialize model
    #model_SAT2 = model_cnn().to(DEVICE)
    model_SAT2 = Net().to(DEVICE)
    ## train params
    lr = 0.1
    optimiser = optim.SGD(model_SAT2.parameters(), lr=lr)
    epochs = 5
    ## train model
    training_history_entropySmoothing = olympic.fit(
        model_SAT2,
        optimiser,
        nn.CrossEntropyLoss(),
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=adversarial_training_entropy,
        update_fn_kwargs={'adversary': entropySmoothing, 'k': 10, 'step': 0.01, 'eps': 3.0, 'norm': 2, 'gamma':1e-5},
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5, factor=0.5, min_delta=0.005, monitor='val_accuracy')
        ]
    )
    ## verify validation
    print('final validation accuracy:')
    olympic.evaluate(model_SAT2, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    ## save model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_SAT2_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_SAT2,modelname)


Epoch 1:   0%|          | 0/390 [00:00<?, ?it/s][A

Begin training...



Epoch 1:   0%|          | 1/390 [00:01<08:10,  1.26s/it][A
Epoch 1:   0%|          | 1/390 [00:01<08:10,  1.26s/it, loss=2.85, accuracy=0][A
Epoch 1:   1%|          | 2/390 [00:01<06:46,  1.05s/it, loss=2.85, accuracy=0][A
Epoch 1:   1%|          | 2/390 [00:01<06:46,  1.05s/it, loss=3.57, accuracy=0.102][A
Epoch 1:   1%|          | 3/390 [00:02<05:48,  1.11it/s, loss=3.57, accuracy=0.102][A
Epoch 1:   1%|          | 3/390 [00:02<05:48,  1.11it/s, loss=4, accuracy=0.0859]  [A
Epoch 1:   1%|          | 4/390 [00:02<05:08,  1.25it/s, loss=4, accuracy=0.0859][A
Epoch 1:   1%|          | 4/390 [00:02<05:08,  1.25it/s, loss=4.12, accuracy=0.125][A
Epoch 1:   1%|▏         | 5/390 [00:03<04:41,  1.37it/s, loss=4.12, accuracy=0.125][A
Epoch 1:   1%|▏         | 5/390 [00:03<04:41,  1.37it/s, loss=3.69, accuracy=0.156][A
Epoch 1:   2%|▏         | 6/390 [00:04<04:20,  1.47it/s, loss=3.69, accuracy=0.156][A
Epoch 1:   2%|▏         | 6/390 [00:04<04:20,  1.47it/s, loss=2.78, accuracy=0.

Epoch 1:  24%|██▍       | 94/390 [00:53<02:48,  1.76it/s, loss=1.91, accuracy=0.312][A
Epoch 1:  24%|██▍       | 94/390 [00:53<02:48,  1.76it/s, loss=1.78, accuracy=0.352][A
Epoch 1:  24%|██▍       | 95/390 [00:54<02:47,  1.76it/s, loss=1.78, accuracy=0.352][A
Epoch 1:  24%|██▍       | 95/390 [00:54<02:47,  1.76it/s, loss=1.74, accuracy=0.359][A
Epoch 1:  25%|██▍       | 96/390 [00:54<02:47,  1.76it/s, loss=1.74, accuracy=0.359][A
Epoch 1:  25%|██▍       | 96/390 [00:54<02:47,  1.76it/s, loss=1.67, accuracy=0.391][A
Epoch 1:  25%|██▍       | 97/390 [00:55<02:46,  1.76it/s, loss=1.67, accuracy=0.391][A
Epoch 1:  25%|██▍       | 97/390 [00:55<02:46,  1.76it/s, loss=1.8, accuracy=0.375] [A
Epoch 1:  25%|██▌       | 98/390 [00:55<02:46,  1.76it/s, loss=1.8, accuracy=0.375][A
Epoch 1:  25%|██▌       | 98/390 [00:55<02:46,  1.76it/s, loss=1.94, accuracy=0.312][A
Epoch 1:  25%|██▌       | 99/390 [00:56<02:45,  1.76it/s, loss=1.94, accuracy=0.312][A
Epoch 1:  25%|██▌       | 99/390 

Epoch 1:  48%|████▊     | 186/390 [01:47<02:05,  1.63it/s, loss=1.49, accuracy=0.492][A
Epoch 1:  48%|████▊     | 186/390 [01:47<02:05,  1.63it/s, loss=1.49, accuracy=0.516][A
Epoch 1:  48%|████▊     | 187/390 [01:47<02:03,  1.64it/s, loss=1.49, accuracy=0.516][A
Epoch 1:  48%|████▊     | 187/390 [01:47<02:03,  1.64it/s, loss=1.61, accuracy=0.422][A
Epoch 1:  48%|████▊     | 188/390 [01:48<02:03,  1.63it/s, loss=1.61, accuracy=0.422][A
Epoch 1:  48%|████▊     | 188/390 [01:48<02:03,  1.63it/s, loss=1.65, accuracy=0.398][A
Epoch 1:  48%|████▊     | 189/390 [01:48<02:03,  1.63it/s, loss=1.65, accuracy=0.398][A
Epoch 1:  48%|████▊     | 189/390 [01:48<02:03,  1.63it/s, loss=1.48, accuracy=0.422][A
Epoch 1:  49%|████▊     | 190/390 [01:49<02:01,  1.64it/s, loss=1.48, accuracy=0.422][A
Epoch 1:  49%|████▊     | 190/390 [01:49<02:01,  1.64it/s, loss=1.68, accuracy=0.398][A
Epoch 1:  49%|████▉     | 191/390 [01:50<02:00,  1.65it/s, loss=1.68, accuracy=0.398][A
Epoch 1:  49%|████▉  

Epoch 1:  71%|███████▏  | 278/390 [02:44<01:05,  1.70it/s, loss=1.49, accuracy=0.391][A
Epoch 1:  71%|███████▏  | 278/390 [02:44<01:05,  1.70it/s, loss=1.37, accuracy=0.477][A
Epoch 1:  72%|███████▏  | 279/390 [02:44<01:05,  1.70it/s, loss=1.37, accuracy=0.477][A
Epoch 1:  72%|███████▏  | 279/390 [02:44<01:05,  1.70it/s, loss=1.63, accuracy=0.406][A
Epoch 1:  72%|███████▏  | 280/390 [02:45<01:05,  1.69it/s, loss=1.63, accuracy=0.406][A
Epoch 1:  72%|███████▏  | 280/390 [02:45<01:05,  1.69it/s, loss=1.47, accuracy=0.445][A
Epoch 1:  72%|███████▏  | 281/390 [02:46<01:05,  1.67it/s, loss=1.47, accuracy=0.445][A
Epoch 1:  72%|███████▏  | 281/390 [02:46<01:05,  1.67it/s, loss=1.34, accuracy=0.539][A
Epoch 1:  72%|███████▏  | 282/390 [02:46<01:05,  1.65it/s, loss=1.34, accuracy=0.539][A
Epoch 1:  72%|███████▏  | 282/390 [02:46<01:05,  1.65it/s, loss=1.61, accuracy=0.453][A
Epoch 1:  73%|███████▎  | 283/390 [02:47<01:05,  1.63it/s, loss=1.61, accuracy=0.453][A
Epoch 1:  73%|███████

Epoch 1:  95%|█████████▍| 370/390 [03:39<00:11,  1.69it/s, loss=1.43, accuracy=0.469][A
Epoch 1:  95%|█████████▍| 370/390 [03:39<00:11,  1.69it/s, loss=1.28, accuracy=0.562][A
Epoch 1:  95%|█████████▌| 371/390 [03:40<00:11,  1.69it/s, loss=1.28, accuracy=0.562][A
Epoch 1:  95%|█████████▌| 371/390 [03:40<00:11,  1.69it/s, loss=1.32, accuracy=0.5]  [A
Epoch 1:  95%|█████████▌| 372/390 [03:41<00:10,  1.69it/s, loss=1.32, accuracy=0.5][A
Epoch 1:  95%|█████████▌| 372/390 [03:41<00:10,  1.69it/s, loss=1.33, accuracy=0.508][A
Epoch 1:  96%|█████████▌| 373/390 [03:41<00:10,  1.69it/s, loss=1.33, accuracy=0.508][A
Epoch 1:  96%|█████████▌| 373/390 [03:41<00:10,  1.69it/s, loss=1.34, accuracy=0.484][A
Epoch 1:  96%|█████████▌| 374/390 [03:42<00:09,  1.69it/s, loss=1.34, accuracy=0.484][A
Epoch 1:  96%|█████████▌| 374/390 [03:42<00:09,  1.69it/s, loss=1.46, accuracy=0.438][A
Epoch 1:  96%|█████████▌| 375/390 [03:42<00:08,  1.69it/s, loss=1.46, accuracy=0.438][A
Epoch 1:  96%|█████████

Epoch 2:  18%|█▊        | 72/390 [00:43<03:11,  1.66it/s, loss=1.28, accuracy=0.523][A
Epoch 2:  19%|█▊        | 73/390 [00:44<03:11,  1.66it/s, loss=1.28, accuracy=0.523][A
Epoch 2:  19%|█▊        | 73/390 [00:44<03:11,  1.66it/s, loss=1.16, accuracy=0.594][A
Epoch 2:  19%|█▉        | 74/390 [00:44<03:10,  1.66it/s, loss=1.16, accuracy=0.594][A
Epoch 2:  19%|█▉        | 74/390 [00:44<03:10,  1.66it/s, loss=1.36, accuracy=0.523][A
Epoch 2:  19%|█▉        | 75/390 [00:45<03:09,  1.66it/s, loss=1.36, accuracy=0.523][A
Epoch 2:  19%|█▉        | 75/390 [00:45<03:09,  1.66it/s, loss=1.18, accuracy=0.555][A
Epoch 2:  19%|█▉        | 76/390 [00:46<03:08,  1.66it/s, loss=1.18, accuracy=0.555][A
Epoch 2:  19%|█▉        | 76/390 [00:46<03:08,  1.66it/s, loss=1.42, accuracy=0.469][A
Epoch 2:  20%|█▉        | 77/390 [00:46<03:08,  1.66it/s, loss=1.42, accuracy=0.469][A
Epoch 2:  20%|█▉        | 77/390 [00:46<03:08,  1.66it/s, loss=1.37, accuracy=0.516][A
Epoch 2:  20%|██        | 78/390

Epoch 2:  42%|████▏     | 164/390 [01:39<02:12,  1.71it/s, loss=1.16, accuracy=0.547][A
Epoch 2:  42%|████▏     | 165/390 [01:39<02:12,  1.69it/s, loss=1.16, accuracy=0.547][A
Epoch 2:  42%|████▏     | 165/390 [01:39<02:12,  1.69it/s, loss=1.41, accuracy=0.508][A
Epoch 2:  43%|████▎     | 166/390 [01:40<02:12,  1.69it/s, loss=1.41, accuracy=0.508][A
Epoch 2:  43%|████▎     | 166/390 [01:40<02:12,  1.69it/s, loss=1.32, accuracy=0.523][A
Epoch 2:  43%|████▎     | 167/390 [01:40<02:13,  1.68it/s, loss=1.32, accuracy=0.523][A
Epoch 2:  43%|████▎     | 167/390 [01:40<02:13,  1.68it/s, loss=1.42, accuracy=0.469][A
Epoch 2:  43%|████▎     | 168/390 [01:41<02:13,  1.67it/s, loss=1.42, accuracy=0.469][A
Epoch 2:  43%|████▎     | 168/390 [01:41<02:13,  1.67it/s, loss=1.34, accuracy=0.516][A
Epoch 2:  43%|████▎     | 169/390 [01:42<02:12,  1.66it/s, loss=1.34, accuracy=0.516][A
Epoch 2:  43%|████▎     | 169/390 [01:42<02:12,  1.66it/s, loss=1.22, accuracy=0.555][A
Epoch 2:  44%|████▎  

Epoch 2:  66%|██████▌   | 256/390 [02:35<01:22,  1.62it/s, loss=1.2, accuracy=0.523] [A
Epoch 2:  66%|██████▌   | 257/390 [02:36<01:22,  1.62it/s, loss=1.2, accuracy=0.523][A
Epoch 2:  66%|██████▌   | 257/390 [02:36<01:22,  1.62it/s, loss=1.33, accuracy=0.555][A
Epoch 2:  66%|██████▌   | 258/390 [02:36<01:21,  1.62it/s, loss=1.33, accuracy=0.555][A
Epoch 2:  66%|██████▌   | 258/390 [02:36<01:21,  1.62it/s, loss=1.29, accuracy=0.516][A
Epoch 2:  66%|██████▋   | 259/390 [02:37<01:20,  1.63it/s, loss=1.29, accuracy=0.516][A
Epoch 2:  66%|██████▋   | 259/390 [02:37<01:20,  1.63it/s, loss=1.21, accuracy=0.562][A
Epoch 2:  67%|██████▋   | 260/390 [02:37<01:19,  1.64it/s, loss=1.21, accuracy=0.562][A
Epoch 2:  67%|██████▋   | 260/390 [02:37<01:19,  1.64it/s, loss=1.06, accuracy=0.617][A
Epoch 2:  67%|██████▋   | 261/390 [02:38<01:18,  1.64it/s, loss=1.06, accuracy=0.617][A
Epoch 2:  67%|██████▋   | 261/390 [02:38<01:18,  1.64it/s, loss=1.3, accuracy=0.477] [A
Epoch 2:  67%|██████▋ 

Epoch 2:  89%|████████▉ | 348/390 [03:32<00:25,  1.64it/s, loss=1.06, accuracy=0.602][A
Epoch 2:  89%|████████▉ | 349/390 [03:32<00:25,  1.62it/s, loss=1.06, accuracy=0.602][A
Epoch 2:  89%|████████▉ | 349/390 [03:32<00:25,  1.62it/s, loss=1.03, accuracy=0.641][A
Epoch 2:  90%|████████▉ | 350/390 [03:33<00:24,  1.61it/s, loss=1.03, accuracy=0.641][A
Epoch 2:  90%|████████▉ | 350/390 [03:33<00:24,  1.61it/s, loss=1.12, accuracy=0.586][A
Epoch 2:  90%|█████████ | 351/390 [03:33<00:24,  1.61it/s, loss=1.12, accuracy=0.586][A
Epoch 2:  90%|█████████ | 351/390 [03:33<00:24,  1.61it/s, loss=1.09, accuracy=0.594][A
Epoch 2:  90%|█████████ | 352/390 [03:34<00:23,  1.61it/s, loss=1.09, accuracy=0.594][A
Epoch 2:  90%|█████████ | 352/390 [03:34<00:23,  1.61it/s, loss=0.897, accuracy=0.68][A
Epoch 2:  91%|█████████ | 353/390 [03:35<00:23,  1.60it/s, loss=0.897, accuracy=0.68][A
Epoch 2:  91%|█████████ | 353/390 [03:35<00:23,  1.60it/s, loss=1.16, accuracy=0.617][A
Epoch 2:  91%|███████

Epoch 3:  13%|█▎        | 50/390 [00:31<03:31,  1.61it/s, loss=1.02, accuracy=0.688][A
Epoch 3:  13%|█▎        | 51/390 [00:32<03:31,  1.61it/s, loss=1.02, accuracy=0.688][A
Epoch 3:  13%|█▎        | 51/390 [00:32<03:31,  1.61it/s, loss=1.06, accuracy=0.609][A
Epoch 3:  13%|█▎        | 52/390 [00:32<03:30,  1.60it/s, loss=1.06, accuracy=0.609][A
Epoch 3:  13%|█▎        | 52/390 [00:32<03:30,  1.60it/s, loss=1.05, accuracy=0.617][A
Epoch 3:  14%|█▎        | 53/390 [00:33<03:29,  1.61it/s, loss=1.05, accuracy=0.617][A
Epoch 3:  14%|█▎        | 53/390 [00:33<03:29,  1.61it/s, loss=0.999, accuracy=0.617][A
Epoch 3:  14%|█▍        | 54/390 [00:33<03:26,  1.62it/s, loss=0.999, accuracy=0.617][A
Epoch 3:  14%|█▍        | 54/390 [00:33<03:26,  1.62it/s, loss=1.27, accuracy=0.57]  [A
Epoch 3:  14%|█▍        | 55/390 [00:34<03:24,  1.63it/s, loss=1.27, accuracy=0.57][A
Epoch 3:  14%|█▍        | 55/390 [00:34<03:24,  1.63it/s, loss=0.87, accuracy=0.719][A
Epoch 3:  14%|█▍        | 56/3

Epoch 3:  36%|███▋      | 142/390 [01:28<02:34,  1.61it/s, loss=0.995, accuracy=0.648][A
Epoch 3:  36%|███▋      | 142/390 [01:28<02:34,  1.61it/s, loss=1.12, accuracy=0.602] [A
Epoch 3:  37%|███▋      | 143/390 [01:29<02:33,  1.60it/s, loss=1.12, accuracy=0.602][A
Epoch 3:  37%|███▋      | 143/390 [01:29<02:33,  1.60it/s, loss=1.06, accuracy=0.609][A
Epoch 3:  37%|███▋      | 144/390 [01:29<02:33,  1.61it/s, loss=1.06, accuracy=0.609][A
Epoch 3:  37%|███▋      | 144/390 [01:29<02:33,  1.61it/s, loss=0.965, accuracy=0.656][A
Epoch 3:  37%|███▋      | 145/390 [01:30<02:32,  1.60it/s, loss=0.965, accuracy=0.656][A
Epoch 3:  37%|███▋      | 145/390 [01:30<02:32,  1.60it/s, loss=1.01, accuracy=0.633] [A
Epoch 3:  37%|███▋      | 146/390 [01:30<02:31,  1.61it/s, loss=1.01, accuracy=0.633][A
Epoch 3:  37%|███▋      | 146/390 [01:30<02:31,  1.61it/s, loss=0.973, accuracy=0.695][A
Epoch 3:  38%|███▊      | 147/390 [01:31<02:30,  1.61it/s, loss=0.973, accuracy=0.695][A
Epoch 3:  38%|

Epoch 3:  60%|█████▉    | 233/390 [02:24<01:38,  1.59it/s, loss=0.952, accuracy=0.633][A
Epoch 3:  60%|█████▉    | 233/390 [02:24<01:38,  1.59it/s, loss=0.879, accuracy=0.656][A
Epoch 3:  60%|██████    | 234/390 [02:25<01:38,  1.58it/s, loss=0.879, accuracy=0.656][A
Epoch 3:  60%|██████    | 234/390 [02:25<01:38,  1.58it/s, loss=0.952, accuracy=0.625][A
Epoch 3:  60%|██████    | 235/390 [02:25<01:40,  1.54it/s, loss=0.952, accuracy=0.625][A
Epoch 3:  60%|██████    | 235/390 [02:25<01:40,  1.54it/s, loss=0.833, accuracy=0.695][A
Epoch 3:  61%|██████    | 236/390 [02:26<01:41,  1.52it/s, loss=0.833, accuracy=0.695][A
Epoch 3:  61%|██████    | 236/390 [02:26<01:41,  1.52it/s, loss=0.901, accuracy=0.641][A
Epoch 3:  61%|██████    | 237/390 [02:27<01:44,  1.47it/s, loss=0.901, accuracy=0.641][A
Epoch 3:  61%|██████    | 237/390 [02:27<01:44,  1.47it/s, loss=0.766, accuracy=0.703][A
Epoch 3:  61%|██████    | 238/390 [02:27<01:41,  1.50it/s, loss=0.766, accuracy=0.703][A
Epoch 3:  

Epoch 3:  83%|████████▎ | 324/390 [03:21<00:39,  1.67it/s, loss=0.912, accuracy=0.664][A
Epoch 3:  83%|████████▎ | 324/390 [03:21<00:39,  1.67it/s, loss=0.884, accuracy=0.664][A
Epoch 3:  83%|████████▎ | 325/390 [03:21<00:39,  1.67it/s, loss=0.884, accuracy=0.664][A
Epoch 3:  83%|████████▎ | 325/390 [03:21<00:39,  1.67it/s, loss=1.11, accuracy=0.602] [A
Epoch 3:  84%|████████▎ | 326/390 [03:22<00:38,  1.65it/s, loss=1.11, accuracy=0.602][A
Epoch 3:  84%|████████▎ | 326/390 [03:22<00:38,  1.65it/s, loss=0.761, accuracy=0.766][A
Epoch 3:  84%|████████▍ | 327/390 [03:23<00:38,  1.63it/s, loss=0.761, accuracy=0.766][A
Epoch 3:  84%|████████▍ | 327/390 [03:23<00:38,  1.63it/s, loss=0.888, accuracy=0.641][A
Epoch 3:  84%|████████▍ | 328/390 [03:23<00:38,  1.62it/s, loss=0.888, accuracy=0.641][A
Epoch 3:  84%|████████▍ | 328/390 [03:23<00:38,  1.62it/s, loss=1.05, accuracy=0.641] [A
Epoch 3:  84%|████████▍ | 329/390 [03:24<00:37,  1.61it/s, loss=1.05, accuracy=0.641][A
Epoch 3:  84

Epoch 4:   6%|▋         | 25/390 [00:15<03:41,  1.65it/s, loss=0.952, accuracy=0.617][A
Epoch 4:   6%|▋         | 25/390 [00:15<03:41,  1.65it/s, loss=0.73, accuracy=0.734] [A
Epoch 4:   7%|▋         | 26/390 [00:16<03:40,  1.65it/s, loss=0.73, accuracy=0.734][A
Epoch 4:   7%|▋         | 26/390 [00:16<03:40,  1.65it/s, loss=0.75, accuracy=0.719][A
Epoch 4:   7%|▋         | 27/390 [00:16<03:38,  1.66it/s, loss=0.75, accuracy=0.719][A
Epoch 4:   7%|▋         | 27/390 [00:16<03:38,  1.66it/s, loss=1.01, accuracy=0.688][A
Epoch 4:   7%|▋         | 28/390 [00:17<03:35,  1.68it/s, loss=1.01, accuracy=0.688][A
Epoch 4:   7%|▋         | 28/390 [00:17<03:35,  1.68it/s, loss=0.922, accuracy=0.641][A
Epoch 4:   7%|▋         | 29/390 [00:18<03:33,  1.69it/s, loss=0.922, accuracy=0.641][A
Epoch 4:   7%|▋         | 29/390 [00:18<03:33,  1.69it/s, loss=0.829, accuracy=0.703][A
Epoch 4:   8%|▊         | 30/390 [00:18<03:31,  1.70it/s, loss=0.829, accuracy=0.703][A
Epoch 4:   8%|▊         | 

Epoch 4:  30%|██▉       | 116/390 [01:11<02:47,  1.63it/s, loss=0.719, accuracy=0.727][A
Epoch 4:  30%|███       | 117/390 [01:12<02:48,  1.62it/s, loss=0.719, accuracy=0.727][A
Epoch 4:  30%|███       | 117/390 [01:12<02:48,  1.62it/s, loss=0.895, accuracy=0.695][A
Epoch 4:  30%|███       | 118/390 [01:13<02:48,  1.61it/s, loss=0.895, accuracy=0.695][A
Epoch 4:  30%|███       | 118/390 [01:13<02:48,  1.61it/s, loss=0.803, accuracy=0.734][A
Epoch 4:  31%|███       | 119/390 [01:13<02:48,  1.61it/s, loss=0.803, accuracy=0.734][A
Epoch 4:  31%|███       | 119/390 [01:13<02:48,  1.61it/s, loss=0.912, accuracy=0.664][A
Epoch 4:  31%|███       | 120/390 [01:14<02:47,  1.61it/s, loss=0.912, accuracy=0.664][A
Epoch 4:  31%|███       | 120/390 [01:14<02:47,  1.61it/s, loss=0.782, accuracy=0.641][A
Epoch 4:  31%|███       | 121/390 [01:14<02:46,  1.61it/s, loss=0.782, accuracy=0.641][A
Epoch 4:  31%|███       | 121/390 [01:14<02:46,  1.61it/s, loss=0.783, accuracy=0.75] [A
Epoch 4:  

Epoch 4:  53%|█████▎    | 207/390 [02:07<01:51,  1.64it/s, loss=0.804, accuracy=0.711][A
Epoch 4:  53%|█████▎    | 208/390 [02:08<01:51,  1.63it/s, loss=0.804, accuracy=0.711][A
Epoch 4:  53%|█████▎    | 208/390 [02:08<01:51,  1.63it/s, loss=0.792, accuracy=0.727][A
Epoch 4:  54%|█████▎    | 209/390 [02:09<01:51,  1.62it/s, loss=0.792, accuracy=0.727][A
Epoch 4:  54%|█████▎    | 209/390 [02:09<01:51,  1.62it/s, loss=0.7, accuracy=0.758]  [A
Epoch 4:  54%|█████▍    | 210/390 [02:09<01:51,  1.61it/s, loss=0.7, accuracy=0.758][A
Epoch 4:  54%|█████▍    | 210/390 [02:09<01:51,  1.61it/s, loss=0.856, accuracy=0.648][A
Epoch 4:  54%|█████▍    | 211/390 [02:10<01:51,  1.61it/s, loss=0.856, accuracy=0.648][A
Epoch 4:  54%|█████▍    | 211/390 [02:10<01:51,  1.61it/s, loss=0.846, accuracy=0.711][A
Epoch 4:  54%|█████▍    | 212/390 [02:10<01:50,  1.61it/s, loss=0.846, accuracy=0.711][A
Epoch 4:  54%|█████▍    | 212/390 [02:10<01:50,  1.61it/s, loss=0.902, accuracy=0.633][A
Epoch 4:  55

Epoch 4:  76%|███████▋  | 298/390 [03:04<00:56,  1.63it/s, loss=0.792, accuracy=0.688][A
Epoch 4:  77%|███████▋  | 299/390 [03:04<00:56,  1.62it/s, loss=0.792, accuracy=0.688][A
Epoch 4:  77%|███████▋  | 299/390 [03:04<00:56,  1.62it/s, loss=0.711, accuracy=0.766][A
Epoch 4:  77%|███████▋  | 300/390 [03:05<00:55,  1.61it/s, loss=0.711, accuracy=0.766][A
Epoch 4:  77%|███████▋  | 300/390 [03:05<00:55,  1.61it/s, loss=0.854, accuracy=0.688][A
Epoch 4:  77%|███████▋  | 301/390 [03:05<00:55,  1.61it/s, loss=0.854, accuracy=0.688][A
Epoch 4:  77%|███████▋  | 301/390 [03:05<00:55,  1.61it/s, loss=0.641, accuracy=0.758][A
Epoch 4:  77%|███████▋  | 302/390 [03:06<00:54,  1.60it/s, loss=0.641, accuracy=0.758][A
Epoch 4:  77%|███████▋  | 302/390 [03:06<00:54,  1.60it/s, loss=0.592, accuracy=0.789][A
Epoch 4:  78%|███████▊  | 303/390 [03:07<00:54,  1.60it/s, loss=0.592, accuracy=0.789][A
Epoch 4:  78%|███████▊  | 303/390 [03:07<00:54,  1.60it/s, loss=0.667, accuracy=0.75] [A
Epoch 4:  

Epoch 4: 100%|█████████▉| 389/390 [04:00<00:00,  1.64it/s, loss=0.663, accuracy=0.781][A
Epoch 4: 100%|██████████| 390/390 [04:02<00:00,  1.27s/it, loss=0.663, accuracy=0.781][A
Epoch 4: 100%|██████████| 390/390 [04:02<00:00,  1.61it/s, loss=0.803, accuracy=0.709, val_loss=0.778, val_accuracy=0.732][A

Epoch 5:   0%|          | 0/390 [00:00<?, ?it/s][A
Epoch 5:   0%|          | 1/390 [00:01<08:12,  1.27s/it][A
Epoch 5:   0%|          | 1/390 [00:01<08:12,  1.27s/it, loss=0.552, accuracy=0.789][A
Epoch 5:   1%|          | 2/390 [00:01<06:51,  1.06s/it, loss=0.552, accuracy=0.789][A
Epoch 5:   1%|          | 2/390 [00:01<06:51,  1.06s/it, loss=0.743, accuracy=0.734][A
Epoch 5:   1%|          | 3/390 [00:02<05:56,  1.09it/s, loss=0.743, accuracy=0.734][A
Epoch 5:   1%|          | 3/390 [00:02<05:56,  1.09it/s, loss=0.627, accuracy=0.758][A
Epoch 5:   1%|          | 4/390 [00:03<05:19,  1.21it/s, loss=0.627, accuracy=0.758][A
Epoch 5:   1%|          | 4/390 [00:03<05:19,  1.21it

Epoch 5:  23%|██▎       | 91/390 [00:56<03:05,  1.61it/s, loss=0.635, accuracy=0.773][A
Epoch 5:  23%|██▎       | 91/390 [00:56<03:05,  1.61it/s, loss=0.581, accuracy=0.758][A
Epoch 5:  24%|██▎       | 92/390 [00:57<03:04,  1.62it/s, loss=0.581, accuracy=0.758][A
Epoch 5:  24%|██▎       | 92/390 [00:57<03:04,  1.62it/s, loss=0.649, accuracy=0.797][A
Epoch 5:  24%|██▍       | 93/390 [00:58<03:03,  1.62it/s, loss=0.649, accuracy=0.797][A
Epoch 5:  24%|██▍       | 93/390 [00:58<03:03,  1.62it/s, loss=0.818, accuracy=0.719][A
Epoch 5:  24%|██▍       | 94/390 [00:58<03:03,  1.61it/s, loss=0.818, accuracy=0.719][A
Epoch 5:  24%|██▍       | 94/390 [00:58<03:03,  1.61it/s, loss=0.672, accuracy=0.703][A
Epoch 5:  24%|██▍       | 95/390 [00:59<03:03,  1.61it/s, loss=0.672, accuracy=0.703][A
Epoch 5:  24%|██▍       | 95/390 [00:59<03:03,  1.61it/s, loss=0.713, accuracy=0.734][A
Epoch 5:  25%|██▍       | 96/390 [01:00<03:03,  1.60it/s, loss=0.713, accuracy=0.734][A
Epoch 5:  25%|██▍    

Epoch 5:  47%|████▋     | 182/390 [01:52<02:04,  1.67it/s, loss=0.754, accuracy=0.742][A
Epoch 5:  47%|████▋     | 182/390 [01:52<02:04,  1.67it/s, loss=0.701, accuracy=0.734][A
Epoch 5:  47%|████▋     | 183/390 [01:53<02:03,  1.68it/s, loss=0.701, accuracy=0.734][A
Epoch 5:  47%|████▋     | 183/390 [01:53<02:03,  1.68it/s, loss=0.559, accuracy=0.844][A
Epoch 5:  47%|████▋     | 184/390 [01:54<02:03,  1.67it/s, loss=0.559, accuracy=0.844][A
Epoch 5:  47%|████▋     | 184/390 [01:54<02:03,  1.67it/s, loss=0.817, accuracy=0.695][A
Epoch 5:  47%|████▋     | 185/390 [01:54<02:03,  1.66it/s, loss=0.817, accuracy=0.695][A
Epoch 5:  47%|████▋     | 185/390 [01:54<02:03,  1.66it/s, loss=0.713, accuracy=0.758][A
Epoch 5:  48%|████▊     | 186/390 [01:55<02:03,  1.65it/s, loss=0.713, accuracy=0.758][A
Epoch 5:  48%|████▊     | 186/390 [01:55<02:03,  1.65it/s, loss=0.542, accuracy=0.844][A
Epoch 5:  48%|████▊     | 187/390 [01:55<02:04,  1.63it/s, loss=0.542, accuracy=0.844][A
Epoch 5:  

Epoch 5:  70%|███████   | 273/390 [02:48<01:13,  1.60it/s, loss=0.665, accuracy=0.797][A
Epoch 5:  70%|███████   | 273/390 [02:48<01:13,  1.60it/s, loss=0.584, accuracy=0.781][A
Epoch 5:  70%|███████   | 274/390 [02:49<01:12,  1.60it/s, loss=0.584, accuracy=0.781][A
Epoch 5:  70%|███████   | 274/390 [02:49<01:12,  1.60it/s, loss=0.702, accuracy=0.742][A
Epoch 5:  71%|███████   | 275/390 [02:50<01:12,  1.60it/s, loss=0.702, accuracy=0.742][A
Epoch 5:  71%|███████   | 275/390 [02:50<01:12,  1.60it/s, loss=0.671, accuracy=0.766][A
Epoch 5:  71%|███████   | 276/390 [02:50<01:11,  1.60it/s, loss=0.671, accuracy=0.766][A
Epoch 5:  71%|███████   | 276/390 [02:50<01:11,  1.60it/s, loss=0.571, accuracy=0.805][A
Epoch 5:  71%|███████   | 277/390 [02:51<01:09,  1.62it/s, loss=0.571, accuracy=0.805][A
Epoch 5:  71%|███████   | 277/390 [02:51<01:09,  1.62it/s, loss=0.62, accuracy=0.75]  [A
Epoch 5:  71%|███████▏  | 278/390 [02:51<01:08,  1.63it/s, loss=0.62, accuracy=0.75][A
Epoch 5:  71

Epoch 5:  93%|█████████▎| 364/390 [03:45<00:16,  1.62it/s, loss=0.627, accuracy=0.766][A
Epoch 5:  93%|█████████▎| 364/390 [03:45<00:16,  1.62it/s, loss=0.737, accuracy=0.734][A
Epoch 5:  94%|█████████▎| 365/390 [03:45<00:15,  1.62it/s, loss=0.737, accuracy=0.734][A
Epoch 5:  94%|█████████▎| 365/390 [03:45<00:15,  1.62it/s, loss=0.61, accuracy=0.773] [A
Epoch 5:  94%|█████████▍| 366/390 [03:46<00:14,  1.62it/s, loss=0.61, accuracy=0.773][A
Epoch 5:  94%|█████████▍| 366/390 [03:46<00:14,  1.62it/s, loss=0.499, accuracy=0.828][A
Epoch 5:  94%|█████████▍| 367/390 [03:46<00:14,  1.62it/s, loss=0.499, accuracy=0.828][A
Epoch 5:  94%|█████████▍| 367/390 [03:46<00:14,  1.62it/s, loss=0.617, accuracy=0.766][A
Epoch 5:  94%|█████████▍| 368/390 [03:47<00:13,  1.62it/s, loss=0.617, accuracy=0.766][A
Epoch 5:  94%|█████████▍| 368/390 [03:47<00:13,  1.62it/s, loss=0.652, accuracy=0.75] [A
Epoch 5:  95%|█████████▍| 369/390 [03:48<00:13,  1.61it/s, loss=0.652, accuracy=0.75][A
Epoch 5:  95

Finished.
final validation accuracy:





In [None]:
if TrainSATInf:
    ## initialize model
    model_SATInf = Net().to(DEVICE)
    ## train params
    #lr = 0.1
    lr = 0.05
    optimiser = optim.SGD(model_SATInf.parameters(), lr=lr)
    '''
    best model
    lr = 0.1
    epochs = 20
    k_satinf = 20
    step_satinf = 0.01
    eps_satinf = 0.2
    '''
    lr = 0.05
    epochs = 20
    k_satinf = 20
    step_satinf = 0.01
    eps_satinf = 0.2
    ## train model
    training_history_entropySmoothing = olympic.fit(
        model_SATInf,
        optimiser,
        nn.CrossEntropyLoss(),
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=adversarial_training_entropy,
        #update_fn_kwargs={'adversary': entropySmoothing, 'k': 10, 'step': 0.01, 'eps': 0.3, 'norm': 'inf', 'gamma':1e-5},
        #update_fn_kwargs={'adversary': entropySmoothing, 'k': 20, 'step': 0.015, 'eps': 0.3, 'norm': 'inf', 'gamma':1e-5},
        #update_fn_kwargs={'adversary': entropySmoothing, 'k': 20, 'step': 0.015, 'eps': 0.3, 'norm': 'inf', 'gamma':1e-5},
        update_fn_kwargs={'adversary': entropySmoothing, 'k': k_satinf, 'step': step_satinf, 'eps': eps_satinf, 'norm': 'inf', 'gamma':1e-5},
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5, factor=0.5, min_delta=0.005, monitor='val_accuracy')
        ]
    )
    ## verify validation
    print('final validation accuracy:')
    olympic.evaluate(model_SATInf, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    ## save model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_SATInf_ep'+str(epochs)+'_k_'+str(k_satinf)+'_step_'+str(step_satinf)+'_lr'+str(lr)+'.pt'
    torch.save(model_SATInf,modelname)



Epoch 1:   0%|          | 0/390 [00:00<?, ?it/s][A[A

Begin training...




Epoch 1:   0%|          | 1/390 [00:01<08:05,  1.25s/it][A[A

Epoch 1:   0%|          | 1/390 [00:01<08:05,  1.25s/it, loss=2.61, accuracy=0][A[A

Epoch 1:   1%|          | 2/390 [00:01<06:44,  1.04s/it, loss=2.61, accuracy=0][A[A

Epoch 1:   1%|          | 2/390 [00:01<06:44,  1.04s/it, loss=2.74, accuracy=0.0938][A[A

Epoch 1:   1%|          | 3/390 [00:02<05:47,  1.11it/s, loss=2.74, accuracy=0.0938][A[A

Epoch 1:   1%|          | 3/390 [00:02<05:47,  1.11it/s, loss=2.79, accuracy=0.0938][A[A

Epoch 1:   1%|          | 4/390 [00:02<05:07,  1.25it/s, loss=2.79, accuracy=0.0938][A[A

Epoch 1:   1%|          | 4/390 [00:02<05:07,  1.25it/s, loss=2.62, accuracy=0.109] [A[A

Epoch 1:   1%|▏         | 5/390 [00:03<04:40,  1.37it/s, loss=2.62, accuracy=0.109][A[A

Epoch 1:   1%|▏         | 5/390 [00:03<04:40,  1.37it/s, loss=2.56, accuracy=0.172][A[A

Epoch 1:   2%|▏         | 6/390 [00:04<04:20,  1.48it/s, loss=2.56, accuracy=0.172][A[A

Epoch 1:   2%|▏         | 6

Epoch 1:  23%|██▎       | 90/390 [00:51<02:51,  1.75it/s, loss=1.83, accuracy=0.266][A[A

Epoch 1:  23%|██▎       | 90/390 [00:51<02:51,  1.75it/s, loss=1.83, accuracy=0.32] [A[A

Epoch 1:  23%|██▎       | 91/390 [00:52<02:51,  1.75it/s, loss=1.83, accuracy=0.32][A[A

Epoch 1:  23%|██▎       | 91/390 [00:52<02:51,  1.75it/s, loss=1.73, accuracy=0.375][A[A

Epoch 1:  24%|██▎       | 92/390 [00:52<02:50,  1.74it/s, loss=1.73, accuracy=0.375][A[A

Epoch 1:  24%|██▎       | 92/390 [00:52<02:50,  1.74it/s, loss=1.59, accuracy=0.367][A[A

Epoch 1:  24%|██▍       | 93/390 [00:53<02:52,  1.73it/s, loss=1.59, accuracy=0.367][A[A

Epoch 1:  24%|██▍       | 93/390 [00:53<02:52,  1.73it/s, loss=1.91, accuracy=0.328][A[A

Epoch 1:  24%|██▍       | 94/390 [00:54<02:52,  1.72it/s, loss=1.91, accuracy=0.328][A[A

Epoch 1:  24%|██▍       | 94/390 [00:54<02:52,  1.72it/s, loss=1.75, accuracy=0.344][A[A

Epoch 1:  24%|██▍       | 95/390 [00:54<02:51,  1.72it/s, loss=1.75, accuracy=0.3

Epoch 1:  46%|████▌     | 178/390 [01:46<02:12,  1.60it/s, loss=1.66, accuracy=0.398][A[A

Epoch 1:  46%|████▌     | 178/390 [01:46<02:12,  1.60it/s, loss=1.75, accuracy=0.312][A[A

Epoch 1:  46%|████▌     | 179/390 [01:46<02:12,  1.60it/s, loss=1.75, accuracy=0.312][A[A

Epoch 1:  46%|████▌     | 179/390 [01:46<02:12,  1.60it/s, loss=1.76, accuracy=0.328][A[A

Epoch 1:  46%|████▌     | 180/390 [01:47<02:11,  1.60it/s, loss=1.76, accuracy=0.328][A[A

Epoch 1:  46%|████▌     | 180/390 [01:47<02:11,  1.60it/s, loss=1.71, accuracy=0.383][A[A

Epoch 1:  46%|████▋     | 181/390 [01:48<02:10,  1.60it/s, loss=1.71, accuracy=0.383][A[A

Epoch 1:  46%|████▋     | 181/390 [01:48<02:10,  1.60it/s, loss=1.6, accuracy=0.359] [A[A

Epoch 1:  47%|████▋     | 182/390 [01:48<02:10,  1.60it/s, loss=1.6, accuracy=0.359][A[A

Epoch 1:  47%|████▋     | 182/390 [01:48<02:10,  1.60it/s, loss=1.63, accuracy=0.391][A[A

Epoch 1:  47%|████▋     | 183/390 [01:49<02:09,  1.59it/s, loss=1.63, a

Epoch 1:  68%|██████▊   | 266/390 [02:40<01:17,  1.60it/s, loss=1.65, accuracy=0.383][A[A

Epoch 1:  68%|██████▊   | 266/390 [02:40<01:17,  1.60it/s, loss=1.68, accuracy=0.344][A[A

Epoch 1:  68%|██████▊   | 267/390 [02:41<01:15,  1.62it/s, loss=1.68, accuracy=0.344][A[A

Epoch 1:  68%|██████▊   | 267/390 [02:41<01:15,  1.62it/s, loss=1.57, accuracy=0.406][A[A

Epoch 1:  69%|██████▊   | 268/390 [02:42<01:14,  1.63it/s, loss=1.57, accuracy=0.406][A[A

Epoch 1:  69%|██████▊   | 268/390 [02:42<01:14,  1.63it/s, loss=1.46, accuracy=0.469][A[A

Epoch 1:  69%|██████▉   | 269/390 [02:42<01:13,  1.65it/s, loss=1.46, accuracy=0.469][A[A

Epoch 1:  69%|██████▉   | 269/390 [02:42<01:13,  1.65it/s, loss=1.49, accuracy=0.406][A[A

Epoch 1:  69%|██████▉   | 270/390 [02:43<01:12,  1.65it/s, loss=1.49, accuracy=0.406][A[A

Epoch 1:  69%|██████▉   | 270/390 [02:43<01:12,  1.65it/s, loss=1.65, accuracy=0.445][A[A

Epoch 1:  69%|██████▉   | 271/390 [02:43<01:11,  1.67it/s, loss=1.65, 

Epoch 1:  91%|█████████ | 354/390 [03:35<00:22,  1.61it/s, loss=1.64, accuracy=0.469][A[A

Epoch 1:  91%|█████████ | 354/390 [03:35<00:22,  1.61it/s, loss=1.68, accuracy=0.383][A[A

Epoch 1:  91%|█████████ | 355/390 [03:36<00:21,  1.60it/s, loss=1.68, accuracy=0.383][A[A

Epoch 1:  91%|█████████ | 355/390 [03:36<00:21,  1.60it/s, loss=1.32, accuracy=0.492][A[A

Epoch 1:  91%|█████████▏| 356/390 [03:36<00:21,  1.60it/s, loss=1.32, accuracy=0.492][A[A

Epoch 1:  91%|█████████▏| 356/390 [03:36<00:21,  1.60it/s, loss=1.33, accuracy=0.516][A[A

Epoch 1:  92%|█████████▏| 357/390 [03:37<00:20,  1.61it/s, loss=1.33, accuracy=0.516][A[A

Epoch 1:  92%|█████████▏| 357/390 [03:37<00:20,  1.61it/s, loss=1.52, accuracy=0.461][A[A

Epoch 1:  92%|█████████▏| 358/390 [03:37<00:19,  1.62it/s, loss=1.52, accuracy=0.461][A[A

Epoch 1:  92%|█████████▏| 358/390 [03:37<00:19,  1.62it/s, loss=1.59, accuracy=0.484][A[A

Epoch 1:  92%|█████████▏| 359/390 [03:38<00:19,  1.63it/s, loss=1.59, 

Epoch 2:  13%|█▎        | 52/390 [00:32<03:30,  1.61it/s, loss=1.43, accuracy=0.5][A[A

Epoch 2:  13%|█▎        | 52/390 [00:32<03:30,  1.61it/s, loss=1.41, accuracy=0.453][A[A

Epoch 2:  14%|█▎        | 53/390 [00:33<03:30,  1.60it/s, loss=1.41, accuracy=0.453][A[A

Epoch 2:  14%|█▎        | 53/390 [00:33<03:30,  1.60it/s, loss=1.32, accuracy=0.555][A[A

Epoch 2:  14%|█▍        | 54/390 [00:34<03:29,  1.60it/s, loss=1.32, accuracy=0.555][A[A

Epoch 2:  14%|█▍        | 54/390 [00:34<03:29,  1.60it/s, loss=1.56, accuracy=0.414][A[A

Epoch 2:  14%|█▍        | 55/390 [00:34<03:28,  1.60it/s, loss=1.56, accuracy=0.414][A[A

Epoch 2:  14%|█▍        | 55/390 [00:34<03:28,  1.60it/s, loss=1.43, accuracy=0.453][A[A

Epoch 2:  14%|█▍        | 56/390 [00:35<03:28,  1.60it/s, loss=1.43, accuracy=0.453][A[A

Epoch 2:  14%|█▍        | 56/390 [00:35<03:28,  1.60it/s, loss=1.48, accuracy=0.438][A[A

Epoch 2:  15%|█▍        | 57/390 [00:35<03:28,  1.60it/s, loss=1.48, accuracy=0.43

Epoch 2:  36%|███▌      | 140/390 [01:27<02:35,  1.61it/s, loss=1.54, accuracy=0.445][A[A

Epoch 2:  36%|███▌      | 141/390 [01:27<02:34,  1.61it/s, loss=1.54, accuracy=0.445][A[A

Epoch 2:  36%|███▌      | 141/390 [01:27<02:34,  1.61it/s, loss=1.55, accuracy=0.406][A[A

Epoch 2:  36%|███▋      | 142/390 [01:28<02:34,  1.61it/s, loss=1.55, accuracy=0.406][A[A

Epoch 2:  36%|███▋      | 142/390 [01:28<02:34,  1.61it/s, loss=1.49, accuracy=0.422][A[A

Epoch 2:  37%|███▋      | 143/390 [01:29<02:34,  1.60it/s, loss=1.49, accuracy=0.422][A[A

Epoch 2:  37%|███▋      | 143/390 [01:29<02:34,  1.60it/s, loss=1.41, accuracy=0.438][A[A

Epoch 2:  37%|███▋      | 144/390 [01:29<02:33,  1.60it/s, loss=1.41, accuracy=0.438][A[A

Epoch 2:  37%|███▋      | 144/390 [01:29<02:33,  1.60it/s, loss=1.39, accuracy=0.547][A[A

Epoch 2:  37%|███▋      | 145/390 [01:30<02:33,  1.60it/s, loss=1.39, accuracy=0.547][A[A

Epoch 2:  37%|███▋      | 145/390 [01:30<02:33,  1.60it/s, loss=1.43, 

Epoch 2:  58%|█████▊    | 228/390 [02:21<01:41,  1.60it/s, loss=1.58, accuracy=0.445][A[A

Epoch 2:  59%|█████▊    | 229/390 [02:22<01:40,  1.60it/s, loss=1.58, accuracy=0.445][A[A

Epoch 2:  59%|█████▊    | 229/390 [02:22<01:40,  1.60it/s, loss=1.3, accuracy=0.531] [A[A

Epoch 2:  59%|█████▉    | 230/390 [02:23<01:39,  1.60it/s, loss=1.3, accuracy=0.531][A[A

Epoch 2:  59%|█████▉    | 230/390 [02:23<01:39,  1.60it/s, loss=1.29, accuracy=0.477][A[A

Epoch 2:  59%|█████▉    | 231/390 [02:23<01:39,  1.60it/s, loss=1.29, accuracy=0.477][A[A

Epoch 2:  59%|█████▉    | 231/390 [02:23<01:39,  1.60it/s, loss=1.32, accuracy=0.516][A[A

Epoch 2:  59%|█████▉    | 232/390 [02:24<01:38,  1.60it/s, loss=1.32, accuracy=0.516][A[A

Epoch 2:  59%|█████▉    | 232/390 [02:24<01:38,  1.60it/s, loss=1.41, accuracy=0.469][A[A

Epoch 2:  60%|█████▉    | 233/390 [02:25<01:38,  1.60it/s, loss=1.41, accuracy=0.469][A[A

Epoch 2:  60%|█████▉    | 233/390 [02:25<01:38,  1.60it/s, loss=1.27, a

Epoch 2:  81%|████████  | 316/390 [03:16<00:46,  1.60it/s, loss=1.49, accuracy=0.484][A[A

Epoch 2:  81%|████████▏ | 317/390 [03:17<00:45,  1.60it/s, loss=1.49, accuracy=0.484][A[A

Epoch 2:  81%|████████▏ | 317/390 [03:17<00:45,  1.60it/s, loss=1.32, accuracy=0.508][A[A

Epoch 2:  82%|████████▏ | 318/390 [03:17<00:45,  1.60it/s, loss=1.32, accuracy=0.508][A[A

Epoch 2:  82%|████████▏ | 318/390 [03:17<00:45,  1.60it/s, loss=1.22, accuracy=0.578][A[A

Epoch 2:  82%|████████▏ | 319/390 [03:18<00:44,  1.60it/s, loss=1.22, accuracy=0.578][A[A

Epoch 2:  82%|████████▏ | 319/390 [03:18<00:44,  1.60it/s, loss=1.42, accuracy=0.484][A[A

Epoch 2:  82%|████████▏ | 320/390 [03:19<00:43,  1.60it/s, loss=1.42, accuracy=0.484][A[A

Epoch 2:  82%|████████▏ | 320/390 [03:19<00:43,  1.60it/s, loss=1.49, accuracy=0.492][A[A

Epoch 2:  82%|████████▏ | 321/390 [03:19<00:43,  1.60it/s, loss=1.49, accuracy=0.492][A[A

Epoch 2:  82%|████████▏ | 321/390 [03:19<00:43,  1.60it/s, loss=1.43, 

Epoch 3:   4%|▎         | 14/390 [00:09<03:56,  1.59it/s, loss=1.39, accuracy=0.562][A[A

Epoch 3:   4%|▎         | 14/390 [00:09<03:56,  1.59it/s, loss=1.47, accuracy=0.5]  [A[A

Epoch 3:   4%|▍         | 15/390 [00:09<03:55,  1.59it/s, loss=1.47, accuracy=0.5][A[A

Epoch 3:   4%|▍         | 15/390 [00:09<03:55,  1.59it/s, loss=1.17, accuracy=0.617][A[A

Epoch 3:   4%|▍         | 16/390 [00:10<03:53,  1.60it/s, loss=1.17, accuracy=0.617][A[A

Epoch 3:   4%|▍         | 16/390 [00:10<03:53,  1.60it/s, loss=1.45, accuracy=0.43] [A[A

Epoch 3:   4%|▍         | 17/390 [00:11<03:50,  1.62it/s, loss=1.45, accuracy=0.43][A[A

Epoch 3:   4%|▍         | 17/390 [00:11<03:50,  1.62it/s, loss=1.16, accuracy=0.562][A[A

Epoch 3:   5%|▍         | 18/390 [00:11<03:47,  1.64it/s, loss=1.16, accuracy=0.562][A[A

Epoch 3:   5%|▍         | 18/390 [00:11<03:47,  1.64it/s, loss=1.23, accuracy=0.562][A[A

Epoch 3:   5%|▍         | 19/390 [00:12<03:44,  1.65it/s, loss=1.23, accuracy=0.562

Epoch 3:  26%|██▋       | 103/390 [01:04<02:57,  1.62it/s, loss=1.31, accuracy=0.57][A[A

Epoch 3:  26%|██▋       | 103/390 [01:04<02:57,  1.62it/s, loss=1.3, accuracy=0.516][A[A

Epoch 3:  27%|██▋       | 104/390 [01:04<02:55,  1.63it/s, loss=1.3, accuracy=0.516][A[A

Epoch 3:  27%|██▋       | 104/390 [01:04<02:55,  1.63it/s, loss=1.32, accuracy=0.523][A[A

Epoch 3:  27%|██▋       | 105/390 [01:05<02:53,  1.64it/s, loss=1.32, accuracy=0.523][A[A

Epoch 3:  27%|██▋       | 105/390 [01:05<02:53,  1.64it/s, loss=1.35, accuracy=0.5]  [A[A

Epoch 3:  27%|██▋       | 106/390 [01:06<02:52,  1.64it/s, loss=1.35, accuracy=0.5][A[A

Epoch 3:  27%|██▋       | 106/390 [01:06<02:52,  1.64it/s, loss=1.13, accuracy=0.57][A[A

Epoch 3:  27%|██▋       | 107/390 [01:06<02:51,  1.65it/s, loss=1.13, accuracy=0.57][A[A

Epoch 3:  27%|██▋       | 107/390 [01:06<02:51,  1.65it/s, loss=1.34, accuracy=0.539][A[A

Epoch 3:  28%|██▊       | 108/390 [01:07<02:50,  1.65it/s, loss=1.34, accurac

Epoch 3:  49%|████▉     | 191/390 [01:58<02:04,  1.60it/s, loss=1.31, accuracy=0.516][A[A

Epoch 3:  49%|████▉     | 191/390 [01:58<02:04,  1.60it/s, loss=1.2, accuracy=0.531] [A[A

Epoch 3:  49%|████▉     | 192/390 [01:59<02:04,  1.60it/s, loss=1.2, accuracy=0.531][A[A

Epoch 3:  49%|████▉     | 192/390 [01:59<02:04,  1.60it/s, loss=1.13, accuracy=0.586][A[A

Epoch 3:  49%|████▉     | 193/390 [02:00<02:03,  1.60it/s, loss=1.13, accuracy=0.586][A[A

Epoch 3:  49%|████▉     | 193/390 [02:00<02:03,  1.60it/s, loss=1.09, accuracy=0.625][A[A

Epoch 3:  50%|████▉     | 194/390 [02:00<02:01,  1.61it/s, loss=1.09, accuracy=0.625][A[A

Epoch 3:  50%|████▉     | 194/390 [02:00<02:01,  1.61it/s, loss=1.28, accuracy=0.594][A[A

Epoch 3:  50%|█████     | 195/390 [02:01<02:00,  1.61it/s, loss=1.28, accuracy=0.594][A[A

Epoch 3:  50%|█████     | 195/390 [02:01<02:00,  1.61it/s, loss=1.34, accuracy=0.477][A[A

Epoch 3:  50%|█████     | 196/390 [02:02<02:00,  1.61it/s, loss=1.34, a

Epoch 3:  72%|███████▏  | 279/390 [02:53<01:09,  1.60it/s, loss=1.15, accuracy=0.562][A[A

Epoch 3:  72%|███████▏  | 279/390 [02:53<01:09,  1.60it/s, loss=1.25, accuracy=0.547][A[A

Epoch 3:  72%|███████▏  | 280/390 [02:54<01:08,  1.61it/s, loss=1.25, accuracy=0.547][A[A

Epoch 3:  72%|███████▏  | 280/390 [02:54<01:08,  1.61it/s, loss=1.2, accuracy=0.617] [A[A

Epoch 3:  72%|███████▏  | 281/390 [02:54<01:08,  1.60it/s, loss=1.2, accuracy=0.617][A[A

Epoch 3:  72%|███████▏  | 281/390 [02:54<01:08,  1.60it/s, loss=1.11, accuracy=0.555][A[A

Epoch 3:  72%|███████▏  | 282/390 [02:55<01:07,  1.60it/s, loss=1.11, accuracy=0.555][A[A

Epoch 3:  72%|███████▏  | 282/390 [02:55<01:07,  1.60it/s, loss=1.19, accuracy=0.594][A[A

Epoch 3:  73%|███████▎  | 283/390 [02:55<01:06,  1.60it/s, loss=1.19, accuracy=0.594][A[A

Epoch 3:  73%|███████▎  | 283/390 [02:55<01:06,  1.60it/s, loss=1.25, accuracy=0.555][A[A

Epoch 3:  73%|███████▎  | 284/390 [02:56<01:06,  1.60it/s, loss=1.25, a

Epoch 3:  94%|█████████▍| 367/390 [03:48<00:14,  1.60it/s, loss=1.02, accuracy=0.609][A[A

Epoch 3:  94%|█████████▍| 367/390 [03:48<00:14,  1.60it/s, loss=1.23, accuracy=0.594][A[A

Epoch 3:  94%|█████████▍| 368/390 [03:48<00:13,  1.59it/s, loss=1.23, accuracy=0.594][A[A

Epoch 3:  94%|█████████▍| 368/390 [03:48<00:13,  1.59it/s, loss=1.3, accuracy=0.57]  [A[A

Epoch 3:  95%|█████████▍| 369/390 [03:49<00:13,  1.59it/s, loss=1.3, accuracy=0.57][A[A

Epoch 3:  95%|█████████▍| 369/390 [03:49<00:13,  1.59it/s, loss=1.02, accuracy=0.625][A[A

Epoch 3:  95%|█████████▍| 370/390 [03:50<00:12,  1.59it/s, loss=1.02, accuracy=0.625][A[A

Epoch 3:  95%|█████████▍| 370/390 [03:50<00:12,  1.59it/s, loss=1.06, accuracy=0.625][A[A

Epoch 3:  95%|█████████▌| 371/390 [03:50<00:11,  1.59it/s, loss=1.06, accuracy=0.625][A[A

Epoch 3:  95%|█████████▌| 371/390 [03:50<00:11,  1.59it/s, loss=1.08, accuracy=0.625][A[A

Epoch 3:  95%|█████████▌| 372/390 [03:51<00:11,  1.59it/s, loss=1.08, ac

# TRAIN MODEL USING TRADES

In [30]:
args = {}
args['test_batch_size'] = 128
args['no_cuda'] = False
args['epsilon'] = 0.3
args['num_steps'] = 5
args['step_size'] = 0.01
args['random'] =True,
args['white_box_attack']=True
args['log_interval'] = 100
args['beta'] = 1.0
args['log_interval'] = 1

In [31]:
def train(args, model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        # calculate robust loss
        loss = trades_loss(model=model,
                           x_natural=data,
                           y=target,
                           optimizer=optimizer,
                           step_size=args['step_size'],
                           epsilon=args['epsilon'],
                           perturb_steps=args['num_steps'],
                           beta=args['beta'],
                           distance = 'l_2')
        

        #print('outloss pre step:',loss)
        loss.backward(retain_graph=True)
        
        optimizer.step()
        #print('outloss post step:',loss.item())

        # print progress
        if batch_idx % args['log_interval'] == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

In [32]:
def eval_train(model, device, train_loader):
    model.eval()
    train_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            train_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    train_loss /= len(train_loader.dataset)
    print('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        train_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))
    training_accuracy = correct / len(train_loader.dataset)
    return train_loss, training_accuracy

In [33]:
def eval_test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    test_accuracy = correct / len(test_loader.dataset)
    return test_loss, test_accuracy

In [34]:
if TrainTRADES:
    args['attack'] = 'l_2' #or 'l_inf'
    ## initialize model
    #model_TRADES = model_cnn().to(DEVICE)
    model_TRADES = Net().to(DEVICE)
    ## training params
    lr = 0.1
    optimizer = optim.SGD(model_TRADES.parameters(), lr=lr)
    epochs = 10
    ## train model

    for epoch in range(1, epochs + 1):
        # adjust learning rate for SGD
        #adjust_learning_rate(optimizer, epoch)

        # adversarial training
        train(args, model_TRADES, DEVICE, train_loader, optimizer, epoch)

        # evaluation on natural examples
        print('================================================================')
        eval_train(model_TRADES, DEVICE, train_loader)
        eval_test(model_TRADES, DEVICE, val_loader)
        print('================================================================')

    ## save model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_TRADES_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_TRADES,modelname)







Training: Average loss: 1.2212, Accuracy: 27988/50000 (56%)
Test: Average loss: 1.2542, Accuracy: 5517/10000 (55%)




Training: Average loss: 0.7253, Accuracy: 37379/50000 (75%)
Test: Average loss: 0.8181, Accuracy: 7056/10000 (71%)






Training: Average loss: 0.5824, Accuracy: 39685/50000 (79%)
Test: Average loss: 0.7569, Accuracy: 7301/10000 (73%)




Training: Average loss: 0.5419, Accuracy: 40599/50000 (81%)
Test: Average loss: 0.8326, Accuracy: 7183/10000 (72%)




Training: Average loss: 0.3923, Accuracy: 42992/50000 (86%)
Test: Average loss: 0.7940, Accuracy: 7400/10000 (74%)






Training: Average loss: 0.4638, Accuracy: 42207/50000 (84%)
Test: Average loss: 1.0209, Accuracy: 6991/10000 (70%)




Training: Average loss: 0.3718, Accuracy: 43821/50000 (88%)
Test: Average loss: 1.0820, Accuracy: 7177/10000 (72%)






Training: Average loss: 0.2510, Accuracy: 45631/50000 (91%)
Test: Average loss: 1.0003, Accuracy: 7333/10000 (73%)




Training: Average loss: 0.1324, Accuracy: 47622/50000 (95%)
Test: Average loss: 0.8499, Accuracy: 7614/10000 (76%)




Training: Average loss: 0.1680, Accuracy: 47051/50000 (94%)
Test: Average loss: 1.0444, Accuracy: 7456/10000 (75%)


In [38]:
if TrainTRADESInf:
    args['attack'] = 'l_inf' 
    ## initialize model
    model_TRADESInf = Net().to(DEVICE)
    ## training params
    lr = 0.1
    optimizer = optim.SGD(model_TRADESInf.parameters(), lr=lr)
    epochs = 10
    ## train model

    for epoch in range(1, epochs + 1):
        # adjust learning rate for SGD
        #adjust_learning_rate(optimizer, epoch)

        # adversarial training
        train(args, model_TRADESInf, DEVICE, train_loader, optimizer, epoch)

        # evaluation on natural examples
        print('================================================================')
        eval_train(model_TRADESInf, DEVICE, train_loader)
        eval_test(model_TRADESInf, DEVICE, val_loader)
        print('================================================================')

    ## save model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_TRADESInf_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_TRADESInf,modelname)





Training: Average loss: 1.1691, Accuracy: 28985/50000 (58%)
Test: Average loss: 1.2069, Accuracy: 5694/10000 (57%)




Training: Average loss: 0.7314, Accuracy: 37130/50000 (74%)
Test: Average loss: 0.8306, Accuracy: 7029/10000 (70%)






Training: Average loss: 0.6212, Accuracy: 38836/50000 (78%)
Test: Average loss: 0.7993, Accuracy: 7183/10000 (72%)




Training: Average loss: 0.4516, Accuracy: 41952/50000 (84%)
Test: Average loss: 0.7539, Accuracy: 7404/10000 (74%)




Training: Average loss: 0.3717, Accuracy: 43341/50000 (87%)
Test: Average loss: 0.8175, Accuracy: 7361/10000 (74%)






Training: Average loss: 0.2835, Accuracy: 45034/50000 (90%)
Test: Average loss: 0.8360, Accuracy: 7461/10000 (75%)




Training: Average loss: 0.2677, Accuracy: 45752/50000 (92%)
Test: Average loss: 0.9265, Accuracy: 7423/10000 (74%)






Training: Average loss: 0.1552, Accuracy: 47377/50000 (95%)
Test: Average loss: 0.8265, Accuracy: 7608/10000 (76%)




Training: Average loss: 0.1625, Accuracy: 47233/50000 (94%)
Test: Average loss: 0.9993, Accuracy: 7582/10000 (76%)




Training: Average loss: 0.1127, Accuracy: 47963/50000 (96%)
Test: Average loss: 0.9551, Accuracy: 7535/10000 (75%)


# TRAIN MODEL USING MART

In [39]:
#TBD

# TRAIN MODEL USING MMA

In [40]:
#TBD

# LOAD ALL PRE-TRAINED MODELS

In [41]:
TrainSGD = TrainSGD*True
TrainESGD = TrainESGD*True
TrainL2 = TrainL2*True
TrainLInf = TrainLInf*True
TrainSAT2 = TrainSAT2*True
TrainSATInf = TrainSATInf*True
TrainTRADES = TrainTRADES*True
TrainTRADESInf = TrainTRADESInf*True
TrainMART = TrainMART*False
TrainMMA = TrainMMA*False
TrainMMAInf = TrainMMAInf*False

In [42]:
TrainTRADESInf = True

In [44]:
# Load all the pre-trained models
if not TrainSGD:
    model_SGD = torch.load('../trainedmodels/'+dataset+'/SGD_ep10_lr0.1.pt').to(DEVICE)
if not TrainESGD:    
    model_ESGD = torch.load('../trainedmodels/'+dataset+'/ESGD_ep5_lr0.1.pt').to(DEVICE)
if not TrainL2:
    adv_model_linf = torch.load('../trainedmodels/'+dataset+'/AT2_ep2_lr0.1.pt').to(DEVICE)
if not TrainLInf:
    adv_model_l2 = torch.load('../trainedmodels/'+dataset+'/ATInf_ep2_lr0.1.pt').to(DEVICE)
if not TrainSAT2:
    model_SAT2 = torch.load('../trainedmodels/'+dataset+'/SAT2_ep2_lr0.1.pt').to(DEVICE)
if not TrainSATInf:
    model_SATInf = torch.load('../trainedmodels/'+dataset+'/SATInf_ep2_lr0.1.pt').to(DEVICE)
if not TrainTRADES:
    model_TRADES = torch.load('../trainedmodels/'+dataset+'/TRADES_ep2_lr0.1.pt').to(DEVICE)
if not TrainTRADESInf:
    model_TRADES = torch.load('../trainedmodels/'+dataset+'/TRADESInf_ep2_lr0.1.pt').to(DEVICE)    
'''
if not TrainMART:
    if dataset=='CIFAR10':
        model_MART = Net()
        if NetName == 'ResNet18':
            dics = torch.load('../trainedmodels/'+dataset+'/ResNet18_model_MART.pt')
        if NetName == 'WideResNet':
            dics = torch.load('../trainedmodels/'+dataset+'/rob_cifar_mart.pt')            
        model_MART.load_state_dict(dics)
        model_MART = model_MART.to(DEVICE)
if not TrainMMA:
    model_MMA = torch.load('../trainedmodels/'+dataset+'/'+dataset.lower()+'-L2-MMA-4.0-sd0/model_best.pt')
    model_MMAInf = torch.load('../trainedmodels/'+dataset+'/'+dataset.lower()+'-Linf-MMA-0.45-sd0/model_best.pt')
'''

"\nif not TrainMART:\n    if dataset=='CIFAR10':\n        model_MART = Net()\n        if NetName == 'ResNet18':\n            dics = torch.load('../trainedmodels/'+dataset+'/ResNet18_model_MART.pt')\n        if NetName == 'WideResNet':\n            dics = torch.load('../trainedmodels/'+dataset+'/rob_cifar_mart.pt')            \n        model_MART.load_state_dict(dics)\n        model_MART = model_MART.to(DEVICE)\nif not TrainMMA:\n    model_MMA = torch.load('../trainedmodels/'+dataset+'/'+dataset.lower()+'-L2-MMA-4.0-sd0/model_best.pt')\n    model_MMAInf = torch.load('../trainedmodels/'+dataset+'/'+dataset.lower()+'-Linf-MMA-0.45-sd0/model_best.pt')\n"

## Quantifying adversarial accuracy

In [45]:
def infnorm(x):
    infn = torch.max(torch.abs(x.detach().cpu()))
    return infn

In [46]:
def evaluate_against_adversary(model, k, eps, step, norm):
    total = 0
    acc = 0
    for x, y in val_loader:
        total += x.size(0)
        
        if norm == 2:
            x_adv = pgd(
                model, x.to(DEVICE), y.to(DEVICE), torch.nn.CrossEntropyLoss(), k=k, step=step, eps=eps, norm=2)
            print('rel. l2-norm of x_adv-x:',torch.norm(x_adv.detach().cpu()-x)/np.sqrt(x.size(0)))#/torch.norm(x))
        elif norm == 'inf':
            x_adv = iterated_fgsm(
                model, x.to(DEVICE), y.to(DEVICE), torch.nn.CrossEntropyLoss(), k=k, step=step, eps=eps, norm='inf')
            print('rel. linf-norm of x_adv-x:',infnorm(x_adv.detach().cpu()-x)/infnorm(x))
        y_pred = model(x_adv)

        acc += olympic.metrics.accuracy(y.to(DEVICE), y_pred) * x.size(0)

    return acc/total

## Evaluate robust models

In [47]:
loadResults = False

In [48]:
t1 = time.time()
if not loadResults:
    pgd_attack_range = [3.0] #np.arange(0, 6.1, 1./3)
    acc_SGD = []
    acc_ESGD = []
    acc_l2 = []
    acc_SAT2 = []
    acc_TRADES = []
    acc_MMA = []
    for eps in pgd_attack_range:
        print('eps:',eps)
        print('evaluating SGD network...')
        acc_SGD.append(evaluate_against_adversary(model_SGD, k=20, eps=eps, step=0.5, norm=2))
        print('evaluating ESGD network...')
        acc_ESGD.append(evaluate_against_adversary(model_ESGD, k=20, eps=eps, step=0.5, norm=2))
        print('evaluating SAT2 network...')
        acc_SAT2.append(evaluate_against_adversary(model_SAT2, k=30, eps=eps, step=0.3, norm=2))
        print('evaluating l2 network...')
        acc_l2.append(evaluate_against_adversary(adv_model_l2, k=30, eps=eps, step=1.5, norm=2))
        print('evaluating TRADES network...')
        acc_TRADES.append(evaluate_against_adversary(model_TRADES, k=30, eps=eps, step=1.5, norm=2))  
        #print('evaluating MART network...')
        #acc_MMA.append(evaluate_against_adversary(model_MMA, k=30, eps=eps, step=1.5, norm=2))          
print("time elapsed:",time.time()-t1)    

eps: 3.0
evaluating SGD network...
rel. l2-norm of x_adv-x: tensor(2.9203)
rel. l2-norm of x_adv-x: tensor(2.9299)
rel. l2-norm of x_adv-x: tensor(2.9110)
rel. l2-norm of x_adv-x: tensor(2.9079)
rel. l2-norm of x_adv-x: tensor(2.9165)
rel. l2-norm of x_adv-x: tensor(2.9116)
rel. l2-norm of x_adv-x: tensor(2.8897)
rel. l2-norm of x_adv-x: tensor(2.9204)
rel. l2-norm of x_adv-x: tensor(2.9440)
rel. l2-norm of x_adv-x: tensor(2.9237)
rel. l2-norm of x_adv-x: tensor(2.9232)
rel. l2-norm of x_adv-x: tensor(2.9266)
rel. l2-norm of x_adv-x: tensor(2.9064)
rel. l2-norm of x_adv-x: tensor(2.9315)
rel. l2-norm of x_adv-x: tensor(2.9236)
rel. l2-norm of x_adv-x: tensor(2.8902)
rel. l2-norm of x_adv-x: tensor(2.9205)
rel. l2-norm of x_adv-x: tensor(2.9212)
rel. l2-norm of x_adv-x: tensor(2.9150)
rel. l2-norm of x_adv-x: tensor(2.9059)
rel. l2-norm of x_adv-x: tensor(2.8753)
rel. l2-norm of x_adv-x: tensor(2.9225)
rel. l2-norm of x_adv-x: tensor(2.9158)
rel. l2-norm of x_adv-x: tensor(2.9100)
rel. 

rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)
rel. l2-norm of x_adv-x: tensor(2.9999)


In [49]:
if not loadResults:    
    accData = [acc_SGD,acc_ESGD,acc_l2,acc_SAT2,acc_TRADES,acc_MMA]
    np.save('../results/accData_l2.npy',accData)

In [None]:
if loadResults:
    pgd_attack_range = [3.0] #np.arange(0.0, 6.1, 1./3)
    accData2 = np.load('../results/accData_l2.npy')
    [acc_SGD,acc_ESGD,acc_l2,acc_SAT2,acc_TRADES,acc_MMA] = accData2

In [51]:
with plt.style.context('ggplot'):
    COLOURS = plt.rcParams['axes.prop_cycle'].by_key()['color']
    #MARKERS = plt.rcParams['axes.prop_cycle'].by_key()['marker']

In [55]:
if len(pgd_attack_range) > 1:
    with plt.style.context('ggplot'):
        plt.rcParams["font.weight"] = "bold"
        plt.rcParams["axes.labelweight"] = "bold"
        plt.rcParams["axes.labelcolor"] = "black"
        fig, axes = plt.subplots(1, 1, figsize=(15,10))
        plt.figure(figsize=(14,7))
        axes.set_title('$L^2$-bounded adversary')
        axes.plot(pgd_attack_range, acc_SGD, label='SGD')
        axes.plot(pgd_attack_range, acc_ESGD, label='Entropy-SGD')
        axes.plot(pgd_attack_range, acc_SAT2, label='Data-Entropy-SGD ($L_2$)')
        axes.plot(pgd_attack_range, acc_l2, label='$L2$ training')
        axes.plot(pgd_attack_range, acc_TRADES, label='TRADES training')
        axes.plot(pgd_attack_range, acc_MMA, label='MMA training')

        axes.vlines([3], 0, 1, colors=COLOURS[1], linestyle='--')
        axes.set_ylabel('Accuracy')
        axes.set_xlabel('Epsilon')
        axes.set_ylim((0,1))
        axes.legend()
else:
    print(accData)

[[0.0], [0.0], [0.031149839743589744], [0.0], [0.00010016025641025641], []]


In [56]:
loadResultsFGSM = False

In [63]:
t1 = time.time()
if not loadResultsFGSM:
    fgsm_attack_range = [8/255] #np.arange(0.0, 0.52, 0.025)
    fgsm_acc_linf = []
    fgsm_acc_SGD = []
    fgsm_acc_ESGD = []
    fgsm_acc_SATInf = []
    fgsm_acc_TRADESInf = []
    fgsm_acc_MART = []
    fgsm_acc_MMAInf = []
    for eps in fgsm_attack_range:
        print('eps:',eps)
        print('evaluating SGD network...')
        fgsm_acc_SGD.append(evaluate_against_adversary(model_SGD, k=20, eps=eps, step=0.1, norm='inf'))
        print('evaluating ESGD network...')
        fgsm_acc_ESGD.append(evaluate_against_adversary(model_ESGD, k=20, eps=eps, step=0.1, norm='inf'))
        print('evaluating SAT2 network...')
        fgsm_acc_SATInf.append(evaluate_against_adversary(model_SATInf, k=10, eps=eps, step=0.1, norm='inf'))        
        print('evaluating linf network...')
        fgsm_acc_linf.append(evaluate_against_adversary(adv_model_linf, k=50, eps=eps, step=0.02, norm='inf'))
        print('evaluating TRADES network...')
        fgsm_acc_TRADESInf.append(evaluate_against_adversary(model_TRADESInf, k=50, eps=eps, step=0.02, norm='inf'))
        #print('evaluating MART network...')
        #fgsm_acc_MART.append(evaluate_against_adversary(model_MART, k=50, eps=eps, step=0.02, norm='inf'))
        #print('evaluating MMA network...')
        #fgsm_acc_MMAInf.append(evaluate_against_adversary(model_MMAInf, k=50, eps=eps, step=0.02, norm='inf'))
         
print("time elapsed:",time.time()-t1)    

eps: 0.03137254901960784
evaluating SGD network...
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x

rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(

rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
time elapsed: 632.8202047348022


In [64]:
if not loadResultsFGSM:    
    fgsmaccData = [fgsm_acc_SGD,fgsm_acc_ESGD,fgsm_acc_linf,fgsm_acc_SATInf, fgsm_acc_TRADESInf, fgsm_acc_MART, fgsm_acc_MMAInf]
    np.save('../results/fgsmaccData.npy',fgsmaccData)

In [None]:
if loadResultsFGSM:
    fgsm_attack_range = [8/255] #np.arange(0.0, 0.52, 0.025)
    accData2 = np.load('../results/fgsmaccData.npy')
    [fgsm_acc_SGD,fgsm_acc_ESGD,fgsm_acc_linf,fgsm_acc_SATInf, fgsm_acc_TRADESInf, fgsm_acc_MART, fgsm_acc_MMAInf] = accData2

In [82]:
if len(fgsm_attack_range) > 1:
    with plt.style.context('ggplot'):
        plt.rcParams["font.weight"] = "bold"
        plt.rcParams["axes.labelweight"] = "bold"
        plt.rcParams["axes.labelcolor"] = "black"
        fig, axes = plt.subplots(1, 1, figsize=(15,10))
        plt.figure(figsize=(14,7))
        axes.set_title('$L^\infty$-bounded adversary')
        axes.plot(fgsm_attack_range, fgsm_acc_SGD, label='SGD')
        axes.plot(fgsm_attack_range, fgsm_acc_ESGD, label='Entropy-SGD')
        axes.plot(fgsm_attack_range, fgsm_acc_SATInf, label='Data-Entropy-SGD ($L_\infty$)')    
        axes.plot(fgsm_attack_range, fgsm_acc_linf, label='$L{\infty}$ training')
        axes.plot(fgsm_attack_range, fgsm_acc_TRADES, label='TRADES training')
        axes.vlines([0.25], 0, 1, colors=COLOURS[1], linestyle='--')
        axes.set_ylabel('Accuracy')
        axes.set_xlabel('Epsilon')
        axes.set_ylim((0,1))
        axes.legend()
else:
    print(fgsmaccData)

[[0.0026041666666666665], [0.039162660256410256], [0.34385016025641024], [0.17377804487179488], [0.055789262820512824], [], []]


In [96]:
tempFGSMSATInf_test1 = []
tempFGSMSATInf_test1.append(evaluate_against_adversary(model_SATInf_test1, k=10, eps=eps, step=0.1, norm='inf'))

rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(0.0314)
rel. linf-norm of x_adv-x: tensor(

In [97]:
tempFGSMSATInf_test1

[0.06300080128205128]