In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision
import numpy as np
import json
import time
import sys
import copy
from tqdm import tqdm_notebook as tqdm
from matplotlib import pyplot as plt

from DBA import models
from DBA import activation
from DBA import attack
from DBA import training
from DBA import utilities
from DBA import densenet
from DBA import resnet
from DBA import wideresnet
from DBA import optdefense
from DBA import vgg
norm_mean = 0
norm_var = 1
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((norm_mean,norm_mean,norm_mean), (norm_var, norm_var, norm_var)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((norm_mean,norm_mean,norm_mean), (norm_var, norm_var, norm_var)),
])
cifar_train = datasets.CIFAR10("./data", train=True, download=True, transform=transform_train)
cifar_test = datasets.CIFAR10("./data", train=False, download=True, transform=transform_test)
train_loader = DataLoader(cifar_train, batch_size = 128, shuffle=True)
test_loader = DataLoader(cifar_test, batch_size = 50, shuffle=True)
device = torch.device('cuda:0')

basenet = resnet.ResNet18()
epsilon=0.1
advnet = vgg.vgg0_bn()
advnet.load_state_dict(torch.load('models/vgg0_adv.pth'))
basenet = basenet.to(device)
advnet = advnet.to(device)
optnet = optdefense.OptNet(basenet, advnet, epsilon=epsilon, stepsize=epsilon/10., 
                           iters=13, randomize=True, adv_mode='adv')
opt = optim.SGD(optnet.basemodel.parameters(), lr=0.1, momentum=0.9)
for ep in range(80):
    if ep == 50:
        for param_group in opt.param_groups:
                param_group['lr'] = 0.01
    train_err, train_loss = training.epoch(train_loader, basenet, opt, device=device, use_tqdm=True)
    test_err, test_loss = training.epoch(test_loader, basenet, device=device, use_tqdm=True)
    print('epoch', ep, 'train acc', 1-train_err, 'test acc', 1-test_err)
    torch.save(optnet.basemodel.state_dict(), 'models/resnet18_vgg0_adv_0.1.pth')

#eval
evalmethods = []
evalmethods.append({'name':'std'})
evalmethods.append({'name':'DA', 'n_test':5000})
attacknet1 = vgg.vgg19_bn()
attacknet1.load_state_dict(torch.load('./models/vgg19_bn2.pth', map_location='cpu'))
evalmethods.append({'name':'transfer',
                    'net':attacknet1,
                   'model_name':'vgg',
                   'n_test':5000})
attacknet2 = resnet.ResNet18()
attacknet2.load_state_dict(torch.load('./models/resnet18_cifar.pth', map_location='cpu'))
evalmethods.append({'name':'transfer',
                    'net':attacknet2,
                   'model_name':'resnet18',
                   'n_test':5000})
evalmethods.append({'name':'BPDA',
                    'samples':1,
                    'iters':20,
                    'epoch_test':20})
evalmethods.append({'name':'BPDA',
                    'iters':0,
                    'samples':1,
                   'epoch_test':20})
training.eval_robustness(optnet, test_loader, device, use_tqdm=True, evalmethods=evalmethods)

In [None]:
basenet = resnet.ResNet18()
epsilon=0.2
advnet = vgg.vgg0_bn()
advnet.load_state_dict(torch.load('models/vgg0_adv.pth'))
basenet = basenet.to(device)
advnet = advnet.to(device)
optnet = optdefense.OptNet(basenet, advnet, epsilon=epsilon, stepsize=epsilon/10., 
                           iters=13, randomize=True, adv_mode='adv')
opt = optim.SGD(optnet.basemodel.parameters(), lr=0.1, momentum=0.9)
for ep in range(80):
    if ep == 50:
        for param_group in opt.param_groups:
                param_group['lr'] = 0.01
    train_err, train_loss = training.epoch(train_loader, optnet, opt, device=device, use_tqdm=True)
    test_err, test_loss = training.epoch(test_loader, optnet, device=device, use_tqdm=True)
    print('epoch', ep, 'train acc', 1-train_err, 'test acc', 1-test_err)
    torch.save(optnet.basemodel.state_dict(), 'models/resnet18_vgg0_adv_0.2.pth')

#eval
training.eval_robustness(optnet, test_loader, device, use_tqdm=True, evalmethods=evalmethods)

In [None]:
basenet = resnet.ResNet18()
epsilon=0.3
advnet = vgg.vgg0_bn()
advnet.load_state_dict(torch.load('models/vgg0_adv.pth'))
basenet = basenet.to(device)
advnet = advnet.to(device)
optnet = optdefense.OptNet(basenet, advnet, epsilon=epsilon, stepsize=epsilon/10., 
                           iters=13, randomize=True, adv_mode='adv')
opt = optim.SGD(optnet.basemodel.parameters(), lr=0.1, momentum=0.9)
for ep in range(80):
    if ep == 50:
        for param_group in opt.param_groups:
                param_group['lr'] = 0.01
    train_err, train_loss = training.epoch(train_loader, optnet, opt, device=device, use_tqdm=True)
    test_err, test_loss = training.epoch(test_loader, optnet, device=device, use_tqdm=True)
    print('epoch', ep, 'train acc', 1-train_err, 'test acc', 1-test_err)
    torch.save(optnet.basemodel.state_dict(), 'models/resnet18_vgg0_adv_0.3.pth')

#eval
training.eval_robustness(optnet, test_loader, device, use_tqdm=True, evalmethods=evalmethods)

In [None]:
basenet = resnet.ResNet18()
epsilon=0.4
advnet = vgg.vgg0_bn()
advnet.load_state_dict(torch.load('models/vgg0_adv.pth'))
basenet = basenet.to(device)
advnet = advnet.to(device)
optnet = optdefense.OptNet(basenet, advnet, epsilon=epsilon, stepsize=epsilon/10., 
                           iters=13, randomize=True, adv_mode='adv')
opt = optim.SGD(optnet.basemodel.parameters(), lr=0.1, momentum=0.9)
for ep in range(80):
    if ep == 50:
        for param_group in opt.param_groups:
                param_group['lr'] = 0.01
    train_err, train_loss = training.epoch(train_loader, optnet, opt, device=device, use_tqdm=True)
    test_err, test_loss = training.epoch(test_loader, optnet, device=device, use_tqdm=True)
    print('epoch', ep, 'train acc', 1-train_err, 'test acc', 1-test_err)
    torch.save(optnet.basemodel.state_dict(), 'models/resnet18_vgg0_adv_0.4.pth')

#eval
training.eval_robustness(optnet, test_loader, device, use_tqdm=True, evalmethods=evalmethods)