# Configuration

In [None]:
# Configuration #
config = {
    'validation_times': None,
    'data_name': None,  # {'breast_cancer', 'cifar10', 'MNIST', 'FashionMNIST', 'spambase', 'abalone', 'iris', 'wine'}
    'perturb_type': None,  # {None, 'mixup', 'mixup sc', 'mixup nb', 'gauss VRM'}
    'augWPL': None,  # {None, 0.25, 0.5, 0.75, 0.9, 0.99} or others
    'geometric_param': None,  # {0.25, 0.5, 0.75, 1.} or others
    'gauss_vicinal_std': None,  # {0.25, 0.5, 0.75, 1.} or others
    'batch_size': None,
    'step_size': None,  # {0.1, 0.01, 0.001, 0.0005} or others
    'epochs': None,
    'L2_decay': None,  # {0., 1e-4} or others
    'alpha': None,  # {0.25, 0.5, 0.75, 1.} or others
    'breast_spam_model': None  # {'complex_tanh_sigmoid', 'simple_relu_sigmoid', 'simple_relu_softmax'}; ONLY VALID for legacy 'breast_cancer' and 'spambase'
}
if config['data_name'] in ['breast_cancer', 'spambase']:
    if config['breast_spam_model'] in ['complex_tanh_sigmoid', 'simple_relu_sigmoid']:
        config['criterion_type'] = 'BCE'
    elif config['breast_spam_model'] in ['simple_relu_softmax']:
        config['criterion_type'] = 'CE'
elif config['data_name'] in  ['cifar10', 'MNIST', 'FashionMNIST', 'abalone', 'iris', 'wine']:
    config['criterion_type'] = 'CE'

# Folders and Files #
gdrive_dir = '/content/gdrive'
experiment_folder_dir = gdrive_dir + '/My Drive/colab'
save_folder_dir = experiment_folder_dir + '/{}-{}'.format(str(config['data_name']), str(config['perturb_type']))
file_prefix = '{}-{}-{}-{}-'.format(str(config['augWPL']), str(config['alpha']), str(config['geometric_param']), str(config['gauss_vicinal_std']))

# Libraries

In [None]:
import torch
from torchvision import transforms, datasets, models
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from google.colab import drive
drive.mount(gdrive_dir)
import sys
sys.path.append(experiment_folder_dir)
import time
import datetime
import os
import json
import resnet
import load_external_data
import pandas as pd

if not os.path.exists(save_folder_dir):
    os.mkdir(save_folder_dir)

# Data Augmentation / Perturbation related functions

In [None]:
def mixup(inputs, labels, alpha):
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample().to('cuda')
    batch_size = labels.size(0)
    idx = torch.randperm(batch_size).to('cuda')
    mixup_inputs = lmbda * inputs + (1 - lmbda) * inputs[idx]
    labels_b = labels[idx]
    return mixup_inputs, labels, labels_b, lmbda

In [None]:
def mixup_sc(inputs, labels, alpha):
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample().to('cuda')
    mixup_inputs_uc_list = list()
    labels_uc_list = list()
    unique_classes = torch.unique(labels)
    for uc in unique_classes:
        mask_uc = (labels == uc).flatten()  # flatten to avoid the labels are in column vector
        inputs_uc = inputs[mask_uc]
        labels_uc = labels[mask_uc]
        batch_size_uc = labels_uc.size(0)
        idx = torch.randperm(batch_size_uc).to('cuda')
        mixup_inputs_uc = lmbda * inputs_uc + (1 - lmbda) * inputs_uc[idx]
        mixup_inputs_uc_list.append(mixup_inputs_uc)
        labels_uc_list.append(labels_uc)
    mixup_inputs_sc = torch.vstack(mixup_inputs_uc_list)
    mixup_labels_sc = torch.cat(labels_uc_list, dim=0)  # use cat not hstack to avoid labels are in column vector
    return mixup_inputs_sc, mixup_labels_sc, lmbda

