In [1]:
OUT_DIR = '/home/ubuntu/logs_russian_roulette_adversarial_training'
NUM_WORKERS = 16
BATCH_SIZE = 512

# Barebones starter example

### Imports

In [2]:
from robustness import model_utils, datasets, train, defaults
from robustness.datasets import CIFAR
import torch as ch

# We use cox (http://github.com/MadryLab/cox) to log, store and analyze
# results. Read more at https//cox.readthedocs.io.
from cox.utils import Parameters
import cox.store

### Make dataset and loaders

In [3]:
# Hard-coded dataset, architecture, batch size, workers
ds = CIFAR('/tmp/')
m, _ = model_utils.make_and_restore_model(arch='resnet18', dataset=ds)
train_loader, val_loader = ds.make_loaders(batch_size=BATCH_SIZE, workers=NUM_WORKERS)

==> Preparing dataset cifar..
Files already downloaded and verified
Files already downloaded and verified


### Make a cox store for logging

In [4]:
# Create a cox store for logging
store = store.Store(OUT_DIR, "CIFAR10 -- Russian Roulette Adversarial Training")
args_dict = args.as_dict() if isinstance(args, utils.Parameters) else vars(args)
schema = store.schema_from_dict(args_dict)
store.add_table('metadata', schema)
store['metadata'].append_row(args_dict)

Logging in: /home/ubuntu/logs_russian_roulette_adversarial_training/808a675c-0185-4c9f-b685-90a21b3eecb0


### Set up training arguments

In [5]:
# Hard-coded base parameters
train_kwargs = {
    'out_dir': "train_out",
    'adv_train': 1,
    'constraint': '2',
    'eps': 0.5,
    'attack_lr': 0.1,
    'attack_steps': 7,
    'epochs': 5,
    'log_iters'  : 10,
    'stop_probability' : 0.1
}
train_args = Parameters(train_kwargs)

# Fill whatever parameters are missing from the defaults
train_args = defaults.check_and_fill_args(train_args,
                        defaults.TRAINING_ARGS, CIFAR)
train_args = defaults.check_and_fill_args(train_args,
                        defaults.PGD_ARGS, CIFAR)

### Train Model

In [6]:
# Train a model
train.train_model(train_args, m, (train_loader, val_loader), store=out_store)
pass

Train Epoch:0 | Loss 1.9917 | AdvPrec1 29.470 | AdvPrec5 81.114 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:32<00:00,  3.03it/s]
Val Epoch:0 | Loss 1.6042 | NatPrec1 39.170 | NatPrec5 89.770 | Reg term: 0.0 ||: 100%|██████████| 20/20 [00:00<00:00, 20.51it/s]
Val Epoch:0 | Loss 2.1360 | AdvPrec1 20.560 | AdvPrec5 79.560 | Reg term: 0.0 ||: 100%|██████████| 20/20 [00:22<00:00,  1.14s/it]
Train Epoch:1 | Loss 1.5309 | AdvPrec1 43.442 | AdvPrec5 90.932 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:28<00:00,  3.49it/s]
Train Epoch:2 | Loss 1.2701 | AdvPrec1 53.878 | AdvPrec5 94.206 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:26<00:00,  3.72it/s]
Train Epoch:3 | Loss 1.0978 | AdvPrec1 60.698 | AdvPrec5 95.862 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:27<00:00,  3.60it/s]
Train Epoch:4 | Loss 0.9638 | AdvPrec1 65.784 | AdvPrec5 96.856 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:27<00:00,  3.61it/s]
Val Epoch:4 | Loss 0.8916 | NatPrec1 67.810 | NatPrec5 97.870 | Reg term: 0.0 ||

# Customizations

## Custom loss

In [7]:
train_crit = ch.nn.CrossEntropyLoss()
def custom_train_loss(logits, targ):
    probs = ch.ones_like(logits) * 0.5
    logits_to_multiply = ch.bernoulli(probs) * 9 + 1
    return train_crit(logits_to_multiply * logits, targ)

adv_crit = ch.nn.CrossEntropyLoss(reduction='none').cuda()
def custom_adv_loss(model, inp, targ):
    logits = model(inp)
    probs = ch.ones_like(logits) * 0.5
    logits_to_multiply = ch.bernoulli(probs) * 9 + 1
    new_logits = logits_to_multiply * logits
    return adv_crit(new_logits, targ), new_logits

train_args.custom_train_loss = custom_train_loss
train_args.custom_adv_loss = custom_adv_loss

In [8]:
train.train_model(train_args, m, (train_loader, val_loader), store=out_store)

Train Epoch:0 | Loss 2.9150 | AdvPrec1 12.864 | AdvPrec5 55.432 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:31<00:00,  3.13it/s]
Val Epoch:0 | Loss 2.2236 | NatPrec1 18.270 | NatPrec5 65.570 | Reg term: 0.0 ||: 100%|██████████| 20/20 [00:01<00:00, 19.45it/s]
Val Epoch:0 | Loss 2.2883 | AdvPrec1 15.960 | AdvPrec5 62.130 | Reg term: 0.0 ||: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it]
Train Epoch:1 | Loss 2.1842 | AdvPrec1 18.872 | AdvPrec5 73.456 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:30<00:00,  3.18it/s]
Train Epoch:2 | Loss 2.1177 | AdvPrec1 21.404 | AdvPrec5 78.918 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:29<00:00,  3.32it/s]
Train Epoch:3 | Loss 2.0731 | AdvPrec1 25.254 | AdvPrec5 81.684 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:33<00:00,  2.94it/s]
Train Epoch:4 | Loss 2.0079 | AdvPrec1 29.734 | AdvPrec5 84.780 | Reg term: 0.0 ||: 100%|██████████| 98/98 [00:30<00:00,  3.17it/s]
Val Epoch:4 | Loss 1.9488 | NatPrec1 33.730 | NatPrec5 86.850 | Reg term: 0.0 ||

DataParallel(
  (module): AttackerModel(
    (normalizer): InputNormalize()
    (model): ResNet(
      (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (layer1): SequentialWithArgs(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (shortcut): Sequential()
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     

## Custom per-iteration logging

In [None]:
CUSTOM_SCHEMA = {'iteration': int, 'weight_norm': float }
out_store.add_table('custom', CUSTOM_SCHEMA)

In [None]:
from torch.nn.utils import parameters_to_vector as flatten

def log_norm(mod, it, loop_type, inp, targ):
    if loop_type == 'train':
        curr_params = flatten(mod.parameters())
        log_info_custom = { 'iteration': it,
                    'weight_norm': ch.norm(curr_params).detach().cpu().numpy() }
        out_store['custom'].append_row(log_info_custom)
    
train_args.iteration_hook = log_norm

In [None]:
train.train_model(train_args, m, (train_loader, val_loader), store=out_store)
pass

## Custom architecture

In [None]:
from torch import nn
from robustness.model_utils import make_and_restore_model

class MLP(nn.Module):
    # Must implement the num_classes argument
    def __init__(self, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(32*32*3, 1000)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(1000, num_classes)

    def forward(self, x, *args, **kwargs):
        out = x.view(x.shape[0], -1)
        out = self.fc1(out)
        out = self.relu1(out)
        return self.fc2(out)

new_model = MLP(num_classes=10)

In [None]:
new_model, _ = make_and_restore_model(arch=new_model, dataset=ds)

In [None]:
train.train_model(train_args, new_model, (train_loader, val_loader), store=out_store)
pass