Evaluate the models

# Stochastic Adversarial Training (StochAT)

### SoTA

vanila SGD: 
MNIST - 99%+ (most cnns), CIFAR10 - 93%+ (resnet18), 96%+ (wideresnet) 

MNIST:

adversarial attacks: 
l-inf @ eps = 80/255 @20 steps: TRADES - 96.07% - (4 layer cnn), MART 96.4%, MMA 95.5%, PGD - 96.01% - (4 layer cnn)

adversarial attacks:
l-2 @ eps = 32/255 (check): TRADES, MMA, PGD

CIFAR10:

adversarial attacks: 
l-inf @ eps = 8/255 @20 steps: 
TRADES 53-56% - (WRN-34-10), MART 57-58% (WRN-34-10), MMA 47%, PGD 48% - (WRN-32-10)// 49% - (WRN-34-10), Std - 0.03%
https://openreview.net/pdf?id=rklOg6EFwS (Table 4)

adversarial attacks: 
l-inf @ eps = 8/255 @20 steps: 
[ResNet10] TRADES 45.4%, MART 46.6%, MMA 37.26%, PGD 42.27%, Std 0.14%

Benign accuracies: TRADES 84.92%, MART 83.62%, MMA 84.36, PGD 87.14%, Std 95.8% [wideresnet]
https://openreview.net/pdf?id=Ms9zjhVB5R (Table 1)

adversarial attacks:
l-2 @ eps = 32/255 (check): TRADES, MART, MMA, PGD

TBD: CWinf attacks

## Pretrained models for comparison

download pretrained models and place in ../trainedmodels/MNIST or ../trainedmodels/CIFAR10 respectively

### TRADES :
https://github.com/yaodongyu/TRADES (MNIST: small cnn, CIFAR10: WideResNet34)
### MMA : 
https://github.com/BorealisAI/mma_training (MNIST: lenet5, CIFAR10: WideResNet28)
### MART :
 https://github.com/YisenWang/MART (CIFAR10: ResNet18 and WideResNet34)

## IMPORT LIBRARIES

In [1]:
import numpy as np
import pandas as pd
from torch import nn, optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import transforms, datasets
from multiprocessing import cpu_count
from collections import OrderedDict
import matplotlib.pyplot as plt
import torch
import olympic
import sys
from typing import Union, Callable, Tuple
sys.path.append('../adversarial/')
sys.path.append('../architectures/')
from functional import boundary, iterated_fgsm, local_search, pgd, entropySmoothing
from ESGD_utils import *
import pickle
import time
import torch.backends.cudnn as cudnn
import argparse, math, random
import ESGD_optim
from trades import trades_loss

In [27]:
from torch.optim.lr_scheduler import StepLR

In [2]:
if torch.cuda.is_available():
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'

In [3]:
DEVICE

'cuda'

# LOAD DATA

In [4]:
#place data folders outside working directory

In [14]:
dataset = 'MNIST' # [MNIST, CIFAR10]
transform = transforms.Compose([
    transforms.ToTensor(),
])
bsz = 128
if dataset == 'MNIST':
    train = datasets.MNIST('../../data/MNIST', train=True, transform=transform, download=True)
    val = datasets.MNIST('../../data/MNIST', train=False, transform=transform, download=True)
elif dataset == 'CIFAR10':
    train = datasets.CIFAR10('../../data/CIFAR10', train=True, transform=transform, download=True)
    val = datasets.CIFAR10('../../data/CIFAR10', train=False, transform=transform, download=True)
    
train_loader = DataLoader(train, batch_size=128, num_workers=cpu_count(),drop_last=True)
val_loader = DataLoader(val, batch_size=128, num_workers=cpu_count(),drop_last=True)

# INITIALIZE NETWORK

In [15]:
if dataset=='MNIST':
    from net_mnist import Net, NetSoft, model_cnn
    Net = model_cnn
    NetName = 'model_cnn'

In [7]:
if dataset=='CIFAR10':
    #[ResNet18,ResNet34,ResNet50,WideResNet]
    from resnet import ResNet18,ResNet34,ResNet50
    from wideresnet import WideResNet
    Net = WideResNet
    NetName = 'WideResNet'

# RANDOM SEED 

In [8]:
seed = 50
torch.set_num_threads(2)
if DEVICE=='cuda':
    torch.cuda.set_device(-1)
    torch.cuda.manual_seed(seed)
    cudnn.benchmark = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f503f63f5e8>

