In [105]:
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
from foolbox import PyTorchModel, accuracy, samples
from foolbox.attacks import LinfPGD, LinfFastGradientAttack
from trainers import Trainer, FGSMTrainer
from robustbench.model_zoo.models import Carmon2019UnlabeledNet
from utils import adversarial_accuracy, fgsm_
import eagerpy as ep
from Nets import CIFAR_Wide_Res_Net
%load_ext autoreload
%autoreload 2
%aimport Nets, trainers

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [98]:
# setup
device = torch.device("cuda")
batch_size = 128
# remove the normalize
transform = transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
        
normalized_min = (0 - 0.5) / 0.5
normalized_max = (1 - 0.5) / 0.5
train_dataset = datasets.CIFAR10(root='./data', train=True,
                                download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                  shuffle=True, num_workers=2)
test_dataset = datasets.CIFAR10(root='./data', train=False,
                               download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                 shuffle=False, num_workers=2)
classes = classes = ('plane', 'car', 'bird', 'cat',
   'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [99]:
model = CIFAR_Wide_Res_Net(device).eval()
model.load_state_dict(torch.load("models/cifar_10_wide_res_net_22e.model"))

<All keys matched successfully>

In [100]:
fgsm_model = CIFAR_Wide_Res_Net(device).eval()
fgsm_model.load_state_dict(torch.load("models/fgsm_eps8.model"))

<All keys matched successfully>

In [101]:
def get_accuracy(model, attack=None, epsilon=0.03):
    fmodel = PyTorchModel(model, bounds=(-1, 1))
    correct = 0
    for images, labels in test_loader:
        images, labels = images.to(device), labels.type(torch.cuda.LongTensor)
        if attack is None:
            correct += accuracy(fmodel, images, labels) * images.shape[0]
        else:
            raw_advs, clipped_advs, success = attack(fmodel, images, labels, epsilons=epsilon)
            correct += (~success).sum().item()
    return correct / len(test_loader.dataset)

In [26]:
print("Model accuracy: {}%, FGSM model accuracy: {}%".format(get_accuracy(model)*100, get_accuracy(fgsm_model)*100))

Model accuracy: 85.0%, FGSM model accuracy: 80.38%


In [102]:
eps = 0.03
attack = FGSM()
print("Model accuracy FGSM attack for eps={}: {}%, FGSM model accuracy: {}%".format(eps, get_accuracy(model, attack=attack, epsilon=eps)*100, get_accuracy(fgsm_model, attack=attack, epsilon=eps)*100))

Model accuracy FGSM attack for eps=0.03: 9.98%, FGSM model accuracy: 57.74%


In [104]:
adversarial_accuracy(fgsm_model, test_loader, attack=fgsm_, eps=0.03, device=device, normalized_min=normalized_min, normalized_max=normalized_max, random_step=False)

0 / 10000
1280 / 10000
2560 / 10000
3840 / 10000
5120 / 10000
6400 / 10000
7680 / 10000
8960 / 10000


37.75

In [66]:
# apply the attack
attack = LinfPGD()
epsilons = [
    0.0,
    0.0002,
    0.0005,
    0.0008,
    0.001,
    0.0015,
    0.002,
    0.003,
    0.01,
    0.1,
    0.3,
    0.5,
    1.0,
]
raw_advs, clipped_advs, success = attack(fmodel, images, labels, epsilons=epsilons)

# # calculate and report the robust accuracy (the accuracy of the model when
# # it is attacked)
robust_accuracy = 1 - success.type(torch.cuda.FloatTensor).mean(axis=-1)
print("robust accuracy for perturbations with")
for eps, acc in zip(epsilons, robust_accuracy):
    print(f"  Linf norm ≤ {eps:<6}: {acc.item() * 100:4.1f} %")

robust accuracy for perturbations with
  Linf norm ≤ 0.0   :  0.0 %
