In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import time

from torchvision import datasets, transforms
from tensorboardX import SummaryWriter

device = torch.device("cpu")
batch_size = 64

np.random.seed(42)
torch.manual_seed(42)

train_dataset = datasets.MNIST('mnist_data/', train=True, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))
test_dataset = datasets.MNIST('mnist_data/', train=False, download=True, transform=transforms.Compose(
    [transforms.ToTensor()]
))

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 50)
        self.fc2 = nn.Linear(50,50)
        self.fc3 = nn.Linear(50,50)
        self.fc4 = nn.Linear(50,10)

    def forward(self, x):
        x = x.view((-1, 28*28))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [3]:
def train_model(model, num_epochs):
    learning_rate = 0.0001
    opt = optim.Adam(params=model.parameters(), lr=learning_rate)
    ce_loss = torch.nn.CrossEntropyLoss()

    for epoch in range(1,num_epochs+1):
        t1 = time.time()

        for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)      
            opt.zero_grad()
            out = model(x_batch)
            batch_loss = ce_loss(out, y_batch)
            batch_loss.backward()
            opt.step()

        tot_test, tot_acc = 0.0, 0.0
        for batch_idx, (x_batch, y_batch) in enumerate(test_loader):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            out = model(x_batch)
            pred = torch.max(out, dim=1)[1]
            acc = pred.eq(y_batch).sum().item()
            tot_acc += acc
            tot_test += x_batch.size()[0]
        t2 = time.time()

        print('Epoch %d: Accuracy %.5lf [%.2lf seconds]' % (epoch, tot_acc/tot_test, t2-t1))

In [4]:
base_model = nn.Sequential(Net())
base_model = base_model.to(device)
base_model.train()
train_model(base_model, 15)

  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch 1: Accuracy 0.86280 [7.95 seconds]
Epoch 2: Accuracy 0.90360 [7.85 seconds]
Epoch 3: Accuracy 0.91270 [7.85 seconds]
Epoch 4: Accuracy 0.91880 [8.25 seconds]
Epoch 5: Accuracy 0.92410 [9.48 seconds]
Epoch 6: Accuracy 0.92990 [9.80 seconds]
Epoch 7: Accuracy 0.93310 [11.12 seconds]
Epoch 8: Accuracy 0.93580 [10.65 seconds]
Epoch 9: Accuracy 0.93970 [9.94 seconds]
Epoch 10: Accuracy 0.94150 [10.49 seconds]
Epoch 11: Accuracy 0.94430 [10.56 seconds]
Epoch 12: Accuracy 0.94610 [10.35 seconds]
Epoch 13: Accuracy 0.94980 [11.20 seconds]
Epoch 14: Accuracy 0.95010 [11.33 seconds]
Epoch 15: Accuracy 0.95350 [11.32 seconds]


In [5]:
def get_interval_input(img, delta):
    il = torch.maximum(img - delta, torch.tensor(0))
    iu = torch.minimum(img + delta, torch.tensor(1))
    return il, iu