# LOAD PRETRAINED OR TRAIN NEW MODELS:

In [9]:
TrainSGD = True
TrainESGD = True
TrainL2 = True
TrainLInf = True
TrainSAT2 = True
TrainSATInf = True
TrainTRADES = True
TrainTRADESInf = True
# tbd: add training modules for MART(inf only), MMA and MMAInf - always keep False
TrainMART = False 
TrainMMA = False
TrainMMAInf = False

# TRAIN NAIVE MODEL USING SGD

In [23]:
def eval_train(model, device, train_loader):
    model.eval()
    train_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            train_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    train_loss /= len(train_loader.dataset)
    print('Training: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        train_loss, correct, len(train_loader.dataset),
        100. * correct / len(train_loader.dataset)))
    training_accuracy = correct / len(train_loader.dataset)
    return train_loss, training_accuracy

In [24]:
def eval_test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, size_average=False).item()
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print('Test: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    test_accuracy = correct / len(test_loader.dataset)
    return test_loss, test_accuracy

In [39]:
def evaluate_against_adversary(model, k, eps, step, norm):
    total = 0
    acc = 0
    for x, y in val_loader:
        total += x.size(0)
        
        if norm == 2:
            x_adv = pgd(
                model, x.to(DEVICE), y.to(DEVICE), torch.nn.CrossEntropyLoss(), k=k, step=step, eps=eps, norm=2)
            #print('rel. l2-norm of x_adv-x:',torch.norm(x_adv.detach().cpu()-x)/np.sqrt(x.size(0)))#/torch.norm(x))
            if torch.norm(x_adv.detach().cpu()-x)/np.sqrt(x.size(0))<0.6*eps:
                print('reconfigure attack')
                break;
        elif norm == 'inf':
            x_adv = iterated_fgsm(
                model, x.to(DEVICE), y.to(DEVICE), torch.nn.CrossEntropyLoss(), k=k, step=step, eps=eps, norm='inf')
            print('rel. linf-norm of x_adv-x:',infnorm(x_adv.detach().cpu()-x)/infnorm(x))
        y_pred = model(x_adv)

        acc += olympic.metrics.accuracy(y.to(DEVICE), y_pred) * x.size(0)
    print('Adversarial Accuracy: {}/{} ({:.0f}%)'.format(acc,total,100. * acc / total))
    return acc/total

In [40]:
def train_sgd(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        ypred = model(data)
        
        optimizer.zero_grad()
        sgd_loss = nn.CrossEntropyLoss()
        # calculate robust loss
        loss = sgd_loss(ypred,target)
        
        #print('outloss pre step:',loss)
        loss.backward(retain_graph=True)
        
        optimizer.step()
        #print('outloss post step:',loss.item())

        # print progress
        print_progress = 0
        #if print_progress == 1 & batch_idx % 100 == 0:
        #    print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        #        epoch, batch_idx * len(data), len(train_loader.dataset),
        #        100. * batch_idx / len(train_loader), loss.item()))

In [43]:
5%5

0

In [44]:
if TrainSGD:
    ## initialize model
    model_SGD = Net().to(DEVICE)
    ## training params
    optimiser = optim.SGD(model_SGD.parameters(), lr=lr)
    epochs = 10
    
    ## train model
    model_SGD = Net().to(DEVICE)
    ## training params
    lr = 0.1
    optimizer = optim.SGD(model_SGD.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=20, gamma=0.5)    
    ## train model

    for epoch in range(1, epochs + 1):
        print('Epoch:',epoch)
        # adjust learning rate for SGD

        # adversarial training
        train_sgd(model_SGD, DEVICE, train_loader, optimizer, epoch)
        # evaluation on natural examples
        print('================================================================')
        eval_train(model_SGD, DEVICE, train_loader)
        eval_test(model_SGD, DEVICE, val_loader)
        if (epoch) % 5 == 0:
            evaluate_against_adversary(model_SGD, k=20, eps=3.0, step=0.5, norm=2)
        print('================================================================')
        scheduler.step()

    ## save model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_SGD_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_SGD,modelname)

Epoch: 1




Training: Average loss: 0.1914, Accuracy: 56229/60000 (94%)
Test: Average loss: 0.1770, Accuracy: 9408/10000 (94%)
Epoch: 2
Training: Average loss: 0.0744, Accuracy: 58477/60000 (97%)
Test: Average loss: 0.0779, Accuracy: 9736/10000 (97%)
Epoch: 3
Training: Average loss: 0.0564, Accuracy: 58807/60000 (98%)
Test: Average loss: 0.0688, Accuracy: 9763/10000 (98%)
Epoch: 4
Training: Average loss: 0.0394, Accuracy: 59130/60000 (99%)
Test: Average loss: 0.0579, Accuracy: 9805/10000 (98%)
Adversarial Accuracy: 209.0/9984 (2%)
Epoch: 5
Training: Average loss: 0.0347, Accuracy: 59195/60000 (99%)
Test: Average loss: 0.0583, Accuracy: 9811/10000 (98%)
Epoch: 6
Training: Average loss: 0.0283, Accuracy: 59334/60000 (99%)
Test: Average loss: 0.0590, Accuracy: 9818/10000 (98%)
Epoch: 7
Training: Average loss: 0.0165, Accuracy: 59571/60000 (99%)
Test: Average loss: 0.0547, Accuracy: 9834/10000 (98%)
Epoch: 8
Training: Average loss: 0.0136, Accuracy: 59636/60000 (99%)
Test: Average loss: 0.0546, Accura

# TRAIN MODEL USING ENTROPY SGD (ESGD)

In [62]:
def entropy_training(model, optimiser, loss_fn, x, y, epoch):   
    model.train()
    y_pred = model(x)
    
    def helper():
        def feval():
            #x, y = Variable(x), Variable(y.squeeze())
            bsz = x.size(0)
            optimiser.zero_grad()
            yh = model(x)
            f = loss_fn.forward(yh, y)
            f.backward()

            yp = yh.argmax(axis=1)
            prec1 = 100*torch.sum(yp == y)//bsz
            err = 100.-prec1.item()

            return (f.data.item(), err)
        return feval

    loss, err = optimiser.step(helper(), model, loss_fn)
    loss = torch.tensor(loss)
    return loss, y_pred

In [63]:
def train_entropy(model, device, train_loader, optimizer, epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        ypred = model(data)
        
        sgd_loss = nn.CrossEntropyLoss()
        # calculate robust loss
        loss = entropy_training(model,optimizer,sgd_loss,data,target,epoch)

In [65]:
if TrainESGD:
    ## initialize model
    model_ESGD = Net().to(DEVICE)
    ## training parameters
    lr = 0.01 
    l2 = 0.0 #l2 regularization
    L = 0    #langevin iterations
    gamma = 1e-4 
    scoping = 1e-3
    noise = 1e-4
    loss_fn = nn.CrossEntropyLoss()
    epochs = 20
    optimiser = ESGD_optim.EntropySGD(model_ESGD.parameters(),config = dict(lr=lr, momentum=0.9, nesterov=True, weight_decay=l2,
            L=L, eps=noise, g0=gamma, g1=scoping))

    scheduler = StepLR(optimiser, step_size=20, gamma=0.5)    
    ## train model

    for epoch in range(1, epochs + 1):
        print('Epoch:',epoch)
        # adjust learning rate for SGD

        # adversarial training
        train_entropy(model_ESGD, DEVICE, train_loader, optimiser, epoch)
        # evaluation on natural examples
        print('================================================================')
        eval_train(model_ESGD, DEVICE, train_loader)
        eval_test(model_ESGD, DEVICE, val_loader)
        if (epoch) % 5 == 0:
            evaluate_against_adversary(model_ESGD, k=20, eps=3.0, step=0.5, norm=2)
        print('================================================================')
        scheduler.step()
    ## save trained model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_ESGD_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_ESGD,modelname)

Epoch: 1


	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  /pytorch/torch/csrc/utils/python_arg_parser.cpp:766.)
  mdw.mul_(mom).add_(1-damp, dw)






Training: Average loss: 0.1951, Accuracy: 56255/60000 (94%)
Test: Average loss: 0.1891, Accuracy: 9369/10000 (94%)
Epoch: 2
Training: Average loss: 0.0907, Accuracy: 58141/60000 (97%)
Test: Average loss: 0.0925, Accuracy: 9686/10000 (97%)
Epoch: 3
Training: Average loss: 0.0627, Accuracy: 58652/60000 (98%)
Test: Average loss: 0.0723, Accuracy: 9746/10000 (97%)
Epoch: 4
Training: Average loss: 0.0477, Accuracy: 58934/60000 (98%)
Test: Average loss: 0.0633, Accuracy: 9789/10000 (98%)
Epoch: 5
Training: Average loss: 0.0380, Accuracy: 59142/60000 (99%)
Test: Average loss: 0.0618, Accuracy: 9796/10000 (98%)
Adversarial Accuracy: 238.0/9984 (2%)
Epoch: 6
Training: Average loss: 0.0325, Accuracy: 59230/60000 (99%)
Test: Average loss: 0.0623, Accuracy: 9810/10000 (98%)
Epoch: 7
Training: Average loss: 0.0324, Accuracy: 59241/60000 (99%)
Test: Average loss: 0.0723, Accuracy: 9783/10000 (98%)
Epoch: 8
Training: Average loss: 0.0261, Accuracy: 59359/60000 (99%)
Test: Average loss: 0.0630, Accura

# TRAIN MODEL USING PGD SGD 

In [68]:
def infnorm(x):
    infn = torch.max(torch.abs(x.detach().cpu()))
    return infn

In [69]:
def adversarial_training(model, optimiser, loss_fn, x, y, epoch, adversary, k, step, eps, norm, random):
    """Performs a single update against a specified adversary"""
    model.train()

    # Adversial perturbation
    x_adv = adversary(model, x, y, loss_fn, k=k, step=step, eps=eps, norm=norm, random=True)
    #print('l2:',torch.norm(x_adv.detach().cpu()-x.detach().cpu())/np.sqrt(x.detach().cpu().size(0)))    
    #print('linf:',infnorm(x_adv.detach().cpu()-x.detach().cpu())/infnorm(x))    

    optimiser.zero_grad()
    y_pred = model(x_adv)
    loss = loss_fn(y_pred, y)
    loss.backward()
    optimiser.step()

    return loss, y_pred

In [70]:
def train_adversarial(model, device, train_loader, optimizer, epoch,adversary,k,step,eps,norm,random):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        ypred = model(data)
        
        sgd_loss = nn.CrossEntropyLoss()
        # calculate robust loss
        loss = adversarial_training(model,optimizer,sgd_loss,data,target,epoch,adversary,k,step,eps,norm,random)

## l2 ball

In [72]:
if TrainL2:
    ## initialize model
    model_l2 = Net().to(DEVICE)
    
    ## training params
    optimizer = optim.SGD(model_l2.parameters(), lr=lr)
    epochs = 10
    
    ## training params
    lr = 0.1
    optimizer = optim.SGD(model_l2.parameters(), lr=lr)
    scheduler = StepLR(optimizer, step_size=20, gamma=0.5)    
    ## train model

    for epoch in range(1, epochs + 1):
        print('Epoch:',epoch)
        # adjust learning rate for SGD

        # adversarial training
        train_adversarial(model_l2, DEVICE, train_loader, optimizer, epoch,adversary=pgd,k=10,step=0.1,eps=3.0,norm=2,random=True)
        # evaluation on natural examples
        print('================================================================')
        eval_train(model_l2, DEVICE, train_loader)
        eval_test(model_l2, DEVICE, val_loader)
        if (epoch) % 5 == 0:
            evaluate_against_adversary(model_l2, k=20, eps=3.0, step=0.5, norm=2)
        print('================================================================')
        scheduler.step()

        
    '''    
    lr = 0.1
    optimiser = optim.SGD(adv_model_l2.parameters(), lr=lr)
    epochs = 5
    ## train model
    training_history_l2 = olympic.fit(
        adv_model_l2,
        optimiser,
        nn.CrossEntropyLoss(),
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=adversarial_training,
        update_fn_kwargs={'adversary': pgd, 'k': 10, 'step': 0.01, 'eps': 3.0, 'norm': 2, 'random':True},
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5, factor=0.5, min_delta=0.005, monitor='val_accuracy')
        ]
    )
    ## verify validation accuracy
    print('final validation accuracy:')
    valacc = olympic.evaluate(adv_model_l2, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    print(valacc['val_accuracy'])
    '''
    ## save trained model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_AT2_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_l2,modelname)

Epoch: 1




Training: Average loss: 0.2274, Accuracy: 55610/60000 (93%)
Test: Average loss: 0.2118, Accuracy: 9321/10000 (93%)
Epoch: 2
Training: Average loss: 0.0746, Accuracy: 58541/60000 (98%)
Test: Average loss: 0.0691, Accuracy: 9778/10000 (98%)
Epoch: 3
Training: Average loss: 0.0443, Accuracy: 59114/60000 (99%)
Test: Average loss: 0.0459, Accuracy: 9843/10000 (98%)
Epoch: 4
Training: Average loss: 0.0347, Accuracy: 59290/60000 (99%)
Test: Average loss: 0.0392, Accuracy: 9858/10000 (99%)
Epoch: 5
Training: Average loss: 0.0279, Accuracy: 59412/60000 (99%)
Test: Average loss: 0.0361, Accuracy: 9868/10000 (99%)
Adversarial Accuracy: 2382.0/9984 (24%)
Epoch: 6
Training: Average loss: 0.0216, Accuracy: 59541/60000 (99%)
Test: Average loss: 0.0320, Accuracy: 9880/10000 (99%)
Epoch: 7
Training: Average loss: 0.0183, Accuracy: 59592/60000 (99%)
Test: Average loss: 0.0314, Accuracy: 9878/10000 (99%)
Epoch: 8
Training: Average loss: 0.0160, Accuracy: 59639/60000 (99%)
Test: Average loss: 0.0303, Accu

## linf ball

In [None]:
if TrainLInf:
    ## initialize model
    adv_model_linf = Net().to(DEVICE)
    ## train params
    lr = 0.1
    optimiser = optim.SGD(adv_model_linf.parameters(), lr=lr)
    epochs = 5
    ## train model
    training_history_linf = olympic.fit(
        adv_model_linf,
        optimiser,
        nn.CrossEntropyLoss(),
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=adversarial_training,
        update_fn_kwargs={'adversary': iterated_fgsm,'k': 10, 'step': 0.07, 'eps': 0.032, 'norm': 'inf', 'random':True},
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5, factor=0.5, min_delta=0.005, monitor='val_accuracy')
        ]
    )
    ## verify validation
    print('final validation accuracy:')
    valacc = olympic.evaluate(adv_model_linf, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    
    ## save trained model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_ATInf_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(adv_model_linf,modelname)

# TRAIN MODEL USING SAT

In [None]:
def adversarial_training_entropy(model, optimiser, loss_fn, x, y, epoch, adversary, k, step, eps, norm, gamma):
    """Performs a single update against a specified adversary"""
    model.train()
    
    # Adversial perturbation
    #alpha = 0.8
    N = 1
    loss = 0
    for l in range(N):
        x_adv = adversary(model, x, y, loss_fn, k=k, step=step, eps=eps, norm=norm, random=True, gamma=gamma)
        
        optimiser.zero_grad()
        y_pred = model(x_adv)
        loss = loss + loss_fn(y_pred,y)
        #loss = (1-alpha)*loss + alpha*loss_fn(y_pred, y)
    loss = loss/N
    loss.backward()
    optimiser.step()
    
    return loss, y_pred

In [None]:
if TrainSAT2:
    ## initialize model
    model_SAT2 = model_cnn().to(DEVICE)
    ## train params
    lr = 0.1
    optimiser = optim.SGD(model_SAT2.parameters(), lr=lr)
    epochs = 5
    ## train model
    training_history_entropySmoothing = olympic.fit(
        model_SAT2,
        optimiser,
        nn.CrossEntropyLoss(),
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=adversarial_training_entropy,
        update_fn_kwargs={'adversary': entropySmoothing, 'k': 10, 'step': 0.01, 'eps': 3.0, 'norm': 2, 'gamma':1e-5},
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5, factor=0.5, min_delta=0.005, monitor='val_accuracy')
        ]
    )
    ## verify validation
    print('final validation accuracy:')
    olympic.evaluate(model_SAT2, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    ## save model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_SAT2_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_SAT2,modelname)

In [None]:
if TrainSATInf:
    ## initialize model
    model_SATInf = model_cnn().to(DEVICE)
    ## train params
    lr = 0.1
    optimiser = optim.SGD(model_SATInf.parameters(), lr=lr)
    epochs = 5
    ## train model
    training_history_entropySmoothing = olympic.fit(
        model_SATInf,
        optimiser,
        nn.CrossEntropyLoss(),
        dataloader=train_loader,
        epochs=epochs,
        metrics=['accuracy'],
        prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)),
        update_fn=adversarial_training_entropy,
        update_fn_kwargs={'adversary': entropySmoothing, 'k': 10, 'step': 0.01, 'eps': 0.3, 'norm': 'inf', 'gamma':1e-5},
        callbacks=[
            olympic.callbacks.Evaluate(val_loader),
            olympic.callbacks.ReduceLROnPlateau(patience=5, factor=0.5, min_delta=0.005, monitor='val_accuracy')
        ]
    )
    ## verify validation
    print('final validation accuracy:')
    olympic.evaluate(model_SATInf, val_loader, metrics=['accuracy'],
                     prepare_batch = lambda batch: (batch[0].to(DEVICE), batch[1].to(DEVICE)))
    ## save model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_SATInf_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_SATInf,modelname)