In [None]:
def mixup_nb(inputs, labels, geometric_param, alpha):
    inner_batch_size = labels.size(0)
    inputs_flatten = inputs.reshape(inner_batch_size, -1)

    # Compute pair-wise distances & sort the distances #
    dists = torch.cdist(inputs_flatten, inputs_flatten)
    sort_idx = torch.argsort(dists, dim=1)
    sort_idx_no_itself = sort_idx[:, 1:]

    # Generate geometric random variables for selecting neighbours & get the index of selected neighbour data #
    select_idx = torch.distributions.geometric.Geometric(geometric_param).sample_n(inner_batch_size).type(torch.LongTensor).to('cuda')
    select_idx_clipped = torch.clamp(select_idx, max=inner_batch_size - 2)
    nb_idx = sort_idx_no_itself[torch.arange(inner_batch_size), select_idx_clipped]

    # mixup with neighbours #
    inputs_nb = inputs[nb_idx]
    labels_nb = labels[nb_idx]
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample().to('cuda')
    mixup_inputs_nb = lmbda * inputs + (1 - lmbda) * inputs_nb
    return mixup_inputs_nb, labels, labels_nb, lmbda

In [None]:
def mixup_criterion(criterion, predicts, labels, labels_b, lmbda):
    mixup_loss = lmbda * criterion(predicts, labels) + (1 - lmbda) * criterion(predicts, labels_b)
    return mixup_loss

In [None]:
def gauss_vicinal(inputs, gauss_vicinal_std):
    inputs_gauss = torch.normal(inputs, gauss_vicinal_std)
    return inputs_gauss

# Data Loader

