**Installing TPU Client for Pytorch**

This notebook is used to train the Models and is optimised for using PyTorch on Clout TPUs. For more information on how to use TPU on google colab, please see https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb    

In [1]:
!pip uninstall -y torch

!pip install torch==1.8.2+cpu -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html
# !pip install torch==1.10.0 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html

!pip install -q cloud-tpu-client==0.10 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.8-cp37-cp37m-linux_x86_64.whl
# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl


Found existing installation: torch 1.8.2+cpu
Uninstalling torch-1.8.2+cpu:
  Successfully uninstalled torch-1.8.2+cpu
Looking in links: https://download.pytorch.org/whl/lts/1.8/torch_lts.html
Collecting torch==1.8.2+cpu
  Using cached https://download.pytorch.org/whl/lts/1.8/cpu/torch-1.8.2%2Bcpu-cp37-cp37m-linux_x86_64.whl (169.1 MB)
Installing collected packages: torch
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
torchvision 0.11.1+cu111 requires torch==1.10.0, but you have torch 1.8.2+cpu which is incompatible.
torchtext 0.11.0 requires torch==1.10.0, but you have torch 1.8.2+cpu which is incompatible.
torchaudio 0.10.0+cu111 requires torch==1.10.0, but you have torch 1.8.2+cpu which is incompatible.[0m
Successfully installed torch-1.8.2+cpu


**Imports**

In [2]:
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator, FormatStrFormatter, AutoMinorLocator
from tqdm.notebook import tqdm
import copy
from mpl_toolkits.mplot3d import Axes3D
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.distributed.parallel_loader as pll
import torch_xla.utils.serialization as xser

import torchvision.transforms as transforms
import torchvision.datasets

import random
import os
import pickle

random_seed = 1
torch.backends.cudnn.enabled = False

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = xm.xla_device()

torch_xla.core.xla_model.set_rng_state(1, device=device)
np.random.seed(0)
torch.manual_seed(0)




<torch._C.Generator at 0x7f9c570878f0>

**Making 'Robustness_analysis' directory the root directory**

In [3]:
%cd drive/MyDrive/Robustness_analysis/

/content/drive/MyDrive/Robustness_analysis


**For loading MNIST and CIFAR10 datsets**

The loadDatset method takes the following input variables:
>- dataset_type (str): the dataset you want to load, accepts a case-insensitive string 'mnist' for the MNIST dataset and 'cifar10' for the CIFAR10 dataset
>- download_ (bool): option to download datatset from torchvision.datasets
>- params (int, int): consists of two parameters in a tuple (batch size of dataset, number of workers) 


In [12]:
def loadDataset(dataset_type, download_, params):
  # Configures training (and evaluation) parameters
  dataset_type = dataset_type.upper()
  flags = {}
  flags['batch_size'] = params[0]
  flags['num_workers'] = params[1]

  if dataset_type == 'CIFAR10':
    # Transformations for CIFAR10
    trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train_dataset = torchvision.datasets.CIFAR10(root=r"Dataset", train=True,
                                                transform=trans, download=download_)
    test_dataset = torchvision.datasets.CIFAR10(root="Dataset", train=False,
                                                transform=trans, download=download_)
  elif dataset_type == 'MNIST':
    # Transformations for MNIST
    trans = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 
    train_dataset = torchvision.datasets.MNIST(root=r"Dataset", train=True,
                                                transform=trans, download=download_)
    test_dataset = torchvision.datasets.MNIST(root="Dataset", train=False,
                                                transform=trans, download=download_)
  else:
    # Transformations for CIFAR10 (default)
    trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    train_dataset = torchvision.datasets.CIFAR10(root=r"Dataset", train=True,
                                                transform=trans, download=download_)
    test_dataset = torchvision.datasets.CIFAR10(root="Dataset", train=False,
                                                transform=trans, download=download_)
    
  # for train set
  train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)

  train_loader = torch.utils.data.DataLoader(
      train_dataset,
      batch_size=flags['batch_size'],
      sampler=train_sampler,
      num_workers=flags['num_workers'],
      drop_last=True)

  # for test set 
  test_sampler = torch.utils.data.distributed.DistributedSampler(
    test_dataset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)

  test_loader = torch.utils.data.DataLoader(
      test_dataset,
      batch_size=flags['batch_size'],
      sampler=test_sampler,
      num_workers=flags['num_workers'],
      drop_last=True)
  
  return train_loader, test_loader

