<a href="https://colab.research.google.com/github/awsdevguru/PearsonMLFoundations/blob/main/3_6_03_Hands_on_Lab_Adversarial_Detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Adversarial Attacks on CIFAR-10 with FGSM (Fast Gradient Sign Method)

This notebook demonstrates how to craft adversarial examples on the CIFAR-10 dataset using the Fast Gradient Sign Method (FGSM) against a simple convolutional neural network (CNN) trained with PyTorch.

**Goals**

- Load CIFAR-10 and build a simple CNN classifier
- Train (or load) the model
- Evaluate on clean test images
- Implement FGSM and generate adversarial examples
- Measure adversarial accuracy vs. perturbation strength (epsilon)
- Visualize original vs. adversarial images

FGSM is a white-box adversarial attack that perturbs the input image in the
direction of the sign of the gradient of the loss with respect to the input. Even small perturbations (often imperceptible) can drastically change the model's prediction.

## 1. Imports and Configuration

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

import torchvision
import torchvision.transforms as transforms

import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path

# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# CIFAR-10 normalization stats
cifar10_mean = torch.tensor([0.4914, 0.4822, 0.4465]).view(3, 1, 1)
cifar10_std = torch.tensor([0.2470, 0.2435, 0.2616]).view(3, 1, 1)


## 2. Load the dataset

We will:

- Apply standard CIFAR-10 normalization
- Use basic data augmentation (random crop + horizontal flip) for training
- Build PyTorch DataLoaders for train and test sets

For faster experimentation, you can optionally subsample the training set.

In [None]:
# Transforms for training and test sets
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2470, 0.2435, 0.2616),
    ),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        (0.4914, 0.4822, 0.4465),
        (0.2470, 0.2435, 0.2616),
    ),
])

# Download CIFAR-10
trainset = torchvision.datasets.CIFAR10(
    root="./data", train=True, download=True, transform=transform_train
)
testset = torchvision.datasets.CIFAR10(
    root="./data", train=False, download=True, transform=transform_test
)

# Optional: subsample training set for quicker runs
train_subset_size = 10_000   # set to None to use full dataset

if train_subset_size is not None:
    indices = np.random.permutation(len(trainset))[:train_subset_size]
    trainset = Subset(trainset, indices)

batch_size = 128

trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

print("Train size:", len(trainset))
print("Test size:", len(testset))


## 3. Define a simple CNN classifier for CIFAR-10

This is a small convolutional network, not meant to be state-of-the-art,
but good enough to illustrate adversarial attacks.


In [None]:
class SimpleCIFAR10CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))   # 32x32 -> 16x16
        x = self.pool(F.relu(self.conv2(x)))   # 16x16 -> 8x8
        x = self.pool(F.relu(self.conv3(x)))   # 8x8  -> 4x4
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleCIFAR10CNN().to(device)
print(model)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

checkpoint_path = Path("cifar10_simplecnn_fgsm_demo.pth")


## 4. Train (or load) the model

To keep the notebook flexible:
- If a checkpoint exists, we load it.
- Otherwise, we train for a few epochs and save the model.

Increase the number of epochs for better accuracy (and stronger demonstrations).


In [None]:
def train(model, trainloader, optimizer, criterion, device, epochs=5):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in trainloader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        epoch_loss = running_loss / len(trainloader)
        epoch_acc = 100.0 * correct / total
        print(f"Epoch {epoch+1}/{epochs} - loss: {epoch_loss:.4f} - acc: {epoch_acc:.2f}%")
    print("Finished Training")


