# Import required modules

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Mounting drive
## Uncomment if you want to save the models to your drive

In [2]:
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
# PATH = "/path/to/directory"
# def save_model(model_name, file_path, model):
#   model_save_name = model_name
#   file_path += f"{model_name}.pt"
#   torch.save(model.state_dict(), file_path)

# def load_model(model_name, file_path, model):
#   file_path += f"{model_name}.pt"
#   model.load_state_dict(torch.load(file_path))

In [None]:
# Getting GPU runtime if available
dev = "cpu"
if torch.cuda.is_available():
    dev = "cuda:0"

device = torch.device(dev)
print(device)

# Mixer class
This serves as an interface to get the training and testing datasets as well as mixing them.

In [5]:
import random
import math

class Mixer:
    def __init__(self):
        #This is the data augmentation as described in the original paper
        self.transform = transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                        transforms.RandomHorizontalFlip(1),
                                        transforms.Pad(4),
                                        transforms.RandomCrop(32),
                                        ])
        
        self.no_transform = transforms.Compose([transforms.ToTensor()])
        self.cifar_train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
        self.cifar_test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True)
    
    #Naive between-class mixing method (BC)
    def mix_images_bc(self, img1, img2):
        ratio = random.uniform(0.5, 1)
        
        #While loop to correct value if ratio is exactly 0.5, want it to be > 0.5 to get single class label
        while ratio == 0.5:
            ratio = random.uniform(0.5, 1)
        mixed_image = img1 * ratio + img2 * (1 - ratio)

        return mixed_image, ratio

    # Advanced between-class mixing method (BC+)
    def mix_images_bc_plus(self, img1, img2):
        #Better mixing method
        ratio = random.uniform(0.5, 1)

        #While loop to correct value if ratio is exactly 0.5, want it to be > 0.5 to get single class label
        while ratio == 0.5:
            ratio = random.uniform(0.5, 1)
        i1_mean, i2_mean = torch.mean(img1), torch.mean(img2)
        i1_std, i2_std = torch.std(img1), torch.std(img2)
        p = 1 / (1 + (i1_std / i2_std) * ((1 - ratio) / ratio)) 

        mixed_image = (p * (img1 - i1_mean) + (1 - p) * (img2 - i2_mean)) / (math.sqrt(p ** 2 + (1 - p) ** 2))
        # mixed_image = torch.where(mixed_image > 255, 255, mixed_image)

        return mixed_image, ratio
    
    #Returns test dataset to users. Default transformation is False as we want it to test accurac for original CIFAR-10 dataset
    def get_test_dataset(self, transform = False):
        testing_size = len(self.cifar_test_set.data)
        original_test_set = torch.empty((testing_size, 3, 32, 32), dtype=float)
        original_test_set_label = []
        
        for i, data in enumerate(self.cifar_test_set.data):
            data = Image.fromarray(data)
            if transform:
                data = self.transform(data)
            else:
                data = self.no_transform(data)
            original_test_set[i] = data
            original_test_set_label.append(self.cifar_test_set.targets[i])
        
        original_test_set_label = F.one_hot(torch.tensor(original_test_set_label))

        return original_test_set, original_test_set_label

    #Returns train dataset to users. Default transformation is True as we want to train with augmented data
    def get_train_dataset(self, transform = True):
        # original_train_set = torch.empty((training_size, 3, 32, 32), dtype=float)
        original_train_set = [] 
        original_train_set_label = []
        
        for i, data in enumerate(self.cifar_train_set.data):
            data = Image.fromarray(data)

            if transform:
              data_transformed = self.transform(data)
              original_train_set.append(data_transformed)
              original_train_set_label.append(self.cifar_train_set.targets[i])
            
            data = self.no_transform(data)

            original_train_set.append(data)
            original_train_set_label.append(self.cifar_train_set.targets[i])

        original_train_set = torch.stack(original_train_set)
        original_train_set_label = F.one_hot(torch.tensor(original_train_set_label))

        return original_train_set, original_train_set_label
        

    #Core logic for mixing image_dataset, input images and labels returned from get_train_dataset
    #Default algo is set to bc mixing, input "bc+" for advanced mixing
    def mix_image_dataset(self, images, labels, algo="bc"):
        mixed_images = []
        mix_labels = []
        mix_ratios = []
        mixing_algo = self.mix_images_bc_plus if algo == "bc+" else self.mix_images_bc
        total = 0

        random.seed(10)
        for offset in range(1, 100):
            if total >= images.shape[0]: break
            for i in range(images.shape[0] - 1):
                if total >= images.shape[0]: break
                if i + offset < images.shape[0] and not torch.all(torch.eq(labels[i], labels[i + offset])):
                    m_i, ratio = mixing_algo(images[i], images[i + offset])
                    mixed_images.append(m_i)
                    mix_labels.append(labels[i])
                    mix_ratios.append(ratio)
                    total += 1

        
        return torch.stack(mixed_images), torch.stack(mix_labels), torch.tensor(mix_ratios)
            
                