In [None]:
def get_data_loader(config):
    if config['data_name'] == 'breast_cancer':
        train_data, train_labels, val_data, val_labels, test_data, test_labels = load_external_data.load_skl_data('breast_cancer')
        test_data = np.vstack((val_data, test_data))
        test_labels = np.hstack((val_labels, test_labels))
        train_data = torch.from_numpy(train_data).type(torch.FloatTensor)
        train_labels = torch.from_numpy(train_labels)
        test_data = torch.from_numpy(test_data).type(torch.FloatTensor)
        test_labels = torch.from_numpy(test_labels)
        train_mean = torch.mean(train_data, 0)
        train_std = torch.std(train_data, 0)
        train_data = (train_data - train_mean) / train_std
        test_data = (test_data - train_mean) / train_std
        train_set = torch.utils.data.TensorDataset(train_data, train_labels)
        test_set = torch.utils.data.TensorDataset(test_data, test_labels)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=config['batch_size'], shuffle=True, num_workers=0)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=config['batch_size'], shuffle=False, num_workers=0)
    elif config['data_name'] == 'cifar10':
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),  # can omit
            transforms.RandomHorizontalFlip(),  # can omit
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465),
                (0.2023, 0.1994, 0.2010)
            )
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465),
                (0.2023, 0.1994, 0.2010)
            )
        ])
        train_set = datasets.CIFAR10(root=experiment_folder_dir, train=True, download=True, transform=transform_train)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=config['batch_size'], shuffle=True, num_workers=0)
        test_set = datasets.CIFAR10(root=experiment_folder_dir, train=False, download=True, transform=transform_test)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=config['batch_size'], shuffle=False, num_workers=0)
    elif config['data_name'] == 'MNIST':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3015,))
        ])
        train_set = datasets.MNIST(root=experiment_folder_dir, train=True, download=True, transform=transform)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=config['batch_size'], shuffle=True, num_workers=0)
        test_set = datasets.MNIST(root=experiment_folder_dir, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=config['batch_size'], shuffle=False, num_workers=0)
    elif config['data_name'] == 'FashionMNIST':
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.2860,), (0.3205,))
        ])
        train_set = datasets.FashionMNIST(root=experiment_folder_dir, train=True, download=True, transform=transform)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=config['batch_size'], shuffle=True, num_workers=0)
        test_set = datasets.FashionMNIST(root=experiment_folder_dir, train=False, download=True, transform=transform)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=config['batch_size'], shuffle=False, num_workers=0)
    elif config['data_name'] == 'spambase':
        spambase = pd.read_csv(experiment_folder_dir + '/dataset_44_spambase.csv').to_numpy()
        data, labels = spambase[:, :-1], spambase[:, -1]
        num_data = labels.shape[0]
        idx = np.random.permutation(num_data)
        data = data[idx]
        labels = labels[idx]
        splitpoint = int(num_data * 0.6)
        train_data = data[:splitpoint]
        train_labels = labels[:splitpoint]
        test_data = data[splitpoint:]
        test_labels = labels[splitpoint:]
        train_data = torch.from_numpy(train_data).type(torch.FloatTensor)
        train_labels = torch.from_numpy(train_labels)
        test_data = torch.from_numpy(test_data).type(torch.FloatTensor)
        test_labels = torch.from_numpy(test_labels)
        train_mean = torch.mean(train_data, 0)
        train_std = torch.std(train_data, 0)
        train_data = (train_data - train_mean) / train_std
        test_data = (test_data - train_mean) / train_std
        train_set = torch.utils.data.TensorDataset(train_data, train_labels)
        test_set = torch.utils.data.TensorDataset(test_data, test_labels)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=config['batch_size'], shuffle=True, num_workers=0)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=config['batch_size'], shuffle=False, num_workers=0)
    elif config['data_name'] == 'abalone':
        column_names = ["sex", "length", "diameter", "height", "whole weight", "shucked weight", "viscera weight", "shell weight", "rings"]
        abalone = pd.read_csv(experiment_folder_dir + '/' + 'abalone.data', names=column_names)
        for label in "MFI":
            abalone[label] = abalone["sex"] == label
        del abalone["sex"]
        labels = abalone.rings.values
        del abalone["rings"]
        data = abalone.values.astype(np.float32)
        num_data = labels.shape[0]
        idx = np.random.permutation(num_data)
        data = data[idx]
        labels = labels[idx]
        labels[labels == 29] = 28
        labels = labels - 1
        splitpoint = int(num_data * 0.6)
        train_data = data[:splitpoint]
        train_labels = labels[:splitpoint]
        test_data = data[splitpoint:]
        test_labels = labels[splitpoint:]
        train_data = torch.from_numpy(train_data).type(torch.FloatTensor)
        train_labels = torch.from_numpy(train_labels)
        test_data = torch.from_numpy(test_data).type(torch.FloatTensor)
        test_labels = torch.from_numpy(test_labels)
        train_mean = torch.mean(train_data, 0)
        train_std = torch.std(train_data, 0)
        train_data = (train_data - train_mean) / train_std
        test_data = (test_data - train_mean) / train_std
        train_set = torch.utils.data.TensorDataset(train_data, train_labels)
        test_set = torch.utils.data.TensorDataset(test_data, test_labels)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=config['batch_size'], shuffle=True, num_workers=0)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=config['batch_size'], shuffle=False, num_workers=0)
    elif config['data_name'] == 'iris':
        train_data, train_labels, val_data, val_labels, test_data, test_labels = load_external_data.load_skl_data('iris')
        test_data = np.vstack((val_data, test_data))
        test_labels = np.hstack((val_labels, test_labels))
        train_data = torch.from_numpy(train_data).type(torch.FloatTensor)
        train_labels = torch.from_numpy(train_labels)
        test_data = torch.from_numpy(test_data).type(torch.FloatTensor)
        test_labels = torch.from_numpy(test_labels)
        train_mean = torch.mean(train_data, 0)
        train_std = torch.std(train_data, 0)
        train_data = (train_data - train_mean) / train_std
        test_data = (test_data - train_mean) / train_std
        train_set = torch.utils.data.TensorDataset(train_data, train_labels)
        test_set = torch.utils.data.TensorDataset(test_data, test_labels)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=config['batch_size'], shuffle=True, num_workers=0)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=config['batch_size'], shuffle=False, num_workers=0)
    elif config['data_name'] == 'wine':
        train_data, train_labels, val_data, val_labels, test_data, test_labels = load_external_data.load_skl_data('wine')
        test_data = np.vstack((val_data, test_data))
        test_labels = np.hstack((val_labels, test_labels))
        train_data = torch.from_numpy(train_data).type(torch.FloatTensor)
        train_labels = torch.from_numpy(train_labels)
        test_data = torch.from_numpy(test_data).type(torch.FloatTensor)
        test_labels = torch.from_numpy(test_labels)
        train_mean = torch.mean(train_data, 0)
        train_std = torch.std(train_data, 0)
        train_data = (train_data - train_mean) / train_std
        test_data = (test_data - train_mean) / train_std
        train_set = torch.utils.data.TensorDataset(train_data, train_labels)
        test_set = torch.utils.data.TensorDataset(test_data, test_labels)
        train_loader = torch.utils.data.DataLoader(train_set, batch_size=config['batch_size'], shuffle=True, num_workers=0)
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=config['batch_size'], shuffle=False, num_workers=0)
    return train_loader, test_loader

