**NOTE: For each section I put some explanation for me to study the theory, so you can skip it.**

## 1. FGSM Attack
The Fast Gradient Sign Method (FGSM) is one of the simplest methods to generate adversarial examples. The idea is to perturb an input image $x$ in the direction of the gradient of the loss $J$ with respect to $x$ so that the model's prediction is maximally affected. Mathematically, the adversarial example $x_{\text{adv}}$ is generated by:

$$
\begin{equation*}
x_{\text{adv}} = x + \epsilon \cdot \text{sign} (∇_x J(x,y))
\end{equation*}
$$

* $\epsilon$ controls the magnitude of the perturbation.
* The gradient $∇_x J(x,y)$ shows which direction in pixel space would most increase the loss.
* The sign operation ensures that the perturbation is in a fixed direction (either +1 or –1 for each pixel).

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from functools import partial

In [None]:
batch_size = 64
transform_dict = {
    'train': transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ToTensor()
    ]),
    'test': transforms.Compose([
        transforms.ToTensor()
    ]),
}

In [None]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_dict['train'])
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_dict['test'])

trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Define the CNN
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout(p=0.2),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        )
        self.pool = nn.Sequential(
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=2),
            nn.Flatten()
        )
        self.head = nn.Linear(128*7*7, 10)

    def forward(self, x):
        x = self.conv(x)
        self.last_feature_map = x
        if self.last_feature_map.requires_grad:
            self.last_feature_map.retain_grad()
        x = self.pool(x)
        return self.head(x)

In [None]:
model = CNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [None]:
# Utility functions to get accuracy
def get_accuracy(model, dataloader, device):
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y = y.to(device)
            out = model(x)
            correct += (torch.argmax(out, dim=1) == y).sum().item()
    return correct / len(dataloader.dataset)

In [None]:
# ----------------------------
# Class-based Adversarial Attacks and GradCAM
# ----------------------------

# FGSM Attack (Class-based implementation)
class FGSMAttack:
    def __init__(self, model, epsilon, loss_fn=None):
        self.model = model
        self.epsilon = epsilon
        self.loss_fn = loss_fn if loss_fn is not None else nn.CrossEntropyLoss()

    def perturb(self, x, y):
        x_adv = x.clone().detach().requires_grad_(True)
        outputs = self.model(x_adv)
        loss = self.loss_fn(outputs, y)
        self.model.zero_grad()
        loss.backward()
        grad_sign = x_adv.grad.data.sign()
        x_adv = x_adv + self.epsilon * grad_sign
        x_adv = torch.clamp(x_adv, 0, 1)
        return x_adv

## 2. BIM Attack

The Basic Iterative Method (BIM) (also known as Iterative FGSM) extends FGSM by applying the perturbation multiple times with a smaller step size $\alpha$. After each iteration, the perturbation is clipped so that the total perturbation remains within the $\epsilon$-ball around the original input. Mathematically, for iteration $i$:

$$
\begin{equation*}
x_{\text{adv}}^{i+1} = \text{clip}_{x,\epsilon} ( x_{\text{adv}}^{i} + \alpha \cdot (∇_{x_{\text{adv}}^{i}} J(x_{\text{adv}}^{i}, y))
\end{equation*}
$$

* $\alpha$ is the step size (often set to a small value such as $\epsilon/N$ for $N$ iterations).
* Clipping ensures that $||x_{\text{adv}} - x||_{\infty} \leq \epsilon$.

In [None]:
# BIM Attack (Iterative FGSM)
class BIMAttack:
    def __init__(self, model, epsilon, alpha, iters, loss_fn=None):
        self.model = model
        self.epsilon = epsilon
        self.alpha = alpha
        self.iters = iters
        self.loss_fn = loss_fn if loss_fn is not None else nn.CrossEntropyLoss()

    def perturb(self, x, y):
        x_adv = x.clone().detach()
        for _ in range(self.iters):
            x_adv.requires_grad = True
            outputs = self.model(x_adv)
            loss = self.loss_fn(outputs, y)
            self.model.zero_grad()
            loss.backward()
            grad_sign = x_adv.grad.data.sign()
            x_adv = x_adv + self.alpha * grad_sign
            # Clip to ensure the overall perturbation stays within [-epsilon, epsilon]
            perturbation = torch.clamp(x_adv - x, min=-self.epsilon, max=self.epsilon)
            x_adv = torch.clamp(x + perturbation, 0, 1).detach()
        return x_adv

In [None]:
# GradCAM Visualization
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output.detach()

        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0].detach()

        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)

    def generate(self, input_image, target_class):
        self.model.zero_grad()
        output = self.model(input_image)
        loss = output[0, target_class]
        loss.backward(retain_graph=True)
        weights = self.gradients.mean(dim=[2, 3], keepdim=True)
        grad_cam_map = torch.sum(weights * self.activations, dim=1, keepdim=True)
        grad_cam_map = F.relu(grad_cam_map)
        grad_cam_map = F.interpolate(grad_cam_map, size=input_image.shape[2:], mode='bilinear', align_corners=False)
        grad_cam_map = (grad_cam_map - grad_cam_map.min()) / (grad_cam_map.max() - grad_cam_map.min() + 1e-8)
        return grad_cam_map

