In [1]:
import torch 
import torchvision 
import torchvision.transforms as transforms

In [3]:
tensor_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=tensor_transform)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)

Files already downloaded and verified


In [5]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

net = Net()

In [7]:
net.load_state_dict(torch.load('cifar_net.pth'))
net.eval()

Net(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=4096, out_features=512, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=512, out_features=10, bias=True)
)

In [9]:
def fgsm_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon*sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

def denorm(batch, mean=[0.1307], std=[0.3081]):
    if isinstance(mean, list):
        mean = torch.tensor(mean)
    if isinstance(std, list):
        std = torch.tensor(std)

    return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)



In [11]:
def test( model, test_dataloader, epsilon ):
    correct = 0
    adv_examples = []
    for data, target in test_dataloader:
        data.requires_grad = True

        output = model(data)
        init_pred = output.max(1, keepdim=True)[1]
        for i in range(len(target)):
            if init_pred[i].item() != target[i].item():
                continue
                
            loss = F.nll_loss(output, target)
            model.zero_grad()
            loss.backward()
            data_grad = data.grad.data
            data_denorm = denorm(data)
            perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)
            perturbed_data_normalized = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(perturbed_data)
            output = model(perturbed_data_normalized)
    
            final_pred = output.max(1, keepdim=True)[1]
            if final_pred[i].item() == target[i].item():
                correct += 1
                if epsilon == 0 and len(adv_examples) < 5:
                    adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                    adv_examples.append( (init_pred[i].item(), final_pred[i].item(), adv_ex) )
                else:
                    if len(adv_examples) < 5:
                        adv_ex = perturbed_data.squeeze().detach().cpu().numpy()
                        adv_examples.append( (init_pred[i].item(), final_pred[i].item(), adv_ex) )

    final_acc = correct/float(len(test_dataloader))
    print(f"Epsilon: {epsilon}\tTest Accuracy = {correct} / {len(test_dataloader)} = {final_acc}")

    return final_acc, adv_examples

In [13]:
accuracies = []
examples = []

epsilons = [0, .05, .1, .15, .2, .25, .3]

for eps in epsilons:
    acc, ex = test(net, test_dataloader, eps)
    accuracies.append(acc)
    examples.append(ex)

Epsilon: 0	Test Accuracy = 4268 / 10000 = 0.4268
Epsilon: 0.05	Test Accuracy = 487 / 10000 = 0.0487
Epsilon: 0.1	Test Accuracy = 472 / 10000 = 0.0472
Epsilon: 0.15	Test Accuracy = 625 / 10000 = 0.0625
Epsilon: 0.2	Test Accuracy = 714 / 10000 = 0.0714
Epsilon: 0.25	Test Accuracy = 767 / 10000 = 0.0767
Epsilon: 0.3	Test Accuracy = 790 / 10000 = 0.079
