In [6]:
# linear.svg
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import LightSource

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

mnist_train = datasets.MNIST("../data", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST("../data", train=False, download=True, transform=transforms.ToTensor())

train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") #

In [8]:
torch.manual_seed(0)

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.shape[0], -1)    

model_cnn = nn.Sequential(
      nn.Linear(784, 10)).to(device)

In [10]:
def epoch(loader, model, opt=None):
    """Standard training/evaluation epoch over the dataset"""
    total_loss, total_err = 0.,0.
    for X,y in loader:
        X,y = X.to(device), y.to(device)
        yp = model(X)
        loss = nn.CrossEntropyLoss()(yp,y)
        if opt:
            opt.zero_grad()
            loss.backward()
            opt.step()
        
        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]
    return total_err / len(loader.dataset), total_loss / len(loader.dataset)


# Standard (i.e. non-robust) training on the logistic regression

In [7]:
# This cell should run in less than 3 minutes

from tqdm.notebook import tqdm

model_logreg = nn.Sequential(
     Flatten(),
     nn.Linear(784, 10)).to(device)
opt = optim.SGD(model_logreg.parameters(), lr=1e-1)

print(*("{}".format(i) for i in ("Train Err", "Test Err", "Adv Err")), sep="\t")

for t in tqdm(range(10)):
    train_err, train_loss = epoch(train_loader, model_logreg, opt)
    test_err, test_loss = epoch(test_loader, model_logreg)
    if t == 4:
        for param_group in opt.param_groups:
            param_group["lr"] = 1e-2
    print(*("{:.6f}".format(i) for i in (train_err, test_err, adv_err)), sep="\t")

Train Err	Test Err	Adv Err


  0%|          | 0/10 [00:00<?, ?it/s]

0.133783	0.097800	0.389200
0.099067	0.089300	0.405500
0.092850	0.084800	0.430700
0.088050	0.081300	0.450600
0.085333	0.081400	0.458700
0.082850	0.081400	0.463600
0.082700	0.081300	0.464100
0.082717	0.081600	0.466000
0.082800	0.081300	0.467400
0.082567	0.081300	0.468800