trainLoader, testLoader = loadDataset(dataset_type='mnist', download_=False, params=(100, 1))

**Loading ImageNet Tiny dataset**

This dataset is constructed using the Tiny ImageNet dataset from https://www.kaggle.com/c/tiny-imagenet.

The default dataset of 90,000 training examples and 10,000 test examples. To split the dataset to 80,000 training and 20,000 test examples, load the current dataset and select 50 examples of each class from the training set and to add to the test set.

loadimageNet accepts two parameters:
>- path (str): the path of the dataset folader 'tiny-imagenet-200/complete_dataset'
>- params (int, int): consists of two parameters in a tuple (batch size of dataset, number of workers)


In [5]:
class ImageNetTiny(torch.utils.data.Dataset):
  # Creating ImageNet class to orangise dataset
  def __init__(self, data_dict, transform=None, target_transform=None):
    self.targets = data_dict['labels']
    self.data = data_dict['images']
    self.classes = data_dict['Classes_dict']

    self.label_keys = self.classes.keys()
    self.label_values = self.classes.values()
    
    self.transform = transform
    self.target_transform = target_transform

  def __len__(self):
    return len(self.targets)

  def __getitem__(self, idx):
    image = self.data[idx]
    label = self.targets[idx]
    if self.transform:
      image = self.transform(image)
    if self.target_transform:
      label = self.target_transform(label)
    return image, label


def save_obj(obj, name):
  with open(name + '.pkl', 'wb') as f:
    pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

def load_imagenet(dir_path):
  train_file_name = "train_set"
  test_file_name = "test_set"

  with (open(dir_path+train_file_name+ '.pkl', "rb")) as openfile:
    while True:
        try:
            train_dict = (pickle.load(openfile))
        except EOFError:
            break

  with (open(dir_path+test_file_name+ '.pkl', "rb")) as openfile:
    while True:
        try:
            test_dict = (pickle.load(openfile))
        except EOFError:
            break
  
  return train_dict, test_dict

def loadImageNetTiny(path, params):
  flags = {}
  flags['batch_size'] = params[0]
  flags['num_workers'] = params[1]

  train_dict, test_dict = load_imagenet(path)

  trainset = ImageNetTiny(train_dict)
  testset = ImageNetTiny(test_dict)

  train_sampler = torch.utils.data.distributed.DistributedSampler(
    trainset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)

  train_loader = torch.utils.data.DataLoader(
    trainset,
    batch_size=flags['batch_size'],
    sampler=train_sampler,
    num_workers=flags['num_workers'],
    drop_last=True)

  test_sampler = torch.utils.data.distributed.DistributedSampler(
    testset,
    num_replicas=xm.xrt_world_size(),
    rank=xm.get_ordinal(),
    shuffle=True)

  test_loader = torch.utils.data.DataLoader(
      testset,
      batch_size=flags['batch_size'],
      sampler=test_sampler,
      num_workers=flags['num_workers'],
      drop_last=True)

  return train_loader, test_loader

trainLoader, testLoader = loadImageNetTiny(path="Dataset/tiny-imagenet-200/complete_dataset/", params=(100, 1))

**Loading Networks**

This cell is for loading the model for training. To load the correct model with input and output shapes configured to suit the dataset, the 'Model' class accepts two parameters:
>- mode_type (str): The type of model you want to load, options:
>>* ResNet-18 (default)
>>* ResNet-50
>>* SqueezeNet-v1.1
>>* ShuffleNet V2 x1.0 

>- data_type (str): The dataset you are using, options:
>>* MNIST
>>* CIFAR10
>>* ImageNet Tiny

