# Deep Learning Fall 2023 Course Project - Zooming in on MLPs

### Imports

In [1]:
import os
import torch
import timm
import detectors

from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from ffcv.fields import BytesField, IntField, RGBImageField
from ffcv.writer import DatasetWriter
from transformers import AutoImageProcessor, AutoModelForImageClassification, ViTFeatureExtractor, ViTForImageClassification

from data_utils.data_stats import *
from data_utils.dataloader import get_loader
from data_utils.dataset_to_beton import get_dataset
from models.networks import get_model
from utils.metrics import topk_acc, real_acc, AverageMeter

### Fetching data loader and model architecture

In [2]:
def get_data_and_model(dataset='cifar10', model='mlp', data_path='/scratch/ffcv/'):
    """
    This function retrieves the data, model and feature extractor (if needed) based on the provided information.

    Parameters:
    dataset (str): The name of the dataset to retrieve (can be cifar10, cifar100 or imagenet).
    model (str): The name of the model to retrieve (can be mlp, cnn or vit; only mlp is supported for dataset imagenet).
    data_path (str): The path to the data.

    Returns:
    data_loader (DataLoader): The retrieved data loader.
    model (Model): The retrieved model.
    feature_extractor (Model): The retrieved feature extractor (only for model vit).

    Raises:
    AssertionError: If the dataset or model is not supported.
    """

    assert dataset in ('cifar10', 'cifar100', 'imagenet'), f'dataset {dataset} is currently not supported by this function'
    assert model in ('mlp', 'cnn', 'vit'), f'model {model} is currently not supported by this function'

    num_classes = CLASS_DICT[dataset]
    eval_batch_size = 1024
    feature_extractor = None

    if dataset == 'imagenet':
        data_resolution = 64
        assert model == 'mlp', f'imagenet dataset is only supported by mlp model'
    else:
        data_resolution = 32

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if device == 'cuda':
        torch.backends.cuda.matmul.allow_tf32 = True

    if model == 'mlp':
        crop_resolution = 64
        architecture = 'B_12-Wi_1024'
        checkpoint = 'in21k_' + dataset

        model = get_model(architecture=architecture, resolution=crop_resolution, num_classes=num_classes, checkpoint=checkpoint)

    if model == 'cnn':
        crop_resolution = 32
        architecture = 'resnet18_' + dataset

        model = timm.create_model(architecture, pretrained=True)

    if model == 'vit':
        crop_resolution = 32

        if dataset == 'cifar10':
            architecture = 'nateraw/vit-base-patch16-224-cifar10'
            feature_extractor = ViTFeatureExtractor.from_pretrained(architecture)
            model = ViTForImageClassification.from_pretrained(architecture)

        elif dataset == 'cifar100':
            architecture = 'Ahmed9275/Vit-Cifar100'
            feature_extractor = AutoImageProcessor.from_pretrained(architecture)
            model = AutoModelForImageClassification.from_pretrained(architecture)

    data_loader = get_loader(
        dataset,
        bs=eval_batch_size,
        mode="test",
        augment=False,
        dev=device,
        mixup=0.0,
        data_path=data_path,
        data_resolution=data_resolution,
        crop_resolution=crop_resolution,
    )

    return data_loader, model, feature_extractor

### Evaluating baseline model accuracy

In [6]:
# Define a test function that evaluates test accuracy
@torch.no_grad()
def test(model, loader, extractor = None):
    model.eval()
    total_acc, total_top5 = AverageMeter(), AverageMeter()

    for ims, targs in tqdm(loader, desc="Evaluation"):
        # ims = torch.reshape(ims, (ims.shape[0], -1))
        if extractor is not None:
            transformed_ims = [transforms.ToPILImage()(im) for im in ims]
            inputs = extractor(images=transformed_ims, return_tensors="pt")
            outputs = model(**inputs)
            preds = outputs.logits
        else:
            preds = model(ims)
            
        acc, top5 = topk_acc(preds, targs, k=5, avg=True)

        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])

    return (
        total_acc.get_avg(percentage=True),
        total_top5.get_avg(percentage=True),
    )

In [7]:
data_loader, model, feature_extractor = get_data_and_model(dataset='cifar10', model='cnn', data_path='/scratch/ffcv/')
test_acc, test_top5 = test(model, data_loader, feature_extractor)

# Print all the stats
print("Test Accuracy        ", "{:.4f}".format(test_acc))
print("Top 5 Test Accuracy          ", "{:.4f}".format(test_top5))

Loading /scratch/ffcv/cifar10/val_32.beton


Evaluation: 100%|██████████| 10/10 [04:29<00:00, 27.00s/it]

Test Accuracy         94.4600
Top 5 Test Accuracy           99.8300





### Evaluate adversarial accuracy

In [3]:
def pgd(model, x_batch, target, k, eps, eps_step):
    loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
    x_adv = x_batch + eps * (2*torch.rand_like(x_batch) - 1)
    x_adv.clamp_(min=0., max=1.)
    
    for _ in range(k):
        x_adv.detach_().requires_grad_()

        model.zero_grad()
        out = model(x_adv)
        loss_fn(out, target).backward()
    
        step = eps_step * x_adv.grad.sign()
        x_adv = x_batch + (x_adv + step - x_batch).clamp_(min=-eps, max=eps)

        x_adv.clamp_(min=0, max=1)

    return x_adv.detach()

def fgsm_untargeted(model, x, label, eps, clip_min=None, clip_max=None):
    input_ = x.clone().detach_()
    input_.requires_grad_()

    logits = model(input_)
    model.zero_grad()
    loss = torch.nn.CrossEntropyLoss()(logits, label)
    loss.backward()
    
    out = input_ + eps * input_.grad.sign()
    
    if (clip_min is not None) or (clip_max is not None):
        out.clamp_(min=clip_min, max=clip_max)
        
    return out


In [6]:
def test_adversarial(model, loader, eps, mode, epsStep = None):
    model.eval()
    total_acc, total_top5 = AverageMeter(), AverageMeter()
    total_accFGSM, total_top5FGSM = AverageMeter(), AverageMeter()

    for ims, targs in tqdm(loader, desc="Evaluation"):
        ims = torch.reshape(ims, (ims.shape[0], -1))
        targs = targs
        if mode =="fgsm":
            imsFGSM = fgsm_untargeted(model, ims, targs, eps, clip_min=None, clip_max=None)
        if mode == "pgd":
            imsFGSM = pgd(model, ims, targs, 5, eps, epsStep)

        preds = model(ims)
        predsFGSM = model(imsFGSM)
   
        acc, top5 = topk_acc(preds, targs, k=5, avg=True)
        accFGSM, top5FGSM = topk_acc(predsFGSM, targs, k=5, avg=True)

        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])

        total_accFGSM.update(accFGSM, ims.shape[0])
        total_top5FGSM.update(top5FGSM, ims.shape[0])

    return (
        total_acc.get_avg(percentage=True),
        total_top5.get_avg(percentage=True),
        total_accFGSM.get_avg(percentage=True),
        total_top5FGSM.get_avg(percentage=True),
    )

In [7]:
data_loader, model, feature_extractor = get_data_and_model(dataset='cifar10', model='mlp', data_path='/scratch/ffcv/')
test_adversarial(model, data_loader, 0.05, 'fgsm')

In [None]:
data_loader, model, feature_extractor = get_data_and_model(dataset='cifar10', model='cnn', data_path='/scratch/ffcv/')
test_adversarial(model, data_loader, 0.05, 'fgsm')
