In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.optim as optim

from src.models import ResNet, MNIST_CNN, CIFAR_CNN
from src.helpers import evaluate_rob_accuracy, evaluate_clean_accuracy, load_model, safe_model,_evaluate_model
from src.data_loader import load_torchvision_dataset, load_imagenette
#from src.pruning import identify_layers, _evaluate_sparsity

import time

#device = torch.device("cuda:0")
device = torch.device("cpu")

dtype = torch.float32

# Initialization

In [16]:
#model = ResNet()
#model = MNIST_CNN()
model = CIFAR_CNN()

identifying layers


In [3]:
train_loader, test_loader = load_torchvision_dataset('CIFAR10')

Files already downloaded and verified
Files already downloaded and verified


In [None]:
PATH = './saved-models/CIFAR-baseline-150-epochs.pth'
model = load_model(model, PATH)

# Before

In [None]:
evaluate_clean_accuracy(model, test_loader, device)

In [None]:
evaluate_rob_accuracy(model, test_loader, device, epsilon=8/255, attack='FGSM')

# After

In [None]:
evaluate_clean_accuracy(model, test_loader, device)

In [None]:
evaluate_rob_accuracy(model, test_loader, device, epsilon=8/255, attack='FGSM')

In [None]:
#from torch.autograd import Variable
import time 


def fit_free(model, train_loader, val_loader , epochs, device, number_of_replays=3, eps = 16/255):
    mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
    mean = torch.tensor(mean).view(3,1,1).expand(3,32,32)
    std = torch.tensor(std).view(3,1,1).expand(3,32,32)
    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters())
    
    pert_storage = torch.zeros((512,3,32,32))
    
    for epoch in range(epochs):  # loop over the dataset multiple times
        t0 = time.time()
        running_loss, acc_epoch_loss, avg_epoch_loss, epoch_accuracy, acc_epoch_accuracy = 0.0, 0.0, 0.0, 0.0, 0.0
        #pert_storage = torch.zeros([512, 3, 32,32])
        for i, data in enumerate(train_loader, 0):
            mini_batch_loss = 0.0
            mini_batch_acc = 0.0
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            
            no_of_samples_in_batch = inputs.shape[0]
            

            # Mini Batch Replays
            for j in range(number_of_replays):
                noise_batch = pert_storage[:no_of_samples_in_batch].detach().clone().requires_grad_(True)
                adv_input = inputs+noise_batch[:no_of_samples_in_batch]
                
                adv_input.clamp_(0, 1.0)
                adv_input.sub_(mean).div_(std)
                
                #print(adv_input[0])

                # forward + backward + optimize
                outputs = model(adv_input)
                loss = criterion(outputs, labels)
                
                # zero the gradients
                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                
                #craft adv pert
                pert = fgsm(noise_batch.grad)
                
                pert_storage[0:no_of_samples_in_batch] += pert
                pert_storage.clamp_(-eps, eps)
                #update weights
                optimizer.step()

                    # print statistics
                accuracy = get_accuracy(labels, outputs)
                mini_batch_loss += loss.item()
                mini_batch_acc += accuracy
                acc_loss += mini_batch_loss
                acc_accuracy += mini_batch_acc
                
        avg_mini_batch_accuracy = mini_batch_acc / number_of_replays    
        avg_mini_batch_loss = mini_batch_loss / number_of_replays
        
            if i%1 == 0:
                print('[%d, %5d] loss: %.5f, train_accuracy: %.2f' %(epoch + 1, i + 1, avg_mini_batch_loss, avg_mini_batch_accuracy))
            running_loss = 0.0
        avg_epoch_accuracy = acc_epoch_accuracy / (i+1)*number_of_replays    
        avg_epoch_loss = acc_epoch_loss / (i+1)*number_of_replays
        t1 = time.time()
        accuracy, loss = _evaluate_model(model, val_loader, device, criterion)
        print('duration:', t1-t0,'- train loss: ',avg_epoch_loss,' - train accuracy: ',avg_epoch_accuracy,' - validation accuracy: ', accuracy,' - validation loss: ', loss)
        print('duration: %d s - train loss: %.5f - train accuracy: %.2f - validation loss: %.2f - validation accuracy: %.2f ' %(t1-t0, avg_epoch_loss, avg_epoch_accuracy, loss, accuracy))
        
    print('Finished Training')
    return {
        'criterion': criterion,
        'optimizer': optimizer,
        'hist': 'Not implemented',
        'val_accuracy': accuracy
    }

def get_accuracy(labels, outputs):
    _, predicted = torch.max(outputs.data, 1)

    total = labels.size(0)
    correct = (predicted == labels).sum().item()
    return 100 * correct / total

def fgsm(gradients, step_size=.05):
    return step_size*torch.sign(gradients)

In [None]:
fit_free(model, train_loader, test_loader, 5, device, number_of_replays=3, eps = 16/255)