In [13]:
def getModel(model_type, dataset_type):
  model_type = model_type.lower()
  dataset_type = dataset_type.lower()

  if model_type  == 'resnet50':
    if dataset_type == 'mnist':
      # for loading ResNet-50 model for MNIST dataset 
      temp_model = torchvision.models.resnet50(num_classes=10)
      temp_model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    elif dataset_type == 'cifar10':
      # for loading ResNet-50 model for CIFAR10 dataset
      temp_model = torchvision.models.resnet50(num_classes=10)
    elif dataset_type == 'imagenet': 
      # for loading ResNet-50 model for TimageNet Tiny dataset
      temp_model = torchvision.models.resnet50(num_classes=200)
    else:
      # (default) loading ResNet-50 model for CIFAR10 dataset
      temp_model = torchvision.models.resnet50(num_classes=10)

  elif model_type == 'squeezenet':
    if dataset_type == 'mnist':
      # for loading SqueezeNet-v1.1 model for MNIST dataset 
      temp_model = torchvision.models.squeezenet1_1(num_classes=10)
      temp_model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2))
    elif dataset_type == 'cifar10':
      # for loading SqueezeNet-v1.1 model for CIFAR10 dataset
      temp_model = torchvision.models.squeezenet1_1(num_classes=10)
    elif dataset_type == 'imagenet': 
      # for loading SqueezeNet-v1.1 model for TimageNet Tiny dataset
      temp_model = torchvision.models.squeezenet1_1(num_classes=200)
    else:
      # (default) loading SqueezeNet-v1.1 model for CIFAR10 dataset
      temp_model = torchvision.models.squeezenet1_1(num_classes=10)

  elif model_type == 'shufflenet':
    if dataset_type == 'mnist':
      # for loading ShuffleNet V2 x1.0 model for MNIST dataset 
      temp_model = torchvision.models.shufflenet_v2_x0_5(num_classes=10)
      temp_model.conv1[0] = nn.Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    elif dataset_type == 'cifar10':
      # for loading ShuffleNet V2 x1.0 model for CIFAR10 dataset
      temp_model = torchvision.models.shufflenet_v2_x0_5(num_classes=10)
    elif dataset_type == 'imagenet': 
      # for loading ShuffleNet V2 x1.0 model for TimageNet Tiny dataset
      temp_model = torchvision.models.shufflenet_v2_x0_5(num_classes=200)
    else:
      # (default) loading ShuffleNet V2 x1.0 model for CIFAR10 dataset
      temp_model = torchvision.models.shufflenet_v2_x0_5(num_classes=10)

  else:
    # (Default) loading ResNet-18 models
    if dataset_type == 'mnist':
      # for loading ResNet-18 model for MNIST dataset 
      temp_model = torchvision.models.resnet18(num_classes=10)
      temp_model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    elif dataset_type == 'cifar10':
      # for loading ResNet-18 model for CIFAR10 dataset
      temp_model = torchvision.models.resnet18(num_classes=10)
    elif dataset_type == 'imagenet': 
      # for loading ResNet-18 model for TimageNet Tiny dataset
      temp_model = torchvision.models.resnet18(num_classes=200)
    else:
      # (default) loading ResNet-18 model for CIFAR10 dataset
      temp_model = torchvision.models.resnet18(num_classes=10)

  return temp_model

class Model(nn.Module):
  def __init__(self, mod_type, data_type):
    super().__init__()
    self.model = getModel(mod_type, data_type)
    self.loss = nn.CrossEntropyLoss()

##  @auto_move_data
  def forward(self, x):
    return self.model(x)

  def training_step(self, batch, batch_no):
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    return loss

  def configure_optimizers(self):
    return torch.optim.Adam(self.model.parameters(), lr=0.001)

WRAPPED_MODEL = xmp.MpModelWrapper(Model('resnet18', 'mnist'))
model = WRAPPED_MODEL.to(device)


**Loading a pretrainned network**

To load a pretrained network, change the varibles:

>- model_name (str): name of the model you want to load (must be the save as the directory and file names of the saved model. Furthermore, the model must be the same as the one instantiated from the cell above.)
>- dataset_name (str): the name of the dataset 
>- reload_init (str): The directory name of the initialisation you want to load ('Init_' followed by initialisation number)
>- reload_epoch (int): The epoch of the pretrained model you want to load. 

In [19]:
model_name, dataset_name = 'ResNet18', 'MNIST'
reload_init, reload_epoch = 'Init_1', 15

path = "Saved_models/"+model_name+"/"+dataset_name+"/" + reload_init + '/'+ model_name+"_"+dataset_name+"_" + str(reload_epoch) + "epochs.ckpt"
# checkpoint = torch.load(path, map_location=device)

WRAPPED_MODEL = xmp.MpModelWrapper(Model('resnet18', 'mnist'))
load_model = WRAPPED_MODEL.to(device)

