In [1]:
import json
import time
import argparse
import shutil

import os
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from copy import deepcopy

from data import get_data_info
from data import load_data

from util import format_time
from util import Logger
from util import Trainer
from util import seed

In [2]:
class Args:
    augment = 'base'

    num_workers = 1
    batch_size = 256
    batch_size_validation = 256
    
    data_dir = './data/database/'
    log_dir = 'trained_models'
        
    data = 'cifar10'
    desc = 'resnet4_3semble_test_1'

    model = "resnet4" #'preact-resnet18-swish'
    normalize = False
    pretrained_file = None

    num_adv_epochs = 100 # 400
    adv_eval_freq = 25
    
    beta = 5.0
    lamda = 2.0
    log_det_lamda = 0.2 #0.5
    tau = 0.995
    
    lr = 0.2
    weight_decay = 5e-4
    scheduler = 'cosinew'
    nesterov = True
    clip_grad = None

    attack = 'linf-pgd'
    attack_eps = 8/255
    attack_step = 2/255
    attack_iter = 10

    debug = False
    unsup_fraction = 0.7
    seed = 1
    resume_path = ''
    num_models = 3

    #label smoothing
    ls = 0.1

args = Args()

In [3]:
DATA_DIR = os.path.join(args.data_dir, args.data)
LOG_DIR = os.path.join(args.log_dir, args.desc)
WEIGHTS = os.path.join(LOG_DIR, 'weights-best.pt')
if os.path.exists(LOG_DIR):
    shutil.rmtree(LOG_DIR)
os.makedirs(LOG_DIR)
logger = Logger(os.path.join(LOG_DIR, 'log-train.log'))

with open(os.path.join(LOG_DIR, 'args.txt'), 'w') as f:
    json.dump(args.__dict__, f, indent=4)


info = get_data_info(DATA_DIR)
BATCH_SIZE = args.batch_size
BATCH_SIZE_VALIDATION = args.batch_size_validation
NUM_ADV_EPOCHS = args.num_adv_epochs
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.log('Using device: {}'.format(device))
if args.debug:
    NUM_ADV_EPOCHS = 1

# To speed up training
torch.backends.cudnn.benchmark = True


# Load data

seed(args.seed)
train_dataloader, test_dataloader, eval_dataloader = load_data(
    data_dir=DATA_DIR, 
    batch_size=BATCH_SIZE, 
    batch_size_test=BATCH_SIZE_VALIDATION, 
    augmentation=args.augment, 
    unsup_fraction=args.unsup_fraction,
    num_workers=args.num_workers
)

# Adversarial Training

seed(args.seed)
trainer = Trainer(info, args)

last_lr = args.lr


if NUM_ADV_EPOCHS > 0:
    logger.log('\n\n')
    metrics = pd.DataFrame()
    logger.log('Standard Accuracy-\tTest: {:2f}%.'.format(trainer.eval(test_dataloader)*100))
    
    old_score = [0.0, 0.0]
    logger.log('RST Adversarial training for {} epochs'.format(NUM_ADV_EPOCHS))
    trainer.init_optimizer(args.num_adv_epochs)
    test_adv_acc = 0.0    

if args.resume_path:
    start_epoch = trainer.load_model_resume(os.path.join(args.resume_path, 'state-last.pt')) + 1
    logger.log(f'Resuming at epoch {start_epoch}')
else:
    start_epoch = 1

