## Import Libraries

In [None]:
import os, argparse, time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.optim import SGD, Adam, lr_scheduler

from models.cifar10.resnet import ResNet34
from models.svhn.wide_resnet import WRN16_8
from models.stl10.wide_resnet import WRN40_2

from dataloaders.cifar10 import cifar10_dataloaders
from dataloaders.svhn import svhn_dataloaders
from dataloaders.stl10 import stl10_dataloaders

from utils.utils import *
from utils.context import ctx_noparamgrad_and_eval
from attacks.pgd import PGD

# Parameters

In [None]:
gpu = 0
cpus = 8
dataset = 'cifar10' #choices=['cifar10', 'svhn', 'stl10']
batch_size = 128
epochs = 100
decay_epochs = [50, 150]
opt = 'sgd' #choices=['sgd', 'adam']
decay = 'cos' #choices=['cos', 'multisteps']
lr = 0.1
momentum = 0.9
wd = 5e-4 #weight decay
targeted = True #if true, targeted attack
eps = 31
steps =7
Lambda = 0.5
resume = True #If true, resume from early stopped ckpt
efficient = True

# Set Enviroment

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
torch.backends.cudnn.benchmark = True

# Load Data

In [None]:
if dataset == 'cifar10':
    train_loader, val_loader, _ = cifar10_dataloaders(train_batch_size=batch_size, num_workers=cpus)
elif dataset == 'svhn':
    train_loader, val_loader, _ = svhn_dataloaders(train_batch_size=batch_size, num_workers=cpus)
elif dataset == 'stl10':
    train_loader, val_loader = stl10_dataloaders(train_batch_size=batch_size, num_workers=cpus)


# Initialize model

In [None]:
if dataset == 'cifar10':
    model_fn = ResNet34
elif dataset == 'svhn':
    model_fn = WRN16_8
elif dataset == 'stl10':
    model_fn = WRN40_2
model = model_fn().cuda()
model = torch.nn.DataParallel(model)

model_str = model_fn.__name__
if opt == 'sgd':
    opt_str = 'e%d-b%d_sgd-lr%s-m%s-wd%s' % (epochs, batch_size, lr, momentum, wd)
elif opt == 'adam':
    opt_str = 'e%d-b%d_adam-lr%s-wd%s' % (epochs, batch_size, lr, wd)
if decay == 'cos':
    decay_str = 'cos'
elif decay == 'multisteps':
    decay_str = 'multisteps-%s' % decay_epochs
attack_str = 'targeted' if targeted else 'untargeted' + '-pgd-%d-%d' % (eps, steps)
loss_str = 'lambda%s' % (Lambda)
save_folder = os.path.join(os.getcwd(), 'PGD_results', dataset, model_str, '%s_%s_%s_%s' % (attack_str, opt_str, decay_str, loss_str))
create_dir(save_folder)

if opt == 'sgd':
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd)
elif opt == 'adam':
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=wd)
if decay == 'cos':
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
elif decay == 'multisteps':
    scheduler = lr_scheduler.MultiStepLR(optimizer, decay_epochs, gamma=0.1)

attacker = PGD(eps=eps/1000, steps=steps)

if resume:
    last_epoch, best_TA, best_ATA, training_loss, val_TA, val_ATA \
         = load_ckpt(model, optimizer, scheduler, os.path.join(save_folder, 'latest.pth'))
    start_epoch = last_epoch + 1
else:
    start_epoch = 0
    best_TA, best_ATA = 0, 0
    # training curve lists:
    training_loss, val_TA, val_ATA = [], [], []


# Training