checkpoint = torch.load(path, map_location=torch.device('cpu'))
load_model.load_state_dict(checkpoint)

load_model = load_model.to(device)

**To run the FGSM attack on the model**

In [22]:
def fgsm_attack(image, epsilon, data_grad):
  min_val, max_val = torch.min(image), torch.max(image)
  # Collect the element-wise sign of the data gradient
  sign_data_grad = data_grad.sign()
  # Create the perturbed image by adjusting each pixel of the input image
  perturbed_image = image + epsilon*sign_data_grad

  # Adding clipping to maintain [0,1] range
  # perturbed_image = torch.clamp(perturbed_image, min_val, max_val)
  # Return the perturbed image
  return perturbed_image

dataset_name = 'MNIST'
model_name= 'ResNet18'
reload_init = 'Init_1'
# filter_type = "maxmin"
epoch_str = str(20)

model_test = Model(model_name, dataset_name)
path_dir = 'Saved_models/'+model_name+'/'+dataset_name+'/'
path = path_dir+reload_init+'/'+model_name+'_'+dataset_name+'_'+epoch_str+'epochs.ckpt'

checkpoint = torch.load(path, map_location=torch.device('cpu'))
model_test.load_state_dict(checkpoint)
model_test = model_test.to(device)
# print(run_fgsm_attack(model_test, test_loader, 0.011))

eps = 1.5

adv_examples = []
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

correct_adv, correct_og, total, running_loss = 0, 0, 0, 0

para_test_loader = pll.ParallelLoader(testLoader, [device]).per_device_loader(device)
# Loop over all examples in test set
for data, target in tqdm(para_test_loader):
  # Send the data and label to the device
  data, target = data.double().to(device), target.to(device)

  # Set requires_grad attribute of tensor. Important for Attack
  data.requires_grad = True

  # Forward pass the data through the model
  output = model_test(data)# .to(device)
  _, init_pred = torch.max(output.data, 1)

  # If the initial prediction is wrong, dont bother attacking, just move on
  # if init_pred.item() != target.item():
  #   continue

  # Calculate the loss
  loss = criterion(output, target)
  loss = loss.to(device)

  # Zero all existing gradients
  model_test.zero_grad()

  # Calculate gradients of model in backward pass
  loss.backward()

  # Collect datagrad
  data_grad = data.grad.data

  # Call FGSM Attack
  perturbed_data = fgsm_attack(data, eps, data_grad)

  # Re-classify the perturbed image
  output = model_test(perturbed_data).to(device)
  test_loss = criterion(output, target).to(device)

  # Check for success
  _, final_pred = torch.max(output.data, 1)

  running_loss += test_loss.item() * data.size(0)
  total += target.size(0)
  correct_adv += (final_pred == target).sum().item()
  correct_og += (init_pred == target).sum().item()

# Calculate final accuracy for this epsilon
final_acc_adv = correct_adv / total
final_acc_og = correct_og / total
epoch_loss = running_loss / total

print("Clean network accuracy: ", final_acc_og)
print("Adversarial network accuracy: ", final_acc_adv, "for epsilon: ", eps)

  0%|          | 0/100 [00:00<?, ?it/s]

Clean network accuracy:  0.9912
Adversarial network accuracy:  0.5566 for epsilon:  1.5


**Local analysis**

In [25]:
def fgsm_attack(image, epsilon, data_grad):
  min_val, max_val = torch.min(image), torch.max(image)
  # Collect the element-wise sign of the data gradient
  sign_data_grad = data_grad.sign()
  # Create the perturbed image by adjusting each pixel of the input image
  perturbed_image = image + epsilon*sign_data_grad

  # Adding clipping to maintain [0,1] range
  # perturbed_image = torch.clamp(perturbed_image, min_val, max_val)
  # Return the perturbed image
  return perturbed_image


def find_minmax(_obj, _names):
  g_max = 0
  g_min = 0

  for name in _names:
    temp_max = torch.max(_obj[name])
    temp_min = torch.min(_obj[name])

    if temp_max >= g_max:
      g_max = temp_max

    if temp_min <= g_min:
      g_min = temp_min

  return g_min, g_max


