# Training script

It runs the training flow.

### Imports

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# To enable importing robustness directory
import sys, os
sys.path.insert(1, os.path.join(sys.path[0], '..'))

import torch as ch
import numpy as np

import cox.store
from cox.utils import Parameters
from datetime import datetime
import time

from robustness import model_utils, datasets, train, defaults
from robustness.datasets import CIFAR, HAM10000, HAM10000_dataset, HAM10000_3cls, HAM10000_dataset_3cls_balanced, freeze, unfreeze
from robustness.tools.utils import fix_random_seed
from robustness.evaluation import plot_curves_from_log, evaluate_model

## Training config

In [None]:
# Training
ADV_TRAIN = False
ADV_EVAL = False
lr = 1e-4
BATCH_SIZE = 32
EPOCHS = 10
step_lr = None
custom_schedule = None
lr_patience = 5
es_patience = 10

# Model
base_model_expid = None
use_dropout_head = False
dropout_perc = 0
arch = 'resnet18'
pytorch_pretrained = True
unfreeze_to_layer = 0

# Other settings
do_eval_model = False
eval_checkpoint_type = 'latest'
TRAIN_COLAB = True
NUM_WORKERS = 16
expid = datetime.now().strftime("%Y-%m-%d---%H:%M:%S")
seed = 42

# Ablation
apply_ablation = False
saliency_dir = None
perc_ablation = 0

# Adversary
EPS = 0.5
ITERATIONS = 7
constraint = '2'

In [None]:
if TRAIN_COLAB:
    ds_path = #TODO-USER (e.g., "/content/data")
    OUT_DIR = # TODO-USER (e.g., "/content/drive/My Drive/logs")
    device = 'cuda'
else:
    ds_path = # TODO-USER (e.g., "/Users/andrei/Google Drive/data/HAM10000")
    OUT_DIR = # TODO-USER (e.g., "/Users/andrei/Google Drive/logs")
    device = 'cpu'

train_kwargs = {
    'out_dir': "train_out",
    'adv_train': ADV_TRAIN,
    'adv_eval': ADV_EVAL,
    'epochs': EPOCHS,
    'lr': lr,
    'optimizer': 'Adam',
    'device': device,
    'batch_size': BATCH_SIZE,
    'arch': arch,
    'pytorch_pretrained': pytorch_pretrained,
    'dataset_file_name': train_file_name,
    'step_lr': step_lr,
    'custom_schedule': custom_schedule,
    'lr_patience': lr_patience,
    'es_patience': es_patience,
    'log_iters': 1,
    'use_adv_prec': True,
    'apply_ablation': apply_ablation,
    'saliency_dir': saliency_dir,
    'perc_ablation': perc_ablation,
    'dropout_perc': dropout_perc,
    'use_dropout_head': use_dropout_head
}

attack_kwargs = {
    'constraint': constraint,
    'eps': EPS,
    'attack_lr': EPS/5,
    'attack_steps': ITERATIONS,
    'random_start': True
}

# merge train_kwargs with attack_kwargs
train_kwargs_merged = {**train_kwargs, **attack_kwargs}

In [None]:
expid

In [None]:
fix_random_seed(seed)
out_store = cox.store.Store(OUT_DIR, expid)

In [None]:
print(out_store.exp_id)

### Resume path

In [None]:
train_kwargs_merged['base_model_expid'] = base_model_expid
if base_model_expid:
  resume_path = os.path.join(OUT_DIR, base_model_expid, "checkpoint.pt.latest")
else:
  resume_path = None

## Train

Fill whatever parameters are missing from the defaults

In [None]:
train_args = Parameters(train_kwargs_merged)
train_args = defaults.check_and_fill_args(train_args,
                        defaults.TRAINING_ARGS, HAM10000)
train_args = defaults.check_and_fill_args(train_args,
                        defaults.PGD_ARGS, CIFAR)

train_args

### Data Loader

In [None]:
dataset = HAM10000_3cls(ds_path, file_name=train_file_name, 
                        apply_ablation=apply_ablation, saliency_dir=saliency_dir, perc_ablation=perc_ablation,
                        use_dropout_head=use_dropout_head, dropout_perc=dropout_perc)

train_loader, val_loader = dataset.make_loaders(
    batch_size=BATCH_SIZE,
    workers=NUM_WORKERS
)

In [None]:
model, _ = model_utils.make_and_restore_model(
    arch=arch,
    pytorch_pretrained=pytorch_pretrained,
    dataset=dataset, 
    resume_path=resume_path,
    device=device
)

In [None]:
if base_model_expid == None: # if no base model, then train only the last layers
    freeze(model.model)
    unfreeze(model.model.fc, 5)
else: # if base model, then unfreeze until a given layer to fine-tune the whole network
    model = model.module
    unfreeze(model.model, unfreeze_to_layer)

In [None]:
model.model

### Train model

In [None]:
start = time.time()

model_finetuned = train.train_model(train_args, model, (train_loader, val_loader), store=out_store)

end = time.time()
print("Training took %.2f sec" % (end - start))

In [None]:
plot_curves_from_log(out_store)['logs'].df

In [None]:
print(out_store.exp_id)

In [None]:
out_store.close()

### Evaluate model

Evaluate the model on the whole train set and the test set (on standard data, with ablation as in training)

In [None]:
if do_eval_model:
    # training dataset
    train_dataset = HAM10000_dataset_3cls_balanced(ds_path, train_file_name, train=True, 
                                                   transform = dataset.transform_test, 
                                                   apply_ablation=apply_ablation, saliency_dir=saliency_dir, 
                                                   perc_ablation=perc_ablation)

    # test dataset
    test_dataset = HAM10000_dataset_3cls_balanced(ds_path, test_file_name, test=True,
                                                  transform = dataset.transform_test,
                                                  apply_ablation=apply_ablation, saliency_dir=saliency_dir, 
                                                  perc_ablation=perc_ablation)

    accs = evaluate_model(out_store.exp_id, dataset, train_dataset, test_dataset, OUT_DIR, device, arch, checkpoint_type=eval_checkpoint_type)
    print(accs)