# TRAIN MODEL USING TRADES

In [None]:
args = {}
args['test_batch_size'] = 128
args['no_cuda'] = False
args['epsilon'] = 0.3
args['num_steps'] = 5
args['step_size'] = 0.01
args['random'] =True,
args['white_box_attack']=True
args['log_interval'] = 100
args['beta'] = 1.0
args['log_interval'] = 1

In [None]:
if TrainTRADES:
    args['attack'] = 'l_2' #or 'l_inf'
    ## initialize model
    model_TRADES = model_cnn().to(DEVICE)
    ## training params
    lr = 0.1
    optimizer = optim.SGD(model_TRADES.parameters(), lr=lr)
    epochs = 10
    ## train model

    for epoch in range(1, epochs + 1):
        # adjust learning rate for SGD
        #adjust_learning_rate(optimizer, epoch)

        # adversarial training
        train(args, model_TRADES, DEVICE, train_loader, optimizer, epoch)

        # evaluation on natural examples
        print('================================================================')
        eval_train(model_TRADES, DEVICE, train_loader)
        eval_test(model_TRADES, DEVICE, val_loader)
        print('================================================================')

    ## save model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_TRADES_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_TRADES,modelname)

In [None]:
if TrainTRADESInf:
    args['attack'] = 'l_inf' 
    ## initialize model
    model_TRADESInf = Net().to(DEVICE)
    ## training params
    lr = 0.1
    optimizer = optim.SGD(model_TRADESInf.parameters(), lr=lr)
    epochs = 10
    ## train model

    for epoch in range(1, epochs + 1):
        # adjust learning rate for SGD
        #adjust_learning_rate(optimizer, epoch)

        # adversarial training
        train(args, model_TRADESInf, DEVICE, train_loader, optimizer, epoch)

        # evaluation on natural examples
        print('================================================================')
        eval_train(model_TRADESInf, DEVICE, train_loader)
        eval_test(model_TRADESInf, DEVICE, val_loader)
        print('================================================================')

    ## save model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_TRADESInf_ep'+str(epochs)+'_lr'+str(lr)+'.pt'
    torch.save(model_TRADESInf,modelname)