def get_test_names(ex_model_og):
  ex_model = copy.deepcopy(ex_model_og)
  # params, names = count_parameters(ex_model.model)

  model_obj = copy.deepcopy(ex_model.model.state_dict())
  new_names = []

  for p_name, module in list(ex_model.model.named_modules())[1:]:
    nn_ = p_name + ".weight"
    if isinstance(module, nn.Conv2d):
      new_names.append(nn_)
    elif isinstance(module, nn.Linear):
      new_names.append(nn_)

  return new_names, model_obj


def step_filter_func(input_mat, offset, fil_type):
  y = torch.zeros_like(input_mat)
  with torch.no_grad():
    if fil_type == "minmax":
      input_mat_fin = torch.where((input_mat <= offset), y, input_mat)
      # input_mat[ind] *= 0
    elif fil_type == "maxmin":
      input_mat_fin = torch.where((input_mat >= offset), y, input_mat)
    else:
      input_mat_fin = torch.where((input_mat <= offset), y, input_mat)
      
  return input_mat_fin

def count_zero(x):
  if torch.is_tensor(x):
    return torch.count_nonzero(x==0)
  else:
    return np.count_nonzero(x==0)

def save_obj(obj, name):
  with open(name + '.pkl', 'wb') as f:
    pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)


epochs = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100]

epsilon = 0.03
samples_no = 25 
dataset_name = 'MNIST'
model_name= 'ResNet18'
filter_type = "minmax"
rep = ["advog", "advog2", "advog3"]
glob_min, glob_max = -1, 1

final_dict = {}

for r, rep_count in enumerate(rep):
  r += 1
  for e in tqdm(epochs):
    epoch_str = str(e)
    WRAPPED_MODEL = xmp.MpModelWrapper(Model(model_name, dataset_name))
    model_test = WRAPPED_MODEL.to(device)

    path_dir = 'Saved_models/'+model_name+'/'+dataset_name+'/'
    path = path_dir+'Init_'+str(r)+'/'+model_name+'_'+dataset_name+'_'+epoch_str+'epochs.ckpt'

    checkpoint = torch.load(path, map_location=torch.device('cpu'))
    model_test.load_state_dict(checkpoint)
    model_test = model_test.to(device)
    criterion = nn.NLLLoss()

    acc_og_list, acc_adv_list, lss_list, zero_count = [], [], [], []
    results_dict = {}

    names_, obj_ = get_test_names(model_test)
    # g_min, g_max = find_minmax(obj_, names_)
    
    if filter_type == "minmax":
      g_min, g_max = glob_min, glob_max
    elif filter_type == "maxmin":
      g_min, g_max = glob_max, glob_min
    elif filter_type == 'pulseminmax':
      g_min, g_max = glob_min, glob_max
    else:
      g_min, g_max = glob_min, glob_max

    alpha_vals = torch.linspace(g_min, g_max, samples_no, device=device)
    steps = float((g_max - g_min) / samples_no)

    for a_count, alpha in enumerate(alpha_vals):
      temp_obj = copy.deepcopy(obj_)
      temp_model = copy.deepcopy(model_test)
      z_count = 0

      prev = alpha + steps
      next = alpha - steps
      prev, next = prev.to(device), next.to(device)

      for name in names_:
        mat = copy.deepcopy(temp_obj[name])
        mat = mat.to(device)
        filtered_weight = torch.zeros_like(mat)

        if filter_type == "pulseminmax":
          maska = mat >= prev
          maskb = mat < next

          filtered_weight[~maska] = mat[~maska]
          filtered_weight[~maskb] = mat[~maskb]

          z_count += count_zero(filtered_weight)

        else:
          filtered_weight = step_filter_func(mat, alpha, filter_type)
          z_count += count_zero(filtered_weight)

      temp_obj[name] = filtered_weight
      zero_count.append(int(z_count))
      temp_model.model.load_state_dict(temp_obj)

      # file_path_final = path_dir+"temp_model.ckpt"
      # # torch.save(model.state_dict(), file_path_final)
      # xm.save(temp_model.state_dict(), file_path_final)

      # checkpoint = torch.load(path_dir+"temp_model.ckpt", map_location=torch.device('cpu'))

      adv_examples = []
      criterion = nn.CrossEntropyLoss()
      criterion = criterion.to(device)

      correct_adv, correct_og, total, running_loss = 0, 0, 0, 0
      para_test_loader = pll.ParallelLoader(testLoader, [device]).per_device_loader(device)

      for data, target in para_test_loader:
        data, target = data.double().to(device), target.to(device)
        data.requires_grad = True

        # Forward pass the data through the model
        output = temp_model(data)
        _, init_pred = torch.max(output.data, 1)
        init_pred.to(device)

        # If the initial prediction is wrong, dont bother attacking, just move on
        # if init_pred.item() != target.item():
        #   continue

        # Calculate the loss
        loss = criterion(output, target)

        # Zero all existing gradients
        temp_model.zero_grad()

        # Calculate gradients of model in backward pass
        loss.backward()

        # Collect datagrad
        data_grad = data.grad.data

        # Call FGSM Attack
        perturbed_data = fgsm_attack(data, epsilon, data_grad)

        # Re-classify the perturbed image
        output = temp_model(perturbed_data).to(device)
        test_loss = criterion(output, target)

        # Check for success
        _, final_pred = torch.max(output.data, 1)
        final_pred.to(device)

        running_loss += test_loss.item() * data.size(0)
        total += target.size(0)
        correct_adv += (final_pred == target).sum().item()
        correct_og += (init_pred == target).sum().item()
      
      acc_og = correct_og/total
      acc_adv = correct_adv/total
      lss_ = running_loss/total
      
      acc_og_list.append(acc_og)
      acc_adv_list.append(acc_adv)
      lss_list.append(lss_)

      alpha = float(alpha)
      # results_dict[alpha] = (acc_og, acc_adv, lss_, zero_count)
      del temp_model

    final_dict[e] = [acc_og_list, acc_adv_list, lss_list, zero_count]

  print(final_dict.keys())
  dict_save_name = path_dir+"/global_data/"+model_name+"_"+rep_count+"_global_"+filter_type+"_"+dataset_name+"_"+"allepochs"
  save_obj(final_dict, dict_save_name)


