In [1]:
import numpy as np                # import numpy
import matplotlib.pyplot as plt   # import matplotlib, a python 2d plotting library
from tqdm import tqdm
import pandas as pd
import seaborn as sns

#import torch packages
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

if torch.cuda.is_available():
  print('Running on Graphics')
  device=torch.device('cuda:0')
else:
  device=torch.device('cpu')
  print('Running on Processor')

Running on Graphics


In [41]:
class bottleneck(nn.Module):
  def __init__(self, in_size, bn_size, act=nn.ReLU()):
    super().__init__()
    self.L1 = nn.Linear(in_size, bn_size)
    self.L2 = nn.Linear(bn_size, in_size)
    self.act = act

  def forward(self, x):
    self.bn = self.act(self.L1(x))
    x = self.act(self.L2(self.bn))
    return x

class Encoder(nn.Module):
  def __init__(self, in_size, bn_size):
    super().__init__()
    self.enc = nn.Sequential(
        nn.Linear(28*28, 100),
        nn.ReLU(),nn.Linear(100,in_size),
        nn.ReLU())
        
  def forward(self, x):
    x = torch.flatten(x, start_dim=1)
    x = self.enc(x)
    return x

class Decoder(nn.Module):
  def __init__(self, in_size, bn_size):
    super().__init__()
    self.dec = nn.Sequential(nn.Linear(in_size, 100), nn.ReLU(),
        nn.Linear(100, 28*28),
        nn.Sigmoid())

  def forward(self, x):
    x = self.dec(x)
    return x.reshape(-1,1,28,28)

class DNA(nn.Module):
    def __init__(self, in_size, bn_size):
        super().__init__()
        self.enc = Encoder(in_size, bn_size)
        self.dec = Decoder(in_size, bn_size)
        self.bn1 = bottleneck(in_size, bn_size)
        self.bn2 = bottleneck(in_size, bn_size)
        
    def forward(self, x):
        x = self.enc(x)
        r1 = self.dec(self.bn1(x))
        r2 = self.dec(self.bn2(x))
        return r1, r2

In [None]:
class AE(nn.Module):
    def __init__(self, in_size, bn_size):
        super().__init__()
        self.enc = Encoder(in_size, bn_size)
        self.dec = Decoder(in_size, bn_size)
        self.bn = bottleneck(in_size, bn_size)
        
    def forward(self, x):
        x = self.enc(x)
        r1 = self.dec(self.bn(x))
        return r1

