**Installing TPU Client for Pytorch**

This notebook is used to train the Models and is optimised for using PyTorch on Clout TPUs. For more information on how to use TPU on google colab, please see https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb    

In [None]:
!pip uninstall -y torch

!pip install torch==1.8.2+cpu -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
# !pip install torch==1.10.0 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html

!pip install -q cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl


Found existing installation: torch 1.8.2+cpu
Uninstalling torch-1.8.2+cpu:
  Successfully uninstalled torch-1.8.2+cpu
Looking in links: https://download.pytorch.org/whl/lts/1.8/torch_lts.html
Collecting torch==1.8.2+cpu
  Using cached https://download.pytorch.org/whl/lts/1.8/cpu/torch-1.8.2%2Bcpu-cp37-cp37m-linux_x86_64.whl (169.1 MB)
Installing collected packages: torch
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchvision 0.11.1+cu111 requires torch==1.10.0, but you have torch 1.8.2+cpu which is incompatible.
torchtext 0.11.0 requires torch==1.10.0, but you have torch 1.8.2+cpu which is incompatible.
torchaudio 0.10.0+cu111 requires torch==1.10.0, but you have torch 1.8.2+cpu which is incompatible.[0m
Successfully installed torch-1.8.2+cpu


**Imports**

In [None]:
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator, FormatStrFormatter, AutoMinorLocator
from tqdm.notebook import tqdm
import copy
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pll
import torch_xla.utils.serialization as xser

import torchvision.transforms as transforms
import torchvision.datasets

import random
import os
import pickle

random_seed = 1
torch.backends.cudnn.enabled = False

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = xm.xla_device()

torch_xla.core.xla_model.set_rng_state(1, device=device)
np.random.seed(0)
torch.manual_seed(0)




<torch._C.Generator at 0x7f16d18acab0>

**Making 'Robustness_analysis' directory the root directory**

In [None]:
%cd Robustness_analysis/

/content/Robustness_analysis


**For loading MNIST and CIFAR10 datsets**

The loadDatset method takes the following input variables:
>- dataset_type (str): the dataset you want to load, accepts a case-insensitive string 'mnist' for the MNIST dataset and 'cifar10' for the CIFAR10 dataset
>- download_ (bool): option to download datatset from torchvision.datasets
>- params (int, int): consists of two parameters in a tuple (batch size of dataset, number of workers) 


In [None]:
def loadDataset(dataset_type, download_, params):
  # Configures training (and evaluation) parameters
  dataset_type = dataset_type.upper()
  flags = {}
  flags['batch_size'] = params[0]
  flags['num_workers'] = params[1]

  if dataset_type == 'CIFAR10':
    # Transformations for CIFAR10
    trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train_dataset = torchvision.datasets.CIFAR10(root=r"Dataset", train=True,
                                                transform=trans, download=download_)
    test_dataset = torchvision.datasets.CIFAR10(root="Dataset", train=False,
                                                transform=trans, download=download_)
  elif dataset_type == 'MNIST':
    # Transformations for MNIST
    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 
    train_dataset = torchvision.datasets.MNIST(root=r"Dataset", train=True,
                                                transform=trans, download=download_)
    test_dataset = torchvision.datasets.MNIST(root="Dataset", train=False,
                                                transform=trans, download=download_)
  else:
    # Transformations for CIFAR10 (default)
    trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train_dataset = torchvision.datasets.CIFAR10(root=r"Dataset", train=True,
                                                transform=trans, download=download_)
    test_dataset = torchvision.datasets.CIFAR10(root="Dataset", train=False,
                                                transform=trans, download=download_)
    
  # for train set
  train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)

  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=flags['batch_size'],
      sampler=train_sampler,
      num_workers=flags['num_workers'],
      drop_last=True)

  # for test set 
  test_sampler = torch.utils.data.distributed.DistributedSampler(
    test_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)

  test_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=flags['batch_size'],
      sampler=test_sampler,
      num_workers=flags['num_workers'],
      drop_last=True)
  
  return train_loader, test_loader

trainLoader, testLoader = loadDataset(dataset_type='mnist', download_=False, params=(100, 1))

**Loading ImageNet Tiny dataset**

This dataset is constructed using the Tiny ImageNet dataset from https://www.kaggle.com/c/tiny-imagenet.

