In [9]:
import os
import torch
import timm
import detectors

from PIL import Image
from tqdm import tqdm
from torchvision import transforms
from ffcv.fields import BytesField, IntField, RGBImageField
from ffcv.writer import DatasetWriter
from transformers import AutoImageProcessor, AutoModelForImageClassification, ViTFeatureExtractor, ViTForImageClassification

from data_utils.data_stats import *
from data_utils.dataloader import get_loader
from data_utils.dataset_to_beton import get_dataset
from models.networks import get_model
from utils.metrics import topk_acc, real_acc, AverageMeter

from torchsummary import summary
import matplotlib.pyplot as plt
import time

Create adversarial examples for trainset

In [51]:
def get_data_and_model(dataset, model, data_path='/scratch/ffcv/', partition = 'test', augment = True):
    """
    This function retrieves the data, model and feature extractor (if needed) based on the provided information.

    Parameters:
    dataset (str): The name of the dataset to retrieve (can be cifar10, cifar100 or imagenet).
    model (str): The name of the model to retrieve (can be mlp, cnn or vit; only mlp is supported for dataset imagenet).
    data_path (str): The path to the data.

    Returns (as a tuple):
    data_loader (DataLoader): The retrieved data loader.
    model (Model): The retrieved model.

    Raises:
    AssertionError: If the dataset or model is not supported.
    """

    assert dataset in ('cifar10', 'cifar100', 'imagenet'), f'dataset {dataset} is currently not supported by this function'
    assert model in ('mlp', 'cnn', 'vit'), f'model {model} is currently not supported by this function'

    num_classes = CLASS_DICT[dataset]
    eval_batch_size = 100

    if dataset == 'imagenet':
        data_resolution = 64
        assert model == 'mlp', f'imagenet dataset is only supported by mlp model'
    else:
        data_resolution = 32

    crop_resolution = data_resolution

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if device == 'cuda':
        torch.backends.cuda.matmul.allow_tf32 = True

    if model == 'mlp':
        architecture = 'B_12-Wi_1024'
        checkpoint = 'in21k_' + dataset
        model = get_model(architecture=architecture, resolution=64, num_classes=num_classes, checkpoint=checkpoint)

    if model == 'cnn':
        architecture = 'resnet18_' + dataset
        model = timm.create_model(architecture, pretrained=True)

    if model == 'vit':
        architecture = 'vit_small_patch16_224_' + dataset + '_v7.pth'
        model = torch.load(architecture)

    data_loader = get_loader(
        dataset,
        bs=eval_batch_size,
        mode=partition,
        augment=augment,
        dev=device,
        mixup=0.0,
        data_path=data_path,
        data_resolution=data_resolution,
        crop_resolution=crop_resolution,
    )
    model.cuda()
    return data_loader, model

In [3]:
class Reshape(torch.nn.Module): 
    def __init__(self, shape=224): 
        super(Reshape, self).__init__()
        self.shape = shape 
        
    def forward(self, x): 
        shape = self.shape
        x = transforms.functional.resize(x, size=(shape, shape))
        
        #if shape == 64, its an mlp
        if shape == 64:
            #x = torch.reshape(x, shape=(-1,))
            x = torch.reshape(x, shape=(x.shape[0],-1))
        return x

In [4]:
def denormalize(tensor, mean, std):
    """
    Denormalize a tensor.

    Parameters:
    tensor (torch.Tensor): The tensor to denormalize.
    mean (float or sequence): The mean used for normalization.
    std (float or sequence): The standard deviation used for normalization.

    Returns:
    torch.Tensor: The denormalized tensor.
    """
    return tensor*std[1]+mean[1]

def normalize(tensor, mean, std):
    """
    Normalize a tensor.

    Parameters:
    tensor (torch.Tensor): The tensor to normalize.
    mean (float or sequence): The mean used for normalization.
    std (float or sequence): The standard deviation used for normalization.

    Returns:
    torch.Tensor: The normalized tensor.
    """
    return (tensor-mean[1])/std[1]

