# Import Libraries

In [1]:
import torch
from torchvision import transforms, datasets, models

from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


# Import outside code

In [2]:
import numpy as np
from sklearn.datasets import load_iris, load_wine, load_breast_cancer, make_circles, make_classification, make_regression


def train_val_test_split(data, labels, split=(0.6, 0.2, 0.2)):
    # Split data #
    num_data = data.shape[0]
    num_train_data = int(num_data * split[0])
    num_val_data = int(num_data * split[1])
    train_data = data[:num_train_data]
    train_labels = labels[:num_train_data]
    val_data = data[num_train_data:num_train_data + num_val_data]
    val_labels = labels[num_train_data:num_train_data + num_val_data]
    test_data = data[num_train_data + num_val_data:]
    test_labels = labels[num_train_data + num_val_data:]
    train_val_test = (train_data, train_labels, val_data, val_labels, test_data, test_labels)
    return train_val_test


def load_skl_data(data_name, need_num=None, split=(0.6, 0.2, 0.2)):
    # Load and unpack data from sklearn & randomise #
    if data_name == 'iris':
        skl_data = load_iris()
    elif data_name == 'wine':
        skl_data = load_wine()
    elif data_name == 'breast_cancer':
        skl_data = load_breast_cancer()
    num_data = skl_data['data'].shape[0]
    random_idx = np.random.permutation(num_data)
    data = skl_data['data'][random_idx]
    labels = skl_data['target'][random_idx]

    # Require number of data #
    if need_num is not None:
        data = data[:need_num]
        labels = data[:need_num]

    # Split data #
    train_val_test = train_val_test_split(data, labels, split=split)
    return train_val_test


def load_circular_data(need_num, noise=0.1, factor=0.5, split=(0.6, 0.2, 0.2)):
    # Load circular data #
    data, labels = make_circles(n_samples=need_num, noise=noise, factor=factor)
    labels[labels == 0] = -1

    # Split data #
    train_val_test = train_val_test_split(data, labels, split=split)
    return train_val_test


def load_two_spirals(need_num, noise=0.5, split=(0.6, 0.2, 0.2)):
    # Create two spirals data #
    n = np.sqrt(np.random.rand(need_num, 1)) * 780 * (2 * np.pi) / 360
    d1x = -np.cos(n) * n + np.random.rand(need_num, 1) * noise
    d1y = np.sin(n) * n + np.random.rand(need_num, 1) * noise
    data_extended = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y))))
    labels_extended = np.hstack((np.ones(need_num) * -1, np.ones(need_num)))
    idx = np.random.permutation(need_num * 2)
    data_extended = data_extended[idx]
    labels_extended = labels_extended[idx]
    data = data_extended[:need_num]
    labels = labels_extended[:need_num]

    # Split data #
    train_val_test = train_val_test_split(data, labels, split=split)
    return train_val_test


def load_random_classification_dataset(need_num, need_features, need_classes=2, need_flip=0.01, class_sep=1.0, random_state=None, split=(0.6, 0.2, 0.2)):
    # Create data for classification #
    n_informative = need_classes
    n_redundant = 0
    n_repeated = 0
    n_cluster_per_class = 2
    data, labels = make_classification(n_samples=need_num, n_features=need_features, n_informative=n_informative, n_redundant=n_redundant, n_repeated=n_repeated, n_classes=need_classes, n_clusters_per_class=n_cluster_per_class, flip_y=need_flip, class_sep=class_sep, random_state=random_state)

    # Change labels to +1/-1 if it is binary classification #
    if need_classes == 2:
        labels[labels == 0] = -1

    # Split data #
    train_val_test = train_val_test_split(data, labels, split=split)
    return train_val_test


def load_random_regression_dataset(need_num, need_features, bias, noise=1, random_state=None, split=(0.6, 0.2, 0.2)):
    # Create data for regression #
    n_informative = need_features
    n_targets = 1
    data, labels = make_regression(n_samples=need_num, n_features=need_features, n_informative=n_informative, n_targets=n_targets, bias=bias, noise=noise, random_state=random_state)

    # Split data #
    train_val_test = train_val_test_split(data, labels, split=split)
    return train_val_test

