In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import os
import sys

from tqdm import tqdm

from resnet import ResNet18

In [2]:
if torch.cuda.is_available() == True:
    device = torch.device('cuda:0')
    print(torch.cuda.get_device_name())
else:
    device = torch.device('cpu')

GeForce RTX 2070 SUPER


In [3]:
batch_size = 128

transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), 
                                transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

bar_format = '{bar:30} {n_fmt}/{total_fmt} [{elapsed}<{remaining} {rate_fmt}] {desc}'
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [4]:
def imshow(img):
    img[0] = img[0] * 0.2023 + 0.4914
    img[1] = img[1] * 0.1994 + 0.4822
    img[2] = img[2] * 0.2010 + 0.4465
    npimg = img.cpu().detach().numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [5]:
def save_model(epoch, acc, optimizer):
    global best_acc
    if acc > best_acc:
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'optimizer': optimizer.state_dict()
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/model_{epoch}.pth')
        best_acc = acc
        print('Saving Model...')

def load_model(name):
    state_dict = torch.load(f'./models/{name}.pth', map_location=device)
    model = ResNet18()
    model.to(device)
    model.load_state_dict(state_dict['model'])
    optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4)
#     optimizer.load_state_dict(state_dict['optimizer'])
    return model, optimizer

In [6]:
def fgsm_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon * sign_data_grad
#     perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

In [46]:
def attack(model, epsilon):
    model.eval()
    test_loss = 0
    correct = 0
    success = 0
    total = 0
    test_iter = tqdm(enumerate(test_loader), total=len(test_loader), unit_scale=batch_size, bar_format=bar_format)
    for i, (batch, label) in test_iter:
        batch, label = batch.to(device), label.to(device)
        batch.requires_grad = True
        output = model(batch)
        loss = loss_function(output, label)
        _, predicted = output.max(1)
        
        model.zero_grad()
        loss.backward()
        batch_grad = batch.grad.data
        for i, data in enumerate(batch.clone()):
            if label[i].item() == predicted[i].item():
                data_grad = batch_grad[i]
                perturbed_image = fgsm_attack(data, epsilon, data_grad)
                batch[i] = perturbed_image
#                 if i == 11:
#                     imshow(data)
#                     imshow(perturbed_image)
            else:
                batch[i] = data
        
        new_output = model(batch)
        _, new_predicted = new_output.max(1)
        
        total += label.size(0)
        correct += new_predicted.eq(label).sum().item()
        success += (~new_predicted.eq(predicted)).sum().item()

        acc = 100. * correct / total
        test_iter.set_description(f'[{acc:.2f}%({correct}/{total}) {success}]', False)
    return acc

In [47]:
# FGSM Attack
model, optimizer = load_model('baseline')
loss_function = nn.CrossEntropyLoss()
epsilon = 0.01
attack(model, epsilon)

██████████████████████████████ 10112/10112 [00:11<00:00 870.34it/s] [75.08%(7508/10000) 1945]: 


75.08