In [9]:
def fit_fast(model, train_loader, val_loader , epochs, device, eps = 8/255):
    mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
    mean = torch.tensor(mean).view(3,1,1).expand(3,32,32)
    std = torch.tensor(std).view(3,1,1).expand(3,32,32)

    optimizer = optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss().to(device)
    for epoch in range(epochs):
        t0 = time.time()
        running_loss, acc_epoch_loss, avg_epoch_loss, epoch_accuracy, acc_epoch_accuracy = 0.0, 0.0, 0.0, 0.0, 0.0
        
        for i, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            pert = torch.rand_like(inputs, requires_grad=True)
            adv_inputs = inputs + pert
            adv_inputs.clamp_(0, 1.0)
            adv_inputs.sub_(mean).div_(std)
            #clip 0,1
            
            # first backwards pass to perform fgsm
            outputs = model(adv_inputs)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            pert = pert + (eps * pert.grad)
            pert.clamp_(-eps, eps)
            adv_inputs = inputs + pert
            
            # second backwards pass to update weights on adv.
            optimizer.zero_grad()
            outputs = model(adv_inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            accuracy = get_accuracy(labels, outputs)
            acc_epoch_loss += loss.item() 
            avg_epoch_loss = acc_epoch_loss / (i+1)
            acc_epoch_accuracy += accuracy
            avg_epoch_accuracy = acc_epoch_accuracy / (i+1)
            if i%1 == 0:
                print('[%d, %5d] loss: %.5f, train_accuracy: %.2f' %(epoch + 1, i + 1, loss.item(), accuracy))

        t1 = time.time()
        accuracy, loss = _evaluate_model(model, val_loader, device, criterion)

        print('duration: %d s - train loss: %.5f - train accuracy: %.2f - validation loss: %.5f - validation accuracy: %.2f ' %(t1-t0, avg_epoch_loss, avg_epoch_accuracy, loss, accuracy))

def get_accuracy(labels, outputs):
    _, predicted = torch.max(outputs.data, 1)

    total = labels.size(0)
    correct = (predicted == labels).sum().item()
    return 100 * correct / total
            
            

In [10]:
fit_fast(model, train_loader, test_loader, 30, device)

[1,     1] loss: 2.71994, train_accuracy: 10.94
[1,     2] loss: 2.61876, train_accuracy: 11.52
[1,     3] loss: 2.53445, train_accuracy: 14.45
[1,     4] loss: 2.59530, train_accuracy: 14.65
[1,     5] loss: 2.36077, train_accuracy: 16.21
[1,     6] loss: 2.42841, train_accuracy: 18.36
[1,     7] loss: 2.32086, train_accuracy: 18.95
[1,     8] loss: 2.23608, train_accuracy: 19.14
[1,     9] loss: 2.29884, train_accuracy: 20.51
[1,    10] loss: 2.16112, train_accuracy: 22.46
[1,    11] loss: 2.22307, train_accuracy: 22.27
[1,    12] loss: 2.22280, train_accuracy: 21.88
[1,    13] loss: 2.20852, train_accuracy: 23.63
[1,    14] loss: 2.12756, train_accuracy: 23.05
[1,    15] loss: 2.12157, train_accuracy: 22.27
[1,    16] loss: 2.08616, train_accuracy: 23.05
[1,    17] loss: 2.03543, train_accuracy: 26.56
[1,    18] loss: 2.10694, train_accuracy: 23.63
[1,    19] loss: 2.14892, train_accuracy: 22.07
[1,    20] loss: 2.06685, train_accuracy: 26.76
[1,    21] loss: 2.14502, train_accuracy

[2,    72] loss: 1.76227, train_accuracy: 38.09
[2,    73] loss: 1.86249, train_accuracy: 30.86
[2,    74] loss: 1.81873, train_accuracy: 32.62
[2,    75] loss: 1.85440, train_accuracy: 33.98
[2,    76] loss: 1.83829, train_accuracy: 30.66
[2,    77] loss: 1.81324, train_accuracy: 34.18
[2,    78] loss: 1.81008, train_accuracy: 31.64
[2,    79] loss: 1.84453, train_accuracy: 33.59
[2,    80] loss: 1.80842, train_accuracy: 33.20
[2,    81] loss: 1.85633, train_accuracy: 33.59
[2,    82] loss: 1.81781, train_accuracy: 31.25
[2,    83] loss: 1.81516, train_accuracy: 36.52
[2,    84] loss: 1.78259, train_accuracy: 39.06
[2,    85] loss: 1.76345, train_accuracy: 34.38
[2,    86] loss: 1.74368, train_accuracy: 35.55
[2,    87] loss: 1.81582, train_accuracy: 33.98
[2,    88] loss: 1.70904, train_accuracy: 37.11
[2,    89] loss: 1.81688, train_accuracy: 33.20
[2,    90] loss: 1.83271, train_accuracy: 33.20
[2,    91] loss: 1.73621, train_accuracy: 37.11
[2,    92] loss: 1.87143, train_accuracy

[4,    42] loss: 1.73618, train_accuracy: 39.26
[4,    43] loss: 1.76022, train_accuracy: 35.55
[4,    44] loss: 1.68306, train_accuracy: 38.87
[4,    45] loss: 1.71588, train_accuracy: 36.91
[4,    46] loss: 1.74166, train_accuracy: 36.33
[4,    47] loss: 1.71446, train_accuracy: 36.72
[4,    48] loss: 1.74213, train_accuracy: 38.87
[4,    49] loss: 1.69143, train_accuracy: 36.52
[4,    50] loss: 1.68440, train_accuracy: 38.48
[4,    51] loss: 1.69154, train_accuracy: 35.94
[4,    52] loss: 1.69785, train_accuracy: 38.09
[4,    53] loss: 1.74708, train_accuracy: 39.06
[4,    54] loss: 1.75130, train_accuracy: 36.91
[4,    55] loss: 1.66772, train_accuracy: 40.43
[4,    56] loss: 1.66232, train_accuracy: 38.28
[4,    57] loss: 1.68866, train_accuracy: 38.87
[4,    58] loss: 1.69049, train_accuracy: 38.48
[4,    59] loss: 1.68906, train_accuracy: 41.02
[4,    60] loss: 1.67436, train_accuracy: 39.65
[4,    61] loss: 1.73356, train_accuracy: 39.26
[4,    62] loss: 1.69483, train_accuracy

[6,    12] loss: 1.58701, train_accuracy: 44.14
[6,    13] loss: 1.61550, train_accuracy: 41.02
[6,    14] loss: 1.64899, train_accuracy: 43.55
[6,    15] loss: 1.63424, train_accuracy: 39.84
[6,    16] loss: 1.58944, train_accuracy: 44.14
[6,    17] loss: 1.63284, train_accuracy: 42.58
[6,    18] loss: 1.68672, train_accuracy: 39.26
[6,    19] loss: 1.70584, train_accuracy: 37.11
[6,    20] loss: 1.59851, train_accuracy: 42.97
[6,    21] loss: 1.62118, train_accuracy: 40.23
[6,    22] loss: 1.69037, train_accuracy: 36.91
[6,    23] loss: 1.60955, train_accuracy: 41.41
[6,    24] loss: 1.64143, train_accuracy: 41.02
[6,    25] loss: 1.69225, train_accuracy: 39.26
[6,    26] loss: 1.63175, train_accuracy: 40.43
[6,    27] loss: 1.68015, train_accuracy: 40.43
[6,    28] loss: 1.58959, train_accuracy: 41.41
[6,    29] loss: 1.64923, train_accuracy: 42.38
[6,    30] loss: 1.58470, train_accuracy: 44.14
[6,    31] loss: 1.66978, train_accuracy: 39.45
[6,    32] loss: 1.58536, train_accuracy

[7,    83] loss: 1.53868, train_accuracy: 44.92
[7,    84] loss: 1.56656, train_accuracy: 44.34
[7,    85] loss: 1.59706, train_accuracy: 42.38
[7,    86] loss: 1.57147, train_accuracy: 41.41
[7,    87] loss: 1.67428, train_accuracy: 40.43
[7,    88] loss: 1.55674, train_accuracy: 44.34
[7,    89] loss: 1.59070, train_accuracy: 41.80
[7,    90] loss: 1.55652, train_accuracy: 45.12
[7,    91] loss: 1.55550, train_accuracy: 40.82
[7,    92] loss: 1.56896, train_accuracy: 42.19
[7,    93] loss: 1.55221, train_accuracy: 44.92
[7,    94] loss: 1.65357, train_accuracy: 42.97
[7,    95] loss: 1.65610, train_accuracy: 41.99
[7,    96] loss: 1.56854, train_accuracy: 45.90
[7,    97] loss: 1.54567, train_accuracy: 44.34
[7,    98] loss: 1.58721, train_accuracy: 43.15
duration: 311 s - train loss: 1.59302 - train accuracy: 42.80 - validation loss: 1.30843 - validation accuracy: 53.80 
[8,     1] loss: 1.62711, train_accuracy: 43.55
[8,     2] loss: 1.55907, train_accuracy: 45.12
[8,     3] loss: 

[9,    53] loss: 1.52709, train_accuracy: 47.07
[9,    54] loss: 1.53556, train_accuracy: 42.97
[9,    55] loss: 1.45234, train_accuracy: 49.02
[9,    56] loss: 1.65688, train_accuracy: 40.62
[9,    57] loss: 1.47487, train_accuracy: 45.90
[9,    58] loss: 1.50177, train_accuracy: 44.92
[9,    59] loss: 1.56171, train_accuracy: 45.31
[9,    60] loss: 1.61408, train_accuracy: 40.04
[9,    61] loss: 1.53746, train_accuracy: 45.12
[9,    62] loss: 1.52420, train_accuracy: 45.51
[9,    63] loss: 1.56056, train_accuracy: 47.07
[9,    64] loss: 1.56901, train_accuracy: 42.58
[9,    65] loss: 1.45738, train_accuracy: 45.90
[9,    66] loss: 1.52738, train_accuracy: 46.29
[9,    67] loss: 1.53548, train_accuracy: 44.92
[9,    68] loss: 1.57493, train_accuracy: 43.55
[9,    69] loss: 1.59838, train_accuracy: 43.55
[9,    70] loss: 1.55435, train_accuracy: 44.92
[9,    71] loss: 1.55522, train_accuracy: 45.51
[9,    72] loss: 1.52632, train_accuracy: 47.27
[9,    73] loss: 1.42662, train_accuracy

[11,    21] loss: 1.47770, train_accuracy: 44.73
[11,    22] loss: 1.55279, train_accuracy: 42.38
[11,    23] loss: 1.52089, train_accuracy: 47.85
[11,    24] loss: 1.53443, train_accuracy: 42.38
[11,    25] loss: 1.53428, train_accuracy: 42.19
[11,    26] loss: 1.51575, train_accuracy: 45.31
[11,    27] loss: 1.50365, train_accuracy: 49.22
[11,    28] loss: 1.52067, train_accuracy: 44.53
[11,    29] loss: 1.43697, train_accuracy: 46.68
[11,    30] loss: 1.48955, train_accuracy: 47.85
[11,    31] loss: 1.52549, train_accuracy: 45.51
[11,    32] loss: 1.46419, train_accuracy: 48.05
[11,    33] loss: 1.48341, train_accuracy: 48.44
[11,    34] loss: 1.50194, train_accuracy: 46.29
[11,    35] loss: 1.50634, train_accuracy: 44.92
[11,    36] loss: 1.49921, train_accuracy: 43.75
[11,    37] loss: 1.52623, train_accuracy: 46.48
[11,    38] loss: 1.51475, train_accuracy: 42.58
[11,    39] loss: 1.68965, train_accuracy: 39.26
[11,    40] loss: 1.57328, train_accuracy: 41.80
[11,    41] loss: 1.

[12,    88] loss: 1.48327, train_accuracy: 49.02
[12,    89] loss: 1.53877, train_accuracy: 43.75
[12,    90] loss: 1.42880, train_accuracy: 47.27
[12,    91] loss: 1.47767, train_accuracy: 45.70
[12,    92] loss: 1.53381, train_accuracy: 44.73
[12,    93] loss: 1.46201, train_accuracy: 48.44
[12,    94] loss: 1.42248, train_accuracy: 49.61
[12,    95] loss: 1.54301, train_accuracy: 46.68
[12,    96] loss: 1.55282, train_accuracy: 43.95
[12,    97] loss: 1.45085, train_accuracy: 48.63
[12,    98] loss: 1.54800, train_accuracy: 40.48
duration: 248 s - train loss: 1.49050 - train accuracy: 46.36 - validation loss: 1.21115 - validation accuracy: 57.74 
[13,     1] loss: 1.50183, train_accuracy: 44.92
[13,     2] loss: 1.45920, train_accuracy: 45.51
[13,     3] loss: 1.53477, train_accuracy: 43.55
[13,     4] loss: 1.53524, train_accuracy: 43.36
[13,     5] loss: 1.52964, train_accuracy: 44.34
[13,     6] loss: 1.55263, train_accuracy: 45.31
[13,     7] loss: 1.44774, train_accuracy: 48.24

[14,    55] loss: 1.44709, train_accuracy: 48.24
[14,    56] loss: 1.50530, train_accuracy: 48.24
[14,    57] loss: 1.44458, train_accuracy: 50.00
[14,    58] loss: 1.47619, train_accuracy: 48.24
[14,    59] loss: 1.49019, train_accuracy: 47.07
[14,    60] loss: 1.48553, train_accuracy: 45.70
[14,    61] loss: 1.44604, train_accuracy: 48.44
[14,    62] loss: 1.46936, train_accuracy: 49.22
[14,    63] loss: 1.44202, train_accuracy: 49.61
[14,    64] loss: 1.53113, train_accuracy: 44.73
[14,    65] loss: 1.44732, train_accuracy: 49.80
[14,    66] loss: 1.47255, train_accuracy: 46.48
[14,    67] loss: 1.42806, train_accuracy: 48.83
[14,    68] loss: 1.46718, train_accuracy: 47.85
[14,    69] loss: 1.44826, train_accuracy: 48.24
[14,    70] loss: 1.47909, train_accuracy: 44.14
[14,    71] loss: 1.49136, train_accuracy: 47.46
[14,    72] loss: 1.42486, train_accuracy: 48.83
[14,    73] loss: 1.45955, train_accuracy: 48.05
[14,    74] loss: 1.55220, train_accuracy: 43.75
[14,    75] loss: 1.

[16,    22] loss: 1.51048, train_accuracy: 45.70
[16,    23] loss: 1.43940, train_accuracy: 49.22
[16,    24] loss: 1.47919, train_accuracy: 49.02
[16,    25] loss: 1.40866, train_accuracy: 47.85
[16,    26] loss: 1.40640, train_accuracy: 50.98
[16,    27] loss: 1.40383, train_accuracy: 49.61
[16,    28] loss: 1.45240, train_accuracy: 49.22
[16,    29] loss: 1.46592, train_accuracy: 48.05
[16,    30] loss: 1.48962, train_accuracy: 44.53
[16,    31] loss: 1.35672, train_accuracy: 49.41
[16,    32] loss: 1.43077, train_accuracy: 48.24
[16,    33] loss: 1.35557, train_accuracy: 53.52
[16,    34] loss: 1.38205, train_accuracy: 49.61
[16,    35] loss: 1.36384, train_accuracy: 50.78
[16,    36] loss: 1.40021, train_accuracy: 48.63
[16,    37] loss: 1.41831, train_accuracy: 49.02
[16,    38] loss: 1.38891, train_accuracy: 49.02
[16,    39] loss: 1.49110, train_accuracy: 46.48
[16,    40] loss: 1.50119, train_accuracy: 45.12
[16,    41] loss: 1.40189, train_accuracy: 49.80
[16,    42] loss: 1.

[17,    89] loss: 1.43328, train_accuracy: 49.22
[17,    90] loss: 1.48703, train_accuracy: 45.51
[17,    91] loss: 1.42229, train_accuracy: 45.70
[17,    92] loss: 1.47844, train_accuracy: 44.53
[17,    93] loss: 1.38381, train_accuracy: 49.80
[17,    94] loss: 1.46234, train_accuracy: 48.63
[17,    95] loss: 1.48034, train_accuracy: 46.29
[17,    96] loss: 1.40007, train_accuracy: 47.27
[17,    97] loss: 1.35713, train_accuracy: 51.56
[17,    98] loss: 1.36535, train_accuracy: 50.60
duration: 206 s - train loss: 1.42897 - train accuracy: 48.74 - validation loss: 1.16103 - validation accuracy: 59.60 
[18,     1] loss: 1.49609, train_accuracy: 44.14
[18,     2] loss: 1.41737, train_accuracy: 50.59
[18,     3] loss: 1.37069, train_accuracy: 50.20
[18,     4] loss: 1.40700, train_accuracy: 50.00
[18,     5] loss: 1.46402, train_accuracy: 48.05
[18,     6] loss: 1.44220, train_accuracy: 47.85
[18,     7] loss: 1.49557, train_accuracy: 46.88
[18,     8] loss: 1.42554, train_accuracy: 46.68

[19,    56] loss: 1.41390, train_accuracy: 49.22
[19,    57] loss: 1.32802, train_accuracy: 53.71
[19,    58] loss: 1.51096, train_accuracy: 43.95
[19,    59] loss: 1.42181, train_accuracy: 48.05
[19,    60] loss: 1.38823, train_accuracy: 50.78
[19,    61] loss: 1.42689, train_accuracy: 49.80
[19,    62] loss: 1.33178, train_accuracy: 53.12
[19,    63] loss: 1.47310, train_accuracy: 46.88
[19,    64] loss: 1.38703, train_accuracy: 52.54
[19,    65] loss: 1.40486, train_accuracy: 51.76
[19,    66] loss: 1.51451, train_accuracy: 44.53
[19,    67] loss: 1.38968, train_accuracy: 52.34
[19,    68] loss: 1.44401, train_accuracy: 49.41
[19,    69] loss: 1.37541, train_accuracy: 50.98
[19,    70] loss: 1.40527, train_accuracy: 50.00
[19,    71] loss: 1.38859, train_accuracy: 48.44
[19,    72] loss: 1.47928, train_accuracy: 49.41
[19,    73] loss: 1.40956, train_accuracy: 50.98
[19,    74] loss: 1.50131, train_accuracy: 48.44
[19,    75] loss: 1.37557, train_accuracy: 52.34
[19,    76] loss: 1.

[21,    23] loss: 1.39318, train_accuracy: 51.56
[21,    24] loss: 1.34784, train_accuracy: 51.17
[21,    25] loss: 1.36889, train_accuracy: 49.80
[21,    26] loss: 1.31355, train_accuracy: 53.32
[21,    27] loss: 1.50495, train_accuracy: 44.53
[21,    28] loss: 1.37583, train_accuracy: 50.78
[21,    29] loss: 1.42740, train_accuracy: 50.59
[21,    30] loss: 1.40718, train_accuracy: 50.59
[21,    31] loss: 1.38395, train_accuracy: 49.80
[21,    32] loss: 1.39858, train_accuracy: 50.78
[21,    33] loss: 1.36348, train_accuracy: 51.37
[21,    34] loss: 1.41605, train_accuracy: 49.41
[21,    35] loss: 1.37295, train_accuracy: 52.73
[21,    36] loss: 1.53956, train_accuracy: 44.73
[21,    37] loss: 1.35611, train_accuracy: 49.02
[21,    38] loss: 1.39824, train_accuracy: 49.41
[21,    39] loss: 1.37087, train_accuracy: 50.39
[21,    40] loss: 1.44813, train_accuracy: 48.83
[21,    41] loss: 1.37327, train_accuracy: 49.80
[21,    42] loss: 1.37024, train_accuracy: 48.24
[21,    43] loss: 1.

[22,    90] loss: 1.37420, train_accuracy: 51.56
[22,    91] loss: 1.43241, train_accuracy: 47.27
[22,    92] loss: 1.37250, train_accuracy: 51.56
[22,    93] loss: 1.41650, train_accuracy: 48.05
[22,    94] loss: 1.35650, train_accuracy: 52.34
[22,    95] loss: 1.34330, train_accuracy: 54.88
[22,    96] loss: 1.37999, train_accuracy: 49.02
[22,    97] loss: 1.34616, train_accuracy: 49.61
[22,    98] loss: 1.53462, train_accuracy: 46.13
duration: 247 s - train loss: 1.38492 - train accuracy: 50.37 - validation loss: 1.09739 - validation accuracy: 61.75 
[23,     1] loss: 1.37841, train_accuracy: 50.20
[23,     2] loss: 1.36544, train_accuracy: 52.15
[23,     3] loss: 1.34147, train_accuracy: 53.91
[23,     4] loss: 1.37322, train_accuracy: 51.76
[23,     5] loss: 1.39601, train_accuracy: 52.54
[23,     6] loss: 1.35780, train_accuracy: 52.54
[23,     7] loss: 1.46345, train_accuracy: 48.63
[23,     8] loss: 1.40890, train_accuracy: 49.41
[23,     9] loss: 1.30874, train_accuracy: 53.52

[24,    57] loss: 1.36182, train_accuracy: 48.05
[24,    58] loss: 1.39581, train_accuracy: 51.56
[24,    59] loss: 1.44826, train_accuracy: 48.63
[24,    60] loss: 1.47228, train_accuracy: 49.41
[24,    61] loss: 1.37938, train_accuracy: 50.78
[24,    62] loss: 1.47112, train_accuracy: 46.09
[24,    63] loss: 1.38379, train_accuracy: 52.93
[24,    64] loss: 1.38299, train_accuracy: 49.02
[24,    65] loss: 1.43811, train_accuracy: 49.02
[24,    66] loss: 1.35885, train_accuracy: 52.34
[24,    67] loss: 1.39705, train_accuracy: 49.61
[24,    68] loss: 1.45891, train_accuracy: 47.07
[24,    69] loss: 1.42565, train_accuracy: 50.00
[24,    70] loss: 1.36252, train_accuracy: 48.05
[24,    71] loss: 1.37556, train_accuracy: 50.59
[24,    72] loss: 1.35249, train_accuracy: 51.37
[24,    73] loss: 1.37189, train_accuracy: 50.00
[24,    74] loss: 1.40233, train_accuracy: 53.52
[24,    75] loss: 1.33488, train_accuracy: 54.10
[24,    76] loss: 1.39068, train_accuracy: 50.20
[24,    77] loss: 1.

[26,    24] loss: 1.36204, train_accuracy: 51.37
[26,    25] loss: 1.41128, train_accuracy: 48.83
[26,    26] loss: 1.27991, train_accuracy: 53.71
[26,    27] loss: 1.39686, train_accuracy: 48.83
[26,    28] loss: 1.25301, train_accuracy: 57.62
[26,    29] loss: 1.36480, train_accuracy: 52.54
[26,    30] loss: 1.38644, train_accuracy: 52.73
[26,    31] loss: 1.30292, train_accuracy: 55.86
[26,    32] loss: 1.33502, train_accuracy: 51.76
[26,    33] loss: 1.39540, train_accuracy: 50.39
[26,    34] loss: 1.37006, train_accuracy: 49.02
[26,    35] loss: 1.41497, train_accuracy: 50.98
[26,    36] loss: 1.27069, train_accuracy: 52.93
[26,    37] loss: 1.43153, train_accuracy: 49.80
[26,    38] loss: 1.32576, train_accuracy: 55.27
[26,    39] loss: 1.35966, train_accuracy: 52.73
[26,    40] loss: 1.43597, train_accuracy: 50.39
[26,    41] loss: 1.32622, train_accuracy: 54.69
[26,    42] loss: 1.40067, train_accuracy: 50.20
[26,    43] loss: 1.38353, train_accuracy: 49.41
[26,    44] loss: 1.

[27,    91] loss: 1.26440, train_accuracy: 55.66
[27,    92] loss: 1.39946, train_accuracy: 51.56
[27,    93] loss: 1.37328, train_accuracy: 51.37
[27,    94] loss: 1.36164, train_accuracy: 54.10
[27,    95] loss: 1.41173, train_accuracy: 50.59
[27,    96] loss: 1.47326, train_accuracy: 47.46
[27,    97] loss: 1.39184, train_accuracy: 50.78
[27,    98] loss: 1.44490, train_accuracy: 44.94
duration: 436 s - train loss: 1.35926 - train accuracy: 51.61 - validation loss: 1.06289 - validation accuracy: 63.40 
[28,     1] loss: 1.35943, train_accuracy: 50.78
[28,     2] loss: 1.37766, train_accuracy: 52.93
[28,     3] loss: 1.37782, train_accuracy: 50.59
[28,     4] loss: 1.39367, train_accuracy: 49.80
[28,     5] loss: 1.37050, train_accuracy: 50.59
[28,     6] loss: 1.31929, train_accuracy: 50.59
[28,     7] loss: 1.34752, train_accuracy: 49.80
[28,     8] loss: 1.30646, train_accuracy: 52.73
[28,     9] loss: 1.31796, train_accuracy: 55.08
[28,    10] loss: 1.25061, train_accuracy: 52.93

[29,    58] loss: 1.32113, train_accuracy: 52.73
[29,    59] loss: 1.33966, train_accuracy: 54.30
[29,    60] loss: 1.33541, train_accuracy: 53.12
[29,    61] loss: 1.36433, train_accuracy: 49.61
[29,    62] loss: 1.34924, train_accuracy: 54.10
[29,    63] loss: 1.32799, train_accuracy: 52.34
[29,    64] loss: 1.33186, train_accuracy: 53.32
[29,    65] loss: 1.32990, train_accuracy: 51.37
[29,    66] loss: 1.34630, train_accuracy: 53.12
[29,    67] loss: 1.37193, train_accuracy: 49.41
[29,    68] loss: 1.38159, train_accuracy: 48.24
[29,    69] loss: 1.34957, train_accuracy: 52.34
[29,    70] loss: 1.31529, train_accuracy: 50.00
[29,    71] loss: 1.43074, train_accuracy: 49.80
[29,    72] loss: 1.33828, train_accuracy: 51.56
[29,    73] loss: 1.34556, train_accuracy: 53.12
[29,    74] loss: 1.39359, train_accuracy: 50.98
[29,    75] loss: 1.36077, train_accuracy: 50.00
[29,    76] loss: 1.40180, train_accuracy: 50.39
[29,    77] loss: 1.34148, train_accuracy: 52.73
[29,    78] loss: 1.

In [11]:
evaluate_clean_accuracy(model, test_loader, device)

(63.87, 0.0)

In [15]:
evaluate_rob_accuracy(model, test_loader, device, epsilon=8/255, attack='PGD')

58.5

In [17]:
def fit_fast_with_double_update(model, train_loader, val_loader , epochs, device, number_of_replays=3, eps = 16/255):
    mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
    mean = torch.tensor(mean).view(3,1,1).expand(3,32,32)
    std = torch.tensor(std).view(3,1,1).expand(3,32,32)

    optimizer = optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss().to(device)
    for epoch in range(epochs):
        t0 = time.time()
        running_loss, acc_epoch_loss, avg_epoch_loss, epoch_accuracy, acc_epoch_accuracy = 0.0, 0.0, 0.0, 0.0, 0.0
        
        for i, data in enumerate(train_loader):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            pert = torch.rand_like(inputs, requires_grad=True)
            adv_inputs = inputs + pert
            adv_inputs.clamp_(0, 1.0)
            adv_inputs.sub_(mean).div_(std)
            #clip 0,1
            
            # first backwards pass to perform fgsm
            outputs = model(adv_inputs)
            loss = criterion(outputs, labels)
            
            optimizer.zero_grad()
            loss.backward()
            pert = pert + (eps * pert.grad)
            pert.clamp_(-eps, eps)
            adv_inputs = inputs + pert
            optimizer.step()
            
            # second backwards pass to update weights on adv.
            optimizer.zero_grad()
            outputs = model(adv_inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            accuracy = get_accuracy(labels, outputs)
            acc_epoch_loss += loss.item() 
            avg_epoch_loss = acc_epoch_loss / (i+1)
            acc_epoch_accuracy += accuracy
            avg_epoch_accuracy = acc_epoch_accuracy / (i+1)
            if i%1 == 0:
                print('[%d, %5d] loss: %.5f, train_accuracy: %.2f' %(epoch + 1, i + 1, loss.item(), accuracy))

        t1 = time.time()
        accuracy, loss = _evaluate_model(model, val_loader, device, criterion)

        print('duration: %d s - train loss: %.5f - train accuracy: %.2f - validation loss: %.5f - validation accuracy: %.2f ' %(t1-t0, avg_epoch_loss, avg_epoch_accuracy, loss, accuracy))

            
            

In [None]:
fit_fast_with_double_update(model, train_loader, test_loader, 30, device, number_of_replays=3, eps = 16/255)

[1,     1] loss: 2.66116, train_accuracy: 12.11
[1,     2] loss: 2.55891, train_accuracy: 19.14
[1,     3] loss: 2.44707, train_accuracy: 18.75
[1,     4] loss: 2.29412, train_accuracy: 21.88
[1,     5] loss: 2.28044, train_accuracy: 21.29
[1,     6] loss: 2.23886, train_accuracy: 20.31
[1,     7] loss: 2.13059, train_accuracy: 24.61
[1,     8] loss: 2.27068, train_accuracy: 23.63
[1,     9] loss: 2.13736, train_accuracy: 26.17
[1,    10] loss: 2.14453, train_accuracy: 24.41
[1,    11] loss: 2.19443, train_accuracy: 23.63
[1,    12] loss: 2.08948, train_accuracy: 27.34
[1,    13] loss: 2.10452, train_accuracy: 23.05
[1,    14] loss: 2.00129, train_accuracy: 30.08
[1,    15] loss: 2.02556, train_accuracy: 25.98
[1,    16] loss: 2.09792, train_accuracy: 26.37
[1,    17] loss: 2.06948, train_accuracy: 23.83
[1,    18] loss: 1.96625, train_accuracy: 29.49
[1,    19] loss: 1.99748, train_accuracy: 26.17
[1,    20] loss: 2.05507, train_accuracy: 26.37
[1,    21] loss: 2.07647, train_accuracy

[2,    72] loss: 1.67464, train_accuracy: 38.87
[2,    73] loss: 1.67981, train_accuracy: 40.04
[2,    74] loss: 1.66716, train_accuracy: 37.30
[2,    75] loss: 1.77758, train_accuracy: 35.55
[2,    76] loss: 1.74583, train_accuracy: 37.50
[2,    77] loss: 1.67587, train_accuracy: 37.89
[2,    78] loss: 1.74530, train_accuracy: 37.70
[2,    79] loss: 1.74576, train_accuracy: 36.13
[2,    80] loss: 1.67584, train_accuracy: 35.94
[2,    81] loss: 1.74719, train_accuracy: 40.23
[2,    82] loss: 1.70323, train_accuracy: 39.84
[2,    83] loss: 1.67273, train_accuracy: 38.09
[2,    84] loss: 1.77330, train_accuracy: 34.57
[2,    85] loss: 1.71300, train_accuracy: 38.87
[2,    86] loss: 1.71906, train_accuracy: 37.11
[2,    87] loss: 1.68709, train_accuracy: 40.43
[2,    88] loss: 1.71389, train_accuracy: 37.50
[2,    89] loss: 1.71178, train_accuracy: 34.77
[2,    90] loss: 1.74080, train_accuracy: 38.09
[2,    91] loss: 1.71875, train_accuracy: 38.28
[2,    92] loss: 1.71261, train_accuracy

[4,    42] loss: 1.70110, train_accuracy: 38.09
[4,    43] loss: 1.63171, train_accuracy: 43.55
[4,    44] loss: 1.68387, train_accuracy: 38.28
[4,    45] loss: 1.60345, train_accuracy: 41.99
[4,    46] loss: 1.61998, train_accuracy: 37.89
[4,    47] loss: 1.62594, train_accuracy: 40.43
[4,    48] loss: 1.72208, train_accuracy: 38.67
[4,    49] loss: 1.56524, train_accuracy: 42.19
[4,    50] loss: 1.61532, train_accuracy: 42.77
[4,    51] loss: 1.65125, train_accuracy: 39.84
[4,    52] loss: 1.67315, train_accuracy: 38.87
[4,    53] loss: 1.70330, train_accuracy: 38.09
[4,    54] loss: 1.69604, train_accuracy: 39.06
[4,    55] loss: 1.57728, train_accuracy: 41.99
[4,    56] loss: 1.64115, train_accuracy: 38.87
[4,    57] loss: 1.53323, train_accuracy: 45.12
[4,    58] loss: 1.63504, train_accuracy: 40.23
[4,    59] loss: 1.61329, train_accuracy: 41.02
[4,    60] loss: 1.62497, train_accuracy: 40.43
[4,    61] loss: 1.63985, train_accuracy: 41.41
[4,    62] loss: 1.66023, train_accuracy

[6,    12] loss: 1.57309, train_accuracy: 43.36
[6,    13] loss: 1.58494, train_accuracy: 40.43
[6,    14] loss: 1.59600, train_accuracy: 39.65
[6,    15] loss: 1.58176, train_accuracy: 43.55
[6,    16] loss: 1.57007, train_accuracy: 39.65
[6,    17] loss: 1.53633, train_accuracy: 45.51
[6,    18] loss: 1.58099, train_accuracy: 42.77
[6,    19] loss: 1.59369, train_accuracy: 44.14
[6,    20] loss: 1.56602, train_accuracy: 42.38
[6,    21] loss: 1.58460, train_accuracy: 44.34
[6,    22] loss: 1.56109, train_accuracy: 44.34
[6,    23] loss: 1.57402, train_accuracy: 46.48
[6,    24] loss: 1.59129, train_accuracy: 42.58
[6,    25] loss: 1.61645, train_accuracy: 42.58
[6,    26] loss: 1.59020, train_accuracy: 42.38
[6,    27] loss: 1.57079, train_accuracy: 43.75
[6,    28] loss: 1.57358, train_accuracy: 43.55
[6,    29] loss: 1.58544, train_accuracy: 42.97
[6,    30] loss: 1.58424, train_accuracy: 40.04
[6,    31] loss: 1.60268, train_accuracy: 41.02
[6,    32] loss: 1.66935, train_accuracy

[7,    83] loss: 1.49397, train_accuracy: 44.34
[7,    84] loss: 1.61261, train_accuracy: 44.14
[7,    85] loss: 1.50340, train_accuracy: 44.53
[7,    86] loss: 1.52804, train_accuracy: 43.36
[7,    87] loss: 1.52560, train_accuracy: 44.73
[7,    88] loss: 1.51971, train_accuracy: 44.53
[7,    89] loss: 1.53020, train_accuracy: 45.51
[7,    90] loss: 1.56188, train_accuracy: 45.31
[7,    91] loss: 1.49510, train_accuracy: 47.07
[7,    92] loss: 1.55821, train_accuracy: 43.75
[7,    93] loss: 1.53873, train_accuracy: 44.34
[7,    94] loss: 1.56778, train_accuracy: 45.51
[7,    95] loss: 1.49413, train_accuracy: 45.51
[7,    96] loss: 1.61250, train_accuracy: 40.43
[7,    97] loss: 1.51507, train_accuracy: 45.12
[7,    98] loss: 1.51362, train_accuracy: 46.13
duration: 207 s - train loss: 1.55413 - train accuracy: 44.01 - validation loss: 1.26015 - validation accuracy: 55.56 
[8,     1] loss: 1.49023, train_accuracy: 47.46
[8,     2] loss: 1.49632, train_accuracy: 46.09
[8,     3] loss: 

[9,    53] loss: 1.57607, train_accuracy: 42.38
[9,    54] loss: 1.47938, train_accuracy: 47.85
[9,    55] loss: 1.57416, train_accuracy: 43.95
[9,    56] loss: 1.48715, train_accuracy: 44.14
[9,    57] loss: 1.51339, train_accuracy: 44.14
[9,    58] loss: 1.56348, train_accuracy: 43.36
[9,    59] loss: 1.57255, train_accuracy: 43.75
[9,    60] loss: 1.47212, train_accuracy: 47.07
[9,    61] loss: 1.57112, train_accuracy: 43.16
[9,    62] loss: 1.53337, train_accuracy: 48.44
[9,    63] loss: 1.52251, train_accuracy: 44.73
[9,    64] loss: 1.50502, train_accuracy: 47.07
[9,    65] loss: 1.47540, train_accuracy: 47.27
[9,    66] loss: 1.51071, train_accuracy: 46.88
[9,    67] loss: 1.52500, train_accuracy: 46.88
[9,    68] loss: 1.53299, train_accuracy: 42.77
[9,    69] loss: 1.53197, train_accuracy: 44.34
[9,    70] loss: 1.53303, train_accuracy: 47.27
[9,    71] loss: 1.55473, train_accuracy: 45.90
[9,    72] loss: 1.50594, train_accuracy: 46.88
[9,    73] loss: 1.55380, train_accuracy

[11,    21] loss: 1.48082, train_accuracy: 48.44
[11,    22] loss: 1.43388, train_accuracy: 51.37
[11,    23] loss: 1.42999, train_accuracy: 49.80
[11,    24] loss: 1.45445, train_accuracy: 48.44
[11,    25] loss: 1.46694, train_accuracy: 43.75
[11,    26] loss: 1.49944, train_accuracy: 47.27
[11,    27] loss: 1.49643, train_accuracy: 46.88
[11,    28] loss: 1.46874, train_accuracy: 46.88
[11,    29] loss: 1.51918, train_accuracy: 43.75
[11,    30] loss: 1.58006, train_accuracy: 42.58
[11,    31] loss: 1.38901, train_accuracy: 46.88
[11,    32] loss: 1.45450, train_accuracy: 48.05
[11,    33] loss: 1.40389, train_accuracy: 49.80
[11,    34] loss: 1.49022, train_accuracy: 48.44
[11,    35] loss: 1.47318, train_accuracy: 46.09
[11,    36] loss: 1.51791, train_accuracy: 43.95
[11,    37] loss: 1.46330, train_accuracy: 47.46
[11,    38] loss: 1.45250, train_accuracy: 47.07
[11,    39] loss: 1.50420, train_accuracy: 45.90
[11,    40] loss: 1.43366, train_accuracy: 51.76
[11,    41] loss: 1.

[12,    88] loss: 1.41276, train_accuracy: 47.27
[12,    89] loss: 1.46816, train_accuracy: 46.68
[12,    90] loss: 1.44488, train_accuracy: 50.00
[12,    91] loss: 1.51187, train_accuracy: 43.36
[12,    92] loss: 1.52992, train_accuracy: 44.73
[12,    93] loss: 1.47650, train_accuracy: 47.66
[12,    94] loss: 1.47650, train_accuracy: 46.68
[12,    95] loss: 1.40658, train_accuracy: 50.20
[12,    96] loss: 1.51699, train_accuracy: 46.88
[12,    97] loss: 1.48515, train_accuracy: 46.09
[12,    98] loss: 1.39417, train_accuracy: 49.40
duration: 220 s - train loss: 1.46611 - train accuracy: 47.40 - validation loss: 1.17969 - validation accuracy: 58.91 
[13,     1] loss: 1.49008, train_accuracy: 45.12
[13,     2] loss: 1.46007, train_accuracy: 47.46
[13,     3] loss: 1.44605, train_accuracy: 47.46
[13,     4] loss: 1.57125, train_accuracy: 42.58
[13,     5] loss: 1.43919, train_accuracy: 48.63
[13,     6] loss: 1.34628, train_accuracy: 53.52
[13,     7] loss: 1.47974, train_accuracy: 43.75

[14,    55] loss: 1.37196, train_accuracy: 48.24
[14,    56] loss: 1.40112, train_accuracy: 49.41
[14,    57] loss: 1.41886, train_accuracy: 50.59
[14,    58] loss: 1.42205, train_accuracy: 50.59
[14,    59] loss: 1.47696, train_accuracy: 50.39
[14,    60] loss: 1.32737, train_accuracy: 52.93
[14,    61] loss: 1.43409, train_accuracy: 47.85
[14,    62] loss: 1.39120, train_accuracy: 50.39
[14,    63] loss: 1.45461, train_accuracy: 46.88
[14,    64] loss: 1.42387, train_accuracy: 49.80
[14,    65] loss: 1.43996, train_accuracy: 48.05
[14,    66] loss: 1.33173, train_accuracy: 53.91
[14,    67] loss: 1.48375, train_accuracy: 43.16
[14,    68] loss: 1.41654, train_accuracy: 45.90
[14,    69] loss: 1.50818, train_accuracy: 46.88
[14,    70] loss: 1.47642, train_accuracy: 47.66
[14,    71] loss: 1.42162, train_accuracy: 48.05
[14,    72] loss: 1.45431, train_accuracy: 43.75
[14,    73] loss: 1.42585, train_accuracy: 51.37
[14,    74] loss: 1.43663, train_accuracy: 50.39
[14,    75] loss: 1.

[16,    22] loss: 1.42223, train_accuracy: 50.39
[16,    23] loss: 1.34812, train_accuracy: 51.37
[16,    24] loss: 1.48425, train_accuracy: 46.68
[16,    25] loss: 1.44628, train_accuracy: 50.00
[16,    26] loss: 1.37801, train_accuracy: 50.59
[16,    27] loss: 1.47073, train_accuracy: 47.85
[16,    28] loss: 1.41481, train_accuracy: 48.44
[16,    29] loss: 1.50524, train_accuracy: 43.95
[16,    30] loss: 1.42064, train_accuracy: 50.59
[16,    31] loss: 1.44001, train_accuracy: 49.80
[16,    32] loss: 1.46592, train_accuracy: 48.44
[16,    33] loss: 1.46732, train_accuracy: 45.70
[16,    34] loss: 1.44285, train_accuracy: 51.76
[16,    35] loss: 1.37192, train_accuracy: 48.63
[16,    36] loss: 1.45139, train_accuracy: 48.83
[16,    37] loss: 1.41038, train_accuracy: 50.98
[16,    38] loss: 1.32279, train_accuracy: 51.17
[16,    39] loss: 1.42855, train_accuracy: 48.44
[16,    40] loss: 1.42207, train_accuracy: 49.80
[16,    41] loss: 1.35796, train_accuracy: 49.80
[16,    42] loss: 1.

[17,    89] loss: 1.40331, train_accuracy: 48.24
[17,    90] loss: 1.32057, train_accuracy: 53.12
[17,    91] loss: 1.39010, train_accuracy: 52.15
[17,    92] loss: 1.36745, train_accuracy: 49.61
[17,    93] loss: 1.41566, train_accuracy: 53.32
[17,    94] loss: 1.42137, train_accuracy: 47.46
[17,    95] loss: 1.39317, train_accuracy: 49.22
[17,    96] loss: 1.43371, train_accuracy: 50.78
[17,    97] loss: 1.41755, train_accuracy: 49.41
[17,    98] loss: 1.32652, train_accuracy: 53.87
duration: 211 s - train loss: 1.41436 - train accuracy: 49.46 - validation loss: 1.11880 - validation accuracy: 60.92 
[18,     1] loss: 1.39509, train_accuracy: 52.93
[18,     2] loss: 1.47158, train_accuracy: 47.07
[18,     3] loss: 1.35489, train_accuracy: 51.76
[18,     4] loss: 1.38033, train_accuracy: 50.39
[18,     5] loss: 1.43131, train_accuracy: 49.61
[18,     6] loss: 1.34349, train_accuracy: 53.71
[18,     7] loss: 1.33559, train_accuracy: 52.15
[18,     8] loss: 1.48054, train_accuracy: 47.66

[19,    56] loss: 1.34660, train_accuracy: 51.95
[19,    57] loss: 1.40347, train_accuracy: 51.76
[19,    58] loss: 1.32601, train_accuracy: 52.34
[19,    59] loss: 1.37485, train_accuracy: 51.37
[19,    60] loss: 1.44617, train_accuracy: 46.68
[19,    61] loss: 1.40552, train_accuracy: 50.98
[19,    62] loss: 1.44793, train_accuracy: 49.02
[19,    63] loss: 1.31319, train_accuracy: 52.93
[19,    64] loss: 1.38942, train_accuracy: 48.24
[19,    65] loss: 1.36290, train_accuracy: 50.00
[19,    66] loss: 1.39194, train_accuracy: 50.39
[19,    67] loss: 1.44764, train_accuracy: 49.02
[19,    68] loss: 1.41066, train_accuracy: 50.39
[19,    69] loss: 1.40127, train_accuracy: 50.59
[19,    70] loss: 1.37176, train_accuracy: 50.98
[19,    71] loss: 1.43154, train_accuracy: 46.48
[19,    72] loss: 1.34742, train_accuracy: 52.54
[19,    73] loss: 1.38663, train_accuracy: 46.88
[19,    74] loss: 1.44003, train_accuracy: 48.83
[19,    75] loss: 1.38022, train_accuracy: 50.20
[19,    76] loss: 1.

[21,    23] loss: 1.43796, train_accuracy: 48.83
[21,    24] loss: 1.42799, train_accuracy: 50.00
[21,    25] loss: 1.35709, train_accuracy: 50.78
[21,    26] loss: 1.37438, train_accuracy: 51.56
[21,    27] loss: 1.31944, train_accuracy: 55.47
[21,    28] loss: 1.37542, train_accuracy: 52.54
[21,    29] loss: 1.38820, train_accuracy: 49.22
[21,    30] loss: 1.30404, train_accuracy: 54.10
[21,    31] loss: 1.30364, train_accuracy: 55.47
[21,    32] loss: 1.45445, train_accuracy: 49.22
[21,    33] loss: 1.38508, train_accuracy: 47.46
[21,    34] loss: 1.36044, train_accuracy: 52.34
[21,    35] loss: 1.44392, train_accuracy: 47.27
[21,    36] loss: 1.41624, train_accuracy: 48.05
[21,    37] loss: 1.40537, train_accuracy: 49.80
[21,    38] loss: 1.27171, train_accuracy: 54.10
[21,    39] loss: 1.40040, train_accuracy: 50.78
[21,    40] loss: 1.35041, train_accuracy: 49.80
[21,    41] loss: 1.38968, train_accuracy: 52.15
[21,    42] loss: 1.41117, train_accuracy: 46.88
[21,    43] loss: 1.

In [None]:
evaluate_clean_accuracy(model, test_loader, device)

In [None]:
evaluate_rob_accuracy(model, test_loader, device, epsilon=8/255, attack='PGD')

In [None]:
global global_noise_data
global_noise_data = torch.zeros([512, 3, 32,32])
def fit_free( model,train_loader, criterion, optimizer, epoch, eps=8/255):
    global global_noise_data
    mean, std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
    mean = torch.tensor(mean).view(3,1,1).expand(3,32,32)
    mean = torch.tensor(std).view(3,1,1).expand(3,32,32)
    # Initialize the meters
    # switch to train mode
    model.train()
    for i, (inputs, target) in enumerate(train_loader):
        running_loss, acc_epoch_loss, avg_epoch_loss, epoch_accuracy, acc_epoch_accuracy = 0.0, 0.0, 0.0, 0.0, 0.0
        inputs = inputs
        target = target
        for j in range(5):
            # Ascend on the global noise
            noise_batch = Variable(global_noise_data[0:inputs.size(0)], requires_grad=True)
            in1 = inputs + noise_batch
            in1.clamp_(0, 1.0)
            in1.sub_(mean).div_(std)
            output = model(in1)
            loss = criterion(output, target)
            
            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            
            # Update the noise for the next iteration
            pert = fgsm(noise_batch.grad, .1)
            global_noise_data[0:inputs.size(0)] += pert.data
            global_noise_data.clamp_(eps, eps)

            optimizer.step()
            # measure elapsed time
            accuracy = get_accuracy(target, output)
            acc_epoch_loss += running_loss 
            avg_epoch_loss = acc_epoch_loss / (i+1)
            acc_epoch_accuracy += accuracy
            avg_epoch_accuracy = acc_epoch_accuracy / (i+1)
            if i%2 == 0:
                print('[%d, %5d] loss: %.5f, train_accuracy: %.2f' %(epoch + 1, i + 1, running_loss, accuracy))
            running_loss = 0.0

In [None]:
criterion = nn.CrossEntropyLoss().cuda()
    
    # Optimizer:
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(10):
    fit_free(model, train_loader,criterion, optimizer, epoch)