# Models

In [None]:
def get_model(config):
    if config['data_name'] == 'breast_cancer':
        if config['breast_spam_model'] == 'complex_tanh_sigmoid':
            class fc_model(nn.Module):
                def __init__(self):
                    super(fc_model, self).__init__()
                    self.fc1 = nn.Linear(30, 128)
                    self.fc2 = nn.Linear(128, 64)
                    self.fc3 = nn.Linear(64, 32)
                    self.fc4 = nn.Linear(32, 1)
                def forward(self, inputs):
                    fc1_out = F.tanh(self.fc1(inputs))
                    fc2_out = F.tanh(self.fc2(fc1_out))
                    fc3_out = F.tanh(self.fc3(fc2_out))
                    fc4_out = self.fc4(fc3_out)
                    return fc4_out
        elif config['breast_spam_model'] == 'simple_relu_sigmoid':
            class fc_model(nn.Module):
                def __init__(self):
                    super(fc_model, self).__init__()
                    self.fc1 = nn.Linear(30, 128)
                    self.fc2 = nn.Linear(128, 128)
                    self.fc3 = nn.Linear(128, 1)
                def forward(self, inputs):
                    fc1_out = F.relu(self.fc1(inputs))
                    fc2_out = F.relu(self.fc2(fc1_out))
                    fc3_out = self.fc3(fc2_out)
                    return fc3_out
        elif config['breast_spam_model'] == 'simple_relu_softmax':
            class fc_model(nn.Module):
                def __init__(self):
                    super(fc_model, self).__init__()
                    self.fc1 = nn.Linear(30, 128)
                    self.fc2 = nn.Linear(128, 128)
                    self.fc3 = nn.Linear(128, 2)
                def forward(self, inputs):
                    fc1_out = F.relu(self.fc1(inputs))
                    fc2_out = F.relu(self.fc2(fc1_out))
                    fc3_out = self.fc3(fc2_out)
                    return fc3_out 
        model = fc_model()
        model.cuda()
    elif config['data_name'] == 'cifar10':
        model = resnet.ResNet18()
        model.cuda()
    elif config['data_name'] == 'MNIST':
        model = models.resnet18(pretrained=False)
        for param in model.parameters():
            param.requires_grad = True
        model.conv1 = torch.nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
        model.fc = torch.nn.Linear(512, 10)
        model.cuda()
    elif config['data_name'] == 'FashionMNIST':
        model = models.resnet18(pretrained=False)
        for param in model.parameters():
            param.requires_grad = True
        model.conv1 = torch.nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
        model.fc = torch.nn.Linear(512, 10)
        model.cuda()
    elif config['data_name'] == 'spambase':
        if config['breast_spam_model'] == 'complex_tanh_sigmoid':
            class fc_model(nn.Module):
                def __init__(self):
                    super(fc_model, self).__init__()
                    self.fc1 = nn.Linear(57, 256)
                    self.fc2 = nn.Linear(256, 128)
                    self.fc3 = nn.Linear(128, 64)
                    self.fc4 = nn.Linear(64, 32)
                    self.fc5 = nn.Linear(32, 1)
                def forward(self, inputs):
                    fc1_out = F.tanh(self.fc1(inputs))
                    fc2_out = F.tanh(self.fc2(fc1_out))
                    fc3_out = F.tanh(self.fc3(fc2_out))
                    fc4_out = F.tanh(self.fc4(fc3_out))
                    fc5_out = self.fc5(fc4_out)
                    return fc5_out
        elif config['breast_spam_model'] == 'simple_relu_sigmoid':
            class fc_model(nn.Module):
                def __init__(self):
                    super(fc_model, self).__init__()
                    self.fc1 = nn.Linear(57, 128)
                    self.fc2 = nn.Linear(128, 128)
                    self.fc3 = nn.Linear(128, 1)
                def forward(self, inputs):
                    fc1_out = F.relu(self.fc1(inputs))
                    fc2_out = F.relu(self.fc2(fc1_out))
                    fc3_out = self.fc3(fc2_out)
                    return fc3_out
        elif config['breast_spam_model'] == 'simple_relu_softmax':
            class fc_model(nn.Module):
                def __init__(self):
                    super(fc_model, self).__init__()
                    self.fc1 = nn.Linear(57, 128)
                    self.fc2 = nn.Linear(128, 128)
                    self.fc3 = nn.Linear(128, 2)
                def forward(self, inputs):
                    fc1_out = F.relu(self.fc1(inputs))
                    fc2_out = F.relu(self.fc2(fc1_out))
                    fc3_out = self.fc3(fc2_out)
                    return fc3_out 
        model = fc_model()
        model.cuda()
    elif config['data_name'] == 'abalone':
        class fc_model(nn.Module):
            def __init__(self):
                super(fc_model, self).__init__()
                self.fc1 = nn.Linear(10, 128)
                self.fc2 = nn.Linear(128, 128)
                self.fc3 = nn.Linear(128, 28)
            def forward(self, inputs):
                fc1_out = F.relu(self.fc1(inputs))
                fc2_out = F.relu(self.fc2(fc1_out))
                fc3_out = self.fc3(fc2_out)
                return fc3_out
        model = fc_model()
        model.cuda()
    elif config['data_name'] == 'iris':
        class fc_model(nn.Module):
            def __init__(self):
                super(fc_model, self).__init__()
                self.fc1 = nn.Linear(4, 128)
                self.fc2 = nn.Linear(128, 128)
                self.fc3 = nn.Linear(128, 3)
            def forward(self, inputs):
                fc1_out = F.relu(self.fc1(inputs))
                fc2_out = F.relu(self.fc2(fc1_out))
                fc3_out = self.fc3(fc2_out)
                return fc3_out
        model = fc_model()
        model.cuda()
    elif config['data_name'] == 'wine':
        class fc_model(nn.Module):
            def __init__(self):
                super(fc_model, self).__init__()
                self.fc1 = nn.Linear(13, 128)
                self.fc2 = nn.Linear(128, 128)
                self.fc3 = nn.Linear(128, 3)
            def forward(self, inputs):
                fc1_out = F.relu(self.fc1(inputs))
                fc2_out = F.relu(self.fc2(fc1_out))
                fc3_out = self.fc3(fc2_out)
                return fc3_out
        model = fc_model()
        model.cuda()
    return model

