# Adversarial Robustness of MLP, ViT and CNN on CIFAR-10 and CIFAR-100

### Imports

In [None]:
import time

import torch
from tqdm import tqdm
from torchsummary import summary

from model_utils import get_test_data_and_model
from utils.metrics import topk_acc, AverageMeter

### Evaluating baseline model accuracy

In [None]:
# Define a test function that evaluates test accuracy
@torch.no_grad()
def test(model, loader, model_name=None):
    total_acc, total_top5 = AverageMeter(), AverageMeter()

    model.eval()

    for ims, targs in tqdm(loader, desc="Evaluation"):
        preds = model(ims)
        acc, top5 = topk_acc(preds, targs, k=5, avg=True)
        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])

    return (
        total_acc.get_avg(percentage=True),
        total_top5.get_avg(percentage=True),
    )

In [None]:
dataset_name = 'cifar10'
model_name = 'cnn'

data_loader, model = get_test_data_and_model(dataset=dataset_name, model=model_name, data_path='/scratch/data/ffcv/')
test_acc, test_top5 = test(model, data_loader, model_name)

# Print all the stats
print("Test Accuracy        ", "{:.4f}".format(test_acc))
print("Top 5 Test Accuracy          ", "{:.4f}".format(test_top5))

### Evaluate inference time vs accuracy

In [None]:
# Define a test function that evaluates inference time
@torch.no_grad()
def get_inference_time(model, loader, model_name=None):

    model.eval()

    start = time.time()
    for ims, _ in loader:
        _ = model(ims)
    end = time.time()

    return end-start

In [None]:
dataset_name = 'cifar10'
model_name = 'mlp'

data_loader, model = get_test_data_and_model(dataset=dataset_name, model=model_name, data_path='/scratch/data/ffcv/')
print(f"Inference time for {model_name} on {dataset_name}: {get_inference_time(model, data_loader, model_name):.4f}")

In [None]:
summary(model, (3, 32, 32))