### This is the code reproducing the main experiments found in section. 6 of the submission.

#### Pytorch Libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms


#### External Libraries

In [None]:
import torchattacks # From Madry et. al, implementing popular adv. attack methods
import pandas as pd
import numpy as np
import cvxpy as cp
from pathlib import Path # For creating directories to save results
import os # For setting environment variables (MOSEK license)

#### Functions from within this repository

In [None]:
from pt_models.resnet import *
from helper_functions import *
from HR import * # The main file implementing HR and returning a PyTorch loss function

#### What you will need to practically run the code

In [None]:
os.environ['MOSEKLM_LICENSE_FILE']="mosek.lic" # Easily obtained via https://www.mosek.com/products/academic-licenses/
torch.version.cuda == '11.3' # Change according to your PyTorch version (Cuda 11.3 is compatible with the latest PyTorch)
device = 'cuda' if torch.cuda.is_available() else 'cpu' # Use a GPU if one is available, otherwise a CPU

#### CIFAR-10 Pre-processing

In [None]:
normalisation_mean = [0.4914, 0.4822, 0.4465] # Commonly used normalization for CIFAR-10
normalisation_std = [0.2023, 0.1994, 0.2010] # Converts images to [0, 1] scale

In [None]:
# Data
print('==> Preparing data..')
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(normalisation_mean, normalisation_std),
]) # Applying pre-processing transformations

In [None]:
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(normalisation_mean, normalisation_std),
]) 

In [None]:
train_batch_size = 128 # Commonly used CIFAR-10 training batch size
train_batches = 390 # Simply floor(50000/128). The number of complete training batches.
test_batch_size = 100 # Testing batch size. Again commonly used.

In [None]:
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=train_batch_size, shuffle=False, num_workers=20)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=False, transform=transform_test)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=test_batch_size, shuffle=False, num_workers=1)

In [None]:
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

num_classes = len(classes)

#### HR Specifications

In [2]:
α_choice = 0.1 # Trialled values were [0, 0.05, 0.1, 0.2]
r_choice = 0.1 # Trialled values were [0, 0.05, 0.1, 0.2]
ϵ_choice = 0.1 # Trialled values were [0, 0.05, 0.1, 0.2]
# This gives 4x4x4 = 64 models in total 

In [None]:
# Robustness specifications
model_name = f"alpha = {α_choice}, r = {r_choice}, eps = {ϵ_choice}"

#### Poisining/Corruption sources

In [None]:
# Error sources in the data
frac = 0.25 # What fraction of the training data should we consider as a finite random sample?
mis_specification = 0.1 # What fraction of the training labels are misspecified?
noise = 0.1 # How much adversarial noise is in the data?
gaussian_noise = 0.5 # How much gaussian noise is in the data? (Note that noise and gaussian noise are considered separately, not applied together)

# Final training 

