# Adversarial Robustness of MLP, ViT and CNN on CIFAR-10 and CIFAR-100

### Imports

In [None]:
import time

import torch
import timm
from tqdm import tqdm
from torchvision import transforms

from data_utils.data_stats import *
from data_utils.dataloader import get_loader
from models.networks import get_model
from utils.metrics import topk_acc, AverageMeter

### Fetching data loader and model architecture

In [None]:
def get_data_and_model(dataset, model, data_path='/scratch/ffcv/'):
    """
    This function retrieves the data, model and feature extractor (if needed) based on the provided information.

    Parameters:
    dataset (str): The name of the dataset to retrieve (can be cifar10, cifar100 or imagenet).
    model (str): The name of the model to retrieve (can be mlp, cnn or vit; only mlp is supported for dataset imagenet).
    data_path (str): The path to the data.

    Returns (as a tuple):
    data_loader (DataLoader): The retrieved data loader.
    model (Model): The retrieved model.

    Raises:
    AssertionError: If the dataset or model is not supported.
    """

    assert dataset in ('cifar10', 'cifar100', 'imagenet'), f'dataset {dataset} is currently not supported by this function'
    assert model in ('mlp', 'cnn', 'vit'), f'model {model} is currently not supported by this function'

    num_classes = CLASS_DICT[dataset]
    eval_batch_size = 1024
 
    if dataset == 'imagenet':
        data_resolution = 64
        assert model == 'mlp', f'imagenet dataset is only supported by mlp model'
    else:
        data_resolution = 32

    crop_resolution = data_resolution

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if device == 'cuda':
        torch.backends.cuda.matmul.allow_tf32 = True

    if model == 'mlp':
        architecture = 'B_12-Wi_1024'
        checkpoint = 'in21k_' + dataset
        model = get_model(architecture=architecture, resolution=64, num_classes=num_classes, checkpoint=checkpoint)

    if model == 'cnn':
        architecture = 'resnet18_' + dataset
        model = timm.create_model(architecture, pretrained=True)

    if model == 'vit':
        architecture = 'vit_small_patch16_224_' + dataset + '_v7.pth'
        model = torch.load(architecture)

    data_loader = get_loader(
        dataset,
        bs=eval_batch_size,
        mode="test",
        augment=False,
        dev=device,
        mixup=0.0,
        data_path=data_path,
        data_resolution=data_resolution,
        crop_resolution=crop_resolution,
    )

    return data_loader, model

In [None]:
class Reshape(torch.nn.Module): 
    def __init__(self, shape=224): 
        super(Reshape, self).__init__()
        self.shape = shape 
        
    def forward(self, x): 
        shape = self.shape
        x = transforms.functional.resize(x, size=(shape, shape))
        if shape == 64:
            bs = x.shape[0]
            x = torch.reshape(x, shape=(bs,-1,))
        return x

### Evaluating baseline model accuracy

In [None]:
# Define a test function that evaluates test accuracy
@torch.no_grad()
def test(model, loader, model_name=None):
    total_acc, total_top5 = AverageMeter(), AverageMeter()
    if model_name == 'mlp':
        model = torch.nn.Sequential(Reshape(64), model)
    if model_name == 'vit':
        model = torch.nn.Sequential(Reshape(224), model)

    model.eval()

    for ims, targs in tqdm(loader, desc="Evaluation"):
        preds = model(ims)
        acc, top5 = topk_acc(preds, targs, k=5, avg=True)
        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])

    return (
        total_acc.get_avg(percentage=True),
        total_top5.get_avg(percentage=True),
    )

In [None]:
dataset_name = 'cifar10'
model_name = 'mlp'

data_loader, model = get_data_and_model(dataset=dataset_name, model=model_name, data_path='/scratch/data/ffcv/')
test_acc, test_top5 = test(model, data_loader, model_name)

# Print all the stats
print("Test Accuracy        ", "{:.4f}".format(test_acc))
print("Top 5 Test Accuracy          ", "{:.4f}".format(test_top5))

### Evaluate inference time vs accuracy

In [None]:
# Define a test function that evaluates inference time
@torch.no_grad()
def get_inference_time(model, loader, model_name=None):
    if model_name == 'mlp':
        model = torch.nn.Sequential(Reshape(64), model)
    if model_name == 'vit':
        model = torch.nn.Sequential(Reshape(224), model)

    model.eval()

    start = time.time()
    for ims, _ in loader:
        _ = model(ims)
    end = time.time()

    return end-start

In [None]:
dataset_name = 'cifar10'
model_name = 'mlp'

data_loader, model = get_data_and_model(dataset=dataset_name, model=model_name, data_path='/scratch/data/ffcv/')
print(f"Inference time for {model_name} on {dataset_name}: {get_inference_time(model, data_loader, model_name):.4f}")