In [6]:
#Classes of CIFAR-10
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
#Setting up the datasets that will be used
mixer = Mixer()

original_test_set, original_test_set_label = mixer.get_test_dataset(transform=False)
original_train_set, original_train_set_label = mixer.get_train_dataset(transform=False)

mixing_train_set, mixing_train_set_label = mixer.get_train_dataset()

mixed_train_images, mixed_train_labels, mix_train_ratios = mixer.mix_image_dataset(mixing_train_set, mixing_train_set_label, algo="bc")
plus_mixed_train_images, plus_mixed_train_labels, plus_mixed_train_ratios = mixer.mix_image_dataset(mixing_train_set, mixing_train_set_label, algo="bc+") 

# Preview of the images from BC mixing, BC+ mixing and original images from CIFAR-10

In [None]:
#Preview of the images within our mixed images

#BC naiive mixing method plot
fig, axs = plt.subplots(3, 6)
fig.set_dpi(160)
fig.suptitle("BC naive mixing method")
for i in range(3):
    for j in range(6):
        r_index = i * 8 + j
        img = mixed_train_images[r_index] / 2 + 0.5
        axs[i, j].imshow(np.transpose(img, (1, 2,0)))
        axs[i, j].axis('off')
        image_label = torch.nonzero(mixed_train_labels[r_index])
        axs[i,j].set_title("{} \n ratio: {:.4f}".format(classes[image_label[0]], float(mix_train_ratios[r_index])),
        fontdict={'fontsize':6})

plt.show()

#BC+ mixing method plot
fig, axs = plt.subplots(3, 6)
fig.set_dpi(160)
fig.suptitle("BC+ mixing method")
for i in range(3):
    for j in range(6):
        r_index = i * 8 + j
        img = plus_mixed_train_images[r_index] / 2 + 0.5
        axs[i, j].imshow(np.transpose(img, (1, 2,0)))
        axs[i, j].axis('off')
        image_label = torch.nonzero(plus_mixed_train_labels[r_index])
        axs[i,j].set_title("{} \n ratio: {:.4f}".format(classes[image_label[0]], float(plus_mixed_train_ratios[r_index])),
        fontdict={'fontsize':6})

plt.show()

#Original CIFAR-10 images plot
fig, axs = plt.subplots(3, 6)
fig.set_dpi(160)
fig.suptitle("Original training dataset")
for i in range(3):
    for j in range(6):
        r_index = i * 8 + j
        img = original_train_set[r_index] / 2 + 0.5
        axs[i, j].imshow(np.transpose(img, (1, 2,0)))
        axs[i, j].axis('off')
        image_label = torch.nonzero(original_train_set_label[r_index])
        axs[i,j].set_title("{} \n".format(classes[image_label[0]]),
        fontdict={'fontsize':6})

plt.show()

# DatasetIterator class
This is used for iterating through the datset, which would be used as an input to the pytorch DataLoader class.

In [9]:
from torch.utils.data import Dataset