**GLobal analysis**

In [None]:
def fgsm_attack(image, epsilon, data_grad):
  min_val, max_val = torch.min(image), torch.max(image)
  # Collect the element-wise sign of the data gradient
  sign_data_grad = data_grad.sign()
  # Create the perturbed image by adjusting each pixel of the input image
  perturbed_image = image + epsilon*sign_data_grad

  # Adding clipping to maintain [0,1] range
  # perturbed_image = torch.clamp(perturbed_image, min_val, max_val)
  # Return the perturbed image
  return perturbed_image


def get_test_names(ex_model_og):
  ex_model = copy.deepcopy(ex_model_og)
  # params, names = count_parameters(ex_model.model)

  model_obj = copy.deepcopy(ex_model.model.state_dict())
  new_names = []

  for p_name, module in list(ex_model.model.named_modules())[1:]:
    nn_ = p_name + ".weight"
    if isinstance(module, nn.Conv2d):
      new_names.append(nn_)
    elif isinstance(module, nn.Linear):
      new_names.append(nn_)

  return new_names, model_obj


def step_filter_func(input_mat, offset, fil_type):
  y = torch.zeros_like(input_mat)
  with torch.no_grad():
    if fil_type == "minmax":
      input_mat_fin = torch.where((input_mat <= offset), y, input_mat)
      # input_mat[ind] *= 0
    elif fil_type == "maxmin":
      input_mat_fin = torch.where((input_mat >= offset), y, input_mat)
    else:
      input_mat_fin = torch.where((input_mat <= offset), y, input_mat)
      
  return input_mat_fin


def save_obj(obj, name):
  with open(name + '.pkl', 'wb') as f:
    pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)


def count_zero(x):
  if torch.is_tensor(x):
    return torch.count_nonzero(x==0)
  else:
    return np.count_nonzero(x==0)


device = xm.xla_device()

epochs = [60, 70, 80, 90, 100]
dataset_name = "ImageNet"
# epsilon = 0.77 # (0.38) MNIST ResNet18, (0.75) MNIST ResNet50 
epsilon = 0.5 # 0.035 # CIFAR10 ResNet18 (0.015), ResNet50 (0.02), MNIST_SqueezeNet (0.8)
filter_type = "minmax"
filter_folder = 'Step_'+filter_type+'/'
model_name = "ShuffleNet"

