# Training Robust Models

This notebook will demonstrate how you can include adversarial transforms during training to train more robust and potentially more accurate models.

The basic process is very simple - we create our desired training loop as normal, and then add an adversarial transform to transform a batch of data before we train on it. This is very similar to standard data augmentation, but the adversarial optimisation finds more challenging transformations, resulting in greater improvements in robustness.

Install the toolbox

In [1]:
!pip install reetoolbox

Collecting reetoolbox
  Downloading reetoolbox-0.1.0.tar.gz (13 kB)
Collecting torchvision==0.2.1
  Downloading torchvision-0.2.1-py2.py3-none-any.whl (54 kB)
[K     |████████████████████████████████| 54 kB 1.5 MB/s 
Building wheels for collected packages: reetoolbox
  Building wheel for reetoolbox (setup.py) ... [?25l[?25hdone
  Created wheel for reetoolbox: filename=reetoolbox-0.1.0-py3-none-any.whl size=15761 sha256=4d31c56adbda9931acb7bd02800e037ee869a4602120f4c5572c0a7aadacd22c
  Stored in directory: /root/.cache/pip/wheels/7b/67/32/58fb18b077c661c8037f07368b00421439401f770d64b2ad81
Successfully built reetoolbox
Installing collected packages: torchvision, reetoolbox
  Attempting uninstall: torchvision
    Found existing installation: torchvision 0.11.1+cu111
    Uninstalling torchvision-0.11.1+cu111:
      Successfully uninstalled torchvision-0.11.1+cu111
Successfully installed reetoolbox-0.1.0 torchvision-0.2.1


If you're using Google Colab, you need to mount your drive using the cell below. Otherwise you can skip this.

In [3]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


You must set the variable PATH to the directory containing this file. It will probably look like the below path if you cloned the https://github.com/alexjfoote/reetoolbox-tutorials repo to your Google Drive.

In [4]:
PATH = "/content/drive/My Drive/reetoolbox-tutorials"

Import some useful functions

In [7]:
from reetoolbox.utils import load_resnet, load_pannuke, get_dataloader

Define the device we're using and the class names

In [8]:
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
classes = ["Negative", "Positive"]

Load the PanNuke dataset (see https://jgamper.github.io/PanNukeDataset/,  J. Gamper, N. A. Koohbanani, K. Benet, A. Khuram, and N. Rajpoot,
“PanNuke: An Open Pan-Cancer Histology Dataset for Nuclei Instance
Segmentation and Classification,” in Digital Pathology, pp. 11–19,
Springer, Cham, Apr. 2019.)

In [12]:
import os

data_path = os.path.join(PATH, "Data/breast_folds.npz")
Xtr, ytr, Xts, yts = load_pannuke(data_path)

Next we split our training data into a train and validation set and create a dictionary containing a training and validation data loader. We also create a dataloader from our test set.

In [13]:
# Create dataloaders for the training loop
batch_size = 16
val_batch_size = batch_size

# Number of epochs to train for
num_epochs = 15

length = int(0.7 * len(Xtr))

train_Xtr = Xtr[:length]
val_Xtr = Xtr[length:]
train_ytr = ytr[:length]
val_ytr = ytr[length:]

train_data = torch.utils.data.TensorDataset(train_Xtr, train_ytr)
val_data = torch.utils.data.TensorDataset(val_Xtr, val_ytr)

test_data = torch.utils.data.TensorDataset(Xts, yts)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, shuffle = False)

print("Initializing Datasets and Dataloaders...")

num_workers = 2

dataloaders_dict = {
    "train": torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers),
    "val": torch.utils.data.DataLoader(val_data, batch_size=val_batch_size, shuffle=True, num_workers=num_workers)
}

Initializing Datasets and Dataloaders...


We load a pretrained model that we will fine-tune on our dataset and replace the final layer to fit our application (2 output neurons). We then gather all the parameters to update and create an optimiser for use during training.

In [23]:
from torchvision import models
from torch import optim

def load_model(n_classes=2):
    model = models.resnet18(pretrained=True)  
    model.fc = nn.Sequential(nn.Linear(512, n_classes),        
                            nn.LogSoftmax(dim=1))
    model = model.to(device)
    model.train()  

    params_to_update = model.parameters()
    optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)
    return model, optimizer_ft

Create a standard training loop, which also takes a function that takes a batch of inputs and returns a transformed batch, which we can use to wrap our adversarial transforms.

In [35]:
import copy
import time

