Main script for training a classifier for MNIST using l_inf ATENT [Table 2 of paper].

Notebook contains printed result from evaluation of pretrained model

# Adversarial Training via ENTropic regularization (l_inf ATENT)

### SoTA - collected from various papers

vanila SGD: 
MNIST - 99%+ (most cnns), CIFAR10 - 93%+ (resnet18), 96%+ (wideresnet) 

MNIST:

adversarial attacks: 
l-inf @ eps = 80/255 @20 steps: TRADES - 96.07% - (4 layer cnn), MART 96.4%, MMA 95.5%, PGD - 96.01% - (4 layer cnn)

Reference repos for baselines: TRADES : https://github.com/yaodongyu/TRADES (MNIST: small cnn, CIFAR10: WideResNet34) MMA : https://github.com/BorealisAI/mma_training (MNIST: lenet5, CIFAR10: WideResNet28) MART : https://github.com/YisenWang/MART (CIFAR10: ResNet18 and WideResNet34) PGD: (CIFAR10: ResNet50) https://github.com/MadryLab/robustness


### IMPORT LIBRARIES

In [1]:
import sys,os
sys.path.append('../adversarial/')
sys.path.append('../architectures/')
import random

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.autograd import Variable
import torch.backends.cudnn as cudnn

from torchvision import transforms, datasets

import numpy as np
import matplotlib.pyplot as plt

### IMPORT UTILITIES

In [2]:
from functional import entropySmoothing
from utils import eval_train, eval_test, infnorm, train_adversarial

### SET TRAINING PARAMETERS

In [17]:
args = {}
#data loading
args['seed'] = 1
args['test_batch_size'] = 128
args['train_batch_size'] = 128
kwargs = {'num_workers': 4, 'pin_memory': True}
args['no_cuda'] = False

if not args['no_cuda']:
    if torch.cuda.is_available():
        DEVICE = 'cuda'
    else:
        DEVICE = 'cpu'
else:
    DEVICE = 'cpu'

# params for SGLD (inner loop)
args['attack'] = 'l_inf'
args['norm'] = 'inf'
args['epsilon'] = 0.3
args['num_steps'] = 40
args['step_size'] = 0.01
args['random'] =True

# params for SGD (outer loop)
args['lr'] = 0.01
args['momentum'] = 0.9
args['weight_decay'] = 1e-4
args['epochs'] = 30
args['save_freq'] = 1

# load model
args['pretrained'] = True

### LOAD DATA

In [10]:
dataset = 'MNIST' # [MNIST, CIFAR10]
transform = transforms.Compose([
transforms.ToTensor()])
train = datasets.MNIST('../../data/', train=True, transform=transform, download=True)
val = datasets.MNIST('../../data/', train=False, transform=transform, download=True)
train_loader = DataLoader(train, batch_size=args['train_batch_size'], shuffle=True, **kwargs)
val_loader = DataLoader(val, batch_size=args['test_batch_size'], shuffle=False, **kwargs)

### INITIALIZE NETWORK

In [11]:
if dataset=='MNIST':
    from small_cnn import SmallCNN   
    Net = SmallCNN
    NetName = 'SmallCNN'

### SET RANDOM SEED 

In [12]:
seed = args['seed']
torch.set_num_threads(2)
if DEVICE=='cuda':
    torch.cuda.set_device(-1)
    torch.cuda.manual_seed(seed)
    cudnn.benchmark = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7fe04c046f10>

### WHITEBOX L-INF ATTACK

In [13]:
def _pgd_whitebox(model,
                  X,
                  y,
                  epsilon=args['epsilon'],
                  norm=args['norm'],
                  num_steps=args['num_steps'],
                  step_size=args['step_size']):
    out = model(X)
    err = (out.data.max(1)[1] != y.data).float().sum()
    X_pgd = Variable(X.data, requires_grad=True)
    if args['random']:
        random_noise = torch.FloatTensor(*X_pgd.shape).uniform_(-epsilon, epsilon).to(DEVICE)
        X_pgd = Variable(X_pgd.data + random_noise, requires_grad=True)

    for _ in range(num_steps):
        opt = optim.SGD([X_pgd], lr=1e-3)
        opt.zero_grad()

        with torch.enable_grad():
            loss = nn.CrossEntropyLoss()(model(X_pgd), y)
        loss.backward()
        if norm=='inf':
            eta = step_size * X_pgd.grad.data.sign()
            X_pgd = Variable(X_pgd.data + eta, requires_grad=True)
            eta = torch.clamp(X_pgd.data - X.data, -epsilon, epsilon)
            X_pgd = Variable(X.data + eta, requires_grad=True)
        elif norm==2:
            eta = step_size * X_pgd.grad.data / X_pgd.grad.view(X_pgd.shape[0], -1).norm(2, dim=-1)\
                    .view(-1, 1, 1, 1)
            X_pgd = Variable(X_pgd.data + eta, requires_grad=True)
            X_pgd = project(X, X_pgd, norm, epsilon)            
        X_pgd = Variable(torch.clamp(X_pgd, 0, 1.0), requires_grad=True)
    err_pgd = (model(X_pgd).data.max(1)[1] != y.data).float().sum()
    return err, err_pgd