def pgd(model, dataset, x_batch, label, eps, k, eps_step):
    """
    Performs the Projected Gradient Descent (PGD) for adversarial attacks.

    Parameters:
    model (torch.nn.Module): The model to attack.
    dataset (str): The name of the dataset used (can be cifar10, cifar100 or imagenet).
    x_batch (torch.Tensor): The input tensor.
    label (torch.Tensor): The true labels for the input tensor.
    eps (float): The maximum perturbation for PGD.
    k (int): The number of steps for PGD.
    eps_step (float): The step size for each iteration.

    Returns:
    torch.Tensor: The adversarially perturbed input tensor.
    """   
    mean, std = MEAN_DICT[dataset]/255, STD_DICT[dataset]/255

    x = x_batch.clone().detach_()
    x = denormalize(x, mean, std)
    x_adv = x + eps * (2*torch.rand_like(x) - 1)
    x_adv.clamp_(min=0., max=1.)
    
    for _ in range(int(k)):
        x_adv = normalize(x_adv, mean, std).detach_()
        x_adv.requires_grad_()
        model.zero_grad()
        loss = torch.nn.CrossEntropyLoss()(model(x_adv), label)
        loss.backward()
        perturbation = eps_step * x_adv.grad.sign()

        x_adv = denormalize(x_adv, mean, std)
        x_adv = x + (x_adv + perturbation - x).clamp_(min=-eps, max=eps)
        x_adv.clamp_(min=0, max=1)


    return normalize(x_adv.detach(), mean, std)

def fgsm_untargeted(model, dataset, x_batch, label, eps):
    """
    Performs the Fast Gradient Sign Method (FGSM) for untargeted adversarial attacks.

    Parameters:
    model (torch.nn.Module): The model to attack.
    dataset (str): The name of the dataset used (can be cifar10, cifar100 or imagenet).
    x_batch (torch.Tensor): The input tensor.
    label (torch.Tensor): The true labels for the input tensor.
    eps (float): The step size for the FGSM attack.

    Returns:
    torch.Tensor: The adversarially perturbed input tensor.
    """
    mean, std = MEAN_DICT[dataset]/255, STD_DICT[dataset]/255

    x = x_batch.clone().detach_()
    x.requires_grad_()
    model.zero_grad()
    loss = torch.nn.CrossEntropyLoss()(model(x), label)
    loss.backward()
    perturbation = eps * x.grad.sign()

    out = denormalize(x, mean, std) + perturbation
    out = out.clamp_(min=0, max=1)
        
    return normalize(out, mean, std)

In [31]:
def test_adversarial(model, dataset, loader, eps, mode, is_mlp = False, is_vit = False, modelname = 'MLP', datasetname = 'cifar10'):
    model.eval()
    total_adv_acc, total_adv_top5 = AverageMeter(), AverageMeter()
    batchnumber = 0
    if is_mlp:
        model = torch.nn.Sequential(Reshape(64), model)
    if is_vit:
        model = torch.nn.Sequential(Reshape(224), model)
        return
    for ims, targs in tqdm(loader, desc="Evaluation"):            
        if mode =="fgsm":
            adv_ims = fgsm_untargeted(model, dataset, ims, targs, eps)
            path = './adv_examples_train/' + modelname + '/' + datasetname + '/' + mode + '/' + str(eps) + 'Batch' + str(batchnumber)
            batchnumber += 1
            torch.save(adv_ims, path)
        if mode == "pgd":
            adv_ims = pgd(model, dataset, ims, targs, eps=eps, k=5, eps_step=eps/2)
            path = './adv_examples_train/' + modelname + '/' + datasetname + '/' + mode + '/' + str(eps) + 'Batch' + str(batchnumber)
            batchnumber += 1
            torch.save(adv_ims, path)
        adv_preds = model(adv_ims)
        adv_acc, adv_top5 = topk_acc(adv_preds, targs, k=5, avg=True)
        total_adv_acc.update(adv_acc, ims.shape[0])
        total_adv_top5.update(adv_top5, ims.shape[0])

    return (
        total_adv_acc.get_avg(percentage=True),
        total_adv_top5.get_avg(percentage=True),
    )

In [54]:
#USE BATCH SIZE 100

#Make adversarial examples for IMAGENET maybe.... Or others, like cifar testset

eps_range = 0.025
steps = 12

torch.cuda.empty_cache()
for modelname in ['mlp']:
    is_mlp = modelname == 'mlp'
    is_vit = modelname == 'vit'
    for datasetname in ['cifar100']:
        for methodname in ['fgsm']:
            adv_acc = []
            adv_top5 = []
            print('Now starting: ' + methodname +' on ' + modelname + ' with ' + datasetname)
            data_loader, model = get_data_and_model(dataset=datasetname, model=modelname, data_path='./beton/', partition='train')
            #all_eps = np.arange(0,0.26,0.0125)
            all_eps = [0.05]
            for eps in tqdm(all_eps, desc="Evaluating"):
                test_adv_acc, test_adv_top5 = test_adversarial(model, datasetname, data_loader, eps, methodname, is_vit = is_vit, is_mlp= is_mlp, modelname= modelname, datasetname= datasetname)
                adv_acc.append(test_adv_acc)
                adv_top5.append(test_adv_top5)

            #name = methodname + '_' + modelname + '_' + datasetname + '_zoomedin_'
            #np.save('accuracy_' + name, adv_acc)
            #np.save('top5_' + name, adv_top5)

