In [23]:
import os

import numpy as np
import torch
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
from matplotlib import cm
from sklearn.model_selection import StratifiedShuffleSplit
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.transforms import InterpolationMode

import foolbox as fb
from foolbox import PyTorchModel

from robustness_helpers.config import DATA_ROOT, class_names
from robustness_helpers.data_preprocessing import (
    get_cifar10_loaders,
    get_cifar10_full_dataset_and_indices,
)
from robustness_helpers.model_utils import load_cifar10_model, get_device


In [24]:
train_loader, val_loader, test_loader = get_cifar10_loaders()

In [25]:
device = get_device()

In [26]:
model = load_cifar10_model("rdnet_tiny_transfer_learn__valLoss0.1992_valAcc93.86.pth", device)
# model

### Foolbox Evaluation Setup (No Normalization)

This section prepares a test DataLoader for evaluating adversarial robustness using **Foolbox**. 

Normalization is intentionally **omitted** in the transform pipeline because Foolbox expects image tensors to be in the `[0, 1]` range. Standard normalization (e.g., ImageNet mean/std) would distort pixel values and invalidate perturbation calculations.

#### Transform Pipeline
- `Resize(248)` with bicubic interpolation (for scaling consistency).
- `CenterCrop(224)` to match input size expectations for ImageNet-trained models.
- `ToTensor()` converts PIL images to tensors **without normalization**.

#### Dataset
- Uses the **CIFAR-10** test set (`train=False`).
- Downloaded automatically if not present in the specified `DATA_ROOT`.
- Applies the defined no-normalization transform.

#### DataLoader
- `batch_size=16`
- `shuffle=False` (important for reproducibility in adversarial settings)
- `num_workers=0` for minimal multiprocessing (can be adjusted).


In [27]:
full_dataset, train_idx, val_idx = get_cifar10_full_dataset_and_indices()
test_transform_no_norm = transforms.Compose(
    [
        transforms.Resize(
            size=248, interpolation=InterpolationMode.BICUBIC, antialias=True
        ),
        transforms.CenterCrop(size=(224, 224)),
        transforms.ToTensor(),
    ]
)
test_dataset_no_norm = CIFAR10(
    root=DATA_ROOT, train=False, download=True, transform=test_transform_no_norm
)
test_loader_no_norm = DataLoader(
    test_dataset_no_norm, batch_size=16, shuffle=False, num_workers=0
)

val_transform_no_norm = transforms.Compose(
    [
        transforms.Resize(
            size=248, interpolation=InterpolationMode.BICUBIC, antialias=True
        ),
        transforms.CenterCrop(size=(224, 224)),
        transforms.ToTensor(),
    ]
)

val_dataset_no_norm = Subset(full_dataset, val_idx)
val_dataset_no_norm.dataset.transform = val_transform_no_norm
val_loader_no_norm = DataLoader(
    val_dataset_no_norm, batch_size=16, shuffle=False, num_workers=0
)