# Testing

In [None]:
def testing(data_loader, criterion, model, config):
    model.eval()
    correct = 0
    total = 0
    loss = 0.

    # Store perturbation loss if required #
    if config['perturb_type'] is not None:
        perturb_loss = 0
    
    # Start testing #
    with torch.no_grad():
        for data in data_loader:

            # Get the loss of the batch #
            inputs, labels = data
            inputs = inputs.to('cuda')
            if config['criterion_type'] == 'BCE':
                labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
            elif config['criterion_type'] == 'CE':
                labels = labels.type(torch.LongTensor).to('cuda')
            outputs = model(inputs)
            batch_loss = criterion(outputs, labels)
            loss += batch_loss.item()

            # Get the loss of the perturb batch if required #
            if config['perturb_type'] == 'mixup':
                mixup_inputs, mixup_labels_a, mixup_labels_b, lmbda = mixup(inputs, labels, config['alpha'])
                mixup_outputs = model(mixup_inputs)
                batch_mixup_loss = mixup_criterion(criterion, mixup_outputs, mixup_labels_a, mixup_labels_b, lmbda)
                perturb_loss += batch_mixup_loss.item()
            elif config['perturb_type'] == 'mixup sc':
                mixup_inputs_sc, mixup_labels_sc, lmbda = mixup_sc(inputs, labels, config['alpha'])
                mixup_outputs_sc = model(mixup_inputs_sc)
                batch_mixup_loss_sc = criterion(mixup_outputs_sc, mixup_labels_sc)
                perturb_loss += batch_mixup_loss_sc.item()
            elif config['perturb_type'] == 'mixup nb':
                mixup_inputs_nb, mixup_labels_nb_a, mixup_labels_nb_b, lmbda = mixup_nb(inputs, labels, config['geometric_param'], config['alpha'])
                mixup_outputs_nb = model(mixup_inputs_nb)
                batch_mixup_loss_nb = mixup_criterion(criterion, mixup_outputs_nb, mixup_labels_nb_a, mixup_labels_nb_b, lmbda)
                perturb_loss += batch_mixup_loss_nb.item()
            elif config['perturb_type'] == 'gauss VRM':
                gauss_inputs = gauss_vicinal(inputs, config['gauss_vicinal_std'])
                gauss_outputs = model(gauss_inputs)
                batch_gauss_loss = criterion(gauss_outputs, labels)
                perturb_loss += batch_gauss_loss.item()
            
            # Compute predictions #
            if config['criterion_type'] == 'BCE':
                predicts = (torch.sign(outputs) + 1) / 2
            elif config['criterion_type'] == 'CE':
                _, predicts = torch.max(outputs, 1)

            # Accumulation #
            total += labels.size(0)
            correct += (predicts == labels).sum().item()
    
    # Compute accuracy #
    accuracy = correct / total

    # Compute mean losses #
    mean_loss = loss / total
    if config['perturb_type'] is not None:
        mean_perturb_loss = perturb_loss / total

    # Return according to required #
    model.train()
    if config['perturb_type'] is not None:
        return mean_loss, accuracy, mean_perturb_loss
    else:
        return mean_loss, accuracy

