# Training Robust Models

This notebook will demonstrate how you can include adversarial transforms during training to train more robust and potentially more accurate models.

The basic process is very simple - we create our desired training loop as normal, and then add an adversarial transform to transform a batch of data before we train on it. This is very similar to standard data augmentation, but the adversarial optimisation finds more challenging transformations, resulting in greater improvements in robustness.

If you haven't already, make a copy of this Tutorials directory and a put it in the directory you want to work in.

You must set the variable PATH to the directory containing this file.

In [None]:
PATH = ""

Import some useful functions

In [None]:
from reet.utils import load_resnet, load_pannuke, get_dataloader

Define the device we're using and the class names

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
classes = ["Negative", "Positive"]

Load the PanNuke dataset (see https://jgamper.github.io/PanNukeDataset/,  J. Gamper, N. A. Koohbanani, K. Benet, A. Khuram, and N. Rajpoot,
“PanNuke: An Open Pan-Cancer Histology Dataset for Nuclei Instance
Segmentation and Classification,” in Digital Pathology, pp. 11–19,
Springer, Cham, Apr. 2019.)

In [None]:
data_path = os.path.join(PATH, "Data/breast_folds.npz")
Xtr, ytr, Xts, yts = load_pannuke(data_path)

Next we split our training data into a train and validation set and create a dictionary containing a training and validation data loader. We also create a dataloader from our test set.

In [None]:
# Create dataloaders for the training loop
batch_size = 16
val_batch_size = batch_size

# Number of epochs to train for
num_epochs = 15

length = int(0.7 * len(Xtr))

train_Xtr = Xtr[:length]
val_Xtr = Xtr[length:]
train_ytr = ytr[:length]
val_ytr = ytr[length:]

train_data = torch.utils.data.TensorDataset(train_Xtr, train_ytr)
val_data = torch.utils.data.TensorDataset(val_Xtr, val_ytr)

test_data = torch.utils.data.TensorDataset(Xts, yts)
test_loader = torch.utils.data.DataLoader(test_data, batch_size = batch_size, shuffle = False)

print("Initializing Datasets and Dataloaders...")

num_workers = 2

dataloaders_dict = {
    "train": torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers),
    "val": torch.utils.data.DataLoader(val_data, batch_size=val_batch_size, shuffle=True, num_workers=num_workers)
}

We load a pretrained model that we will fine-tune on our dataset and replace the final layer to fit our application (2 output neurons). We then gather all the parameters to update and create an optimiser for use during training.

In [None]:
from torchvision import models

def load_model():
    model = models.resnet18(pretrained=True)  
    model.fc = nn.Sequential(nn.Linear(512, n_classes),        
                            nn.LogSoftmax(dim=1))
    model = model.to(device)
    model.train()  

    params_to_update = model.parameters()
    optimizer_ft = optim.SGD(params_to_update, lr=0.001, momentum=0.9)
    return model, optimizer_ft

Create a standard training loop, which also takes a function that takes a batch of inputs and returns a transformed batch, which we can use to wrap our adversarial transforms.

In [None]:
import copy

def train_loop(model, dataloaders, criterion, optimizer, device="cuda:0", transform_func=None, **kwargs):
    acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    t_full = time.time()

    train_loader = dataloaders["train"]
    val_loader = dataloaders["val"]

    for epoch in range(epochs): 
        print(f'Epoch {epoch + 1}/{epochs}')
        print('-' * 10)   
        t_epoch = time.time() 

        num_examples = 0
        running_loss = 0.0
        running_corrects = 0

        model.train()

        for stage in dataloaders:            
            for batch_no, (inputs, labels) in enumerate(train_loader):
                inputs = inputs.to(device)
                labels = labels.to(device)  

                with torch.set_grad_enabled(True):
                    optimizer.zero_grad()    

                    if stage == "train":
                        model.train()

                        if transform_func is not None:
                            inputs = transform_func(model, inputs, labels, **kwargs)  

                        outputs = model(inputs)  
                        _, preds = torch.max(outputs, 1)

                        loss = criterion(outputs, labels)
                                            
                        loss.backward()
                        optimizer.step()
                    else:
                        model.eval()
                        outputs = model(inputs)  
                        _, preds = torch.max(outputs, 1)

                    num_examples += len(preds)
                    running_corrects += torch.sum(preds == labels.data)
                    running_loss += loss.item()

                epoch_loss = running_loss / batches_per_epoch
                epoch_acc = running_corrects.double() / num_examples
                print_update(stage, epoch_loss, epoch_acc)

        print(f"{round(time.time() - t_epoch, 2)}s")        
        
        if epoch_acc_v >= best_acc:
            best_acc = epoch_acc_v
            best_model_wts = copy.deepcopy(model.state_dict())

        acc_history.append(np.array([epoch_acc, epoch_acc_v]))

    print(f"Took: {round(time.time() - t_full, 2)}s")
    print(f'Best val Acc: {best_acc}')
    model.load_state_dict(best_model_wts)
    return model, np.array(acc_history)