In [28]:
def evaluate_foolbox_attack(model, loader, device, attack, epsilon=0.03, num_batches=5):
    if device == "mps":
        device_to_use = "cpu"
    model.eval()
    model.to(device_to_use)

    fmodel = fb.PyTorchModel(
        model,
        bounds=(0, 1),
        # preprocessing=dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    )
    total = 0
    successful = 0

    for i, (images, labels) in enumerate(loader):
        if i >= num_batches:
            break

        images, labels = images.to(device_to_use), labels.to(device_to_use)
        raw_adv, clipped_adv, success = attack(fmodel, images, labels, epsilons=epsilon)

        total += len(labels)
        successful += success.sum().item()

        del images, labels, raw_adv, clipped_adv, success

    if device == "mps":
        torch.mps.empty_cache()
    elif device == "cuda":
        torch.cuda.empty_cache()
    else:
        pass

    success_rate = successful / total
    print(
        f"\n[Foolbox] Attack Success Rate: {success_rate * 100:.4f}% (epsilon={epsilon}) Attack Type: {attack}\n"
    )
    return success_rate

In [29]:
fgsm_attack = fb.attacks.FGSM()
# evaluate_foolbox_attack(model, test_loader_no_norm, device, fgsm_attack, epsilon=0.03)

linfpgd_attack = fb.attacks.LinfPGD(steps=20)
# evaluate_foolbox_attack(
#     model, test_loader_no_norm, device, linfpgd_attack, epsilon=0.03
# )

deepfool_attack = fb.attacks.L2DeepFoolAttack()
# evaluate_foolbox_attack(
#     model, test_loader_no_norm, device, deepfool_attack, epsilon=0.03
# )

l2pgd_attack = fb.attacks.L2PGD(steps=20)
# evaluate_foolbox_attack(model, test_loader_no_norm, device, l2pgd_attack, epsilon=0.03)

**[Foolbox] Attack Results:**

---

**Attack 1:**
- **Attack Success Rate**: **87.5000%** (epsilon=0.03)  

- **Attack Type**: `LinfFastGradientAttack(rel_stepsize=1.0, abs_stepsize=None, steps=1, random_start=False)`

---

**Attack 2:**
- **Attack Success Rate**: **100.0000%** (epsilon=0.03)  

- **Attack Type**: `LinfProjectedGradientDescentAttack(rel_stepsize=0.03333333333333333, abs_stepsize=None, steps=20, random_start=True)`

---

**Attack 3:**
- **Attack Success Rate**: **56.2500%** (epsilon=0.03)  

- **Attack Type**: `L2DeepFoolAttack(steps=50, candidates=10, overshoot=0.02, loss=logits)`

---

**Attack 4:**
- **Attack Success Rate**: **56.2500%** (epsilon=0.03)  

- **Attack Type**: `L2ProjectedGradientDescentAttack(rel_stepsize=0.025, abs_stepsize=None, steps=20, random_start=True)`

---


### Foolbox Adversarial Attack Image Evaluation

This function evaluates the performance of different **Foolbox attacks** on a given model. It generates **adversarial images** and compares them to the original images by showing the differences and intensity distributions.

#### Steps:
1. **Prepare the model and data**:
   - The model is set to evaluation mode and moved to the appropriate device (CPU or GPU).
   - A batch of images and labels is selected from the `data_loader`.

2. **Foolbox attack**:
   - The function applies the given attack (e.g., FGSM, PGD, DeepFool) on the model using the `Foolbox` library.

3. **Generate images**:
   - It generates:
     - The **original image**.
     - The **adversarial image** (after the attack).
     - A **difference heatmap** to visualize how the adversarial image deviates from the original.

4. **Intensity histograms**:
   - It also plots the intensity distributions of:
     - The **original image**.
     - The **adversarial image**.
     - The **difference** between the original and adversarial images.

5. **Save output**:
   - The generated images and histograms are saved to a folder named `attack_images` with the attack type in the filename.

#### Example of attacks evaluated:
- **FGSM Attack** (Fast Gradient Sign Method)
- **LinfPGD Attack** (Projected Gradient Descent with Linf norm)
- **DeepFool Attack**
- **L2PGD Attack** (Projected Gradient Descent with L2 norm)

Each evaluation generates a PNG image showing:
- The **original** vs **adversarial image**.
- A **difference heatmap**.
- **Intensity histograms** for the original, adversarial, and difference images.


In [30]:
save_path = "attack_images"
os.makedirs(save_path, exist_ok=True)

def evaluate_foolbox_attack_image(
    model, data_loader, device, attack, epsilon=0.03, save_path="attack_images"
):
    if device == "mps":
        device_to_use = "cpu"
    else:
        device_to_use = device

    model.eval()
    model.to(device_to_use)
    os.makedirs(save_path, exist_ok=True)

    images, labels = next(iter(data_loader))
    images, labels = images.to(device_to_use), labels.to(device_to_use)
    image = images[1:2]
    label = labels[1:2]

    fmodel = PyTorchModel(model, bounds=(0, 1))

    raw_advs, clipped_advs, is_adv = attack(fmodel, image, label, epsilons=epsilon)

    original = image[0].detach().cpu().numpy().transpose(1, 2, 0)
    adversarial = clipped_advs[0].detach().cpu().numpy().transpose(1, 2, 0)

    diff_gray = np.mean(np.abs(original - adversarial), axis=2)
    diff_heatmap = cm.inferno(diff_gray / diff_gray.max())

    ground_truth_label = class_names[label.item()]
    with torch.no_grad():
        adv_tensor = (
            torch.tensor(adversarial)
            .permute(2, 0, 1)
            .unsqueeze(0)
            .float()
            .to(device_to_use)
        )
        output = model(adv_tensor)
        pred_label = class_names[torch.argmax(output, dim=1).item()]

    fig, axs = plt.subplots(2, 3, figsize=(12, 6))

    axs[0, 0].imshow(np.clip(original, 0, 1))
    axs[0, 0].set_title("Original Image", pad=10)
    axs[0, 0].axis("off")

    axs[0, 1].imshow(np.clip(adversarial, 0, 1))
    axs[0, 1].set_title("Adversarial Image", pad=10)
    axs[0, 1].axis("off")

    axs[0, 2].imshow(diff_heatmap)
    axs[0, 2].set_title("Difference (x10)", pad=10)
    axs[0, 2].axis("off")

    axs[1, 0].hist(original.ravel(), bins=50, color="gray")
    axs[1, 0].set_title("Original Intensities")

    axs[1, 1].hist(adversarial.ravel(), bins=50, color="red")
    axs[1, 1].set_title("Adversarial Intensities")

    axs[1, 2].hist(diff_gray.ravel(), bins=50, color="orange")
    axs[1, 2].set_title("Difference Intensities")

    fig.text(0.17, 0.02, f"Label: {ground_truth_label}", ha="center", fontsize=10)
    fig.text(0.51, 0.02, f"Predicted: {pred_label}", ha="center", fontsize=10)

    plt.tight_layout(rect=[0, 0.04, 1, 0.95])
    filename = f"{attack.__class__.__name__}_attack_image_with_intensity.png"
    plt.savefig(os.path.join(save_path, filename), dpi=600, bbox_inches="tight")
    plt.close()


# evaluate_foolbox_attack_image(
#     model,
#     val_loader_no_norm,
#     device,
#     fgsm_attack,
#     epsilon=0.03,
#     save_path="attack_images",
# )
# evaluate_foolbox_attack_image(
#     model,
#     val_loader_no_norm,
#     device,
#     linfpgd_attack,
#     epsilon=0.03,
#     save_path="attack_images",
# )
# evaluate_foolbox_attack_image(
#     model,
#     val_loader_no_norm,
#     device,
#     deepfool_attack,
#     epsilon=0.03,
#     save_path="attack_images",
# )
# evaluate_foolbox_attack_image(
#     model,
#     val_loader_no_norm,
#     device,
#     l2pgd_attack,
#     epsilon=0.03,
#     save_path="attack_images",
# )

### Plotting Perturbation Distributions for Adversarial Attacks

This function generates histograms showing the distribution of **perturbations** caused by different **Foolbox attacks** applied to a model. It helps visualize how much the attacks perturb's the images during the adversarial attack process.

### Important Note

The histogram images showing the perturbation distributions were generated using the **Validation Loader (No Norm)**. This approach was chosen to **minimize memory consumption** during the evaluation process. 


#### Steps:
1. **Prepare the model and data**:
   - The model is set to evaluation mode.
   - A batch of images and labels is loaded from the provided data loader.

2. **Adversarial attacks**:
   - For each attack in the provided list the function applies the attack to the images in the data loader.

3. **Perturbation calculation**:
   - The **perturbation** for each image is calculated as the difference between the adversarial image and the original image.
   - The absolute values of these perturbations are stored.

4. **Plot the distributions**:
   - The absolute perturbation values are **flattened** and combined across all batches.
   - A histogram is plotted showing the **distribution of perturbations** (i.e., how much the attack perturbs the images).
   - The histogram is saved as a PNG file with the attack name in the filename.

5. **Save the output**:
   - The generated histograms are saved in the `attack_images` folder.

The result is a series of histograms showing how much perturbation each attack causes to the images, providing insights into the attack's impact on the model's input images.

#### Notes:
- The function processes only a set number of batches (`num_batches`).
- For each attack, the corresponding histogram is saved in the specified `save_path`.

In [31]:
def plot_all_attack_perturbation_distributions(
    model,
    loader,
    device,
    attacks,
    epsilon=0.03,
    num_batches=5,
    save_path="attack_images",
):
    if device == "mps":
        device_to_use = "cpu"
    model.eval()
    os.makedirs(save_path, exist_ok=True)

    for attack in attacks:
        fmodel = fb.PyTorchModel(model, bounds=(0, 1))
        perturbations = []

        for i, (images, labels) in enumerate(loader):
            if i >= num_batches:
                break

            images, labels = images.to(device_to_use), labels.to(device_to_use)
            raw_adv, clipped_adv, _ = attack(fmodel, images, labels, epsilons=epsilon)
            perturbation = clipped_adv - images
            perturbations.append(perturbation.abs().cpu().detach().numpy())

            del images, labels, raw_adv, clipped_adv
            if device == "mps":
                torch.mps.empty_cache()
            elif device == "cuda":
                torch.cuda.empty_cache()
            else:
                pass

        perturbations = np.concatenate([p.flatten() for p in perturbations])
        plt.figure(figsize=(8, 6))
        plt.hist(perturbations, bins=50, color="skyblue", edgecolor="black")
        plt.title(
            f"Distribution of Perturbations\n{attack.__class__.__name__} (ε={epsilon})"
        )
        plt.xlabel("Perturbation Magnitude")
        plt.ylabel("Frequency")
        plt.tight_layout()
        plt.savefig(
            os.path.join(
                save_path, f"perturbation_dist_{attack.__class__.__name__}.png"
            ),
            dpi=600,
        )
        plt.close()


# plot_all_attack_perturbation_distributions(
#     model,
#     val_loader_no_norm,
#     device,
#     attacks=[fgsm_attack, linfpgd_attack, deepfool_attack, l2pgd_attack],
#     epsilon=0.03,
#     num_batches=5
# )