In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim

from src.models import ResNet, MNIST_CNN, CIFAR_CNN
from src.helpers import evaluate_rob_accuracy, evaluate_clean_accuracy, load_model, safe_model,_evaluate_model
from src.data_loader import load_torchvision_dataset, load_imagenette
#from src.pruning import identify_layers, _evaluate_sparsity

import time

#device = torch.device("cuda:0")
device = torch.device("cpu")

dtype = torch.float32

# Double Update vs. Single Update Fast Adversarial Training
Specs:

CIFAR CNN: 4 Conv (16,16,32,32) with batchnorm, 2 FC (128,10)

Data: Cifar (32,32,3)

30 Epochs with eps=8/255



Standard Fast Adv. Training:
Clean: 63.05%
Robust: 59.34%

Fast Adv. Training w Double Update:
Clean: 63.99%
Robust: 61.35%




# Initialization

In [2]:
#model = ResNet()
#model = MNIST_CNN()
model = CIFAR_CNN()

identifying layers


In [7]:
train_loader, test_loader = load_torchvision_dataset('CIFAR10')

Files already downloaded and verified
Files already downloaded and verified


In [4]:
PATH = './saved-models/CIFAR-baseline-150-epochs.pth'
model = load_model(model, PATH)

# Experiment


1. Prune
2. Train
3. measure robust accuracy


In [None]:
stats = run()

In [9]:
from foolbox import PyTorchModel, accuracy, samples
from foolbox.attacks import LinfPGD, FGSM, L0BrendelBethgeAttack, L2CarliniWagnerAttack

def run():
    model = CIFAR_CNN()
    compression_rates = [1,2,4,8,16]
    stats = {}
    for ratio in compression_rates:
        model.prune_magnitude_global_unstruct(1-1/ratio)
        train_data = model.fit(train_loader, test_loader, 20, device)
        images, labels = next(iter(test_loader))
        stats[f'ratio']['l0_robustness'] = bb_attack(model, images, labels)
        stats[f'ratio']['l2_robustness'] = cw_attack(model, images, labels)
        stats[f'ratio']['linf_robustness'] = pgd_attack(model, images, labels)
        stats[f'ratio']['clean_accuracy'] = train_data['accuracy']
        
    return(stats)
        
        
        
    

In [4]:
def bb_attack(model, images, labels, eps=8/255):
    model.eval()
    fmodel = PyTorchModel(model, bounds=(0, 1))
    attack = L0BrendelBethgeAttack()
    raw_advs, clipped_advs, success = attack(fmodel, images, labels, epsilons=eps)
    model.train()

    return torch.sum(success)/len(success)

def cw_attack(model, images, labels, eps=8/255):
    model.eval()
    fmodel = PyTorchModel(model, bounds=(0, 1))
    attack = L2CarliniWagnerAttack()
    raw_advs, clipped_advs, success = attack(fmodel, images, labels, epsilons=eps)
    model.train()

    return torch.sum(success)/len(success)

def pgd_attack(model, images, labels, eps=8/255):
    model.eval()
    fmodel = PyTorchModel(model, bounds=(0, 1))
    attack = LinfPGD()
    raw_advs, clipped_advs, success = attack(fmodel, images, labels, epsilons=eps)
    model.train()

    return torch.sum(success)/len(success)

In [8]:
a = torch.rand((5,2))

In [9]:
torch.sum(a)

tensor(5.4796)

In [18]:
len(images)

512

# Test

In [4]:
model.fit_free(train_loader, test_loader, 1, device, number_of_replays=3, eps = 8/255)

[1,     1] loss: 4.13015, train_accuracy: 4.17
[1,     2] loss: 3.66484, train_accuracy: 4.88
[1,     3] loss: 3.12940, train_accuracy: 6.97
[1,     4] loss: 2.90560, train_accuracy: 8.98
[1,     5] loss: 2.72787, train_accuracy: 9.70
[1,     6] loss: 2.53809, train_accuracy: 10.22
[1,     7] loss: 2.40960, train_accuracy: 15.62
[1,     8] loss: 2.43231, train_accuracy: 14.13
[1,     9] loss: 2.44054, train_accuracy: 14.78
[1,    10] loss: 2.36526, train_accuracy: 17.90
[1,    11] loss: 2.37196, train_accuracy: 19.01
[1,    12] loss: 2.28873, train_accuracy: 18.49
[1,    13] loss: 2.31310, train_accuracy: 14.58
[1,    14] loss: 2.30206, train_accuracy: 18.55
[1,    15] loss: 2.27567, train_accuracy: 15.89
[1,    16] loss: 2.21592, train_accuracy: 18.03
[1,    17] loss: 2.19046, train_accuracy: 20.05
[1,    18] loss: 2.15967, train_accuracy: 23.11
[1,    19] loss: 2.23406, train_accuracy: 18.68
[1,    20] loss: 2.25862, train_accuracy: 17.45
[1,    21] loss: 2.12274, train_accuracy: 23.

{'criterion': CrossEntropyLoss(),
 'optimizer': Adam (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     eps: 1e-08
     lr: 0.001
     weight_decay: 0
 ),
 'hist': 'Not implemented',
 'val_accuracy': 34.82}

In [5]:
model.fit_fast(train_loader, test_loader, 1, device,eps = 8/255)

[1,     1] loss: 1.96565, train_accuracy: 28.91
[1,    11] loss: 1.97713, train_accuracy: 27.93
[1,    21] loss: 1.93998, train_accuracy: 27.34
[1,    31] loss: 1.85866, train_accuracy: 31.05
[1,    41] loss: 1.93476, train_accuracy: 29.88
[1,    51] loss: 1.81361, train_accuracy: 33.79
[1,    61] loss: 1.79394, train_accuracy: 34.96
[1,    71] loss: 1.87157, train_accuracy: 32.42
[1,    81] loss: 1.81604, train_accuracy: 33.98
[1,    91] loss: 1.78834, train_accuracy: 31.84
duration: 255 s - train loss: 1.84690 - train accuracy: 32.50 - validation loss: 1.53757 - validation accuracy: 44.68 
Finished Training


{'criterion': CrossEntropyLoss(),
 'optimizer': Adam (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     eps: 1e-08
     lr: 0.001
     weight_decay: 0
 ),
 'hist': 'Not implemented',
 'val_accuracy': 44.68}

In [6]:
model.fit_fast_with_double_update(train_loader, test_loader, 1, device,eps = 8/255)

[1,     1] loss: 1.70809, train_accuracy: 39.65
[1,    11] loss: 1.82264, train_accuracy: 33.59
[1,    21] loss: 1.73971, train_accuracy: 32.81
[1,    31] loss: 1.81496, train_accuracy: 35.94
[1,    41] loss: 1.78697, train_accuracy: 31.45
[1,    51] loss: 1.78453, train_accuracy: 33.20
[1,    61] loss: 1.72529, train_accuracy: 35.55
[1,    71] loss: 1.67665, train_accuracy: 38.67
[1,    81] loss: 1.74809, train_accuracy: 35.55
[1,    91] loss: 1.71799, train_accuracy: 38.67
duration: 320 s - train loss: 1.76333 - train accuracy: 35.67 - validation loss: 1.45301 - validation accuracy: 48.06 
Finished Training


{'criterion': CrossEntropyLoss(),
 'optimizer': Adam (
 Parameter Group 0
     amsgrad: False
     betas: (0.9, 0.999)
     eps: 1e-08
     lr: 0.001
     weight_decay: 0
 ),
 'hist': 'Not implemented',
 'val_accuracy': 48.06}