In [None]:
import os 
import sys 
import numpy as np
import pandas as pd 

npa = np.array

import torch
import foolbox as fb

import matplotlib as mpl
import matplotlib.pyplot as plt 
%matplotlib inline
import seaborn as sns
sns.set_style("white")
sns.set_style("ticks", {"xtick.major.size": 14, "ytick.major.size": 14})
sns.set_context("paper")
mpl.rcParams['axes.linewidth']=2.5
mpl.rcParams['ytick.major.width']=2.5
mpl.rcParams['xtick.major.width']=2.5

%load_ext autoreload
%autoreload 2
sys.path.insert(0, '../models')
sys.path.insert(0, '../')
import resnet_cifar
import resnet_mnist
import mnist_models
import cifar_models
from utils import get_mnist_test_loader
from utils import get_cifar10_test_loader
from utils import get_cifar100_test_loader


In [None]:
seed = 0
ROOT_PATH = '../'

torch.manual_seed(seed)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

# MNIST
MNIST_DATA_PATH = os.path.join(ROOT_PATH, "data", 'mnist')
mnist = get_mnist_test_loader(batch_size=1000, data_path=MNIST_DATA_PATH)

mnist_images, mnist_labels = [], []
for images, labels in mnist:
    mnist_images.append(images)
    mnist_labels.append(labels)
mnist_images = torch.cat(mnist_images, dim=0)
mnist_labels = torch.cat(mnist_labels, dim=0)
    
# Cifar 10 
C10_DATA_PATH = os.path.join(ROOT_PATH, "data", 'cifar10')
cifar10 = get_cifar10_test_loader(batch_size=100, data_path=C10_DATA_PATH, norm=False)

c10_images, c10_labels = [], []
for im, l in cifar10:
    c10_images.append(im)
    c10_labels.append(l)    
c10_images = torch.cat(c10_images, dim=0)
c10_labels = torch.cat(c10_labels, dim=0)

# Cifar 100 
C100_DATA_PATH = os.path.join(ROOT_PATH, "data", 'cifar100')
cifar100 = get_cifar100_test_loader(batch_size=100, data_path=C100_DATA_PATH, norm=False)

c100_images, c100_labels = [], []
for im, l in cifar100:
    c100_images.append(im)
    c100_labels.append(l)    
c100_images = torch.cat(c100_images, dim=0)
c100_labels = torch.cat(c100_labels, dim=0)

# MNIST

In [None]:
save_path = '../chkpts/mnist/weights.pt'
state_dict = torch.load(save_path)

afd_E = resnet_mnist.ResNet18FeatsNorm()
afd_Dc = mnist_models.LeNetDecoder(10)

afd_E.load_state_dict(state_dict['E_state_dict'])
afd_Dc.load_state_dict(state_dict['Dc_state_dict'])
afd_EDc = torch.nn.Sequential(afd_E, afd_Dc)

afd_EDc.to(device)
afd_EDc.eval();

In [None]:
attack_type = 'linfpgd'


attack_menu = {'linfpgd':{'attack': fb.attacks.LinfPGD(steps=40), 'eps': [0.0, 0.1, 0.3, 0.35, 0.4, 0.45, 0.5]},
               'l1pgd':{'attack': fb.attacks.L1PGD(), 'eps':[0.0, 10.0, 50.0, 100., 200, 400]},
               'l2pgd':{'attack': fb.attacks.L2PGD(), 'eps':[0.0, 2.0, 5.0, 10.0, 20.]},
               'fgsm':{'attack': fb.attacks.FGSM(random_start=True), 'eps':[0.0, 0.1, 0.3, 0.35, 0.4, 0.45, 0.5]},
               'deepfool':{'attack': fb.attacks.LinfDeepFoolAttack(steps=50), 'eps':[0.0, 0.01, 0.1, 0.3, 0.5, 1.0]},
               'cw':{'attack': fb.attacks.L2CarliniWagnerAttack(steps=50, stepsize=0.05), 'eps':[0.0, 1., 2., 5.]},
               'ddn':{'attack': fb.attacks.DDNAttack(steps=100), 'eps':[0.0, 1., 2., 5.0, 10.0]},
              }