for epoch in range(start_epoch, NUM_ADV_EPOCHS+1):
    start = time.time()
    logger.log('======= Epoch {} ======='.format(epoch))
    
    if args.scheduler:
        last_lr = trainer.scheduler.get_last_lr()[0]
    
    res = trainer.train(train_dataloader, epoch=epoch)
    test_acc = trainer.eval(test_dataloader)
    
    logger.log('Loss: {:.4f}.\tLR: {:.4f}'.format(res['loss'], last_lr))
    if 'clean_acc' in res:
        logger.log('Standard Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(res['clean_acc']*100, test_acc*100))
    else:
        logger.log('Standard Accuracy-\tTest: {:.2f}%.'.format(test_acc*100))
    epoch_metrics = {'train_'+k: v for k, v in res.items()}
    epoch_metrics.update({'epoch': epoch, 'lr': last_lr, 'test_clean_acc': test_acc, 'test_adversarial_acc': ''})

    if epoch % args.adv_eval_freq == 0 or epoch == NUM_ADV_EPOCHS:        
        test_adv_acc = trainer.eval(test_dataloader, adversarial=True)
        logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(res['adversarial_acc']*100, 
                                                                                test_adv_acc*100))
        epoch_metrics.update({'test_adversarial_acc': test_adv_acc})
    else:
        logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.'.format(res['adversarial_acc']*100))
    eval_adv_acc = trainer.eval(eval_dataloader, adversarial=True)
    logger.log('Adversarial Accuracy-\tEval: {:.2f}%.'.format(eval_adv_acc*100))
    epoch_metrics['eval_adversarial_acc'] = eval_adv_acc

    
    if eval_adv_acc >= old_score[1]:
        old_score[0], old_score[1] = test_acc, eval_adv_acc
        trainer.save_model(WEIGHTS)
    # trainer.save_model(os.path.join(LOG_DIR, 'weights-last.pt'))
    if epoch % 10 == 0:
        trainer.save_model_resume(os.path.join(LOG_DIR, 'state-last.pt'), epoch) 
    if epoch % 400 == 0:
        shutil.copyfile(WEIGHTS, os.path.join(LOG_DIR, f'weights-best-epoch{str(epoch)}.pt'))

    logger.log('Time taken: {}'.format(format_time(time.time()-start)))
    metrics = metrics.append(pd.DataFrame(epoch_metrics, index=[0]), ignore_index=True)
    metrics.to_csv(os.path.join(LOG_DIR, 'stats_adv.csv'), index=False)

    
    
# Record metrics

train_acc = res['clean_acc'] if 'clean_acc' in res else trainer.eval(train_dataloader)
logger.log('\nTraining completed.')
logger.log('Standard Accuracy-\tTrain: {:.2f}%.\tTest: {:.2f}%.'.format(train_acc*100, old_score[0]*100))
if NUM_ADV_EPOCHS > 0:
    logger.log('Adversarial Accuracy-\tTrain: {:.2f}%.\tEval: {:.2f}%.'.format(res['adversarial_acc']*100, old_score[1]*100)) 

logger.log('Script Completed.')

Logging to file:  trained_models/resnet4_3semble_test_1/log-train.log
Using device: cuda
Files already downloaded and verified
Files already downloaded and verified



Standard Accuracy-	Test: 10.058594%.
RST Adversarial training for 100 epochs
Loss: 4.0852.	LR: 0.0080
Standard Accuracy-	Train: 42.08%.	Test: 47.08%.
Adversarial Accuracy-	Train: 40.19%.
Adversarial Accuracy-	Eval: 4.30%.
Time taken: 0:01:14
Loss: 7.0961.	LR: 0.1520
Standard Accuracy-	Train: 16.03%.	Test: 21.01%.
Adversarial Accuracy-	Train: 15.78%.
Adversarial Accuracy-	Eval: 10.94%.
Time taken: 0:01:14
Loss: 5.0001.	LR: 0.2000
Standard Accuracy-	Train: 29.74%.	Test: 26.67%.
Adversarial Accuracy-	Train: 29.18%.
Adversarial Accuracy-	Eval: 8.50%.
Time taken: 0:01:14
Loss: 4.4896.	LR: 0.1999
Standard Accuracy-	Train: 40.03%.	Test: 46.24%.
Adversarial Accuracy-	Train: 37.25%.
Adversarial Accuracy-	Eval: 10.06%.
Time taken: 0:01:13
Loss: 4.0495.	LR: 0.1997
Standard Accuracy-	Train: 47.44%.	Test: 52.38%.
Adversarial Accuracy

NameError: name 'arg' is not defined

In [None]:
torch.autocast?