In [1]:
# linear.svg
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import LightSource

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #

In [3]:
torch.manual_seed(0)

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1)    

model_cnn = nn.Sequential(
      nn.Linear(784, 10)).to(device)

# Multiple ways of attacking

In [4]:
def fgsm(model, X, y, epsilon=0.1):
    """ Construct FGSM adversarial examples on the examples X"""
    delta = torch.zeros_like(X, requires_grad=True)
    loss = nn.CrossEntropyLoss()(model(X + delta), y)
    loss.backward()
    return epsilon * delta.grad.detach().sign()

def pgd_linf(model, X, y, epsilon=0.1, alpha=0.01, num_iter=5, randomize=False):
    """ Construct FGSM adversarial examples on the examples X"""
    if randomize:
        delta = torch.rand_like(X, requires_grad=True)
        delta.data = delta.data * 2 * epsilon - epsilon
    else:
        delta = torch.zeros_like(X, requires_grad=True)
        
    for t in range(num_iter):
        loss = nn.CrossEntropyLoss()(model(X + delta), y)
        loss.backward()
        delta.data = (delta + alpha*delta.grad.detach().sign()).clamp(-epsilon,epsilon)
        delta.grad.zero_()
    return delta.detach()

In [5]:
def epoch(loader, model, opt=None):
    """Standard training/evaluation epoch over the dataset"""
    total_loss, total_err = 0.,0.
    for X,y in loader:
        X,y = X.to(device), y.to(device)
        yp = model(X)
        loss = nn.CrossEntropyLoss()(yp,y)
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
    return total_err / len(loader.dataset), total_loss / len(loader.dataset)


In [10]:

def epoch_adversarial(loader, model, attack=pgd_linf, opt=None, **kwargs):
    """Adversarial training/evaluation epoch over the dataset"""
    
    total_loss, total_err = 0.,0.
    for X,y in loader:
        X,y = X.to(device), y.to(device)
        delta = pgd_linf(model, X, y)
        yp = model( X + delta )
        loss = nn.CrossEntropyLoss()( yp, y )
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
    return total_err / len(loader.dataset), total_loss / len(loader.dataset)


def epoch_fast_adversarial(loader, model, opt=None,):
    """Adversarial training/evaluation epoch over the dataset"""
    
    total_loss, total_err = 0.,0.
    for X,y in loader:
        X,y = X.to(device), y.to(device)
        delta = fgsm(model, X, y, epsilon=0.1) #pgd_linf(model, X, y)
        yp = model( X + delta )
        loss = nn.CrossEntropyLoss()( yp, y )
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
    return total_err / len(loader.dataset), total_loss / len(loader.dataset)

# Standard (i.e. non-robust) training on the logistic regression

In [7]:
# This cell should run in less than 3 minutes

from tqdm.notebook import tqdm

model_logreg = nn.Sequential(
     Flatten(),
     nn.Linear(784, 10)).to(device)
opt = optim.SGD(model_logreg.parameters(), lr=1e-1)

print(*("{}".format(i) for i in ("Train Err", "Test Err", "Adv Err")), sep="\t")

for t in tqdm(range(10)):
    train_err, train_loss = epoch(train_loader, model_logreg, opt)
    test_err, test_loss = epoch(test_loader, model_logreg)
    adv_err, adv_loss = epoch_adversarial(test_loader, model_logreg, pgd_linf)
    if t == 4:
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-2
    print(*("{:.6f}".format(i) for i in (train_err, test_err, adv_err)), sep="\t")

Train Err	Test Err	Adv Err


  0%|          | 0/10 [00:00<?, ?it/s]

0.133783	0.097800	0.389200
0.099067	0.089300	0.405500
0.092850	0.084800	0.430700
0.088050	0.081300	0.450600
0.085333	0.081400	0.458700
0.082850	0.081400	0.463600
0.082700	0.081300	0.464100
0.082717	0.081600	0.466000
0.082800	0.081300	0.467400
0.082567	0.081300	0.468800


# Adversarial training on the logistic regression