Now starting: fgsm on mlp with cifar100
Weights already downloaded
Load_state output <All keys matched successfully>
Loading ./beton/cifar100\ffcv\train\train_32.beton


Evaluation: 100%|██████████| 500/500 [00:17<00:00, 28.62it/s]
Evaluating: 100%|██████████| 1/1 [00:17<00:00, 17.47s/it]


In [43]:
def train(model, opt, scheduler, loss_fn, epoch, train_loader, mode, args):
    start = time.time()
    model.train()

    total_acc, total_top5 = AverageMeter(), AverageMeter()
    total_loss = AverageMeter()

    for step, (ims, targs) in enumerate(tqdm(train_loader, desc="Training epoch: " + str(epoch))):
        #ims = torch.reshape(ims, (ims.shape[0], -1))

        #load adversarial examples
        path ='./adv_examples_train/' + modelname + '/' + datasetname + '/' + mode + '/' + str(eps) + 'Batch' + str(step)
        ims_adversarial = torch.load(path).cuda()
        #ims_adversarial = torch.reshape(ims_adversarial, (ims.shape[0], -1))

        ims = ims.cuda()
        #concat the two tensors so we can backprop in one go
        ims = torch.concat((ims,ims_adversarial), 0)
        targs = torch.concat((targs,targs), 0)
        preds = model(ims).cuda()

        loss = loss_fn(preds, targs)
        targs_perm = None

        acc, top5 = topk_acc(preds, targs, targs_perm, k=5, avg=True)
        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])
      
        loss.backward()
        
        opt.step()
        opt.zero_grad()

        total_loss.update(loss.item(), ims.shape[0])

    end = time.time()

    scheduler.step()
# Define a test function that evaluates test accuracy
@torch.no_grad()
def test(model, loader):

    model.eval()
    total_acc, total_top5 = AverageMeter(), AverageMeter()

    for ims, targs in tqdm(loader, desc="Evaluation"):
        
        ims = ims.cuda()
        
        preds = model(ims).cuda()
        targs = targs.to("cuda")
        acc = real_acc(preds, targs, k=5, avg=True)
        top5 = 0

        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])


    return (
        total_acc.get_avg(percentage=True),
        total_top5.get_avg(percentage=True),
    )





In [58]:
#train the network on the adversarial examples.
modelname = 'mlp'
optimizer = torch.optim.AdamW(model.parameters())
#hani afach n random gnoh wells eh ned viel epochs sind
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 10,0.05)
loss_function = torch.nn.CrossEntropyLoss()
datasetname = 'imagenet'
epochs = 10
data_loader, model = get_data_and_model(dataset=datasetname, model=modelname, data_path='./beton/', partition='train')
model = torch.nn.Sequential(Reshape(64), model)
mode = 'fgsm'
for x in range(epochs): 
    train(model, optimizer, scheduler, loss_function, x, data_loader, mode, args = None)

Weights already downloaded
Load_state output <All keys matched successfully>
Loading ./beton/imagenet\ffcv\train\train_64.beton


Training epoch: 0: 100%|██████████| 12811/12811 [10:40<00:00, 19.99it/s]


In [59]:
data_loader, m = get_data_and_model(dataset=datasetname, model=modelname, data_path='./beton/', partition='test', augment = False)
test_acc, test_top5 = test(model, data_loader)

# Print all the stats
print("Test Accuracy        ", "{:.4f}".format(test_acc))
print("Top 5 Test Accuracy          ", "{:.4f}".format(test_top5))

Weights already downloaded
Load_state output <All keys matched successfully>
Loading ./beton/imagenet\ffcv\val\val_64.beton


Evaluation: 100%|██████████| 500/500 [00:05<00:00, 87.73it/s] 

Test Accuracy         44.1060
Top 5 Test Accuracy           0.0000





In [57]:
#save the model
name = 'adversarialTrained_' + str(epochs) +'_epochs_' + modelname +'on_' + datasetname
torch.save(model, name +'.pth')