for e_no in epochs:
  print("\n")
  print("running experiment for epoch: ", e_no)
  path = 'Saved_models/'+model_name+'/' + dataset_name + '/'
  file_name = 'Init_3/'+model_name+'_' + dataset_name + "_" + str(e_no) + 'epochs.ckpt'

  checkpoint = torch.load(path+file_name, map_location=torch.device("cpu"))
  # checkpoint = xser.load(path+file_name)
  model_ = Model()
  model_.load_state_dict(checkpoint)
  model_.to(device)
  criterion = nn.NLLLoss()

  mod_names, mod_obj = get_test_names(model_)
  samples_no = 25
  results_dict = {}

  for name in tqdm(mod_names):
    weight = copy.deepcopy(mod_obj[name])
    steps = ((torch.min(weight) - torch.max(weight)) / samples_no) / 2

    if filter_type == "minmax":
      alpha_vals = torch.linspace(torch.min(weight), torch.max(weight), samples_no, device=device)
    elif filter_type == "maxmin":
      alpha_vals = torch.linspace(torch.max(weight), torch.min(weight), samples_no, device=device)
    else:
      alpha_vals = torch.linspace(torch.min(weight), torch.max(weight), samples_no, device=device)

    acc_og_list, acc_adv_list, lss_list, zero_count = [], [], [], []

    # prev = torch.min(alpha_vals) - steps
    # next = prev

    for alpha in alpha_vals:
      model_clone = copy.deepcopy(model_)
      mat = copy.deepcopy(weight)
      mat = mat.to(device)
      filtered_weight = torch.zeros_like(mat)
      new_obj = copy.deepcopy(mod_obj)

      prev = alpha + steps
      next = alpha - steps
      prev, next = prev.to(device), next.to(device)

      if filter_type == "pulseminmax":
        maska = mat >= prev
        maskb = mat < next

        filtered_weight[~maska] = mat[~maska]
        filtered_weight[~maskb] = mat[~maskb]

        # new_obj[name] = filtered_weight
        # z_count += count_zero(filtered_weight)

      else:
        filtered_weight = step_filter_func(mat, alpha, filter_type)

      zero_count.append(count_zero(filtered_weight))
      # change this back to new_obj[name] (and declare new_obj at the beginning of the loop) do this for pulse analysis
      new_obj[name] = filtered_weight
      model_clone.model.load_state_dict(new_obj)
      model_clone = model_clone.to(device)

      criterion = nn.CrossEntropyLoss()
      criterion = criterion.to(device)

      correct_adv, correct_og, total, running_loss = 0, 0, 0, 0
      para_test_loader = pll.ParallelLoader(test_loader, [device]).per_device_loader(device)

      for data, target in para_test_loader:
        # Send the data and label to the device
        data, target = data.double().to(device), target.to(device)

        # Set requires_grad attribute of tensor. Important for Attack
        data.requires_grad = True

        # Forward pass the data through the model
        output = model_clone(data)# .to(device)
        _, init_pred = torch.max(output.data, 1)

        # If the initial prediction is wrong, dont bother attacking, just move on
        # if init_pred.item() != target.item():
        #   continue

        # Calculate the loss
        loss = criterion(output, target)
        loss = loss.to(device)

        # Zero all existing gradients
        model_.zero_grad()

        # Calculate gradients of model in backward pass
        loss.backward()

        # Collect datagrad
        data_grad = data.grad.data

        # Call FGSM Attack
        perturbed_data = fgsm_attack(data, epsilon, data_grad)

        # Re-classify the perturbed image
        output_adv = model_clone(perturbed_data)# .to(device)
        test_loss = criterion(output, target)# .to(device)

        # Check for success
        _, final_pred = torch.max(output_adv.data, 1)

        running_loss += test_loss.item() * data.size(0)
        total += target.size(0)
        correct_adv += (final_pred == target).sum().item()
        correct_og += (init_pred == target).sum().item()

        del test_loss, target, data

      acc_og = correct_og/total
      acc_adv = correct_adv/total
      lss_ = running_loss/total

      acc_og_list.append(acc_og)
      acc_adv_list.append(acc_adv)
      lss_list.append(lss_)

      del model_clone, filtered_weight, new_obj, para_test_loader

    new_name = name.replace(".weight", "")
    results_dict[new_name] = (acc_og_list, acc_adv_list, zero_count, alpha_vals)

  dict_save_name = path+filter_folder+model_name+"_step_"+filter_type+"_advog3_" + dataset_name + "_" + str(e_no) + "epochs"
  save_obj(results_dict, dict_save_name)