In [None]:
for iter in np.arange(0, 10): # Run 10 iterations using different samples
    
    torch.manual_seed(iter) # Seed will be the iteration number throughout
    torch.cuda.manual_seed_all(iter)
    torch.backends.cudnn.deterministic = True # Makes it reproducible
    torch.backends.cudnn.benchmark = True 
    torch.backends.cudnn.enabled = True

    # Data & Training Specifications
    epochs = 300
    lr = 0.01
    resume = False

    start_epoch = 0 # start from epoch 0 or last checkpoint epoch

    # Implementing the error sources.
    # Statistical error - sample less than the full dataset
    
    # Function to corrupt the data. That is, to sample less than the full data size and corrupt some of the labels
    # See the paper for more details.
    sampled_batches, data_points_to_corrupt, unique_labels = return_batches_to_corrupt(iter, 
                                                                                       train_batches, 
                                                                                       train_batch_size, 
                                                                                       frac, 
                                                                                       mis_specification, 
                                                                                       num_classes)
    
    # Splitting the training batches into training and validation (70/30 split)
    np.random.seed(iter)
    training_batches = np.random.choice(sampled_batches, size = int(0.7*len(sampled_batches)), replace = False) 
    validation_batches = [i for i in sampled_batches if i not in training_batches]
    
    # Initialising the NN
    net = ResNet18(iter)
    net = net.to(device)

    if device == 'cuda':
        net = torch.nn.DataParallel(net)

    ########### TRAINING ##################

    # Loss function and optimizer
    criterion = nn.CrossEntropyLoss(reduction="none")
    optimizer = optim.Adam(net.parameters(), lr=lr)
    
    # Initialising Holistic Robustness
    HR = HR_Neural_Networks(NN_model = net,
                        learning_approach = "HD",
                        train_batch_size = train_batch_size,
                        loss_fn = criterion,
                        normalisation_used = [normalisation_mean, normalisation_std],
                        α_choice = α_choice, 
                        r_choice = r_choice,
                        ϵ_choice = ϵ_choice,
                        adversarial_steps=10,
                        adversarial_step_size=0.2
                        )

    def train(epoch):

        print('\nEpoch: %d' % epoch)
        net.train()
        train_loss, HR_losses = [], []
        correct, total = 0, 0

        for batch_idx, (inputs, targets) in enumerate(trainloader):

            if batch_idx in training_batches: # Implements subsampling
 
                targets = corrupt_targets(batch_idx,
                targets,
                training_batches,
                train_batch_size,
                train_batches,
                data_points_to_corrupt,
                unique_labels) # Corrupting labels 

                inputs, targets = inputs.to(device), targets.to(device)

                optimizer.zero_grad() # Clearing the gradient from the previous step
            
                if (α_choice != 0 or r_choice != 0 or ϵ_choice != 0):
                    HR_loss = HR.HR_criterion(inputs, targets, device) # Resolving to find the new HR loss function
                else:
                    # If no robustness is specified, just train as normal.
                    outputs = net(inputs)
                    HR_loss = torch.sum(criterion(outputs, targets))/train_batch_size

                # Backprop w.r.t HR loss
                HR_loss.backward()
                HR_losses.append(HR_loss.cpu().detach().numpy()) # Now saving it
                
                # We also collect regular, unweighted loss
                outputs = net(inputs)
                loss = torch.sum(criterion(outputs, targets))/train_batch_size 
                train_loss.append(loss.item())
                
                optimizer.step()
                
                _, predicted = outputs.max(1) # \hat{y} = \argmax{\hat{p}}
                total += targets.size(0) # Increment the total
                correct += predicted.eq(targets).sum().item() # Increment correct tally if \hat{y} is correct

        training_accuracy = correct/total

        # Collect unweighted loss, training accuracy and weighted HR loss.
        return np.mean(train_loss), training_accuracy, np.mean(HR_losses)


    ########### VALIDATION ##################

    # Initialise validation/testing adversary
    adversarial_attack_test = torchattacks.PGDL2(net,
                                                 eps=noise,
                                                 alpha=0.2,
                                                 steps=10,
                                                 random_start=True,
                                                 eps_for_division=1e-10) 

    # Important to set the same normalization as the original images
    adversarial_attack_test.set_normalization_used(
        mean=normalisation_mean, std=normalisation_std)
    
    def validate(epoch):
        net.eval()

        # Metrics - loss. Adversarial data, natural data and (gaussian) noisy data
        adv_test_loss, nat_test_loss, gn_test_loss= [], [], [] 

        # Metrics - accuracy. Adversarial data, natural data and (gaussian) noisy data
        adv_correct, nat_correct, gn_correct, total = 0, 0, 0, 0

        for batch_idx, (inputs, targets) in enumerate(trainloader):
            
            if batch_idx in validation_batches:
                
                targets = corrupt_targets(batch_idx,
                targets,
                validation_batches,
                train_batch_size,
                train_batches, 
                data_points_to_corrupt,
                unique_labels) # Corrupting labels

                inputs = inputs.to(device)
                targets = targets.to(device)

                if noise > 0:
                    adv = adversarial_attack_test(inputs, targets) # validation-time adv. attack

                else:
                    adv = inputs # If eps = 0, then there's no attack

                if gaussian_noise > 0:
                    # Adding gaussian noise to 3 channels of the CIFAR-10 image
                    gn = gaussian_noise * torch.randn(*inputs.shape) 
                    gn = gn.to(device)
                    gn_inputs = inputs + gn

                with torch.no_grad():

                    if gaussian_noise > 0:
                        # Evaluating on gaussian noise validation images
                        gn_outputs = net(gn_inputs)
                        gn_loss_vec = criterion(gn_outputs, targets)
                        gn_loss = torch.sum(gn_loss_vec)
                        gn_test_loss.append(gn_loss.item()/test_batch_size)
                        gn_predictions = gn_outputs.max(1)[1]
                        gn_correct += gn_predictions.eq(targets).sum().item()

                    if noise > 0:
                        # Evaluating on adversarial validation images
                        adv_outputs = net(adv)
                        adv_loss_vec = criterion(adv_outputs, targets)
                        adv_loss = torch.sum(adv_loss_vec)
                        adv_test_loss.append(adv_loss.item()/test_batch_size)
                        adv_predictions = adv_outputs.max(1)[1]
                        adv_correct += adv_predictions.eq(targets).sum().item()
                        adv_outputs = net(adv)

                    # Evaluating on natural validation images
                    nat_outputs = net(inputs)
                    nat_loss_vec = criterion(nat_outputs, targets)
                    nat_loss = torch.sum(nat_loss_vec)
                    nat_test_loss.append(nat_loss.item()/test_batch_size)
                    nat_predictions = nat_outputs.max(1)[1]
                    nat_correct += nat_predictions.eq(targets).sum().item()

                    total += targets.size(0)

        # Initialise saving of metrics
        outputs = {"Adv Val Loss": -1, "Adv Val Accuracy": -1,
                   "GN Val Loss": -1, "GN Val Accuracy": -1,
                   "Nat Val Loss": -1, "Nat Val Accuracy": -1}

        # Only save adversarial noise loss/accuracy if adv. noise is present
        if noise > 0:

            outputs["Adv Val Loss"] = np.mean(adv_test_loss)
            adv_accuracy = adv_correct/total
            outputs["Adv Val Accuracy"] = adv_accuracy


        # Only save gaussian noise loss/accuracy if gaussian noise is present
        if gaussian_noise > 0:

            outputs["GN Val Loss"] = np.mean(gn_test_loss)
            gn_accuracy = gn_correct/total
            outputs["GN Val Accuracy"] = gn_accuracy

        # Always collect natural loss/accuracy
        nat_accuracy = nat_correct/total
        outputs["Nat Val Loss"] = np.mean(nat_test_loss)
        outputs["Nat Val Accuracy"] = nat_accuracy

        return outputs

    ########### TESTING ##################

    def test(epoch):
        net.eval()

        # Metrics - loss. Adversarial data, natural data and (gaussian) noisy data
        adv_test_loss, nat_test_loss, gn_test_loss = [], [], []

        # Metrics - accuracy. Adversarial data, natural data and (gaussian) noisy data
        adv_correct, nat_correct, gn_correct, total = 0, 0, 0, 0

        for batch_idx, (inputs, targets) in enumerate(testloader):

            inputs = inputs.to(device)
            targets = targets.to(device)

            if noise > 0:
                adv = adversarial_attack_test(inputs, targets) # test-time adv. attack

            else:
                adv = inputs # If eps = 0, then there's no attack

            if gaussian_noise > 0:
                # Adding gaussian noise to 3 channels of the CIFAR-10 image
                gn = gaussian_noise * torch.randn(*inputs.shape)
                gn = gn.to(device)
                gn_inputs = inputs + gn

            with torch.no_grad():

                if gaussian_noise > 0:

                    # Evaluating on gaussian noise testing images
                    gn_outputs = net(gn_inputs)
                    gn_loss_vec = criterion(gn_outputs, targets)
                    gn_loss = torch.sum(gn_loss_vec)
                    gn_test_loss.append(gn_loss.item()/test_batch_size)
                    gn_predictions = gn_outputs.max(1)[1]
                    gn_correct += gn_predictions.eq(targets).sum().item()


                if noise > 0:
                    # Evaluating on adversarial testing images
                    adv_outputs = net(adv)
                    adv_loss_vec = criterion(adv_outputs, targets)
                    adv_loss = torch.sum(adv_loss_vec)
                    adv_test_loss.append(adv_loss.item()/test_batch_size)
                    adv_predictions = adv_outputs.max(1)[1]
                    adv_correct += adv_predictions.eq(targets).sum().item()
                    adv_outputs = net(adv)

                # Evaluating on natural testing images
                nat_outputs = net(inputs)
                nat_loss_vec = criterion(nat_outputs, targets)
                nat_loss = torch.sum(nat_loss_vec)
                nat_test_loss.append(nat_loss.item()/test_batch_size)
                nat_predictions = nat_outputs.max(1)[1]
                nat_correct += nat_predictions.eq(targets).sum().item()

                total += targets.size(0)

        # Initialise saving of metrics
        outputs = {"Adv Test Loss": -1, "Adv Test Accuracy": -1,
                   "GN Test Loss": -1, "GN Test Accuracy": -1,
                   "Nat Test Loss": -1, "Nat Test Accuracy": -1}

        # Only save adversarial noise loss/accuracy if adv. noise is present
        if noise > 0:

            outputs["Adv Test Loss"] = np.mean(adv_test_loss)
            adv_accuracy = adv_correct/total
            outputs["Adv Test Accuracy"] = adv_accuracy

        # Only save gaussian noise loss/accuracy if gaussian noise is present
        if gaussian_noise > 0:

            outputs["GN Test Loss"] = np.mean(gn_test_loss)
            gn_accuracy = gn_correct/total
            outputs["GN Test Accuracy"] = gn_accuracy

        # Always collect natural loss/accuracy
        nat_accuracy = nat_correct/total
        outputs["Nat Test Loss"] = np.mean(nat_test_loss)
        outputs["Nat Test Accuracy"] = nat_accuracy

        return outputs

    # RUNNING TRAINING, VALIDATION AND TESTING

    # Initialising loss metrics
    train_losses, HR_losses, nat_val_losses, adv_val_losses, gn_val_losses, nat_test_losses, adv_test_losses, gn_test_losses = [], [], [], [], [], [], [], []
    
    # Initialising accuracy metrics
    train_accuracies, nat_val_accuracies, adv_val_accuracies, gn_val_accuracies, nat_test_accuracies, adv_test_accuracies, gn_test_accuracies = [], [], [], [], [], []

    # Iterator for stopping criterion - we will run until training loss has converged
    stopping_criterion = 0

    min_epochs = 220
    end_epoch = start_epoch+epochs

    # Final training loop
    for epoch in range(start_epoch, end_epoch):

        # If the stopping criterion has become  or we've reached the final number, STOP.
        stop = (stopping_criterion >= 6 and epoch >= min_epochs) or epoch == end_epoch-1 

        train_loss, train_accuracy, HR_loss = train(epoch) # Training at every epoch

        if epoch % 20 == 0 or stop: # Collect metrics every 20 epochs or when the training is over
            
            # Training metrics
            train_losses.append(train_loss)
            HR_losses.append(HR_loss)
            train_accuracies.append(train_accuracy)
        
            # Running validation. Note this is to save computation time as this step will run only every 20 epochs
            val_metrics = validate(
                epoch)
            
            # Validation metrics
            adv_val_losses.append(val_metrics["Adv Val Loss"])
            nat_val_losses.append(val_metrics["Nat Val Loss"])
            gn_val_losses.append(val_metrics["GN Val Loss"])
            adv_val_accuracies.append(val_metrics["Adv Val Accuracy"])
            nat_val_accuracies.append(val_metrics["Nat Val Accuracy"])
            gn_val_accuracies.append(val_metrics["GN Val Accuracy"])
            
            # Running testing. Note this is to save computation time as this step will run only every 20 epochs.
            test_metrics = test(
                epoch)

            # Testing metrics                 
            adv_test_losses.append(test_metrics["Adv Test Loss"])
            nat_test_losses.append(test_metrics["Nat Test Loss"])
            gn_test_losses.append(test_metrics["GN Test Loss"])
            adv_test_accuracies.append(test_metrics["Adv Test Accuracy"])
            nat_test_accuracies.append(test_metrics["Nat Test Accuracy"])
            gn_test_accuracies.append(test_metrics["GN Test Accuracy"])
        
        # Stopping criterion - is the current training loss higher than the minimum achieved so far? (We are aiming to train to convergence)
        best_train_loss = min(train_losses)

        if train_loss > best_train_loss:
            stopping_criterion += 1
        else:
            stopping_criterion = 0

        # If stopping criterion is reached, then save models and metrics.
        if stop:

            print('Saving..')
            
            models_path = f"HD_models/{model_name}/frac_misspecified_{mis_specification}/frac_data_{frac}/iter_{iter}/"
            Path(models_path).mkdir(parents=True, exist_ok=True)
            torch.save(net.state_dict(), models_path + "checkpoint.pt")

            losses = pd.DataFrame([train_losses, HR_losses, 
                                   nat_val_losses, adv_val_losses, gn_val_losses,
                                   nat_test_losses, adv_test_losses, gn_test_losses,
                                   train_accuracies, 
                                   nat_val_accuracies, adv_val_accuracies, gn_test_accuracies,
                                   nat_test_accuracies, adv_test_accuracies, gn_test_accuracies]).T
                                        
            losses.columns = ["Training Loss", "Inflated Loss", 
                              "Natural Validation Loss", "Adversarial Validation Loss", "Gaussian Noise Validation Loss",
                              "Natural Testing Loss", "Adversarial Testing Loss", "Gaussian Noise Testing Loss",
                              "Training Accuracy", 
                              "Natural Validation Accuracy", "Adversarial Testing Accuracy", "Gaussian Noise Testing Accuracy",
                              "Natural Testing Accuracy", "Adversarial Testing Accuracy", "Gaussian Noise Testing Accuracy"]
            
            results_path = f"HD_results/{model_name}/frac_misspecified_{mis_specification}/frac_data_{frac}/iter_{iter}/"
            Path(results_path).mkdir(parents=True, exist_ok=True)
            losses.to_csv(results_path + "metrics.csv")
            
            break