# TRAIN MODEL USING MART

In [None]:
#TBD

# TRAIN MODEL USING MMA

In [None]:
#TBD

# LOAD ALL PRE-TRAINED MODELS

In [None]:
TrainSGD = TrainSGD*True
TrainESGD = TrainESGD*True
TrainL2 = TrainL2*True
TrainLInf = TrainLInf*True
TrainSAT2 = TrainSAT2*True
TrainSATInf = TrainSATInf*True
TrainTRADES = TrainTRADES*True
TrainTRADESInf = TrainTRADESInf*True
TrainMART = TrainMART*False
TrainMMA = TrainMMA*False
TrainMMAInf = TrainMMAInf*False

In [None]:
TrainTRADESInf = True

In [None]:
# Load all the pre-trained models
if not TrainSGD:
    model_SGD = torch.load('../trainedmodels/'+dataset+'/SGD_ep10_lr0.1.pt').to(DEVICE)
if not TrainESGD:    
    model_ESGD = torch.load('../trainedmodels/'+dataset+'/ESGD_ep5_lr0.1.pt').to(DEVICE)
if not TrainLInf:
    adv_model_linf = torch.load('../trainedmodels/'+dataset+'/AT2_ep2_lr0.1.pt').to(DEVICE)
if not TrainL2:
    adv_model_l2 = torch.load('../trainedmodels/'+dataset+'/ATInf_ep2_lr0.1.pt').to(DEVICE)
