In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from collections import OrderedDict, defaultdict
import torch.optim as optim
import time
from src import *
import math
import pickle
import pandas as pd

In [2]:
plt.rc('text', usetex=True)
%config InlineBackend.figure_format = 'retina'
!mkdir -p figures
!mkdir -p snapshots

In [3]:
# parameters
# PGK
ϵ = 8 / 256
ϵ_s = 2 / 256

weight_decay = 1e-4

val_K = 10
retrain = 10
EPOCHS = 300
TEST_EVERY = 30
batch_size = 128
pre_train = False

small = False
training_with_replay_Ks = [1, 4, 10, 30]
free_Ks = [1, 2, 4, 10, 20]

    
PGD_Ks = [1, 2, 7]


attack_names = ['FSM', 'PGD-20', 'PGD-100', 'CW-100']
attacks = [
     *[PGD(K, ϵ, 2.5 * ϵ/K) for K in [1, 20, 100]],
     CW(100, 1e4, ϵ, 2.5 * ϵ/ 100)]
    
    
if small:
    EPOCHS = 5
    TEST_EVERY = 5
    training_with_replay_Ks = [1, 5]
    free_Ks = [1, 5]
    attack_names = ['FSM', 'PGD-2', 'CW-2']
    attacks = [
         *[PGD(K, ϵ, 2.5 * ϵ/K) for K in [1, 2]],
         CW(2, 1e4, ϵ, 2.5 * ϵ/ 2)]

In [4]:
assert all(EPOCHS == K * int(EPOCHS / K) for K in training_with_replay_Ks)
assert all(EPOCHS == K * int(EPOCHS / K) for K in free_Ks)

In [5]:
# CIFAR INPUT
transform = transforms.Compose(
    [transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.RandomHorizontalFlip(p=0.5),
    ])

transform_test = transforms.Compose([
    transforms.ToTensor()
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, 
                                        download=True, transform=transform)
if small:
    trainset = torch.utils.data.Subset(trainset, range(batch_size))

trainloader = torch.utils.data.DataLoader(trainset, 
        batch_size=batch_size,
        shuffle=True, num_workers=4, 
        pin_memory=True, drop_last=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)

if small:
    testset = torch.utils.data.Subset(testset, range(batch_size))
    
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size * 4,
                                         shuffle=False, num_workers=4)


dataiter = iter(trainloader)
images, labels = dataiter.next()

Files already downloaded and verified
Files already downloaded and verified


In [6]:
norm = StandardScalerLayer(lambda: map(lambda x: x[0], trainloader))

criterion = nn.CrossEntropyLoss()

def build_model(ϵ=ϵ):
    model = WideResNet(28, 10, 10, 0.1)
    adv = AdversarialForFree(ϵ, 0, 1)
    if ϵ not in [0, False]:
        l = [('adv', adv)]
    else:
        l = []
    l.extend([
        ('normalizer', norm),
        ('resnet', model)])
    return nn.Sequential(OrderedDict(l)).cuda()

imgsize = images.size()[1:]
imgsize

torch.Size([3, 32, 32])

In [7]:
F.relu

<function torch.nn.functional.relu(input: torch.Tensor, inplace: bool = False) -> torch.Tensor>

In [None]:
#standard training with replay logs
srl = defaultdict(lambda : defaultdict(lambda : []))

for K in training_with_replay_Ks:
    print(f'\n\n\n\n\n------------------------\n\n\n\n training with {K} replays')

    model = build_model(False)
    optimizer = optim.Adam(model.parameters(), weight_decay=weight_decay)
        
    for epoch in range(int(EPOCHS / K)): # loop over the dataset multiple times
            
        logs = train_with_replay(K, model, trainloader, optimizer, epoch)
        srl[K]['train'].append(logs)
        if (epoch * K + K) % TEST_EVERY == 0:
            # valdiation loss
            logs = run_val(model, testloader, epoch)
            srl[K]['test'].append(logs)
            run_attacks(srl[K], attacks, 
                        attack_names, model, testloader, epoch)
    print('Finished Training')
    torch.save(model.state_dict(), f"wresnet-cifar-10-normal-{K}.pch")
    del model
    torch.cuda.empty_cache()

with open('snapshots/srl.pickle', 'wb') as fd:
    pickle.dump(holder_to_dict(srl), fd)






------------------------



 training with 1 replays
train 	 1: 1.5880 40.0% 71.2s
train 	 2: 1.1595 57.8% 71.2s
train 	 3: 0.9540 65.9% 71.5s
train 	 4: 0.8208 71.3% 71.4s
train 	 5: 0.7153 74.9% 71.3s
train 	 6: 0.6392 77.7% 71.2s
train 	 7: 0.5713 80.4% 71.1s
train 	 8: 0.5228 81.9% 71.0s
train 	 9: 0.4808 83.5% 70.9s
train 	 10: 0.4478 84.7% 70.8s
train 	 11: 0.4202 85.6% 70.8s


