# White-box Attack on CIFAR10

In [25]:
import sys

import torch
import torch.nn as nn

#sys.path.insert(0, '..')
import torchattacks

In [26]:
# imports
import torch
import torch.nn as nn
import torchvision.models as models
from torch.nn import CrossEntropyLoss
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from utils.metrics import topk_acc, real_acc, AverageMeter
from utils.parsers import get_training_parser
from utils.optimizer import get_optimizer, get_scheduler, OPTIMIZERS_DICT, SCHEDULERS

from models.networks import get_model
from data_utils.data_stats import *

import matplotlib.pyplot as plt
import argparse
import time

In [27]:
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

torch.manual_seed(0)

<torch._C.Generator at 0x1150742f0>

In [28]:
from torchattacks import PGD
from utils_attack import *

# LOAD MLP

In [29]:
dataset = 'cifar10'                 # One of cifar10, cifar100, stl10, imagenet or imagenet21
architecture = 'B_12-Wi_1024'
data_resolution = 32                # Resolution of data as it is stored
crop_resolution = 64                # Resolution of fine-tuned model (64 for all models we provide)
num_classes = CLASS_DICT[dataset]
eval_batch_size = 1024
checkpoint = 'in21k_cifar10'

In [30]:
torch.backends.cuda.matmul.allow_tf32 = True

In [69]:
class MLP_Wrapper(nn.Module):
    def __init__(self, architecture, resolution, num_classes, checkpoint):
        super(MLP_Wrapper, self).__init__()
        self.model = get_model(architecture=architecture, resolution=crop_resolution, num_classes=num_classes, checkpoint=checkpoint)
        self.resize = transforms.Resize((resolution, resolution))

    def forward(self, x):
        x = self.resize(x)
        x = x.flatten(1)
        x = self.model(x)
        return x

In [70]:
model_mlp = MLP_Wrapper(architecture, crop_resolution, num_classes, checkpoint)
model_mlp.eval()

Weights already downloaded
Load_state output <All keys matched successfully>


MLP_Wrapper(
  (model): BottleneckMLP(
    (linear_in): Linear(in_features=12288, out_features=1024, bias=True)
    (linear_out): Linear(in_features=1024, out_features=10, bias=True)
    (blocks): ModuleList(
      (0-11): 12 x BottleneckBlock(
        (block): Sequential(
          (0): Linear(in_features=1024, out_features=4096, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=4096, out_features=1024, bias=True)
        )
      )
    )
    (layernorms): ModuleList(
      (0-11): 12 x LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
  )
  (resize): Resize(size=(64, 64), interpolation=bilinear, max_size=None, antialias=warn)
)

# LOAD CNN

In [71]:
from cifar10_models.resnet import *
model_resnet = resnet18()

# Use the weights of the resnet18 model trained on CIFAR10
weights = torch.load('state_dicts/resnet18.pt')
model_resnet.load_state_dict(weights)

model_resnet.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

# LOAD ViT

In [None]:
import timm
model_ViT_tmp = timm.create_model("vit_base_patch16_384", pretrained=True)
model_ViT_tmp.head = nn.Linear(model_ViT_tmp.head.in_features, 10) # CIFAR10
checkpoint = torch.load("finetuned_models/ViT_CIFAR10.t7", map_location=torch.device("cpu"))
model_ViT_tmp = torch.nn.DataParallel(model_ViT_tmp)
model_ViT_tmp.load_state_dict(checkpoint['model'])
print("model loaded")
# ViT uses 384x384 pixel image
input_size = 384

In [None]:
class ViT_Wrapper(nn.Module):
    def __init__(self, model, input_size):
        super(ViT_Wrapper, self).__init__()
        self.model = model
        self.resize = transforms.Resize(input_size)

    def forward(self, x):
        x = self.resize(x)
        x = self.model(x)
        return x

In [None]:
model_ViT = ViT_Wrapper(model_ViT_tmp, input_size)
model_ViT.eval()

# LOAD DATA

In [105]:
mean = (0.4914, 0.4822, 0.4465)
std = (0.2471, 0.2435, 0.2616)

transform = transforms.Compose([
    #transforms.Resize((crop_resolution, crop_resolution)),
    transforms.ToTensor(),
    #transforms.Normalize(mean, std),
])

In [94]:
#sys.path.insert(0, '..')
import robustbench
from robustbench.data import load_cifar10
from robustbench.utils import load_model, clean_accuracy

images, labels = load_cifar10(n_examples=500, transforms_test=transform)
print('[Data loaded]')

device = "cpu"
acc = clean_accuracy(model_mlp, images.to(device), labels.to(device))
print('[MLP loaded]')
#print('MLP Acc: %2.2f %%'%(acc*100))

acc = clean_accuracy(model_resnet, images.to(device), labels.to(device))
print('[CNN loaded]')
#print('CNN Acc: %2.2f %%'%(acc*100))

Files already downloaded and verified
[Data loaded]
[MLP loaded]
[CNN loaded]


In [95]:
def compare_adv_attcks(model1, model2, atk_model1, images, labels, iterations=10):

    model1.eval()
    model2.eval()

    successful_adv_attacks_model1 = []
    successful_adv_attacks_model2 = []

    for it in range(iterations):
        print("Iteration:", it)
        adv_images = atk_model1(images, labels)
        predictions = get_pred(model1, adv_images, "cpu")

        seccessful_adv_images = adv_images[predictions != labels]
        true_labels_of_succ_adv_images = labels[predictions != labels]
        successful_adv_attacks_model1.append(seccessful_adv_images.shape[0])

        predictions = get_pred(model2, seccessful_adv_images, "cpu")
        seccessful_adv_images = seccessful_adv_images[predictions != true_labels_of_succ_adv_images]
        successful_adv_attacks_model2.append(seccessful_adv_images.shape[0])

    return successful_adv_attacks_model1, successful_adv_attacks_model2




In [96]:
atk = PGD(model_mlp, eps=8/255, alpha=2/225, steps=10, random_start=True)
atk.set_normalization_used(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
l1, l2 = compare_adv_attcks(model_mlp, model_resnet, atk, images, labels, 2)


Iteration: 0
Iteration: 1


In [99]:
atk = PGD(model_resnet, eps=8/255, alpha=2/225, steps=10, random_start=True)
atk.set_normalization_used(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
l11, l22 = compare_adv_attcks(model_resnet, model_mlp, atk, images, labels, 2)

Iteration: 0
Iteration: 1


In [98]:
print(l1, l2)

[500, 499] [410, 406]


In [100]:
print(l11, l22)

[500, 500] [312, 306]


In [None]:
idx = 0
pre = get_pred(model_resnet, adv_images[idx:idx+1], device)
imshow(adv_images[idx:idx+1], title="True:%d, Pre:%d"%(labels[idx], pre))

In [None]:
idx = 0
pre = get_pred(model, adv_images1[idx:idx+1].flatten(1), device)
imshow(adv_images1[idx:idx+1], title="True:%d, Pre:%d"%(labels[idx], pre))