if not TrainSAT2:
    model_SAT2 = torch.load('../trainedmodels/'+dataset+'/SAT2_ep2_lr0.1.pt').to(DEVICE)
if not TrainSATInf:
    model_SATInf = torch.load('../trainedmodels/'+dataset+'/SATInf_ep2_lr0.1.pt').to(DEVICE)
if not TrainTRADES:
    model_TRADES = torch.load('../trainedmodels/'+dataset+'/TRADES_ep2_lr0.1.pt').to(DEVICE)
if not TrainTRADESInf:
    model_TRADES = torch.load('../trainedmodels/'+dataset+'/TRADESInf_ep2_lr0.1.pt').to(DEVICE)    
if not TrainMART:
    if dataset=='CIFAR10':
        model_MART = Net()
        if NetName == 'ResNet18':
            dics = torch.load('../trainedmodels/'+dataset+'/ResNet18_model_MART.pt')
        if NetName == 'WideResNet':
            dics = torch.load('../trainedmodels/'+dataset+'/rob_cifar_mart.pt')            
        model_MART.load_state_dict(dics)
        model_MART = model_MART.to(DEVICE)
if not TrainMMA:
    model_MMA = torch.load('../trainedmodels/'+dataset+'/'+dataset.lower()+'-L2-MMA-4.0-sd0/model_best.pt')
    model_MMAInf = torch.load('../trainedmodels/'+dataset+'/'+dataset.lower()+'-Linf-MMA-0.45-sd0/model_best.pt')

