In [None]:
"""
Adversarially train LeNet-5
"""

import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn.functional as F

from adversarialbox.attacks import FGSMAttack, LinfPGDAttack, CertAttack
from adversarialbox.train import adv_train, FGSM_train_rnd
from adversarialbox.utils import to_var, pred_batch, test
from IPython.core.debugger import Tracer

from models import LeNet5, LeNet5ELU


# Hyper-parameters
param = {
    'batch_size': 64,
    'test_batch_size': 100,
    #'num_epochs': 16,
    'num_epochs': 150,
    'delay': 10,
    'learning_rate': 1e-3,
    'weight_decay': 5e-4,
}


# Data loaders
train_dataset = datasets.MNIST(root='../data/',train=True, download=True, 
    transform=transforms.ToTensor())
loader_train = torch.utils.data.DataLoader(train_dataset, 
    batch_size=param['batch_size'], shuffle=True)

test_dataset = datasets.MNIST(root='../data/', train=False, download=True, 
    transform=transforms.ToTensor())
loader_test = torch.utils.data.DataLoader(test_dataset, 
    batch_size=param['test_batch_size'], shuffle=True)


# Setup the model
# net = LeNet5()
net = LeNet5ELU()

if torch.cuda.is_available():
    print('CUDA ensabled.')
    net.cuda()
net.train()

# Adversarial training setup
#adversary = FGSMAttack(epsilon=0.3)
#adversary = LinfPGDAttack()
adversary = CertAttack(a=1e3, gamma=0.04)

# Train the model
criterion = nn.CrossEntropyLoss()
#optimizer = torch.optim.RMSprop(net.parameters(), lr=param['learning_rate'],
#    weight_decay=param['weight_decay'])

from torch.optim.lr_scheduler import StepLR

optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
scheduler = StepLR(optimizer, step_size=100, gamma=0.1)

numiters = 10000
iters = 0
for epoch in range(param['num_epochs']):

    print('Starting epoch %d / %d' % (epoch + 1, param['num_epochs']))
    #scheduler.step()
    for t, (x, y) in enumerate(loader_train):
        iters = iters + 1
        x_var, y_var = to_var(x), to_var(y.long())
        loss = criterion(net(x_var), y_var)

        # adversarial training
        #if epoch+1 > param['delay']: start adv training imediately
        if epoch+1 > 0:
            # use predicted label to prevent label leaking
            #y_pred = pred_batch(x, net)
            x_adv = adv_train(x, y.cpu().long(), net, criterion, adversary)
            x_adv_var = to_var(x_adv)
            loss_adv = criterion(net(x_adv_var), y_var)
            #loss = (loss + loss_adv) / 2
            loss = loss_adv

        if (t + 1) % 100 == 0:
            print('t = %d, loss = %.8f' % (t + 1, loss.data[0]))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if iters > numiters:
        break


test(net, loader_test)

torch.save(net.state_dict(), 'models/adv_trained_lenet5.pkl')