In [3]:
'''ResNet in PyTorch.

BasicBlock and Bottleneck module is from the original ResNet paper:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385

PreActBlock and PreActBottleneck module is from the later paper:
[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Identity Mappings in Deep Residual Networks. arXiv:1603.05027
'''
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.autograd import Variable


def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class PreActBlock(nn.Module):
    '''Pre-activation version of the BasicBlock.'''
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = conv3x3(in_planes, planes, stride)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = conv3x3(planes, planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out)
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out += shortcut
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class PreActBottleneck(nn.Module):
    '''Pre-activation version of the original Bottleneck module.'''
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(PreActBottleneck, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False)
            )

    def forward(self, x):
        out = F.relu(self.bn1(x))
        shortcut = self.shortcut(out)
        out = self.conv1(out)
        out = self.conv2(F.relu(self.bn2(out)))
        out = self.conv3(F.relu(self.bn3(out)))
        out += shortcut
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = conv3x3(3,64)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, lin=0, lout=5):
        out = x
        if lin < 1 and lout > -1:
            out = self.conv1(out)
            out = self.bn1(out)
            out = F.relu(out)
        if lin < 2 and lout > 0:
            out = self.layer1(out)
        if lin < 3 and lout > 1:
            out = self.layer2(out)
        if lin < 4 and lout > 2:
            out = self.layer3(out)
        if lin < 5 and lout > 3:
            out = self.layer4(out)
        if lout > 4:
            out = F.avg_pool2d(out, 4)
            out = out.view(out.size(0), -1)
            out = self.linear(out)
        return out


def ResNet18():
    return ResNet(PreActBlock, [2,2,2,2])

def ResNet34():
    return ResNet(BasicBlock, [3,4,6,3])

def ResNet50():
    return ResNet(Bottleneck, [3,4,6,3])

def ResNet101():
    return ResNet(Bottleneck, [3,4,23,3])

def ResNet152():
    return ResNet(Bottleneck, [3,8,36,3])


def test():
    net = ResNet18()
    y = net(Variable(torch.randn(1,3,32,32)))
    print(y.size())

# test()


# Configuration

In [4]:
"""
Configuration and Hyperparameters
"""
#torch.set_default_tensor_type(torch.cuda.FloatTensor)  # default all in GPU, in pytorch 1.9 even need dataloader to be in GPU

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

batch_size = 128
step_size = 0.01
random_seed = 0
epochs = 100
L2_decay = 1e-4
alpha = 1.
geometric_param = 0.5
perturb_loss_weight = 0.99

torch.manual_seed(random_seed)

<torch._C.Generator at 0x7f9367f07af0>

# Data

In [5]:
"""
Data
"""
train_set = datasets.MNIST(root='/content/gdrive/My Drive/colab', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
test_set = datasets.MNIST(root='/content/gdrive/My Drive/colab', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


# Models, Loss, Optimiser

In [6]:
model = models.resnet18(pretrained=False)
for param in model.parameters():
    param.requires_grad = True
model.conv1 = torch.nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
model.fc = torch.nn.Linear(512, 10)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=step_size, momentum=0.9, weight_decay=L2_decay)
step_size_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[int(epochs / 2)], gamma=0.1)
model.cuda()

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

# Data Augmentation / Perturbation AND corresponding loss

In [7]:
def mixup_MNIST_nb(inputs, labels, geometric_param, alpha):
    inner_batch_size = labels.size(0)
    inputs_flatten = inputs.reshape(inner_batch_size, -1)

    # Compute pair-wise distances & sort the distances #
    dists = torch.cdist(inputs_flatten, inputs_flatten)
    sort_idx = torch.argsort(dists, dim=1)
    sort_idx_no_itself = sort_idx[:, 1:]

    # Generate geometric random variables for selecting neighbours & get the index of selected neighbour data #
    select_idx = torch.distributions.geometric.Geometric(geometric_param).sample_n(inner_batch_size).type(torch.LongTensor).to('cuda')
    select_idx_clipped = torch.clamp(select_idx, max=inner_batch_size - 2)
    nb_idx = sort_idx_no_itself[torch.arange(inner_batch_size), select_idx_clipped]

    # mixup with neighbours #
    inputs_nb = inputs[nb_idx]
    labels_nb = labels[nb_idx]
    lmbda = torch.distributions.beta.Beta(alpha, alpha).sample().to('cuda')
    mixup_inputs_nb = lmbda * inputs + (1 - lmbda) * inputs_nb
    return mixup_inputs_nb, labels, labels_nb, lmbda