class DatasetIterator(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels
  
    def __len__(self):
        return self.images.shape[0]
  
    def __getitem__(self, index):
        # return self.transform(self.images[index]), self.transform(self.labels[index])
        return self.images[index].detach().cpu().numpy(), self.labels[index].detach().cpu().numpy()

# ModelManager class
This class is used for managing the models trained in this experiment.

Inputs:
* **Classifer**: The model that will be trained and tested
* **Optimizer**: Optimizer that will be used during training
* **l_function**: Loss function that will be used during training. By default it is kl_divergence.

***Highly recommended to input a loss function as kl_divergence is very unstable and could result in the model returning NaN***



In [10]:
class ModelManager:
    def __init__(self, classifier, optimizer, l_function = None):
        self.classifier = classifier
        self.optimizer = optimizer
        self.l_function = l_function if l_function is not None else self.kl_divergence
        self.train_losses = []
        self.train_counter = []
    
    #KL-divergence that was implemented in the original paper
    def kl_divergence(self, pred, true):
        entropy = -1 * torch.sum(true[torch.nonzero(true, as_tuple=True)] * torch.log(true[torch.nonzero(true, as_tuple=True)]))
        crossEntropy = -1 * torch.sum(true * torch.nn.functional.log_softmax(pred))
        return (crossEntropy - entropy) / pred.shape[0]

    
    def train_model(self, train_loader, epoch):
      epoch_loss = 0
      self.classifier.train()
      for batch_idx, (images, targets) in enumerate(train_loader):
        images, targets = images.float(), targets.float()
        images = images.to(device)
        targets = targets.to(device)
        self.optimizer.zero_grad()
        output = self.classifier(images)
        loss = self.l_function(output, targets)

        loss.backward()
        self.optimizer.step()
        epoch_loss += loss.item()
        if batch_idx % 100 == 0:
          self.train_losses.append(loss.item()) # item() is to get the value of the tensor directly
          self.train_counter.append((batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
          print(f'Epoch {epoch}: [{batch_idx*len(images)}/{len(train_loader.dataset)}] Loss: {loss.item()}')
      
      print(f"Average loss for Epoch {epoch}: {epoch_loss / len(train_loader.dataset)}")
    
    def test_model(self, test_loader):
        correct, total = 0, 0
        self.classifier.eval()
        with torch.no_grad():
            for data in test_loader:
                images, labels = data
                images, labels = images.float(), labels.float()
                images = images.to(device)
                labels = labels.to(device)
                outputs = self.classifier(images)
                _, predicted = torch.max(outputs.data, 1)
                labels = [int(torch.nonzero(x)) for x in labels]
                total += len(labels)
                correct += sum([1 for x in range(len(labels)) if predicted[x] == labels[x]])
        print('Accuracy of the network on the test images: %d %%' % (100 * correct / total))

# Pretrained Resnet-18 Model

In [11]:
import torchvision.models as models
import torch.optim as optim
import torch.nn as nn

In [12]:
#Training a model now
def get_resnet18_model(pretrained=True):
  resnet18 = models.resnet18(pretrained = True)
  resnet18.fc = nn.Linear(resnet18.fc.in_features, mixed_train_labels[0].shape[0])
  resnet18.layer4.requires_grad = True
  resnet18.layer3.requires_grad = True
  resnet18.conv1.requires_grad = True
  resnet18.layer1.requires_grad = True
  resnet18.layer2.requires_grad = True
  resnet18 = resnet18.float()
  
  return resnet18

# Convnet Model from Paper

In [13]:
class ConvBNReLU(nn.Module):
    def __init__(self, in_channels, out_channels, ksize, stride = 1, pad = 0, bias=False):
        super(ConvBNReLU, self).__init__()
        self.conv = torch.nn.Conv2d(in_channels, out_channels, ksize, stride, pad, bias=bias)
        self.bn = torch.nn.BatchNorm2d(out_channels, eps=1e-5)
    
    def forward(self, x):
        output = self.conv(x)
        output = self.bn(output)

        return torch.relu(output)

class ConvNet(nn.Module):
    def __init__(self, n_classes):
        super(ConvNet, self).__init__()
        self.conv11 = ConvBNReLU(3, 64, 3, pad = 1)
        self.conv12 = ConvBNReLU(64, 64, 3, pad = 1)
        self.conv21 = ConvBNReLU(64, 128, 3, pad = 1)
        self.conv22 = ConvBNReLU(128, 128, 3, pad = 1)
        self.conv31 = ConvBNReLU(128, 256, 3, pad = 1)
        self.conv32 = ConvBNReLU(256, 256, 3, pad = 1)
        self.conv33 = ConvBNReLU(256, 256, 3, pad = 1)
        self.conv34 = ConvBNReLU(256, 256, 3, pad = 1)
        self.fc4 = torch.nn.Linear(256*4*4, 1024)
        self.fc5 = torch.nn.Linear(1024, 1024)
        self.fc6 = torch.nn.Linear(1024, n_classes)
    
    def forward(self, x):
        output = self.conv11(x)
        output = self.conv12(output)
        output = nn.functional.max_pool2d(output, 2)


        output = self.conv21(output)
        output = self.conv22(output)
        output = nn.functional.max_pool2d(output, 2)


        output = self.conv31(output)
        output = self.conv32(output)
        output = self.conv33(output)
        output = self.conv34(output)
        output = nn.functional.max_pool2d(output, 2)
        output = output.reshape(-1, 256*4*4)

        output = nn.functional.dropout(nn.functional.relu(self.fc4(output)))
        output = nn.functional.dropout(nn.functional.relu(self.fc5(output)))

        return self.fc6(output)

# Shake Shake Regularization 
(https://notebook.community/t-vi/pytorch-tvmisc/misc/cifar10-shake-shake)

In [14]:
from torch.autograd import Variable
# shakeshake net leaning heavily on the original torch implementation https://github.com/xgastaldi/shake-shake/
class ShakeShakeBlock2d(torch.nn.Module):
    def __init__(self, in_channels, out_channels, stride, per_image=True, rand_forward=True, rand_backward=True):
        super().__init__()
        self.same_width = (in_channels==out_channels)
        self.per_image = per_image
        self.rand_forward = rand_forward
        self.rand_backward = rand_backward
        self.stride = stride
        self.net1, self.net2 = [torch.nn.Sequential(
                        torch.nn.ReLU(),
                        torch.nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1),
                        torch.nn.BatchNorm2d(out_channels),
                        torch.nn.ReLU(),
                        torch.nn.Conv2d(out_channels, out_channels, 3, padding=1),
                        torch.nn.BatchNorm2d(out_channels)) for i in range(2)]
        if not self.same_width:
            self.skip_conv1 = torch.nn.Conv2d(in_channels, out_channels//2, 1)
            self.skip_conv2 = torch.nn.Conv2d(in_channels, out_channels//2, 1)
            self.skip_bn = torch.nn.BatchNorm2d(out_channels)
    def forward(self, inp):
        if self.same_width:
            skip = inp
        else:
            # double check, this seems to be a fancy way to trow away the top-right and bottom-left of each 2x2 patch (with stride=2)
            x1 = torch.nn.functional.avg_pool2d(inp, 1, stride=self.stride)
            x1 = self.skip_conv1(x1)
            x2 = torch.nn.functional.pad(inp, (1,-1,1,-1))            # this makes the top and leftmost row 0. one could use -1,1
            x2 = torch.nn.functional.avg_pool2d(x2, 1, stride=self.stride)
            x2 = self.skip_conv2(x2)
            skip = torch.cat((x1,x2), dim=1)
            skip = self.skip_bn(skip)
        x1 = self.net1(inp)
        x2 = self.net2(inp)

        if self.training:
            if self.rand_forward:
                if self.per_image:
                    alpha = Variable(inp.data.new(inp.size(0),1,1,1).uniform_())
                else:
                    alpha = Variable(inp.data.new(1,1,1,1).uniform_())
            else:
                alpha = 0.5
            if self.rand_backward:
                if self.per_image:
                    beta = Variable(inp.data.new(inp.size(0),1,1,1).uniform_())
                else:
                    beta = Variable(inp.data.new(1,1,1,1).uniform_())
            else:
                beta = 0.5
            # this is the trick to get beta in the backward (because it does not see the detatched)
            # and alpha in the forward (when it sees the detached with the alpha and the beta cancel)
            x = skip+beta*x1+(1-beta)*x2+((alpha-beta)*x1).detach()+((beta-alpha)*x2).detach()
        else:
            x = skip+0.5*(x1+x2)
        return x

            
class ShakeShakeBlocks2d(torch.nn.Sequential):
    def __init__(self, in_channels, out_channels, depth, stride, per_image=True, rand_forward=True, rand_backward=True):
        super().__init__(*[
            ShakeShakeBlock2d(in_channels if i==0 else out_channels, out_channels, stride if i==0 else 1,
                              per_image, rand_forward, rand_backward) for i in range(depth)])

class ShakeShakeNet(torch.nn.Module):
    def __init__(self, depth=20, basewidth=32, per_image=True, rand_forward=True, rand_backward=True, num_classes=16):
        super().__init__()
        assert (depth - 2) % 6==0, "depth should be n*6+2"
        n = (depth - 2) // 6
        self.inconv = torch.nn.Conv2d(3, 16, 3, padding=1)
        self.bn1 = torch.nn.BatchNorm2d(16)
        self.s1 = ShakeShakeBlocks2d(16, basewidth, n, 1, per_image, rand_forward, rand_backward)
        self.s2 = ShakeShakeBlocks2d(basewidth, 2*basewidth, n, 2, per_image, rand_forward, rand_backward)
        self.s3 = ShakeShakeBlocks2d(2*basewidth, 4*basewidth, n, 2, per_image, rand_forward, rand_backward)
        self.fc = torch.nn.Linear(4*basewidth, num_classes)
    def forward(self, x):
        x = self.inconv(x)
        x = self.bn1(x)
        x = self.s1(x)
        x = self.s2(x)
        x = self.s3(x)
        x = torch.nn.functional.relu(x)
        x = x.view(x.size(0), x.size(1), -1).mean(2)
        x = self.fc(x)
        return x

# BC Mixing algorithm training

In [15]:
#Specification of the # of epochs and the batch-size for training
EPOCH = 20
BATCH_SIZE = 128

In [16]:
bc_mixed_train_images, bc_mixed_train_labels = mixed_train_images, mixed_train_labels

#Preparing train-loader of BC mixed dataset for pytorch DataLoader class
bc_mixed_cifar10_dataset = DatasetIterator(bc_mixed_train_images, bc_mixed_train_labels)
bc_mixed_train_loader = torch.utils.data.DataLoader(bc_mixed_cifar10_dataset, batch_size=BATCH_SIZE, shuffle=True)

#Preparing test-loader for pytorch DataLoader class
cifar10_test_dataset = DatasetIterator(original_test_set, original_test_set_label)
cifar10_test_loader = torch.utils.data.DataLoader(cifar10_test_dataset, batch_size=BATCH_SIZE, shuffle=False)

## Training resnet-18 with BC mixed CIFAR-10 dataset

In [None]:
# TRAINING RESNET18 WITH BC MIXING ALGORITHM
epochs = EPOCH

l_function = torch.nn.CrossEntropyLoss()
# l_function = None #Uncomment to use KL Divergence Loss Function

resnet18_bc = get_resnet18_model()
optimizer = optim.SGD(resnet18_bc.parameters(), lr=0.0001, momentum=0.8)
# optimizer = optim.Adam(resnet18_bc.parameters(), lr = 0.0001) #Uncomment for Adam optimizer

resnet18_bc = resnet18_bc.to(device)

resnet18_bc_model_manager = ModelManager(resnet18_bc, optimizer, l_function)

for e in range(1, epochs + 1):
  resnet18_bc_model_manager.train_model(bc_mixed_train_loader, e)

# Uncomment if you want to save the model
# resnet18_bc = resnet18_bc.to('cpu')
# save_model('resnet18_bc', PATH, resnet18_bc)
# resnet18_bc = resnet18_bc.to(device)

## Training convnet with BC mixed CIFAR-10 dataset

In [None]:
# TRAINING convnet WITH BC MIXING ALGORITHM
epochs = EPOCH

convnet_bc = ConvNet(10)
optimizer = optim.SGD(convnet_bc.parameters(), lr=0.01, momentum=0.8)
# optimizer = optim.Adam(convnet_bc.parameters(), lr = 0.01) #Uncomment for Adam optimizer

l_function = torch.nn.CrossEntropyLoss()
# l_function = None #Uncomment to use KL Divergence Loss Function

convnet_bc = convnet_bc.to(device)

convnet_bc_model_manager = ModelManager(convnet_bc, optimizer, l_function)

for e in range(1, epochs + 1):
  convnet_bc_model_manager.train_model(bc_mixed_train_loader, e)

# Uncomment if you want to save the model
# convnet_bc = convnet_bc.to('cpu')
# save_model('convnet_bc', PATH, convnet_bc)
# convnet_bc = convnet_bc.to(device)


## Training Shake-Shake Regularization with BC mixed CIFAR-10 dataset

In [None]:
# TRAINING Shake-Shake Reglarization WITH BC MIXING ALGORITHM
epochs = EPOCH

l_function = torch.nn.CrossEntropyLoss()
# l_function = None #Uncomment to use KL Divergence Loss Function

shake_bc = ShakeShakeNet(num_classes=10, depth=8)
optimizer = optim.SGD(shake_bc.parameters(), lr=0.01, momentum=0.8)
# optimizer = optim.Adam(shake_bc.parameters(), lr = 0.01) #Uncomment for Adam optimizer

shake_bc = shake_bc.to(device)

shake_bc_model_manager = ModelManager(shake_bc, optimizer, l_function)

for e in range(1, epochs + 1):
  shake_bc_model_manager.train_model(bc_mixed_train_loader, e)

# Uncomment if you want to save the model
# shake_bc = shake_bc.to('cpu')
# save_model('shake_bc', PATH, shake_bc)
# shake_bc = shake_bc.to(device)

## Testing accuracy of the models being trained on BC mixed CIFAR-10 model

In [None]:
print("Testing of BC Pretrained Resnet18 with dataset without transformation:")
resnet18_bc_model_manager.test_model(cifar10_test_loader)

print("\nTesting of BC convnet with dataset without transformation:")
convnet_bc_model_manager.test_model(cifar10_test_loader)

print("\nTesting of BC shake-shake with dataset without transformation:")
shake_bc_model_manager.test_model(cifar10_test_loader)

## Visualization of Training Loss of BC algorithm between Resnet18, Convnet and Shake-Shake Regularization

In [None]:
fig = plt.figure(figsize=(8,6))
ax1 = fig.add_subplot(111)

bc_train_counter = resnet18_bc_model_manager.train_counter
resnet18_bc_train_loss, convnet_bc_train_loss = resnet18_bc_model_manager.train_losses, convnet_bc_model_manager.train_losses
shake_bc_train_loss = shake_bc_model_manager.train_losses

ax1.scatter(bc_train_counter, resnet18_bc_train_loss, c='b', marker="o", label='resnet18')
ax1.scatter(bc_train_counter, convnet_bc_train_loss, s=10, c='r', marker="x", label='convnet')
ax1.scatter(bc_train_counter, shake_bc_train_loss, s=10, c='g', marker="s", label='shake-shake')
plt.legend(loc='upper right')
plt.title('Loss of model trained with BC algorithm')
plt.xlabel('Training Counter')
plt.ylabel('Training Loss value')
plt.show()

# BC+ Mixing algorithm training

In [22]:
#Specification of the # of epochs and the batch-size for training
EPOCH = 20
BATCH_SIZE = 128

In [23]:
bc_plus_mixed_train_images, bc_plus_mixed_train_labels = plus_mixed_train_images, plus_mixed_train_labels

#Preparing train-loader of BC+ mixed dataset for pytorch DataLoader class
bc_plus_mixed_cifar10_dataset = DatasetIterator(bc_plus_mixed_train_images, bc_plus_mixed_train_labels)
bc_plus_mixed_train_loader = torch.utils.data.DataLoader(bc_plus_mixed_cifar10_dataset, batch_size=BATCH_SIZE, shuffle=True)

#Preparing train-loader of original test dataset for pytorch DataLoader class
cifar10_test_dataset = DatasetIterator(original_test_set, original_test_set_label)
cifar10_test_loader = torch.utils.data.DataLoader(cifar10_test_dataset, batch_size=BATCH_SIZE, shuffle=False)

## Training resnet-18 with BC+ mixed CIFAR-10 dataset

In [None]:
# TRAINING RESNET18 WITH BC+ MIXING ALGORITHM
epochs = EPOCH

l_function = torch.nn.CrossEntropyLoss()
# l_function = None #Uncomment to use KL Divergence Loss Function

resnet18_bc_plus = get_resnet18_model()
optimizer = optim.SGD(resnet18_bc_plus.parameters(), lr=0.0001, momentum=0.8)
# optimizer = optim.Adam(resnet18_bc_plus.parameters(), lr = 0.0001) #Uncomment for Adam optimizer

resnet18_bc_plus = resnet18_bc_plus.to(device)

resnet18_bc_plus_model_manager = ModelManager(resnet18_bc_plus, optimizer, l_function)

for e in range(1, epochs + 1):
  resnet18_bc_plus_model_manager.train_model(bc_plus_mixed_train_loader, e)

# Uncomment if you want to save the model
# resnet18_bc_plus = resnet18_bc_plus.to('cpu')
# save_model('resnet18_bc_plus', PATH, resnet18_bc_plus)
# resnet18_bc_plus = resnet18_bc_plus.to(device)

## Training Convnet with BC+ mixed CIFAR-10 dataset

In [None]:
# TRAINING convnet WITH BC+ MIXING ALGORITHM
epochs = EPOCH

convnet_bc_plus = ConvNet(10)
optimizer = optim.SGD(convnet_bc_plus.parameters(), lr=0.01, momentum=0.8)
# optimizer = optim.Adam(convnet_bc_plus.parameters(), lr = 0.01) #Uncomment for Adam optimizer

l_function = torch.nn.CrossEntropyLoss()
# l_function = None #Uncomment to use KL Divergence Loss Function

convnet_bc_plus = convnet_bc_plus.to(device)

convnet_bc_plus_model_manager = ModelManager(convnet_bc_plus, optimizer, l_function)

for e in range(1, epochs + 1):
  convnet_bc_plus_model_manager.train_model(bc_plus_mixed_train_loader, e)

# Uncomment if you want to save the model
# convnet_bc_plus = convnet_bc_plus.to('cpu')
# save_model('convnet_bc_plus', PATH, convnet_bc_plus)
# convnet_bc_plus = convnet_bc_plus.to(device)

## Training Shake-Shake Regularization with BC mixed CIFAR-10 dataset

In [None]:
# TRAINING shake-shake WITH BC+ MIXING ALGORITHM
epochs = EPOCH

l_function = torch.nn.CrossEntropyLoss()
# l_function = None #Uncomment to use KL Divergence Loss Function

shake_bc_plus = ShakeShakeNet(num_classes=10, depth=8)
optimizer = optim.SGD(shake_bc_plus.parameters(), lr=0.01, momentum=0.8)
# optimizer = optim.Adam(shake_bc_plus.parameters(), lr = 0.01) #Uncomment for Adam optimizer

shake_bc_plus = shake_bc_plus.to(device)

shake_bc_plus_model_manager = ModelManager(shake_bc_plus, optimizer, l_function)

for e in range(1, epochs + 1):
  shake_bc_plus_model_manager.train_model(bc_plus_mixed_train_loader, e)

# Uncomment if you want to save the model
# shake_bc_plus = shake_bc_plus.to('cpu')
# save_model('shake_bc_plus', PATH, shake_bc_plus)
# shake_bc_plus = shake_bc_plus.to(device)

## Testing accuracy of the models being trained on BC+ mixed CIFAR-10 model

In [None]:
print("Testing of BC+ Pretrained Resnet18 with original dataset:")
resnet18_bc_plus_model_manager.test_model(cifar10_test_loader)

print("\nTesting of BC+ convnet with original dataset:")
convnet_bc_plus_model_manager.test_model(cifar10_test_loader)

print("\nTesting of BC+ shake-shake with original dataset:")
shake_bc_plus_model_manager.test_model(cifar10_test_loader)

## Visualization of Training Loss of BC+ algorithm between Resnet18, Convnet and Shake-Shake Regularization

In [None]:
fig = plt.figure(figsize=(8,6))
ax1 = fig.add_subplot(111)

bc_plus_train_counter = resnet18_bc_plus_model_manager.train_counter
resnet18_bc_plus_train_loss, convnet_bc_plus_train_loss = resnet18_bc_plus_model_manager.train_losses, convnet_bc_plus_model_manager.train_losses
shake_bc_plus_train_loss = shake_bc_plus_model_manager.train_losses

ax1.scatter(bc_plus_train_counter, resnet18_bc_plus_train_loss, c='b', marker="o", label='resnet18')
ax1.scatter(bc_plus_train_counter, convnet_bc_plus_train_loss, s=10, c='r', marker="x", label='convnet')
ax1.scatter(bc_plus_train_counter, shake_bc_plus_train_loss, s=10, c='g', marker="s", label='shake-shake')
plt.legend(loc='upper right')
plt.title('Loss of model trained with BC+ algorithm')
plt.xlabel('Training Counter')
plt.ylabel('Training Loss value')
plt.show()

# Training Models on non-augmented CIFAR-10 Dataset

In [29]:
#Specification of the # of epochs and the batch-size for training
EPOCH = 20
BATCH_SIZE = 128

In [30]:
#Preparing train-loader of non-augmented dataset for pytorch DataLoader class
original_cifar10_dataset = DatasetIterator(original_train_set, original_train_set_label)
original_train_loader = torch.utils.data.DataLoader(original_cifar10_dataset, batch_size=BATCH_SIZE, shuffle=True)

cifar10_test_dataset = DatasetIterator(original_test_set, original_test_set_label)
# Test loader with original CIFAR-10 test dataset without augmentation
cifar10_test_loader = torch.utils.data.DataLoader(cifar10_test_dataset, batch_size=64, shuffle=False)

## Training of all 3 models

In [None]:
# TRAINING convnet, shake-shake, resnet18 WITH original dataset
epochs = EPOCH
l_function = torch.nn.CrossEntropyLoss()

convnet_original = ConvNet(10)
shake_original = ShakeShakeNet(num_classes=10)
resnet18_original = get_resnet18_model()

optimizer_convnet = optim.SGD(convnet_original.parameters(), lr=0.01, momentum=0.8)
optimizer_shake = optim.SGD(shake_original.parameters(), lr=0.01, momentum=0.8)
optimizer_resnet18 = optim.SGD(resnet18_original.parameters(), lr=0.001, momentum=0.8)


convnet_original = convnet_original.to(device)
shake_original = shake_original.to(device)
resnet18_original = resnet18_original.to(device)

convnet_original_model_manager = ModelManager(convnet_original, optimizer_convnet, l_function)
shake_original_model_manager = ModelManager(shake_original, optimizer_shake, l_function)
resnet18_original_model_manager = ModelManager(resnet18_original, optimizer_resnet18, l_function)

for e in range(1, epochs + 1):
  convnet_original_model_manager.train_model(original_train_loader, e)
  shake_original_model_manager.train_model(original_train_loader, e)
  resnet18_original_model_manager.train_model(original_train_loader, e)

## Testing accuracy of the models being trained on non-augmented CIFAR-10 model

In [None]:
print("Testing of Pretrained Resnet18 with original dataset:")
resnet18_original_model_manager.test_model(cifar10_test_loader)

print("\nTesting of convnet with original dataset:")
convnet_original_model_manager.test_model(cifar10_test_loader)

print("\nTesting of shake-shake with original dataset:")
shake_original_model_manager.test_model(cifar10_test_loader)

## Visualization of Training Loss without mixing between Resnet18, Convnet and Shake-Shake Regularization

In [None]:
fig = plt.figure(figsize=(8,6))
ax1 = fig.add_subplot(111)

original_train_counter = resnet18_original_model_manager.train_counter
resnet18_original_train_loss, convnet_original_train_loss = resnet18_original_model_manager.train_losses, convnet_original_model_manager.train_losses
shake_original_train_loss = shake_original_model_manager.train_losses

ax1.scatter(original_train_counter, resnet18_original_train_loss, c='b', marker="o", label='resnet18')
ax1.scatter(original_train_counter, convnet_original_train_loss, s=10, c='r', marker="x", label='convnet')
ax1.scatter(original_train_counter, shake_original_train_loss, s=10, c='g', marker="s", label='shake-shake')
plt.legend(loc='upper right')
plt.title('Loss of model trained with original dataset')
plt.xlabel('Training Counter')
plt.ylabel('Training Loss value')
plt.show()