In [4]:
OUT_DIR = '/tmp/'
NUM_WORKERS = 16
BATCH_SIZE = 128

# The following notebook is adapted from
# https://github.com/MadryLab/robustness/blob/master/notebooks/Using%20robustness%20as%20a%20library.ipynb


from torchvision.models.utils import load_state_dict_from_url


### Adversarial training

In [1]:
from robustness import model_utils, datasets, train, defaults
from robustness.datasets import FashionMnist#CIFAR#FashionMnist
from robustness import data_augmentation as da
import torch 
import torchvision.datasets

#import sys
#sys.path.append('/home/u21010246/mlpr/venv/lib/python3.8/site-packages/robustness')

# We use cox (http://github.com/MadryLab/cox) to log, store and analyze
# results. Read more at https//cox.readthedocs.io.
from cox.utils import Parameters
import cox.store




### Make dataset and loaders

In [6]:
#download the dataset
fm_train = torchvision.datasets.FashionMNIST('/tmp', download=True, train=True,transform=da.TRAIN_TRANSFORMS_DEFAULT(32))
fm_val = torchvision.datasets.FashionMNIST('/tmp', download=True, train=False,transform=da.TEST_TRANSFORMS_DEFAULT(32))

train_loader = torch.utils.data.DataLoader(fm_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = torch.utils.data.DataLoader(fm_val, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

ds = FashionMnist('/tmp/FashionMNIST/raw') # CIFAR('/tmp')
m, _ = model_utils.make_and_restore_model(arch='resnet18', dataset=ds)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to /tmp/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting /tmp/FashionMNIST/raw/train-images-idx3-ubyte.gz to /tmp/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to /tmp/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting /tmp/FashionMNIST/raw/train-labels-idx1-ubyte.gz to /tmp/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to /tmp/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting /tmp/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to /tmp/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting /tmp/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/FashionMNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


### Make a cox store for logging

In [7]:
# Create a cox store for logging
out_store = cox.store.Store(OUT_DIR)

Logging in: /tmp/ea313207-b970-4a7e-877b-cd77e5e79a24


### Set up training arguments

In [8]:
# Hard-coded base parameters
train_kwargs = {
    'out_dir': "train_out",
    'adv_train': 1,
    'constraint': '2',
    'eps': 0.5,
    'attack_lr': 0.1,
    'attack_steps': 7,
    'epochs': 120
}
train_args = Parameters(train_kwargs)

# Fill whatever parameters are missing from the defaults
train_args = defaults.check_and_fill_args(train_args,
                        defaults.TRAINING_ARGS, FashionMnist)
train_args = defaults.check_and_fill_args(train_args,
                        defaults.PGD_ARGS, FashionMnist)

### Train Model

In [9]:
# Train a model
train.train_model(train_args, m, (train_loader, val_loader), store=out_store)
pass

Train Epoch:0 | Loss 1.3718 | AdvPrec1 56.040 | AdvPrec5 94.698 | Reg term: 0.0 
Val Epoch:0 | Loss 0.7390 | NatPrec1 69.720 | NatPrec5 99.390 | Reg term: 0.0 ||
Val Epoch:0 | Loss 0.9120 | AdvPrec1 63.160 | AdvPrec5 99.090 | Reg term: 0.0 ||
Train Epoch:1 | Loss 0.7063 | AdvPrec1 71.453 | AdvPrec5 99.390 | Reg term: 0.0 
Train Epoch:2 | Loss 0.6158 | AdvPrec1 75.137 | AdvPrec5 99.553 | Reg term: 0.0 
Train Epoch:3 | Loss 0.5807 | AdvPrec1 76.500 | AdvPrec5 99.620 | Reg term: 0.0 
Train Epoch:4 | Loss 0.5471 | AdvPrec1 77.778 | AdvPrec5 99.685 | Reg term: 0.0 
Train Epoch:5 | Loss 0.5234 | AdvPrec1 78.810 | AdvPrec5 99.748 | Reg term: 0.0 
Val Epoch:5 | Loss 0.4089 | NatPrec1 84.060 | NatPrec5 99.830 | Reg term: 0.0 ||
Val Epoch:5 | Loss 0.6014 | AdvPrec1 76.110 | AdvPrec5 99.680 | Reg term: 0.0 ||
Train Epoch:6 | Loss 0.5045 | AdvPrec1 79.522 | AdvPrec5 99.752 | Reg term: 0.0 
Train Epoch:7 | Loss 0.4925 | AdvPrec1 79.947 | AdvPrec5 99.788 | Reg term: 0.0 
Train Epoch:8 | Loss 0.4797 