## Import Libraries

In [None]:
import os, argparse, time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.optim import SGD, Adam, lr_scheduler

from models.cifar10.resnet_OAT import ResNet34OAT
from models.svhn.wide_resnet_OAT import WRN16_8OAT
from models.stl10.wide_resnet_OAT import WRN40_2OAT

from dataloaders.cifar10 import cifar10_dataloaders
from dataloaders.svhn import svhn_dataloaders
from dataloaders.stl10 import stl10_dataloaders

from utils.utils import *
from utils.context import ctx_noparamgrad_and_eval
from utils.sample_lambda import element_wise_sample_lambda, batch_wise_sample_lambda
from attacks.pgd import PGD

# Parameters

In [None]:
gpu = '1'
cpus = 8
dataset = 'cifar10' #choices=['cifar10', 'svhn', 'stl10']
batch_size = 128
epochs = 100
decay_epochs = [50, 150]
opt = 'sgd' #choices=['sgd', 'adam']
decay = 'cos' #choices=['cos', 'multisteps']
lr = 0.1
momentum = 0.9
wd = 5e-4 #weight decay
targeted = True #if true, targeted attack
eps = 31
steps =7
distribution = 'disc'
lambda_choices = [0.0,0.1,0.2,0.3,0.4,1.0]
probs = -1
encoding = 'rand' #choices=['none', 'onehot', 'dct', 'rand']
dim = 128
use2BN = True
sampling ='ew' #sampling scheme for Lambda, choices=['ew', 'bw']
resume = False #If true, resume from early stopped ckpt
efficient = True

# Set Enviroment

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
torch.backends.cudnn.benchmark = True

# Load Data

In [None]:
if dataset == 'cifar10':
    train_loader, val_loader, _ = cifar10_dataloaders(train_batch_size=batch_size, num_workers=cpus)
elif dataset == 'svhn':
    train_loader, val_loader, _ = svhn_dataloaders(train_batch_size=batch_size, num_workers=cpus)
elif dataset == 'stl10':
    train_loader, val_loader = stl10_dataloaders(train_batch_size=batch_size, num_workers=cpus)

# Initialize model

In [None]:
if encoding in ['onehot', 'dct', 'rand']:
    FiLM_in_channels = dim
else: # non encoding
    FiLM_in_channels = 1
if dataset == 'cifar10':
    model_fn = ResNet34OAT
elif dataset == 'svhn':
    model_fn = WRN16_8OAT
elif dataset == 'stl10':
    model_fn = WRN40_2OAT
model = model_fn(use2BN=use2BN, FiLM_in_channels=FiLM_in_channels).cuda()
model = torch.nn.DataParallel(model)

model_str = os.path.join(model_fn.__name__)
if use2BN:
    model_str += '-2BN'
if opt == 'sgd':
    opt_str = 'e%d-b%d_sgd-lr%s-m%s-wd%s' % (epochs, batch_size, lr, momentum, wd)
elif opt == 'adam':
    opt_str = 'e%d-b%d_adam-lr%s-wd%s' % (epochs, batch_size, lr, wd)
if decay == 'cos':
    decay_str = 'cos'
elif decay == 'multisteps':
    decay_str = 'multisteps-%s' % decay_epochs
attack_str = 'targeted' if targeted else 'untargeted' + '-pgd-%s-%d' % (eps, steps)
lambda_str = '%s-%s-%s' % (distribution, sampling, lambda_choices)
if probs > 0:
    lambda_str += '-%s' % probs
if encoding in ['onehot', 'dct', 'rand']:
    lambda_str += '-%s-d%s' % (encoding, dim)
save_folder = os.path.join(os.getcwd(), 'OAT_results', dataset, model_str, '%s_%s_%s_%s' % (attack_str, opt_str, decay_str, lambda_str))
print(save_folder)
create_dir(save_folder)

# encoding matrix:
if encoding == 'onehot':
    I_mat = np.eye(dim)
    encoding_mat = I_mat
elif encoding == 'dct':
    from scipy.fftpack import dct
    dct_mat = dct(np.eye(dim), axis=0)
    encoding_mat = dct_mat
elif encoding == 'rand':
    rand_mat = np.random.randn(dim, dim)
    np.save(os.path.join(save_folder, 'rand_mat.npy'), rand_mat)
    rand_otho_mat, _ = np.linalg.qr(rand_mat)
    np.save(os.path.join(save_folder, 'rand_otho_mat.npy'), rand_otho_mat)
    encoding_mat = rand_otho_mat
elif encoding == 'none':
    encoding_mat = None

if distribution == 'disc':
    val_lambdas = lambda_choices
else:
    val_lambdas = [0,0.2,0.5,1]

