# Setup

### Import dependencies

In [1]:
import os
os.environ["NOTEBOOK_MODE"] = "1" # for robustness library

from robustness import model_utils, datasets, train, defaults
from robustness.datasets import CIFAR

# We use cox (http://github.com/MadryLab/cox) to log, store and analyze
# results. Read more at https//cox.readthedocs.io.
from cox.utils import Parameters
import cox.store

import torch
torch.cuda.get_device_name(0)



'TITAN RTX'

### Download models
From https://github.com/MadryLab/robustness.
Two models downloaded:
- CIFAR10 Linf-norm (ResNet50), ε = 0     (natural training)
- CIFAR10 Linf-norm (ResNet50), ε = 8/255

In [2]:
!mkdir -p data/models
!wget -q -O data/models/cifar_linf_8.pt "https://www.dropbox.com/s/c9qlt1lbdnu9tlo/cifar_linf_8.pt?dl=1"
!wget -q -O data/models/cifar_nat.pt    "https://www.dropbox.com/s/yhpp4yws7sgi6lj/cifar_nat.pt?dl=1"

### Load dataset and models

In [3]:
DS = CIFAR("data")
TRAIN_LOADER, VAL_LOADER = DS.make_loaders(
    batch_size=128, workers=12
)

==> Preparing dataset cifar..
Files already downloaded and verified
Files already downloaded and verified


In [4]:
M_NAT, _ = model_utils.make_and_restore_model(
    arch="resnet50",
    resume_path="data/models/cifar_nat.pt",
    dataset=DS, 
)
M_NAT.eval()

M_ADV, _ = model_utils.make_and_restore_model(
    arch="resnet50",
    resume_path="data/models/cifar_linf_8.pt",
    dataset=DS, 
)
M_ADV.eval()

m_finetune, _ = model_utils.make_and_restore_model(
    arch="resnet50",
    resume_path="data/models/cifar_nat.pt",
    dataset=DS, 
)
m_finetune_params = m_finetune.model.linear.parameters()

=> loading checkpoint 'data/models/cifar_nat.pt'
=> loaded checkpoint 'data/models/cifar_nat.pt' (epoch 190)
=> loading checkpoint 'data/models/cifar_linf_8.pt'
=> loaded checkpoint 'data/models/cifar_linf_8.pt' (epoch 153)
=> loading checkpoint 'data/models/cifar_nat.pt'
=> loaded checkpoint 'data/models/cifar_nat.pt' (epoch 190)


### Set up Cox logging

In [5]:
COX_STORE = cox.store.Store("cox")

Logging in: /fs/data/ttw/code/adversarial-ntks/notebooks/discard-hypothesis/cox/c38c4c9e-7238-4b7f-b1f1-df320a66a88f


# Baseline accuracies

In [6]:
EVAL_ARGS = defaults.check_and_fill_args(
    args=Parameters({
        "adv_eval": 1,
        "out_dir": "eval_out",
        "constraint": 'inf', # L-inf PGD
        "eps": 8.0 / 255.0, # Epsilon constraint (L-inf norm)
        "attack_lr": 2.5 * 8 / 255 / 20,
        "attack_steps": 20
    }),
    arg_list=defaults.PGD_ARGS,
    ds_class=CIFAR
)

In [7]:
train.eval_model(
    model=M_NAT,
    args=EVAL_ARGS,
    loader=VAL_LOADER,
    store=COX_STORE
)

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

{'epoch': 0,
 'nat_prec1': tensor(95.2500, device='cuda:0'),
 'adv_prec1': tensor(0., device='cuda:0'),
 'nat_loss': 0.19557516660168767,
 'adv_loss': 26.358350354003907,
 'train_prec1': nan,
 'train_loss': nan,
 'time': 158.10954117774963}

In [8]:
train.eval_model(
    model=M_ADV,
    args=EVAL_ARGS,
    loader=VAL_LOADER,
    store=COX_STORE
)

  0%|          | 0/79 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

{'epoch': 0,
 'nat_prec1': tensor(87.0300, device='cuda:0'),
 'adv_prec1': tensor(53.5000, device='cuda:0'),
 'nat_loss': 0.43732129278182985,
 'adv_loss': 1.3028398469924927,
 'train_prec1': nan,
 'train_loss': nan,
 'time': 164.59000420570374}

# Adversarial finetuning 

In [9]:
TRAIN_ARGS = Parameters({
    "out_dir": "train_out",
    "adv_train": 1, # Use adversarial training
    "constraint": 'inf', # L-inf PGD
    "eps": 8.0 / 255.0, # Epsilon constraint (L-inf norm)
    "attack_lr": 2.5 * 8 / 255 / 20,
    "attack_steps": 20
})

# Fill whatever parameters are missing from the defaults
TRAIN_ARGS = defaults.check_and_fill_args(
    TRAIN_ARGS, defaults.TRAINING_ARGS, CIFAR
)
TRAIN_ARGS = defaults.check_and_fill_args(
    TRAIN_ARGS, defaults.PGD_ARGS, CIFAR
)

In [None]:
# Train a model
train.train_model(
    model=m_finetune,
    update_params=m_finetune_params,
    args=TRAIN_ARGS,
    loaders=(TRAIN_LOADER, VAL_LOADER),
    store=COX_STORE
)