###Install robustness

In [1]:
!pip install robustness

Collecting robustness
[?25l  Downloading https://files.pythonhosted.org/packages/14/a7/08c95c2adaaac0e12db226f407cd2f1b87626be89ecf8fcf57af334920b4/robustness-1.1.post2-py3-none-any.whl (73kB)
[K     |████████████████████████████████| 81kB 2.2MB/s 
Collecting gitpython
[?25l  Downloading https://files.pythonhosted.org/packages/8c/f9/c315aa88e51fabdc08e91b333cfefb255aff04a2ee96d632c32cb19180c9/GitPython-3.1.3-py3-none-any.whl (451kB)
[K     |████████████████████████████████| 460kB 8.3MB/s 
Collecting GPUtil
  Downloading https://files.pythonhosted.org/packages/ed/0e/5c61eedde9f6c87713e89d794f01e378cfd9565847d4576fa627d758c554/GPUtil-1.4.0.tar.gz
Collecting cox
  Downloading https://files.pythonhosted.org/packages/3a/42/7602dbce4bfc3740c678117c90625d552985b7f6d33423d116837a1a0777/cox-0.1.post2-py3-none-any.whl
Collecting py3nvml
[?25l  Downloading https://files.pythonhosted.org/packages/53/b3/cb30dd8cc1198ae3fdb5a320ca7986d7ca76e23d16415067eafebff8685f/py3nvml-0.2.6-py3-none-any.whl

###Imports

In [0]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from robustness.model_utils import make_and_restore_model
from robustness.datasets import CIFAR
from tqdm import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

###Load and prepare CIFAR-10

In [3]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
     ])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)


testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)


classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [0]:
ds = CIFAR('data') 

In [0]:
device = torch.device("cuda:0")

###Load model

In [0]:
model, _ = make_and_restore_model(arch='resnet50', dataset=ds, resume_path='cifar_l2_0_25.pt')
model.eval()
model.to(device)

In [0]:
_, test_loader = ds.make_loaders(workers=0, batch_size=1, shuffle_train=False, shuffle_val=False)
_, test_loader_shuffled = ds.make_loaders(workers=0, batch_size=1, shuffle_train=True, shuffle_val=True)

In [0]:
def compute_accuracy(logits, y_true, device='cuda:0'):
    y_pred = torch.argmax(logits, dim=1)
    y_true_on_device = y_true.to(device)
    accuracy = (y_pred == y_true_on_device).float().mean()
    return accuracy

In [0]:
criterion = nn.MSELoss()

In [0]:
lr = 0.01

###Generation of robust dataset

In [0]:
#генерация робустного датасета
robust_ds = []

for i in tqdm(range(len(test_loader))):
    X_batch, y_batch = next(iter(test_loader))
    X_rand_batch, y_rand_batch = next(iter(test_loader_shuffled))
    X_batch_gpu = X_batch.to(device)
    X_rand_batch_gpu = X_rand_batch.to(device)

    n = 0
    while n < 1000:
        X_rand_batch_gpu = np.clip(X_rand_batch_gpu.detach().cpu(), 0, 1)
        X_rand_batch_gpu.requires_grad_(True)
        logits = model(X_batch_gpu)[0]
        logits_rand = model(X_rand_batch_gpu)[0]
        loss = criterion(logits, logits_rand)
        grad = torch.autograd.grad(loss, X_rand_batch_gpu)
        x_clone = X_rand_batch_gpu.clone()
        x_clone -= lr*grad[0]
        X_rand_batch_gpu = x_clone.clone()
        n = n + 1
    robust_ds.append(X_rand_batch_gpu)


In [0]:
torch.save(robust_ds, 'robust_dataset.pt')

###Explore robust dataset and compare it with original CIFAR-10

In [0]:
n_pics_to_show = 10
fig, ax = plt.subplots(2, n_pics_to_show, figsize=(20, 10))

for i in range(n_pics_to_show):
    rand_idx = np.random.randint(len(robust_ds))

    pic = robust_ds[rand_idx][0]
    pic2 = testset[rand_idx][0]

    pic_np = pic.cpu().data.numpy()

    pic_np = np.rollaxis(pic_np, 0, 3) # 3xHxW to HxWx3

    pic_np2 = pic2.cpu().data.numpy()
    pic_np2 = np.rollaxis(pic_np2, 0, 3) # 3xHxW to HxWx3

    ax[0, i].imshow(pic_np)


    ax[1, i].imshow(pic_np2)

