In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim
import torchvision.models as models
import eagerpy as ep
from foolbox import PyTorchModel, accuracy, samples
from foolbox.attacks import PGD, FGSM
import matplotlib.pyplot as plt


from src.models  import CifarResNet, MNIST_CNN, CIFAR_CNN
from src.helpers import evaluate_rob_accuracy, evaluate_clean_accuracy, load_model, safe_model,_evaluate_model
from src.data_loader import load_torchvision_dataset, load_imagenette
#from src.pruning import identify_layers, _evaluate_sparsity

import time

if torch.cuda.is_available() == True:
    device = torch.device("cuda:0")
else:
    device = torch.device("cpu")
print(device)
dtype = torch.float32

cuda:0


In [2]:
def PGD(model, data_loader, criterion, steps, max_stepsize, eps, device):
    model.eval()
    advs = []
    correct = 0
    total = 0
    for i, data in enumerate(data_loader):
        if i == i:
            inputs, labels = data
            inputs, labels =inputs.to(device), labels.to(device)

            adv_examples = inputs
            adv_examples.requires_grad = True
            adv_examples.retain_grad()
            for step in range(steps):
                #print(torch.max(adv_examples[0]-inputs[0][0]))
                adv_examples, pert = FGSM_step(model, adv_examples, labels, criterion, max_stepsize, device)
                pert = adv_examples - inputs
                pert.clamp_(-eps, eps)
                adv_examples = inputs + pert
                adv_examples.clamp_(0,1)
            advs.append(adv_examples)
            preds = model(adv_examples)
            #pred_labels = 
            _, predicted = torch.max(preds.data, 1)
            total += len(predicted)
            #correct += (pred_labels == labels).sum().item()
            correct += (predicted != labels).sum().item()
    return advs, correct/total
        

def FGSM_step(model, inputs, labels, criterion, eps, device):

    inputs.retain_grad()
    perturbation = torch.zeros_like(inputs).to(device)
    preds = model(inputs)
    loss = criterion(preds, labels)
    loss.backward(retain_graph=True)
    perturbation = torch.sign(inputs.grad).clamp_(-eps, eps)
    adv_examples = inputs + perturbation
    adv_examples.clamp_(0,1)
    return adv_examples, perturbation
    

def FGSM(model, data_loader, criterion, eps, device):
    model.eval()
    #mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
    #mean = torch.tensor(mean).view(3,1,1).expand(3,32,32).to(device)
    #std = torch.tensor(std).view(3,1,1).expand(3,32,32).to(device)
    advs = []
    correct = 0
    total = 0
    for i,data in enumerate(data_loader):
        if i < 10:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            inputs.requires_grad = True
            adv_examples, perturbation = FGSM_step(model, inputs, labels, criterion, eps, device)

            advs.append(adv_examples)
            preds = model(adv_examples)
            #pred_labels = 
            _, predicted = torch.max(preds.data, 1)
            total += len(predicted)
            #correct += (pred_labels == labels).sum().item()
            correct += (predicted != labels).sum().item()

    
    return advs, correct/total

In [3]:
model = CifarResNet()
model.to(device)
train_loader, test_loader = load_torchvision_dataset('CIFAR10', data_augmentation=True, batchsize = 128)

identifying layers
Files already downloaded and verified
Files already downloaded and verified


In [None]:
train_stats = model.fit_fast(train_loader, test_loader, 100, device)

fast adversarial training
[1,     1] loss: 6.02972, adv_train_accuracy: 0.00, clean_train_accuracy : 12.50
[1,     6] loss: 23.97994, adv_train_accuracy: 10.16, clean_train_accuracy : 14.06
[1,    11] loss: 12.16122, adv_train_accuracy: 7.81, clean_train_accuracy : 10.16
[1,    16] loss: 5.40884, adv_train_accuracy: 3.91, clean_train_accuracy : 17.19
[1,    21] loss: 4.10593, adv_train_accuracy: 12.50, clean_train_accuracy : 7.03
[1,    26] loss: 3.54417, adv_train_accuracy: 7.03, clean_train_accuracy : 14.84
[1,    31] loss: 3.18705, adv_train_accuracy: 8.59, clean_train_accuracy : 17.97
[1,    36] loss: 2.91934, adv_train_accuracy: 8.59, clean_train_accuracy : 17.19
[1,    41] loss: 2.45873, adv_train_accuracy: 12.50, clean_train_accuracy : 26.56
[1,    46] loss: 2.44347, adv_train_accuracy: 13.28, clean_train_accuracy : 26.56
[1,    51] loss: 2.43216, adv_train_accuracy: 6.25, clean_train_accuracy : 25.78
[1,    56] loss: 2.55790, adv_train_accuracy: 13.28, clean_train_accuracy : 26