def evaluate_accuracy(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return 100.0 * correct / total


## 5. Train or Load

In [None]:
TRAIN_MODEL = True  # set False if you only want to load an existing checkpoint

if checkpoint_path.exists():
    print(f"Loading existing checkpoint from {checkpoint_path}")
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
else:
    if TRAIN_MODEL:
        print("No checkpoint found. Training model...")
        train(model, trainloader, optimizer, criterion, device, epochs=5)
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Saved checkpoint to {checkpoint_path}")
    else:
        raise FileNotFoundError("Checkpoint not found and TRAIN_MODEL is False.")

clean_acc = evaluate_accuracy(model, testloader, device)
print(f"Clean test accuracy: {clean_acc:.2f}%")


## 6. Implementing FGSM

For each input image:

1. Enable gradient tracking on the input.
2. Compute the loss for the correct labels.
3. Backpropagate to get the gradient of the loss with respect to the input.
4. Take the sign of this gradient.
5. Create an adversarial example by adding a small step in the direction of the sign.

We clamp the perturbed image back to a reasonable range to avoid numerical issues.
Because we are working in *normalized* pixel space, the bound is not [0, 1]; instead
we clamp to a broad range (e.g. [-3, 3]) that comfortably covers CIFAR-10 after normalization.


In [None]:
def fgsm_attack(model, images, labels, epsilon, device, clamp_min=-3.0, clamp_max=3.0):
    """
    Perform FGSM attack on a batch of images.

    Args:
        model: trained classifier
        images: normalized input images (B, C, H, W)
        labels: ground-truth labels
        epsilon: size of perturbation
        device: torch device
        clamp_min, clamp_max: bounds in normalized space

    Returns:
        adv_images: adversarial images
    """
    images = images.clone().detach().to(device)
    labels = labels.to(device)

    images.requires_grad = True

    outputs = model(images)
    loss = criterion(outputs, labels)

    model.zero_grad()
    loss.backward()

    # Sign of the gradient
    grad_sign = images.grad.data.sign()

    # Create perturbed image
    adv_images = images + epsilon * grad_sign
    adv_images = torch.clamp(adv_images, clamp_min, clamp_max)

    return adv_images.detach()


## 7. Evaluating adversarial robustness vs. epsilon

We now:
- Loop over a set of epsilon values
- For each epsilon, craft adversarial examples on the fly for the entire test set
- Measure the resulting adversarial accuracy

This shows how accuracy degrades as the attack strength increases.


In [None]:
def evaluate_fgsm(model, dataloader, device, epsilons):
    model.eval()
    epsilon_to_acc = {}

    for eps in epsilons:
        correct = 0
        total = 0

        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)

            # Generate adversarial examples
            adv_inputs = fgsm_attack(model, inputs, labels, epsilon=eps, device=device)

            with torch.no_grad():
                outputs = model(adv_inputs)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        epsilon_to_acc[eps] = 100.0 * correct / total
        print(f"Epsilon {eps:.3f} -> adversarial accuracy: {epsilon_to_acc[eps]:.2f}%")

    return epsilon_to_acc


epsilons = [0.0, 0.01, 0.03, 0.05, 0.1]
epsilon_to_acc = evaluate_fgsm(model, testloader, device, epsilons)


## 8. Plot adversarial accuracy vs. epsilon

Let's visualize how the model's accuracy changes as we increase the perturbation size.


In [None]:
plt.figure(figsize=(6, 4))
eps_list = sorted(epsilon_to_acc.keys())
acc_list = [epsilon_to_acc[eps] for eps in eps_list]

plt.plot(eps_list, acc_list, marker="o")
plt.xlabel("Epsilon (FGSM perturbation size)")
plt.ylabel("Accuracy on adversarial examples (%)")
plt.title("CIFAR-10 FGSM Attack: Accuracy vs. Epsilon")
plt.grid(True)
plt.show()


## 9. Visualizing original vs. adversarial images

To build intuition, we'll:
- Take a small batch from the test set,
- Generate adversarial examples,
- Undo the normalization,
- Show clean vs. adversarial images side by side with predicted labels.


In [None]:
# Helpers to unnormalize and display images

mean = torch.tensor(cifar10_mean).view(3, 1, 1)
std = torch.tensor(cifar10_std).view(3, 1, 1)

def unnormalize(img_tensor):
    """
    img_tensor: (C, H, W) in normalized space
    returns: (C, H, W) in [0, 1]
    """
    return torch.clamp(img_tensor * std + mean, 0.0, 1.0)


def show_images(clean_imgs, adv_imgs, clean_labels, adv_preds, classes, n=5):
    """
    Show n examples of original vs. adversarial images.
    """
    n = min(n, clean_imgs.size(0))
    plt.figure(figsize=(10, 4))

    for i in range(n):
        # Clean
        plt.subplot(2, n, i + 1)
        img = unnormalize(clean_imgs[i].cpu())
        plt.imshow(np.transpose(img.numpy(), (1, 2, 0)))
        plt.axis("off")
        plt.title(classes[clean_labels[i].item()])

        # Adversarial
        plt.subplot(2, n, n + i + 1)
        img_adv = unnormalize(adv_imgs[i].cpu())
        plt.imshow(np.transpose(img_adv.numpy(), (1, 2, 0)))
        plt.axis("off")
        plt.title(classes[adv_preds[i].item()])

    plt.tight_layout()
    plt.show()


In [None]:
# Pick one batch from the test set
dataiter = iter(testloader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)

epsilon_demo = 0.03  # choose a moderate epsilon

# Generate adversarial batch
adv_images = fgsm_attack(model, images, labels, epsilon=epsilon_demo, device=device)

# Get predictions for adversarial images
model.eval()
with torch.no_grad():
    outputs_adv = model(adv_images)
    _, preds_adv = outputs_adv.max(1)

print(f"Showing first few images for epsilon = {epsilon_demo}")

show_images(images, adv_images, labels, preds_adv, classes, n=6)