## Quantifying adversarial accuracy

In [None]:
def infnorm(x):
    infn = torch.max(torch.abs(x.detach().cpu()))
    return infn

## Evaluate robust models

In [None]:
loadResults = False

In [None]:
t1 = time.time()
if not loadResults:
    pgd_attack_range = [3.0] #np.arange(0, 6.1, 1./3)
    acc_SGD = []
    acc_ESGD = []
    acc_l2 = []
    acc_SAT2 = []
    acc_TRADES = []
    acc_MMA = []
    for eps in pgd_attack_range:
        print('eps:',eps)
        print('evaluating SGD network...')
        acc_SGD.append(evaluate_against_adversary(model_SGD, k=20, eps=eps, step=0.5, norm=2))
        print('evaluating ESGD network...')
        acc_ESGD.append(evaluate_against_adversary(model_ESGD, k=20, eps=eps, step=0.5, norm=2))
        print('evaluating SAT2 network...')
        acc_SAT2.append(evaluate_against_adversary(model_SAT2, k=30, eps=eps, step=0.3, norm=2))
        print('evaluating l2 network...')
        acc_l2.append(evaluate_against_adversary(adv_model_l2, k=30, eps=eps, step=1.5, norm=2))
        print('evaluating TRADES network...')
        acc_TRADES.append(evaluate_against_adversary(model_TRADES, k=30, eps=eps, step=1.5, norm=2))  
        print('evaluating MART network...')
        acc_MMA.append(evaluate_against_adversary(model_MMA, k=30, eps=eps, step=1.5, norm=2))          
