# Adversarial Robustness of MLP, ViT and CNN on CIFAR-10 and CIFAR-100

### Imports

In [None]:
import time

import torch
import timm
from tqdm import tqdm
from torchvision import transforms

from data_utils.data_stats import *
from data_utils.dataloader import get_loader
from models.networks import get_model
from model_utils import get_test_data_and_model
from utils.metrics import topk_acc, AverageMeter

### Evaluating baseline model accuracy

In [None]:
# Define a test function that evaluates test accuracy
@torch.no_grad()
def test(model, loader):
    total_acc, total_top5 = AverageMeter(), AverageMeter()

    model.eval()

    for ims, targs in tqdm(loader, desc="Evaluation"):
        preds = model(ims)
        acc, top5 = topk_acc(preds, targs, k=5, avg=True)
        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])

    return (
        total_acc.get_avg(percentage=True),
        total_top5.get_avg(percentage=True),
    )

In [None]:
dataset_name = 'cifar10'
model_name = 'mlp'

data_loader, model = get_test_data_and_model(dataset=dataset_name, model=model_name, data_path='/scratch/data/ffcv/')
test_acc, test_top5 = test(model, data_loader)

# Print all the stats
print("Test Accuracy        ", "{:.4f}".format(test_acc))
print("Top 5 Test Accuracy          ", "{:.4f}".format(test_top5))

### Evaluate adversarial accuracy

In [None]:
def denormalize(tensor, mean, std):
    """
    Denormalize a tensor.

    Parameters:
    tensor (torch.Tensor): The tensor to denormalize.
    mean (float or sequence): The mean used for normalization.
    std (float or sequence): The standard deviation used for normalization.

    Returns:
    torch.Tensor: The denormalized tensor.
    """
    return transforms.Normalize(-mean/std, 1/std)(tensor)

def normalize(tensor, mean, std):
    """
    Normalize a tensor.

    Parameters:
    tensor (torch.Tensor): The tensor to normalize.
    mean (float or sequence): The mean used for normalization.
    std (float or sequence): The standard deviation used for normalization.

    Returns:
    torch.Tensor: The normalized tensor.
    """
    return transforms.Normalize(mean, std)(tensor)

def pgd(model, dataset, x_batch, label, eps, k, eps_step):
    """
    Performs the Projected Gradient Descent (PGD) for adversarial attacks.

    Parameters:
    model (torch.nn.Module): The model to attack.
    dataset (str): The name of the dataset used (can be cifar10, cifar100 or imagenet).
    x_batch (torch.Tensor): The input tensor.
    label (torch.Tensor): The true labels for the input tensor.
    eps (float): The maximum perturbation for PGD.
    k (int): The number of steps for PGD.
    eps_step (float): The step size for each iteration.

    Returns:
    torch.Tensor: The adversarially perturbed input tensor.
    """   
    mean, std = MEAN_DICT[dataset]/255, STD_DICT[dataset]/255

    x = x_batch.clone().detach_()
    x = denormalize(x, mean, std)
    x_adv = x + eps * (2*torch.rand_like(x) - 1)
    x_adv.clamp_(min=0., max=1.)
    
    for _ in range(int(k)):
        x_adv = normalize(x_adv, mean, std).detach_()
        x_adv.requires_grad_()
        model.zero_grad()
        loss = torch.nn.CrossEntropyLoss()(model(x_adv), label)
        loss.backward()
        perturbation = eps_step * x_adv.grad.sign()

        x_adv = denormalize(x_adv, mean, std)
        x_adv = x + (x_adv + perturbation - x).clamp_(min=-eps, max=eps)
        x_adv.clamp_(min=0, max=1)

    return normalize(x_adv.detach(), mean, std)

def fgsm_untargeted(model, dataset, x_batch, label, eps):
    """
    Performs the Fast Gradient Sign Method (FGSM) for untargeted adversarial attacks.

    Parameters:
    model (torch.nn.Module): The model to attack.
    dataset (str): The name of the dataset used (can be cifar10, cifar100 or imagenet).
    x_batch (torch.Tensor): The input tensor.
    label (torch.Tensor): The true labels for the input tensor.
    eps (float): The step size for the FGSM attack.

    Returns:
    torch.Tensor: The adversarially perturbed input tensor.
    """
    mean, std = MEAN_DICT[dataset]/255, STD_DICT[dataset]/255

    x = x_batch.clone().detach_()
    x.requires_grad_()
    model.zero_grad()
    model(x)
    loss = torch.nn.CrossEntropyLoss()(model(x), label)
    loss.backward()
    perturbation = eps * x.grad.sign()

    out = denormalize(x, mean, std) + perturbation
    out = out.clamp_(min=0, max=1)
        
    return normalize(out, mean, std)

In [None]:
def test_adversarial(model, dataset, loader, eps, mode, model_name=None):
    model.eval()
    total_adv_acc, total_adv_top5 = AverageMeter(), AverageMeter()

    for ims, targs in tqdm(loader, desc="Evaluation"):
        if mode =="fgsm":
            adv_ims = fgsm_untargeted(model, dataset, ims, targs, eps)
        elif mode == "pgd":
            adv_ims = pgd(model, dataset, ims, targs, eps=eps, k=5, eps_step=eps/2)

        adv_preds = model(adv_ims)
        adv_acc, adv_top5 = topk_acc(adv_preds, targs, k=5, avg=True)
        total_adv_acc.update(adv_acc, ims.shape[0])
        total_adv_top5.update(adv_top5, ims.shape[0])

    return (
        total_adv_acc.get_avg(percentage=True),
        total_adv_top5.get_avg(percentage=True),
    )

In [None]:
dataset_name = 'cifar10'
model_name = 'mlp'
mode = 'fgsm'

adv_acc = []
adv_top5 = []

data_loader, model = get_test_data_and_model(dataset=dataset_name, model=model_name, data_path='/scratch/data/ffcv/')
all_eps = np.arange(0,0.26,0.0125)

for eps in tqdm(all_eps, desc="Evaluating"):
    test_adv_acc, test_adv_top5 = test_adversarial(model, dataset_name, data_loader, eps, mode, model_name)

    adv_acc.append(test_adv_acc)
    adv_top5.append(test_adv_top5)

In [None]:
adv_acc = np.array(adv_acc)
adv_top5 = np.array(adv_top5)

if all_eps == np.arange(0,0.26,0.0125):
    np.save(f'adv_robustness/accuracy_{mode}_{model_name}_{dataset_name}', adv_acc)
    np.save(f'adv_robustness/top5_{mode}_{model_name}_{dataset_name}', adv_top5)
elif all_eps == np.arange(0,0.26,0.0125):
    np.save(f'adv_robustness/accuracy_{mode}_{model_name}_{dataset_name}_zoomedin_', adv_acc)
    np.save(f'adv_robustness/top5_{mode}_{model_name}_{dataset_name}_zoomedin_', adv_top5)