In [8]:
# This cell should run in less than 3 minutes
model_logreg_robust = nn.Sequential(
     Flatten(),
     nn.Linear(784, 10)).to(device)

opt = optim.SGD(model_logreg_robust.parameters(), lr=1e-1)
print(*("{}".format(i) for i in ("Train Err", "Test Err", "Adv Err")), sep="\t")

for t in tqdm( range(10) ):
    train_err, train_loss = epoch_adversarial(train_loader, model_logreg_robust, pgd_linf, opt)
    test_err, test_loss = epoch(test_loader, model_logreg_robust)
    adv_err, adv_loss = epoch_adversarial(test_loader, model_logreg_robust, pgd_linf)
    if t == 4:
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-2
    print(*("{:.6f}".format(i) for i in (train_err, test_err, adv_err)), sep="\t")

Train Err	Test Err	Adv Err


  0%|          | 0/10 [00:00<?, ?it/s]

0.298583	0.113600	0.254900
0.260250	0.110300	0.248600
0.255050	0.111700	0.248300
0.254083	0.109800	0.245500
0.251383	0.111000	0.247900
0.247200	0.108700	0.243200
0.246917	0.107900	0.243500
0.246867	0.108100	0.243800
0.247283	0.108700	0.241500
0.246800	0.108100	0.243600


# Adversarial training on a simple CNN

In [9]:
# This cell should run in less than 15 minutes
model_simple_cnn_robust = nn.Sequential(
                                 nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
                                 Flatten(),
                                 nn.Linear(25088, 10)).to(device)

opt = optim.SGD(model_simple_cnn_robust.parameters(), lr=0.1)
print(*("{}".format(i) for i in ("Train Err", "Test Err", "Adv Err")), sep="\t")

for t in range(10):
    
    train_err, train_loss = epoch_adversarial(train_loader, model_simple_cnn_robust, pgd_linf, opt)
    test_err, test_loss = epoch(test_loader, model_simple_cnn_robust)
    adv_err, adv_loss = epoch_adversarial(test_loader, model_simple_cnn_robust, pgd_linf)
    
    if t == 4:
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-2
    print(*("{:.6f}".format(i) for i in (train_err, test_err, adv_err)), sep="\t")

Train Err	Test Err	Adv Err
0.155133	0.038200	0.080500
0.081883	0.028000	0.065600
0.069933	0.025500	0.060600
0.062767	0.022500	0.055100
0.059250	0.021100	0.054700
0.049383	0.019500	0.049800
0.048267	0.019900	0.050400
0.047517	0.019400	0.050100
0.047283	0.019300	0.049300
0.046967	0.018700	0.050500


In [12]:
# This cell should run in less than 15 minutes
model_simple_cnn_robust = nn.Sequential(
                                 nn.Conv2d(1, 32, 3, padding=1), nn.ReLU(),
                                 Flatten(),
                                 nn.Linear(25088, 10)).to(device)

opt = optim.SGD(model_simple_cnn_robust.parameters(), lr=1e-2)
print(*("{}".format(i) for i in ("Train Err", "Test Err", "Adv Err")), sep="\t")

for t in range(10):
    train_err, train_loss = epoch_fast_adversarial(train_loader, model_simple_cnn_robust, opt)
    test_err, test_loss = epoch(test_loader, model_simple_cnn_robust)
    adv_err, adv_loss = epoch_adversarial(test_loader, model_simple_cnn_robust, pgd_linf)
    if t == 4:
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-2
    print(*("{:.6f}".format(i) for i in (train_err, test_err, adv_err)), sep="\t")

Train Err	Test Err	Adv Err
0.347633	0.101900	0.170100
0.245967	0.090400	0.152900
0.206833	0.077400	0.135700
0.179033	0.067600	0.121800
0.157450	0.060100	0.111400
0.140300	0.053600	0.108900
0.126483	0.050300	0.103900
0.114867	0.048100	0.104100
0.107100	0.041900	0.101400
0.100400	0.039700	0.094000
