<a href="https://colab.research.google.com/github/wielandbrendel/robustness_workshop/blob/master/02_mixup/mixup_attack.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# this cell contains all the commands necessary to run this notebook in colab
# if you cloned the repository and run this notebook locally you do not need to run these commands
!wget https://raw.githubusercontent.com/wielandbrendel/robustness_workshop/master/02_mixup/resnet_3layer.py
!wget https://raw.githubusercontent.com/wielandbrendel/robustness_workshop/master/02_mixup/transforms.py

In [0]:
# run this cell the first time you execute this notebook to download the pretrained weights
!wget https://github.com/wielandbrendel/robustness_workshop/releases/download/v0.0.1/mixup_model_IAT.ckpt

In [0]:
# install the latest master version of Foolbox 3.0
!pip3 install git+https://github.com/bethgelab/foolbox.git

In [0]:
!pip3 install --upgrade typing_extensions

In [0]:
import os
import torch
import torchvision
import numpy as np
import foolbox as fb
import eagerpy as ep

import transforms
import resnet_3layer as resnet

In [0]:
num_sample_MIOL = 15
lamdaOL = 0.6

### Load backbone model

In [0]:
CLASSIFIER = resnet.model_dict['resnet50']
classifier = CLASSIFIER(num_classes=10)

device = torch.device("cuda:0")
classifier = classifier.to(device)

classifier.load_state_dict(torch.load('mixup_model_IAT.ckpt'))
classifier.eval();

### Construct image pools

In [0]:
def onehot(ind):
    vector = np.zeros([10])
    vector[ind] = 1
    return vector.astype(np.float32)

train_trans, test_trans = transforms.cifar_transform()
trainset = torchvision.datasets.CIFAR10(root='~/cifar/',
                                        train=False,
                                        download=True,
                                        transform=train_trans,
                                        target_transform=onehot)
testset = torchvision.datasets.CIFAR10(root='~/cifar/',
                                       train=False,
                                       download=True,
                                       transform=test_trans,
                                       target_transform=onehot)

# we reduce the testset for this workshop
testset.data = testset.data[:200]

dataloader_train = torch.utils.data.DataLoader(
    trainset,
    batch_size=1,
    shuffle=True,
    num_workers=2)

dataloader_test = torch.utils.data.DataLoader(
    testset,
    batch_size=10,
    shuffle=False,
    num_workers=5)

In [0]:
from tqdm import tqdm
num_pool = 10000
mixup_pool_OL = {}

for i in range(10):
    mixup_pool_OL.update({i: []})

for i, data_batch in tqdm(enumerate(dataloader_train), total=num_pool):
    img_batch, label_batch = data_batch
    img_batch = img_batch.to(device)
    _, label_ind = torch.max(label_batch.data, 1)
    mixup_pool_OL[label_ind.numpy()[0]].append(img_batch)
    if i >= (num_pool - 1):
        break

print('Finish constructing mixup_pool_OL')

### Construct surrogate models that wrap OL within model

In [0]:
import torch.nn as nn
import torch.nn.functional as F

soft_max = nn.Softmax(dim=-1)

class CombinedModel(nn.Module):
    def __init__(self, classifier):
        super(CombinedModel, self).__init__()
        self.classifier = classifier

    def forward(self, img_batch):
        pred_cle_mixup_all_OL = 0 # torch.Tensor([0.]*10)
        
        # forward pass without PL/OL
        pred_cle = self.classifier(img_batch)
        cle_con, predicted_cle = torch.max(soft_max(pred_cle.data), 1)
        predicted_cle = predicted_cle.cpu().numpy()
            
        # perform MI-OL
        for k in range(num_sample_MIOL):
            mixup_img_batch = np.empty(img_batch.shape, dtype=np.float32)
            
            for b in range(img_batch.shape[0]):
                # CLEAN
                xs_cle_label = np.random.randint(10)
                while xs_cle_label == predicted_cle[b]:
                    xs_cle_label = np.random.randint(10)
                xs_cle_index = np.random.randint(len(mixup_pool_OL[xs_cle_label]))
                mixup_img_cle = (1 - lamdaOL) * mixup_pool_OL[xs_cle_label][xs_cle_index][0]
                mixup_img_batch[b] = mixup_img_cle.cpu().detach().numpy()

            mixup_img_batch = ep.from_numpy(ep.astensor(img_batch), mixup_img_batch).raw + lamdaOL * img_batch
            pred_cle_mixup = classifier(mixup_img_batch)
            pred_cle_mixup_all_OL = pred_cle_mixup_all_OL + soft_max(pred_cle_mixup)

        pred_cle_mixup_all_OL = pred_cle_mixup_all_OL / num_sample_MIOL

        return pred_cle_mixup_all_OL

In [0]:
combined_classifier = CombinedModel(classifier)
combined_classifier.eval();

In [0]:
iAT_model = fb.models.PyTorchModel(classifier, bounds=(-1, 1), device=device)
iAT_OL_model = fb.models.PyTorchModel(combined_classifier, bounds=(-1, 1), device=device)

### Oblivious attack

In [0]:
acc = 0
total_samples = 0

for x_batch, y_batch in dataloader_test:
    x_batch = x_batch.to(device)
    y_batch = y_batch.argmax(1).to(device)
    
    acc += fb.utils.accuracy(iAT_OL_model, x_batch, y_batch) * x_batch.shape[0]
    total_samples += x_batch.shape[0]

print(f'Clean accuracy: {acc / total_samples:.3f}')

In [0]:
acc = 0
total_samples = 0
epsilon = 8 / 255

attack = fb.attacks.LinfPGD()

for images, labels in dataloader_test:
    images = images.to(device)
    labels = labels.argmax(1).to(device)
    N = len(images)
    
    # PGD returns three values: (1) the raw adversarial images as returned by the
    # attack, (2) the raw adversarials clipped to the valid epsilon region and
    # (3) a boolean tensor indicating which perturbations are actually adversarial
    adv, adv_clipped, adv_mask = attack(iAT_model, images, criterion=fb.criteria.Misclassification(labels), epsilons=2 * epsilon)

    acc += fb.utils.accuracy(iAT_OL_model, adv_clipped, labels) * N
    total_samples += N
    
print()
print(f'Oblivious adversarial accuracy: {acc / total_samples:.3f}')