# Plot

In [None]:
def plot_lines(history, config, save_folder_dir, file_prefix, timestamp):

    # Make title #
    final_test_acc = history['epoch_test_accuracy'][-1]
    if config['augWPL'] is None:
        if config['perturb_type'] is None:
            title = '{}; final test acc: {:.7f}'.format(config['data_name'], final_test_acc)
        else:
            title = '{} {}; final test acc: {:.7f}'.format(config['data_name'], config['perturb_type'], final_test_acc)
    else:
        title = '{} {} WPL{}; final test acc: {:.7f}'.format(config['data_name'], config['perturb_type'], config['augWPL'], final_test_acc)
    
    # Plot losses #
    plt.figure(figsize=(10, 7))
    plt.plot(history['epoch_mean_train_loss'], label='train')
    plt.plot(history['epoch_mean_test_loss'], label='test')
    if config['perturb_type'] is not None:
        plt.plot(history['epoch_mean_perturb_loss'], label=config['perturb_type'])
    plt.grid()
    plt.legend()
    plt.title(title)
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.savefig('{} loss.png'.format(save_folder_dir + '/' + file_prefix + timestamp))
    plt.show()

    # Plot accuracies #
    plt.figure(figsize=(10, 7))
    plt.plot(history['epoch_train_accuracy'], label='train')
    plt.plot(history['epoch_test_accuracy'], label='test')
    plt.grid()
    plt.legend()
    plt.title(title)
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.savefig('{} accuracy.png'.format(save_folder_dir + '/' + file_prefix + timestamp))
    plt.show()

