Evaluate the models

# Stochastic Adversarial Training (StochAT)

## IMPORT LIBRARIES

In [1]:
import numpy as np
import pandas as pd
from torch import nn, optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms, datasets
from multiprocessing import cpu_count
from collections import OrderedDict
import matplotlib.pyplot as plt
import torch
import olympic

In [2]:
import sys
from typing import Union, Callable, Tuple
sys.path.append('../adversarial/')
sys.path.append('../architectures/')
from functional import boundary, iterated_fgsm, local_search, pgd, entropySmoothing
from ESGD_utils import *

In [3]:
import pickle
import time

In [4]:
import torch.backends.cudnn as cudnn
import argparse, math, random
import ESGD_optim

In [5]:
import time

In [6]:
from trades import trades_loss

In [7]:
if torch.cuda.is_available():
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'

In [8]:
DEVICE

'cuda'

# LOAD DATA

In [9]:
dataset = 'MNIST' # [MNIST, CIFAR10]
transform = transforms.Compose([
    transforms.ToTensor(),
])
bsz = 128
if dataset == 'MNIST':
    train = datasets.MNIST('../../data/MNIST', train=True, transform=transform, download=True)
    val = datasets.MNIST('../../data/MNIST', train=False, transform=transform, download=True)
elif dataset == 'CIFAR10':
    train = datasets.CIFAR10('../../data/CIFAR10', train=True, transform=transform, download=True)
    val = datasets.CIFAR10('../../data/CIFAR10', train=False, transform=transform, download=True)
    
train_loader = DataLoader(train, batch_size=128, num_workers=cpu_count(),drop_last=True)
val_loader = DataLoader(val, batch_size=128, num_workers=cpu_count(),drop_last=True)

# INITIALIZE NETWORK

In [10]:
if dataset=='MNIST':
    from net_mnist import Net 

In [11]:
if dataset=='CIFAR10':
    #[ResNet18,ResNet34,ResNet50,WideResNet]
    from resnet import ResNet18,ResNet34,ResNet50
    from wideresnet import WideResNet
    Net = ResNet18

# RANDOM SEED 

In [None]:
seed = 42
torch.set_num_threads(2)
if DEVICE=='cuda':
    torch.cuda.set_device(-1)
    torch.cuda.manual_seed(seed)
    cudnn.benchmark = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# LOAD PRETRAINED OR TRAIN NEW MODELS:

In [12]:
TrainSGD = True
TrainESGD = True
TrainL2 = True
TrainLInf = True
TrainSAT2 = True
TrainSATInf = True
TrainTRADES = True
TrainMART = False
TrainMMA = False

# TRAIN NAIVE MODEL USING SGD