def train_loop(model, dataloaders, criterion, optimizer, epochs, device="cuda:0", transform_func=None, **kwargs):
    acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    t_full = time.time()

    train_loader = dataloaders["train"]
    val_loader = dataloaders["val"]

    for epoch in range(epochs): 
        print(f'Epoch {epoch + 1}/{epochs}')
        print('-' * 10)   
        t_epoch = time.time() 

        num_examples = 0
        running_loss = 0.0
        running_corrects = 0

        model.train()

        for stage in dataloaders:            
            for batch_no, (inputs, labels) in enumerate(train_loader):
                inputs = inputs.to(device)
                labels = labels.to(device)  

                with torch.set_grad_enabled(True):
                    optimizer.zero_grad()    

                    if stage == "train":
                        model.train()

                        if transform_func is not None:
                            inputs = transform_func(model, inputs, labels, **kwargs)  

                        outputs = model(inputs)  
                        _, preds = torch.max(outputs, 1)

                        loss = criterion(outputs, labels)
                                            
                        loss.backward()
                        optimizer.step()
                    else:
                        model.eval()
                        outputs = model(inputs)  
                        _, preds = torch.max(outputs, 1)

                    num_examples += len(preds)
                    running_corrects += torch.sum(preds == labels.data)
                    running_loss += loss.item()

            epoch_loss = running_loss / len(train_loader)
            epoch_acc = running_corrects.double() / num_examples
            print(f'{stage}: Loss: {round(epoch_loss, 3)} Acc: {round(epoch_acc.item(), 3)}')

            if stage == "val" and epoch_acc >= best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print(f"{round(time.time() - t_epoch, 2)}s")  

        acc_history.append(epoch_acc)

    print(f"Took: {round(time.time() - t_full, 2)}s")
    print(f'Best val Acc: {best_acc}')
    model.load_state_dict(best_model_wts)
    return model, np.array(acc_history)

Next we import the desired transforms and corresponding optimisers and parameters, as well as the evaluator, and create a function that will measure the accuracy and robustness of a model to the stain transform.

In [37]:
from reetoolbox.transforms import StainTransform
from reetoolbox.optimisers import PGD
from reetoolbox.constants import (eval_stain_transform_params,
                            eval_stain_optimiser_params)
from reetoolbox.evaluator import Evaluator
from reetoolbox.metrics import get_metrics

def evaluate(model, test_dataset, test_dataloader):
    stain_evaluator = Evaluator(model, test_dataset, test_dataloader, PGD, 
                                StainTransform, eval_stain_optimiser_params, 
                                eval_stain_transform_params, device=device)
    results = stain_evaluator.predict(adversarial=True)
    get_metrics(results)

To start we train a model without any adversarial transforms to get a baseline.

In [38]:
from torch import nn

std_model, optimizer_ft = load_model(n_classes=2)
std_model, _ = train_loop(std_model, dataloaders_dict, nn.CrossEntropyLoss(), optimizer_ft, 
           10, device=device, transform_func=None)
evaluate(std_model, test_data, test_loader)

Epoch 1/10
----------
train: Loss: 0.517 Acc: 0.733
val: Loss: 0.675 Acc: 0.835
3.79s
Epoch 2/10
----------
train: Loss: 0.205 Acc: 0.918
val: Loss: 0.514 Acc: 0.918
3.74s
Epoch 3/10
----------
train: Loss: 0.073 Acc: 0.988
val: Loss: 0.268 Acc: 0.992
3.75s
Epoch 4/10
----------
train: Loss: 0.047 Acc: 0.989
val: Loss: 0.286 Acc: 0.98
3.73s
Epoch 5/10
----------
train: Loss: 0.045 Acc: 0.987
val: Loss: 0.271 Acc: 0.993
3.77s
Epoch 6/10
----------
train: Loss: 0.054 Acc: 0.986
val: Loss: 1.076 Acc: 0.949
3.76s
Epoch 7/10
----------
train: Loss: 0.038 Acc: 0.989
val: Loss: 0.053 Acc: 0.995
3.77s
Epoch 8/10
----------
train: Loss: 0.03 Acc: 0.991
val: Loss: 0.036 Acc: 0.996
3.77s
Epoch 9/10
----------
train: Loss: 0.013 Acc: 0.999
val: Loss: 0.078 Acc: 0.999
3.78s
Epoch 10/10
----------
train: Loss: 0.016 Acc: 0.997
val: Loss: 0.094 Acc: 0.998
3.76s
Took: 37.62s
Best val Acc: 0.999455930359086
Accuracy: 0.928, robust accuracy: 0.268, fooling ratio: 0.712


Next we import `apply_transforms` and the default parameters for creating a stain adversarial optimiser. 

We create a list of tuples containing the optimisers and their adversarial optimiser parameters. We then iterate over this list and use the parameters to set up each optimiser. 

We can give this list of adversarial optimisers to `apply_transforms` and it will sample `k` optimisers and sequentially transform the data using the sampled adversarial optimisers. 

We pass `apply_transforms`, `k`, and the list of adversarial optimisers to the training loop, which will use them to transform each batch of data.

In [42]:
from reetoolbox.trainer import apply_transforms
from reetoolbox.constants import stain_adv_opt_params

adv_model, optimizer_ft = load_model(n_classes=2)

optimizers_and_params = [(PGD, stain_adv_opt_params)]

all_adv_opts = []
for TransformOptimiser, adv_opt_params in optimizers_and_params:
    all_adv_opts.append(TransformOptimiser(adv_model, **adv_opt_params, device=device))

adv_model, _ = train_loop(adv_model, dataloaders_dict, nn.CrossEntropyLoss(), optimizer_ft, 
           epochs=10, device=device, transform_func=apply_transforms, k=1, 
           adv_optimisers=all_adv_opts)