In [None]:
fig, (ax2, ax1) = plt.subplots(ncols=2, figsize=(15,7))

y = [srl[K]["test"][-1].acc * 100 for K in training_with_replay_Ks]
bars = ax1.bar([f'$m={K}$' for K in training_with_replay_Ks], y)
for (i, bar) in zip(y, bars):
    t = ax1.text(bar.get_x() + bar.get_width() /2 - 0.07 , bar.get_height() + 0.10, f'{i:0.1f}%')
for ax in [ax1, ax2]:
    ax.set_xlabel('number of replay steps $m$')
ax1.set_ylabel('validation accuracy ($\%$)')

ax2.set_ylabel('validation loss (KL)')
y = [srl[K]["test"][-1].loss for K in training_with_replay_Ks]
bars = ax2.bar([f'$m={K}$' for K in training_with_replay_Ks], y)
for (i, bar) in zip(y, bars):
    t = ax2.text(bar.get_x() + bar.get_width() /2 - 0.07 , bar.get_height() + 0.10, f'{i:0.1f}')
def savefig(fig, name, f=['svg', 'pdf', 'png']):
    for e in f:
        fig.savefig('figures/' + name + '.' + e)
savefig(fig, 'cost_of_replay')


In [None]:
free_logs = defaultdict(lambda : defaultdict(lambda :[]))

for K in free_Ks:
    model = build_model()
    optimizer = optim.Adam(model.parameters())

    for epoch in range(int(EPOCHS / K)):  # loop over the dataset multiple times
        logs = train_with_replay(K, model, trainloader, optimizer, epoch,
                                after_func=lambda model: model.adv.step())
        free_logs[K]['train'].append(logs)
        
        if (epoch * K + K) % TEST_EVERY == 0:

            logs = run_val(model, testloader, epoch)
            free_logs[K]['test'].append(logs)

            # adv loss
            run_attacks(free_logs[K], attacks, attack_names, model, testloader, epoch)
    

    print('Finished Training')
    torch.save(model.state_dict(), f"snapshots/wresnet-cifar-10-free-{K}.pch")
    del model
    torch.cuda.empty_cache()
    
with open('snapshots/free_logs.pickle', 'wb') as fd:
    pickle.dump(holder_to_dict(free_logs), fd)

In [None]:
pgd_logs = defaultdict(lambda : defaultdict(lambda : []))

for K in PGD_Ks:
    model = build_model(False)
    optimizer = optim.Adam(model.parameters())
    
    attack = PGD(K, ϵ, 2.5 * ϵ / K)
    
    for epoch in range(EPOCHS):  # loop over the dataset multiple times
        
        
        

        logs = train_with_replay(1, 
             model, 
             trainloader, 
             optimizer,
             epoch,
             input_func=lambda inputs, labels: attack(model, inputs, labels))
        pgd_logs[K]['train'].append(logs)
        
        if (epoch + 1) % TEST_EVERY == 0:
    
            logs = run_val(model, testloader, epoch)
            pgd_logs[K]['test'].append(logs)
            run_attacks(pgd_logs[K], attacks, 
                        attack_names, model, testloader, epoch)

    print('Finished Training')
    torch.save(model.state_dict(), f"snapshots/wresnet-cifar-10-pgk-{K}.pch")
    del model
    torch.cuda.empty_cache()
    
with open('snapshots/pgd_logs.pickle', 'wb') as fd:
    pickle.dump(holder_to_dict(pgd_logs), fd)

In [None]:
fmt = lambda x: f'$${x * 100:.2f}\%$$'
d = {}
d['Training'] = ['Natural', 
         *[f'Free $m={K}$' for K in free_Ks],
         *[f'{K}-PGD' for K in PGD_Ks]]


x = [srl[1]['test'][-1].acc,
         *[free_logs[K]['test'][-1].acc for K in free_Ks],
         *[pgd_logs[K]['test'][-1].acc for K in PGD_Ks]]

d['Natural Images'] = list(map(fmt, x))
            
for name in attack_names:
    n = f'adv_test/{name}'
    
    x = [srl[1][n][-1].acc]
    
    for K in free_Ks:
        x.append(free_logs[K][n][-1].acc)
    
    for K in PGD_Ks:
        x.append(pgd_logs[K][n][-1].acc)
    d[name] = list(map(fmt, x))
        
tt = lambda x: sum(i.time for i in x)
fmt = lambda x: f'$${math.ceil(x / 60)}$$'
x = [srl[1]['train'],
    *[free_logs[K]['train'] for K in free_Ks],
    *[pgd_logs[K]['train'] for K in PGD_Ks]]

d['Training Time(M)'] = list(map(lambda x: fmt(tt(x)), x))

df = pd.DataFrame(d)

df

In [None]:
df.to_csv('figures/grid.csv')

In [None]:
df.to_latex('figures/grid.tex')