print("time elapsed:",time.time()-t1)    

In [None]:
if not loadResults:    
    accData = [acc_SGD,acc_ESGD,acc_l2,acc_SAT2,acc_TRADES,acc_MMA]
    np.save('../results/accData_l2.npy',accData)

In [None]:
if loadResults:
    pgd_attack_range = [3.0] #np.arange(0.0, 6.1, 1./3)
    accData2 = np.load('../results/accData_l2.npy')
    [acc_SGD,acc_ESGD,acc_l2,acc_SAT2,acc_TRADES,acc_MMA] = accData2

In [None]:
with plt.style.context('ggplot'):
    COLOURS = plt.rcParams['axes.prop_cycle'].by_key()['color']
    MARKERS = plt.rcParams['axes.prop_cycle'].by_key()['marker']

In [None]:
if len(pgd_attack_range) > 1:
    with plt.style.context('ggplot'):
        plt.rcParams["font.weight"] = "bold"
        plt.rcParams["axes.labelweight"] = "bold"
        plt.rcParams["axes.labelcolor"] = "black"
        fig, axes = plt.subplots(1, 1, figsize=(15,10))
        plt.figure(figsize=(14,7))
        axes.set_title('$L^2$-bounded adversary')
        axes.plot(pgd_attack_range, acc_SGD, label='SGD')
        axes.plot(pgd_attack_range, acc_ESGD, label='Entropy-SGD')
        axes.plot(pgd_attack_range, acc_SAT2, label='Data-Entropy-SGD ($L_2$)')
        axes.plot(pgd_attack_range, acc_l2, label='$L2$ training')
        axes.plot(pgd_attack_range, acc_TRADES, label='TRADES training')
        axes.plot(pgd_attack_range, acc_MMA, label='MMA training')

        axes.vlines([3], 0, 1, colors=COLOURS[1], linestyle='--')
        axes.set_ylabel('Accuracy')
        axes.set_xlabel('Epsilon')
        axes.set_ylim((0,1))
        axes.legend()
