In [None]:
def fgsm_attack(image, epsilon, data_grad):
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

def adversarial_test(model, test_dataloader_, epsilon):
    loss_func = nn.CrossEntropyLoss()
    model.eval()
    adv_acc = 0
    adv_examples = []
    for idx, data in tqdm(enumerate(test_dataloader_), total=len(test_dataloader_)):
        image, label = data
        image = image.to(device)
        label = label.to(device)

        image.requires_grad = True
        output = model(image)
        init_pred = output.max(1, keepdim=True)[1]
        loss = loss_func(output, label)
        model.zero_grad()
        loss.backward()
        data_grad = image.grad.data
        # perturb the image with the calculated gradient
        perturbed_image = fgsm_attack(image, epsilon, data_grad)

        # calculate the model output on the perturbed image
        output = model(perturbed_image)
        final_pred = output.max(1, keepdim=True)[1]
        adv_acc += (final_pred == label).sum().item()
        if final_pred.item() == label.item():
            adv_acc += 1
        else:
            adv_ex = perturbed_image.squeeze().detach().cpu().numpy()
            adv_examples.append( (init_pred.item(), final_pred.item(), adv_ex) )

    adv_acc = adv_acc / len(test_dataloader_.dataset)
    return adv_acc, adv_examples

test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=download_data, transform=transform)
test_dataloader = torch.utils.data.DataLoader(test_dataset,  batch_size=1, shuffle=False)
epsilons = [0.1, 0.2, 0.5]
accuracies = []
examples = []

# Run test for each epsilon
for eps in epsilons:
    acc, ex = adversarial_test(model, test_dataloader, eps)
    accuracies.append(acc)
    examples.append(ex)
print(accuracies)