if opt == 'sgd':
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd)
elif opt == 'adam':
    optimizer = Adam(model.parameters(), lr=lr, weight_decay=wd)
if decay == 'cos':
    scheduler = lr_scheduler.CosineAnnealingLR(optimizer, epochs)
elif decay == 'multisteps':
    scheduler = lr_scheduler.MultiStepLR(optimizer, decay_epochs, gamma=0.1)

if resume:
    last_epoch, best_TA, best_ATA, training_loss, val_TA, val_ATA \
         = load_ckpt(model, optimizer, scheduler, os.path.join(save_folder, 'latest.pth'))
    start_epoch = last_epoch + 1
else:
    start_epoch = 0
    training_loss, val_TA, val_ATA, best_TA, best_ATA = [], {}, {}, {}, {}
    for val_lambda in val_lambdas:
        val_TA[val_lambda], val_ATA[val_lambda], best_TA[val_lambda], best_ATA[val_lambda] = [], [], 0, 0


attacker = PGD(eps=eps/1000, steps=steps, use_FiLM=True)

# Training

In [None]:
for epoch in range(start_epoch, epochs):
    train_fp = open(os.path.join(save_folder, 'train_log.txt'), 'a+')
    val_fp = open(os.path.join(save_folder, 'val_log.txt'), 'a+')
    start_time = time.time()
    ## training:
    model.train()
    requires_grad_(model, True)
    accs, accs_adv, losses, lps = AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
    for i, (imgs, labels) in enumerate(train_loader):
        imgs, labels = imgs.cuda(), labels.cuda()
        # sample _lambda:
        if sampling == 'ew':
            _lambda_flat, _lambda, num_zeros = element_wise_sample_lambda(distribution, lambda_choices, encoding_mat, 
                batch_size=batch_size, probs=probs)
        if use2BN:
            idx2BN = num_zeros
        else:
            idx2BN = None

        # logits for clean imgs:
        logits = model(imgs, _lambda, idx2BN)
        # clean loss:
        lc = F.cross_entropy(logits, labels, reduction='none')
        
        if efficient:
            # generate adversarial images:
            with ctx_noparamgrad_and_eval(model):
                if use2BN:
                    imgs_adv = attacker.attack(model, imgs[num_zeros:], labels=labels[num_zeros:], _lambda=_lambda[num_zeros:], idx2BN=0)
                else:
                    imgs_adv = attacker.attack(model, imgs[num_zeros:], labels=labels[num_zeros:], _lambda=_lambda[num_zeros:], idx2BN=None)
            # logits for adv imgs:
            logits_adv = model(imgs_adv.detach(), _lambda[num_zeros:], idx2BN=0)
            
            # loss and update:
            la = F.cross_entropy(logits_adv, labels[num_zeros:], reduction='none') 
            la = torch.cat([torch.zeros((num_zeros,)).cuda(), la], dim=0)
        else:
            # generate adversarial images:
            with ctx_noparamgrad_and_eval(model):
                imgs_adv = attacker.attack(model, imgs, labels=labels, _lambda=_lambda, idx2BN=idx2BN)
            # logits for adv imgs:
            logits_adv = model(imgs_adv.detach(), _lambda, idx2BN=idx2BN)

            # loss and update:
            la = F.cross_entropy(logits_adv, labels, reduction='none') 
        wc = (1-_lambda_flat)
        wa = _lambda_flat
        loss = torch.mean(wc * lc + wa * la) 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # get current lr:
        current_lr = scheduler.get_lr()[0]

        # metrics:
        accs.append((logits.argmax(1) == labels).float().mean().item())
        if efficient:
            accs_adv.append((logits_adv.argmax(1) == labels[num_zeros:]).float().mean().item())
        else:
            accs_adv.append((logits_adv.argmax(1) == labels).float().mean().item())
        losses.append(loss.item())

        if i % 50 == 0:
            train_str = 'Epoch %d-%d | Train | Loss: %.4f, SA: %.4f, RA: %.4f' % (
                epoch, i, losses.avg, accs.avg, accs_adv.avg)
            print(train_str)
            # print('idx2BN:', idx2BN)
            train_fp.write(train_str + '\n')
        # if i % 100 == 0:
        #     print('_lambda_flat:', _lambda_flat.size(), _lambda_flat[0:10].data.data.cpu().numpy().squeeze())
        #     print('_lambda:', _lambda.size(), _lambda[0:5,:].data.cpu().numpy().squeeze())

    # lr schedualr update at the end of each epoch:
    scheduler.step()


    ## validation:
    model.eval()
    requires_grad_(model, False)
    print(model.training)

    eval_this_epoch = (epoch % 10 == 0) or (epoch>=int(0.75*epochs)) # boolean
    
    if eval_this_epoch:
        val_accs, val_accs_adv = {}, {}
        for val_lambda in val_lambdas:
            val_accs[val_lambda], val_accs_adv[val_lambda] = AverageMeter(), AverageMeter()
            
        for i, (imgs, labels) in enumerate(val_loader):
            imgs, labels = imgs.cuda(), labels.cuda()

            for j, val_lambda in enumerate(val_lambdas):
                # sample _lambda:
                if distribution == 'disc' and encoding_mat is not None:
                    _lambda = np.expand_dims( np.repeat(j, labels.size()[0]), axis=1 ).astype(np.uint8)
                    _lambda = encoding_mat[_lambda,:] 
                else:
                    _lambda = np.expand_dims( np.repeat(val_lambda, labels.size()[0]), axis=1 )
                _lambda = torch.from_numpy(_lambda).float().cuda()
                if use2BN:
                    idx2BN = int(labels.size()[0]) if val_lambda==0 else 0
                else:
                    idx2BN = None
                # TA:
                logits = model(imgs, _lambda, idx2BN)
                val_accs[val_lambda].append((logits.argmax(1) == labels).float().mean().item())
                # ATA:
                # generate adversarial images:
                with ctx_noparamgrad_and_eval(model):
                    imgs_adv = attacker.attack(model, imgs, labels=labels, _lambda=_lambda, idx2BN=idx2BN)
                linf_norms = (imgs_adv - imgs).view(imgs.size()[0], -1).norm(p=np.Inf, dim=1)
                logits_adv = model(imgs_adv.detach(), _lambda, idx2BN)
                val_accs_adv[val_lambda].append((logits_adv.argmax(1) == labels).float().mean().item())

    val_str = 'Epoch %d | Validation | Time: %.4f | lr: %s' % (epoch, (time.time()-start_time), current_lr)
    if eval_this_epoch:
        val_str += ' | linf: %.4f - %.4f\n' % (torch.min(linf_norms).data, torch.max(linf_norms).data)
        for val_lambda in val_lambdas:
            val_str += 'val_lambda%s: SA: %.4f, RA: %.4f\n' % (val_lambda, val_accs[val_lambda].avg, val_accs_adv[val_lambda].avg)
    print(val_str)
    val_fp.write(val_str + '\n')
    val_fp.close() # close file pointer

    # save loss curve:
    training_loss.append(losses.avg)
    plt.plot(training_loss)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss')
    plt.grid(True)
    plt.savefig(os.path.join(save_folder, 'training_loss.png'))
    plt.close()

    if eval_this_epoch:
        for val_lambda in val_lambdas:
            val_TA[val_lambda].append(val_accs[val_lambda].avg) 
            plt.plot(val_TA[val_lambda], 'r')
            val_ATA[val_lambda].append(val_accs_adv[val_lambda].avg)
            plt.plot(val_ATA[val_lambda], 'g')
            plt.xlabel('Epoch')
            plt.ylabel('Accuracy')
            plt.title('Validation Accuracy')
            plt.legend(["SA", "RA"])
            plt.grid(True)
            plt.savefig(os.path.join(save_folder, 'val_acc%s.png' % val_lambda))
            plt.close()
    else:
        for val_lambda in val_lambdas:
            val_TA[val_lambda].append(val_TA[val_lambda][-1]) 
            plt.plot(val_TA[val_lambda], 'r')
            val_ATA[val_lambda].append(val_ATA[val_lambda][-1])
            plt.plot(val_ATA[val_lambda], 'g')
            plt.xlabel('Epoch')
            plt.ylabel('Accuracy')
            plt.title('Validation Accuracy')
            plt.legend(["SA", "RA"])
            plt.grid(True)
            plt.savefig(os.path.join(save_folder, 'val_acc%s.png' % val_lambda))
            plt.close()

    # save pth:
    if eval_this_epoch:
        for val_lambda in val_lambdas:
            if val_accs[val_lambda].avg >= best_TA[val_lambda]:
                best_TA[val_lambda] = val_accs[val_lambda].avg # update best TA
                torch.save(model.state_dict(), os.path.join(save_folder, 'best_SA%s.pth' % val_lambda))
            if val_accs_adv[val_lambda].avg >= best_ATA[val_lambda]:
                best_ATA[val_lambda] = val_accs_adv[val_lambda].avg # update best ATA
                torch.save(model.state_dict(), os.path.join(save_folder, 'best_RA%s.pth' % val_lambda))
    save_ckpt(epoch, model, optimizer, scheduler, best_TA, best_ATA, training_loss, val_TA, val_ATA, 
        os.path.join(save_folder, 'latest.pth'))
        