[2,    96] loss: 1.93943, adv_train_accuracy: 25.78, clean_train_accuracy : 35.94
[2,   101] loss: 2.04456, adv_train_accuracy: 23.44, clean_train_accuracy : 34.38
[2,   106] loss: 2.00830, adv_train_accuracy: 25.00, clean_train_accuracy : 34.38
[2,   111] loss: 2.18131, adv_train_accuracy: 19.53, clean_train_accuracy : 32.03
[2,   116] loss: 2.21146, adv_train_accuracy: 18.75, clean_train_accuracy : 35.94
[2,   121] loss: 2.20801, adv_train_accuracy: 21.09, clean_train_accuracy : 29.69
[2,   126] loss: 2.09440, adv_train_accuracy: 29.69, clean_train_accuracy : 39.84
[2,   131] loss: 1.97813, adv_train_accuracy: 32.03, clean_train_accuracy : 42.97
[2,   136] loss: 2.00226, adv_train_accuracy: 30.47, clean_train_accuracy : 40.62
[2,   141] loss: 1.99764, adv_train_accuracy: 18.75, clean_train_accuracy : 32.03
[2,   146] loss: 2.08147, adv_train_accuracy: 25.78, clean_train_accuracy : 34.38
[2,   151] loss: 2.13385, adv_train_accuracy: 22.66, clean_train_accuracy : 38.28
[2,   156] loss:

[3,   191] loss: 1.87925, adv_train_accuracy: 25.78, clean_train_accuracy : 39.84
[3,   196] loss: 2.11075, adv_train_accuracy: 24.22, clean_train_accuracy : 40.62
[3,   201] loss: 2.00470, adv_train_accuracy: 22.66, clean_train_accuracy : 34.38
[3,   206] loss: 1.88838, adv_train_accuracy: 25.78, clean_train_accuracy : 41.41
[3,   211] loss: 1.93359, adv_train_accuracy: 25.78, clean_train_accuracy : 47.66
[3,   216] loss: 1.90549, adv_train_accuracy: 22.66, clean_train_accuracy : 39.84
[3,   221] loss: 1.88701, adv_train_accuracy: 27.34, clean_train_accuracy : 43.75
[3,   226] loss: 1.87451, adv_train_accuracy: 25.00, clean_train_accuracy : 46.88
[3,   231] loss: 1.79554, adv_train_accuracy: 31.25, clean_train_accuracy : 54.69
[3,   236] loss: 1.95877, adv_train_accuracy: 28.91, clean_train_accuracy : 39.84
[3,   241] loss: 1.74557, adv_train_accuracy: 32.81, clean_train_accuracy : 51.56
[3,   246] loss: 1.80214, adv_train_accuracy: 35.16, clean_train_accuracy : 50.78
[3,   251] loss:

[4,   286] loss: 1.89068, adv_train_accuracy: 28.91, clean_train_accuracy : 45.31
[4,   291] loss: 1.84280, adv_train_accuracy: 33.59, clean_train_accuracy : 52.34
[4,   296] loss: 1.68853, adv_train_accuracy: 31.25, clean_train_accuracy : 53.91
[4,   301] loss: 1.88098, adv_train_accuracy: 28.91, clean_train_accuracy : 42.19
[4,   306] loss: 1.82418, adv_train_accuracy: 33.59, clean_train_accuracy : 47.66
[4,   311] loss: 1.82862, adv_train_accuracy: 28.91, clean_train_accuracy : 46.88
[4,   316] loss: 1.84265, adv_train_accuracy: 32.81, clean_train_accuracy : 51.56
[4,   321] loss: 1.75573, adv_train_accuracy: 30.47, clean_train_accuracy : 50.78
[4,   326] loss: 1.66568, adv_train_accuracy: 27.34, clean_train_accuracy : 55.47
[4,   331] loss: 1.76398, adv_train_accuracy: 31.25, clean_train_accuracy : 53.91
[4,   336] loss: 1.81587, adv_train_accuracy: 31.25, clean_train_accuracy : 47.66
[4,   341] loss: 1.90723, adv_train_accuracy: 32.03, clean_train_accuracy : 50.00
[4,   346] loss:

