In [None]:
!pip install gdown

In [None]:
!gdown https://drive.google.com/uc?id=1-7z0lFddFDcy97On7vO5MlucFUiFJYi1

In [None]:
import numpy as np
import json

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as Data

import torchvision.utils
from torchvision import models
import torchvision.datasets as dsets
import torchvision.transforms as transforms

In [None]:
# load data
data = torch.load('lab_img.pt')

In [None]:
# set device
use_cuda = True
device = torch.device("cuda" if use_cuda and torch.cuda.is_available() else "cpu")
print('device', device)

In [None]:
# load resnet101
model = models.resnet101(pretrained=True)

# use cuda if available
if torch.cuda.is_available():
    model.cuda()

model = model.eval()

print('use cuda', torch.cuda.is_available())

In [None]:
# test clean accuracy
total = 0
correct = 0
for d in data:
    lab, img = d
    lab, img = lab.cuda(), img.cuda()
    output = model(img)
    _, pre = torch.max(output.data, 1)
    total += 1
    if pre == lab:
        correct += 1
print('clean accuracy', correct/total)

In [None]:
# CW-L2 Attack
# Based on the paper, i.e. not exact same version of the code on https://github.com/carlini/nn_robust_attacks
# (1) Binary search method for c, (2) Optimization on tanh space, (3) Choosing method best l2 adversaries is NOT IN THIS CODE.
def cw_l2_attack(model, images, labels, targeted=False, c=1e-4, kappa=0, max_iter=1000, learning_rate=0.01) :

    images = images.to(device)     
    labels = labels.to(device)

    # Define f-function
    def f(x) :

        outputs = model(x)
        one_hot_labels = torch.eye(len(outputs[0]))[labels].to(device)

        i, _ = torch.max((1-one_hot_labels)*outputs, dim=1)
        j = torch.masked_select(outputs, one_hot_labels.byte())
        
        # If targeted, optimize for making the other class most likely 
        if targeted :
            return torch.clamp(i-j, min=-kappa)
        
        # If untargeted, optimize for making the other class most likely 
        else :
            return torch.clamp(j-i, min=-kappa)
    
    w = torch.zeros_like(images, requires_grad=True).to(device)

    optimizer = optim.Adam([w], lr=learning_rate)

    prev = 1e10
    
    for step in range(max_iter) :

        a = 1/2*(nn.Tanh()(w) + 1)

        loss1 = nn.MSELoss(reduction='sum')(a, images)
        loss2 = torch.sum(c*f(a))

        cost = loss1 + loss2

        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        # Early Stop when loss does not converge.
        if step % (max_iter//10) == 0 :
            if cost > prev :
                print('Attack Stopped due to CONVERGENCE....')
                return a
            prev = cost
        
        print('- Learning Progress : %2.2f %%        ' %((step+1)/max_iter*100), end='\r')

    attack_images = 1/2*(nn.Tanh()(w) + 1)

    return attack_images

In [None]:
correct = 0
total = 0
advs = []

for lab, img in data:
    adv = cw_l2_attack(model, img, lab, targeted=False, c=0.1)
    lab = lab.to(device)
    outputs = model(adv)
    _, pre = torch.max(outputs.data, 1)
    
    total += 1
    correct += (pre == lab).sum()
    
    advs.append([lab, adv])

print('robust accuracy', (correct / total).item())

In [None]:
# show some clean examples
import matplotlib.pyplot as plt
cnt = 0
plt.figure(figsize=(8,10))
for i in range(4):
    for j in range(4):
        cnt += 1
        plt.subplot(4, 4,cnt)
        plt.xticks([], [])
        plt.yticks([], [])
        img = data[i*4+j][1][0].cpu().permute(1,2,0)
        plt.imshow(img, cmap="gray")
plt.tight_layout()
plt.show()

In [None]:
# show some adv examples
import matplotlib.pyplot as plt
cnt = 0
plt.figure(figsize=(8,10))
for i in range(4):
    for j in range(4):
        cnt += 1
        plt.subplot(4, 4,cnt)
        plt.xticks([], [])
        plt.yticks([], [])
        img = advs[i*4+j][1][0].cpu().permute(1,2,0)
        plt.imshow(img.detach().numpy(), cmap="gray")
plt.tight_layout()
plt.show()

In [None]:
torch.save(advs, 'cw-resnet101-advs.pt')