In [None]:
# ----------------------------
# Integration Examples
# ----------------------------

# 1. Using FGSM for evaluation (replacing your fgsm_attack function)
fgsm_attack_instance = FGSMAttack(model, epsilon=8/255)
def get_adversarial_accuracy(model, dataloader, attack, device):
    model.eval()
    correct = 0
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        # Use the attack instance's perturb method
        adv = attack.perturb(x, y)
        with torch.no_grad():
            out = model(adv)
            correct += (torch.argmax(out, dim=1) == y).sum().item()
    return correct / len(dataloader.dataset)

adv_acc = get_adversarial_accuracy(model, testloader, fgsm_attack_instance, device)
print("Adversarial Accuracy (FGSM):", adv_acc)

# 2. Using BIM for evaluation and adversarial training
bim_attack_instance = BIMAttack(model, epsilon=8/255, alpha=2/255, iters=5)

# Example: Evaluate BIM adversarial accuracy
adv_acc_bim = get_adversarial_accuracy(model, testloader, bim_attack_instance, device)
print("Adversarial Accuracy (BIM):", adv_acc_bim)

# 3. GradCAM visualization on a clean image and an adversarial image
# For GradCAM, we choose the last convolutional layer in your CNN.
# In your CNN, the last conv layer is at index 4 in model.conv.
gradcam = GradCAM(model, target_layer=model.conv[4])
x_batch, y_batch = next(iter(testloader))
x_batch, y_batch = x_batch.to(device), y_batch.to(device)

# Generate adversarial image using FGSM for visualization example
adv_image = fgsm_attack_instance.perturb(x_batch, y_batch)
# Use the true label for the clean image and adversarial label for the adversarial image
clean_label = y_batch[0].item()
adv_label = model(adv_image).argmax(dim=1)[0].item()

# Generate heatmaps
heatmap_clean = gradcam.generate(x_batch[0:1], clean_label)
heatmap_adv = gradcam.generate(adv_image[0:1], adv_label)

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
axes[0].imshow(x_batch[0].detach().cpu().permute(1, 2, 0))
axes[0].set_title(f"Clean Image: {testset.classes[clean_label]}")
axes[1].imshow(adv_image[0].detach().cpu().permute(1, 2, 0))
axes[1].set_title(f"Adversarial Image: {testset.classes[adv_label]}")
axes[2].imshow(heatmap_adv[0,0].detach().cpu(), cmap='jet')
axes[2].set_title("GradCAM on Adv Image")
plt.show()

# 4. Adversarial Training with BIM
def adversarial_training(model, train_loader, optimizer, criterion, device, attack, epochs):
    model.to(device)
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct_clean = 0
        correct_adv = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            # Generate adversarial examples using the provided attack instance
            images_adv = attack.perturb(images, labels)

            optimizer.zero_grad()
            outputs_clean = model(images)
            outputs_adv = model(images_adv)
            loss_clean = criterion(outputs_clean, labels)
            loss_adv = criterion(outputs_adv, labels)
            loss_total = loss_clean + loss_adv
            loss_total.backward()
            optimizer.step()

            running_loss += loss_total.item()
            _, pred_clean = outputs_clean.max(1)
            _, pred_adv = outputs_adv.max(1)
            correct_clean += (pred_clean == labels).sum().item()
            correct_adv += (pred_adv == labels).sum().item()
            total += labels.size(0)

        print(f"Epoch {epoch+1}/{epochs} Loss: {running_loss/len(train_loader):.4f} "
              f"Clean Acc: {correct_clean/total:.4f} Adv Acc: {correct_adv/total:.4f}")

# Reinitialize or load your model and optimizer as needed before adversarial training
model = CNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
epochs = 10
print("Adversarial Training with BIM Attack")
adversarial_training(model, trainloader, optimizer, criterion, device, bim_attack_instance, epochs)

print("Post-training Clean Accuracy:", get_accuracy(model, testloader, device))
print("Post-training Adv Accuracy (using FGSM):", get_adversarial_accuracy(model, testloader, fgsm_attack_instance, device))