# Adversarial Robustness of MLP, ViT and CNN on CIFAR-10 and CIFAR-100

### Imports

In [2]:
import time

import torch
import timm
from tqdm import tqdm
from torchvision import transforms
import torch.nn as nn
import torchvision.transforms as T

from data_utils.data_stats import *
from data_utils.dataloader import get_loader
from models.networks import get_model
from utils.metrics import topk_acc, AverageMeter

### Fetching data loader and model architecture

In [9]:
def get_data_and_model(dataset, model, data_path='/scratch/ffcv/', split='test', batch_size=100):
    """
    This function retrieves the data, model and feature extractor (if needed) based on the provided information.

    Parameters:
    dataset (str): The name of the dataset to retrieve (can be cifar10, cifar100 or imagenet).
    model (str): The name of the model to retrieve (can be mlp, cnn or vit; only mlp is supported for dataset imagenet).
    data_path (str): The path to the data.

    Returns (as a tuple):
    data_loader (DataLoader): The retrieved data loader.
    model (Model): The retrieved model.

    Raises:
    AssertionError: If the dataset or model is not supported.
    """

    assert dataset in ('cifar10', 'cifar100', 'imagenet'), f'dataset {dataset} is currently not supported by this function'
    assert model in ('mlp', 'cnn', 'vit'), f'model {model} is currently not supported by this function'

    num_classes = CLASS_DICT[dataset]
    eval_batch_size = batch_size
 
    if dataset == 'imagenet':
        data_resolution = 64
        assert model == 'mlp', f'imagenet dataset is only supported by mlp model'
    else:
        data_resolution = 32

    crop_resolution = data_resolution

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if device == 'cuda':
        torch.backends.cuda.matmul.allow_tf32 = True

    if model == 'mlp':
        architecture = 'B_12-Wi_1024'
        checkpoint = 'in21k_' + dataset
        model = get_model(architecture=architecture, resolution=64, num_classes=num_classes, checkpoint=checkpoint)

    if model == 'cnn':
        architecture = 'resnet18_' + dataset
        model = timm.create_model(architecture, pretrained=True)

    if model == 'vit':
        architecture = 'vit_small_patch16_224_' + dataset + '_v7.pth'
        model = torch.load(architecture)

    data_loader = get_loader(
        dataset,
        bs=eval_batch_size,
        mode=split,
        augment=split == 'train',
        dev=device,
        mixup=0.0,
        data_path=data_path,
        data_resolution=data_resolution,
        crop_resolution=crop_resolution,
    )

    return data_loader, model

In [4]:
class Reshape(torch.nn.Module): 
    def __init__(self, shape=224): 
        super(Reshape, self).__init__()
        self.shape = shape 
        
    def forward(self, x): 
        shape = self.shape
        x = transforms.functional.resize(x, size=(shape, shape))
        if shape == 64:
            bs = x.shape[0]
            x = torch.reshape(x, shape=(bs,-1,))
        return x

### Evaluating baseline model accuracy

In [5]:
dataset_name = 'cifar100'
model_name = 'mlp'

train_data_loader, _ = get_data_and_model(dataset=dataset_name, model=model_name, data_path='/scratch/data/ffcv/', split='train', batch_size=30000)
_, model = get_data_and_model(dataset='cifar10', model=model_name, data_path='/scratch/data/ffcv/', split='train', batch_size=30000)

model = nn.Sequential(Reshape(64), model)

Weights already downloaded
Load_state output <All keys matched successfully>
Loading /scratch/data/ffcv/cifar100/train_32.beton
Weights already downloaded
Load_state output <All keys matched successfully>
Loading /scratch/data/ffcv/cifar10/train_32.beton


In [6]:
optimizer = torch.optim.AdamW(model.parameters())
loss_function = nn.CrossEntropyLoss()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Train weights for the linear layer for 5 epochs, resizing the input images
num_epochs = 5

for _ in tqdm(range(num_epochs)):
    for ims, targs in train_data_loader:
        idx = (targs==0) | (targs==1) | (targs==2) | (targs==3) | (targs==4) | (targs==5) | (targs==6) | (targs==7) | (targs==8) | (targs==9)
        ims, targs = ims[idx], targs[idx]
        optimizer.zero_grad()
        outputs = model(ims)
        loss = loss_function(outputs, targs)
        loss.backward()
        optimizer.step()

100%|██████████| 5/5 [01:52<00:00, 22.59s/it]


In [7]:
# Define a test function that evaluates test accuracy
@torch.no_grad()
def test(model, loader):
    total_acc, total_top5 = AverageMeter(), AverageMeter()
    
    model.eval()

    for ims, targs in tqdm(loader, desc="Evaluation"):
        idx = (targs==0) | (targs==1) | (targs==2) | (targs==3) | (targs==4) | (targs==5) | (targs==6) | (targs==7) | (targs==8) | (targs==9)
        ims, targs = ims[idx], targs[idx]
        preds = model(ims)
        acc, top5 = topk_acc(preds, targs, k=5, avg=True)
        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])

    return (
        total_acc.get_avg(percentage=True),
        total_top5.get_avg(percentage=True),
    )

In [10]:
test_data_loader, _ = get_data_and_model(dataset=dataset_name, model=model_name, data_path='/scratch/data/ffcv/', split='test')
test_acc, test_top5 = test(model, test_data_loader)

# Print all the stats
print("Test Accuracy        ", "{:.4f}".format(test_acc))
print("Top 5 Test Accuracy          ", "{:.4f}".format(test_top5))

Weights already downloaded
Load_state output <All keys matched successfully>
Loading /scratch/data/ffcv/cifar100/test_32.beton


Evaluation:   0%|          | 0/100 [00:00<?, ?it/s]Exception ignored in: <finalize object at 0x7fd044f6ffc0; dead>
Traceback (most recent call last):
  File "/home/apouget/miniconda3/envs/ffcv/lib/python3.9/weakref.py", line 591, in __call__
    return info.func(*info.args, **(info.kwargs or {}))
  File "/home/apouget/miniconda3/envs/ffcv/lib/python3.9/site-packages/numba/core/dispatcher.py", line 312, in finalizer
    for cres in overloads.values():
KeyError: (Array(uint8, 1, 'C', True, aligned=True), Array(uint8, 1, 'C', True, aligned=True), uint32, uint32, uint32, uint32, Literal[int](0), Literal[int](0), Literal[int](1), Literal[int](1), Literal[bool](False), Literal[bool](False))
Evaluation:  44%|████▍     | 44/100 [00:06<00:06,  8.55it/s]