else:
    print(acc_Data)

In [None]:
loadResultsFGSM = False

In [None]:
t1 = time.time()
if not loadResultsFGSM:
    fgsm_attack_range = [8/255] #np.arange(0.0, 0.52, 0.025)
    fgsm_acc_linf = []
    fgsm_acc_SGD = []
    fgsm_acc_ESGD = []
    fgsm_acc_SATInf = []
    fgsm_acc_TRADESInf = []
    fgsm_acc_MART = []
    fgsm_acc_MMAInf = []
    for eps in fgsm_attack_range:
        print('eps:',eps)
        print('evaluating SGD network...')
        fgsm_acc_SGD.append(evaluate_against_adversary(model_SGD, k=20, eps=eps, step=0.1, norm='inf'))
        print('evaluating ESGD network...')
        fgsm_acc_ESGD.append(evaluate_against_adversary(model_ESGD, k=20, eps=eps, step=0.1, norm='inf'))
        print('evaluating SAT2 network...')
        fgsm_acc_SATInf.append(evaluate_against_adversary(model_SATInf, k=10, eps=eps, step=0.1, norm='inf'))        
        print('evaluating linf network...')
        fgsm_acc_linf.append(evaluate_against_adversary(adv_model_linf, k=50, eps=eps, step=0.02, norm='inf'))
        print('evaluating TRADES network...')
        fgsm_acc_TRADESInf.append(evaluate_against_adversary(model_TRADESInf, k=50, eps=eps, step=0.02, norm='inf'))
        print('evaluating MART network...')
        fgsm_acc_MART.append(evaluate_against_adversary(model_MART, k=50, eps=eps, step=0.02, norm='inf'))
        print('evaluating MMA network...')
        fgsm_acc_MMAInf.append(evaluate_against_adversary(model_MMAInf, k=50, eps=eps, step=0.02, norm='inf'))
         
print("time elapsed:",time.time()-t1)    

In [None]:
if not loadResultsFGSM:    
    fgsmaccData = [fgsm_acc_SGD,fgsm_acc_ESGD,fgsm_acc_linf,fgsm_acc_SATInf, fgsm_acc_TRADESInf, fgsm_acc_MART, fgsm_acc_MMAInf]
    np.save('../results/fgsmaccData.npy',fgsmaccData)

In [None]:
if loadResultsFGSM:
    fgsm_attack_range = [8/255] #np.arange(0.0, 0.52, 0.025)
    accData2 = np.load('../results/fgsmaccData.npy')
    [fgsm_acc_SGD,fgsm_acc_ESGD,fgsm_acc_linf,fgsm_acc_SATInf, fgsm_acc_TRADESInf, fgsm_acc_MART, fgsm_acc_MMAInf] = accData2

In [None]:
if len(fgsm_attack_range) > 1:
    with plt.style.context('ggplot'):
        plt.rcParams["font.weight"] = "bold"
        plt.rcParams["axes.labelweight"] = "bold"
        plt.rcParams["axes.labelcolor"] = "black"
        fig, axes = plt.subplots(1, 1, figsize=(15,10))
        plt.figure(figsize=(14,7))
        axes.set_title('$L^\infty$-bounded adversary')
        axes.plot(fgsm_attack_range, fgsm_acc_SGD, label='SGD')
        axes.plot(fgsm_attack_range, fgsm_acc_ESGD, label='Entropy-SGD')
        axes.plot(fgsm_attack_range, fgsm_acc_SATInf, label='Data-Entropy-SGD ($L_\infty$)')    
        axes.plot(fgsm_attack_range, fgsm_acc_linf, label='$L{\infty}$ training')
        axes.plot(fgsm_attack_range, fgsm_acc_TRADES, label='TRADES training')
        axes.vlines([0.25], 0, 1, colors=COLOURS[1], linestyle='--')
        axes.set_ylabel('Accuracy')
        axes.set_xlabel('Epsilon')
        axes.set_ylim((0,1))
        axes.legend()
else:
    print(acc_Data)