In [6]:
def get_robust_accuracy(model, delta):
    tot_test, tot_acc = 0.0, 0.0
    for batch_idx, (x_batch, y_batch) in enumerate(test_loader):
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        w0,b0,w1,b1,w2,b2,w3,b3 = model.parameters()
        b0, b1, b2, b3 = b0.reshape(-1,1), b1.reshape(-1,1), b2.reshape(-1,1), b3.reshape(-1,1)
        
        il, iu = get_interval_input(x_batch, delta)
        il, iu = il.view((-1,784,1)), iu.view((-1,784,1))
            
        l1l = torch.matmul(torch.maximum(w0, torch.tensor(0)), il) + torch.matmul(torch.minimum(w0, torch.tensor(0)), iu)
        l1u = torch.matmul(torch.maximum(w0, torch.tensor(0)), iu) + torch.matmul(torch.minimum(w0, torch.tensor(0)), il)
        l1l, l1u = torch.maximum(l1l + b0, torch.tensor(0)), torch.maximum(l1u + b0, torch.tensor(0))

        l2l = torch.matmul(torch.maximum(w1, torch.tensor(0)), l1l) + torch.matmul(torch.minimum(w1, torch.tensor(0)), l1u)
        l2u = torch.matmul(torch.maximum(w1, torch.tensor(0)), l1u) + torch.matmul(torch.minimum(w1, torch.tensor(0)), l1l)
        l2l, l2u = torch.maximum(l2l + b1, torch.tensor(0)), torch.maximum(l2u + b1, torch.tensor(0))

        l3l = torch.matmul(torch.maximum(w2, torch.tensor(0)), l2l) + torch.matmul(torch.minimum(w2, torch.tensor(0)), l2u)
        l3u = torch.matmul(torch.maximum(w2, torch.tensor(0)), l2u) + torch.matmul(torch.minimum(w2, torch.tensor(0)), l2l)
        l3l, l3u = torch.maximum(l3l + b2, torch.tensor(0)), torch.maximum(l3u + b2, torch.tensor(0))

        ol = torch.matmul(torch.maximum(w3, torch.tensor(0)), l3l) + torch.matmul(torch.minimum(w3, torch.tensor(0)), l3u)
        ou = torch.matmul(torch.maximum(w3, torch.tensor(0)), l3u) + torch.matmul(torch.minimum(w3, torch.tensor(0)), l3l)
        ol, ou = ol + b3, ou + b3

        for idx, t in enumerate(y_batch):
            robust = True
            for j in range(10):
                if j != t:
                    robust = robust and (ol[idx][t] > ou[idx][j])
            tot_acc += 1 if robust else 0
    
        tot_test += x_batch.size()[0]
    print('Epsilon %.3lf: Robust accuracy %.5lf' % (delta, tot_acc/tot_test))

In [7]:
for i in range(1,11):
    get_robust_accuracy(base_model, 0.01*i)

Epsilon 0.010: Robust accuracy 0.00000
Epsilon 0.020: Robust accuracy 0.00000
Epsilon 0.030: Robust accuracy 0.00000
Epsilon 0.040: Robust accuracy 0.00000
Epsilon 0.050: Robust accuracy 0.00000
Epsilon 0.060: Robust accuracy 0.00000
Epsilon 0.070: Robust accuracy 0.00000
Epsilon 0.080: Robust accuracy 0.00000
Epsilon 0.090: Robust accuracy 0.00000
Epsilon 0.100: Robust accuracy 0.00000


In [8]:
def robust_train_model(model, num_epochs, delta, kappa_min, kappa_max):
    learning_rate = 0.001
    opt = optim.Adam(params=model.parameters(), lr=learning_rate)
    fit_loss = torch.nn.CrossEntropyLoss()
    spec_loss = torch.nn.CrossEntropyLoss()

    for epoch in range(1,num_epochs+1):
        t1 = time.time()
        
        kappa = kappa_max - (kappa_max-kappa_min)*(epoch-1)/(num_epochs-1)
        
        for batch_idx, (x_batch, y_batch) in enumerate(train_loader):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)   
            opt.zero_grad()

            w0,b0,w1,b1,w2,b2,w3,b3 = model.parameters()
            b0, b1, b2, b3 = b0.reshape(-1,1), b1.reshape(-1,1), b2.reshape(-1,1), b3.reshape(-1,1)
            
            il, iu = get_interval_input(x_batch, delta)
            il, iu = il.view((-1,784,1)), iu.view((-1,784,1))
            
            l1l = torch.matmul(torch.maximum(w0, torch.tensor(0)), il) + torch.matmul(torch.minimum(w0, torch.tensor(0)), iu)
            l1u = torch.matmul(torch.maximum(w0, torch.tensor(0)), iu) + torch.matmul(torch.minimum(w0, torch.tensor(0)), il)
            l1l, l1u = torch.maximum(l1l + b0, torch.tensor(0)), torch.maximum(l1u + b0, torch.tensor(0))
            
            l2l = torch.matmul(torch.maximum(w1, torch.tensor(0)), l1l) + torch.matmul(torch.minimum(w1, torch.tensor(0)), l1u)
            l2u = torch.matmul(torch.maximum(w1, torch.tensor(0)), l1u) + torch.matmul(torch.minimum(w1, torch.tensor(0)), l1l)
            l2l, l2u = torch.maximum(l2l + b1, torch.tensor(0)), torch.maximum(l2u + b1, torch.tensor(0))
            
            l3l = torch.matmul(torch.maximum(w2, torch.tensor(0)), l2l) + torch.matmul(torch.minimum(w2, torch.tensor(0)), l2u)
            l3u = torch.matmul(torch.maximum(w2, torch.tensor(0)), l2u) + torch.matmul(torch.minimum(w2, torch.tensor(0)), l2l)
            l3l, l3u = torch.maximum(l3l + b2, torch.tensor(0)), torch.maximum(l3u + b2, torch.tensor(0))
            
            ol = torch.matmul(torch.maximum(w3, torch.tensor(0)), l3l) + torch.matmul(torch.minimum(w3, torch.tensor(0)), l3u)
            ou = torch.matmul(torch.maximum(w3, torch.tensor(0)), l3u) + torch.matmul(torch.minimum(w3, torch.tensor(0)), l3l)
            ol, ou = ol + b3, ou + b3
            
            true_mask = F.one_hot(y_batch, num_classes=10)
            false_mask = 1 - true_mask
            worst = ou[:,:,0] * false_mask + ol[:,:,0] * true_mask
            
            out = model(x_batch)
            batch_loss = kappa*fit_loss(out, y_batch) + (1-kappa)*spec_loss(worst, y_batch)
            batch_loss.backward()
            opt.step()

        tot_test, tot_acc = 0.0, 0.0
        for batch_idx, (x_batch, y_batch) in enumerate(test_loader):
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            out = model(x_batch)
            pred = torch.max(out, dim=1)[1]
            acc = pred.eq(y_batch).sum().item()
            tot_acc += acc
            tot_test += x_batch.size()[0]
        t2 = time.time()

        print('Epoch %d: Accuracy %.5lf [%.2lf seconds]' % (epoch, tot_acc/tot_test, t2-t1))