Next we import the desired transforms and corresponding optimisers and parameters, as well as the evaluator, and create a function that will measure the accuracy and robustness of a model to the stain transform.

In [None]:
from reet.transforms import StainTransform
from reet.optimisers import PGD
from reet.constants import (eval_stain_transform_params,
                            eval_stain_optimiser_params)
from reet.evaluator import Evaluator

def evaluate(model, test_dataset, test_dataloader):
    stain_evaluator = Evaluator(model, test_dataset, test_dataloader, PGD, 
                                StainTransform, eval_stain_optimiser_params, 
                                eval_stain_transform_params, device=device)
    results = stain_evaluator.predict(adversarial=True)
    get_metrics(results)

To start we train a model without any adversarial transforms to get a baseline.

In [None]:
from torch import nn

std_model, optimizer_ft = load_model()
std_model, _ = train_loop(std_model, dataloaders_dict, nn.CrossEntropyLoss(), optimizer_ft, 
           device=device, transform_func=None)
evaluate(std_model, test_data, test_loader)

Next we import `apply_transforms` and the default parameters for creating a stain adversarial optimiser. 

We create a list of tuples containing the optimisers and their adversarial optimiser parameters. We then iterate over this list and use the parameters to set up each optimiser. 

We can give this list of adversarial optimisers to `apply_transforms` and it will sample `k` optimisers and sequentially transform the data using the sampled adversarial optimisers. 

We pass `apply_transforms`, `k`, and the list of adversarial optimisers to the training loop, which will use them to transform each batch of data.

In [None]:
from reet.trainer import apply_transforms
from reet.constants import stain_adv_opt_params

adv_model, optimizer_ft = load_model()

optimizers_and_params = [(PGD, stain_adv_opt_params)]

all_adv_opts = []
for TransformOptimiser, adv_opt_params in optimizers_and_params:
    all_adv_opts.append(TransformOptimiser(adv_model, **adv_opt_params, device=device))

adv_model, _ = train_loop(adv_model, dataloaders_dict, nn.CrossEntropyLoss(), optimizer_ft, 
           device=device, transform_func=apply_transforms, k=1, 
           adv_optimisers=all_adv_opts)

evaluate(adv_model, test_data, test_loader)

As you can see, including the adversarial stain transform improved the model's accuracy, and significantly improved it's robustness to the stain transform, compared to the model trained without the transform.

You can also use the built in implementation of adversarial training for free (see https://arxiv.org/abs/1904.12843), adapted for use with multiple adversarial transforms. This reorders the data, repeating each `m` times in a row and performing the optimisation over the repetitions for more efficient adversarial training. This is beneficial when you want to a more expensive adversarial optimisation, such as multi-step PGD - you can do one step per repetition of the batch and reach comparable robustness with lower training times.

You can also use the built in evaluation function, which will take a list of dictionaries that contain all the info needed to set up and run an evaluation, and perform all the evaluations.

In [None]:
from reet.trainer import train, evaluation

model_params = {
    "path": None,
    "load_saved": False,
    "pretrained": True,
    "n_classes": 2
}

train_params = {
    "dataloaders": dataloaders_dict,
    "criterion": nn.CrossEntropyLoss(),
    "initial_epochs": 0,
    "adv_epochs": 30,
    "m": 5,
    "last_n": 30,
    "k": 1
}

adv_free_model, optimizer_ft = load_model()
adv_free_model, _ = train(adv_free_model, optimizer_ft, optimizers_and_params, 
                          train_params, device=device)

stain_evaluator_params = {
    "dataset": test_data,
    "dataloader": test_loader,
    "TransformOptimiser": PGD,
    "Transform": StainTransform,
    "optimiser_params": eval_stain_optimiser_params,
    "trans_params": eval_stain_trans_params
}

evaluator_params = [stain_evaluator_params]

evaluation(adv_free_model, evaluator_params, device=device)