In [None]:
for epoch in range(start_epoch, epochs):
    fp = open(os.path.join(save_folder, 'train_log.txt'), 'a+')
    start_time = time.time()

    model.train()
    requires_grad_(model, True)
    accs, accs_adv, losses = AverageMeter(), AverageMeter(), AverageMeter()
    for i, (imgs, labels) in enumerate(train_loader):
        imgs, labels = imgs.cuda(), labels.cuda()

        if Lambda != 0:
            with ctx_noparamgrad_and_eval(model):
                imgs_adv = attacker.attack(model, imgs, labels)
            logits_adv = model(imgs_adv.detach())

        logits = model(imgs)

        loss = F.cross_entropy(logits, labels)
        if Lambda != 0:
            loss = (1-Lambda) * loss + Lambda * F.cross_entropy(logits_adv, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        current_lr = scheduler.get_lr()[0]

        accs.append((logits.argmax(1) == labels).float().mean().item())
        if Lambda != 0:
            accs_adv.append((logits_adv.argmax(1) == labels).float().mean().item())
        losses.append(loss.item())

        if i % 100 == 0:
            train_str = 'Epoch %d-%d | Train | Loss: %.4f, SA: %.4f' % (epoch, i, losses.avg, accs.avg)
            if Lambda != 0:
                train_str += ', RA: %.4f' % (accs_adv.avg)
            print(train_str)

    scheduler.step()

    model.eval()
    requires_grad_(model, False)
    print(model.training)

    if dataset == 'cifar10':
        eval_this_epoch = (epoch % 10 == 0) or (epoch>=int(0.7*epochs))
    elif dataset == 'svhn':
        eval_this_epoch = (epoch % 10 == 0) or (epoch>=int(0.25*epochs))
    
    if eval_this_epoch:
        val_accs, val_accs_adv = AverageMeter(), AverageMeter()
        for i, (imgs, labels) in enumerate(val_loader):
            imgs, labels = imgs.cuda(), labels.cuda()

            with ctx_noparamgrad_and_eval(model):
                imgs_adv = attacker.attack(model, imgs, labels)
            linf_norms = (imgs_adv - imgs).view(imgs.size()[0], -1).norm(p=np.Inf, dim=1)
            logits_adv = model(imgs_adv.detach())

            logits = model(imgs)

            val_accs.append((logits.argmax(1) == labels).float().mean().item())
            val_accs_adv.append((logits_adv.argmax(1) == labels).float().mean().item())

        val_str = 'Epoch %d | Validation | Time: %.4f | lr: %s | SA: %.4f, RA: %.4f, linf: %.4f - %.4f' % (
            epoch, (time.time()-start_time), current_lr, val_accs.avg, val_accs_adv.avg, 
            torch.min(linf_norms).data, torch.max(linf_norms).data)
        print(val_str)
        fp.write(val_str + '\n')


    training_loss.append(losses.avg)
    plt.plot(training_loss)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.grid(True)
    plt.savefig(os.path.join(save_folder, 'training_loss.png'))
    plt.close()

    if eval_this_epoch:
        val_TA.append(val_accs.avg) 
        plt.plot(val_TA, 'r')
        val_ATA.append(val_accs_adv.avg)
        plt.plot(val_ATA, 'g')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.title('Validation Accuracy')
        plt.legend(["SA", "RA"])
        plt.grid(True)
        plt.savefig(os.path.join(save_folder, 'val_acc.png'))
        plt.close()
    else:
        val_TA.append(val_TA[-1]) 
        plt.plot(val_TA, 'r')
        val_ATA.append(val_ATA[-1])
        plt.plot(val_ATA, 'g')
        plt.xlabel('Epoch')
        plt.ylabel('Accuracy')
        plt.title('Validation Accuracy')
        plt.legend(["SA", "RA"])
        plt.grid(True)
        plt.savefig(os.path.join(save_folder, 'val_acc.png'))
        plt.close()

    if eval_this_epoch:
        if val_accs.avg >= best_TA:
            best_TA = val_accs.avg
            torch.save(model.state_dict(), os.path.join(save_folder, 'best_SA.pth'))
        if val_accs_adv.avg >= best_ATA:
            best_ATA = val_accs_adv.avg
            torch.save(model.state_dict(), os.path.join(save_folder, 'best_RA.pth'))
    save_ckpt(epoch, model, optimizer, scheduler, best_TA, best_ATA, training_loss, val_TA, val_ATA, 
        os.path.join(save_folder, 'latest.pth'))


