In [None]:
import matplotlib.pyplot as plt
import numpy as np

def show_images(original, adversarial_fgsm, adversarial_pgd, label):
    # Convert tensors to NumPy arrays for visualization
    original = original.squeeze().detach().cpu().numpy()
    adversarial_fgsm = adversarial_fgsm.squeeze().detach().cpu().numpy()
    adversarial_pgd = adversarial_pgd.squeeze().detach().cpu().numpy()

    # Plot images
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    titles = ["Original", "FGSM Perturbed", "PGD Perturbed"]
    
    for ax, img, title in zip(axes, [original, adversarial_fgsm, adversarial_pgd], titles):
        ax.imshow(img, cmap="gray")
        ax.set_title(title)
        ax.axis("off")

    plt.show()



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader

# Load a sample dataset (MNIST)
transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

# Load a pretrained model (LeNet-style CNN for MNIST)
model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
model.eval()

# FGSM Attack Function
def fgsm_attack(model, image, label, epsilon):
    image.requires_grad = True  # Enable gradient tracking
    plt.imshow(image)
    output = model(image)
    loss = nn.CrossEntropyLoss()(output, label)
    model.zero_grad()
    loss.backward()
    perturbation = epsilon * image.grad.sign()  # Compute perturbation
    adv_image = torch.clamp(image + perturbation, 0, 1)  # Ensure valid pixel range
    return adv_image

# Run the attack on one image
image, label = next(iter(dataloader))
label = torch.tensor([label])  # Ensure correct shape
epsilon = 0.1  # Small perturbation
adv_image = fgsm_attack(model, image, label, epsilon)

print("FGSM Attack Completed!")


Using cache found in /home/andrew/.cache/torch/hub/pytorch_vision_v0.10.0