The default dataset of 90,000 training examples and 10,000 test examples. To split the dataset to 80,000 training and 20,000 test examples, load the current dataset and select 50 examples of each class from the training set and to add to the test set.

loadimageNet accepts two parameters:
>- path (str): the path of the dataset folader 'tiny-imagenet-200/complete_dataset'
>- params (int, int): consists of two parameters in a tuple (batch size of dataset, number of workers)


In [None]:
class ImageNetTiny(torch.utils.data.Dataset):
  # Creating ImageNet class to orangise dataset
  def __init__(self, data_dict, transform=None, target_transform=None):
    self.targets = data_dict['labels']
    self.data = data_dict['images']
    self.classes = data_dict['Classes_dict']

    self.label_keys = self.classes.keys()
    self.label_values = self.classes.values()
    
    self.transform = transform
    self.target_transform = target_transform

  def __len__(self):
    return len(self.targets)

  def __getitem__(self, idx):
    image = self.data[idx]
    label = self.targets[idx]
    if self.transform:
      image = self.transform(image)
    if self.target_transform:
      label = self.target_transform(label)
    return image, label


def save_obj(obj, name):
  with open(name + '.pkl', 'wb') as f:
    pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_imagenet(dir_path):
  train_file_name = "train_set"
  test_file_name = "test_set"

  with (open(dir_path+train_file_name+ '.pkl', "rb")) as openfile:
    while True:
        try:
            train_dict = (pickle.load(openfile))
        except EOFError:
            break

  with (open(dir_path+test_file_name+ '.pkl', "rb")) as openfile:
    while True:
        try:
            test_dict = (pickle.load(openfile))
        except EOFError:
            break
  
  return train_dict, test_dict

def loadImageNetTiny(path, params):
  flags = {}
  flags['batch_size'] = params[0]
  flags['num_workers'] = params[1]

  train_dict, test_dict = load_imagenet(path)

  trainset = ImageNetTiny(train_dict)
  testset = ImageNetTiny(test_dict)

  train_sampler = torch.utils.data.distributed.DistributedSampler(
    trainset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)

  train_loader = torch.utils.data.DataLoader(
    trainset,
    batch_size=flags['batch_size'],
    sampler=train_sampler,
    num_workers=flags['num_workers'],
    drop_last=True)

  test_sampler = torch.utils.data.distributed.DistributedSampler(
    testset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)

  test_loader = torch.utils.data.DataLoader(
      testset,
      batch_size=flags['batch_size'],
      sampler=test_sampler,
      num_workers=flags['num_workers'],
      drop_last=True)

  return train_loader, test_loader

trainLoader, testLoader = loadImageNetTiny(path="Dataset/tiny-imagenet-200/complete_dataset/", params=(100, 1))

**Loading Networks**

This cell is for loading the model for training. To load the correct model with input and output shapes configured to suit the dataset, the 'Model' class accepts two parameters:
>- mode_type (str): The type of model you want to load, options:
>>* ResNet-18 (default)
>>* ResNet-50
>>* SqueezeNet-v1.1
>>* ShuffleNet V2 x1.0 

>- data_type (str): The dataset you are using, options:
>>* MNIST
>>* CIFAR10
>>* ImageNet Tiny

