# Mount your own Google Drive

Allow Google Drive for desktop full access to your Google Account. (You can remove it later again in your Google account settings.)

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

# Clone the repository to your Google Drive

Next, we change the directory to your Google Drive content drive, clone the public repository that our group prepared.

In [None]:
%cd gdrive/MyDrive/
! git clone https://github.com/UzL-PrivSec/summer-school-2024.git
%cd summer-school-2024/
%cd adversarial_examples/

In [1]:
import numpy as np
import torch.nn.functional as F
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

from data.MnistNet import Net

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Net().to(device)
model.load_state_dict(torch.load("./data/mnist_cnn.pt"))  # Load pre-trained weights
model.eval()  # Avoid any backpropagation

dataset = datasets.MNIST('./data', train=False, download=True, transform=transforms.ToTensor())
data_loader = torch.utils.data.DataLoader(dataset, **{"batch_size": 1, "shuffle": True})

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [None]:
eps = 1  # Upper bound for the noise
eps_step = 1/20  # Increase of the noise

# In case of a targeted attack we need to specify the targeted class
target_label = None
#target_label = torch.Tensor([0]).type(torch.LongTensor).to(device)
targeted = target_label != None  # True if <target_label> != None else False

In [None]:
def get_grad_signs(x, y):
    """
    Feed the given input into the model and get the gradient wrt the desired label.
    The FGSM only requires the signs of the gradient.

    :param data: A single image
    :param label: The desired output label

    :return: The sign of the created gradient wrt the desired label
    """

    # Forward propagate input
    output = model(x)

    # Create desired gradients
    # Note: Use <F.nll_loss> as loss funtion
    loss = # TODO a) #
    model.zero_grad()
    loss.backward()

    # Get signs of gradient [[0.5, 0, -0.7, 3], ...] -> [[1, 0, -1, 1], ...]
    # Note: Use <x.grad.data> to access the gradient
    sign = # TODO a) #

    # Invert every sign in case of a targeted attack
    if targeted:
        # TODO b) #
        pass

    return sign


def apply_perturbation(x, gradient_signs, eps):
    """
    Apply a perturbation onto the given image.
    The strength of the applied perturbation is controlled by <eps>.

    :param x: A single image
    :param gradient_signs: The signs of the gradient with the same shape as the input image
    :param eps: A skalar which controlles the strength of the applied perturbation

    :return: A modified version of the given image
    """

    # Calculate perturbation -> eps * sign(gradient)
    perturbation = # TODO a) #

    # Add perturbation to the image
    x = x + perturbation

    # Ensure that every pixel value is still in the interval [0,1]
    x = # TODO a) #

    return x

In [None]:
def minimal_perturbation(x, y):
    """
    Try to find a perturbation based on a minimal epsilon such that...
    a) Untargeted attack: the model classifies the given image into an arbitrary class
    b) Targeted attack: the model classifies the given image into the class <target_label>

    :param x: A single image
    :param y: The original class label of the given image

    :return: (adversarial example, new label) as tuple or None if no adversarial example could be generated
    """

    # In case of a targeted attack set the desired label
    if targeted:
        y = target_label

    # Get signs of the desired gradient
    gradient_signs = get_grad_signs(x, y)

    current_eps = eps_step
    partial_stop_condition = current_eps <= eps

    while partial_stop_condition:

        current_adv_x = apply_perturbation(x, gradient_signs, current_eps)

        # Predict new label
        adv_preds = # TODO a) #
        new_label = adv_preds.argmax(dim=1, keepdim=True)[0].cpu().detach()

        # Untargeted attack: Check if we get another arbitrary label
        if not targeted:
            flipped = # TODO a) #
        # Targeted attack: Check if we get the desired label
        else:
            flipped = # TODO b) #

        # Update current eps and check the stop condition
        current_eps += eps_step
        partial_stop_condition = current_eps <= eps

        # If we successfully generated an adversarial example then save it in combination with its new label
        if flipped:
            return current_adv_x.detach().cpu(), new_label

    return None

In [None]:
def generate(data_loader, device):
    """
    Iterates the given dataset and tries to find an adversarial example for every image.

    :param data_loader: Dataset packed in an iterable data loader
    :param device: PyTorch device

    :return: A list of adversarial examples with the following structure
             [(original_img, original_label, adversarial_img, adversarial_label), ...]
    """

    x_advs = list()

    for i, (data, label) in enumerate(data_loader):

        data, label = data.to(device), label.to(device)
        data.requires_grad = True  # To generate gradients...

        # This does not make much sense...
        if targeted:
            if label == target_label:
                continue

        # Search for an adversarial example
        x_adv = minimal_perturbation(data, label)
        if x_adv != None:
            x_advs.append( (data.detach().cpu(), label.detach().cpu(), *x_adv) )

        # You don't have to iterate all 10,000 images...
        if i >= 10:
            break

    return x_advs

In [None]:
adv_examples = generate(data_loader, device)

print(f"Found {len(adv_examples)} adversarial examples...")

# Use this to plot some images and their corresponding adversarial versions
for org, org_label, adv, adv_label in adv_examples:

    plt.imshow(org[0][0])
    plt.title(f"Label: {org_label}")
    plt.show()

    plt.imshow(adv[0][0])
    plt.title(f"Label: {adv_label}")
    plt.show()