assert attack_type in ('linfpgd', 'l1pgd', 'l2pgd', 'fgsm', 'deepfool', 'cw', 'ddn')
attack, epsilons = attack_menu[attack_type]['attack'], attack_menu[attack_type]['eps']


In [None]:
fmodel = fb.PyTorchModel(afd_EDc, bounds=(0,1), device=device)

advs, success = [], []
for images, labels in mnist:
    try:
        _, current_advs, current_success = attack(fmodel, images.to(device), labels.to(device), epsilons=epsilons)
    except: 
        continue
    advs.append(current_advs)
    success.append(current_success)
afd_adv_images = torch.cat([torch.stack(ad) for ad in advs], dim=1).cpu().numpy()
success = torch.cat(success, dim=-1)
afd_attack_success = success.cpu().numpy()

In [None]:
MARKER_SIZE = 15

sns.set_style("ticks", {"xtick.major.size": 14, "ytick.major.size": 14})
sns.set_context("poster")
mpl.rcParams['axes.linewidth']=2.5
mpl.rcParams['ytick.major.width']=2.5
mpl.rcParams['xtick.major.width']=2.5

plt.figure(figsize=(7, 5))
plt.plot(epsilons, 100 * (1 - afd_attack_success.mean(-1).round(2)), marker='.', color='g', markersize=MARKER_SIZE)
plt.xlabel('eps')
plt.ylabel('Accuracy')
plt.title(f'Robustness - {attack_type}')

sns.despine()
plt.tight_layout()
plt.show()

## AutoAttack

In [None]:
from autoattack import AutoAttack

# apgd-ce, apgd-t, fab-t, square
epsilon = 0.3
adversary = AutoAttack(afd_EDc, norm='Linf', eps=epsilon, version='standard')
x_adv = adversary.run_standard_evaluation(mnist_images, mnist_labels, bs=1000)

# Cifar10

In [None]:
attack_type = 'deepfool'

attack_menu = {'linfpgd':{'attack': fb.attacks.LinfPGD(steps=20, abs_stepsize=2./255), 'eps': [0.0, 2./255, 4./255, 8./255, 16./255, 32./255]},
               'l1pgd':{'attack': fb.attacks.L1PGD(), 'eps': [0.0, 10.0, 50.0, 100., 200, 400]},
               'l2pgd':{'attack': fb.attacks.L2PGD(), 'eps': [0.0, 1.0, 2.0, 5.0, 10.0, 20.0]},
               'fgsm':{'attack': fb.attacks.FGSM(), 'eps': [0.0, 2./255, 4./255, 8./255, 16./255, 32./255, 64./255]},
               'deepfool':{'attack': fb.attacks.LinfDeepFoolAttack(), 'eps': [0.0, 2./255, 4./255, 8./255, 16./255, 32./255, 64./255]},
               'cw':{'attack': fb.attacks.L2CarliniWagnerAttack(steps=20), 'eps': [1.]}, #[0., 1., 2., 5.]},
               'ddn':{'attack': fb.attacks.DDNAttack(steps=100), 'eps':[0.0, 2.0, 5.0, 10.0, 15.]},
              }

assert attack_type in ('linfpgd', 'l1pgd', 'l2pgd', 'fgsm', 'deepfool', 'cw', 'ddn')
attack, epsilons = attack_menu[attack_type]['attack'], attack_menu[attack_type]['eps']



In [None]:
save_path = '../chkpts/cifar10/weights.pt'
state_dict = torch.load(save_path)

afd_E = resnet_cifar.ResNet18Feats()
afd_Dc = resnet_cifar.ResNetDecoder()
afd_E.load_state_dict(state_dict['E_state_dict'])
afd_Dc.load_state_dict(state_dict['Dc_state_dict'])

afd_EDc = torch.nn.Sequential(afd_E, afd_Dc)

afd_EDc.to(device)
afd_EDc.eval();

In [None]:
fmodel = fb.PyTorchModel(afd_EDc, bounds=(0,1), device=device)

advs, success = [], []
for images, labels in cifar10:
    _, current_advs, current_success = attack(fmodel, images.to(device), labels.to(device), epsilons=epsilons)
    advs.append(current_advs)
    success.append(current_success)