In [None]:
class Classifier(nn.Module):
  def __init__(self):
        super().__init__()
        self.c = nn.Sequential(
            nn.Conv2d(1,8,3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(8,16,3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16,32,3,padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(start_dim=1),
            nn.Linear(288,10))
        
  def forward(self, x):
        self.out = self.c(x)
        return self.out 

In [4]:
train_data = MNIST('../../mnist_digits/', train=True, download=True,transform=torchvision.transforms.ToTensor())
test_data = MNIST('../../mnist_digits/', train=False, download=True,transform=torchvision.transforms.ToTensor())

In [26]:
def gen_PGD(x_nat, y, epsilon, loss_f, models, k=40, a=3/255):
    for model in models:
        model.eval()
    y = y.to(device)
    x_nat = x_nat.to(device)
    x = x_nat + (2*epsilon*torch.rand(x_nat.shape) - epsilon).to(device)
    x = torch.clamp(x, 0, 1)
    for i in range(k):
        x.requires_grad = True
        for model in models:
            model.zero_grad()
        loss = loss_f(x,y)
        loss.backward()
        x = x + a*x.grad.data.sign()
        x = torch.clamp(x, 0, 1)
        perturb = torch.clamp(x-x_nat, -epsilon, epsilon)
        x = (x_nat + perturb).detach()
    return x

In [27]:
def gen_FGSM(x, y, epsilon, loss_f, models):
    for model in models:
        model.zero_grad()
    x.requires_grad = True
    loss = loss_f(x,y)
    loss.backward()
    perturbed_x = torch.clamp(x + epsilon*(x.grad.data).sign(), min=0, max=1.0)
    return perturbed_x.detach()

In [28]:
loss_ce = nn.CrossEntropyLoss()
def eval_AE(auto, cls, epsilons, attack_function, label, DN):
    results = []
    if DN:
        loss_f = lambda x,y: loss_ce(cls(auto(x)[0]),y)**2 + loss_ce(cls(auto(x)[1]),y)**2
        a_type = 'DNA (dual)'
    else:
        loss_f = lambda x,y: loss_ce(cls(auto(x)),y)
        a_type = 'AE'
    for epsilon in epsilons:
        total_correct = 0.0
        for x,y in DataLoader(test_data, batch_size=500):
            x = x.to(device)
            y = y.to(device)
            if epsilon != 0.0:
                x = attack_function(x, y, epsilon, loss_f, [auto, cls])
            xhat = auto(x)
            if DN:
                yhat1 = torch.argmax(cls(xhat[0]), dim=1)
                yhat2 = torch.argmax(cls(xhat[1]), dim=1)
                total_correct = total_correct + torch.sum((yhat1==y) | (yhat2==y)).detach().cpu().numpy()

            else:
                yhat1 = torch.argmax(cls(xhat), dim=1)
                total_correct = total_correct + torch.sum(yhat1==y).detach().cpu().numpy()
        acc = total_correct/len(test_data)
        results.append([a_type, label, epsilon, acc])
    if DN:
        loss_f = [lambda x,y: loss_ce(cls(auto(x)[0]),y), lambda x,y: loss_ce(cls(auto(x)[1]),y)]
        a_type = 'DNA (single)'
        for epsilon in epsilons:
            total_correct = 0.0
            for x,y in DataLoader(test_data, batch_size=500):
                x = x.to(device)
                y = y.to(device)
                for i in range(2):
                    if epsilon != 0:
                        x_adv = attack_function(x, y, epsilon, loss_f[i], [auto, cls])
                    else:
                        x_adv = x
                    xhat = auto(x_adv)
                    yhat1 = torch.argmax(cls(xhat[0]), dim=1)
                    yhat2 = torch.argmax(cls(xhat[1]), dim=1)
                    total_correct = total_correct + torch.sum((yhat1==y) | (yhat2==y)).detach().cpu().numpy()
            acc = total_correct/(2*len(test_data))
            results.append([a_type, label, epsilon, acc])
    return results

In [None]:
#Run the Adversarial Tests
epsilons = [0.0, 0.05, 0.1, 0.15]
dna = AE(in_size, bn_size).to(device)
cls = Classifier().to(device)
dna.load_state_dict(torch.load('models/ae'))
cls.load_state_dict(torch.load('models/cls_ae'))

results = eval_AE(dna, cls, epsilons, gen_FGSM, 'FGSM', DN=False)
results.extend(eval_AE(dna, cls, epsilons, gen_PGD, 'PGD', DN=False))

dna = DNA(in_size, bn_size).to(device)
cls = Classifier().to(device)
dna.load_state_dict(torch.load('models/dna'))
cls.load_state_dict(torch.load('models/cls_dna'))

results.extend(eval_AE(dna, cls, epsilons, gen_FGSM, 'FGSM', DN=True))
results.extend(eval_AE(dna, cls, epsilons, gen_PGD, 'PGD', DN=True)) 

In [None]:
#save results in dataframe
results_table = pd.DataFrame(results, columns=['Model', 'Attack', r'$\epsilon$', 'Accuracy'])
results_table.columns = ['Model', 'Attack', r'$\epsilon$', 'Accuracy']
results_table['Accuracy'] = results_table['Accuracy']*100
results_table.to_pickle("results.pkl")

In [None]:
results_table = results_table[results_table.Model != 'DNA (single)']
results_table['Model'] = results_table['Model'].replace('DNA (dual)', 'DNA')
sns.set_theme()
fig, ax = plt.subplots(1,2, figsize=(15,5))

sns.lineplot(ax=ax[0], data=results_table[results_table.Attack == 'FGSM'], x=r'$\epsilon$', y='Accuracy', hue='Model', palette='Blues_d')
ax[0].set_ylim(0,100)
ax[0].set_xlabel(r'Attack Magnitude ($\epsilon$; FGSM)')
ax[0].set_ylabel('Accuracy (%)')

sns.lineplot(ax=ax[1], data=results_table[results_table.Attack == 'PGD'], x=r'$\epsilon$', y='Accuracy', hue='Model', palette='Blues_d')
ax[1].set_ylim(0,100)
ax[1].set_xlabel(r'Attack Magnitude ($\epsilon$; PGD)')
ax[1].set_ylabel('Accuracy (%)')

In [None]:
fig.savefig('line_plot.jpg', bbox_inches = 'tight', pad_inches = 0, format='jpg', dpi=600)