In [41]:
# import some neccesary module
import copy
import logging
import os
import time

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from apex import amp

from tqdm import tqdm

In [42]:
from preact_resnet import PreActResNet18
from utils import (upper_limit, lower_limit, std, clamp, get_loaders,
    attack_pgd, evaluate_pgd, evaluate_standard)

In [43]:
# Hyperparameter Setting
out_dir = "/home/ubuntu/zhc/adversarial_attack/fast_adversarial/CIFAR10/output"
seed = 7
data_dir = '../../cifar-data'
batch_size = 128
epsilon = 8
alpha = 10
device = "cuda" if torch.cuda.is_available() else "cpu"
lr_max = 0.2
momentum = 0.9
weight_decay = 5e-4
opt_level = 'O2'
loss_scale = '1.0'
delta_init = 'random'
epochs = 15
lr_schedule = 'cyclic'
lr_min = 0.
early_stop = True

In [44]:
# if not exists the out_dir, we create it.
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

# and we create a logfile path
logfile = os.path.join(out_dir, 'output.log')

# if logfile exists, we remove it 
if os.path.exists(logfile):
    os.remove(logfile)

logging.basicConfig(
    format='[%(asctime)s] - %(message)s',
    datefmt='%Y/%m/%d %H:%M:%S',
    level=logging.INFO,
    filename=os.path.join(out_dir, 'output.log'))


In [45]:
# set random seed 

np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


In [46]:
# data preprocess and load 
train_loader, test_loader = get_loaders(data_dir, batch_size)

Files already downloaded and verified
Files already downloaded and verified


In [47]:
epsilon = (epsilon / 255.) / std
alpha = (alpha / 255.) / std
pgd_alpha = (2 / 255.) / std

In [49]:
model = PreActResNet18().to(device)
model.train()

PreActResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (layer1): Sequential(
    (0): PreActBlock(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (1): PreActBlock(
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
  )
  (layer2): Sequential(
    (0): PreActBlock(
      (bn1): BatchNorm2d

In [50]:
# optimizor: SGD
opt = torch.optim.SGD(model.parameters(), lr = lr_max, momentum=momentum, weight_decay=weight_decay)

In [51]:
amp_args = dict(opt_level=opt_level, loss_scale=loss_scale, verbosity=False)


In [52]:
if opt_level == 'O2':
    amp_args['master_weights'] = True
model, opt = amp.initialize(model, opt, **amp_args)



In [53]:
# criterion function: CrossEntropy
criterion = nn.CrossEntropyLoss()

In [54]:
if delta_init == 'previous':
    delta = torch.zeros(batch_size, 3, 32, 32).cuda()

lr_steps = epochs * len(train_loader)

if lr_schedule == 'cyclic':
    scheduler = torch.optim.lr_scheduler.CyclicLR(opt, base_lr=lr_min, max_lr=lr_max,
        step_size_up=lr_steps / 2, step_size_down=lr_steps / 2)
elif lr_schedule == 'multistep':
    scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[lr_steps / 2, lr_steps * 3 / 4], gamma=0.1)


In [55]:
# Training
prev_robust_acc = 0.
start_train_time = time.time()

In [57]:
total_batches = len(train_loader)
for epoch in tqdm(range(epochs)):
    start_epoch_time = time.time()
    train_loss = 0
    train_acc = 0
    train_n = 0
    for i, (X, y) in enumerate(train_loader):
        X, y = X.to(device), y.to(device)
        if i == 0:
            first_batch = (X, y)
        if delta_init != 'previous':
            delta = torch.zeros_like(X).cuda()
        if delta_init == 'random':
            for j in range(len(epsilon)):
                delta[:, j, :, :].uniform_(-epsilon[j][0][0].item(), epsilon[j][0][0].item())
            delta.data = clamp(delta, lower_limit - X, upper_limit - X)
            
        delta.requires_grad = True
        output = model(X + delta[:X.size(0)])
        loss = F.cross_entropy(output, y)
        with amp.scale_loss(loss, opt) as scaled_loss:
            scaled_loss.backward()
        grad = delta.grad.detach()
        delta.data = clamp(delta + alpha * torch.sign(grad), -epsilon, epsilon)
        delta.data[:X.size(0)] = clamp(delta[:X.size(0)], lower_limit - X, upper_limit - X)
        delta = delta.detach()
        output = model(X + delta[:X.size(0)])
        loss = criterion(output, y)
        opt.zero_grad()
        with amp.scale_loss(loss, opt) as scaled_loss:
            scaled_loss.backward()
        opt.step()
        train_loss += loss.item() * y.size(0)
        train_acc += (output.max(1)[1] == y).sum().item()
        train_n += y.size(0)
        scheduler.step()
    if early_stop:
        # Check current PGD robustness of model using random minibatch
        X, y = first_batch
        pgd_delta = attack_pgd(model, X, y, epsilon, pgd_alpha, 5, 1, opt)
        with torch.no_grad():
            output = model(clamp(X + pgd_delta[:X.size(0)], lower_limit, upper_limit))
        robust_acc = (output.max(1)[1] == y).sum().item() / y.size(0)
        if robust_acc - prev_robust_acc < -0.2:
            break
        prev_robust_acc = robust_acc
        best_state_dict = copy.deepcopy(model.state_dict())
    epoch_time = time.time()
    lr = scheduler.get_lr()[0]
    print('%d \t %.1f \t \t %.4f \t %.4f \t %.4f',
            epoch, epoch_time - start_epoch_time, lr, train_loss/train_n, train_acc/train_n)

train_time = time.time()

if not early_stop:
    best_state_dict = model.state_dict()
torch.save(best_state_dict, os.path.join(out_dir, 'model.pth'))
print((train_time - start_train_time)/60)

# Evaluation
model_test = PreActResNet18().to(device)
model_test.load_state_dict(best_state_dict)
model_test.float()
model_test.eval()

pgd_loss, pgd_acc = evaluate_pgd(test_loader, model_test, 50, 10)
test_loss, test_acc = evaluate_standard(test_loader, model_test)

print('Test Loss \t Test Acc \t PGD Loss \t PGD Acc')
print('%.4f \t \t %.4f \t %.4f \t %.4f', test_loss, test_acc, pgd_loss, pgd_acc)

  7%|▋         | 1/15 [00:18<04:25, 18.97s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 0 18.965568780899048 0.02666666666666666 2.1306632760620117 0.19348


 13%|█▎        | 2/15 [00:36<03:57, 18.30s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 1 17.83906054496765 0.05333333333333332 1.918559842529297 0.27552


 20%|██        | 3/15 [00:54<03:37, 18.09s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 2 17.83459162712097 0.07999999999999999 1.825787401046753 0.30774


 27%|██▋       | 4/15 [01:12<03:17, 17.98s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 3 17.802282333374023 0.10666666666666665 1.7358898900604247 0.33736


 33%|███▎      | 5/15 [01:30<02:59, 17.91s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 4 17.799811601638794 0.1333333333333333 1.6520530545043945 0.36958


 40%|████      | 6/15 [01:48<02:40, 17.86s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 5 17.759487867355347 0.15999999999999998 1.5971660720825196 0.3933


 47%|████▋     | 7/15 [02:05<02:22, 17.83s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 6 17.76990270614624 0.18666666666666673 1.5517124658584596 0.4078


 53%|█████▎    | 8/15 [02:23<02:04, 17.82s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 7 17.7959885597229 0.18666666666666673 1.5124848135375977 0.4247


 60%|██████    | 9/15 [02:41<01:47, 17.84s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 8 17.868568181991577 0.15999999999999998 1.4708707464599609 0.43812


 67%|██████▋   | 10/15 [02:59<01:29, 17.84s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 9 17.86358618736267 0.1333333333333334 1.4272908445739747 0.45386


 73%|███████▎  | 11/15 [03:17<01:11, 17.83s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 10 17.78630518913269 0.10666666666666665 1.3926101739501953 0.46768


 80%|████████  | 12/15 [03:34<00:53, 17.81s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 11 17.784669160842896 0.07999999999999999 1.3521190425491334 0.48286


 87%|████████▋ | 13/15 [03:52<00:35, 17.82s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 12 17.825743198394775 0.05333333333333332 1.296216632080078 0.50012


 93%|█████████▎| 14/15 [04:10<00:17, 17.82s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 13 17.831653594970703 0.02666666666666666 1.230075249633789 0.5264


100%|██████████| 15/15 [04:28<00:00, 17.89s/it]

%d 	 %.1f 	 	 %.4f 	 %.4f 	 %.4f 14 17.78027582168579 0.0 1.129176103439331 0.55982
4.65578502813975





Test Loss 	 Test Acc 	 PGD Loss 	 PGD Acc
%.4f 	 	 %.4f 	 %.4f 	 %.4f 0.664918485736847 0.7917 1.0259269359588623 0.445