In [9]:
robust_model = nn.Sequential(Net())
robust_model = robust_model.to(device)
robust_model.train()
robust_train_model(robust_model, 30, 0.1, 0.5, 1)

Epoch 1: Accuracy 0.93190 [16.00 seconds]
Epoch 2: Accuracy 0.93390 [16.28 seconds]
Epoch 3: Accuracy 0.94390 [17.57 seconds]
Epoch 4: Accuracy 0.95040 [16.25 seconds]
Epoch 5: Accuracy 0.95080 [17.82 seconds]
Epoch 6: Accuracy 0.95120 [16.81 seconds]
Epoch 7: Accuracy 0.95160 [16.59 seconds]
Epoch 8: Accuracy 0.95330 [16.99 seconds]
Epoch 9: Accuracy 0.95310 [16.15 seconds]
Epoch 10: Accuracy 0.95430 [17.24 seconds]
Epoch 11: Accuracy 0.95440 [17.55 seconds]
Epoch 12: Accuracy 0.95420 [17.26 seconds]
Epoch 13: Accuracy 0.95410 [17.47 seconds]
Epoch 14: Accuracy 0.95170 [18.52 seconds]
Epoch 15: Accuracy 0.95120 [17.42 seconds]
Epoch 16: Accuracy 0.95210 [16.01 seconds]
Epoch 17: Accuracy 0.95140 [16.40 seconds]
Epoch 18: Accuracy 0.95160 [16.58 seconds]
Epoch 19: Accuracy 0.95160 [16.44 seconds]
Epoch 20: Accuracy 0.94900 [16.79 seconds]
Epoch 21: Accuracy 0.95130 [15.54 seconds]
Epoch 22: Accuracy 0.95110 [15.63 seconds]
Epoch 23: Accuracy 0.95080 [16.07 seconds]
Epoch 24: Accuracy 0

In [10]:
for i in range(1,11):
    get_robust_accuracy(robust_model, 0.01*i)

Epsilon 0.010: Robust accuracy 0.92820
Epsilon 0.020: Robust accuracy 0.91320
Epsilon 0.030: Robust accuracy 0.90110
Epsilon 0.040: Robust accuracy 0.89040
Epsilon 0.050: Robust accuracy 0.87580
Epsilon 0.060: Robust accuracy 0.86090
Epsilon 0.070: Robust accuracy 0.84380
Epsilon 0.080: Robust accuracy 0.82270
Epsilon 0.090: Robust accuracy 0.79810
Epsilon 0.100: Robust accuracy 0.76940