In [None]:
if TrainSGD:
    ## initialize model
    model_SGD = Net().to(DEVICE)
    ## training params
    lr = 0.1
    optimiser = optim.SGD(model_SGD.parameters(), lr=lr)
    loss_fn = nn.CrossEntropyLoss()
    epochs = 10
    ## train model
    history_natural = olympic.fit(
        model_SGD,
        optimiser,
        loss_fn,
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5)
        ]
    )
    ## verify validation accuracy
    print('final validation accuracy:')
    valscore = olympic.evaluate(model_SGD, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    print(valscore)
    ## save model
    modelname = '../trainedmodels/'+dataset+'/SGD_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_SGD,modelname)

# TRAIN MODEL USING ENTROPY SGD (ESGD)

In [None]:
def entropy_training(model, optimiser, loss_fn, x, y, epoch):
    
    model.train()
    y_pred = model(x)
    
    def helper():
        def feval():
            #x, y = Variable(x), Variable(y.squeeze())
            bsz = x.size(0)
            optimiser.zero_grad()
            yh = model(x)
            f = loss_fn.forward(yh, y)
            f.backward()

            yp = yh.argmax(axis=1)
            prec1 = 100*torch.sum(yp == y)//bsz
            err = 100.-prec1.item()

            return (f.data.item(), err)
        return feval

    loss, err = optimiser.step(helper(), model, loss_fn)
    loss = torch.tensor(loss)
    return loss, y_pred

In [None]:
if TrainESGD:
    ## initialize model
    model_ESGD = Net().to(DEVICE)
    ## training parameters
    lr = 0.1 
    l2 = 0.0 #l2 regularization
    L = 0    #langevin iterations
    gamma = 1e-4 
    scoping = 1e-3
    noise = 1e-4
    loss_fn = nn.CrossEntropyLoss()
    epochs = 5
    optimiser = ESGD_optim.EntropySGD(model_ESGD.parameters(),
            config = dict(lr=lr, momentum=0.9, nesterov=True, weight_decay=l2,
            L=L, eps=noise, g0=gamma, g1=scoping))
    ## train model
    history_natural = olympic.fit(
        model_ESGD,
        optimiser,
        loss_fn,
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=entropy_training,
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5)
        ]
    )
    ## verify validation accuracy
    print('final validation accuracy:')
    valacc = olympic.evaluate(model_ESGD, val_loader, metrics=['accuracy'],
                         prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    print(valacc['val_accuracy'])
    ## save trained model
    modelname = '../trainedmodels/'+dataset+'/ESGD_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_ESGD,modelname)

# TRAIN MODEL USING PGD SGD 

In [13]:
def infnorm(x):
    infn = torch.max(torch.abs(x.detach().cpu()))
    return infn

In [14]:
def adversarial_training(model, optimiser, loss_fn, x, y, epoch, adversary, k, step, eps, norm, random):
    """Performs a single update against a specified adversary"""
    model.train()

    # Adversial perturbation
    x_adv = adversary(model, x, y, loss_fn, k=k, step=step, eps=eps, norm=norm, random=True)
    #print('l2:',torch.norm(x_adv.detach().cpu()-x.detach().cpu())/np.sqrt(x.detach().cpu().size(0)))    
    #print('linf:',infnorm(x_adv.detach().cpu()-x.detach().cpu())/infnorm(x))    

    optimiser.zero_grad()
    y_pred = model(x_adv)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimiser.step()

    return loss, y_pred

## l2 ball

In [16]:
if TrainL2:
    ## initialize model
    adv_model_l2 = Net().to(DEVICE)
    lr = 0.01
    optimiser = optim.SGD(adv_model_l2.parameters(), lr=lr)
    epochs = 5
    ## train model
    training_history_l2 = olympic.fit(
        adv_model_l2,
        optimiser,
        nn.CrossEntropyLoss(),
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=adversarial_training,
        update_fn_kwargs={'adversary': pgd, 'k': 2, 'step': 0.0005, 'eps': 1.0, 'norm': 2, 'random':True},
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5, factor=0.5, min_delta=0.005, monitor='val_accuracy')
        ]
    )
    ## verify validation accuracy
    print('final validation accuracy:')
    valacc = olympic.evaluate(adv_model_l2, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    print(valacc['val_accuracy'])
    ## save trained model
    modelname = '../trainedmodels/'+dataset+'/AT2_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(adv_model_l2,modelname)

Epoch 1:   0%|          | 0/468 [00:00<?, ?it/s]

Begin training...


Epoch 1: 100%|██████████| 468/468 [00:03<00:00, 122.97it/s, loss=2.23, accuracy=0.184, val_loss=1.83, val_accuracy=0.63]
Epoch 2: 100%|██████████| 468/468 [00:03<00:00, 125.73it/s, loss=1.31, accuracy=0.568, val_loss=0.594, val_accuracy=0.818]
Epoch 3: 100%|██████████| 468/468 [00:03<00:00, 126.84it/s, loss=0.843, accuracy=0.729, val_loss=0.435, val_accuracy=0.875]
Epoch 4: 100%|██████████| 468/468 [00:03<00:00, 125.42it/s, loss=0.708, accuracy=0.778, val_loss=0.351, val_accuracy=0.899]
Epoch 5: 100%|██████████| 468/468 [00:03<00:00, 125.02it/s, loss=0.616, accuracy=0.807, val_loss=0.301, val_accuracy=0.912]

Finished.
final validation accuracy:





0.9121594551282052


## linf ball

In [18]:
if TrainLInf:
    ## initialize model
    adv_model_linf = Net().to(DEVICE)
    ## train params
    lr = 0.01
    optimiser = optim.SGD(adv_model_linf.parameters(), lr=lr)
    epochs = 5
    ## train model
    training_history_linf = olympic.fit(
        adv_model_linf,
        optimiser,
        nn.CrossEntropyLoss(),
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=adversarial_training,
        update_fn_kwargs={'adversary': iterated_fgsm,'k': 2, 'step': 0.0005, 'eps': 0.1, 'norm': 'inf', 'random':True},
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5, factor=0.5, min_delta=0.005, monitor='val_accuracy')
        ]
    )
    ## verify validation
    print('final validation accuracy:')
    valacc = olympic.evaluate(adv_model_linf, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    
    ## save trained model
    modelname = '../trainedmodels/'+dataset+'/ATInf_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(adv_model_linf,modelname)


Epoch 1:   0%|          | 0/468 [00:00<?, ?it/s][A

Begin training...



Epoch 1:   0%|          | 1/468 [00:00<01:27,  5.34it/s][A
Epoch 1:   0%|          | 1/468 [00:00<01:27,  5.34it/s, loss=2.31, accuracy=0.0703][A
Epoch 1:   0%|          | 2/468 [00:00<01:27,  5.34it/s, loss=2.3, accuracy=0.117]  [A
Epoch 1:   1%|          | 3/468 [00:00<01:27,  5.34it/s, loss=2.3, accuracy=0.125][A
Epoch 1:   1%|          | 4/468 [00:00<01:26,  5.34it/s, loss=2.31, accuracy=0.109][A
Epoch 1:   1%|          | 5/468 [00:00<01:26,  5.34it/s, loss=2.31, accuracy=0.0938][A
Epoch 1:   1%|▏         | 6/468 [00:00<01:26,  5.34it/s, loss=2.31, accuracy=0.0781][A
Epoch 1:   1%|▏         | 7/468 [00:00<01:26,  5.34it/s, loss=2.31, accuracy=0.0938][A
Epoch 1:   2%|▏         | 8/468 [00:00<01:26,  5.34it/s, loss=2.3, accuracy=0.102]  [A
Epoch 1:   2%|▏         | 9/468 [00:00<01:25,  5.34it/s, loss=2.3, accuracy=0.117][A
Epoch 1:   2%|▏         | 10/468 [00:00<01:25,  5.34it/s, loss=2.31, accuracy=0.0938][A
Epoch 1:   2%|▏         | 11/468 [00:00<01:25,  5.34it/s, loss=

Epoch 1:  19%|█▊        | 87/468 [00:00<00:11, 33.95it/s, loss=2.3, accuracy=0.109]  [A
Epoch 1:  19%|█▉        | 88/468 [00:00<00:08, 43.35it/s, loss=2.3, accuracy=0.109][A
Epoch 1:  19%|█▉        | 88/468 [00:00<00:08, 43.35it/s, loss=2.3, accuracy=0.102][A
Epoch 1:  19%|█▉        | 89/468 [00:00<00:08, 43.35it/s, loss=2.3, accuracy=0.148][A
Epoch 1:  19%|█▉        | 90/468 [00:00<00:08, 43.35it/s, loss=2.3, accuracy=0.0938][A
Epoch 1:  19%|█▉        | 91/468 [00:00<00:08, 43.35it/s, loss=2.29, accuracy=0.0781][A
Epoch 1:  20%|█▉        | 92/468 [00:00<00:08, 43.35it/s, loss=2.31, accuracy=0.0703][A
Epoch 1:  20%|█▉        | 93/468 [00:00<00:08, 43.35it/s, loss=2.3, accuracy=0.141]  [A
Epoch 1:  20%|██        | 94/468 [00:00<00:08, 43.35it/s, loss=2.3, accuracy=0.125][A
Epoch 1:  20%|██        | 95/468 [00:00<00:08, 43.35it/s, loss=2.3, accuracy=0.148][A
Epoch 1:  21%|██        | 96/468 [00:00<00:08, 43.35it/s, loss=2.3, accuracy=0.141][A
Epoch 1:  21%|██        | 97/468 [

Epoch 1:  37%|███▋      | 172/468 [00:01<00:02, 107.98it/s, loss=2.29, accuracy=0.172][A
Epoch 1:  37%|███▋      | 173/468 [00:01<00:02, 107.98it/s, loss=2.28, accuracy=0.172][A
Epoch 1:  37%|███▋      | 174/468 [00:01<00:02, 107.98it/s, loss=2.29, accuracy=0.164][A
Epoch 1:  37%|███▋      | 175/468 [00:01<00:02, 107.98it/s, loss=2.3, accuracy=0.102] [A
Epoch 1:  38%|███▊      | 176/468 [00:01<00:02, 107.98it/s, loss=2.29, accuracy=0.156][A
Epoch 1:  38%|███▊      | 177/468 [00:01<00:02, 107.98it/s, loss=2.3, accuracy=0.109] [A
Epoch 1:  38%|███▊      | 178/468 [00:01<00:02, 107.98it/s, loss=2.3, accuracy=0.141][A
Epoch 1:  38%|███▊      | 179/468 [00:01<00:02, 107.98it/s, loss=2.29, accuracy=0.188][A
Epoch 1:  38%|███▊      | 180/468 [00:01<00:02, 107.98it/s, loss=2.3, accuracy=0.117] [A
Epoch 1:  39%|███▊      | 181/468 [00:01<00:02, 107.98it/s, loss=2.3, accuracy=0.133][A
Epoch 1:  39%|███▉      | 182/468 [00:01<00:02, 107.98it/s, loss=2.3, accuracy=0.172][A
Epoch 1:  39%

Epoch 1:  55%|█████▍    | 257/468 [00:02<00:01, 128.53it/s, loss=2.29, accuracy=0.188][A
Epoch 1:  55%|█████▌    | 258/468 [00:02<00:01, 128.53it/s, loss=2.28, accuracy=0.156][A
Epoch 1:  55%|█████▌    | 259/468 [00:02<00:01, 128.53it/s, loss=2.3, accuracy=0.141] [A
Epoch 1:  56%|█████▌    | 260/468 [00:02<00:01, 128.53it/s, loss=2.28, accuracy=0.18][A
Epoch 1:  56%|█████▌    | 261/468 [00:02<00:01, 128.53it/s, loss=2.29, accuracy=0.156][A
Epoch 1:  56%|█████▌    | 262/468 [00:02<00:01, 128.53it/s, loss=2.28, accuracy=0.195][A
Epoch 1:  56%|█████▌    | 263/468 [00:02<00:01, 128.53it/s, loss=2.29, accuracy=0.18] [A
Epoch 1:  56%|█████▋    | 264/468 [00:02<00:01, 128.53it/s, loss=2.29, accuracy=0.148][A
Epoch 1:  57%|█████▋    | 265/468 [00:02<00:01, 128.53it/s, loss=2.3, accuracy=0.133] [A
Epoch 1:  57%|█████▋    | 266/468 [00:02<00:01, 128.53it/s, loss=2.28, accuracy=0.156][A
Epoch 1:  57%|█████▋    | 267/468 [00:02<00:01, 128.53it/s, loss=2.28, accuracy=0.117][A
Epoch 1:  5

Epoch 1:  73%|███████▎  | 342/468 [00:02<00:00, 128.38it/s, loss=2.28, accuracy=0.188][A
Epoch 1:  73%|███████▎  | 343/468 [00:02<00:00, 128.38it/s, loss=2.27, accuracy=0.211][A
Epoch 1:  74%|███████▎  | 344/468 [00:02<00:00, 128.38it/s, loss=2.26, accuracy=0.219][A
Epoch 1:  74%|███████▎  | 345/468 [00:02<00:00, 128.38it/s, loss=2.28, accuracy=0.258][A
Epoch 1:  74%|███████▍  | 346/468 [00:02<00:00, 128.38it/s, loss=2.29, accuracy=0.133][A
Epoch 1:  74%|███████▍  | 347/468 [00:02<00:00, 128.38it/s, loss=2.28, accuracy=0.156][A
Epoch 1:  74%|███████▍  | 348/468 [00:02<00:00, 128.38it/s, loss=2.29, accuracy=0.125][A
Epoch 1:  75%|███████▍  | 349/468 [00:02<00:00, 128.38it/s, loss=2.27, accuracy=0.172][A
Epoch 1:  75%|███████▍  | 350/468 [00:02<00:00, 128.38it/s, loss=2.27, accuracy=0.141][A
Epoch 1:  75%|███████▌  | 351/468 [00:02<00:00, 132.25it/s, loss=2.27, accuracy=0.141][A
Epoch 1:  75%|███████▌  | 351/468 [00:02<00:00, 132.25it/s, loss=2.27, accuracy=0.164][A
Epoch 1:  

Epoch 1:  91%|█████████ | 427/468 [00:03<00:00, 135.87it/s, loss=2.27, accuracy=0.188][A
Epoch 1:  91%|█████████▏| 428/468 [00:03<00:00, 135.87it/s, loss=2.25, accuracy=0.188][A
Epoch 1:  92%|█████████▏| 429/468 [00:03<00:00, 135.87it/s, loss=2.28, accuracy=0.188][A
Epoch 1:  92%|█████████▏| 430/468 [00:03<00:00, 135.87it/s, loss=2.27, accuracy=0.195][A
Epoch 1:  92%|█████████▏| 431/468 [00:03<00:00, 135.87it/s, loss=2.26, accuracy=0.211][A
Epoch 1:  92%|█████████▏| 432/468 [00:03<00:00, 135.87it/s, loss=2.26, accuracy=0.289][A
Epoch 1:  93%|█████████▎| 433/468 [00:03<00:00, 135.87it/s, loss=2.26, accuracy=0.18] [A
Epoch 1:  93%|█████████▎| 434/468 [00:03<00:00, 135.87it/s, loss=2.26, accuracy=0.164][A
Epoch 1:  93%|█████████▎| 435/468 [00:03<00:00, 135.87it/s, loss=2.25, accuracy=0.219][A
Epoch 1:  93%|█████████▎| 436/468 [00:03<00:00, 135.87it/s, loss=2.25, accuracy=0.242][A
Epoch 1:  93%|█████████▎| 437/468 [00:03<00:00, 136.84it/s, loss=2.25, accuracy=0.242][A
Epoch 1:  

Epoch 2:   9%|▉         | 43/468 [00:00<00:26, 16.15it/s, loss=2.25, accuracy=0.172][A
Epoch 2:   9%|▉         | 44/468 [00:00<00:26, 16.15it/s, loss=2.23, accuracy=0.242][A
Epoch 2:  10%|▉         | 45/468 [00:00<00:26, 16.15it/s, loss=2.24, accuracy=0.18] [A
Epoch 2:  10%|▉         | 46/468 [00:00<00:26, 16.15it/s, loss=2.25, accuracy=0.227][A
Epoch 2:  10%|█         | 47/468 [00:00<00:26, 16.15it/s, loss=2.21, accuracy=0.242][A
Epoch 2:  10%|█         | 48/468 [00:00<00:26, 16.15it/s, loss=2.23, accuracy=0.258][A
Epoch 2:  10%|█         | 49/468 [00:00<00:25, 16.15it/s, loss=2.23, accuracy=0.219][A
Epoch 2:  11%|█         | 50/468 [00:00<00:25, 16.15it/s, loss=2.21, accuracy=0.328][A
Epoch 2:  11%|█         | 51/468 [00:00<00:25, 16.15it/s, loss=2.24, accuracy=0.195][A
Epoch 2:  11%|█         | 52/468 [00:00<00:25, 16.15it/s, loss=2.2, accuracy=0.234] [A
Epoch 2:  11%|█▏        | 53/468 [00:00<00:25, 16.15it/s, loss=2.24, accuracy=0.211][A
Epoch 2:  12%|█▏        | 54/468

Epoch 2:  28%|██▊       | 129/468 [00:01<00:04, 71.07it/s, loss=2.17, accuracy=0.258][A
Epoch 2:  28%|██▊       | 130/468 [00:01<00:04, 71.07it/s, loss=2.17, accuracy=0.305][A
Epoch 2:  28%|██▊       | 131/468 [00:01<00:04, 71.07it/s, loss=2.2, accuracy=0.305] [A
Epoch 2:  28%|██▊       | 132/468 [00:01<00:04, 71.07it/s, loss=2.2, accuracy=0.188][A
Epoch 2:  28%|██▊       | 133/468 [00:01<00:04, 71.07it/s, loss=2.19, accuracy=0.328][A
Epoch 2:  29%|██▊       | 134/468 [00:01<00:04, 71.07it/s, loss=2.17, accuracy=0.289][A
Epoch 2:  29%|██▉       | 135/468 [00:01<00:04, 71.07it/s, loss=2.2, accuracy=0.227] [A
Epoch 2:  29%|██▉       | 136/468 [00:01<00:04, 82.00it/s, loss=2.2, accuracy=0.227][A
Epoch 2:  29%|██▉       | 136/468 [00:01<00:04, 82.00it/s, loss=2.21, accuracy=0.188][A
Epoch 2:  29%|██▉       | 137/468 [00:01<00:04, 82.00it/s, loss=2.15, accuracy=0.312][A
Epoch 2:  29%|██▉       | 138/468 [00:01<00:04, 82.00it/s, loss=2.19, accuracy=0.188][A
Epoch 2:  30%|██▉      

Epoch 2:  46%|████▌     | 214/468 [00:01<00:02, 114.67it/s, loss=2.04, accuracy=0.359][A
Epoch 2:  46%|████▌     | 215/468 [00:01<00:02, 116.87it/s, loss=2.04, accuracy=0.359][A
Epoch 2:  46%|████▌     | 215/468 [00:01<00:02, 116.87it/s, loss=1.99, accuracy=0.336][A
Epoch 2:  46%|████▌     | 216/468 [00:01<00:02, 116.87it/s, loss=2.02, accuracy=0.352][A
Epoch 2:  46%|████▋     | 217/468 [00:01<00:02, 116.87it/s, loss=2.04, accuracy=0.312][A
Epoch 2:  47%|████▋     | 218/468 [00:01<00:02, 116.87it/s, loss=2.04, accuracy=0.32] [A
Epoch 2:  47%|████▋     | 219/468 [00:01<00:02, 116.87it/s, loss=2, accuracy=0.359]  [A
Epoch 2:  47%|████▋     | 220/468 [00:01<00:02, 116.87it/s, loss=1.93, accuracy=0.383][A
Epoch 2:  47%|████▋     | 221/468 [00:01<00:02, 116.87it/s, loss=2, accuracy=0.391]   [A
Epoch 2:  47%|████▋     | 222/468 [00:01<00:02, 116.87it/s, loss=2.02, accuracy=0.32][A
Epoch 2:  48%|████▊     | 223/468 [00:01<00:02, 116.87it/s, loss=1.99, accuracy=0.383][A
Epoch 2:  48

Epoch 2:  64%|██████▎   | 298/468 [00:02<00:01, 127.10it/s, loss=1.75, accuracy=0.414][A
Epoch 2:  64%|██████▍   | 299/468 [00:02<00:01, 127.10it/s, loss=1.72, accuracy=0.484][A
Epoch 2:  64%|██████▍   | 300/468 [00:02<00:01, 127.10it/s, loss=1.9, accuracy=0.367] [A
Epoch 2:  64%|██████▍   | 301/468 [00:02<00:01, 127.10it/s, loss=1.72, accuracy=0.453][A
Epoch 2:  65%|██████▍   | 302/468 [00:02<00:01, 127.10it/s, loss=1.79, accuracy=0.398][A
Epoch 2:  65%|██████▍   | 303/468 [00:02<00:01, 127.10it/s, loss=1.68, accuracy=0.5]  [A
Epoch 2:  65%|██████▍   | 304/468 [00:02<00:01, 127.10it/s, loss=1.68, accuracy=0.469][A
Epoch 2:  65%|██████▌   | 305/468 [00:02<00:01, 127.10it/s, loss=1.6, accuracy=0.586] [A
Epoch 2:  65%|██████▌   | 306/468 [00:02<00:01, 127.10it/s, loss=1.73, accuracy=0.438][A
Epoch 2:  66%|██████▌   | 307/468 [00:02<00:01, 127.10it/s, loss=1.67, accuracy=0.477][A
Epoch 2:  66%|██████▌   | 308/468 [00:02<00:01, 125.06it/s, loss=1.67, accuracy=0.477][A
Epoch 1:  

Epoch 2:  82%|████████▏ | 383/468 [00:03<00:00, 127.95it/s, loss=1.64, accuracy=0.477][A
Epoch 2:  82%|████████▏ | 384/468 [00:03<00:00, 127.95it/s, loss=1.63, accuracy=0.438][A
Epoch 2:  82%|████████▏ | 385/468 [00:03<00:00, 127.95it/s, loss=1.53, accuracy=0.516][A
Epoch 2:  82%|████████▏ | 386/468 [00:03<00:00, 127.95it/s, loss=1.37, accuracy=0.523][A
Epoch 2:  83%|████████▎ | 387/468 [00:03<00:00, 126.85it/s, loss=1.37, accuracy=0.523][A
Epoch 2:  83%|████████▎ | 387/468 [00:03<00:00, 126.85it/s, loss=1.46, accuracy=0.531][A
Epoch 2:  83%|████████▎ | 388/468 [00:03<00:00, 126.85it/s, loss=1.66, accuracy=0.422][A
Epoch 2:  83%|████████▎ | 389/468 [00:03<00:00, 126.85it/s, loss=1.52, accuracy=0.516][A
Epoch 2:  83%|████████▎ | 390/468 [00:03<00:00, 126.85it/s, loss=1.48, accuracy=0.508][A
Epoch 2:  84%|████████▎ | 391/468 [00:03<00:00, 126.85it/s, loss=1.49, accuracy=0.492][A
Epoch 2:  84%|████████▍ | 392/468 [00:03<00:00, 126.85it/s, loss=1.44, accuracy=0.453][A
Epoch 2:  

Epoch 2: 100%|██████████| 468/468 [00:04<00:00, 111.73it/s, loss=1.88, accuracy=0.373, val_loss=0.868, val_accuracy=0.806][A

Epoch 3:   0%|          | 0/468 [00:00<?, ?it/s][A
Epoch 3:   0%|          | 1/468 [00:00<01:26,  5.38it/s][A
Epoch 3:   0%|          | 1/468 [00:00<01:26,  5.38it/s, loss=1.14, accuracy=0.664][A
Epoch 3:   0%|          | 2/468 [00:00<01:26,  5.38it/s, loss=1.27, accuracy=0.648][A
Epoch 3:   1%|          | 3/468 [00:00<01:26,  5.38it/s, loss=1.18, accuracy=0.641][A
Epoch 3:   1%|          | 4/468 [00:00<01:26,  5.38it/s, loss=1.19, accuracy=0.617][A
Epoch 3:   1%|          | 5/468 [00:00<01:26,  5.38it/s, loss=1.3, accuracy=0.57]  [A
Epoch 3:   1%|▏         | 6/468 [00:00<01:25,  5.38it/s, loss=1.19, accuracy=0.609][A
Epoch 3:   1%|▏         | 7/468 [00:00<01:25,  5.38it/s, loss=1.33, accuracy=0.484][A
Epoch 3:   2%|▏         | 8/468 [00:00<01:25,  5.38it/s, loss=1.32, accuracy=0.555][A
Epoch 3:   2%|▏         | 9/468 [00:00<01:25,  5.38it/s, loss=1.4

Epoch 3:  18%|█▊        | 85/468 [00:00<00:11, 34.78it/s, loss=1.09, accuracy=0.602][A
Epoch 3:  18%|█▊        | 86/468 [00:00<00:10, 34.78it/s, loss=1.07, accuracy=0.688][A
Epoch 3:  19%|█▊        | 87/468 [00:00<00:10, 34.78it/s, loss=0.928, accuracy=0.734][A
Epoch 3:  19%|█▉        | 88/468 [00:00<00:10, 34.78it/s, loss=1.1, accuracy=0.594]  [A
Epoch 3:  19%|█▉        | 89/468 [00:00<00:10, 34.78it/s, loss=0.958, accuracy=0.727][A
Epoch 3:  19%|█▉        | 90/468 [00:00<00:10, 34.78it/s, loss=0.951, accuracy=0.695][A
Epoch 3:  19%|█▉        | 91/468 [00:00<00:10, 34.78it/s, loss=1.26, accuracy=0.594] [A
Epoch 3:  20%|█▉        | 92/468 [00:00<00:10, 34.78it/s, loss=1.24, accuracy=0.609][A
Epoch 3:  20%|█▉        | 93/468 [00:00<00:10, 34.78it/s, loss=1.06, accuracy=0.648][A
Epoch 3:  20%|██        | 94/468 [00:00<00:08, 44.65it/s, loss=1.06, accuracy=0.648][A
Epoch 3:  20%|██        | 94/468 [00:00<00:08, 44.65it/s, loss=1.12, accuracy=0.641][A
Epoch 3:  20%|██        | 9

Epoch 3:  36%|███▋      | 170/468 [00:01<00:03, 98.37it/s, loss=0.948, accuracy=0.68][A
Epoch 3:  37%|███▋      | 171/468 [00:01<00:03, 98.37it/s, loss=0.803, accuracy=0.727][A
Epoch 3:  37%|███▋      | 172/468 [00:01<00:03, 98.37it/s, loss=1, accuracy=0.672]    [A
Epoch 3:  37%|███▋      | 173/468 [00:01<00:02, 98.37it/s, loss=1.08, accuracy=0.625][A
Epoch 3:  37%|███▋      | 174/468 [00:01<00:02, 105.91it/s, loss=1.08, accuracy=0.625][A
Epoch 3:  37%|███▋      | 174/468 [00:01<00:02, 105.91it/s, loss=0.999, accuracy=0.617][A
Epoch 3:  37%|███▋      | 175/468 [00:01<00:02, 105.91it/s, loss=0.889, accuracy=0.711][A
Epoch 3:  38%|███▊      | 176/468 [00:01<00:02, 105.91it/s, loss=1.09, accuracy=0.602] [A
Epoch 3:  38%|███▊      | 177/468 [00:01<00:02, 105.91it/s, loss=1.21, accuracy=0.578][A
Epoch 3:  38%|███▊      | 178/468 [00:01<00:02, 105.91it/s, loss=0.977, accuracy=0.703][A
Epoch 3:  38%|███▊      | 179/468 [00:01<00:02, 105.91it/s, loss=0.836, accuracy=0.742][A
Epoch 3

Epoch 3:  54%|█████▍    | 254/468 [00:02<00:01, 123.60it/s, loss=1.18, accuracy=0.602][A
Epoch 3:  54%|█████▍    | 255/468 [00:02<00:01, 123.60it/s, loss=0.792, accuracy=0.734][A
Epoch 3:  55%|█████▍    | 256/468 [00:02<00:01, 123.45it/s, loss=0.792, accuracy=0.734][A
Epoch 3:  55%|█████▍    | 256/468 [00:02<00:01, 123.45it/s, loss=0.828, accuracy=0.766][A
Epoch 3:  55%|█████▍    | 257/468 [00:02<00:01, 123.45it/s, loss=0.953, accuracy=0.758][A
Epoch 3:  55%|█████▌    | 258/468 [00:02<00:01, 123.45it/s, loss=0.822, accuracy=0.734][A
Epoch 3:  55%|█████▌    | 259/468 [00:02<00:01, 123.45it/s, loss=0.977, accuracy=0.672][A
Epoch 3:  56%|█████▌    | 260/468 [00:02<00:01, 123.45it/s, loss=0.827, accuracy=0.711][A
Epoch 3:  56%|█████▌    | 261/468 [00:02<00:01, 123.45it/s, loss=1.05, accuracy=0.648] [A
Epoch 3:  56%|█████▌    | 262/468 [00:02<00:01, 123.45it/s, loss=0.865, accuracy=0.719][A
Epoch 3:  56%|█████▌    | 263/468 [00:02<00:01, 123.45it/s, loss=0.851, accuracy=0.719][A


Epoch 3:  72%|███████▏  | 338/468 [00:02<00:01, 129.31it/s, loss=1.22, accuracy=0.547][A
Epoch 3:  72%|███████▏  | 338/468 [00:02<00:01, 129.31it/s, loss=0.812, accuracy=0.758][A
Epoch 3:  72%|███████▏  | 339/468 [00:02<00:00, 129.31it/s, loss=0.773, accuracy=0.797][A
Epoch 3:  73%|███████▎  | 340/468 [00:02<00:00, 129.31it/s, loss=0.61, accuracy=0.828] [A
Epoch 3:  73%|███████▎  | 341/468 [00:02<00:00, 129.31it/s, loss=0.731, accuracy=0.758][A
Epoch 3:  73%|███████▎  | 342/468 [00:02<00:00, 129.31it/s, loss=0.938, accuracy=0.711][A
Epoch 3:  73%|███████▎  | 343/468 [00:02<00:00, 129.31it/s, loss=0.916, accuracy=0.742][A
Epoch 3:  74%|███████▎  | 344/468 [00:02<00:00, 129.31it/s, loss=0.808, accuracy=0.719][A
Epoch 3:  74%|███████▎  | 345/468 [00:02<00:00, 129.31it/s, loss=0.88, accuracy=0.75]  [A
Epoch 3:  74%|███████▍  | 346/468 [00:02<00:00, 129.31it/s, loss=0.966, accuracy=0.672][A
Epoch 3:  74%|███████▍  | 347/468 [00:02<00:00, 129.31it/s, loss=0.909, accuracy=0.719][A


Epoch 3:  90%|████████▉ | 421/468 [00:03<00:00, 123.79it/s, loss=0.739, accuracy=0.758][A
Epoch 3:  90%|█████████ | 422/468 [00:03<00:00, 123.79it/s, loss=0.82, accuracy=0.68]  [A
Epoch 3:  90%|█████████ | 423/468 [00:03<00:00, 123.79it/s, loss=0.975, accuracy=0.648][A
Epoch 3:  91%|█████████ | 424/468 [00:03<00:00, 123.79it/s, loss=0.823, accuracy=0.742][A
Epoch 3:  91%|█████████ | 425/468 [00:03<00:00, 123.79it/s, loss=0.686, accuracy=0.82] [A
Epoch 3:  91%|█████████ | 426/468 [00:03<00:00, 123.79it/s, loss=0.689, accuracy=0.797][A
Epoch 3:  91%|█████████ | 427/468 [00:03<00:00, 123.79it/s, loss=0.798, accuracy=0.758][A
Epoch 3:  91%|█████████▏| 428/468 [00:03<00:00, 123.79it/s, loss=0.644, accuracy=0.812][A
Epoch 3:  92%|█████████▏| 429/468 [00:03<00:00, 123.79it/s, loss=0.855, accuracy=0.688][A
Epoch 3:  92%|█████████▏| 430/468 [00:03<00:00, 123.79it/s, loss=0.99, accuracy=0.633] [A
Epoch 3:  92%|█████████▏| 431/468 [00:03<00:00, 123.79it/s, loss=0.702, accuracy=0.805][A

Epoch 4:   8%|▊         | 37/468 [00:00<00:41, 10.40it/s, loss=0.7, accuracy=0.797]  [A
Epoch 4:   8%|▊         | 38/468 [00:00<00:41, 10.40it/s, loss=0.858, accuracy=0.727][A
Epoch 4:   8%|▊         | 39/468 [00:00<00:41, 10.40it/s, loss=0.704, accuracy=0.812][A
Epoch 4:   9%|▊         | 40/468 [00:00<00:41, 10.40it/s, loss=0.759, accuracy=0.773][A
Epoch 4:   9%|▉         | 41/468 [00:00<00:29, 14.36it/s, loss=0.759, accuracy=0.773][A
Epoch 4:   9%|▉         | 41/468 [00:00<00:29, 14.36it/s, loss=0.86, accuracy=0.758] [A
Epoch 4:   9%|▉         | 42/468 [00:00<00:29, 14.36it/s, loss=0.791, accuracy=0.758][A
Epoch 4:   9%|▉         | 43/468 [00:00<00:29, 14.36it/s, loss=0.648, accuracy=0.789][A
Epoch 4:   9%|▉         | 44/468 [00:00<00:29, 14.36it/s, loss=0.888, accuracy=0.656][A
Epoch 4:  10%|▉         | 45/468 [00:00<00:29, 14.36it/s, loss=0.741, accuracy=0.805][A
Epoch 4:  10%|▉         | 46/468 [00:00<00:29, 14.36it/s, loss=0.758, accuracy=0.766][A
Epoch 4:  10%|█      

Epoch 4:  26%|██▌       | 121/468 [00:01<00:05, 66.71it/s, loss=0.614, accuracy=0.82][A
Epoch 4:  26%|██▌       | 122/468 [00:01<00:05, 66.71it/s, loss=0.681, accuracy=0.812][A
Epoch 4:  26%|██▋       | 123/468 [00:01<00:05, 66.71it/s, loss=0.806, accuracy=0.734][A
Epoch 4:  26%|██▋       | 124/468 [00:01<00:05, 66.71it/s, loss=0.9, accuracy=0.703]  [A
Epoch 4:  27%|██▋       | 125/468 [00:01<00:05, 66.71it/s, loss=0.801, accuracy=0.75][A
Epoch 4:  27%|██▋       | 126/468 [00:01<00:05, 66.71it/s, loss=0.891, accuracy=0.742][A
Epoch 4:  27%|██▋       | 127/468 [00:01<00:05, 66.71it/s, loss=0.725, accuracy=0.766][A
Epoch 4:  27%|██▋       | 128/468 [00:01<00:05, 66.71it/s, loss=0.615, accuracy=0.812][A
Epoch 4:  28%|██▊       | 129/468 [00:01<00:05, 66.71it/s, loss=0.678, accuracy=0.781][A
Epoch 4:  28%|██▊       | 130/468 [00:01<00:05, 66.71it/s, loss=0.738, accuracy=0.812][A
Epoch 4:  28%|██▊       | 131/468 [00:01<00:05, 66.71it/s, loss=0.806, accuracy=0.766][A
Epoch 4:  28

Epoch 4:  44%|████▍     | 205/468 [00:01<00:02, 114.91it/s, loss=0.573, accuracy=0.82] [A
Epoch 4:  44%|████▍     | 206/468 [00:01<00:02, 114.91it/s, loss=0.713, accuracy=0.797][A
Epoch 4:  44%|████▍     | 207/468 [00:01<00:02, 114.91it/s, loss=0.951, accuracy=0.742][A
Epoch 4:  44%|████▍     | 208/468 [00:01<00:02, 114.91it/s, loss=0.9, accuracy=0.758]  [A
Epoch 4:  45%|████▍     | 209/468 [00:01<00:02, 114.91it/s, loss=0.956, accuracy=0.695][A
Epoch 4:  45%|████▍     | 210/468 [00:01<00:02, 114.91it/s, loss=0.677, accuracy=0.797][A
Epoch 4:  45%|████▌     | 211/468 [00:01<00:02, 114.91it/s, loss=0.707, accuracy=0.789][A
Epoch 4:  45%|████▌     | 212/468 [00:01<00:02, 114.91it/s, loss=0.662, accuracy=0.773][A
Epoch 4:  46%|████▌     | 213/468 [00:01<00:02, 118.37it/s, loss=0.662, accuracy=0.773][A
Epoch 4:  46%|████▌     | 213/468 [00:01<00:02, 118.37it/s, loss=0.84, accuracy=0.742] [A
Epoch 4:  46%|████▌     | 214/468 [00:01<00:02, 118.37it/s, loss=0.576, accuracy=0.836][A

Epoch 4:  62%|██████▏   | 289/468 [00:02<00:01, 130.16it/s, loss=0.6, accuracy=0.812]  [A
Epoch 4:  62%|██████▏   | 290/468 [00:02<00:01, 130.16it/s, loss=0.834, accuracy=0.734][A
Epoch 4:  62%|██████▏   | 291/468 [00:02<00:01, 130.16it/s, loss=0.607, accuracy=0.812][A
Epoch 4:  62%|██████▏   | 292/468 [00:02<00:01, 130.16it/s, loss=0.874, accuracy=0.758][A
Epoch 4:  63%|██████▎   | 293/468 [00:02<00:01, 130.16it/s, loss=0.978, accuracy=0.797][A
Epoch 4:  63%|██████▎   | 294/468 [00:02<00:01, 130.16it/s, loss=0.744, accuracy=0.719][A
Epoch 4:  63%|██████▎   | 295/468 [00:02<00:01, 130.16it/s, loss=0.676, accuracy=0.805][A
Epoch 4:  63%|██████▎   | 296/468 [00:02<00:01, 131.58it/s, loss=0.676, accuracy=0.805][A
Epoch 4:  63%|██████▎   | 296/468 [00:02<00:01, 131.58it/s, loss=0.774, accuracy=0.734][A
Epoch 4:  63%|██████▎   | 297/468 [00:02<00:01, 131.58it/s, loss=0.589, accuracy=0.859][A
Epoch 4:  64%|██████▎   | 298/468 [00:02<00:01, 131.58it/s, loss=0.589, accuracy=0.82] [A

Epoch 4:  80%|███████▉  | 373/468 [00:03<00:00, 129.18it/s, loss=0.624, accuracy=0.797][A
Epoch 4:  80%|███████▉  | 374/468 [00:03<00:00, 129.18it/s, loss=0.601, accuracy=0.844][A
Epoch 4:  80%|████████  | 375/468 [00:03<00:00, 129.18it/s, loss=0.905, accuracy=0.742][A
Epoch 4:  80%|████████  | 376/468 [00:03<00:00, 129.18it/s, loss=0.562, accuracy=0.852][A
Epoch 4:  81%|████████  | 377/468 [00:03<00:00, 129.18it/s, loss=0.516, accuracy=0.844][A
Epoch 4:  81%|████████  | 378/468 [00:03<00:00, 129.18it/s, loss=0.597, accuracy=0.789][A
Epoch 4:  81%|████████  | 379/468 [00:03<00:00, 126.54it/s, loss=0.597, accuracy=0.789][A
Epoch 4:  81%|████████  | 379/468 [00:03<00:00, 126.54it/s, loss=0.518, accuracy=0.812][A
Epoch 4:  81%|████████  | 380/468 [00:03<00:00, 126.54it/s, loss=0.558, accuracy=0.836][A
Epoch 4:  81%|████████▏ | 381/468 [00:03<00:00, 126.54it/s, loss=0.455, accuracy=0.891][A
Epoch 4:  82%|████████▏ | 382/468 [00:03<00:00, 126.54it/s, loss=0.426, accuracy=0.844][A

Epoch 4:  98%|█████████▊| 457/468 [00:03<00:00, 132.80it/s, loss=0.545, accuracy=0.812][A
Epoch 4:  98%|█████████▊| 458/468 [00:03<00:00, 132.80it/s, loss=0.439, accuracy=0.891][A
Epoch 4:  98%|█████████▊| 459/468 [00:03<00:00, 132.80it/s, loss=0.505, accuracy=0.859][A
Epoch 4:  98%|█████████▊| 460/468 [00:03<00:00, 132.80it/s, loss=0.617, accuracy=0.844][A
Epoch 4:  99%|█████████▊| 461/468 [00:03<00:00, 132.80it/s, loss=0.411, accuracy=0.883][A
Epoch 4:  99%|█████████▊| 462/468 [00:03<00:00, 132.80it/s, loss=0.301, accuracy=0.953][A
Epoch 4:  99%|█████████▉| 463/468 [00:03<00:00, 132.80it/s, loss=0.473, accuracy=0.836][A
Epoch 4:  99%|█████████▉| 464/468 [00:03<00:00, 137.46it/s, loss=0.473, accuracy=0.836][A
Epoch 4:  99%|█████████▉| 464/468 [00:03<00:00, 137.46it/s, loss=0.782, accuracy=0.727][A
Epoch 4:  99%|█████████▉| 465/468 [00:03<00:00, 137.46it/s, loss=0.445, accuracy=0.852][A
Epoch 4: 100%|█████████▉| 466/468 [00:03<00:00, 137.46it/s, loss=0.402, accuracy=0.891][A

Epoch 5:  16%|█▌        | 73/468 [00:00<00:14, 27.03it/s, loss=0.587, accuracy=0.859][A
Epoch 5:  16%|█▌        | 74/468 [00:00<00:14, 27.03it/s, loss=0.607, accuracy=0.836][A
Epoch 5:  16%|█▌        | 75/468 [00:00<00:14, 27.03it/s, loss=0.702, accuracy=0.797][A
Epoch 5:  16%|█▌        | 76/468 [00:00<00:14, 27.03it/s, loss=0.497, accuracy=0.844][A
Epoch 5:  16%|█▋        | 77/468 [00:00<00:14, 27.03it/s, loss=0.682, accuracy=0.844][A
Epoch 5:  17%|█▋        | 78/468 [00:00<00:14, 27.03it/s, loss=0.606, accuracy=0.82] [A
Epoch 5:  17%|█▋        | 79/468 [00:00<00:14, 27.03it/s, loss=0.627, accuracy=0.812][A
Epoch 5:  17%|█▋        | 80/468 [00:00<00:14, 27.03it/s, loss=0.598, accuracy=0.828][A
Epoch 5:  17%|█▋        | 81/468 [00:00<00:14, 27.03it/s, loss=0.64, accuracy=0.828] [A
Epoch 5:  18%|█▊        | 82/468 [00:00<00:14, 27.03it/s, loss=0.515, accuracy=0.844][A
Epoch 5:  18%|█▊        | 83/468 [00:00<00:14, 27.03it/s, loss=0.485, accuracy=0.82] [A
Epoch 5:  18%|█▊     

Epoch 5:  34%|███▍      | 158/468 [00:01<00:03, 90.21it/s, loss=0.751, accuracy=0.852][A
Epoch 5:  34%|███▍      | 159/468 [00:01<00:03, 90.21it/s, loss=0.643, accuracy=0.805][A
Epoch 5:  34%|███▍      | 160/468 [00:01<00:03, 90.21it/s, loss=0.522, accuracy=0.852][A
Epoch 5:  34%|███▍      | 161/468 [00:01<00:03, 90.21it/s, loss=0.644, accuracy=0.773][A
Epoch 5:  35%|███▍      | 162/468 [00:01<00:03, 90.21it/s, loss=0.607, accuracy=0.812][A
Epoch 5:  35%|███▍      | 163/468 [00:01<00:03, 90.21it/s, loss=0.74, accuracy=0.773] [A
Epoch 5:  35%|███▌      | 164/468 [00:01<00:03, 90.21it/s, loss=0.734, accuracy=0.789][A
Epoch 5:  35%|███▌      | 165/468 [00:01<00:03, 90.21it/s, loss=0.659, accuracy=0.812][A
Epoch 5:  35%|███▌      | 166/468 [00:01<00:03, 90.21it/s, loss=0.509, accuracy=0.859][A
Epoch 5:  36%|███▌      | 167/468 [00:01<00:03, 90.21it/s, loss=0.43, accuracy=0.828] [A
Epoch 5:  36%|███▌      | 168/468 [00:01<00:03, 90.21it/s, loss=0.501, accuracy=0.805][A
Epoch 5:  

Epoch 5:  52%|█████▏    | 242/468 [00:01<00:01, 127.70it/s, loss=0.688, accuracy=0.82][A
Epoch 5:  52%|█████▏    | 243/468 [00:01<00:01, 127.70it/s, loss=0.638, accuracy=0.805][A
Epoch 5:  52%|█████▏    | 244/468 [00:02<00:01, 127.70it/s, loss=0.721, accuracy=0.812][A
Epoch 5:  52%|█████▏    | 245/468 [00:02<00:01, 127.70it/s, loss=0.866, accuracy=0.703][A
Epoch 5:  53%|█████▎    | 246/468 [00:02<00:01, 127.70it/s, loss=0.7, accuracy=0.805]  [A
Epoch 5:  53%|█████▎    | 247/468 [00:02<00:01, 127.70it/s, loss=0.586, accuracy=0.805][A
Epoch 5:  53%|█████▎    | 248/468 [00:02<00:01, 127.70it/s, loss=0.804, accuracy=0.727][A
Epoch 5:  53%|█████▎    | 249/468 [00:02<00:01, 127.70it/s, loss=0.465, accuracy=0.875][A
Epoch 5:  53%|█████▎    | 250/468 [00:02<00:01, 127.70it/s, loss=0.561, accuracy=0.82] [A
Epoch 5:  54%|█████▎    | 251/468 [00:02<00:01, 127.70it/s, loss=0.581, accuracy=0.828][A
Epoch 5:  54%|█████▍    | 252/468 [00:02<00:01, 127.70it/s, loss=0.603, accuracy=0.797][A


Epoch 5:  70%|██████▉   | 326/468 [00:02<00:01, 133.68it/s, loss=0.516, accuracy=0.875][A
Epoch 5:  70%|██████▉   | 327/468 [00:02<00:01, 133.68it/s, loss=0.575, accuracy=0.844][A
Epoch 5:  70%|███████   | 328/468 [00:02<00:01, 133.68it/s, loss=0.636, accuracy=0.805][A
Epoch 5:  70%|███████   | 329/468 [00:02<00:01, 133.68it/s, loss=0.553, accuracy=0.844][A
Epoch 5:  71%|███████   | 330/468 [00:02<00:01, 133.68it/s, loss=0.783, accuracy=0.766][A
Epoch 5:  71%|███████   | 331/468 [00:02<00:01, 133.68it/s, loss=0.738, accuracy=0.719][A
Epoch 5:  71%|███████   | 332/468 [00:02<00:01, 133.68it/s, loss=0.702, accuracy=0.828][A
Epoch 5:  71%|███████   | 333/468 [00:02<00:01, 133.68it/s, loss=0.597, accuracy=0.82] [A
Epoch 5:  71%|███████▏  | 334/468 [00:02<00:01, 133.68it/s, loss=0.535, accuracy=0.867][A
Epoch 5:  72%|███████▏  | 335/468 [00:02<00:00, 133.68it/s, loss=0.713, accuracy=0.805][A
Epoch 5:  72%|███████▏  | 336/468 [00:02<00:00, 133.68it/s, loss=0.666, accuracy=0.812][A

Epoch 5:  88%|████████▊ | 410/468 [00:03<00:00, 134.75it/s, loss=0.554, accuracy=0.859][A
Epoch 5:  88%|████████▊ | 411/468 [00:03<00:00, 134.75it/s, loss=0.331, accuracy=0.922][A
Epoch 5:  88%|████████▊ | 412/468 [00:03<00:00, 134.75it/s, loss=0.47, accuracy=0.812] [A
Epoch 5:  88%|████████▊ | 413/468 [00:03<00:00, 134.75it/s, loss=0.692, accuracy=0.727][A
Epoch 5:  88%|████████▊ | 414/468 [00:03<00:00, 134.75it/s, loss=0.826, accuracy=0.758][A
Epoch 5:  89%|████████▊ | 415/468 [00:03<00:00, 134.75it/s, loss=0.589, accuracy=0.82] [A
Epoch 5:  89%|████████▉ | 416/468 [00:03<00:00, 134.75it/s, loss=0.727, accuracy=0.766][A
Epoch 5:  89%|████████▉ | 417/468 [00:03<00:00, 134.75it/s, loss=0.428, accuracy=0.875][A
Epoch 5:  89%|████████▉ | 418/468 [00:03<00:00, 134.75it/s, loss=0.449, accuracy=0.859][A
Epoch 5:  90%|████████▉ | 419/468 [00:03<00:00, 134.75it/s, loss=0.593, accuracy=0.812][A
Epoch 5:  90%|████████▉ | 420/468 [00:03<00:00, 134.75it/s, loss=0.518, accuracy=0.812][A

Finished.
final validation accuracy:





# TRAIN MODEL USING SAT

In [19]:
def adversarial_training_entropy(model, optimiser, loss_fn, x, y, epoch, adversary, k, step, eps, norm, gamma):
    """Performs a single update against a specified adversary"""
    model.train()
    
    # Adversial perturbation
    #alpha = 0.8
    N = 1
    loss = 0
    for l in range(N):
        x_adv = adversary(model, x, y, loss_fn, k=k, step=step, eps=eps, norm=norm, random=True, gamma=gamma)
        
        optimiser.zero_grad()
        y_pred = model(x_adv)
        loss = loss + loss_fn(y_pred,y)
        #loss = (1-alpha)*loss + alpha*loss_fn(y_pred, y)
    loss = loss/N
    loss.backward()
    optimiser.step()
    
    return loss, y_pred

In [24]:
if TrainSAT2:
    ## initialize model
    model_SAT2 = Net().to(DEVICE)
    ## train params
    lr = 0.01
    optimiser = optim.SGD(model_SAT2.parameters(), lr=lr)
    epochs = 5
    ## train model
    training_history_entropySmoothing = olympic.fit(
        model_SAT2,
        optimiser,
        nn.CrossEntropyLoss(),
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=adversarial_training_entropy,
        update_fn_kwargs={'adversary': entropySmoothing, 'k': 2, 'step': 0.0005, 'eps': 1.0, 'norm': 2, 'gamma':1e-5},
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5, factor=0.5, min_delta=0.005, monitor='val_accuracy')
        ]
    )
    ## verify validation
    print('final validation accuracy:')
    olympic.evaluate(model_SAT2, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    ## save model
    modelname = '../trainedmodels/'+dataset+'/SAT2_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_SAT2,modelname)


Epoch 1:   0%|          | 0/468 [00:00<?, ?it/s][A

Begin training...



Epoch 1:   0%|          | 1/468 [00:00<01:26,  5.39it/s][A
Epoch 1:   0%|          | 1/468 [00:00<01:26,  5.39it/s, loss=2.32, accuracy=0.0938][A
Epoch 1:   0%|          | 2/468 [00:00<01:26,  5.39it/s, loss=2.32, accuracy=0.0781][A
Epoch 1:   1%|          | 3/468 [00:00<01:26,  5.39it/s, loss=2.29, accuracy=0.148] [A
Epoch 1:   1%|          | 4/468 [00:00<01:26,  5.39it/s, loss=2.32, accuracy=0.0625][A
Epoch 1:   1%|          | 5/468 [00:00<01:25,  5.39it/s, loss=2.32, accuracy=0.0703][A
Epoch 1:   1%|▏         | 6/468 [00:00<01:25,  5.39it/s, loss=2.31, accuracy=0.109] [A
Epoch 1:   1%|▏         | 7/468 [00:00<01:25,  5.39it/s, loss=2.31, accuracy=0.0703][A
Epoch 1:   2%|▏         | 8/468 [00:00<01:25,  5.39it/s, loss=2.32, accuracy=0.109] [A
Epoch 1:   2%|▏         | 9/468 [00:00<01:25,  5.39it/s, loss=2.32, accuracy=0.0781][A
Epoch 1:   2%|▏         | 10/468 [00:00<01:24,  5.39it/s, loss=2.29, accuracy=0.18] [A
Epoch 1:   2%|▏         | 11/468 [00:00<01:00,  7.53it/s, l

Epoch 1:  18%|█▊        | 85/468 [00:00<00:09, 41.81it/s, loss=2.29, accuracy=0.125][A
Epoch 1:  18%|█▊        | 86/468 [00:00<00:09, 41.81it/s, loss=2.31, accuracy=0.117][A
Epoch 1:  19%|█▊        | 87/468 [00:01<00:09, 41.81it/s, loss=2.31, accuracy=0.109][A
Epoch 1:  19%|█▉        | 88/468 [00:01<00:07, 51.36it/s, loss=2.31, accuracy=0.109][A
Epoch 1:  19%|█▉        | 88/468 [00:01<00:07, 51.36it/s, loss=2.3, accuracy=0.125] [A
Epoch 1:  19%|█▉        | 89/468 [00:01<00:07, 51.36it/s, loss=2.29, accuracy=0.102][A
Epoch 1:  19%|█▉        | 90/468 [00:01<00:07, 51.36it/s, loss=2.3, accuracy=0.102] [A
Epoch 1:  19%|█▉        | 91/468 [00:01<00:07, 51.36it/s, loss=2.32, accuracy=0.0781][A
Epoch 1:  20%|█▉        | 92/468 [00:01<00:07, 51.36it/s, loss=2.3, accuracy=0.109]  [A
Epoch 1:  20%|█▉        | 93/468 [00:01<00:07, 51.36it/s, loss=2.3, accuracy=0.109][A
Epoch 1:  20%|██        | 94/468 [00:01<00:07, 51.36it/s, loss=2.28, accuracy=0.109][A
Epoch 1:  20%|██        | 95/46

Epoch 1:  36%|███▌      | 169/468 [00:01<00:02, 102.59it/s, loss=2.32, accuracy=0.102][A
Epoch 1:  36%|███▋      | 170/468 [00:01<00:02, 102.59it/s, loss=2.3, accuracy=0.102] [A
Epoch 1:  37%|███▋      | 171/468 [00:01<00:02, 102.59it/s, loss=2.29, accuracy=0.172][A
Epoch 1:  37%|███▋      | 172/468 [00:01<00:02, 102.59it/s, loss=2.29, accuracy=0.133][A
Epoch 1:  37%|███▋      | 173/468 [00:01<00:02, 106.03it/s, loss=2.29, accuracy=0.133][A
Epoch 1:  37%|███▋      | 173/468 [00:01<00:02, 106.03it/s, loss=2.3, accuracy=0.125] [A
Epoch 1:  37%|███▋      | 174/468 [00:01<00:02, 106.03it/s, loss=2.31, accuracy=0.0859][A
Epoch 1:  37%|███▋      | 175/468 [00:01<00:02, 106.03it/s, loss=2.29, accuracy=0.117] [A
Epoch 1:  38%|███▊      | 176/468 [00:01<00:02, 106.03it/s, loss=2.29, accuracy=0.117][A
Epoch 1:  38%|███▊      | 177/468 [00:01<00:02, 106.03it/s, loss=2.29, accuracy=0.133][A
Epoch 1:  38%|███▊      | 178/468 [00:01<00:02, 106.03it/s, loss=2.29, accuracy=0.148][A
Epoch 1:

Epoch 1:  54%|█████▍    | 252/468 [00:02<00:01, 116.86it/s, loss=2.29, accuracy=0.148][A
Epoch 1:  54%|█████▍    | 253/468 [00:02<00:01, 116.86it/s, loss=2.28, accuracy=0.156][A
Epoch 1:  54%|█████▍    | 254/468 [00:02<00:01, 116.86it/s, loss=2.29, accuracy=0.133][A
Epoch 1:  54%|█████▍    | 255/468 [00:02<00:01, 116.86it/s, loss=2.28, accuracy=0.109][A
Epoch 1:  55%|█████▍    | 256/468 [00:02<00:01, 116.86it/s, loss=2.29, accuracy=0.117][A
Epoch 1:  55%|█████▍    | 257/468 [00:02<00:01, 116.86it/s, loss=2.28, accuracy=0.172][A
Epoch 1:  55%|█████▌    | 258/468 [00:02<00:01, 116.86it/s, loss=2.28, accuracy=0.102][A
Epoch 1:  55%|█████▌    | 259/468 [00:02<00:01, 116.86it/s, loss=2.27, accuracy=0.195][A
Epoch 1:  56%|█████▌    | 260/468 [00:02<00:01, 117.72it/s, loss=2.27, accuracy=0.195][A
Epoch 1:  56%|█████▌    | 260/468 [00:02<00:01, 117.72it/s, loss=2.28, accuracy=0.203][A
Epoch 1:  56%|█████▌    | 261/468 [00:02<00:01, 117.72it/s, loss=2.28, accuracy=0.148][A
Epoch 1:  

Epoch 1:  72%|███████▏  | 337/468 [00:03<00:01, 119.84it/s, loss=2.27, accuracy=0.164][A
Epoch 1:  72%|███████▏  | 338/468 [00:03<00:01, 121.29it/s, loss=2.27, accuracy=0.164][A
Epoch 1:  72%|███████▏  | 338/468 [00:03<00:01, 121.29it/s, loss=2.27, accuracy=0.133][A
Epoch 1:  72%|███████▏  | 339/468 [00:03<00:01, 121.29it/s, loss=2.26, accuracy=0.211][A
Epoch 1:  73%|███████▎  | 340/468 [00:03<00:01, 121.29it/s, loss=2.26, accuracy=0.18] [A
Epoch 1:  73%|███████▎  | 341/468 [00:03<00:01, 121.29it/s, loss=2.25, accuracy=0.156][A
Epoch 1:  73%|███████▎  | 342/468 [00:03<00:01, 121.29it/s, loss=2.27, accuracy=0.133][A
Epoch 1:  73%|███████▎  | 343/468 [00:03<00:01, 121.29it/s, loss=2.27, accuracy=0.148][A
Epoch 1:  74%|███████▎  | 344/468 [00:03<00:01, 121.29it/s, loss=2.27, accuracy=0.18] [A
Epoch 1:  74%|███████▎  | 345/468 [00:03<00:01, 121.29it/s, loss=2.24, accuracy=0.258][A
Epoch 1:  74%|███████▍  | 346/468 [00:03<00:01, 121.29it/s, loss=2.26, accuracy=0.156][A
Epoch 1:  

Epoch 1:  90%|████████▉ | 421/468 [00:03<00:00, 118.62it/s, loss=2.21, accuracy=0.25] [A
Epoch 1:  90%|█████████ | 422/468 [00:03<00:00, 118.62it/s, loss=2.24, accuracy=0.203][A
Epoch 1:  90%|█████████ | 423/468 [00:03<00:00, 118.62it/s, loss=2.23, accuracy=0.211][A
Epoch 1:  91%|█████████ | 424/468 [00:03<00:00, 118.62it/s, loss=2.23, accuracy=0.188][A
Epoch 1:  91%|█████████ | 425/468 [00:03<00:00, 118.62it/s, loss=2.19, accuracy=0.305][A
Epoch 1:  91%|█████████ | 426/468 [00:03<00:00, 118.62it/s, loss=2.26, accuracy=0.195][A
Epoch 1:  91%|█████████ | 427/468 [00:03<00:00, 118.62it/s, loss=2.22, accuracy=0.188][A
Epoch 1:  91%|█████████▏| 428/468 [00:03<00:00, 118.62it/s, loss=2.22, accuracy=0.25] [A
Epoch 1:  92%|█████████▏| 429/468 [00:03<00:00, 119.59it/s, loss=2.22, accuracy=0.25][A
Epoch 1:  92%|█████████▏| 429/468 [00:03<00:00, 119.59it/s, loss=2.23, accuracy=0.188][A
Epoch 1:  92%|█████████▏| 430/468 [00:03<00:00, 119.59it/s, loss=2.23, accuracy=0.266][A
Epoch 1:  9

Epoch 2:   8%|▊         | 38/468 [00:00<00:41, 10.41it/s, loss=2.15, accuracy=0.281][A
Epoch 2:   8%|▊         | 39/468 [00:00<00:29, 14.33it/s, loss=2.15, accuracy=0.281][A
Epoch 2:   8%|▊         | 39/468 [00:00<00:29, 14.33it/s, loss=2.11, accuracy=0.305][A
Epoch 2:   9%|▊         | 40/468 [00:00<00:29, 14.33it/s, loss=2.12, accuracy=0.297][A
Epoch 2:   9%|▉         | 41/468 [00:00<00:29, 14.33it/s, loss=2.11, accuracy=0.344][A
Epoch 2:   9%|▉         | 42/468 [00:00<00:29, 14.33it/s, loss=2.07, accuracy=0.32] [A
Epoch 2:   9%|▉         | 43/468 [00:00<00:29, 14.33it/s, loss=2.14, accuracy=0.281][A
Epoch 2:   9%|▉         | 44/468 [00:00<00:29, 14.33it/s, loss=2.15, accuracy=0.227][A
Epoch 2:  10%|▉         | 45/468 [00:00<00:29, 14.33it/s, loss=2.16, accuracy=0.273][A
Epoch 2:  10%|▉         | 46/468 [00:00<00:29, 14.33it/s, loss=2.16, accuracy=0.242][A
Epoch 2:  10%|█         | 47/468 [00:00<00:29, 14.33it/s, loss=2.14, accuracy=0.242][A
Epoch 2:  10%|█         | 48/468

Epoch 2:  26%|██▋       | 123/468 [00:01<00:05, 64.47it/s, loss=1.99, accuracy=0.32] [A
Epoch 2:  26%|██▋       | 124/468 [00:01<00:05, 64.47it/s, loss=2.01, accuracy=0.312][A
Epoch 2:  27%|██▋       | 125/468 [00:01<00:05, 64.47it/s, loss=1.92, accuracy=0.328][A
Epoch 2:  27%|██▋       | 126/468 [00:01<00:05, 64.47it/s, loss=2.06, accuracy=0.258][A
Epoch 2:  27%|██▋       | 127/468 [00:01<00:05, 64.47it/s, loss=1.87, accuracy=0.383][A
Epoch 2:  27%|██▋       | 128/468 [00:01<00:05, 64.47it/s, loss=1.92, accuracy=0.328][A
Epoch 2:  28%|██▊       | 129/468 [00:01<00:04, 75.39it/s, loss=1.92, accuracy=0.328][A
Epoch 2:  28%|██▊       | 129/468 [00:01<00:04, 75.39it/s, loss=1.89, accuracy=0.398][A
Epoch 2:  28%|██▊       | 130/468 [00:01<00:04, 75.39it/s, loss=1.85, accuracy=0.375][A
Epoch 2:  28%|██▊       | 131/468 [00:01<00:04, 75.39it/s, loss=1.9, accuracy=0.352] [A
Epoch 2:  28%|██▊       | 132/468 [00:01<00:04, 75.39it/s, loss=1.93, accuracy=0.352][A
Epoch 2:  28%|██▊    

Epoch 2:  44%|████▍     | 207/468 [00:01<00:02, 112.41it/s, loss=1.79, accuracy=0.406][A
Epoch 2:  44%|████▍     | 208/468 [00:01<00:02, 112.41it/s, loss=1.71, accuracy=0.414][A
Epoch 2:  45%|████▍     | 209/468 [00:01<00:02, 112.41it/s, loss=1.66, accuracy=0.492][A
Epoch 2:  45%|████▍     | 210/468 [00:01<00:02, 112.41it/s, loss=1.75, accuracy=0.445][A
Epoch 2:  45%|████▌     | 211/468 [00:01<00:02, 112.41it/s, loss=1.74, accuracy=0.445][A
Epoch 2:  45%|████▌     | 212/468 [00:01<00:02, 112.41it/s, loss=1.72, accuracy=0.438][A
Epoch 2:  46%|████▌     | 213/468 [00:01<00:02, 112.41it/s, loss=1.8, accuracy=0.359] [A
Epoch 2:  46%|████▌     | 214/468 [00:01<00:02, 112.41it/s, loss=1.7, accuracy=0.406][A
Epoch 2:  46%|████▌     | 215/468 [00:01<00:02, 112.41it/s, loss=1.6, accuracy=0.461][A
Epoch 2:  46%|████▌     | 216/468 [00:01<00:02, 112.41it/s, loss=1.58, accuracy=0.508][A
Epoch 2:  46%|████▋     | 217/468 [00:01<00:02, 112.41it/s, loss=1.65, accuracy=0.508][A
Epoch 2:  47

Epoch 2:  62%|██████▏   | 292/468 [00:02<00:01, 121.02it/s, loss=1.44, accuracy=0.5]  [A
Epoch 2:  63%|██████▎   | 293/468 [00:02<00:01, 121.02it/s, loss=1.49, accuracy=0.508][A
Epoch 2:  63%|██████▎   | 294/468 [00:02<00:01, 121.02it/s, loss=1.44, accuracy=0.539][A
Epoch 2:  63%|██████▎   | 295/468 [00:02<00:01, 121.02it/s, loss=1.36, accuracy=0.547][A
Epoch 2:  63%|██████▎   | 296/468 [00:02<00:01, 119.89it/s, loss=1.36, accuracy=0.547][A
Epoch 2:  63%|██████▎   | 296/468 [00:02<00:01, 119.89it/s, loss=1.47, accuracy=0.492][A
Epoch 2:  63%|██████▎   | 297/468 [00:02<00:01, 119.89it/s, loss=1.31, accuracy=0.531][A
Epoch 2:  64%|██████▎   | 298/468 [00:02<00:01, 119.89it/s, loss=1.39, accuracy=0.562][A
Epoch 2:  64%|██████▍   | 299/468 [00:02<00:01, 119.89it/s, loss=1.35, accuracy=0.594][A
Epoch 2:  64%|██████▍   | 300/468 [00:02<00:01, 119.89it/s, loss=1.51, accuracy=0.508][A
Epoch 2:  64%|██████▍   | 301/468 [00:02<00:01, 119.89it/s, loss=1.4, accuracy=0.508] [A
Epoch 2:  

Epoch 2:  80%|████████  | 376/468 [00:03<00:00, 119.44it/s, loss=1.05, accuracy=0.688][A
Epoch 2:  81%|████████  | 377/468 [00:03<00:00, 119.44it/s, loss=1.09, accuracy=0.609][A
Epoch 2:  81%|████████  | 378/468 [00:03<00:00, 119.44it/s, loss=1.21, accuracy=0.594][A
Epoch 2:  81%|████████  | 379/468 [00:03<00:00, 119.44it/s, loss=1.22, accuracy=0.617][A
Epoch 2:  81%|████████  | 380/468 [00:03<00:00, 119.44it/s, loss=1.14, accuracy=0.617][A
Epoch 2:  81%|████████▏ | 381/468 [00:03<00:00, 119.44it/s, loss=1.13, accuracy=0.656][A
Epoch 2:  82%|████████▏ | 382/468 [00:03<00:00, 119.44it/s, loss=1.05, accuracy=0.609][A
Epoch 2:  82%|████████▏ | 383/468 [00:03<00:00, 119.44it/s, loss=1.34, accuracy=0.547][A
Epoch 2:  82%|████████▏ | 384/468 [00:03<00:00, 119.44it/s, loss=1.29, accuracy=0.594][A
Epoch 2:  82%|████████▏ | 385/468 [00:03<00:00, 119.44it/s, loss=1.34, accuracy=0.602][A
Epoch 2:  82%|████████▏ | 386/468 [00:03<00:00, 119.44it/s, loss=1.02, accuracy=0.75] [A
Epoch 2:  

Epoch 2:  98%|█████████▊| 460/468 [00:03<00:00, 122.16it/s, loss=0.922, accuracy=0.695][A
Epoch 2:  99%|█████████▊| 461/468 [00:03<00:00, 122.16it/s, loss=0.767, accuracy=0.758][A
Epoch 2:  99%|█████████▊| 462/468 [00:03<00:00, 122.16it/s, loss=0.824, accuracy=0.711][A
Epoch 2:  99%|█████████▉| 463/468 [00:03<00:00, 122.16it/s, loss=0.877, accuracy=0.727][A
Epoch 2:  99%|█████████▉| 464/468 [00:03<00:00, 122.16it/s, loss=1.19, accuracy=0.578] [A
Epoch 2:  99%|█████████▉| 465/468 [00:04<00:00, 123.10it/s, loss=1.19, accuracy=0.578][A
Epoch 2:  99%|█████████▉| 465/468 [00:04<00:00, 123.10it/s, loss=0.859, accuracy=0.656][A
Epoch 2: 100%|█████████▉| 466/468 [00:04<00:00, 123.10it/s, loss=0.806, accuracy=0.711][A
Epoch 2: 100%|█████████▉| 467/468 [00:04<00:00, 123.10it/s, loss=1.08, accuracy=0.68]  [A
Epoch 2: 100%|██████████| 468/468 [00:04<00:00, 106.35it/s, loss=1.61, accuracy=0.465, val_loss=0.677, val_accuracy=0.824][A

Epoch 3:   0%|          | 0/468 [00:00<?, ?it/s][A
Epo

Epoch 3:  16%|█▌        | 75/468 [00:00<00:11, 34.53it/s, loss=1.02, accuracy=0.664] [A
Epoch 3:  16%|█▌        | 76/468 [00:00<00:11, 34.53it/s, loss=0.802, accuracy=0.75][A
Epoch 3:  16%|█▋        | 77/468 [00:00<00:11, 34.53it/s, loss=0.895, accuracy=0.727][A
Epoch 3:  17%|█▋        | 78/468 [00:00<00:11, 34.53it/s, loss=0.809, accuracy=0.797][A
Epoch 3:  17%|█▋        | 79/468 [00:00<00:11, 34.53it/s, loss=0.972, accuracy=0.695][A
Epoch 3:  17%|█▋        | 80/468 [00:00<00:11, 34.53it/s, loss=1.03, accuracy=0.695] [A
Epoch 3:  17%|█▋        | 81/468 [00:00<00:11, 34.53it/s, loss=1, accuracy=0.688]   [A
Epoch 3:  18%|█▊        | 82/468 [00:00<00:11, 34.53it/s, loss=0.826, accuracy=0.727][A
Epoch 3:  18%|█▊        | 83/468 [00:00<00:11, 34.53it/s, loss=0.94, accuracy=0.664] [A
Epoch 3:  18%|█▊        | 84/468 [00:00<00:11, 34.53it/s, loss=0.902, accuracy=0.703][A
Epoch 3:  18%|█▊        | 85/468 [00:00<00:11, 34.53it/s, loss=0.923, accuracy=0.711][A
Epoch 3:  18%|█▊       

Epoch 3:  34%|███▍      | 160/468 [00:01<00:03, 92.82it/s, loss=0.757, accuracy=0.758][A
Epoch 3:  34%|███▍      | 161/468 [00:01<00:03, 98.05it/s, loss=0.757, accuracy=0.758][A
Epoch 3:  34%|███▍      | 161/468 [00:01<00:03, 98.05it/s, loss=0.784, accuracy=0.758][A
Epoch 3:  35%|███▍      | 162/468 [00:01<00:03, 98.05it/s, loss=0.952, accuracy=0.703][A
Epoch 3:  35%|███▍      | 163/468 [00:01<00:03, 98.05it/s, loss=1.03, accuracy=0.703] [A
Epoch 3:  35%|███▌      | 164/468 [00:01<00:03, 98.05it/s, loss=1.13, accuracy=0.656][A
Epoch 3:  35%|███▌      | 165/468 [00:01<00:03, 98.05it/s, loss=0.94, accuracy=0.75] [A
Epoch 3:  35%|███▌      | 166/468 [00:01<00:03, 98.05it/s, loss=0.835, accuracy=0.742][A
Epoch 3:  36%|███▌      | 167/468 [00:01<00:03, 98.05it/s, loss=0.883, accuracy=0.672][A
Epoch 3:  36%|███▌      | 168/468 [00:01<00:03, 98.05it/s, loss=0.983, accuracy=0.648][A
Epoch 3:  36%|███▌      | 169/468 [00:01<00:03, 98.05it/s, loss=0.937, accuracy=0.688][A
Epoch 3:  36

Epoch 3:  52%|█████▏    | 243/468 [00:02<00:01, 118.91it/s, loss=0.907, accuracy=0.695][A
Epoch 3:  52%|█████▏    | 244/468 [00:02<00:01, 118.91it/s, loss=0.854, accuracy=0.75] [A
Epoch 3:  52%|█████▏    | 245/468 [00:02<00:01, 118.91it/s, loss=1.19, accuracy=0.602][A
Epoch 3:  53%|█████▎    | 246/468 [00:02<00:01, 118.91it/s, loss=0.893, accuracy=0.711][A
Epoch 3:  53%|█████▎    | 247/468 [00:02<00:01, 118.91it/s, loss=0.868, accuracy=0.711][A
Epoch 3:  53%|█████▎    | 248/468 [00:02<00:01, 118.91it/s, loss=1.05, accuracy=0.617] [A
Epoch 3:  53%|█████▎    | 249/468 [00:02<00:01, 118.91it/s, loss=0.722, accuracy=0.75][A
Epoch 3:  53%|█████▎    | 250/468 [00:02<00:01, 118.91it/s, loss=0.904, accuracy=0.664][A
Epoch 3:  54%|█████▎    | 251/468 [00:02<00:01, 118.91it/s, loss=0.819, accuracy=0.727][A
Epoch 3:  54%|█████▍    | 252/468 [00:02<00:01, 119.53it/s, loss=0.819, accuracy=0.727][A
Epoch 3:  54%|█████▍    | 252/468 [00:02<00:01, 119.53it/s, loss=1.02, accuracy=0.672] [A
E

Epoch 3:  70%|██████▉   | 327/468 [00:02<00:01, 120.04it/s, loss=0.836, accuracy=0.688][A
Epoch 3:  70%|███████   | 328/468 [00:02<00:01, 120.04it/s, loss=0.869, accuracy=0.703][A
Epoch 3:  70%|███████   | 329/468 [00:02<00:01, 120.04it/s, loss=0.744, accuracy=0.773][A
Epoch 3:  71%|███████   | 330/468 [00:02<00:01, 118.78it/s, loss=0.744, accuracy=0.773][A
Epoch 3:  71%|███████   | 330/468 [00:02<00:01, 118.78it/s, loss=0.897, accuracy=0.75] [A
Epoch 3:  71%|███████   | 331/468 [00:02<00:01, 118.78it/s, loss=0.971, accuracy=0.656][A
Epoch 3:  71%|███████   | 332/468 [00:02<00:01, 118.78it/s, loss=0.899, accuracy=0.758][A
Epoch 3:  71%|███████   | 333/468 [00:02<00:01, 118.78it/s, loss=0.742, accuracy=0.781][A
Epoch 3:  71%|███████▏  | 334/468 [00:02<00:01, 118.78it/s, loss=0.813, accuracy=0.773][A
Epoch 3:  72%|███████▏  | 335/468 [00:02<00:01, 118.78it/s, loss=0.871, accuracy=0.75] [A
Epoch 3:  72%|███████▏  | 336/468 [00:02<00:01, 118.78it/s, loss=1.02, accuracy=0.672][A


Epoch 3:  88%|████████▊ | 410/468 [00:03<00:00, 120.53it/s, loss=0.805, accuracy=0.82] [A
Epoch 3:  88%|████████▊ | 411/468 [00:03<00:00, 120.53it/s, loss=0.504, accuracy=0.844][A
Epoch 3:  88%|████████▊ | 412/468 [00:03<00:00, 120.53it/s, loss=0.816, accuracy=0.781][A
Epoch 3:  88%|████████▊ | 413/468 [00:03<00:00, 120.53it/s, loss=0.851, accuracy=0.703][A
Epoch 3:  88%|████████▊ | 414/468 [00:03<00:00, 120.53it/s, loss=1.04, accuracy=0.68]  [A
Epoch 3:  89%|████████▊ | 415/468 [00:03<00:00, 120.53it/s, loss=0.68, accuracy=0.75][A
Epoch 3:  89%|████████▉ | 416/468 [00:03<00:00, 120.53it/s, loss=0.802, accuracy=0.766][A
Epoch 3:  89%|████████▉ | 417/468 [00:03<00:00, 120.53it/s, loss=0.525, accuracy=0.852][A
Epoch 3:  89%|████████▉ | 418/468 [00:03<00:00, 120.53it/s, loss=0.653, accuracy=0.773][A
Epoch 3:  90%|████████▉ | 419/468 [00:03<00:00, 120.53it/s, loss=0.794, accuracy=0.766][A
Epoch 3:  90%|████████▉ | 420/468 [00:03<00:00, 122.99it/s, loss=0.794, accuracy=0.766][A
E

Epoch 4:   6%|▌         | 26/468 [00:00<00:41, 10.56it/s, loss=0.674, accuracy=0.805][A
Epoch 4:   6%|▌         | 26/468 [00:00<00:41, 10.56it/s, loss=0.74, accuracy=0.75]  [A
Epoch 4:   6%|▌         | 27/468 [00:00<00:41, 10.56it/s, loss=0.687, accuracy=0.742][A
Epoch 4:   6%|▌         | 28/468 [00:00<00:41, 10.56it/s, loss=0.804, accuracy=0.688][A
Epoch 4:   6%|▌         | 29/468 [00:00<00:41, 10.56it/s, loss=0.756, accuracy=0.781][A
Epoch 4:   6%|▋         | 30/468 [00:00<00:41, 10.56it/s, loss=0.638, accuracy=0.797][A
Epoch 4:   7%|▋         | 31/468 [00:00<00:41, 10.56it/s, loss=0.564, accuracy=0.836][A
Epoch 4:   7%|▋         | 32/468 [00:00<00:41, 10.56it/s, loss=0.774, accuracy=0.766][A
Epoch 4:   7%|▋         | 33/468 [00:00<00:41, 10.56it/s, loss=0.681, accuracy=0.805][A
Epoch 4:   7%|▋         | 34/468 [00:00<00:41, 10.56it/s, loss=0.599, accuracy=0.812][A
Epoch 4:   7%|▋         | 35/468 [00:00<00:41, 10.56it/s, loss=0.638, accuracy=0.82] [A
Epoch 4:   8%|▊      

Epoch 4:  24%|██▎       | 110/468 [00:01<00:06, 54.36it/s, loss=0.893, accuracy=0.766][A
Epoch 4:  24%|██▎       | 111/468 [00:01<00:06, 54.36it/s, loss=0.693, accuracy=0.797][A
Epoch 4:  24%|██▍       | 112/468 [00:01<00:06, 54.36it/s, loss=0.905, accuracy=0.727][A
Epoch 4:  24%|██▍       | 113/468 [00:01<00:06, 54.36it/s, loss=0.766, accuracy=0.781][A
Epoch 4:  24%|██▍       | 114/468 [00:01<00:06, 54.36it/s, loss=0.875, accuracy=0.703][A
Epoch 4:  25%|██▍       | 115/468 [00:01<00:06, 54.36it/s, loss=0.945, accuracy=0.656][A
Epoch 4:  25%|██▍       | 116/468 [00:01<00:05, 65.35it/s, loss=0.945, accuracy=0.656][A
Epoch 4:  25%|██▍       | 116/468 [00:01<00:05, 65.35it/s, loss=0.881, accuracy=0.688][A
Epoch 4:  25%|██▌       | 117/468 [00:01<00:05, 65.35it/s, loss=0.606, accuracy=0.844][A
Epoch 4:  25%|██▌       | 118/468 [00:01<00:05, 65.35it/s, loss=0.651, accuracy=0.781][A
Epoch 4:  25%|██▌       | 119/468 [00:01<00:05, 65.35it/s, loss=0.647, accuracy=0.797][A
Epoch 4:  

Epoch 4:  41%|████      | 193/468 [00:01<00:02, 106.79it/s, loss=0.736, accuracy=0.781][A
Epoch 4:  41%|████▏     | 194/468 [00:01<00:02, 106.79it/s, loss=0.769, accuracy=0.797][A
Epoch 4:  42%|████▏     | 195/468 [00:01<00:02, 106.79it/s, loss=0.627, accuracy=0.828][A
Epoch 4:  42%|████▏     | 196/468 [00:01<00:02, 106.79it/s, loss=0.626, accuracy=0.797][A
Epoch 4:  42%|████▏     | 197/468 [00:01<00:02, 106.79it/s, loss=0.69, accuracy=0.82]  [A
Epoch 4:  42%|████▏     | 198/468 [00:01<00:02, 106.79it/s, loss=0.763, accuracy=0.805][A
Epoch 4:  43%|████▎     | 199/468 [00:01<00:02, 106.79it/s, loss=0.468, accuracy=0.812][A
Epoch 4:  43%|████▎     | 200/468 [00:01<00:02, 106.79it/s, loss=0.529, accuracy=0.859][A
Epoch 4:  43%|████▎     | 201/468 [00:01<00:02, 106.79it/s, loss=0.66, accuracy=0.789] [A
Epoch 4:  43%|████▎     | 202/468 [00:01<00:02, 108.68it/s, loss=0.66, accuracy=0.789][A
Epoch 4:  43%|████▎     | 202/468 [00:01<00:02, 108.68it/s, loss=0.78, accuracy=0.727][A
E

Epoch 4:  59%|█████▉    | 277/468 [00:02<00:01, 116.39it/s, loss=0.556, accuracy=0.852][A
Epoch 4:  59%|█████▉    | 278/468 [00:02<00:01, 118.74it/s, loss=0.556, accuracy=0.852][A
Epoch 4:  59%|█████▉    | 278/468 [00:02<00:01, 118.74it/s, loss=0.678, accuracy=0.742][A
Epoch 4:  60%|█████▉    | 279/468 [00:02<00:01, 118.74it/s, loss=0.724, accuracy=0.758][A
Epoch 4:  60%|█████▉    | 280/468 [00:02<00:01, 118.74it/s, loss=0.542, accuracy=0.836][A
Epoch 4:  60%|██████    | 281/468 [00:02<00:01, 118.74it/s, loss=0.582, accuracy=0.836][A
Epoch 4:  60%|██████    | 282/468 [00:02<00:01, 118.74it/s, loss=0.661, accuracy=0.789][A
Epoch 4:  60%|██████    | 283/468 [00:02<00:01, 118.74it/s, loss=0.537, accuracy=0.828][A
Epoch 4:  61%|██████    | 284/468 [00:02<00:01, 118.74it/s, loss=0.512, accuracy=0.859][A
Epoch 4:  61%|██████    | 285/468 [00:02<00:01, 118.74it/s, loss=0.624, accuracy=0.812][A
Epoch 4:  61%|██████    | 286/468 [00:02<00:01, 118.74it/s, loss=0.626, accuracy=0.828][A

Epoch 4:  77%|███████▋  | 360/468 [00:03<00:00, 113.53it/s, loss=0.672, accuracy=0.797][A
Epoch 4:  77%|███████▋  | 361/468 [00:03<00:00, 113.53it/s, loss=0.626, accuracy=0.82] [A
Epoch 4:  77%|███████▋  | 362/468 [00:03<00:00, 113.53it/s, loss=0.675, accuracy=0.797][A
Epoch 4:  78%|███████▊  | 363/468 [00:03<00:00, 113.31it/s, loss=0.675, accuracy=0.797][A
Epoch 4:  78%|███████▊  | 363/468 [00:03<00:00, 113.31it/s, loss=0.676, accuracy=0.836][A
Epoch 4:  78%|███████▊  | 364/468 [00:03<00:00, 113.31it/s, loss=0.452, accuracy=0.859][A
Epoch 4:  78%|███████▊  | 365/468 [00:03<00:00, 113.31it/s, loss=0.677, accuracy=0.75] [A
Epoch 4:  78%|███████▊  | 366/468 [00:03<00:00, 113.31it/s, loss=0.486, accuracy=0.836][A
Epoch 4:  78%|███████▊  | 367/468 [00:03<00:00, 113.31it/s, loss=0.473, accuracy=0.844][A
Epoch 4:  79%|███████▊  | 368/468 [00:03<00:00, 113.31it/s, loss=0.733, accuracy=0.82] [A
Epoch 4:  79%|███████▉  | 369/468 [00:03<00:00, 113.31it/s, loss=0.617, accuracy=0.797][A

Epoch 4:  95%|█████████▍| 443/468 [00:03<00:00, 111.91it/s, loss=0.558, accuracy=0.797][A
Epoch 4:  95%|█████████▍| 444/468 [00:03<00:00, 111.91it/s, loss=0.475, accuracy=0.875][A
Epoch 4:  95%|█████████▌| 445/468 [00:03<00:00, 111.91it/s, loss=0.455, accuracy=0.859][A
Epoch 4:  95%|█████████▌| 446/468 [00:03<00:00, 111.91it/s, loss=0.628, accuracy=0.852][A
Epoch 4:  96%|█████████▌| 447/468 [00:03<00:00, 113.02it/s, loss=0.628, accuracy=0.852][A
Epoch 4:  96%|█████████▌| 447/468 [00:03<00:00, 113.02it/s, loss=0.474, accuracy=0.844][A
Epoch 4:  96%|█████████▌| 448/468 [00:03<00:00, 113.02it/s, loss=0.708, accuracy=0.773][A
Epoch 4:  96%|█████████▌| 449/468 [00:04<00:00, 113.02it/s, loss=0.573, accuracy=0.82] [A
Epoch 4:  96%|█████████▌| 450/468 [00:04<00:00, 113.02it/s, loss=0.695, accuracy=0.766][A
Epoch 4:  96%|█████████▋| 451/468 [00:04<00:00, 113.02it/s, loss=0.67, accuracy=0.797] [A
Epoch 4:  97%|█████████▋| 452/468 [00:04<00:00, 113.02it/s, loss=0.555, accuracy=0.836][A

Epoch 5:  12%|█▏        | 58/468 [00:00<00:21, 19.51it/s, loss=0.563, accuracy=0.82] [A
Epoch 5:  13%|█▎        | 59/468 [00:00<00:20, 19.51it/s, loss=0.586, accuracy=0.828][A
Epoch 5:  13%|█▎        | 60/468 [00:00<00:20, 19.51it/s, loss=0.568, accuracy=0.82] [A
Epoch 5:  13%|█▎        | 61/468 [00:00<00:15, 25.96it/s, loss=0.568, accuracy=0.82][A
Epoch 5:  13%|█▎        | 61/468 [00:00<00:15, 25.96it/s, loss=0.566, accuracy=0.82][A
Epoch 5:  13%|█▎        | 62/468 [00:00<00:15, 25.96it/s, loss=0.809, accuracy=0.805][A
Epoch 5:  13%|█▎        | 63/468 [00:00<00:15, 25.96it/s, loss=0.614, accuracy=0.82] [A
Epoch 5:  14%|█▎        | 64/468 [00:00<00:15, 25.96it/s, loss=0.631, accuracy=0.797][A
Epoch 5:  14%|█▍        | 65/468 [00:00<00:15, 25.96it/s, loss=0.721, accuracy=0.773][A
Epoch 5:  14%|█▍        | 66/468 [00:00<00:15, 25.96it/s, loss=0.559, accuracy=0.828][A
Epoch 5:  14%|█▍        | 67/468 [00:00<00:15, 25.96it/s, loss=0.765, accuracy=0.781][A
Epoch 5:  15%|█▍       

Epoch 5:  30%|███       | 142/468 [00:01<00:03, 82.13it/s, loss=0.443, accuracy=0.883][A
Epoch 5:  31%|███       | 143/468 [00:01<00:03, 82.13it/s, loss=0.518, accuracy=0.844][A
Epoch 5:  31%|███       | 144/468 [00:01<00:03, 89.00it/s, loss=0.518, accuracy=0.844][A
Epoch 5:  31%|███       | 144/468 [00:01<00:03, 89.00it/s, loss=0.699, accuracy=0.781][A
Epoch 5:  31%|███       | 145/468 [00:01<00:03, 89.00it/s, loss=0.593, accuracy=0.812][A
Epoch 5:  31%|███       | 146/468 [00:01<00:03, 89.00it/s, loss=0.506, accuracy=0.875][A
Epoch 5:  31%|███▏      | 147/468 [00:01<00:03, 89.00it/s, loss=0.632, accuracy=0.852][A
Epoch 5:  32%|███▏      | 148/468 [00:01<00:03, 89.00it/s, loss=0.487, accuracy=0.859][A
Epoch 5:  32%|███▏      | 149/468 [00:01<00:03, 89.00it/s, loss=0.553, accuracy=0.789][A
Epoch 5:  32%|███▏      | 150/468 [00:01<00:03, 89.00it/s, loss=0.615, accuracy=0.773][A
Epoch 5:  32%|███▏      | 151/468 [00:01<00:03, 89.00it/s, loss=0.471, accuracy=0.867][A
Epoch 5:  

Epoch 5:  48%|████▊     | 225/468 [00:02<00:02, 110.88it/s, loss=0.476, accuracy=0.859][A
Epoch 5:  48%|████▊     | 226/468 [00:02<00:02, 110.88it/s, loss=0.497, accuracy=0.852][A
Epoch 5:  49%|████▊     | 227/468 [00:02<00:02, 110.88it/s, loss=0.51, accuracy=0.844] [A
Epoch 5:  49%|████▊     | 228/468 [00:02<00:02, 110.88it/s, loss=0.746, accuracy=0.75][A
Epoch 5:  49%|████▉     | 229/468 [00:02<00:02, 110.88it/s, loss=0.582, accuracy=0.828][A
Epoch 5:  49%|████▉     | 230/468 [00:02<00:02, 110.89it/s, loss=0.582, accuracy=0.828][A
Epoch 5:  49%|████▉     | 230/468 [00:02<00:02, 110.89it/s, loss=0.461, accuracy=0.859][A
Epoch 5:  49%|████▉     | 231/468 [00:02<00:02, 110.89it/s, loss=0.527, accuracy=0.836][A
Epoch 5:  50%|████▉     | 232/468 [00:02<00:02, 110.89it/s, loss=0.592, accuracy=0.797][A
Epoch 5:  50%|████▉     | 233/468 [00:02<00:02, 110.89it/s, loss=0.673, accuracy=0.812][A
Epoch 5:  50%|█████     | 234/468 [00:02<00:02, 110.89it/s, loss=0.804, accuracy=0.742][A


Epoch 5:  66%|██████▌   | 308/468 [00:02<00:01, 114.53it/s, loss=0.851, accuracy=0.734][A
Epoch 5:  66%|██████▌   | 309/468 [00:02<00:01, 114.53it/s, loss=0.783, accuracy=0.734][A
Epoch 5:  66%|██████▌   | 310/468 [00:02<00:01, 114.53it/s, loss=0.449, accuracy=0.867][A
Epoch 5:  66%|██████▋   | 311/468 [00:02<00:01, 114.53it/s, loss=0.536, accuracy=0.812][A
Epoch 5:  67%|██████▋   | 312/468 [00:02<00:01, 114.53it/s, loss=0.646, accuracy=0.797][A
Epoch 5:  67%|██████▋   | 313/468 [00:02<00:01, 114.53it/s, loss=0.574, accuracy=0.812][A
Epoch 5:  67%|██████▋   | 314/468 [00:02<00:01, 114.15it/s, loss=0.574, accuracy=0.812][A
Epoch 5:  67%|██████▋   | 314/468 [00:02<00:01, 114.15it/s, loss=0.563, accuracy=0.82] [A
Epoch 5:  67%|██████▋   | 315/468 [00:02<00:01, 114.15it/s, loss=0.502, accuracy=0.828][A
Epoch 5:  68%|██████▊   | 316/468 [00:02<00:01, 114.15it/s, loss=0.481, accuracy=0.867][A
Epoch 5:  68%|██████▊   | 317/468 [00:02<00:01, 114.15it/s, loss=0.568, accuracy=0.836][A

Epoch 5:  84%|████████▎ | 391/468 [00:03<00:00, 112.12it/s, loss=0.501, accuracy=0.859][A
Epoch 5:  84%|████████▍ | 392/468 [00:03<00:00, 112.12it/s, loss=0.548, accuracy=0.836][A
Epoch 5:  84%|████████▍ | 393/468 [00:03<00:00, 112.12it/s, loss=0.546, accuracy=0.828][A
Epoch 5:  84%|████████▍ | 394/468 [00:03<00:00, 112.12it/s, loss=0.957, accuracy=0.734][A
Epoch 5:  84%|████████▍ | 395/468 [00:03<00:00, 112.12it/s, loss=0.53, accuracy=0.812] [A
Epoch 5:  85%|████████▍ | 396/468 [00:03<00:00, 112.12it/s, loss=0.784, accuracy=0.766][A
Epoch 5:  85%|████████▍ | 397/468 [00:03<00:00, 112.12it/s, loss=0.54, accuracy=0.828] [A
Epoch 5:  85%|████████▌ | 398/468 [00:03<00:00, 113.08it/s, loss=0.54, accuracy=0.828][A
Epoch 5:  85%|████████▌ | 398/468 [00:03<00:00, 113.08it/s, loss=0.524, accuracy=0.812][A
Epoch 5:  85%|████████▌ | 399/468 [00:03<00:00, 113.08it/s, loss=0.42, accuracy=0.859] [A
Epoch 5:  85%|████████▌ | 400/468 [00:03<00:00, 113.08it/s, loss=0.567, accuracy=0.852][A


Finished.
final validation accuracy:





In [26]:
if TrainSATInf:
    ## initialize model
    model_SATInf = Net().to(DEVICE)
    ## train params
    lr = 0.01
    optimiser = optim.SGD(model_SATInf.parameters(), lr=lr)
    epochs = 5
    ## train model
    training_history_entropySmoothing = olympic.fit(
        model_SATInf,
        optimiser,
        nn.CrossEntropyLoss(),
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=adversarial_training_entropy,
        update_fn_kwargs={'adversary': entropySmoothing, 'k': 2, 'step': 0.0005, 'eps': 0.1, 'norm': 'inf', 'gamma':1e-5},
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5, factor=0.5, min_delta=0.005, monitor='val_accuracy')
        ]
    )
    ## verify validation
    print('final validation accuracy:')
    olympic.evaluate(model_SATInf, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    ## save model
    modelname = '../trainedmodels/'+dataset+'/SATInf_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_SATInf,modelname)


Epoch 1:   0%|          | 0/468 [00:00<?, ?it/s][A

Begin training...



Epoch 1:   0%|          | 1/468 [00:00<01:24,  5.55it/s][A
Epoch 1:   0%|          | 1/468 [00:00<01:24,  5.55it/s, loss=2.3, accuracy=0.0625][A
Epoch 1:   0%|          | 2/468 [00:00<01:23,  5.55it/s, loss=2.32, accuracy=0.0859][A
Epoch 1:   1%|          | 3/468 [00:00<01:23,  5.55it/s, loss=2.3, accuracy=0.0781] [A
Epoch 1:   1%|          | 4/468 [00:00<01:23,  5.55it/s, loss=2.32, accuracy=0.0938][A
Epoch 1:   1%|          | 5/468 [00:00<01:23,  5.55it/s, loss=2.32, accuracy=0.102] [A
Epoch 1:   1%|▏         | 6/468 [00:00<01:23,  5.55it/s, loss=2.31, accuracy=0.0938][A
Epoch 1:   1%|▏         | 7/468 [00:00<01:23,  5.55it/s, loss=2.3, accuracy=0.0938] [A
Epoch 1:   2%|▏         | 8/468 [00:00<01:22,  5.55it/s, loss=2.29, accuracy=0.156][A
Epoch 1:   2%|▏         | 9/468 [00:00<01:22,  5.55it/s, loss=2.31, accuracy=0.0938][A
Epoch 1:   2%|▏         | 10/468 [00:00<01:22,  5.55it/s, loss=2.29, accuracy=0.117][A
Epoch 1:   2%|▏         | 11/468 [00:00<01:22,  5.55it/s, los

Epoch 1:  18%|█▊        | 85/468 [00:00<00:08, 42.99it/s, loss=2.29, accuracy=0.125][A
Epoch 1:  18%|█▊        | 86/468 [00:00<00:08, 42.99it/s, loss=2.31, accuracy=0.102][A
Epoch 1:  19%|█▊        | 87/468 [00:00<00:08, 42.99it/s, loss=2.3, accuracy=0.109] [A
Epoch 1:  19%|█▉        | 88/468 [00:00<00:08, 42.99it/s, loss=2.28, accuracy=0.117][A
Epoch 1:  19%|█▉        | 89/468 [00:00<00:08, 42.99it/s, loss=2.3, accuracy=0.0938][A
Epoch 1:  19%|█▉        | 90/468 [00:01<00:08, 42.99it/s, loss=2.31, accuracy=0.109][A
Epoch 1:  19%|█▉        | 91/468 [00:01<00:07, 52.59it/s, loss=2.31, accuracy=0.109][A
Epoch 1:  19%|█▉        | 91/468 [00:01<00:07, 52.59it/s, loss=2.3, accuracy=0.0625][A
Epoch 1:  20%|█▉        | 92/468 [00:01<00:07, 52.59it/s, loss=2.3, accuracy=0.102] [A
Epoch 1:  20%|█▉        | 93/468 [00:01<00:07, 52.59it/s, loss=2.3, accuracy=0.0938][A
Epoch 1:  20%|██        | 94/468 [00:01<00:07, 52.59it/s, loss=2.29, accuracy=0.102][A
Epoch 1:  20%|██        | 95/468

Epoch 1:  36%|███▋      | 170/468 [00:01<00:02, 104.59it/s, loss=2.29, accuracy=0.148][A
Epoch 1:  37%|███▋      | 171/468 [00:01<00:02, 104.59it/s, loss=2.28, accuracy=0.133][A
Epoch 1:  37%|███▋      | 172/468 [00:01<00:02, 104.59it/s, loss=2.3, accuracy=0.0781][A
Epoch 1:  37%|███▋      | 173/468 [00:01<00:02, 104.59it/s, loss=2.3, accuracy=0.0781][A
Epoch 1:  37%|███▋      | 174/468 [00:01<00:02, 104.59it/s, loss=2.31, accuracy=0.0703][A
Epoch 1:  37%|███▋      | 175/468 [00:01<00:02, 104.59it/s, loss=2.29, accuracy=0.117] [A
Epoch 1:  38%|███▊      | 176/468 [00:01<00:02, 104.59it/s, loss=2.29, accuracy=0.156][A
Epoch 1:  38%|███▊      | 177/468 [00:01<00:02, 104.59it/s, loss=2.3, accuracy=0.117] [A
Epoch 1:  38%|███▊      | 178/468 [00:01<00:02, 109.26it/s, loss=2.3, accuracy=0.117][A
Epoch 1:  38%|███▊      | 178/468 [00:01<00:02, 109.26it/s, loss=2.28, accuracy=0.203][A
Epoch 1:  38%|███▊      | 179/468 [00:01<00:02, 109.26it/s, loss=2.3, accuracy=0.102] [A
Epoch 1: 

Epoch 1:  54%|█████▍    | 254/468 [00:02<00:01, 121.40it/s, loss=2.29, accuracy=0.172][A
Epoch 1:  54%|█████▍    | 254/468 [00:02<00:01, 121.40it/s, loss=2.29, accuracy=0.109][A
Epoch 1:  54%|█████▍    | 255/468 [00:02<00:01, 121.40it/s, loss=2.29, accuracy=0.125][A
Epoch 1:  55%|█████▍    | 256/468 [00:02<00:01, 121.40it/s, loss=2.3, accuracy=0.109] [A
Epoch 1:  55%|█████▍    | 257/468 [00:02<00:01, 121.40it/s, loss=2.28, accuracy=0.156][A
Epoch 1:  55%|█████▌    | 258/468 [00:02<00:01, 121.40it/s, loss=2.28, accuracy=0.141][A
Epoch 1:  55%|█████▌    | 259/468 [00:02<00:01, 121.40it/s, loss=2.29, accuracy=0.0781][A
Epoch 1:  56%|█████▌    | 260/468 [00:02<00:01, 121.40it/s, loss=2.29, accuracy=0.102] [A
Epoch 1:  56%|█████▌    | 261/468 [00:02<00:01, 121.40it/s, loss=2.29, accuracy=0.156][A
Epoch 1:  56%|█████▌    | 262/468 [00:02<00:01, 121.40it/s, loss=2.28, accuracy=0.141][A
Epoch 1:  56%|█████▌    | 263/468 [00:02<00:01, 121.40it/s, loss=2.29, accuracy=0.141][A
Epoch 1:

Epoch 1:  72%|███████▏  | 337/468 [00:03<00:01, 122.03it/s, loss=2.28, accuracy=0.133][A
Epoch 1:  72%|███████▏  | 338/468 [00:03<00:01, 122.03it/s, loss=2.27, accuracy=0.211][A
Epoch 1:  72%|███████▏  | 339/468 [00:03<00:01, 122.03it/s, loss=2.27, accuracy=0.148][A
Epoch 1:  73%|███████▎  | 340/468 [00:03<00:01, 122.03it/s, loss=2.28, accuracy=0.195][A
Epoch 1:  73%|███████▎  | 341/468 [00:03<00:01, 122.03it/s, loss=2.27, accuracy=0.188][A
Epoch 1:  73%|███████▎  | 342/468 [00:03<00:01, 122.03it/s, loss=2.27, accuracy=0.18] [A
Epoch 1:  73%|███████▎  | 343/468 [00:03<00:01, 122.03it/s, loss=2.28, accuracy=0.125][A
Epoch 1:  74%|███████▎  | 344/468 [00:03<00:01, 122.03it/s, loss=2.27, accuracy=0.148][A
Epoch 1:  74%|███████▎  | 345/468 [00:03<00:01, 122.19it/s, loss=2.27, accuracy=0.148][A
Epoch 1:  74%|███████▎  | 345/468 [00:03<00:01, 122.19it/s, loss=2.28, accuracy=0.125][A
Epoch 1:  74%|███████▍  | 346/468 [00:03<00:00, 122.19it/s, loss=2.27, accuracy=0.188][A
Epoch 1:  

Epoch 1:  90%|█████████ | 422/468 [00:03<00:00, 124.74it/s, loss=2.25, accuracy=0.164][A
Epoch 1:  90%|█████████ | 423/468 [00:03<00:00, 124.74it/s, loss=2.23, accuracy=0.227][A
Epoch 1:  91%|█████████ | 424/468 [00:03<00:00, 123.44it/s, loss=2.23, accuracy=0.227][A
Epoch 1:  91%|█████████ | 424/468 [00:03<00:00, 123.44it/s, loss=2.25, accuracy=0.203][A
Epoch 1:  91%|█████████ | 425/468 [00:03<00:00, 123.44it/s, loss=2.24, accuracy=0.273][A
Epoch 1:  91%|█████████ | 426/468 [00:03<00:00, 123.44it/s, loss=2.24, accuracy=0.188][A
Epoch 1:  91%|█████████ | 427/468 [00:03<00:00, 123.44it/s, loss=2.23, accuracy=0.18] [A
Epoch 1:  91%|█████████▏| 428/468 [00:03<00:00, 123.44it/s, loss=2.24, accuracy=0.219][A
Epoch 1:  92%|█████████▏| 429/468 [00:03<00:00, 123.44it/s, loss=2.24, accuracy=0.188][A
Epoch 1:  92%|█████████▏| 430/468 [00:03<00:00, 123.44it/s, loss=2.25, accuracy=0.219][A
Epoch 1:  92%|█████████▏| 431/468 [00:03<00:00, 123.44it/s, loss=2.24, accuracy=0.234][A
Epoch 1:  

Epoch 2:   8%|▊         | 38/468 [00:00<00:29, 14.61it/s, loss=2.18, accuracy=0.312][A
Epoch 2:   8%|▊         | 38/468 [00:00<00:29, 14.61it/s, loss=2.18, accuracy=0.211][A
Epoch 2:   8%|▊         | 39/468 [00:00<00:29, 14.61it/s, loss=2.16, accuracy=0.297][A
Epoch 2:   9%|▊         | 40/468 [00:00<00:29, 14.61it/s, loss=2.19, accuracy=0.258][A
Epoch 2:   9%|▉         | 41/468 [00:00<00:29, 14.61it/s, loss=2.14, accuracy=0.312][A
Epoch 2:   9%|▉         | 42/468 [00:00<00:29, 14.61it/s, loss=2.17, accuracy=0.234][A
Epoch 2:   9%|▉         | 43/468 [00:00<00:29, 14.61it/s, loss=2.18, accuracy=0.281][A
Epoch 2:   9%|▉         | 44/468 [00:00<00:29, 14.61it/s, loss=2.17, accuracy=0.273][A
Epoch 2:  10%|▉         | 45/468 [00:00<00:28, 14.61it/s, loss=2.14, accuracy=0.289][A
Epoch 2:  10%|▉         | 46/468 [00:00<00:28, 14.61it/s, loss=2.17, accuracy=0.312][A
Epoch 2:  10%|█         | 47/468 [00:00<00:28, 14.61it/s, loss=2.18, accuracy=0.211][A
Epoch 2:  10%|█         | 48/468

Epoch 2:  26%|██▋       | 123/468 [00:01<00:05, 64.11it/s, loss=2.05, accuracy=0.289][A
Epoch 2:  26%|██▋       | 124/468 [00:01<00:04, 73.22it/s, loss=2.05, accuracy=0.289][A
Epoch 2:  26%|██▋       | 124/468 [00:01<00:04, 73.22it/s, loss=2.02, accuracy=0.352][A
Epoch 2:  27%|██▋       | 125/468 [00:01<00:04, 73.22it/s, loss=2.05, accuracy=0.328][A
Epoch 2:  27%|██▋       | 126/468 [00:01<00:04, 73.22it/s, loss=2.09, accuracy=0.289][A
Epoch 2:  27%|██▋       | 127/468 [00:01<00:04, 73.22it/s, loss=2.01, accuracy=0.328][A
Epoch 2:  27%|██▋       | 128/468 [00:01<00:04, 73.22it/s, loss=2.03, accuracy=0.344][A
Epoch 2:  28%|██▊       | 129/468 [00:01<00:04, 73.22it/s, loss=2.02, accuracy=0.352][A
Epoch 2:  28%|██▊       | 130/468 [00:01<00:04, 73.22it/s, loss=1.98, accuracy=0.336][A
Epoch 2:  28%|██▊       | 131/468 [00:01<00:04, 73.22it/s, loss=1.96, accuracy=0.375][A
Epoch 2:  28%|██▊       | 132/468 [00:01<00:04, 73.22it/s, loss=2.03, accuracy=0.289][A
Epoch 2:  28%|██▊    

Epoch 2:  44%|████▍     | 207/468 [00:01<00:02, 110.34it/s, loss=1.79, accuracy=0.422][A
Epoch 2:  44%|████▍     | 208/468 [00:01<00:02, 110.34it/s, loss=1.74, accuracy=0.445][A
Epoch 2:  45%|████▍     | 209/468 [00:01<00:02, 110.34it/s, loss=1.85, accuracy=0.367][A
Epoch 2:  45%|████▍     | 210/468 [00:01<00:02, 110.34it/s, loss=1.7, accuracy=0.492] [A
Epoch 2:  45%|████▌     | 211/468 [00:01<00:02, 111.30it/s, loss=1.7, accuracy=0.492][A
Epoch 2:  45%|████▌     | 211/468 [00:01<00:02, 111.30it/s, loss=1.73, accuracy=0.406][A
Epoch 2:  45%|████▌     | 212/468 [00:01<00:02, 111.30it/s, loss=1.81, accuracy=0.391][A
Epoch 2:  46%|████▌     | 213/468 [00:01<00:02, 111.30it/s, loss=1.85, accuracy=0.367][A
Epoch 2:  46%|████▌     | 214/468 [00:01<00:02, 111.30it/s, loss=1.81, accuracy=0.383][A
Epoch 2:  46%|████▌     | 215/468 [00:02<00:02, 111.30it/s, loss=1.82, accuracy=0.383][A
Epoch 2:  46%|████▌     | 216/468 [00:02<00:02, 111.30it/s, loss=1.62, accuracy=0.492][A
Epoch 2:  4

Epoch 2:  62%|██████▏   | 291/468 [00:02<00:01, 115.60it/s, loss=1.54, accuracy=0.484][A
Epoch 2:  62%|██████▏   | 292/468 [00:02<00:01, 115.60it/s, loss=1.49, accuracy=0.508][A
Epoch 2:  63%|██████▎   | 293/468 [00:02<00:01, 115.60it/s, loss=1.59, accuracy=0.531][A
Epoch 2:  63%|██████▎   | 294/468 [00:02<00:01, 115.60it/s, loss=1.54, accuracy=0.5]  [A
Epoch 2:  63%|██████▎   | 295/468 [00:02<00:01, 115.60it/s, loss=1.47, accuracy=0.5][A
Epoch 2:  63%|██████▎   | 296/468 [00:02<00:01, 116.57it/s, loss=1.47, accuracy=0.5][A
Epoch 2:  63%|██████▎   | 296/468 [00:02<00:01, 116.57it/s, loss=1.54, accuracy=0.422][A
Epoch 2:  63%|██████▎   | 297/468 [00:02<00:01, 116.57it/s, loss=1.39, accuracy=0.516][A
Epoch 2:  64%|██████▎   | 298/468 [00:02<00:01, 116.57it/s, loss=1.52, accuracy=0.438][A
Epoch 2:  64%|██████▍   | 299/468 [00:02<00:01, 116.57it/s, loss=1.53, accuracy=0.5]  [A
Epoch 2:  64%|██████▍   | 300/468 [00:02<00:01, 116.57it/s, loss=1.59, accuracy=0.469][A
Epoch 2:  64%|

Epoch 2:  80%|████████  | 375/468 [00:03<00:00, 117.11it/s, loss=1.31, accuracy=0.586][A
Epoch 2:  80%|████████  | 376/468 [00:03<00:00, 117.11it/s, loss=1.33, accuracy=0.523][A
Epoch 2:  81%|████████  | 377/468 [00:03<00:00, 117.11it/s, loss=1.24, accuracy=0.555][A
Epoch 2:  81%|████████  | 378/468 [00:03<00:00, 117.11it/s, loss=1.14, accuracy=0.609][A
Epoch 2:  81%|████████  | 379/468 [00:03<00:00, 117.11it/s, loss=1.15, accuracy=0.609][A
Epoch 2:  81%|████████  | 380/468 [00:03<00:00, 117.11it/s, loss=1.19, accuracy=0.633][A
Epoch 2:  81%|████████▏ | 381/468 [00:03<00:00, 117.11it/s, loss=1.18, accuracy=0.633][A
Epoch 2:  82%|████████▏ | 382/468 [00:03<00:00, 117.11it/s, loss=1.21, accuracy=0.562][A
Epoch 2:  82%|████████▏ | 383/468 [00:03<00:00, 117.11it/s, loss=1.54, accuracy=0.523][A
Epoch 2:  82%|████████▏ | 384/468 [00:03<00:00, 118.95it/s, loss=1.54, accuracy=0.523][A
Epoch 2:  82%|████████▏ | 384/468 [00:03<00:00, 118.95it/s, loss=1.33, accuracy=0.578][A
Epoch 2:  

Epoch 2:  98%|█████████▊| 459/468 [00:04<00:00, 121.89it/s, loss=0.991, accuracy=0.719][A
Epoch 2:  98%|█████████▊| 460/468 [00:04<00:00, 121.89it/s, loss=1.07, accuracy=0.602] [A
Epoch 2:  99%|█████████▊| 461/468 [00:04<00:00, 121.89it/s, loss=0.825, accuracy=0.719][A
Epoch 2:  99%|█████████▊| 462/468 [00:04<00:00, 122.39it/s, loss=0.825, accuracy=0.719][A
Epoch 2:  99%|█████████▊| 462/468 [00:04<00:00, 122.39it/s, loss=0.906, accuracy=0.719][A
Epoch 2:  99%|█████████▉| 463/468 [00:04<00:00, 122.39it/s, loss=0.915, accuracy=0.656][A
Epoch 2:  99%|█████████▉| 464/468 [00:04<00:00, 122.39it/s, loss=1.34, accuracy=0.477] [A
Epoch 2:  99%|█████████▉| 465/468 [00:04<00:00, 122.39it/s, loss=0.99, accuracy=0.633][A
Epoch 2: 100%|█████████▉| 466/468 [00:04<00:00, 122.39it/s, loss=0.95, accuracy=0.719][A
Epoch 2: 100%|█████████▉| 467/468 [00:04<00:00, 122.39it/s, loss=1.1, accuracy=0.625] [A
Epoch 2: 100%|██████████| 468/468 [00:04<00:00, 104.45it/s, loss=1.68, accuracy=0.439, val_lo

Epoch 3:  16%|█▌        | 75/468 [00:00<00:11, 35.49it/s, loss=1.04, accuracy=0.656][A
Epoch 3:  16%|█▌        | 75/468 [00:00<00:11, 35.49it/s, loss=1.06, accuracy=0.68] [A
Epoch 3:  16%|█▌        | 76/468 [00:00<00:11, 35.49it/s, loss=0.913, accuracy=0.742][A
Epoch 3:  16%|█▋        | 77/468 [00:00<00:11, 35.49it/s, loss=1.04, accuracy=0.641] [A
Epoch 3:  17%|█▋        | 78/468 [00:00<00:10, 35.49it/s, loss=0.916, accuracy=0.742][A
Epoch 3:  17%|█▋        | 79/468 [00:00<00:10, 35.49it/s, loss=0.976, accuracy=0.68] [A
Epoch 3:  17%|█▋        | 80/468 [00:00<00:10, 35.49it/s, loss=1.06, accuracy=0.609][A
Epoch 3:  17%|█▋        | 81/468 [00:00<00:10, 35.49it/s, loss=0.843, accuracy=0.711][A
Epoch 3:  18%|█▊        | 82/468 [00:00<00:10, 35.49it/s, loss=0.947, accuracy=0.688][A
Epoch 3:  18%|█▊        | 83/468 [00:00<00:10, 35.49it/s, loss=0.916, accuracy=0.68] [A
Epoch 3:  18%|█▊        | 84/468 [00:00<00:10, 35.49it/s, loss=1.01, accuracy=0.648][A
Epoch 3:  18%|█▊        |

Epoch 3:  34%|███▍      | 159/468 [00:01<00:03, 90.84it/s, loss=0.984, accuracy=0.688][A
Epoch 3:  34%|███▍      | 160/468 [00:01<00:03, 90.84it/s, loss=0.84, accuracy=0.719] [A
Epoch 3:  34%|███▍      | 161/468 [00:01<00:03, 97.67it/s, loss=0.84, accuracy=0.719][A
Epoch 3:  34%|███▍      | 161/468 [00:01<00:03, 97.67it/s, loss=0.883, accuracy=0.734][A
Epoch 3:  35%|███▍      | 162/468 [00:01<00:03, 97.67it/s, loss=0.969, accuracy=0.695][A
Epoch 3:  35%|███▍      | 163/468 [00:01<00:03, 97.67it/s, loss=1.03, accuracy=0.68]  [A
Epoch 3:  35%|███▌      | 164/468 [00:01<00:03, 97.67it/s, loss=1.05, accuracy=0.68][A
Epoch 3:  35%|███▌      | 165/468 [00:01<00:03, 97.67it/s, loss=0.937, accuracy=0.719][A
Epoch 3:  35%|███▌      | 166/468 [00:01<00:03, 97.67it/s, loss=0.896, accuracy=0.734][A
Epoch 3:  36%|███▌      | 167/468 [00:01<00:03, 97.67it/s, loss=0.832, accuracy=0.734][A
Epoch 3:  36%|███▌      | 168/468 [00:01<00:03, 97.67it/s, loss=0.891, accuracy=0.727][A
Epoch 3:  36%

Epoch 3:  52%|█████▏    | 242/468 [00:02<00:01, 113.98it/s, loss=0.948, accuracy=0.656][A
Epoch 3:  52%|█████▏    | 243/468 [00:02<00:01, 113.98it/s, loss=0.835, accuracy=0.719][A
Epoch 3:  52%|█████▏    | 244/468 [00:02<00:01, 113.98it/s, loss=0.877, accuracy=0.781][A
Epoch 3:  52%|█████▏    | 245/468 [00:02<00:01, 113.98it/s, loss=1.11, accuracy=0.641] [A
Epoch 3:  53%|█████▎    | 246/468 [00:02<00:01, 113.98it/s, loss=0.969, accuracy=0.625][A
Epoch 3:  53%|█████▎    | 247/468 [00:02<00:01, 113.98it/s, loss=0.982, accuracy=0.695][A
Epoch 3:  53%|█████▎    | 248/468 [00:02<00:01, 116.07it/s, loss=0.982, accuracy=0.695][A
Epoch 3:  53%|█████▎    | 248/468 [00:02<00:01, 116.07it/s, loss=1.27, accuracy=0.508] [A
Epoch 3:  53%|█████▎    | 249/468 [00:02<00:01, 116.07it/s, loss=0.793, accuracy=0.812][A
Epoch 3:  53%|█████▎    | 250/468 [00:02<00:01, 116.07it/s, loss=0.884, accuracy=0.703][A
Epoch 3:  54%|█████▎    | 251/468 [00:02<00:01, 116.07it/s, loss=0.988, accuracy=0.727][A

Epoch 3:  69%|██████▉   | 325/468 [00:02<00:01, 117.25it/s, loss=0.859, accuracy=0.734][A
Epoch 3:  70%|██████▉   | 326/468 [00:02<00:01, 117.25it/s, loss=0.821, accuracy=0.766][A
Epoch 3:  70%|██████▉   | 327/468 [00:02<00:01, 117.25it/s, loss=0.859, accuracy=0.75] [A
Epoch 3:  70%|███████   | 328/468 [00:02<00:01, 117.25it/s, loss=0.844, accuracy=0.742][A
Epoch 3:  70%|███████   | 329/468 [00:02<00:01, 117.25it/s, loss=0.938, accuracy=0.68] [A
Epoch 3:  71%|███████   | 330/468 [00:02<00:01, 117.25it/s, loss=1.03, accuracy=0.57] [A
Epoch 3:  71%|███████   | 331/468 [00:02<00:01, 117.25it/s, loss=1.05, accuracy=0.625][A
Epoch 3:  71%|███████   | 332/468 [00:02<00:01, 117.25it/s, loss=0.944, accuracy=0.672][A
Epoch 3:  71%|███████   | 333/468 [00:03<00:01, 116.51it/s, loss=0.944, accuracy=0.672][A
Epoch 3:  71%|███████   | 333/468 [00:03<00:01, 116.51it/s, loss=0.885, accuracy=0.68] [A
Epoch 3:  71%|███████▏  | 334/468 [00:03<00:01, 116.51it/s, loss=0.86, accuracy=0.719][A
Ep

Epoch 3:  87%|████████▋ | 408/468 [00:03<00:00, 115.82it/s, loss=0.994, accuracy=0.648][A
Epoch 3:  87%|████████▋ | 409/468 [00:03<00:00, 115.82it/s, loss=0.885, accuracy=0.719][A
Epoch 3:  88%|████████▊ | 410/468 [00:03<00:00, 115.82it/s, loss=0.674, accuracy=0.781][A
Epoch 3:  88%|████████▊ | 411/468 [00:03<00:00, 115.82it/s, loss=0.623, accuracy=0.836][A
Epoch 3:  88%|████████▊ | 412/468 [00:03<00:00, 115.82it/s, loss=0.88, accuracy=0.75]  [A
Epoch 3:  88%|████████▊ | 413/468 [00:03<00:00, 115.82it/s, loss=1, accuracy=0.672]  [A
Epoch 3:  88%|████████▊ | 414/468 [00:03<00:00, 115.82it/s, loss=1.09, accuracy=0.703][A
Epoch 3:  89%|████████▊ | 415/468 [00:03<00:00, 115.82it/s, loss=0.789, accuracy=0.75][A
Epoch 3:  89%|████████▉ | 416/468 [00:03<00:00, 115.82it/s, loss=0.889, accuracy=0.727][A
Epoch 3:  89%|████████▉ | 417/468 [00:03<00:00, 115.82it/s, loss=0.617, accuracy=0.82] [A
Epoch 3:  89%|████████▉ | 418/468 [00:03<00:00, 115.82it/s, loss=0.662, accuracy=0.758][A
Epo

Epoch 4:   5%|▌         | 24/468 [00:00<01:00,  7.30it/s, loss=0.895, accuracy=0.742][A
Epoch 4:   5%|▌         | 25/468 [00:00<01:00,  7.30it/s, loss=0.681, accuracy=0.797][A
Epoch 4:   6%|▌         | 26/468 [00:00<01:00,  7.30it/s, loss=0.761, accuracy=0.766][A
Epoch 4:   6%|▌         | 27/468 [00:00<00:43, 10.15it/s, loss=0.761, accuracy=0.766][A
Epoch 4:   6%|▌         | 27/468 [00:00<00:43, 10.15it/s, loss=0.724, accuracy=0.766][A
Epoch 4:   6%|▌         | 28/468 [00:00<00:43, 10.15it/s, loss=0.727, accuracy=0.758][A
Epoch 4:   6%|▌         | 29/468 [00:00<00:43, 10.15it/s, loss=0.625, accuracy=0.75] [A
Epoch 4:   6%|▋         | 30/468 [00:00<00:43, 10.15it/s, loss=0.797, accuracy=0.75][A
Epoch 4:   7%|▋         | 31/468 [00:00<00:43, 10.15it/s, loss=0.676, accuracy=0.766][A
Epoch 4:   7%|▋         | 32/468 [00:00<00:42, 10.15it/s, loss=0.893, accuracy=0.703][A
Epoch 4:   7%|▋         | 33/468 [00:00<00:42, 10.15it/s, loss=0.73, accuracy=0.789] [A
Epoch 4:   7%|▋       

Epoch 4:  23%|██▎       | 109/468 [00:01<00:06, 52.32it/s, loss=0.75, accuracy=0.711] [A
Epoch 4:  24%|██▎       | 110/468 [00:01<00:06, 52.32it/s, loss=0.842, accuracy=0.75][A
Epoch 4:  24%|██▎       | 111/468 [00:01<00:06, 52.32it/s, loss=0.696, accuracy=0.781][A
Epoch 4:  24%|██▍       | 112/468 [00:01<00:05, 62.44it/s, loss=0.696, accuracy=0.781][A
Epoch 4:  24%|██▍       | 112/468 [00:01<00:05, 62.44it/s, loss=0.832, accuracy=0.734][A
Epoch 4:  24%|██▍       | 113/468 [00:01<00:05, 62.44it/s, loss=0.836, accuracy=0.75] [A
Epoch 4:  24%|██▍       | 114/468 [00:01<00:05, 62.44it/s, loss=0.926, accuracy=0.719][A
Epoch 4:  25%|██▍       | 115/468 [00:01<00:05, 62.44it/s, loss=1.01, accuracy=0.656] [A
Epoch 4:  25%|██▍       | 116/468 [00:01<00:05, 62.44it/s, loss=1.15, accuracy=0.703][A
Epoch 4:  25%|██▌       | 117/468 [00:01<00:05, 62.44it/s, loss=0.58, accuracy=0.82] [A
Epoch 4:  25%|██▌       | 118/468 [00:01<00:05, 62.44it/s, loss=0.591, accuracy=0.828][A
Epoch 4:  25%

Epoch 4:  41%|████      | 192/468 [00:01<00:02, 108.01it/s, loss=0.732, accuracy=0.75] [A
Epoch 4:  41%|████      | 193/468 [00:01<00:02, 108.01it/s, loss=0.832, accuracy=0.766][A
Epoch 4:  41%|████▏     | 194/468 [00:01<00:02, 108.01it/s, loss=0.859, accuracy=0.727][A
Epoch 4:  42%|████▏     | 195/468 [00:01<00:02, 108.01it/s, loss=0.623, accuracy=0.773][A
Epoch 4:  42%|████▏     | 196/468 [00:01<00:02, 108.01it/s, loss=0.676, accuracy=0.758][A
Epoch 4:  42%|████▏     | 197/468 [00:01<00:02, 108.01it/s, loss=0.656, accuracy=0.812][A
Epoch 4:  42%|████▏     | 198/468 [00:01<00:02, 108.01it/s, loss=0.81, accuracy=0.773] [A
Epoch 4:  43%|████▎     | 199/468 [00:01<00:02, 109.76it/s, loss=0.81, accuracy=0.773][A
Epoch 4:  43%|████▎     | 199/468 [00:01<00:02, 109.76it/s, loss=0.486, accuracy=0.789][A
Epoch 4:  43%|████▎     | 200/468 [00:01<00:02, 109.76it/s, loss=0.632, accuracy=0.867][A
Epoch 4:  43%|████▎     | 201/468 [00:01<00:02, 109.76it/s, loss=0.642, accuracy=0.812][A


Epoch 4:  59%|█████▉    | 275/468 [00:02<00:01, 115.39it/s, loss=0.626, accuracy=0.828][A
Epoch 4:  59%|█████▉    | 276/468 [00:02<00:01, 115.39it/s, loss=0.685, accuracy=0.805][A
Epoch 4:  59%|█████▉    | 277/468 [00:02<00:01, 115.39it/s, loss=0.567, accuracy=0.82] [A
Epoch 4:  59%|█████▉    | 278/468 [00:02<00:01, 115.39it/s, loss=0.741, accuracy=0.758][A
Epoch 4:  60%|█████▉    | 279/468 [00:02<00:01, 115.39it/s, loss=0.752, accuracy=0.742][A
Epoch 4:  60%|█████▉    | 280/468 [00:02<00:01, 115.39it/s, loss=0.522, accuracy=0.836][A
Epoch 4:  60%|██████    | 281/468 [00:02<00:01, 115.39it/s, loss=0.594, accuracy=0.82] [A
Epoch 4:  60%|██████    | 282/468 [00:02<00:01, 115.39it/s, loss=0.605, accuracy=0.805][A
Epoch 4:  60%|██████    | 283/468 [00:02<00:01, 115.39it/s, loss=0.691, accuracy=0.805][A
Epoch 4:  61%|██████    | 284/468 [00:02<00:01, 115.39it/s, loss=0.536, accuracy=0.797][A
Epoch 4:  61%|██████    | 285/468 [00:02<00:01, 117.73it/s, loss=0.536, accuracy=0.797][A

Epoch 4:  77%|███████▋  | 359/468 [00:03<00:00, 118.03it/s, loss=0.726, accuracy=0.766][A
Epoch 4:  77%|███████▋  | 360/468 [00:03<00:00, 117.99it/s, loss=0.726, accuracy=0.766][A
Epoch 4:  77%|███████▋  | 360/468 [00:03<00:00, 117.99it/s, loss=0.62, accuracy=0.836] [A
Epoch 4:  77%|███████▋  | 361/468 [00:03<00:00, 117.99it/s, loss=0.649, accuracy=0.844][A
Epoch 4:  77%|███████▋  | 362/468 [00:03<00:00, 117.99it/s, loss=0.738, accuracy=0.75] [A
Epoch 4:  78%|███████▊  | 363/468 [00:03<00:00, 117.99it/s, loss=0.75, accuracy=0.75] [A
Epoch 4:  78%|███████▊  | 364/468 [00:03<00:00, 117.99it/s, loss=0.497, accuracy=0.875][A
Epoch 4:  78%|███████▊  | 365/468 [00:03<00:00, 117.99it/s, loss=0.673, accuracy=0.781][A
Epoch 4:  78%|███████▊  | 366/468 [00:03<00:00, 117.99it/s, loss=0.529, accuracy=0.836][A
Epoch 4:  78%|███████▊  | 367/468 [00:03<00:00, 117.99it/s, loss=0.603, accuracy=0.789][A
Epoch 4:  79%|███████▊  | 368/468 [00:03<00:00, 117.99it/s, loss=0.571, accuracy=0.797][A


Epoch 4:  94%|█████████▍| 442/468 [00:03<00:00, 118.26it/s, loss=0.775, accuracy=0.789][A
Epoch 4:  95%|█████████▍| 443/468 [00:03<00:00, 118.26it/s, loss=0.676, accuracy=0.797][A
Epoch 4:  95%|█████████▍| 444/468 [00:03<00:00, 118.26it/s, loss=0.502, accuracy=0.836][A
Epoch 4:  95%|█████████▌| 445/468 [00:03<00:00, 118.26it/s, loss=0.621, accuracy=0.852][A
Epoch 4:  95%|█████████▌| 446/468 [00:03<00:00, 117.08it/s, loss=0.621, accuracy=0.852][A
Epoch 4:  95%|█████████▌| 446/468 [00:03<00:00, 117.08it/s, loss=0.625, accuracy=0.828][A
Epoch 4:  96%|█████████▌| 447/468 [00:03<00:00, 117.08it/s, loss=0.505, accuracy=0.828][A
Epoch 4:  96%|█████████▌| 448/468 [00:03<00:00, 117.08it/s, loss=0.796, accuracy=0.75] [A
Epoch 4:  96%|█████████▌| 449/468 [00:03<00:00, 117.08it/s, loss=0.714, accuracy=0.758][A
Epoch 4:  96%|█████████▌| 450/468 [00:03<00:00, 117.08it/s, loss=0.677, accuracy=0.758][A
Epoch 4:  96%|█████████▋| 451/468 [00:03<00:00, 117.08it/s, loss=0.736, accuracy=0.828][A

Epoch 5:  12%|█▏        | 57/468 [00:00<00:20, 20.01it/s, loss=0.882, accuracy=0.688][A
Epoch 5:  12%|█▏        | 58/468 [00:00<00:20, 20.01it/s, loss=0.69, accuracy=0.828] [A
Epoch 5:  13%|█▎        | 59/468 [00:00<00:20, 20.01it/s, loss=0.545, accuracy=0.82][A
Epoch 5:  13%|█▎        | 60/468 [00:00<00:20, 20.01it/s, loss=0.537, accuracy=0.836][A
Epoch 5:  13%|█▎        | 61/468 [00:00<00:20, 20.01it/s, loss=0.652, accuracy=0.781][A
Epoch 5:  13%|█▎        | 62/468 [00:00<00:20, 20.01it/s, loss=0.739, accuracy=0.734][A
Epoch 5:  13%|█▎        | 63/468 [00:00<00:15, 26.62it/s, loss=0.739, accuracy=0.734][A
Epoch 5:  13%|█▎        | 63/468 [00:00<00:15, 26.62it/s, loss=0.503, accuracy=0.836][A
Epoch 5:  14%|█▎        | 64/468 [00:00<00:15, 26.62it/s, loss=0.542, accuracy=0.836][A
Epoch 5:  14%|█▍        | 65/468 [00:00<00:15, 26.62it/s, loss=0.684, accuracy=0.781][A
Epoch 5:  14%|█▍        | 66/468 [00:00<00:15, 26.62it/s, loss=0.671, accuracy=0.789][A
Epoch 5:  14%|█▍      

Epoch 5:  30%|███       | 141/468 [00:01<00:03, 82.29it/s, loss=0.576, accuracy=0.805][A
Epoch 5:  30%|███       | 142/468 [00:01<00:03, 82.29it/s, loss=0.536, accuracy=0.836][A
Epoch 5:  31%|███       | 143/468 [00:01<00:03, 82.29it/s, loss=0.399, accuracy=0.875][A
Epoch 5:  31%|███       | 144/468 [00:01<00:03, 82.29it/s, loss=0.756, accuracy=0.711][A
Epoch 5:  31%|███       | 145/468 [00:01<00:03, 82.29it/s, loss=0.592, accuracy=0.812][A
Epoch 5:  31%|███       | 146/468 [00:01<00:03, 82.29it/s, loss=0.475, accuracy=0.883][A
Epoch 5:  31%|███▏      | 147/468 [00:01<00:03, 82.29it/s, loss=0.517, accuracy=0.875][A
Epoch 5:  32%|███▏      | 148/468 [00:01<00:03, 82.29it/s, loss=0.545, accuracy=0.859][A
Epoch 5:  32%|███▏      | 149/468 [00:01<00:03, 90.43it/s, loss=0.545, accuracy=0.859][A
Epoch 5:  32%|███▏      | 149/468 [00:01<00:03, 90.43it/s, loss=0.602, accuracy=0.789][A
Epoch 5:  32%|███▏      | 150/468 [00:01<00:03, 90.43it/s, loss=0.786, accuracy=0.742][A
Epoch 5:  

Epoch 5:  48%|████▊     | 224/468 [00:02<00:02, 114.81it/s, loss=0.75, accuracy=0.805][A
Epoch 5:  48%|████▊     | 225/468 [00:02<00:02, 114.81it/s, loss=0.634, accuracy=0.82][A
Epoch 5:  48%|████▊     | 226/468 [00:02<00:02, 114.81it/s, loss=0.586, accuracy=0.82][A
Epoch 5:  49%|████▊     | 227/468 [00:02<00:02, 114.81it/s, loss=0.682, accuracy=0.766][A
Epoch 5:  49%|████▊     | 228/468 [00:02<00:02, 114.81it/s, loss=0.646, accuracy=0.781][A
Epoch 5:  49%|████▉     | 229/468 [00:02<00:02, 114.81it/s, loss=0.62, accuracy=0.781] [A
Epoch 5:  49%|████▉     | 230/468 [00:02<00:02, 114.81it/s, loss=0.65, accuracy=0.82] [A
Epoch 5:  49%|████▉     | 231/468 [00:02<00:02, 114.81it/s, loss=0.475, accuracy=0.883][A
Epoch 5:  50%|████▉     | 232/468 [00:02<00:02, 114.81it/s, loss=0.662, accuracy=0.781][A
Epoch 5:  50%|████▉     | 233/468 [00:02<00:02, 114.81it/s, loss=0.599, accuracy=0.805][A
Epoch 5:  50%|█████     | 234/468 [00:02<00:02, 114.81it/s, loss=0.709, accuracy=0.773][A
Epo

Epoch 5:  66%|██████▌   | 308/468 [00:02<00:01, 118.39it/s, loss=1.02, accuracy=0.688] [A
Epoch 5:  66%|██████▌   | 309/468 [00:02<00:01, 118.39it/s, loss=0.586, accuracy=0.852][A
Epoch 5:  66%|██████▌   | 310/468 [00:02<00:01, 118.52it/s, loss=0.586, accuracy=0.852][A
Epoch 5:  66%|██████▌   | 310/468 [00:02<00:01, 118.52it/s, loss=0.517, accuracy=0.836][A
Epoch 5:  66%|██████▋   | 311/468 [00:02<00:01, 118.52it/s, loss=0.673, accuracy=0.734][A
Epoch 5:  67%|██████▋   | 312/468 [00:02<00:01, 118.52it/s, loss=0.583, accuracy=0.82] [A
Epoch 5:  67%|██████▋   | 313/468 [00:02<00:01, 118.52it/s, loss=0.582, accuracy=0.836][A
Epoch 5:  67%|██████▋   | 314/468 [00:02<00:01, 118.52it/s, loss=0.538, accuracy=0.828][A
Epoch 5:  67%|██████▋   | 315/468 [00:02<00:01, 118.52it/s, loss=0.643, accuracy=0.844][A
Epoch 5:  68%|██████▊   | 316/468 [00:02<00:01, 118.52it/s, loss=0.415, accuracy=0.875][A
Epoch 5:  68%|██████▊   | 317/468 [00:02<00:01, 118.52it/s, loss=0.523, accuracy=0.828][A

Epoch 5:  84%|████████▎ | 391/468 [00:03<00:00, 118.62it/s, loss=0.572, accuracy=0.82] [A
Epoch 5:  84%|████████▍ | 392/468 [00:03<00:00, 118.62it/s, loss=0.571, accuracy=0.805][A
Epoch 5:  84%|████████▍ | 393/468 [00:03<00:00, 118.62it/s, loss=0.593, accuracy=0.828][A
Epoch 5:  84%|████████▍ | 394/468 [00:03<00:00, 118.62it/s, loss=0.913, accuracy=0.711][A
Epoch 5:  84%|████████▍ | 395/468 [00:03<00:00, 118.62it/s, loss=0.516, accuracy=0.828][A
Epoch 5:  85%|████████▍ | 396/468 [00:03<00:00, 118.62it/s, loss=0.553, accuracy=0.852][A
Epoch 5:  85%|████████▍ | 397/468 [00:03<00:00, 118.62it/s, loss=0.538, accuracy=0.836][A
Epoch 5:  85%|████████▌ | 398/468 [00:03<00:00, 120.43it/s, loss=0.538, accuracy=0.836][A
Epoch 5:  85%|████████▌ | 398/468 [00:03<00:00, 120.43it/s, loss=0.481, accuracy=0.812][A
Epoch 5:  85%|████████▌ | 399/468 [00:03<00:00, 120.43it/s, loss=0.422, accuracy=0.852][A
Epoch 5:  85%|████████▌ | 400/468 [00:03<00:00, 120.43it/s, loss=0.559, accuracy=0.805][A

Finished.
final validation accuracy:





# TRAIN MODEL USING TRADES

In [27]:
args = {}
args['test_batch_size'] = 128
args['no_cuda'] = False
args['epsilon'] = 0.0
args['num_steps'] = 3
args['step_size'] = 0.01
args['random'] =True,
args['model_path']='./checkpoints/model_mnist_smallcnn.pt'
args['source_model_path'] ='./checkpoints/model_mnist_smallcnn.pt'
args['target_model_path'] = './checkpoints/model_mnist_smallcnn.pt'
args['white_box_attack']=True
args['log_interval'] = 3
args['beta'] = 1.0

In [28]:
def trades_obj(model, optimiser, loss_fn, x,y, epoch):
    data, target = x, y
    optimiser.zero_grad()
    # calculate robust loss
    print(args)
    loss = trades_loss(model=model,
                       x_natural=data,
                       y=target,
                       optimizer=optimiser,
                       step_size=args['step_size'],
                       epsilon=args['epsilon'],
                       perturb_steps=args['num_steps'],
                       beta=args['beta'])

    loss.backward()
    optimiser.step()
    ypred = model(data)
            
    return loss, ypred

In [30]:
if TrainTRADES:
    ## initialize model
    model_TRADES = Net().to(DEVICE)
    ## training params
    lr = 0.001
    optimiser = optim.SGD(model_TRADES.parameters(), lr=lr)
    epochs = 10
    ## train model
    history_natural = olympic.fit(
        model_TRADES,
        optimiser,
        nn.KLDivLoss(size_average=False),
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=trades_obj,
        #update_fn_kwargs={'adversary': entropySmoothing, 'k': 30, 'step': 0.03, 'eps': 0.3, 'norm': 'inf', 'gamma':1e-5},
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5)
        ]
    )
    ## verify validation accuracy
    print('final validation accuracy:')
    valscore = olympic.evaluate(model_TRADES, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    ## save model
    modelname = '../trainedmodels/'+dataset+'/TRADES_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_TRADES,modelname)


Epoch 1:   0%|          | 0/468 [00:00<?, ?it/s][A

Begin training...



Epoch 1:   0%|          | 1/468 [00:00<01:32,  5.06it/s][A
Epoch 1:   0%|          | 1/468 [00:00<01:32,  5.06it/s, loss=2.3, accuracy=0.0938][A
Epoch 1:   0%|          | 2/468 [00:00<01:32,  5.06it/s, loss=2.31, accuracy=0.109][A
Epoch 1:   1%|          | 3/468 [00:00<01:31,  5.06it/s, loss=2.3, accuracy=0.141] [A
Epoch 1:   1%|          | 4/468 [00:00<01:31,  5.06it/s, loss=2.32, accuracy=0.0938][A
Epoch 1:   1%|          | 5/468 [00:00<01:31,  5.06it/s, loss=2.32, accuracy=0.125] [A
Epoch 1:   1%|▏         | 6/468 [00:00<01:31,  5.06it/s, loss=2.3, accuracy=0.125] [A
Epoch 1:   1%|▏         | 7/468 [00:00<01:31,  5.06it/s, loss=2.3, accuracy=0.109][A
Epoch 1:   2%|▏         | 8/468 [00:00<01:05,  6.99it/s, loss=2.3, accuracy=0.109][A
Epoch 1:   2%|▏         | 8/468 [00:00<01:05,  6.99it/s, loss=2.28, accuracy=0.18][A
Epoch 1:   2%|▏         | 9/468 [00:00<01:05,  6.99it/s, loss=2.32, accuracy=0.117][A
Epoch 1:   2%|▏         | 10/468 [00:00<01:05,  6.99it/s, loss=2.29, a

Epoch 1:  18%|█▊        | 82/468 [00:01<00:07, 53.93it/s, loss=2.31, accuracy=0.0859][A
Epoch 1:  18%|█▊        | 83/468 [00:01<00:07, 53.93it/s, loss=2.31, accuracy=0.102] [A
Epoch 1:  18%|█▊        | 84/468 [00:01<00:07, 53.93it/s, loss=2.3, accuracy=0.141] [A
Epoch 1:  18%|█▊        | 85/468 [00:01<00:07, 53.93it/s, loss=2.3, accuracy=0.0625][A
Epoch 1:  18%|█▊        | 86/468 [00:01<00:07, 53.93it/s, loss=2.29, accuracy=0.109][A
Epoch 1:  19%|█▊        | 87/468 [00:01<00:06, 57.00it/s, loss=2.29, accuracy=0.109][A
Epoch 1:  19%|█▊        | 87/468 [00:01<00:06, 57.00it/s, loss=2.3, accuracy=0.102] [A
Epoch 1:  19%|█▉        | 88/468 [00:01<00:06, 57.00it/s, loss=2.3, accuracy=0.117][A
Epoch 1:  19%|█▉        | 89/468 [00:01<00:06, 57.00it/s, loss=2.31, accuracy=0.0938][A
Epoch 1:  19%|█▉        | 90/468 [00:01<00:06, 57.00it/s, loss=2.3, accuracy=0.109]  [A
Epoch 1:  19%|█▉        | 91/468 [00:01<00:06, 57.00it/s, loss=2.31, accuracy=0.133][A
Epoch 1:  20%|█▉        | 92/

Epoch 1:  35%|███▌      | 164/468 [00:02<00:04, 71.78it/s, loss=2.31, accuracy=0.0625][A
Epoch 1:  35%|███▌      | 165/468 [00:02<00:04, 71.78it/s, loss=2.31, accuracy=0.0938][A
Epoch 1:  35%|███▌      | 166/468 [00:02<00:04, 70.52it/s, loss=2.31, accuracy=0.0938][A
Epoch 1:  35%|███▌      | 166/468 [00:02<00:04, 70.52it/s, loss=2.3, accuracy=0.109]  [A
Epoch 1:  36%|███▌      | 167/468 [00:02<00:04, 70.52it/s, loss=2.3, accuracy=0.0859][A
Epoch 1:  36%|███▌      | 168/468 [00:02<00:04, 70.52it/s, loss=2.3, accuracy=0.141] [A
Epoch 1:  36%|███▌      | 169/468 [00:02<00:04, 70.52it/s, loss=2.32, accuracy=0.0703][A
Epoch 1:  36%|███▋      | 170/468 [00:02<00:04, 70.52it/s, loss=2.3, accuracy=0.133]  [A
Epoch 1:  37%|███▋      | 171/468 [00:02<00:04, 70.52it/s, loss=2.29, accuracy=0.133][A
Epoch 1:  37%|███▋      | 172/468 [00:02<00:04, 70.52it/s, loss=2.3, accuracy=0.0781][A
Epoch 1:  37%|███▋      | 173/468 [00:02<00:04, 70.52it/s, loss=2.3, accuracy=0.133] [A
Epoch 1:  37%|█

Epoch 1:  52%|█████▏    | 245/468 [00:03<00:03, 72.11it/s, loss=2.31, accuracy=0.0859][A
Epoch 1:  53%|█████▎    | 246/468 [00:03<00:03, 72.11it/s, loss=2.29, accuracy=0.148] [A
Epoch 1:  53%|█████▎    | 247/468 [00:03<00:03, 72.16it/s, loss=2.29, accuracy=0.148][A
Epoch 1:  53%|█████▎    | 247/468 [00:03<00:03, 72.16it/s, loss=2.29, accuracy=0.141][A
Epoch 1:  53%|█████▎    | 248/468 [00:03<00:03, 72.16it/s, loss=2.3, accuracy=0.109] [A
Epoch 1:  53%|█████▎    | 249/468 [00:03<00:03, 72.16it/s, loss=2.3, accuracy=0.0859][A
Epoch 1:  53%|█████▎    | 250/468 [00:03<00:03, 72.16it/s, loss=2.3, accuracy=0.148] [A
Epoch 1:  54%|█████▎    | 251/468 [00:03<00:03, 72.16it/s, loss=2.31, accuracy=0.0781][A
Epoch 1:  54%|█████▍    | 252/468 [00:03<00:02, 72.16it/s, loss=2.3, accuracy=0.141]  [A
Epoch 1:  54%|█████▍    | 253/468 [00:03<00:02, 72.16it/s, loss=2.3, accuracy=0.109][A
Epoch 1:  54%|█████▍    | 254/468 [00:03<00:02, 72.16it/s, loss=2.31, accuracy=0.0781][A
Epoch 1:  54%|███

Epoch 1:  70%|██████▉   | 326/468 [00:04<00:01, 73.40it/s, loss=2.29, accuracy=0.172][A
Epoch 1:  70%|██████▉   | 327/468 [00:04<00:01, 74.13it/s, loss=2.29, accuracy=0.172][A
Epoch 1:  70%|██████▉   | 327/468 [00:04<00:01, 74.13it/s, loss=2.31, accuracy=0.0703][A
Epoch 1:  70%|███████   | 328/468 [00:04<00:01, 74.13it/s, loss=2.29, accuracy=0.133] [A
Epoch 1:  70%|███████   | 329/468 [00:04<00:01, 74.13it/s, loss=2.29, accuracy=0.109][A
Epoch 1:  71%|███████   | 330/468 [00:04<00:01, 74.13it/s, loss=2.29, accuracy=0.133][A
Epoch 1:  71%|███████   | 331/468 [00:04<00:01, 74.13it/s, loss=2.3, accuracy=0.133] [A
Epoch 1:  71%|███████   | 332/468 [00:04<00:01, 74.13it/s, loss=2.3, accuracy=0.117][A
Epoch 1:  71%|███████   | 333/468 [00:04<00:01, 74.13it/s, loss=2.29, accuracy=0.164][A
Epoch 1:  71%|███████▏  | 334/468 [00:04<00:01, 74.13it/s, loss=2.3, accuracy=0.125] [A
Epoch 1:  72%|███████▏  | 335/468 [00:04<00:01, 73.77it/s, loss=2.3, accuracy=0.125][A
Epoch 1:  72%|███████

Epoch 1:  87%|████████▋ | 407/468 [00:05<00:00, 68.65it/s, loss=2.29, accuracy=0.125][A
Epoch 1:  87%|████████▋ | 408/468 [00:05<00:00, 68.65it/s, loss=2.29, accuracy=0.148][A
Epoch 1:  87%|████████▋ | 409/468 [00:05<00:00, 68.65it/s, loss=2.3, accuracy=0.117] [A
Epoch 1:  88%|████████▊ | 410/468 [00:05<00:00, 68.65it/s, loss=2.29, accuracy=0.125][A
Epoch 1:  88%|████████▊ | 411/468 [00:05<00:00, 68.65it/s, loss=2.29, accuracy=0.141][A
Epoch 1:  88%|████████▊ | 412/468 [00:06<00:00, 68.65it/s, loss=2.29, accuracy=0.141][A
Epoch 1:  88%|████████▊ | 413/468 [00:06<00:00, 68.65it/s, loss=2.3, accuracy=0.125] [A
Epoch 1:  88%|████████▊ | 414/468 [00:06<00:00, 69.47it/s, loss=2.3, accuracy=0.125][A
Epoch 1:  88%|████████▊ | 414/468 [00:06<00:00, 69.47it/s, loss=2.29, accuracy=0.125][A
Epoch 1:  89%|████████▊ | 415/468 [00:06<00:00, 69.47it/s, loss=2.3, accuracy=0.125] [A
Epoch 1:  89%|████████▉ | 416/468 [00:06<00:00, 69.47it/s, loss=2.29, accuracy=0.141][A
Epoch 1:  89%|████████

RuntimeError: "log_cuda" not implemented for 'Long'


Epoch 1: 100%|█████████▉| 467/468 [00:24<00:00, 72.31it/s, loss=2.28, accuracy=0.18][A

In [None]:
print('final validation accuracy:')
valscore = olympic.evaluate(model_TRADES, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
print(valscore['val_accuracy'])    

# TRAIN MODEL USING MART

# TRAIN MODEL USING MMA

# LOAD ALL PRE-TRAINED MODELS

In [None]:
TrainSGD = False
TrainESGD = False
TrainL2 = False
TrainLInf = False
TrainSAT2 = False
TrainSATInf = False

In [None]:
# Load all the pre-trained models
if dataset=='MNIST':
    if not TrainSGD:
        model_SGD = torch.load('../trainedmodels/MNIST/SGD_ep10_lr0.1.pt').to(DEVICE)
    if not TrainESGD:    
        model_ESGD = torch.load('../trainedmodels/MNIST/ESGD_ep5_lr0.1.pt').to(DEVICE)
    if not TrainLInf:
        adv_model_linf = torch.load('../trainedmodels/MNIST/AT2_ep2_lr0.1.pt').to(DEVICE)
    if not TrainL2:
        adv_model_l2 = torch.load('../trainedmodels/MNIST/ATInf_ep2_lr0.1.pt').to(DEVICE)
    if not TrainSAT2:
        model_SAT2 = torch.load('../trainedmodels/MNIST/SAT2_ep2_lr0.1.pt').to(DEVICE)
    if not TrainSATInf:
        model_SATInf = torch.load('../trainedmodels/MNIST/SATInf_ep2_lr0.1.pt').to(DEVICE)

# VISUALIZE NETWORK OUTPUT AT DIFFERENT LEVELS OF ATTACK

In [None]:
def visualise_adversarial_examples(model, x, y, l2_eps=5.0, linf_eps=0.2):
    x = x.unsqueeze(0).to(DEVICE)
    y =  torch.tensor([y]).to(DEVICE)
    
    ## l2 and linf attacks
    x_adv_l2 = pgd(model, x, y, torch.nn.CrossEntropyLoss(), k=30, step=1.0, eps=l2_eps, norm=2)
    x_adv_linf = iterated_fgsm(model, x, y, torch.nn.CrossEntropyLoss(), k=60, step=0.01, eps=linf_eps, norm='inf')
    
    y_pred = model(x)
    y_pred_l2 = model(x_adv_l2)
    y_pred_linf = model(x_adv_linf)
    
    fig, axes = plt.subplots(1, 3, figsize=(15,5))
    
    axes[0].imshow(x[0, 0].cpu().numpy(), cmap='gray')
    axes[0].set_title(
        f'Natural, '
        f'P({ y_pred.argmax(dim=1).item()}) = '
        f'{np.round(y_pred.softmax(dim=1)[0, y_pred.argmax(dim=1).item()].item(), 3)}')
    axes[0].set_xticks([])
    axes[0].set_yticks([])
    
    axes[1].imshow(x_adv_l2[0, 0].cpu().numpy(), cmap='gray')
    axes[1].set_title(
        f'$L^2$ adversary, '
        f'eps={l2_eps}, '
        f'P({y_pred_l2.argmax(dim=1).item()}) = '
        f'{np.round(y_pred_l2.softmax(dim=1)[0, y_pred_l2.argmax(dim=1).item()].item(), 3)}')
    axes[1].set_xticks([])
    axes[1].set_yticks([])
    
    axes[2].imshow(x_adv_linf[0, 0].cpu().numpy(), cmap='gray')
    axes[2].set_title(
        '$L^{\infty}$ adversary, '
        f'eps={linf_eps}, '
        f'P({y_pred_l2.argmax(dim=1).item()}) = '
        f'{np.round(y_pred_linf.softmax(dim=1)[0, y_pred_linf.argmax(dim=1).item()].item(), 3)}')
    axes[2].set_xticks([])
    axes[2].set_yticks([])
    
    
    plt.show()

Evaluate sgd model with PGD attack and FGSM attack

In [None]:
if True:
    visualise_adversarial_examples(model_SGD, *val[1])
    visualise_adversarial_examples(model_SGD, *val[2])
    visualise_adversarial_examples(model_SGD, *val[7])

Evaluate Adversarial L-2 model with PGD attack and FGSM attack

In [None]:
if True:
    visualise_adversarial_examples(adv_model_l2, *val[1])
    visualise_adversarial_examples(adv_model_l2, *val[6])
    visualise_adversarial_examples(adv_model_l2, *val[10])

## Quantifying adversarial accuracy

In [None]:
def infnorm(x):
    infn = torch.max(torch.abs(x.detach().cpu()))
    return infn

In [None]:
def evaluate_against_adversary(model, k, eps, step, norm):
    total = 0
    acc = 0
    for x, y in val_loader:
        total += x.size(0)
        
        if norm == 2:
            x_adv = pgd(
                model, x.to(DEVICE), y.to(DEVICE), torch.nn.CrossEntropyLoss(), k=k, step=step, eps=eps, norm=2)
            print('rel. l2-norm of x_adv-x:',torch.norm(x_adv.detach().cpu()-x)/np.sqrt(x.size(0)))#/torch.norm(x))
        elif norm == 'inf':
            x_adv = iterated_fgsm(
                model, x.to(DEVICE), y.to(DEVICE), torch.nn.CrossEntropyLoss(), k=k, step=step, eps=eps, norm='inf')
            print('rel. linf-norm of x_adv-x:',infnorm(x_adv.detach().cpu()-x)/infnorm(x))
        y_pred = model(x_adv)

        acc += olympic.metrics.accuracy(y.to(DEVICE), y_pred) * x.size(0)

    return acc/total

## Evaluate robust models

In [None]:
loadResults = False

In [None]:
t1 = time.time()
if not loadResults:
    pgd_attack_range = np.arange(0, 6.1, 1./3)
    acc_SGD = []
    acc_ESGD = []
    acc_l2 = []
    acc_linf = []
    acc_SAT2 = []
    acc_SATInf = []
    for eps in pgd_attack_range:
        print('eps:',eps)
        print('evaluating SGD network...')
        acc_SGD.append(evaluate_against_adversary(model_SGD, k=20, eps=eps, step=0.5, norm=2))
        print('evaluating ESGD network...')
        acc_ESGD.append(evaluate_against_adversary(model_ESGD, k=20, eps=eps, step=0.5, norm=2))
        print('evaluating SAT2 network...')
        acc_SAT2.append(evaluate_against_adversary(model_SAT2, k=30, eps=eps, step=0.3, norm=2))
        print('evaluating SATInf network...')
        acc_SATInf.append(evaluate_against_adversary(model_SATInf, k=30, eps=eps, step=0.25, norm=2))        
        print('evaluating linf network...')
        acc_linf.append(evaluate_against_adversary(adv_model_linf, k=30, eps=eps*1.2, step=1.5, norm=2))
        print('evaluating l2 network...')
        acc_l2.append(evaluate_against_adversary(adv_model_l2, k=30, eps=eps*1.2, step=1.5, norm=2))
print("time elapsed:",time.time()-t1)    

In [None]:
if not loadResults:    
    accData = [acc_SGD,acc_ESGD,acc_l2,acc_linf,acc_SAT2,acc_SATInf]
    np.save('../results/accData_l2.npy',accData)

In [None]:
if loadResults:
    pgd_attack_range = np.arange(0.0, 6.1, 1./3)
    accData2 = np.load('../results/accData_l2.npy')
    [acc_SGD,acc_ESGD,acc_l2,acc_linf,acc_SAT2,acc_SATInf] = accData2

In [None]:
with plt.style.context('ggplot'):
    plt.rcParams["font.weight"] = "bold"
    plt.rcParams["axes.labelweight"] = "bold"
    plt.rcParams["axes.labelcolor"] = "black"
    fig, axes = plt.subplots(1, 1, figsize=(15,10))
    plt.figure(figsize=(14,7))
    axes.set_title('$L^2$-bounded adversary')
    axes.plot(pgd_attack_range, acc_SGD, label='SGD')
    axes.plot(pgd_attack_range, acc_ESGD, label='Entropy-SGD')
    axes.plot(pgd_attack_range, acc_SAT2, label='Data-Entropy-SGD ($L_2$)')
    axes.plot(pgd_attack_range, acc_SATInf, label='Data-Entropy-SGD ($L_\infty$)')
    axes.plot(pgd_attack_range, acc_l2, label='$L2$ training')
    axes.plot(pgd_attack_range, acc_linf, label='$L{\infty}$ training')
    axes.vlines([3], 0, 1, colors=COLOURS[1], linestyle='--')
    axes.set_ylabel('Accuracy')
    axes.set_xlabel('Epsilon')
    axes.set_ylim((0,1))
    axes.legend()

In [None]:
loadResultsFGSM = False

In [None]:
t1 = time.time()
if not loadResultsFGSM:
    fgsm_attack_range = np.arange(0.0, 0.52, 0.025)
    fgsm_acc_linf = []
    fgsm_acc_l2 = []
    fgsm_acc_SGD = []
    fgsm_acc_ESGD = []
    fgsm_acc_SAT2 = []
    fgsm_acc_SATInf = []
    for eps in fgsm_attack_range:
        print('eps:',eps)
        print('evaluating SGD network...')
        fgsm_acc_SGD.append(evaluate_against_adversary(model_SGD, k=20, eps=eps, step=0.1, norm='inf'))
        print('evaluating ESGD network...')
        fgsm_acc_ESGD.append(evaluate_against_adversary(model_ESGD, k=20, eps=eps, step=0.1, norm='inf'))
        print('evaluating SAT2 network...')
        fgsm_acc_SAT2.append(evaluate_against_adversary(model_SAT2, k=10, eps=eps, step=0.1, norm='inf'))
        print('evaluating SATInf network...')
        fgsm_acc_SATInf.append(evaluate_against_adversary(model_SATInf, k=10, eps=eps, step=0.1, norm='inf'))        
        print('evaluating linf network...')
        fgsm_acc_linf.append(evaluate_against_adversary(adv_model_linf, k=50, eps=eps, step=0.02, norm='inf'))
        print('evaluating l2 network...')
        fgsm_acc_l2.append(evaluate_against_adversary(adv_model_l2, k=50, eps=eps, step=0.02, norm='inf'))
    
print("time elapsed:",time.time()-t1)    

In [None]:
if not loadResultsFGSM:    
    fgsmaccData = [fgsm_acc_SGD,fgsm_acc_ESGD,fgsm_acc_l2,fgsm_acc_linf,fgsm_acc_SAT2,fgsm_acc_SATInf]
    np.save('../results/fgsmaccData.npy',fgsmaccData)

In [None]:
if loadResultsFGSM:
    fgsm_attack_range = np.arange(0.0, 0.52, 0.025)
    accData2 = np.load('../results/fgsmaccData.npy')
    [fgsm_acc_SGD,fgsm_acc_ESGD,fgsm_acc_l2,fgsm_acc_linf,fgsm_acc_SAT2,fgsm_acc_SATInf] = accData2

In [None]:
with plt.style.context('ggplot'):
    plt.rcParams["font.weight"] = "bold"
    plt.rcParams["axes.labelweight"] = "bold"
    plt.rcParams["axes.labelcolor"] = "black"
    fig, axes = plt.subplots(1, 1, figsize=(15,10))
    plt.figure(figsize=(14,7))
    axes.set_title('$L^\infty$-bounded adversary')
    axes.plot(fgsm_attack_range, fgsm_acc_SGD, label='SGD')
    axes.plot(fgsm_attack_range, fgsm_acc_ESGD, label='Entropy-SGD')
    axes.plot(fgsm_attack_range, fgsm_acc_SAT2, label='Data-Entropy-SGD ($L_2$)')
    axes.plot(fgsm_attack_range, fgsm_acc_SATInf, label='Data-Entropy-SGD ($L_\infty$)')    
    axes.plot(fgsm_attack_range, fgsm_acc_l2, label='$L2$ training')
    axes.plot(fgsm_attack_range, fgsm_acc_linf, label='$L{\infty}$ training')
    axes.vlines([0.25], 0, 1, colors=COLOURS[1], linestyle='--')
    axes.set_ylabel('Accuracy')
    axes.set_xlabel('Epsilon')
    axes.set_ylim((0,1))
    axes.legend()