In [24]:
from __future__ import print_function

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np

import torchvision
import torchvision.transforms as transforms

import os
import argparse
import pandas as pd
import csv
import time

from models import *
from utils import progress_bar
from randomaug import RandAugment
from models.vit import ViT
from models.convmixer import ConvMixer

In [25]:
size = 32
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(size),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Add RandAugment with N, M(hyperparameter)
aug = True
if aug:  
    N = 2; M = 14;
    transform_train.transforms.insert(0, RandAugment(N, M))

# Prepare dataset
trainset = torchvision.datasets.CIFAR10(root='../../data', train=True, download=False, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=100, shuffle=True, num_workers=8)

testset = torchvision.datasets.CIFAR10(root='../../data', train=False, download=False, transform=transform_test)
test_dataloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=8)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [26]:
test_model = 'flipover'
# test_model = 'vanilla'
if test_model == 'vanilla':
    from models.vit_small import ViT
    dropout_scale = 0.1
else: 
    from models.vit_small_flip_att import ViT
    dropout_scale = 0.2
network = ViT(
    image_size = size,
    patch_size = 4,
    num_classes = 10,
    dim = int(512),
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout = 0.1,
    emb_dropout = 0.1)
network = nn.DataParallel(network)


In [27]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
# checkpoint = torch.load('./checkpoint/vit_small-4-vanilla-ckpt.t7')
checkpoint = torch.load('./checkpoint/vit_small_flip_att-4-vanilla-ckpt.t7')  # Adjust the name accordingly
network.load_state_dict(checkpoint['model'])
network.to(device)
positive = 0
negative = 0
network.eval()
# for epsilon in eps:
epsilon = 0.03
loss_fn = nn.CrossEntropyLoss() 
for X, y in test_dataloader:
    X, y = X.to(device), y.to(device)
    X.requires_grad = True
    pred = network(X)
    network.zero_grad()
    loss = loss_fn(pred, y)
    loss.backward()
    X = X + epsilon * X.grad.sign()
    X_adv = torch.clamp(X, 0, 1)
    pred = network(X_adv)
    pred = F.softmax(pred, dim = 1)
    for item in zip(pred, y):
        if torch.argmax(item[0]) == item[1]:
            positive += 1
        else:
            negative += 1

acc_adv = positive / (positive + negative)
print(f"epsilon={epsilon} acc_adv: {acc_adv * 100}% ")

epsilon=0.03 acc_adv: 37.7% 


: 