# Save history

In [None]:
def save_history(history, save_folder_dir, file_prefix, timestamp):
    json_file_name = file_prefix + timestamp + '.json'
    with open(save_folder_dir + '/' + json_file_name, 'w') as fp:
        json.dump(history, fp)

# Training

In [None]:
def training(train_loader, test_loader, model, optimizer, step_size_scheduler, config):

    # History #
    history = {
        'epoch_mean_train_loss': list(),
        'epoch_train_accuracy': list(),
        'epoch_mean_test_loss': list(),
        'epoch_test_accuracy': list(),
    }
    if config['perturb_type'] is not None:
        history['epoch_mean_perturb_loss'] = list()
    
    # Define criterion #
    if config['criterion_type'] == 'BCE':
        criterion = torch.nn.BCEWithLogitsLoss()
    elif config['criterion_type'] == 'CE':
        criterion = torch.nn.CrossEntropyLoss()
    
    # Start training #
    for epoch in range(config['epochs']):
        start = time.time()
        for i, data in enumerate(train_loader, 0):
            model.train()
            optimizer.zero_grad()
            
            # Get inputs and labels #
            inputs, labels = data
            inputs = inputs.to('cuda')
            if config['criterion_type'] == 'BCE':
                labels = labels.type(torch.FloatTensor).reshape(-1, 1).to('cuda')
            elif config['criterion_type'] == 'CE':
                labels = labels.type(torch.LongTensor).to('cuda')

            # Perturbation #
            if config['perturb_type'] == 'mixup':
                perturb_inputs, perturb_labels_a, perturb_labels_b, lmbda = mixup(inputs, labels, config['alpha'])
            elif config['perturb_type'] == 'mixup sc':
                perturb_inputs, perturb_labels, lmbda = mixup_sc(inputs, labels, config['alpha'])
            elif config['perturb_type'] == 'mixup nb':
                perturb_inputs, perturb_labels_a, perturb_labels_b, lmbda = mixup_nb(inputs, labels, config['geometric_param'], config['alpha'])
            elif config['perturb_type'] == 'gauss VRM':
                perturb_inputs = gauss_vicinal(inputs, config['gauss_vicinal_std'])
            
            # Augmentation or not #
            if config['augWPL'] is None:
                if config['perturb_type'] is None:
                    ultimate_inputs = inputs
                else:
                    ultimate_inputs = perturb_inputs
            else:
                ultimate_inputs = torch.vstack((inputs, perturb_inputs))

            # Get outputs #
            ultimate_outputs = model(ultimate_inputs)
            if config['augWPL'] is None:
                if config['perturb_type'] is None:
                    outputs = ultimate_outputs
                else:
                    perturb_outputs = ultimate_outputs
            else:
                current_batch_size = labels.size(0)
                outputs = ultimate_outputs[:current_batch_size]
                perturb_outputs = ultimate_outputs[current_batch_size:]
            
            # Compute losses #
            if config['augWPL'] is None:
                if config['perturb_type'] is None:
                    ultimate_loss = criterion(outputs, labels)
                elif config['perturb_type'] == 'mixup':
                    ultimate_loss = mixup_criterion(criterion, perturb_outputs, perturb_labels_a, perturb_labels_b, lmbda)
                elif config['perturb_type'] == 'mixup sc':
                    ultimate_loss = criterion(perturb_outputs, perturb_labels)
                elif config['perturb_type'] == 'mixup nb':
                    ultimate_loss = mixup_criterion(criterion, perturb_outputs, perturb_labels_a, perturb_labels_b, lmbda)
                elif config['perturb_type'] == 'gauss VRM':
                    ultimate_loss = criterion(perturb_outputs, labels)
            else:
                loss = criterion(outputs, labels)
                if config['perturb_type'] == 'mixup':
                    perturb_loss = mixup_criterion(criterion, perturb_outputs, perturb_labels_a, perturb_labels_b, lmbda)
                elif config['perturb_type'] == 'mixup sc':
                    perturb_loss = criterion(perturb_outputs, perturb_labels)
                elif config['perturb_type'] == 'mixup nb':
                    perturb_loss = mixup_criterion(criterion, perturb_outputs, perturb_labels_a, perturb_labels_b, lmbda)
                elif config['perturb_type'] == 'gauss VRM':
                    perturb_loss = criterion(perturb_outputs, labels)
                ultimate_loss = config['augWPL'] * perturb_loss + (1 - config['augWPL']) * loss

            # Gradient calculation and optimisation #
            ultimate_loss.backward()
            optimizer.step()

        # Step size scheduler #
        step_size_scheduler.step()

        # Testing on Train and Test data #
        if config['perturb_type'] is None:
            epoch_mean_train_loss, epoch_train_accuracy = testing(train_loader, criterion, model, config)
            epoch_mean_test_loss, epoch_test_accuracy = testing(test_loader, criterion, model, config)
        else:
            epoch_mean_train_loss, epoch_train_accuracy, epoch_mean_perturb_loss = testing(train_loader, criterion, model, config)
            epoch_mean_test_loss, epoch_test_accuracy, _ = testing(test_loader, criterion, model, config)
        history['epoch_mean_train_loss'].append(epoch_mean_train_loss)
        history['epoch_train_accuracy'].append(epoch_train_accuracy)
        history['epoch_mean_test_loss'].append(epoch_mean_test_loss)
        history['epoch_test_accuracy'].append(epoch_test_accuracy)
        if config['perturb_type'] is not None:
            history['epoch_mean_perturb_loss'].append(epoch_mean_perturb_loss)

        # Print losses and accuracies #
        end = time.time()
        if config['perturb_type'] is None:
            print('epoch: {}, train loss: {:.10f}, train acc: {:.5f}, test loss: {:.10f}, test acc: {:.5f}, {:.2f}s'.format(epoch + 1, epoch_mean_train_loss, epoch_train_accuracy, epoch_mean_test_loss, epoch_test_accuracy, end - start))
        else:
            print('epoch: {}, train loss: {:.10f}, train acc: {:.5f}, perturb loss: {:.10f}, test loss: {:.10f}, test acc: {:.5f}, {:.2f}s'.format(epoch + 1, epoch_mean_train_loss, epoch_train_accuracy, epoch_mean_perturb_loss, epoch_mean_test_loss, epoch_test_accuracy, end - start))
    return history

In [None]:
train_loader, test_loader = get_data_loader(config)
model = get_model(config)
optimizer = torch.optim.SGD(model.parameters(), lr=config['step_size'], momentum=0.9, weight_decay=config['L2_decay'])
step_size_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(config['epochs'] * 0.5), int(config['epochs'] * 0.75)], gamma=0.1)
history = training(train_loader, test_loader, model, optimizer, step_size_scheduler, config)
timestamp = datetime.datetime.now().strftime("%d-%m-%y-%H-%M-%S")
save_history(history, save_folder_dir, file_prefix, timestamp)
plot_lines(history, config, save_folder_dir, file_prefix, timestamp)