def eval_adv_test_whitebox(model, device, test_loader):
    """
    evaluate model by white-box attack
    """
    model.eval()
    robust_err_total = 0
    natural_err_total = 0

    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        # pgd attack
        X, y = Variable(data, requires_grad=True), Variable(target)
        err_natural, err_robust = _pgd_whitebox(model, X, y)
        robust_err_total += err_robust
        natural_err_total += err_natural
    rob = 100-robust_err_total.item()/100    
    print('natural_acc_total: ', 100-natural_err_total.item()/100)
    print('robust_acc_total: ', rob)
    return rob

### L-INF ATENT MODULE

In [14]:
def adversarial_training_entropy(model, optimiser, loss_fn, x, y, epoch, adversary, L, step, eps, norm):
    """Performs a single update against a specified adversary"""
    model.train()
    
    # Adversial perturbation
    alpha=0.9
    loss = 0
    gamma = 0
    projector = True #(project after each inner epoch)
        
    for l in range(L):     
        
        if l==0: ## initialize using random perturbation of true x, run for one epoch
            k=1
            random=True
            xp = None
        elif l>0 and l<L-1: ## initialize with previous iterate of adversarial perturbation, run one epoch
            k=1
            random=False
            xp=x_adv
        elif l == L-1: ## initialize with previous iterate, run one epoch, project to epsilon ball
            k=1
            random=False
            xp = x_adv
          
        x_adv,bfl = adversary(model, x, y, loss_fn, xp=xp, step=step, eps=eps, norm=norm, random=random, ep=1e-4,projector=projector,gamma=gamma)
        
        optimiser.zero_grad()
        y_pred = model(x_adv)
        pred = y_pred.max(1, keepdim=True)[1]
        correct = pred.eq(y.view_as(pred)).sum().item()
        loss = (1-alpha)*loss + alpha*loss_fn(y_pred, y)
        
    loss.backward()
    optimiser.step()
    return loss, correct

### INITIALIZE NET OR LOAD FROM CHECKPOINT

In [22]:
model_ATENT = Net().to(DEVICE)
#model_ATENT = nn.DataParallel(model_ATENT)
#load pretrained state dict here
if args['pretrained']:
    pathstr = '../trainedmodels/MNIST/BEST_model-nn-epoch23-robacc96.pt' 
    model_ATENT.load_state_dict(torch.load(pathstr))
    rob = eval_adv_test_whitebox(model_ATENT, DEVICE, val_loader)   

natural_acc_total:  99.45
robust_acc_total:  96.4


In [21]:
if not args['pretrained']:
    ## training params
    epochs = args['epochs']  
    lr_init = args['lr']
    optimizer = optim.SGD(model_ATENT.parameters(), lr=lr_init, momentum=0.9)
    ## train model

    for epoch in range(41, epochs+1):
        print('Epoch:',epoch)

        # adversarial training
        train_adversarial(adversarial_training_entropy,model_ATENT, DEVICE, train_loader, optimizer, epoch,adversary=entropySmoothing,L=2*args['num_steps'],step=args['step_size'],eps=args['epsilon'],norm=args['norm'])

        # evaluation on natural and adversarial examples
        print('================================================================')
        eval_train(model_ATENT, DEVICE, train_loader)
        rob = eval_adv_test_whitebox(model_ATENT, DEVICE, val_loader)            
        print('================================================================')

        # save checkpoint
        if (epoch-1) % 1 == 0:
            torch.save(model_ATENT.state_dict(),
                       os.path.join(model_dir, 'model-nn-epoch{}-robacc{}.pt'.format(epoch,int(np.round(rob)))))

    ## save model
    modelname = '../trainedmodels/'+dataset+'/'+NetName+'_ATENT_'+args['attack']+'_ep'+str(epochs)+'_lr'+str(lr_init)+'.pt'
    #torch.save(model_ATENT,modelname)