evaluate(adv_model, test_data, test_loader)

Epoch 1/10
----------
train: Loss: 0.687 Acc: 0.606
val: Loss: 1.769 Acc: 0.739
9.41s
Epoch 2/10
----------
train: Loss: 0.489 Acc: 0.764
val: Loss: 0.872 Acc: 0.847
9.32s
Epoch 3/10
----------
train: Loss: 0.424 Acc: 0.81
val: Loss: 0.842 Acc: 0.86
9.34s
Epoch 4/10
----------
train: Loss: 0.331 Acc: 0.857
val: Loss: 0.49 Acc: 0.904
9.33s
Epoch 5/10
----------
train: Loss: 0.237 Acc: 0.911
val: Loss: 0.888 Acc: 0.943
9.36s
Epoch 6/10
----------
train: Loss: 0.215 Acc: 0.924
val: Loss: 1.177 Acc: 0.958
9.34s
Epoch 7/10
----------
train: Loss: 0.192 Acc: 0.923
val: Loss: 0.428 Acc: 0.958
9.4s
Epoch 8/10
----------
train: Loss: 0.117 Acc: 0.954
val: Loss: 0.591 Acc: 0.977
9.34s
Epoch 9/10
----------
train: Loss: 0.128 Acc: 0.946
val: Loss: 0.195 Acc: 0.971
9.34s
Epoch 10/10
----------
train: Loss: 0.073 Acc: 0.982
val: Loss: 0.268 Acc: 0.991
9.37s
Took: 93.55s
Best val Acc: 0.9907508161044614
Accuracy: 0.921, robust accuracy: 0.890, fooling ratio: 0.033


As you can see, including the adversarial stain transform improved the model's accuracy, and significantly improved it's robustness to the stain transform, compared to the model trained without the transform.

You can also use the built in implementation of adversarial training for free (see https://arxiv.org/abs/1904.12843), adapted for use with multiple adversarial transforms. This reorders the data, repeating each `m` times in a row and performing the optimisation over the repetitions for more efficient adversarial training. This is beneficial when you want to a more expensive adversarial optimisation, such as multi-step PGD - you can do one step per repetition of the batch and reach comparable robustness with lower training times.

You can also use the built in evaluation function, which will take a list of dictionaries that contain all the info needed to set up and run an evaluation, and perform all the evaluations.

In [47]:
from reetoolbox.trainer import train, evaluation

model_params = {
    "path": None,
    "load_saved": False,
    "pretrained": True,
    "n_classes": 2
}

train_params = {
    "dataloaders": dataloaders_dict,
    "criterion": nn.CrossEntropyLoss(),
    "initial_epochs": 0,
    "adv_epochs": 10,
    "m": 1,
    "last_n": 10,
    "k": 1
}

adv_free_model, optimizer_ft = load_model()
adv_free_model = train(adv_free_model, optimizer_ft, optimizers_and_params, 
                          train_params, device=device)

stain_evaluator_params = {
    "dataset": test_data,
    "dataloader": test_loader,
    "TransformOptimiser": PGD,
    "Transform": StainTransform,
    "optimiser_params": eval_stain_optimiser_params,
    "trans_params": eval_stain_transform_params
}

evaluator_params = [stain_evaluator_params]

evaluation(adv_free_model, evaluator_params, device=device)

Initial Epochs... 
Epoch 1/0
----------
Took: 0.0s
Best val Acc: 0.0
Adversarial Epochs...
Epoch 1/10
----------
Train: Loss: 0.717 Acc: 0.578
Val: Loss: 0.498 Acc: 0.777
8.6s
Epoch 2/10
----------
Train: Loss: 0.535 Acc: 0.754
Val: Loss: 0.455 Acc: 0.787
8.74s
Epoch 3/10
----------
Train: Loss: 0.385 Acc: 0.827
Val: Loss: 0.332 Acc: 0.838
8.7s
Epoch 4/10
----------
Train: Loss: 0.266 Acc: 0.881
Val: Loss: 0.376 Acc: 0.851
8.73s
Epoch 5/10
----------
Train: Loss: 0.243 Acc: 0.904
Val: Loss: 0.398 Acc: 0.835
8.73s
Epoch 6/10
----------
Train: Loss: 0.186 Acc: 0.928
Val: Loss: 0.437 Acc: 0.825
8.7s
Epoch 7/10
----------
Train: Loss: 0.14 Acc: 0.945
Val: Loss: 0.334 Acc: 0.873
8.69s
Epoch 8/10
----------
Train: Loss: 0.119 Acc: 0.952
Val: Loss: 0.425 Acc: 0.846
8.71s
Epoch 9/10
----------
Train: Loss: 0.106 Acc: 0.961
Val: Loss: 0.456 Acc: 0.835
8.73s
Epoch 10/10
----------
Train: Loss: 0.076 Acc: 0.976
Val: Loss: 0.476 Acc: 0.853
8.72s
Took: 87.19s
Best val Acc: 0.8734177215189873
<class