In [None]:
def getModel(model_type, dataset_type):
  model_type = model_type.lower()
  dataset_type = dataset_type.lower()

  if model_type  == 'resnet50':
    if dataset_type == 'mnist':
      # for loading ResNet-50 model for MNIST dataset 
      temp_model = torchvision.models.resnet50(num_classes=10)
      temp_model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    elif dataset_type == 'cifar10':
      # for loading ResNet-50 model for CIFAR10 dataset
      temp_model = torchvision.models.resnet50(num_classes=10)
    elif dataset_type == 'imagenet': 
      # for loading ResNet-50 model for TimageNet Tiny dataset
      temp_model = torchvision.models.resnet50(num_classes=200)
    else:
      # (default) loading ResNet-50 model for CIFAR10 dataset
      temp_model = torchvision.models.resnet50(num_classes=10)

  elif model_type == 'squeezenet':
    if dataset_type == 'mnist':
      # for loading SqueezeNet-v1.1 model for MNIST dataset 
      temp_model = torchvision.models.squeezenet1_1(num_classes=10)
      temp_model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2))
    elif dataset_type == 'cifar10':
      # for loading SqueezeNet-v1.1 model for CIFAR10 dataset
      temp_model = torchvision.models.squeezenet1_1(num_classes=10)
    elif dataset_type == 'imagenet': 
      # for loading SqueezeNet-v1.1 model for TimageNet Tiny dataset
      temp_model = torchvision.models.squeezenet1_1(num_classes=200)
    else:
      # (default) loading SqueezeNet-v1.1 model for CIFAR10 dataset
      temp_model = torchvision.models.squeezenet1_1(num_classes=10)

  elif model_type == 'shufflenet':
    if dataset_type == 'mnist':
      # for loading ShuffleNet V2 x1.0 model for MNIST dataset 
      temp_model = torchvision.models.shufflenet_v2_x0_5(num_classes=10)
      temp_model.conv1[0] = nn.Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    elif dataset_type == 'cifar10':
      # for loading ShuffleNet V2 x1.0 model for CIFAR10 dataset
      temp_model = torchvision.models.shufflenet_v2_x0_5(num_classes=10)
    elif dataset_type == 'imagenet': 
      # for loading ShuffleNet V2 x1.0 model for TimageNet Tiny dataset
      temp_model = torchvision.models.shufflenet_v2_x0_5(num_classes=200)
    else:
      # (default) loading ShuffleNet V2 x1.0 model for CIFAR10 dataset
      temp_model = torchvision.models.shufflenet_v2_x0_5(num_classes=10)

  else:
    # (Default) loading ResNet-18 models
    if dataset_type == 'mnist':
      # for loading ResNet-18 model for MNIST dataset 
      temp_model = torchvision.models.resnet18(num_classes=10)
      temp_model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    elif dataset_type == 'cifar10':
      # for loading ResNet-18 model for CIFAR10 dataset
      temp_model = torchvision.models.resnet18(num_classes=10)
    elif dataset_type == 'imagenet': 
      # for loading ResNet-18 model for TimageNet Tiny dataset
      temp_model = torchvision.models.resnet18(num_classes=200)
    else:
      # (default) loading ResNet-18 model for CIFAR10 dataset
      temp_model = torchvision.models.resnet18(num_classes=10)

  return temp_model

class Model(nn.Module):
  def __init__(self, mod_type, data_type):
    super().__init__()
    self.model = getModel(mod_type, data_type)
    self.loss = nn.CrossEntropyLoss()

##  @auto_move_data
  def forward(self, x):
    return self.model(x)

  def training_step(self, batch, batch_no):
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    return loss

  def configure_optimizers(self):
    return torch.optim.Adam(self.model.parameters(), lr=0.001)

WRAPPED_MODEL = xmp.MpModelWrapper(Model('squeezenet', 'imagenet'))
model = WRAPPED_MODEL.to(device)


**For Training Network**

The directory 'Saved_models / <em>model_type</em> /<em>dataset_type</em> / <em>Init_x</em>' will automatically be created if you do not have it already. The network training initialising parameters as:

>- init_list_1: the number of random network initialisations you want, each initialisation will be trained and saved seperately in the directory 'Saved_models / <em>model_type</em> /<em>dataset_type</em> / <em>Init_x</em>', where x is the network initialisation number.

>- model_name (str): the rame of the model you are saving the network to. 

>- dataset_name (str):  the rame of the dataset you are saving the network to

>- max_epoch (int): the maximum number of epochs you want to train the network for (default=10)

>- save_epoch: how often, in epochs, you want to save the network (default=5)

>- reload_old (bool): True if you want to resume a previous training session.

>- reload_init (str): The initialisation you want to reload and resume.

>- reload_epoch (int): 

>- save_model (bool): If you want to save the trained models.

The training hyperparameters, such as optimiser and loss function can be changed using variables <em>opt</em> and <em>loss_fn</em>.

In [58]:
def ifExists(path, dir=False):
  if dir == False:
    return os.path.isfile(path)
  else:
    return os.path.isdir(path)

def save_obj(obj, name):
    with open(name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)


init_list_1 = ['Init_1', 'Init_2', 'Init_3']
init_list_1 = ['Init_1']

model_name = "ResNet18"
dataset_name = "MNIST"

max_epochs = 10
save_epoch = 5
reload_old, reload_init, reload_epoch = False, 'Init_1', 0

loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.001)
save_model = False

