In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim

from src.models import CifarResNet, MNIST_CNN, CIFAR_CNN
from src.helpers import evaluate_rob_accuracy, evaluate_clean_accuracy, load_model, safe_model,_evaluate_model
from src.data_loader import load_torchvision_dataset, load_imagenette
#from src.pruning import identify_layers, _evaluate_sparsity

import time

if torch.cuda.is_available() == True:
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
print(device)
dtype = torch.float32

In [None]:
import pandas as pd

In [None]:
pd.read_pickle('./results/preliminary-double.pkl')

In [None]:

evaluate_clean_accuracy(model.to(device), test_loader, device)

In [None]:
from foolbox import PyTorchModel, accuracy, samples
from foolbox.attacks import LinfPGD, FGSM, L0BrendelBethgeAttack, L2CarliniWagnerAttack

epochs = 500


def run(training_method):
    model = CIFAR_CNN().to(device)
    #print(model.device)
    compression_rates = [1,2,4,8,16]
    stats = {}
    for ratio in compression_rates:
        print('compression rate: ', 1-1/ratio)
        fit = get_train_method(model, training_method)
        model.prune_magnitude_global_unstruct(1-1/ratio, device)
        #print(fit)
        train_data = fit(train_loader, test_loader, epochs, device, eps=8/255, patience=5)
        images, labels = next(iter(test_loader))
        images, labels = images.to(device), labels.to(device)
        #stats[f'ratio']['l0_robustness'] = bb_attack(model, images, labels).item()
        #print('bb done')
        stats[f'{ratio}'] = {}
        stats[f'{ratio}']['l2_robustness'] = cw_attack(model, images, labels).item()
        print('cw done')
        stats[f'{ratio}']['linf_robustness'] = pgd_attack(model, images, labels).item()
        print('pgd done')
        stats[f'{ratio}']['clean_accuracy'] = train_data['val_accuracy']
        
    return(stats)
        
        
        

def get_train_method(model, method):
    if method=='standard':
        return model.fit
    if method=='free':
        return model.fit_free
    if method=='fast':
        return model.fit_fast
    if method=='fast_double':
        return model.fit_fast_with_double_update

def bb_attack(model, images, labels, eps=8/255):
    model.eval()
    fmodel = PyTorchModel(model, bounds=(0, 1))
    attack = L0BrendelBethgeAttack()
    raw_advs, clipped_advs, success = attack(fmodel, images, labels, epsilons=eps)
    model.train()

    return (1 - torch.sum(success)/len(success)) / 100

def cw_attack(model, images, labels, eps=8/255):
    model.eval()
    fmodel = PyTorchModel(model, bounds=(0, 1))
    attack = L2CarliniWagnerAttack()
    raw_advs, clipped_advs, success = attack(fmodel, images, labels, epsilons=eps)
    model.train()

    return (1 - torch.sum(success)/len(success)) / 100

def pgd_attack(model, images, labels, eps=8/255):
    model.eval()
    fmodel = PyTorchModel(model, bounds=(0, 1))
    attack = LinfPGD()
    raw_advs, clipped_advs, success = attack(fmodel, images, labels, epsilons=eps)
    model.train()

    return (1 - torch.sum(success)/len(success)) / 100

# Double Update vs. Single Update Fast Adversarial Training
Specs:

CIFAR CNN: 4 Conv (16,16,32,32) with batchnorm, 2 FC (128,10)

Data: Cifar (32,32,3)

30 Epochs with eps=8/255



Standard Fast Adv. Training:
Clean: 63.05%
Robust: 59.34%

Fast Adv. Training w Double Update:
Clean: 63.99%
Robust: 61.35%




# Initialization

In [None]:
model = CIFAR_CNN()

In [None]:
train_loader, test_loader = load_torchvision_dataset('CIFAR10')

In [None]:
PATH = './saved-models/CIFAR-baseline-150-epochs.pth'
model = load_model(model, PATH)

# Experiment


1. Prune
2. Train
3. measure robust accuracy


In [None]:
standard_stats = run('standard')

In [None]:
standard_stats['1']['l2_robustness']

In [None]:
import pandas as pd
df = pd.DataFrame(data=standard_stats).T
df['l2_robustness'] = df['l2_robustness']*10000
df['linf_robustness'] = df['linf_robustness']*10000
df.to_pickle('./results/preliminary-standard-no-augmentation.pkl')
df

In [None]:
fast_stats = run('fast')

In [None]:
free_stats = run('free')

In [None]:
double_stats = run('fast_double')

# Attacks

In [None]:
images, labels = next(iter(test_loader))
images, labels = images.to(device), labels.to(device)


In [None]:
pgd_attack(model, images, labels, eps=8/255)

In [None]:
cw_attack(model, images, labels, eps=8/255)

In [None]:
bb_attack(model, images, labels)