In [1]:
!pip install foolbox -q
import foolbox as fb
import torch
import torchvision
from torchvision.models import ResNet18_Weights
from helpers import get_model
import numpy as np
import matplotlib.pyplot as plt
import eagerpy as ep
import copy
from torch.quantization import quantize_fx
import torchvision.transforms as transforms
import warnings
warnings.filterwarnings("ignore")


parameters = {'quantise':True}


classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def load_test_data(data_dir):
    transform = transforms.Compose([transforms.ToTensor(), transforms.RandomErasing()])
    testset = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)   
    return testset

data_dir = '../../data'
tes = load_test_data(data_dir)
testloader = torch.utils.data.DataLoader(tes, batch_size=20, shuffle=False, num_workers=2)
dataloader = {"test": testloader}

model = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT)
dataiter = iter(dataloader['test'])
img, lab = next(dataiter)


if parameters['quantise'] == True:
    m = copy.deepcopy(model)
    m.to("cpu")
    m.eval()
    qconfig_dict = {"": torch.quantization.get_default_qconfig("fbgemm")}
    model_prepared = quantize_fx.prepare_fx(m, qconfig_dict, img)

    with torch.inference_mode():
        for _ in range(10):
            img, lab = next(dataiter)
            model_prepared(img)
    model = quantize_fx.convert_fx(model_prepared)


test_out = model(img)
print("model loading success")

Files already downloaded and verified
model loading success


In [None]:
model = model.eval()
preprocessing = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], axis=-3)
bounds = (0, 1)
fmodel = fb.PyTorchModel(model, bounds=bounds, preprocessing=preprocessing)
print("successfully converted to foolbox model")
fmodel = fmodel.transform_bounds((0, 1))
images, labels = fb.utils.samples(fmodel, dataset='cifar10', batchsize=16)
images = ep.astensor(images)
labels = ep.astensor(labels)
accuracy = fb.utils.accuracy(fmodel, images, labels)
print(accuracy)


attack = fb.attacks.LinfDeepFoolAttack()
#attack = fb.attacks.LinfFastGradientAttack()
#attack = fb.attacks.LinfProjectedGradientDescentAttack()
#attack = fb.attacks.SaltAndPepperNoiseAttack() 

raw, clipped, is_adv = attack(fmodel, images, labels, epsilons=0.03)
accuracy = is_adv.float32().mean().item()
print(accuracy)