afd_adv_images = torch.cat([torch.stack(ad) for ad in advs], dim=1).cpu().numpy()
success = torch.cat(success, dim=-1)
afd_attack_success = success.cpu().numpy()



In [None]:
sns.set_style("ticks", {"xtick.major.size": 14, "ytick.major.size": 14})
sns.set_context("poster")
mpl.rcParams['axes.linewidth']=2.5
mpl.rcParams['ytick.major.width']=2.5
mpl.rcParams['xtick.major.width']=2.5

MARKER_SIZE=15

plt.figure(figsize=(9, 5))
plt.plot(epsilons, 100 * (1 - afd_attack_success.mean(-1).round(2)), marker='.', markersize=MARKER_SIZE, color='g')

plt.xlabel('eps')
plt.ylabel('Accuracy')
plt.title(f'Robustness - {attack_type}')

sns.despine()
plt.tight_layout()
plt.show()

# Cifar100

In [None]:
attack_type = 'deepfool'

attack_menu = {'linfpgd':{'attack': fb.attacks.LinfPGD(steps=20, abs_stepsize=2./255), 'eps': [0.0, 2./255, 4./255, 8./255, 16./255, 32./255]},
               'l1pgd':{'attack': fb.attacks.L1PGD(), 'eps': [0.0, 10.0, 50.0, 100., 200, 400]},
               'l2pgd':{'attack': fb.attacks.L2PGD(), 'eps': [0.0, 1.0, 2.0, 5.0, 10.0, 20.0]},
               'fgsm':{'attack': fb.attacks.FGSM(), 'eps': [0.0, 2./255, 4./255, 8./255, 16./255, 32./255, 64./255]},
               'deepfool':{'attack': fb.attacks.LinfDeepFoolAttack(), 'eps': [0.0, 2./255, 4./255, 8./255, 16./255, 32./255, 64./255]},
               'cw':{'attack': fb.attacks.L2CarliniWagnerAttack(steps=20), 'eps': [1.]}, #[0., 1., 2., 5.]},
               'ddn':{'attack': fb.attacks.DDNAttack(steps=100), 'eps':[0.0, 2.0, 5.0, 10.0, 15.]}
              }

assert attack_type in ('linfpgd', 'l1pgd', 'l2pgd', 'fgsm', 'deepfool', 'cw', 'ddn')
attack, epsilons = attack_menu[attack_type]['attack'], attack_menu[attack_type]['eps']



In [None]:
save_path = '../chkpts/cifar100/weights.pt'
state_dict = torch.load(save_path)

afd_E = resnet_cifar.ResNet18Feats()
afd_Dc = resnet_cifar.ResNetDecoder(num_classes=100)
afd_E.load_state_dict(state_dict['E_state_dict'])
afd_Dc.load_state_dict(state_dict['Dc_state_dict'])

afd_EDc = torch.nn.Sequential(afd_E, afd_Dc)

afd_EDc.to(device)
afd_EDc.eval();

In [None]:
fmodel = fb.PyTorchModel(afd_EDc, bounds=(0,1), device=device)

advs, success = [], []
for images, labels in cifar100:
    _, current_advs, current_success = attack(fmodel, images.to(device), labels.to(device), epsilons=epsilons)
    advs.append(current_advs)
    success.append(current_success)
afd_adv_images = torch.cat([torch.stack(ad) for ad in advs], dim=1).cpu().numpy()
success = torch.cat(success, dim=-1)
afd_attack_success = success.cpu().numpy()



In [None]:
sns.set_style("ticks", {"xtick.major.size": 14, "ytick.major.size": 14})
sns.set_context("poster")
mpl.rcParams['axes.linewidth']=2.5
mpl.rcParams['ytick.major.width']=2.5
mpl.rcParams['xtick.major.width']=2.5

MARKER_SIZE=15

plt.figure(figsize=(9, 5))
plt.plot(epsilons, 100 * (1 - afd_attack_success.mean(-1).round(2)), marker='.', markersize=MARKER_SIZE, color='g')

plt.xlabel('eps')
plt.ylabel('Accuracy')
plt.title(f'Robustness - {attack_type}')

sns.despine()
plt.tight_layout()
plt.show()