for init in init_list_1:
  torch_xla.core.xla_model.set_rng_state(1, device=device)
  np.random.seed(0)
  torch.manual_seed(0)

  # Instantiating the network
  WRAPPED_MODEL = xmp.MpModelWrapper(Model(model_name, dataset_name))
  model = WRAPPED_MODEL.to(device)
  model.model.train()

  optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
  epoch_num = max_epochs

  if reload_old == True:
    if init == reload_init:
      epoch_num = max_epochs - reload_epoch
      path = "Saved_models/"+model_name+"/"+dataset_name+"/" + reload_init + '/'+ model_name+"_"+dataset_name+"_" + str(reload_epoch) + "epochs.ckpt"
      # checkpoint = torch.load(path, map_location=device)
      checkpoint = xser.load(path)
      model.load_state_dict(checkpoint)
      print("Using previous network: ", path)
    else:
      epoch_num = max_epochs
  else:
    file_path = "Saved_models/"+model_name+"/"+dataset_name+"/" + init + '/'
    check = ifExists(file_path, dir=True)

    if check == False:
      try: 
        os.makedirs(file_path) 
      except OSError as error: 
          print(error)  

  criterion = nn.CrossEntropyLoss()
  #Optimizer(SGD)

  for epoch in tqdm(range(epoch_num)):  # loop over the dataset multiple times
    para_train_loader = pll.ParallelLoader(trainLoader, [device]).per_device_loader(device)

    if reload_old == True:
      if reload_epoch != 0:
        if init == reload_init:
          save_name = file_path+model_name+"_"+dataset_name+"_training_"+reload_init
          with (open(save_name + '.pkl', "rb")) as openfile:
            while True:
                try:
                    objects = (pickle.load(openfile))
                except EOFError:
                    break

          losses = objects['training_loss']
          training_acc = objects['training_accuracy']
          epoch = epoch + add_epoch

      else:
        losses = []
        training_acc = []
    else:
      losses = []
      training_acc = []

    running_loss = 0
    correct = 0
    total = 0
    counter = 0
    
    # Saving network every 'save_epoch' epochs (should be an int dtype)
    if save_model == True:
      if not epoch%save_epoch:
        print("Saving model at epoch {}".format(epoch), "\n", flush=True)
        file_path_final = file_path+model_name+"_"+dataset_name+"_" + str(epoch) + "epochs.ckpt"

        # Save the network
        # torch.save(model.state_dict(), file_path_final)
        xm.save(model.state_dict(), file_path_final)

    for data, targets in para_train_loader:
      # get the inputs; data is a list of [inputs, labels] 
      inputs, labels = data.to(device), targets.to(device)

      # zero the parameter gradients
      optimizer.zero_grad()

      # for normal use
      output = model(inputs)
      loss = criterion(output, labels)

      loss.backward()

      # optimizer.step()
      xm.optimizer_step(optimizer)

      # getting prediction
      running_loss += loss.item() * data[0].size(0)
      _, predicted = output.max(1)
      total += labels.size(0)
      correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    acc = correct/total

    losses.append(epoch_loss)
    training_acc.append(acc)

    train_dict = {'training_loss': losses, 'training_accuracy':training_acc}
    save_name = file_path+model_name+'_'+dataset_name+'_training_'+init
    save_obj(train_dict, save_name)

  # del WRAPPED_MODEL, model
  print('Finished Training init {}'.format(init))
  

  0%|          | 0/10 [00:00<?, ?it/s]

Finished Training init Init_1


**Testing the trained network**

The modek_tester method accepts two arguments:

>- model_ (Obj): The model to be tested

>- test_set_ (Obj): The dataset to test the model on (test set)

In [59]:
def model_tester(model_, test_set_):
  model_.eval()
  model_.to(device)
  criterion = nn.CrossEntropyLoss()

  correct, total = 0, 0
  with torch.no_grad():
    for data, target in tqdm(test_set_):
      data, target = data.to(device), target.to(device)
      output = model_(data)

      _, predicted = torch.max(output.data, 1)
      test_loss = criterion(output, target)

      total += target.size(0)
      correct += (predicted == target).sum().item()

  model_acc = correct / total

  # print("Model accuracy: {}".format(model_acc))
  return model_acc, running_loss

print(model_tester(model, testLoader))

  0%|          | 0/100 [00:00<?, ?it/s]

(0.99, 9.54792369176721)
