In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from collections import OrderedDict
import torch.optim as optim
import time
from src import *
import math
from tqdm import tqdm

In [2]:
# parameters
ϵ = 8 / 256
K = 7
retrain = 10
epoch_count = 300
batch_size = 128
pre_train = False

In [3]:
# CIFAR INPUT
transform = transforms.Compose(
    [transforms.RandomCrop(32, padding=4),
     transforms.ToTensor(),
     transforms.RandomHorizontalFlip(p=0.5),
    ])

transform_test = transforms.Compose([
    transforms.ToTensor()
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

def trainloadicator(b):
    return torch.utils.data.DataLoader(trainset, 
        batch_size=b,
        shuffle=True, num_workers=4, 
        pin_memory=True, drop_last=True)
trainloader = trainloadicator(batch_size)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=4)

dataiter = iter(trainloader)
images, labels = dataiter.next()

Files already downloaded and verified
Files already downloaded and verified


In [4]:
norm = StandardScalerLayer(lambda: map(lambda x: x[0], trainloader))

criterion = nn.CrossEntropyLoss()

def mkmodel(ϵ=ϵ):
    model = WideResNet(34, 10, 10, 0.3)
    adv = AdversarialForFree(ϵ, 0, 1)
    return nn.Sequential(OrderedDict([
        ('adv', adv),
        ('normalizer', norm),
        ('resnet', model)])).cuda()

imgsize = images.size()[1:]
imgsize

torch.Size([3, 32, 32])

In [None]:
train_loader = None
train_batch = None

model = mkmodel()
optimizer = optim.Adam(model.parameters(), weight_decay=5e-4)

for epoch in range(math.ceil(epoch_count / K)):  # loop over the dataset multiple times
                   
    b = 128
    if train_batch != b:
        print(f'setting batch size to {b}')
        train_loader = trainloadicator(b)
        train_batch = b

    logs = Logisticator()
    model.train()
                   
    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = map(lambda x: x.cuda(), data)
        for k in range(K):
            # zero the parameter gradients
            optimizer.zero_grad()
                
            # forward + backward + optimize
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()
            
            optimizer.step()
            model.adv.step()
                     
            acc = accuracy(outputs, labels)
            logs.add(acc, loss.item(), inputs.size(0))
    
    print(f'train \t {epoch + 1}: {logs}')
    
    
    
    model.train(False)
    # valdiation loss
    with torch.no_grad():
        logs = Logisticator()
        for data in testloader:
            inputs, labels = map(lambda x: x.cuda(), data)
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)

            acc = accuracy(outputs, labels)
            logs.add(acc, loss.item(), inputs.size(0))
            
        print(f'val \t {epoch + 1}: {logs}')
    
    # adv loss
    logs = Logisticator()
    for data in testloader:
        inputs, labels = map(lambda x: x.cuda(), data)
        noise = PGK(model, lambda x: F.cross_entropy(x, labels), inputs, ϵ, K)
        
        with torch.no_grad():
            outputs = model(inputs + noise)
            loss = F.cross_entropy(outputs, labels)

            acc = accuracy(outputs, labels)
            logs.add(acc, loss.item(), inputs.size(0))
    print(f'adv \t {epoch + 1}: {logs}')
    
print('Finished Training')


setting batch size to 128
train 	 1: 2.0808 22.0% 617.4s
val 	 1: 2.1781 21.5% 5.9s
adv 	 1: 2.6041 16.4% 120.6s
train 	 2: 1.9606 26.5% 606.1s
val 	 2: 1.7089 33.5% 5.8s
adv 	 2: 2.0658 21.4% 119.5s
train 	 3: 1.8441 30.7% 608.6s
val 	 3: 1.8288 34.4% 5.8s
adv 	 3: 2.4063 19.5% 119.9s
train 	 4: 1.8152 31.6% 608.4s
val 	 4: 1.5631 42.1% 5.8s
adv 	 4: 1.9891 25.1% 119.6s
train 	 5: 1.7979 32.4% 606.1s
val 	 5: 1.6276 36.5% 5.8s
adv 	 5: 2.2220 19.2% 119.3s


In [None]:
del model
torch.cuda.empty_cache() 