In [8]:
def mixup_criterion(criterion, predicts, labels, labels_b, lmbda):
    mixup_loss = lmbda * criterion(predicts, labels) + (1 - lmbda) * criterion(predicts, labels_b)
    return mixup_loss

# Training

In [9]:
"""
Training
"""
model.train()
for epoch in range(epochs):
    epoch_mixup_loss = 0.
    epoch_loss = 0.
    epoch_augment_loss = 0.
    for i, data in enumerate(train_loader, 0):
        optimizer.zero_grad()
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')

        # Mixup with random neighbour perturbation #
        mixup_inputs_nb, mixup_labels_nb_a, mixup_labels_nb_b, lmbda = mixup_MNIST_nb(inputs, labels, geometric_param, alpha)
        
        # Concatenate perturbation and original data, to do augmentation and loss computation #
        original_num = inputs.size(0)
        augment_inputs = torch.vstack((inputs, mixup_inputs_nb))
        augment_outputs = model(augment_inputs)
        outputs = augment_outputs[:original_num]
        mixup_outputs_nb = augment_outputs[original_num:]
        mixup_loss_nb = mixup_criterion(criterion, mixup_outputs_nb, mixup_labels_nb_a, mixup_labels_nb_b, lmbda)
        loss = criterion(outputs, labels)
        weighted_augment_loss = perturb_loss_weight * mixup_loss_nb + (1 - perturb_loss_weight) * loss

        # Record #
        epoch_mixup_loss += mixup_loss_nb.item()
        epoch_loss += loss.item()
        epoch_augment_loss += (mixup_loss_nb.item() + loss.item())

        # Gradient Calculation & Optimisation #
        weighted_augment_loss.backward()
        optimizer.step()
    
    # Step size scheduler #
    step_size_scheduler.step()
    
    # Print decomposed losses #
    print('{}: {} {} {}'.format(epoch, epoch_mixup_loss, epoch_loss, epoch_augment_loss))

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


0: 196.74645580351353 80.49308724328876 277.2395430468023
1: 152.72748199105263 33.022639349102974 185.7501213401556
2: 145.24596729129553 25.388802502304316 170.63476979359984
3: 138.43847758695483 20.887351790443063 159.3258293773979
4: 134.4693555533886 18.643028813414276 153.11238436680287
5: 133.50754065439105 16.90895447693765 150.4164951313287
6: 130.76817281357944 15.114069229923189 145.88224204350263
7: 127.47022967785597 14.530367815867066 142.00059749372303
8: 126.40801223367453 12.839939460158348 139.24795169383287
9: 125.3135385196656 12.004016015678644 137.31755453534424
10: 123.32910076156259 11.492640004027635 134.82174076559022
11: 123.10804510302842 10.928039074409753 134.03608417743817
12: 119.35355248302221 10.344256262294948 129.69780874531716
13: 122.36590484529734 9.720181399025023 132.08608624432236
14: 119.91151796281338 9.707067292183638 129.61858525499701
15: 118.85314879007638 8.937463898211718 127.79061268828809
16: 115.90684498660266 8.596773135475814 124.

# Save model

In [10]:
# torch.save(model.state_dict(), './mixup_model_pytorch_mnist')
# model = models.resnet18(pretrained=False)
# model.conv1 = torch.nn.Conv2d(1, 64, 7, stride=2, padding=3, bias=False)
# model.fc = torch.nn.Linear(512, 10)
# model.load_state_dict(torch.load('./mixup_model_pytorch_mnist'))

# Test on Test Data

In [11]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9932


# Test on Train Data

In [12]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data in train_loader:
        inputs, labels = data
        inputs = inputs.to('cuda')
        labels = labels.to('cuda')
        outputs = model(inputs)
        _, predicts = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicts == labels).sum().item()
print(correct / total)

0.9999666666666667
