In [None]:
import os
import urllib.request
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision
import numpy as np
import foolbox as fb

# This code is tested with Foolbox 3.0.0b, you might
# have to install the latest master version from git w/
# 
# pip3 install git+https://github.com/bethgelab/foolbox.git
#
assert int(fb.__version__.split('.')[0]) >= 3

import resnet

import logging
logger = logging.getLogger('kwtalogger')
logger.setLevel(logging.DEBUG)
fh = logging.FileHandler('kwta_debug.log')
fh.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(message)s')
fh.setFormatter(formatter)
logger.addHandler(fh)

### load pretrained weights

In [None]:
filename = 'kwta_spresnet18_0.1_cifar_adv.pth'
url = f'https://github.com/wielandbrendel/robustness_workshop/releases/download/v0.0.1/{filename}'

if not os.path.isfile(filename):
    print('Downloading pretrained weights.')
    urllib.request.urlretrieve(url, filename)

### load data

In [None]:
norm_mean = 0
norm_var = 1

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((norm_mean,norm_mean,norm_mean), (norm_var, norm_var, norm_var)),
])

cifar_test = datasets.CIFAR10("./data", train=False, download=True, transform=transform_test)
test_loader = DataLoader(cifar_test, batch_size = 200, shuffle=True)

### load model

In [None]:
gamma = 0.1
epsilon = 0.031
filepath = f'kwta_spresnet18_{gamma}_cifar_adv.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

model = resnet.SparseResNet18(sparsities=[gamma, gamma, gamma, gamma], sparse_func='vol').to(device)
model.load_state_dict(torch.load(filepath, map_location=device))
model.eval();

### clean accuracy

In [None]:
acc = 0
total_number = 0

for images, labels in test_loader:
    logits = model(images.to(device))
    acc += np.sum(logits.detach().cpu().numpy().argmax(1) == labels.cpu().numpy())
    total_number += images.shape[0]

# the clean accuracy is much lower than what is reported in the paper
# but the authors claimed that this checkpoint is more robust.
print(f'Clean accuracy is {acc / total_number:.3f}')

### final attack

In [None]:
# wrap model as Foolbox model
fmodel = fb.models.PyTorchModel(model, bounds=(0, 1))

In [None]:
from tqdm.notebook import tqdm
import eagerpy as ep

def best_other_classes(logits: ep.Tensor, exclude: ep.Tensor) -> ep.Tensor:
    other_logits = logits - ep.onehot_like(logits, exclude, value=ep.inf)
    return other_logits.argmax(axis=-1)

def loss_fn(x, classes):
    logits = fmodel(x)

    c_minimize = classes
    c_maximize = best_other_classes(logits, classes)

    N = len(x)
    rows = range(N)
    
    logits_diffs = logits[rows, c_minimize] - logits[rows, c_maximize]
    assert logits_diffs.shape == (N,)

    return logits_diffs

def es_gradient_estimator(x, y, samples, sigma, clip=False):
    value = loss_fn(x, y)

    gradient = ep.zeros_like(x)
    for k in range(samples // 2):
        noise = ep.normal(x, shape=x.shape)

        pos_theta = x + sigma * noise
        neg_theta = x - sigma * noise

        if clip:
            pos_theta = pos_theta.clip(*bounds)
            neg_theta = neg_theta.clip(*bounds)

        pos_loss = loss_fn(pos_theta, y)
        neg_loss = loss_fn(neg_theta, y)

        gradient += (pos_loss - neg_loss)[:, None, None, None] * noise

    gradient /= 2 * sigma * 2 * samples

    return gradient

def gradient_estimator_pgd(images, labels):
    ep_images = ep.astensor(images.to(device))
    ep_labels = ep.astensor(labels.to(device))

    deltas = ep.zeros_like(ep_images)

    samples = 100
    sigma = 8/255
    lr = 0.01

    adversarials = ep.zeros_like(ep_images)
    mask = loss_fn(ep_images, ep_labels) >= 0

    for it in range(100):
        if it < 20:
            samples = 100
        elif it < 40:
            samples = 1000
        else:
            samples = 20000

        pert_images = (ep_images + deltas).clip(0, 1)
        grads = es_gradient_estimator(pert_images[mask], ep_labels[mask], samples, sigma)

        # update only subportion of deltas
        _deltas = np.array(deltas.numpy())
        _deltas[mask.numpy()] = (deltas[mask] - lr * grads.sign()).numpy()
        deltas = ep.from_numpy(deltas, _deltas)

        deltas = deltas.clip(-epsilon, epsilon)
        pert_images = (ep_images + deltas).clip(0, 1)

        new_logit_diffs = loss_fn(pert_images, ep_labels)
        mask = new_logit_diffs >= 0

        values = new_logit_diffs.numpy()
        message = f'({it} / {mask.sum()}) {float(new_logit_diffs.mean().raw):.3f}: {np.array2string(values[mask.numpy()], precision=2, separator=",")}'
        logger.debug(message)

        if mask.sum() == 0:
            break
            
    return adversarials, mask

In [None]:
acc = 0
total_images = 0

for k, (images, labels) in enumerate(test_loader):
    perturbations, correct = gradient_estimator_pgd(images, labels)
    
    acc += float(correct.sum().numpy())
    total_images += images.shape[0]
    print(f'({k}) model accuracy on perturbed images is {100 * acc / total_images:.1f}')