[5,   381] loss: 1.64703, adv_train_accuracy: 35.16, clean_train_accuracy : 51.56
[5,   386] loss: 1.83796, adv_train_accuracy: 34.38, clean_train_accuracy : 51.56
[5,   391] loss: 1.78908, adv_train_accuracy: 33.75, clean_train_accuracy : 52.50
fgsm robustness: 0.2919921875
pgd robustness: 0.28515625
duration: 66 s - train loss: 1.83108 - train accuracy: 31.08 - validation loss: 1.41064 - validation accuracy: 52.02 
[6,     1] loss: 1.74909, adv_train_accuracy: 31.25, clean_train_accuracy : 57.03
[6,     6] loss: 1.75517, adv_train_accuracy: 44.53, clean_train_accuracy : 56.25
[6,    11] loss: 1.76173, adv_train_accuracy: 29.69, clean_train_accuracy : 47.66
[6,    16] loss: 1.74355, adv_train_accuracy: 35.16, clean_train_accuracy : 53.12
[6,    21] loss: 1.60483, adv_train_accuracy: 39.06, clean_train_accuracy : 55.47
[6,    26] loss: 1.86983, adv_train_accuracy: 32.81, clean_train_accuracy : 50.00
[6,    31] loss: 1.73228, adv_train_accuracy: 40.62, clean_train_accuracy : 56.25
[6,  

[7,    71] loss: 1.65634, adv_train_accuracy: 34.38, clean_train_accuracy : 51.56
[7,    76] loss: 1.73583, adv_train_accuracy: 33.59, clean_train_accuracy : 60.16
[7,    81] loss: 1.64544, adv_train_accuracy: 35.94, clean_train_accuracy : 65.62
[7,    86] loss: 1.61416, adv_train_accuracy: 35.16, clean_train_accuracy : 62.50
[7,    91] loss: 1.69386, adv_train_accuracy: 37.50, clean_train_accuracy : 60.94
[7,    96] loss: 1.60838, adv_train_accuracy: 40.62, clean_train_accuracy : 60.16
[7,   101] loss: 1.65319, adv_train_accuracy: 34.38, clean_train_accuracy : 55.47
[7,   106] loss: 1.66574, adv_train_accuracy: 35.94, clean_train_accuracy : 59.38
[7,   111] loss: 1.68862, adv_train_accuracy: 35.16, clean_train_accuracy : 53.12
[7,   116] loss: 1.67765, adv_train_accuracy: 36.72, clean_train_accuracy : 53.12
[7,   121] loss: 1.79227, adv_train_accuracy: 30.47, clean_train_accuracy : 59.38
[7,   126] loss: 1.64225, adv_train_accuracy: 33.59, clean_train_accuracy : 58.59
[7,   131] loss:

[8,   166] loss: 1.61440, adv_train_accuracy: 40.62, clean_train_accuracy : 63.28
[8,   171] loss: 1.61175, adv_train_accuracy: 37.50, clean_train_accuracy : 60.94
[8,   176] loss: 1.65167, adv_train_accuracy: 40.62, clean_train_accuracy : 60.94
[8,   181] loss: 1.86702, adv_train_accuracy: 33.59, clean_train_accuracy : 53.12
[8,   186] loss: 1.59302, adv_train_accuracy: 38.28, clean_train_accuracy : 54.69
[8,   191] loss: 1.65623, adv_train_accuracy: 35.16, clean_train_accuracy : 61.72
[8,   196] loss: 1.59973, adv_train_accuracy: 44.53, clean_train_accuracy : 60.94
[8,   201] loss: 1.67164, adv_train_accuracy: 36.72, clean_train_accuracy : 57.81
[8,   206] loss: 1.65647, adv_train_accuracy: 37.50, clean_train_accuracy : 60.94
[8,   211] loss: 1.61483, adv_train_accuracy: 42.97, clean_train_accuracy : 61.72
[8,   216] loss: 1.73761, adv_train_accuracy: 34.38, clean_train_accuracy : 60.16
[8,   221] loss: 1.58694, adv_train_accuracy: 41.41, clean_train_accuracy : 65.62
[8,   226] loss:

[9,   261] loss: 1.55008, adv_train_accuracy: 40.62, clean_train_accuracy : 64.06
[9,   266] loss: 1.62982, adv_train_accuracy: 38.28, clean_train_accuracy : 62.50
[9,   271] loss: 1.63142, adv_train_accuracy: 40.62, clean_train_accuracy : 61.72
[9,   276] loss: 1.45562, adv_train_accuracy: 44.53, clean_train_accuracy : 64.84
[9,   281] loss: 1.51778, adv_train_accuracy: 39.06, clean_train_accuracy : 61.72
[9,   286] loss: 1.53470, adv_train_accuracy: 49.22, clean_train_accuracy : 67.19
[9,   291] loss: 1.56022, adv_train_accuracy: 42.19, clean_train_accuracy : 60.94
[9,   296] loss: 1.71897, adv_train_accuracy: 38.28, clean_train_accuracy : 59.38
[9,   301] loss: 1.75672, adv_train_accuracy: 35.16, clean_train_accuracy : 53.12
[9,   306] loss: 1.73727, adv_train_accuracy: 37.50, clean_train_accuracy : 56.25
[9,   311] loss: 1.80370, adv_train_accuracy: 30.47, clean_train_accuracy : 48.44
[9,   316] loss: 1.51787, adv_train_accuracy: 39.06, clean_train_accuracy : 64.06
[9,   321] loss:

[10,   351] loss: 1.53168, adv_train_accuracy: 41.41, clean_train_accuracy : 71.09
[10,   356] loss: 1.54066, adv_train_accuracy: 37.50, clean_train_accuracy : 65.62
[10,   361] loss: 1.65942, adv_train_accuracy: 42.19, clean_train_accuracy : 61.72
[10,   366] loss: 1.47596, adv_train_accuracy: 43.75, clean_train_accuracy : 66.41
[10,   371] loss: 1.41473, adv_train_accuracy: 40.62, clean_train_accuracy : 67.97
[10,   376] loss: 1.52782, adv_train_accuracy: 35.94, clean_train_accuracy : 68.75
[10,   381] loss: 1.54243, adv_train_accuracy: 39.06, clean_train_accuracy : 57.81
[10,   386] loss: 1.62695, adv_train_accuracy: 36.72, clean_train_accuracy : 57.81
[10,   391] loss: 1.63712, adv_train_accuracy: 32.50, clean_train_accuracy : 58.75
fgsm robustness: 0.3349609375
pgd robustness: 0.306640625
duration: 66 s - train loss: 1.56950 - train accuracy: 40.26 - validation loss: 1.08637 - validation accuracy: 62.60 
[11,     1] loss: 1.63439, adv_train_accuracy: 36.72, clean_train_accuracy : 

[12,    36] loss: 1.47895, adv_train_accuracy: 45.31, clean_train_accuracy : 68.75
[12,    41] loss: 1.56241, adv_train_accuracy: 38.28, clean_train_accuracy : 65.62
[12,    46] loss: 1.61782, adv_train_accuracy: 35.16, clean_train_accuracy : 69.53
[12,    51] loss: 1.48918, adv_train_accuracy: 39.84, clean_train_accuracy : 67.97
[12,    56] loss: 1.43232, adv_train_accuracy: 41.41, clean_train_accuracy : 68.75
[12,    61] loss: 1.59325, adv_train_accuracy: 40.62, clean_train_accuracy : 66.41
[12,    66] loss: 1.38067, adv_train_accuracy: 50.00, clean_train_accuracy : 68.75
[12,    71] loss: 1.52036, adv_train_accuracy: 46.88, clean_train_accuracy : 72.66
[12,    76] loss: 1.53542, adv_train_accuracy: 39.84, clean_train_accuracy : 69.53
[12,    81] loss: 1.62356, adv_train_accuracy: 44.53, clean_train_accuracy : 67.97
[12,    86] loss: 1.42203, adv_train_accuracy: 51.56, clean_train_accuracy : 68.75
[12,    91] loss: 1.56241, adv_train_accuracy: 39.84, clean_train_accuracy : 69.53
[12,

[13,   126] loss: 1.43151, adv_train_accuracy: 39.06, clean_train_accuracy : 68.75
[13,   131] loss: 1.43780, adv_train_accuracy: 43.75, clean_train_accuracy : 70.31
[13,   136] loss: 1.43585, adv_train_accuracy: 42.97, clean_train_accuracy : 73.44
[13,   141] loss: 1.58169, adv_train_accuracy: 39.06, clean_train_accuracy : 64.84
[13,   146] loss: 1.50875, adv_train_accuracy: 43.75, clean_train_accuracy : 76.56
[13,   151] loss: 1.41247, adv_train_accuracy: 46.88, clean_train_accuracy : 71.88
[13,   156] loss: 1.39940, adv_train_accuracy: 44.53, clean_train_accuracy : 73.44
[13,   161] loss: 1.61090, adv_train_accuracy: 39.06, clean_train_accuracy : 66.41
[13,   166] loss: 1.58837, adv_train_accuracy: 42.19, clean_train_accuracy : 67.97
[13,   171] loss: 1.57530, adv_train_accuracy: 40.62, clean_train_accuracy : 69.53
[13,   176] loss: 1.55639, adv_train_accuracy: 42.19, clean_train_accuracy : 71.09
[13,   181] loss: 1.57504, adv_train_accuracy: 41.41, clean_train_accuracy : 65.62
[13,

[14,   216] loss: 1.50912, adv_train_accuracy: 39.84, clean_train_accuracy : 62.50
[14,   221] loss: 1.40475, adv_train_accuracy: 51.56, clean_train_accuracy : 71.09
[14,   226] loss: 1.47505, adv_train_accuracy: 41.41, clean_train_accuracy : 67.19
[14,   231] loss: 1.33665, adv_train_accuracy: 38.28, clean_train_accuracy : 76.56
[14,   236] loss: 1.36819, adv_train_accuracy: 42.19, clean_train_accuracy : 71.09
[14,   241] loss: 1.49759, adv_train_accuracy: 43.75, clean_train_accuracy : 70.31
[14,   246] loss: 1.40536, adv_train_accuracy: 41.41, clean_train_accuracy : 79.69
[14,   251] loss: 1.41514, adv_train_accuracy: 40.62, clean_train_accuracy : 71.88
[14,   256] loss: 1.47551, adv_train_accuracy: 42.97, clean_train_accuracy : 67.97
[14,   261] loss: 1.44425, adv_train_accuracy: 41.41, clean_train_accuracy : 76.56
[14,   266] loss: 1.51098, adv_train_accuracy: 45.31, clean_train_accuracy : 64.84
[14,   271] loss: 1.32064, adv_train_accuracy: 51.56, clean_train_accuracy : 79.69
[14,

[15,   306] loss: 1.51015, adv_train_accuracy: 44.53, clean_train_accuracy : 70.31
[15,   311] loss: 1.52793, adv_train_accuracy: 48.44, clean_train_accuracy : 71.09
[15,   316] loss: 1.35136, adv_train_accuracy: 55.47, clean_train_accuracy : 69.53
[15,   321] loss: 1.35061, adv_train_accuracy: 46.88, clean_train_accuracy : 69.53
[15,   326] loss: 1.44624, adv_train_accuracy: 46.09, clean_train_accuracy : 71.88
[15,   331] loss: 1.39382, adv_train_accuracy: 49.22, clean_train_accuracy : 71.88
[15,   336] loss: 1.31800, adv_train_accuracy: 53.12, clean_train_accuracy : 73.44
[15,   341] loss: 1.29543, adv_train_accuracy: 46.88, clean_train_accuracy : 80.47
[15,   346] loss: 1.34911, adv_train_accuracy: 51.56, clean_train_accuracy : 77.34
[15,   351] loss: 1.36224, adv_train_accuracy: 46.88, clean_train_accuracy : 73.44
[15,   356] loss: 1.43852, adv_train_accuracy: 45.31, clean_train_accuracy : 76.56
[15,   361] loss: 1.49538, adv_train_accuracy: 43.75, clean_train_accuracy : 66.41
[15,

In [None]:
PATH = './saved-models/cifar-resnet-fast-100-epochs.pth'
torch.save({
        'epoch': 100,
        'model_state_dict': model.state_dict(),
        }, PATH)

In [None]:
train_stats1 = model.fit_fast(train_loader, test_loader, 100, device)

In [None]:
PATH = './saved-models/cifar-resnet-fast-200-epochs.pth'
torch.save({
        'epoch': 200,
        'model_state_dict': model.state_dict(),
        }, PATH)

In [None]:
train_stats2 = model.fit_fast(train_loader, test_loader, 100, device)

In [None]:
PATH = './saved-models/cifar-resnet-fast-300-epochs.pth'
torch.save({
        'epoch': 300,
        'model_state_dict': model.state_dict(),
        }, PATH)

In [None]:
train_stats3 = model.fit_fast(train_loader, test_loader, 100, device)

In [None]:
PATH = './saved-models/cifar-resnet-fast-400-epochs.pth'
torch.save({
        'epoch': 400,
        'model_state_dict': model.state_dict(),
        }, PATH)

In [None]:
train_stats4 = model.fit_fast(train_loader, test_loader, 100, device)

In [None]:
PATH = './saved-models/cifar-resnet-fast-500-epochs.pth'
torch.save({
        'epoch': 500,
        'model_state_dict': model.state_dict(),
        }, PATH)

In [None]:
train_stats5 = model.fit_fast(train_loader, test_loader, 100, device)

In [None]:
PATH = './saved-models/cifar-resnet-fast-600-epochs.pth'
torch.save({
        'epoch': 600,
        'model_state_dict': model.state_dict(),
        }, PATH)

In [None]:
train_stats3 = model.fit_fast(train_loader, test_loader, 100, device)

In [None]:
PATH = './saved-models/cifar-resnet-fast-700-epochs.pth'
torch.save({
        'epoch': 700,
        'model_state_dict': model.state_dict(),
        }, PATH)

In [None]:
checkpoint = torch.load('./saved-models/cifar-resnet-fast-100-epochs.pth')
checkpoint

In [None]:
model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
evaluate_clean_accuracy(model,test_loader, device)

In [None]:
x = next(iter(test_loader))[0][1]
plt.imshow(x.T)

In [None]:
f = f_adv[0][1].cpu().detach()
plt.imshow(f.T)

In [None]:
p = p_adv[0][1].cpu().detach()
plt.imshow(p.T)

In [None]:
torch.max(x - f)

In [None]:
torch.min(x - p)

In [None]:
7*2/255

In [None]:
f_adv, success = FGSM(model, test_loader, torch.nn.CrossEntropyLoss(), 8/255, device)
success

In [None]:
p_adv, success = PGD(model, test_loader, torch.nn.CrossEntropyLoss(), 7, 2/255, 8/255, device)
success

In [None]:
fmodel = PyTorchModel(model, bounds=(0, 1))
images, labels = next(iter(test_loader))
images, labels = images.to(device), labels.to(device)


In [None]:
clean_acc = accuracy(fmodel, images, labels)
print(f"clean accuracy:  {clean_acc * 100:.1f} %")

In [None]:
attack = FGSM()
epsilon = 8/255
raw_advs, clipped_advs, success = attack(fmodel, images, labels, epsilons=epsilon)

In [None]:
robust_accuracy = 1 - success.float().mean(axis=-1)

In [None]:
robust_accuracy

In [None]:
train_stats = model.fit_fast(train_loader, test_loader, 10, device, patience=10)

In [None]:
_, success = FGSM(model, test_loader, torch.nn.CrossEntropyLoss(), 8/255, device)
success

In [None]:
_, success = PGD(model, test_loader, torch.nn.CrossEntropyLoss(), 7, 2/255, 8/255, device)
success

In [None]:
train_stats = model.fit_fast(train_loader, test_loader, 10, device, patience=10)

In [None]:
_, success = FGSM(model, test_loader, torch.nn.CrossEntropyLoss(), 8/255, device)
success

In [None]:
_, success = PGD(model, test_loader, torch.nn.CrossEntropyLoss(), 7, 2/255, 8/255, device)
success

In [None]:
train_stats = model.fit_fast(train_loader, test_loader, 10, device, patience=10)

In [None]:
_, success = FGSM(model, test_loader, torch.nn.CrossEntropyLoss(), 8/255, device)
success

In [None]:
_, success = PGD(model, test_loader, torch.nn.CrossEntropyLoss(), 7, 2/255, 8/255, device)
success

In [None]:
train_stats = model.fit_fast(train_loader, test_loader, 10, device, patience=10)

In [None]:
_, success = FGSM(model, test_loader, torch.nn.CrossEntropyLoss(), 8/255, device)
success

In [None]:
_, success = PGD(model, test_loader, torch.nn.CrossEntropyLoss(), 7, 2/255, 8/255, device)
success

In [None]:

import torch.nn.functional as F
from torch import clamp

criterion = torch.nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters())

alpha=10

cifar10_mean = (0.4914, 0.4822, 0.4465)
cifar10_std = (0.2471, 0.2435, 0.2616)

mu = torch.tensor(cifar10_mean).view(3,1,1).cuda()
std = torch.tensor(cifar10_std).view(3,1,1).cuda()

upper_limit = ((1 - mu)/ std)
lower_limit = ((0 - mu)/ std)

def loc_train(model, epochs, epsilon):
    for epoch in range(epochs):
        start_epoch_time = time.time()
        train_loss = 0
        train_acc = 0
        train_n = 0
        for i, (X, y) in enumerate(train_loader):
            X, y = X.cuda(), y.cuda()
            if i == 0:
                first_batch = (X, y)
            delta = torch.zeros_like(X).cuda()
            

            delta.uniform_(-epsilon, epsilon)
            delta.data = clamp(delta, 0,1)
            
            delta.requires_grad = True
            output = model(X + delta[:X.size(0)])
            loss = F.cross_entropy(output, y)
            loss.backward()
            grad = delta.grad.detach()
            delta.data = clamp(delta + alpha * torch.sign(grad), -epsilon, epsilon)
            delta.data[:X.size(0)] = clamp(delta[:X.size(0)], 0,1)
            delta = delta.detach()
            output = model(X + delta[:X.size(0)])
            loss = criterion(output, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
            train_loss += loss.item() * y.size(0)
            train_acc += (output.max(1)[1] == y).sum().item()
            train_n += y.size(0)
            print(train_acc/train_n, train_loss/train_n)
        epoch_time = time.time()
    return model


In [None]:
model = loc_train(model, 10, 8/255)

In [None]:
evaluate_clean_accuracy(model, test_loader, device)

In [None]:
_, success = FGSM(model, test_loader, torch.nn.CrossEntropyLoss(), 8/255, device)
success

In [None]:
_, success = PGD(model, test_loader, torch.nn.CrossEntropyLoss(), 7, 2/255, 8/255, device)
success

In [None]:
model, data_loader, criterion, steps, max_stepsize, eps, device = model, test_loader, torch.nn.CrossEntropyLoss(), 7, 2/255, 8/255, device

model.eval()
advs = []
correct = 0
total = 0
for i, data in enumerate(data_loader):
    print(i)
    if i==0:
        inputs, labels = data
        inputs, labels =inputs.to(device), labels.to(device)
        adv_examples = inputs
        #adv_inputs = inputs
        adv_examples.requires_grad = True
        #perturbation = torch.zeros_like(adv_inputs, requires_grad=True).to(device)
        for step in range(steps):

            perturbation = torch.zeros_like(inputs, requires_grad=True).to(device)
            preds = model(adv_examples)
            loss = criterion(preds, labels)
            adv_examples.retain_grad()
            loss.backward()

            perturbation = torch.sign(adv_examples.grad).clamp_(-eps, eps)
            print(inputs)
            adv_examples = adv_examples + perturbation
            #adv_examples, pert = FGSM_step(model, adv_examples, labels, criterion, max_stepsize, device)
            adv_examples.clamp_(-eps, eps)

        advs.append(adv_examples)


