# Maxim Distiller torch 1.13

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
%cd /distiller-pytorch-1.13_mvcnn
!pip3 install -e .

In [None]:
# Import fpzip for floating point compression
!pip install fpzip

In [None]:
# !pip install pandas --trusted-host pypi.org --trusted-host files.pythonhosted.org
# !pip install --upgrade pip
!pip install pyzfp

In [None]:
!pip install --upgrade --no-cache-dir gdown

# Extract Dataset

In [None]:
!tar xf modelnet_test.tgz

In [None]:
!mkdir data

# Importing Libraries

Importing libraries and parameters

In [None]:
!pip --version

In [None]:
import numpy as np
from numpy import array

import torch
import torch.optim as optim
import torch.nn as nn
import os,shutil,json
import argparse

import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import pickle
import os
from torch.utils.tensorboard import SummaryWriter # Tensorboard
import time

import glob
import torch.utils.data
import math
from skimage import io, transform
from PIL import Image
import torchvision as vision
from torchvision import transforms, datasets
# import random

import torchvision.models as models

from torchvision.transforms.functional import InterpolationMode
import torchvision.transforms.functional as TF

# For plotting
import matplotlib.pyplot as plt

# Added for ResNet partitioning
from collections import OrderedDict

# For logfile
import csv

# For timing
from datetime import datetime
from pytz import timezone

# Import fpzip for communication approximation
import fpzip
import pyzfp # import compress, decompress

# Import random for random indices generation
import random

# Distiller import for compute approximation
from distiller import models as dis_models
import distiller
from distiller.apputils import *

# Determining device
train_on_gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if train_on_gpu else "cpu")
print(device)

In [None]:
# Proper seed setting for reproducibility of the results
# Taken from Slide 73 (DL)
seed = 0
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmarks=False
os.environ['PYTHONHASHSEED'] = str(seed)

In [None]:
print(torch.__version__)

In [None]:
!nvidia-smi -L
!cat /proc/meminfo
# !/usr/local/cuda/bin/nvcc --version
# !nvidia-smi

In [None]:
print(os.cpu_count())

# Memory Approx Functions and classes

In [None]:
class MemoryApprox(object):
  """
  Use already created error mask and perform bitwise xor
  Args:
  :param refresh_interval: DRAM refresh interval to select specific error mask
  :param mask_dir: path to error mask
  :param img_size: size of image, used to get mask of same size as image
  :return: image injected with DRAM error in PIL format
  """

  def __init__(self, refresh_interval: int, mask_dir: str, img_size):
    self.refresh_interval = refresh_interval
    # convert to single value if tuple provided, assumption is that the width and height are the same
    if isinstance(img_size, tuple):
      img_size = img_size[0]

    if refresh_interval != 1:
      # path to dram error mask which when XOR'ed with image will give error injected image
      err_mask_root = os.path.join(mask_dir, 'mask_dram1_refint_{}_fm_2'.format(str(refresh_interval)))
      err_mask_path = os.path.join(err_mask_root, 'error_mask_ri{}_{}'.format(refresh_interval, img_size))
      # open error mask in image format and store in memory as this will be used for all images in dataset
      self.error_mask = Image.open(err_mask_path)

  def __call__(self, image):
    if self.refresh_interval != 1:
      # convert from PIL to numpy and do all operations
      assert image.size == self.error_mask.size, "Image size and Error mask size differs"
      # Perform bitwise xor of original resized and cropped image (assume coming from SENSOR directly)
      # TODO: VERIFY IF CROP CAN BE DONE SEPARATELY AFTER DRAM ERROR
      err_image_np = np.bitwise_xor(np.asarray(image), np.asarray(self.error_mask))

      # convert from numpy.ndarray to PIL Image
      image = TF.to_pil_image(err_image_np)
    return image

# Model

In [None]:
# Inherit Model from "torch.nn.Module"
class Model(nn.Module):

    def __init__(self, name):
        super(Model, self).__init__()
        self.name = name


    # Function to save model
    def save(self, path, epoch=0):
        # Complete path to save the Model
        complete_path = os.path.join(path, self.name)
        if not os.path.exists(complete_path):
            os.makedirs(complete_path)
        # Uses PyTorch's torch.save to save the model every epoch
        torch.save(self.state_dict(),
                os.path.join(complete_path,
                    "model-{}.pth".format(str(epoch).zfill(5))))


    # Function to save results ??
    # What is this function?
    def save_results(self, path, data):
        raise NotImplementedError("Model subclass must implement this method.")


    # Function to load Model
    def load(self, path, modelfile=None):
        # Get the path from where the saved model to be loaded
        # complete_path = os.path.join(path, self.name)
        complete_path = path
        # If model does not exist, raise error
        if not os.path.exists(complete_path):
            raise IOError("{} directory does not exist in {}".format(self.name, path))

        # If no modelfile name is given
        if modelfile is None:
            # Grab the latest model file
            model_files = glob.glob(complete_path+"/*")
            mf = max(model_files)
        else:
            # Else grab the specific model file
            mf = os.path.join(complete_path, modelfile)

        # load that model file
        # print(mf)
        if device.type == 'cpu':
          self.load_state_dict(torch.load(mf, map_location=torch.device('cpu')))
        else:
          self.load_state_dict(torch.load(mf))
        # self.load_state_dict(torch.load(mf, map_location=torch.device('cpu')))
        # self.load_state_dict(torch.load(mf), strict=False)

# MVCNN and SVCNN

In [None]:
# from enum import unique
# Modified MVCNN code for feature map compression

def compress_tensor(input, method='zfp', compr_knob=2):
  input_shape = input.shape
  # print(input_shape)

  # Get the original input size in bytes

  # # Manual approach
  # # It assumes that all calculations are done with float32
  # original_size = 1
  # # input_shape_list = list(input_shape)
  # for element in input_shape:
  #   original_size *= element
  # original_size *= 4 # Get size in bytes

  # Better approach
  original_size = input.element_size() * input.nelement()

  # Compression and decompression in numpy array
  input = input.cpu() # Mode tensor to CPU for computation
  data = input.detach().numpy()
  # data = array(input)

  # Compress and decompress (lossless or lossy)
  if (method == 'fpzip'):
    compressed_bytes = fpzip.compress(data, precision=compr_knob, order='C')
    data_again = fpzip.decompress(compressed_bytes, order='C')
  elif (method == 'zfp'):
    parallel = True   # Not passed through function, local parameter, not frequently changed
    compressed_bytes = pyzfp.compress(data, tolerance=compr_knob, parallel=parallel)
    data_again = pyzfp.decompress(compressed_bytes, data.shape, data.dtype, tolerance=compr_knob)

  # Convert data to tensor
  decompressed_input = torch.tensor(data_again)
  # print(decompressed_input)
  # print(decompressed_input.shape)
  decompressed_input = decompressed_input.view(input_shape)

  # Sum of Absolute Differences
  SAD = torch.sum(torch.abs(input - decompressed_input)).item()

  # Send back to CUDA (if available)
  decompressed_input = decompressed_input.to(device)

  return decompressed_input, original_size, len(compressed_bytes), SAD

def _get_non_dom_knob (inp_knob_list, knob_dom):
  knob_list = inp_knob_list.copy()
  if knob_dom in knob_list:
    knob_list.remove(knob_dom)
    if (len(knob_list) == 0):
      return knob_dom
    else:
      return knob_list[0]
  else:
    print('Invalid dominant node knob')
    return

class MVCNN_approx_opt(Model):

    def __init__(self, name, model, accurate_models_dict, nclasses=40, cnn_name='alexnet', num_views=12, part_point=5, compr_method = 'None', compr_knob=None, refresh_interval=[0,0,0,0,0,0,0,0,0,0,0,0], \
                 approx_model_dir='/content/drive/MyDrive/Arghadip/MVCNN/jongchyisu/Trained_Models/Approx_Mem', \
                 accurate_model_dir='/content/drive/MyDrive/Arghadip/MVCNN/jongchyisu/Trained_Models/Sensor_Subsampling_Baselines/w_gen_pool/epoch_15', \
                 sparsity=[0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0], ri_dom=0, sp_dom=0.0,\
                 model_dom=None, model_nondom=None):
        super(MVCNN_approx_opt, self).__init__(name)

        self.classnames=['airplane','bathtub','bed','bench','bookshelf','bottle','bowl','car','chair',
                         'cone','cup','curtain','desk','door','dresser','flower_pot','glass_box',
                         'guitar','keyboard','lamp','laptop','mantel','monitor','night_stand',
                         'person','piano','plant','radio','range_hood','sink','sofa','stairs',
                         'stool','table','tent','toilet','tv_stand','vase','wardrobe','xbox']

        self.nclasses = nclasses
        self.num_views = num_views
        self.part_point = part_point
        self.compr_knob = compr_knob
        self.compr_method = compr_method
        self.feature_extractors = model.feature_extractors

        # Get unique refresh interval list
        self.refresh_interval = refresh_interval
        self.approx_model_dir = approx_model_dir
        self.accurate_model_dir = accurate_model_dir
        self.accurate_models_dict = accurate_models_dict

        unique_intervals = set(self.refresh_interval)
        self.unique_intervals_list = list(unique_intervals)

        # Unique sparsity list
        self.sparsity = sparsity
        unique_sparsities = set(self.sparsity)
        self.unique_sparsities_list = list(unique_sparsities)

        # Mean and standard deviation of the dataset used (NOT used anywhere)
        self.mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False).to(device)
        self.std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False).to(device)

        # use_resnet = True if "cnn_name" starts with "resnet"
        self.use_resnet = cnn_name.startswith('resnet')

        # Get the knob settings for dominant group
        self.ri_dom = ri_dom
        self.sp_dom = sp_dom

        self.file_name = None

        # # If "resnet"
        # if self.use_resnet:
        #     # # Get the layers before final classifier in "net_1" ??? VERIFY!!
        #     # # Verified, pooling is done just before the final layer
        #     # self.net_1 = nn.Sequential(*list(model.net.children())[:-1])
        #     # # Get the final classifier in "net_2"
        #     # self.net_2 = model.net.fc
        #     self.net_1 = model.net_1
        #     self.net_2 = model.net_2
        #     self.net_3 = model.net_3

        # # If NOT "resnet"
        # else:
        # Get 3 parts of the network
        # self.net_1 = model.net_1
        # self.net_2 = model.net_2
        # self.net_3 = model.net_3

        # path = '/content/drive/MyDrive/Arghadip/MVCNN/jongchyisu/Trained_Models/Sensor_Subsampling_Baselines/w_gen_pool/epoch_15'
        # modelfile_mvcnn = 'Alexnet_MVCNN_pool_5_epoch_12.pth'
        # accurate_models_dict = {
        #           'alexnet_PP_5' : 'Alexnet_MVCNN_pool_5_epoch_12.pth',
        #           'vgg11_PP_10'  : 'Vgg11_MVCNN_pool_10_epoch_12.pth',
        #           'resnet34_PP_5': 'Resnet34_MVCNN_PP_5.pth'
        # }


        # If all views are working with equal knobs
        if (len(self.unique_intervals_list) == 1) and (len(self.unique_sparsities_list) == 1):
          # If accurate
          if ((self.unique_intervals_list[0] == 0) or (self.unique_intervals_list[0] == 1)) and (self.unique_sparsities_list[0] == 0):
            self.net_1 = model.net_1
            self.net_2 = model.net_2
            self.net_3 = model.net_3
            self.load(self.accurate_model_dir, self.accurate_models_dict[cnn_name + '_PP_' + str(self.part_point)])
            self.eval()
            self.file_name = self.accurate_models_dict[cnn_name + '_PP_' + str(self.part_point)]
          # else if approximate but all equal (load once)
          else:
            if ((self.unique_intervals_list[0] == 0) and (not (self.unique_sparsities_list[0] == 0))):
              modelfile_mvcnn = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(1) + '_sparsity_' + str(self.unique_sparsities_list[0]) + '.pth.tar'    # RI = 0 and 1 has no errors
            else:
              modelfile_mvcnn = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(self.unique_intervals_list[0]) + '_sparsity_' + str(self.unique_sparsities_list[0]) + '.pth.tar'
            # temp = distiller.models.create_model(pretrained=True, dataset='modelnet', arch= cnn_name + '_mvcnn', parallel=False, device_ids=-1)
            # load_checkpoint(temp, os.path.join(self.approx_model_dir,modelfile_mvcnn))
            # # print('Loaded dominant from ', modelfile_mvcnn)
            # temp.eval()

            temp = model_dom
            self.file_name = modelfile_mvcnn
            self.net_1 = temp.net_1
            self.net_2 = temp.net_2
            self.net_3 = temp.net_3
        elif (len(self.unique_intervals_list) <= 2) and (len(self.unique_sparsities_list) <= 2):   # When operating with group-wise same knob settings
          ri_nondom = _get_non_dom_knob (self.unique_intervals_list, self.ri_dom)
          sp_nondom = _get_non_dom_knob (self.unique_sparsities_list, self.sp_dom)
          modelfile1 = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(self.ri_dom) + '_sparsity_' + str(self.sp_dom) + '.pth.tar'
          modelfile2 = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(ri_nondom) + '_sparsity_' + str(sp_nondom) + '.pth.tar'
          self.file_name = modelfile1 + '_&_' + modelfile2
          # temp = distiller.models.create_model(pretrained=True, dataset='modelnet', arch= cnn_name + '_mvcnn', parallel=False, device_ids=-1)

          # # loading dominant model
          # load_checkpoint(temp, os.path.join(self.approx_model_dir,modelfile1))
          # # print('Loaded dominant from ', modelfile1)
          # temp.eval()

          temp = model_dom
          if (self.use_resnet) or (self.part_point < self.feature_extractors):
            self.net_1_dom = temp.net_1
            self.net_2 = temp.net_2
          else:
            self.net_1_dom = temp.net_1
            self.net_2_dom = temp.net_2
          # self.net_3 = temp.net_3
          del temp

          # # loading non-dominant model
          # temp2 = distiller.models.create_model(pretrained=True, dataset='modelnet', arch= cnn_name + '_mvcnn', parallel=False, device_ids=-1)
          # load_checkpoint(temp2, os.path.join(self.approx_model_dir,modelfile2))
          # # print('Loaded non-dominant from ', modelfile2)
          # temp2.eval()

          temp2 = model_nondom
          if (self.use_resnet) or (self.part_point < self.feature_extractors):
            self.net_1_nondom = temp2.net_1
            self.net_2 = temp2.net_2
          else:
            self.net_1_nondom = temp2.net_1
            self.net_2_nondom = temp2.net_2

          self.net_3 = temp2.net_3     # no approximation is applied to net_3 (hence it does not matter from where it is loaded, it's same for all)
          del temp2

        else: # create the basic backbone to load later
          modelfile_mvcnn = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(self.unique_intervals_list[0]) + '_sparsity_' + str(self.unique_sparsities_list[0]) + '.pth.tar'
          temp = distiller.models.create_model(pretrained=True, dataset='modelnet', arch= cnn_name + '_mvcnn', parallel=False, device_ids=-1)
          load_checkpoint(temp, os.path.join(self.approx_model_dir,modelfile_mvcnn))
          # print('Loaded non-dominant from ', modelfile_mvcnn)
          temp.eval()
          self.file_name = modelfile_mvcnn
          self.net_1 = temp.net_1
          self.net_2 = temp.net_2
          self.net_3 = temp.net_3


        # path = '/content/drive/MyDrive/Arghadip/MVCNN/Distiller/Compress_MVCNN/Alexnet/Logs/2022.06.18-223606_60_60_20'
        # modelfile_mvcnn = 'checkpoint.pth.tar'

        # if len(self.unique_intervals_list) == 1:
        #   if not((self.unique_intervals_list[0] == 0) or (self.unique_intervals_list[0] == 1)):
        #     path = self.approx_model_dir
        #     modelfile_mvcnn = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(self.unique_intervals_list[0]) + '.pth'

        # # Load parameters
        # # self.load(path, modelfile_mvcnn)
        # # self.eval()

        # # Load parameters for thinned model (distiller)
        # # Load from distiller checkpoint (Approx Compute)
        # # # Load thinned model from distiller
        # temp = distiller.models.create_model(pretrained=True, dataset='modelnet', arch='alexnet_mvcnn', parallel=False, device_ids=-1)
        # checkpoint_file = os.path.join(path,modelfile_mvcnn)
        # load_checkpoint(temp, checkpoint_file)
        # # checkpoint = torch.load(os.path.join(path,modelfile_mvcnn))
        # # print(checkpoint)
        # # model.load_state_dict(checkpoint['state_dict'])
        # temp.eval()

        # self.net_1 = temp.net_1
        # self.net_2 = temp.net_2
        # self.net_3 = temp.net_3

    # Forward function defines how the inputs will pass through the network
    def forward(self, x):
      # if self.use_resnet:
      #   # Pass the input through the "net_1"
      #   y = self.net_1(x) # Intermediate output for all images in batch
      #   # EXPLORE IN DETAILS!!
      #   # Combine the inputs???? --> Explored!
      #   # Group the feature maps corresponding to a single model
      #   y = y.view((int(x.shape[0]/self.num_views),self.num_views,y.shape[-3],y.shape[-2],y.shape[-1]))#(8,12,512,7,7)
      #   # Flatten and pass the pooled features to the final classifier
      #   # Max pool among the feature maps of a single 3D model
      #   return self.net_2(torch.max(y,1)[0].view(y.shape[0],-1))
      # else:
      batch_mem_orig = 0
      batch_mem_compressed = 0
      SAD_batch = 0

      ####################### POOLING before FLATTENING #######################
      #########################################################################
      if (self.use_resnet) or (self.part_point < self.feature_extractors):

        ################## Approximate Memory Mimic ##########################
        if (len(self.unique_intervals_list) == 1) and (len(self.unique_sparsities_list) == 1): # No approx or same refresh interval for all views
          y1 = self.net_1(x)
          y1 = y1.view((int(x.shape[0]/self.num_views),self.num_views,y1.shape[-3],y1.shape[-2],y1.shape[-1]))
        else: # More than one type of DRAM refresh intervals
          x = x.view((int(x.shape[0]/self.num_views),self.num_views,x.shape[-3],x.shape[-2],x.shape[-1])) # Get the shape as (N,V,C,H,W)
          x = torch.transpose(x,0,1)  # Transpose to get the shape as (V,N,C,H,W)
          y1 = []
          for i in range(self.num_views): # For each view

            # # only approx memory
            # net_1_file = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(self.refresh_interval[i]) + '_net_1.pth'  # Get the 'net_1' parameter file for the RI
            # # self.net_1.load(self.approx_model_dir,net_1_file)  # Load only 'net_1' params (running on EDGE)
            # mf1 = os.path.join(self.approx_model_dir, net_1_file)
            # if device.type == 'cpu':
            #   self.net_1.load_state_dict(torch.load(mf1, map_location=torch.device('cpu')))
            # else:
            #   self.net_1.load_state_dict(torch.load(mf1))
            # self.eval() # always eval() after loading

            if (len(self.unique_intervals_list) <= 2) and (len(self.unique_sparsities_list) <= 2):
              if (self.refresh_interval[i] == self.ri_dom) and (self.sparsity[i] == self.sp_dom):
                # print('Passing though dominant, view = ',i+1, 'RI = ', self.refresh_interval[i], 'SP = ', self.sparsity[i])
                y1_this_view = self.net_1_dom(x[i])
              else:
                y1_this_view = self.net_1_nondom(x[i])
            else:
              # with approx computing
              modelfile_mvcnn = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(self.refresh_interval[i]) + '_sparsity_' + str(self.sparsity[i]) + '.pth.tar'
              temp = distiller.models.create_model(pretrained=True, dataset='modelnet', arch= cnn_name + '_mvcnn', parallel=False, device_ids=-1)
              load_checkpoint(temp, os.path.join(self.approx_model_dir,modelfile_mvcnn))
              temp.eval()
              self.net_1 = temp.net_1
              y1_this_view = self.net_1(x[i]) # Get the output of 'net_1' for this view

            y1.append(y1_this_view) # Append to the list for all the views
          y1 = torch.stack(y1)  # Convert the list to tensor
          y1 = torch.transpose(y1,0,1)  # Transpose to get the shape as (N,V,C,H,W)


        ################## Approximate communication mimic ####################
        # If communication approximation needs to be applied
        if self.compr_method != 'None':
          y1 = torch.transpose(y1,0,1)  # Get V tensors of dimension (N,C,H,W) OR (V,N,C,H,W)
          y1_compressed = []  # Empty list to append the view-wise compressed tensors of dimension (N,C,H,W)
          for i in range(self.num_views):   # Start loop for number of views (as different views can have different level of compression)
            # fpzip can work on a 4D tensor
            if (self.compr_method == 'fpzip'):
              # Returns compressed tensor with same shape as input, original tensor size and compressed tensor size in bytes, and SAD of input and output
              view_compressed, original_size, compressed_size, SAD = compress_tensor(y1[i],'fpzip', self.compr_knob[i])
            # zfp works with 3D tensor (needs additional loop)
            elif (self.compr_method == 'zfp'):
              view_compressed = []
              original_size = 0
              compressed_size = 0
              SAD = 0
              # Start loop for each channel of each view
              for j in range(y1[i].shape[0]):
                channel_compressed, orig_ch_size, compr_ch_size, SAD_ch = compress_tensor(y1[i][j],'zfp', self.compr_knob[i])
                # Append and increment
                view_compressed.append(channel_compressed)
                original_size += orig_ch_size
                compressed_size += compr_ch_size
                SAD += SAD_ch
              view_compressed = torch.stack(view_compressed)  # Convert to tensor
            y1_compressed.append(view_compressed) # Append to form the entire batch's FM
            batch_mem_orig += original_size # Increment byte count for original tensor
            batch_mem_compressed += compressed_size # Increment byte count for the compressed tensor
            SAD_batch += SAD # Increment the SAD value
          y1_compressed = torch.stack(y1_compressed)  # Cast the list to tensor
          y1 = torch.transpose(y1_compressed,0,1) # Take transpose to get back the original shape (N,V,C,H,W)

        y2 = self.net_2(torch.max(y1,1)[0])   # MAX pooling
        # y2 = self.net_2(torch.mean(y1,1))     # AVG pooling
        return self.net_3(y2.view(y2.shape[0],-1)), batch_mem_orig, batch_mem_compressed, SAD_batch

      ####################### POOLING after FLATTENING ########################
      #########################################################################
      else:

        ################## Approximate Memory Mimic ##########################
        if (len(self.unique_intervals_list) == 1) and (len(self.unique_sparsities_list) == 1): # No approx or same refresh interval for all views
          y1 = self.net_1(x)
          y2 = self.net_2(y1.view(y1.shape[0],-1))
          y2 = y2.view((int(x.shape[0]/self.num_views),self.num_views,y2.shape[-1]))
        else: # More than one type of DRAM refresh intervals
          x = x.view((int(x.shape[0]/self.num_views),self.num_views,x.shape[-3],x.shape[-2],x.shape[-1])) # Get the shape as (N,V,C,H,W)
          x = torch.transpose(x,0,1)  # Transpose to get the shape as (V,N,C,H,W)
          y2 = []
          for i in range(self.num_views): # For each view

            # # Only approx memory
            # net_1_file = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(self.refresh_interval[i]) + '_net_1.pth'  # Get the 'net_1' parameter file for the RI
            # net_2_file = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(self.refresh_interval[i]) + '_net_2.pth'  # Get the 'net_2' parameter file for the RI
            # # self.net_1.load(self.approx_model_dir,net_1_file)  # Load 'net_1' params (running on EDGE)
            # # self.net_2.load(self.approx_model_dir,net_2_file)  # Load 'net_2' params (running on EDGE)
            # mf1 = os.path.join(self.approx_model_dir, net_1_file)
            # mf2 = os.path.join(self.approx_model_dir, net_2_file)
            # if device.type == 'cpu':
            #   self.net_1.load_state_dict(torch.load(mf1, map_location=torch.device('cpu')))
            #   self.net_2.load_state_dict(torch.load(mf2, map_location=torch.device('cpu')))
            # else:
            #   self.net_1.load_state_dict(torch.load(mf1))
            #   self.net_2.load_state_dict(torch.load(mf2))
            # self.eval() # always eval() after loading

            if (len(self.unique_intervals_list) <= 2) and (len(self.unique_sparsities_list) <= 2):
              if (self.refresh_interval[i] == self.ri_dom) and (self.sparsity[i] == self.sp_dom):
                # print('Passing though dominant, view = ',i+1, 'RI = ', self.refresh_interval[i], 'SP = ', self.sparsity[i])
                y1_this_view = self.net_1_dom(x[i])
                y2_this_view = self.net_2_dom(y1_this_view.view(y1_this_view.shape[0],-1))  # Flattened and fed to 'net_2'
              else:
                y1_this_view = self.net_1_nondom(x[i])
                y2_this_view = self.net_2_nondom(y1_this_view.view(y1_this_view.shape[0],-1))  # Flattened and fed to 'net_2'
            else:
              # with approx computing
              modelfile_mvcnn = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(self.refresh_interval[i]) + '_sparsity_' + str(self.sparsity[i]) + '.pth.tar'
              temp = distiller.models.create_model(pretrained=True, dataset='modelnet', arch= cnn_name + '_mvcnn', parallel=False, device_ids=-1)
              load_checkpoint(temp, os.path.join(self.approx_model_dir,modelfile_mvcnn))
              temp.eval()
              self.net_1 = temp.net_1
              self.net_2 = temp.net_2
              y1_this_view = self.net_1(x[i]) # Get the output of 'net_1' for this view
              y2_this_view = self.net_2(y1_this_view.view(y1_this_view.shape[0],-1))  # Flattened and fed to 'net_2'

            y2.append(y2_this_view) # Append to the list for all the views
          y2 = torch.stack(y2)  # Convert the list to tensor (V,N,C*H*W)
          y2 = torch.transpose(y2,0,1)  # Transpose to get the shape as (N,V,C*H*W)


        ################## Approximate communication mimic ####################
        # If communication approximation needs to be applied
        if self.compr_method != 'None':
          y2 = torch.transpose(y2,0,1)  # Get V tensors of dimension (N,C*H*W) OR (V,N,C*H*W)
          y2_compressed = []  # Empty list to append the view-wise compressed tensors of dimension (N,C*H*W)
          for i in range(self.num_views):   # Start loop for number of views (as different views can have different level of compression)
            if (self.compr_method == 'fpzip'):
              # Returns compressed tensor with same shape as input, original tensor size and compressed tensor size in bytes, and SAD of input and output
              view_compressed, original_size, compressed_size, SAD = compress_tensor(y2[i],'fpzip', self.compr_knob[i])
            elif (self.compr_method == 'zfp'):
              # Returns compressed tensor with same shape as input, original tensor size and compressed tensor size in bytes, and SAD of input and output
              view_compressed, original_size, compressed_size, SAD = compress_tensor(y2[i],'zfp', self.compr_knob[i])
            y2_compressed.append(view_compressed) # Append to form the entire batch's FM
            batch_mem_orig += original_size # Increment byte count for original tensor
            batch_mem_compressed += compressed_size # Increment byte count for the compressed tensor
            SAD_batch += SAD # Increment the SAD value
          y2_compressed = torch.stack(y2_compressed)  # Cast the list to tensor
          y2 = torch.transpose(y2_compressed,0,1) # Take transpose to get back the original shape (N,V,C*H*W)

        return self.net_3(torch.max(y2,1)[0]), batch_mem_orig, batch_mem_compressed, SAD_batch   # MAX pooling
        # return self.net_3(torch.mean(y2,1)), batch_mem_orig, batch_mem_compressed, SAD_batch   # AVG pooling

In [None]:
# Mean and standard deviation of the dataset (Derived from ModelNet40)
mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False).to(device)
std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False).to(device)

# Definition of flip function
# What it does?
def flip(x, dim):
    xsize = x.size()
    dim = x.dim() + dim if dim < 0 else dim
    x = x.view(-1, *xsize[dim:])
    x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1,
                      -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :]
    return x.view(xsize)


class SVCNN(Model):

    def __init__(self, name, nclasses=40, pretraining=True, cnn_name='vgg11', part_point=5):
        super(SVCNN, self).__init__(name)
        # "name" is a dummy parameter

        # All the classes in ModelNet40
        self.classnames=['airplane','bathtub','bed','bench','bookshelf','bottle','bowl','car','chair',
                         'cone','cup','curtain','desk','door','dresser','flower_pot','glass_box',
                         'guitar','keyboard','lamp','laptop','mantel','monitor','night_stand',
                         'person','piano','plant','radio','range_hood','sink','sofa','stairs',
                         'stool','table','tent','toilet','tv_stand','vase','wardrobe','xbox']

        self.nclasses = nclasses
        self.pretraining = pretraining
        self.cnn_name = cnn_name
        self.part_point = part_point

        # use_resnet = True if "cnn_name" starts with "resnet"
        self.use_resnet = cnn_name.startswith('resnet')

        # HARDCODING of the inherent division of the CNN (feature extractor and classifier)
        # Check values and update manually
        if not self.use_resnet:
          if self.cnn_name == 'alexnet':
            self.feature_extractors = 12
          elif self.cnn_name == 'vgg11':
            self.feature_extractors = 20
          elif self.cnn_name == 'vgg16':
            self.feature_extractors = 30
        else:
          self.feature_extractors = 8
          # if self.cnn_name == 'resnet18':
          #   self.feature_extractors = 999
          # if self.cnn_name == 'resnet34':
          #   self.feature_extractors = 999
          # if self.cnn_name == 'resnet50':
          #   self.feature_extractors = 999

        # Mean and standard deviation of the dataset used (NOT used anywhere)
        self.mean = Variable(torch.FloatTensor([0.485, 0.456, 0.406]), requires_grad=False).to(device)
        self.std = Variable(torch.FloatTensor([0.229, 0.224, 0.225]), requires_grad=False).to(device)

        # If "resnet"
        if self.use_resnet:
          if self.cnn_name == 'resnet18':
              # Load the model from PyTorch repo
              # Set the "pretrained" value based on the passed argument
              self.net = models.resnet18(pretrained=self.pretraining)
              # Change the input and output dimensions of the final classifier (fc) layer
              # 40 classes in ModelNet40
              self.net.fc = nn.Linear(512,40)
          elif self.cnn_name == 'resnet34':
              self.net = models.resnet34(pretrained=self.pretraining)
              self.net.fc = nn.Linear(512,40)
          elif self.cnn_name == 'resnet50':
              self.net = models.resnet50(pretrained=self.pretraining)
              self.net.fc = nn.Linear(2048,40)

          self.resnet_modules = [
              ('conv1',self.net.conv1),
              ('bn1', self.net.bn1),
              ('relu', self.net.relu),
              ('maxpool', self.net.maxpool),
              ('layer1', self.net.layer1),
              ('layer2', self.net.layer2),
              ('layer3', self.net.layer3),
              ('layer4', self.net.layer4),
              ('avgpool', self.net.avgpool)
          ]
          # Partitioning
          if (self.part_point < self.feature_extractors):
            self.net_1 = nn.Sequential(OrderedDict(self.resnet_modules[0:self.part_point+1]))
            self.net_2 = nn.Sequential(OrderedDict(self.resnet_modules[self.part_point+1:]))
          else:
            self.net_1 = nn.Sequential(OrderedDict(self.resnet_modules[0:self.feature_extractors+1]))
            self.net_2 = nn.Sequential(OrderedDict(self.resnet_modules[self.feature_extractors:self.feature_extractors]))
          self.net_3 = self.net.fc
            # if self.cnn_name == 'resnet18':
            #     # Load the model from PyTorch repo
            #     # Set the "pretrained" value based on the passed argument
            #     self.net = models.resnet18(pretrained=self.pretraining)
            #     # Change the input and output dimensions of the final classifier (fc) layer
            #     # 40 classes in ModelNet40
            #     self.net.fc = nn.Linear(512,40)
            # elif self.cnn_name == 'resnet34':
            #     self.net = models.resnet34(pretrained=self.pretraining)
            #     self.net.fc = nn.Linear(512,40)
            # elif self.cnn_name == 'resnet50':
            #     self.net = models.resnet50(pretrained=self.pretraining)
            #     self.net.fc = nn.Linear(2048,40)

        # If NOT resnet
        # "net_1" = Feature Extractor
        # "net_2" = Classifier
        else:
          if (self.part_point < self.feature_extractors):
            if self.cnn_name == 'alexnet':
                self.net_1 = models.alexnet(pretrained=self.pretraining).features[0:self.part_point+1]
                self.net_2 = models.alexnet(pretrained=self.pretraining).features[self.part_point+1:self.feature_extractors+1]
                self.net_3 = models.alexnet(pretrained=self.pretraining).classifier
            elif self.cnn_name == 'vgg11':
                self.net_1 = models.vgg11(pretrained=self.pretraining).features[0:self.part_point+1]
                self.net_2 = models.vgg11(pretrained=self.pretraining).features[self.part_point+1:self.feature_extractors+1]
                self.net_3 = models.vgg11(pretrained=self.pretraining).classifier
            elif self.cnn_name == 'vgg16':
                self.net_1 = models.vgg16(pretrained=self.pretraining).features[0:self.part_point+1]
                self.net_2 = models.vgg16(pretrained=self.pretraining).features[self.part_point+1:self.feature_extractors+1]
                self.net_3 = models.vgg16(pretrained=self.pretraining).classifier

            # self.net_2._modules['6'] = nn.Linear(4096,40)
            # self.net_2._modules['1'][6] = nn.Linear(4096,40)  # Changed for partitioning change
            # self.net_3._modules['6'] = nn.Linear(4096,40)
          else:
            if self.cnn_name == 'alexnet':
                self.net_1 = models.alexnet(pretrained=self.pretraining).features
                self.net_2 = models.alexnet(pretrained=self.pretraining).classifier[0:min((self.part_point-self.feature_extractors),6)]
                self.net_3 = models.alexnet(pretrained=self.pretraining).classifier[min((self.part_point-self.feature_extractors),6):]
            elif self.cnn_name == 'vgg11':
                self.net_1 = models.vgg11(pretrained=self.pretraining).features
                self.net_2 = models.vgg11(pretrained=self.pretraining).classifier[0:min((self.part_point-self.feature_extractors),6)]
                self.net_3 = models.vgg11(pretrained=self.pretraining).classifier[min((self.part_point-self.feature_extractors),6):]
            elif self.cnn_name == 'vgg16':
                self.net_1 = models.vgg16(pretrained=self.pretraining).features
                self.net_2 = models.vgg16(pretrained=self.pretraining).classifier[0:min((self.part_point-self.feature_extractors),6)]
                self.net_3 = models.vgg16(pretrained=self.pretraining).classifier[min((self.part_point-self.feature_extractors),6):]
          # Outside if-else
          self.net_3._modules['6'] = nn.Linear(4096,40)



    # Forward function for the inputs
    def forward(self, x):
        # If resnet, just pass the input through the entire netowrk
        # if self.use_resnet:
        #     # return self.net(x)
        #     y1 = self.net_1(x)
        #     y2 = self.net_2(y1)
        #     return self.net_3(y2.view(y2.shape[0],-1))
        # # If NOT resnet
        # # Store the extracted features in "y"
        # # Then pass the modified flattened "y" to the classifier
        # else:
        if (self.use_resnet) or (self.part_point < self.feature_extractors):
          # Pass through net_1
          y1 = self.net_1(x)
          # Pass through net_2
          y2 = self.net_2(y1)
          # Retain the first dimension of y i.e num_classes (=40)
          # (-1) in torch.view() automatically get the 2nd dimension by merging the other dimensions
          # For example, if original shape of y is [40 256, 6, 6], then after applying this, the shape
          # will be [40, 9216]
          return self.net_3(y2.view(y2.shape[0],-1))
        else:
          # Pass through net_1
          y1 = self.net_1(x)
          # Flatten y1 and Pass through net_2
          y2 = self.net_2(y1.view(y1.shape[0],-1))
          # Pass through net_3
          return self.net_3(y2)


# Batchwise MultiImgDataset

In [None]:
# Load multiview dataset
class Multiview_Dataset_Batch(torch.utils.data.Dataset):

    def __init__(self, root_dir, scale_aug=False, rot_aug=False, test_mode=True, \
                 num_models=0, num_views=12, shuffle=False, subsampling_factor = [1,1,1,1,1,1,1,1,1,1,1,1], refresh_interval=[1,1,1,1,1,1,1,1,1,1,1,1], \
                 mask_dir='/content/drive/MyDrive/Arghadip/MVCNN/AxIS_ref/Approx_DRAM/error_mask',interpolation=Image.NEAREST, \
                 start_idx=0, end_idx=2468):
        # Available classes in ModelNet40 dataset
        self.classnames=['airplane','bathtub','bed','bench','bookshelf','bottle','bowl','car','chair',
                         'cone','cup','curtain','desk','door','dresser','flower_pot','glass_box',
                         'guitar','keyboard','lamp','laptop','mantel','monitor','night_stand',
                         'person','piano','plant','radio','range_hood','sink','sofa','stairs',
                         'stool','table','tent','toilet','tv_stand','vase','wardrobe','xbox']
        # Get the arguments
        self.root_dir = root_dir
        self.scale_aug = scale_aug
        self.rot_aug = rot_aug
        self.test_mode = test_mode
        self.num_views = num_views
        # self.image_size = round(224/subsampling_factor)
        self.image_size = [round(224/ssf) for ssf in subsampling_factor]
        self.interpolation = interpolation

        # Approx Memory
        self.refresh_interval = refresh_interval
        self.mask_dir = mask_dir

        # "set_" determines the "train" or "test"
        set_ = root_dir.split('/')[-1]
        # Hold the name of the parent directory "modelnet40_images_new_12x"
        parent_dir = root_dir.rsplit('/',2)[0]
        self.filepaths = []
        self.num_models_per_category = []
        # Start loop for all the classes in the dataset
        for i in range(len(self.classnames)):
            # The glob module finds all the pathnames matching a specified pattern according to the rules
            # And then sort all the pathnames
            all_files = sorted(glob.glob(parent_dir+'/'+self.classnames[i]+'/'+set_+'/*.png'))
            # Changed from .png to .off
            # all_files = sorted(glob.glob(parent_dir+'/'+self.classnames[i]+'/'+set_+'/*.off'))

            ## Select subset for different number of views (existing comment)
            # Get the number of files to skip for a view from specific angle???
            stride = int(12/self.num_views) # 12 6 4 3 2 1
            # Skip the "stride" number of files for subset of views
            all_files = all_files[::stride]

            # There are 1000 iamges for each class

            # self.num_models_per_category.append(int(len(all_files)/num_views))

             # If there is no "model" for all classes (emtpty directory)
            if num_models == 0:
                # Use the whole dataset
                self.filepaths.extend(all_files)
                self.num_models_per_category.append(int(len(all_files)/num_views))
            else:
                # Use the "num_models" number of files
                self.filepaths.extend(all_files[:min(num_models,len(all_files))])
                # print(len(all_files[:min(num_models,len(all_files))]))
                self.num_models_per_category.append(int(len(all_files[:min(num_models,len(all_files))])/num_views))

            # ################### Log the number of models per category ##########
            # cat_dict = {
            #     'Category': self.classnames[i],
            #     'Cat_idx' : i,
            #     'Num_models': int(len(all_files)/num_views)
            # }
            # fields = cat_dict.keys()
            # logfile = '/content/data/models_per_category.csv'

            # with open(logfile, 'a') as csvfile:
            #   csvwriter = csv.DictWriter(csvfile, fieldnames=fields)
            #   if (i == 0):
            #     csvwriter.writeheader()
            #   csvwriter.writerow(cat_dict)
            # csvfile.close()

        if shuffle==True:
            # permute
            rand_idx = np.random.permutation(int(len(self.filepaths)/num_views))
            filepaths_new = []
            for i in range(len(rand_idx)):
                filepaths_new.extend(self.filepaths[rand_idx[i]*num_views:(rand_idx[i]+1)*num_views])
            self.filepaths = filepaths_new

        self.filepaths = self.filepaths[start_idx*num_views:end_idx*num_views]


        # If "test_mode=False"
        if not self.test_mode:
            # Ranomly flip the images horizontally
            # Convert the image to tensor and then normalize
            self.transform = transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
            ])

    # Utility function to get the number of elements in the dataset
    def __len__(self):
        return int(len(self.filepaths)/self.num_views)

    # Function to get "items" (all the views for a particular model) from the dataset
    def __getitem__(self, idx):
        # Get the "filepath" for the particular index
        path = self.filepaths[idx*self.num_views]
        # Get the "class name" by splitting the path and then picking the proper element
        class_name = path.split('/')[-3]
        # Get the integer index of the "class name"
        class_id = self.classnames.index(class_name)
        # Use PIL instead (existing comment)
        imgs = []
        # Start the loop for all the views
        for i in range(self.num_views):
            # Get the 2D image of the "i"-th view of the "idx"-th 3D object
            # Convert the image to RGB format
            im = Image.open(self.filepaths[idx*self.num_views+i]).convert('RGB')
            if self.test_mode:
              # print('Applied sensor subsampling, image size = ',self.image_size, ', interpolation = ', self.interpolation)
              # Downsample, then upsample: Mimics the effect of sensor subsampling
              self.transform = transforms.Compose([
                  # OLD (BILINEAR ALWAYS)
                  # transforms.Resize(self.image_size[i], interpolation=self.interpolation),   # Downsample
                  # transforms.Resize(224), # Upsample back to 224, don't put this line after ToTensor, doesn't mimic sensor subsampling (more data loss)

                  # NEW
                  transforms.Resize(self.image_size[i], interpolation=self.interpolation),   # Downsample
                  transforms.Resize(224, interpolation=self.interpolation), # Upsample back to 224, don't put this line after ToTensor, doesn't mimic sensor subsampling (more data loss)

                  MemoryApprox(self.refresh_interval[i], self.mask_dir, 224), # Approx Memory on image
                  transforms.ToTensor(),
                  transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                      std=[0.229, 0.224, 0.225])
              ])

            # Apply the transform on the image (Pre-processing)
            if self.transform:
              # print('Applied sensor subsampling transform!')
              im = self.transform(im)
            # Append the tensor "im" to the "imgs"
            imgs.append(im)

        # Return class_id, "num_views" number of tensors and filepaths
        return (class_id, torch.stack(imgs), self.filepaths[idx*self.num_views:(idx+1)*self.num_views])

### Sampler to load a batch with models from a particular class only
class DRAX_Classwise_Batch_Sampler(torch.utils.data.BatchSampler):
    def __init__(self, class_counts, batch_size):
        self.class_counts = class_counts
        self.batch_size = batch_size
        self.indices = list(range(sum(self.class_counts)))
        self.batch_sizes = self._genrate_batch_sizes()

    def _genrate_batch_sizes(self):
        batch_sizes = []
        for class_count in self.class_counts:
          num_batches = int(class_count / self.batch_size)
          remaining_models = int(class_count % self.batch_size)
          cur_batch_size = [self.batch_size]*num_batches
          cur_batch_size.extend([remaining_models])
          batch_sizes.extend(cur_batch_size)
        ## Filter out 0 indices
        batch_sizes = [x for x in batch_sizes if x != 0]
        return batch_sizes


    def __iter__(self):
        batch_start = 0
        for batch_size in self.batch_sizes:
            batch_indices = self.indices[batch_start:batch_start + batch_size]
            yield batch_indices
            batch_start += batch_size

    def __len__(self):
        return len(self.batch_sizes)

# SingleImgDataset

In [None]:
# Load singleview dataset
# Treats all the images as individual, i.e. different views of a single 3D object is treated as independent images
class SingleImgDataset(torch.utils.data.Dataset):

    def __init__(self, root_dir, scale_aug=False, rot_aug=False, test_mode=False, \
                 num_models=0, num_views=12):
        # "num_views" is NOT used
        self.classnames=['airplane','bathtub','bed','bench','bookshelf','bottle','bowl','car','chair',
                         'cone','cup','curtain','desk','door','dresser','flower_pot','glass_box',
                         'guitar','keyboard','lamp','laptop','mantel','monitor','night_stand',
                         'person','piano','plant','radio','range_hood','sink','sofa','stairs',
                         'stool','table','tent','toilet','tv_stand','vase','wardrobe','xbox']
        self.root_dir = root_dir
        self.scale_aug = scale_aug
        self.rot_aug = rot_aug
        self.test_mode = test_mode  # NOT USED, dummy argument

        set_ = root_dir.split('/')[-1]
        parent_dir = root_dir.rsplit('/',2)[0]
        self.filepaths = []
        for i in range(len(self.classnames)):
            all_files = sorted(glob.glob(parent_dir+'/'+self.classnames[i]+'/'+set_+'/*shaded*.png'))
            # Changed from .png to .off
            # all_files = sorted(glob.glob(parent_dir+'/'+self.classnames[i]+'/'+set_+'/*.off'))
            if num_models == 0:
                # Use the whole dataset
                self.filepaths.extend(all_files)
            else:
                self.filepaths.extend(all_files[:min(num_models,len(all_files))])

        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    # Gets the length of the dataset
    def __len__(self):
        return len(self.filepaths)

    # Function to get "item" from the dataset
    def __getitem__(self, idx):
        path = self.filepaths[idx]
        class_name = path.split('/')[-3]
        class_id = self.classnames.index(class_name)

        # Use PIL instead
        im = Image.open(self.filepaths[idx]).convert('RGB')
        if self.transform:
            im = self.transform(im)

        return (class_id, im, path)

# Logging Helpers

In [None]:
# Generate the group-wise knobs based significance
def generate_groupwise_knobs (num_views, d, knob_dom, knob_non_dom, significance, mode='best'):
  knobs = []
  # Get the max indices based on significance of nodes
  dom_indices = []
  if 0 < d < num_views:
    if (mode == 'best'):
      value_to_search = 13
      for i in range(d):
        value_to_search -= 1
        dom_indices.append(significance.index(value_to_search))
    elif (mode == 'worst'):
      value_to_search = 0
      for i in range(d):
        value_to_search += 1
        dom_indices.append(significance.index(value_to_search))
    elif (mode == 'alternate'):
      for i in range(d):
        dom_indices.append(min(i*2, num_views-1))
    elif (mode == 'random'):
      random.seed(seed)
      list_indices = [i for i in range(num_views)]
      random.shuffle(list_indices)
      for i in range(d):
        # temp = random.choice(list_indices)
        temp = list_indices[i]
        dom_indices.append(temp)
        list_indices.remove(temp)
    else:
      dom_indices = [i for i in range(num_views)]   # All dominant
  elif d == 0:
    dom_indices = []   # NO dominant
  elif d == num_views:
    dom_indices = [i for i in range(num_views)]   # All dominant

  for element in knob_non_dom:
    temp = []
    for i in range(num_views):
      if i in dom_indices:
        temp.append(knob_dom)
      else:
        temp.append(element)
    knobs.append(temp)
  if len(dom_indices) == 0:
    return knobs, 'NA'
  else:
    dom_nodes = [dom_index+1 for dom_index in dom_indices]
    return knobs, dom_nodes

# Function to get the knob for the non_dominant group
def get_knob_non_dom(knobs, knob_dom):
  unique_factors = set(knobs)
  unique_factor_list = list(unique_factors)
  if knob_dom in unique_factor_list:
    unique_factor_list.remove(knob_dom)
  else:
    return knob_dom
  if len(unique_factor_list) == 0:
    return knob_dom
  else:
    return unique_factor_list[0]

def write_run_summary(test_id, run_dir, name, cnn_name, num_views, n_models_test, part_point, batchSize, cur_batch_size, \
                      num_workers, \
                      load_svcnn, load_mvcnn, modelfile_svcnn, modelfile_mvcnn, subsampling_factor, interpolation, \
                      val_overall_acc, ss_mode, num_dom_nodes, dominant_nodes, compr_method, compr_knob, \
                      loader_mem_orig, loader_mem_compressed, SAD_loader, sf_dom, cf_dom, refresh_interval, ri_dom, \
                      sparsity, sp_dom,\
                      significance_category, significance, current_batch, running_test, total_iterations,\
                      correctly_classified_models):
  # Dictionary to log the parameters
  if (n_models_test == 0):
    log_num_models_t = 29616
  else:
    log_num_models_t = n_models_test

  image_size = [round(224/ssf) for ssf in subsampling_factor]

  sf_non_dom = get_knob_non_dom(subsampling_factor, sf_dom)
  cf_non_dom = get_knob_non_dom(compr_knob, cf_dom)
  ri_non_dom = get_knob_non_dom(refresh_interval, ri_dom)
  sp_non_dom = get_knob_non_dom(sparsity, sp_dom)

  if (compr_method == 'None') or (loader_mem_compressed == 0) or (loader_mem_orig == 0):
    SAD_norm = 0
    compr_ratio = 0
  else:
    SAD_norm = SAD_loader*2/loader_mem_orig
    compr_ratio = loader_mem_orig/loader_mem_compressed

  # Baseline accuracies dictionary
  baseline_acc = {
      'alexnet': 0.91572124,
      'vgg11': 0.9226094,
      'resnet34': 0.95097244
  }

  log_dict = {
              'test_name': name,
              'Serial_no.': running_test,
              'Total rows': total_iterations,
              'test id': test_id,
              'pooling point': part_point,

              'number_of_models_tested': log_num_models_t,
              'correctly_classified_models' : correctly_classified_models,

              'Batch Size (MVCNN)': batchSize,
              'Curr. batch size' : cur_batch_size,
              'Number of workers': num_workers,

              'Current_batch': current_batch,
              'Sig_category': significance_category,
              'significance_order': significance,

              'load_svcnn': load_svcnn,
              'load_mvcnn': load_mvcnn,
              'SVCNN model name': modelfile_svcnn,
              'MVCNN model name': modelfile_mvcnn,
              'image size': image_size,


              # Useful
              'cnn': cnn_name,
              'number of views': num_views,
  }

  for i in range(len(subsampling_factor)):
    key = 'SSF'+str(i+1)
    log_dict[key] = subsampling_factor[i]

  # Add Feature Map compression parameters
  for i in range(len(compr_knob)):
    key = 'CCF'+str(i+1)    # Communication Compression Factor
    log_dict[key] = compr_knob[i]

  # Add Approx Memory parameters
  for i in range(len(refresh_interval)):
    key = 'RI'+str(i+1)    # Communication Compression Factor
    log_dict[key] = refresh_interval[i]

  # Add Approx Cmpute parameters
  for i in range(len(sparsity)):
    key = 'SP'+str(i+1)    # Communication Compression Factor
    log_dict[key] = sparsity[i]

  # All approximation knobs
  log_dict['interpolation'] = interpolation
  log_dict['Compression Method'] = compr_method
  log_dict['Original IFM (bytes)'] = loader_mem_orig
  log_dict['Compressed IFM (bytes)'] = loader_mem_compressed
  log_dict['SAD'] = SAD_loader
  log_dict['SAD (Normalized)'] = SAD_norm

  log_dict['Choice'] = ss_mode
  log_dict['dominant_nodes'] = dominant_nodes
  log_dict['d'] = num_dom_nodes
  log_dict['SFd'] = sf_dom
  log_dict['SFn'] = sf_non_dom
  log_dict['CFd'] = cf_dom
  log_dict['CFn'] = cf_non_dom
  log_dict['RId'] = ri_dom
  log_dict['RIn'] = ri_non_dom
  log_dict['SPd'] = sp_dom
  log_dict['SPn'] = sp_non_dom
  log_dict['Compression Ratio'] = compr_ratio
  log_dict['Accuracy'] = val_overall_acc
  log_dict['Norm. Accuracy'] = float(val_overall_acc) / baseline_acc[cnn_name]

  fields = log_dict.keys()
  # logfile = run_dir + '/' + run_dir + '_' + cnn_name + '_Sensor_Subsampling_summary.csv'
  logfile = run_dir + '/' + 'Raw_log.csv'

  with open(logfile, 'a') as csvfile:
    csvwriter = csv.DictWriter(csvfile, fieldnames=fields)
    if (test_id == 'A1') and (current_batch == 0):
      csvwriter.writeheader()
    csvwriter.writerow(log_dict)
  csvfile.close()

# New helper functions (class-wise)

Loading significance (class-wise)

In [None]:
import pandas as pd
from scipy.stats import rankdata

class drax_significance:
  def __init__(self, sig_file_path, sig_methods):
    self.sig_file_path = sig_file_path
    self.sig_methods = sig_methods
    self.df_significance = pd.read_excel(sig_file_path, sheet_name=sig_methods, index_col=None, keep_default_na=False)
    self.categories = list(self.df_significance[sig_methods[0]]['class'])

  def get_class_significance(self, sig_method, category):
    if (sig_method not in self.sig_methods) or (category not in self.categories):
      raise ValueError("Invalid significance method OR category!")
      return None
    else:
      values = list(self.df_significance[sig_method].iloc[self.categories.index(category)])[-12:]
      # print(values)
      ranks = list(rankdata(values).astype(int))
      # print(ranks)
      return ranks
# class_sig = significances.get_class_significance('entropy', 'Average')
# print(class_sig)

functions to predict image class (not used) and test one batch

In [None]:
# Function to update validation accuracy
def predict_image_class(cnet, data, view):
  # print(data)
  ## data[0] contains correct categories for all items in the batch
  ## data[0][0] contains correct category for the first item in the batch
  correct_category = data[0][0].to(device)
  image = torch.unsqueeze(data[1][0][view-1], dim=0).to(device)
  # print(image.shape)
  out = cnet(image)
  # print(out.shape)
  pred_category = torch.max(out, 1)[1]

  return int(pred_category), int(correct_category)

# Function to get accuracy for one single batch
def test_one_batch(model, loader, model_name, num_views=12):
    # Set the model in "Evaluation" mode
    model.eval()
    batch_mem_orig, batch_mem_compressed, SAD_batch = [0,0,0]
    for data_batch in loader:
      # Send the data to CUDA based on MVCNN or SVCNN
      if model_name == 'mvcnn':
          N,V,C,H,W = data_batch[1].size()
          in_data = Variable(data_batch[1]).view(-1,C,H,W).to(device)
      else:#'svcnn'
          in_data = Variable(data_batch[1]).to(device)
      target = Variable(data_batch[0]).to(device)

      # Get output, calculate loss and result (Correct or Incorrect)
      if model_name == 'mvcnn':
        out_data, batch_mem_orig, batch_mem_compressed, SAD_batch = model(in_data)   # Updated for feature map compression (Doesn't work with MVCNN.py)
      else:
        out_data = model(in_data)   # Updated for feature map compression (Doesn't work with MVCNN.py)
      pred = torch.max(out_data, 1)[1]
      results = pred == target

      # The number of correct predictions in the batch
      correct_points = torch.sum(results.long())

      # Calculate overall acc
      acc = correct_points.float() / results.size()[0]
      batch_acc = acc.cpu().data.numpy()

      # print ('Batch acc. : ', batch_acc)

    # Return the logs
    return int(correct_points), results.size()[0], batch_acc, batch_mem_orig, batch_mem_compressed, SAD_batch

# from torch.utils.data import DataLoader
# root_dir = "/content/modelnet_test/*/test"
# mask_dir = '/content/error_mask'
# val_dataset = Multiview_Dataset_Batch(root_dir=root_dir, mask_dir=mask_dir, start_idx=0, end_idx=1)
# print(val_dataset[0][2])
# print(len(val_dataset))
# dataloader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=2)
# # # Iterate over the dataloader to access the data in batches
# # for i, batch in enumerate(dataloader):
# #     print(batch)

# ## Load the pretrained Single View CNN (SVCNN)
# cnn_name = name = 'resnet34'
# part_point = 5
# accurate_model_dir = '/content/drive/MyDrive/Arghadip/MVCNN/jongchyisu/mvcnn_pytorch_analysis/0_Journal/Logs/Training_aligned_modelnet/Version_2023_May_24/resnet34/resnet34_stage_1/resnet34'
# modelfile_svcnn = 'model-00010.pth'
# view_for_prediction = 1

# cnet = SVCNN(name, nclasses=40, pretraining=False, cnn_name=cnn_name, part_point=part_point)
# cnet.load(accurate_model_dir, modelfile_svcnn) # Load model parameters <FIXME>
# cnet = cnet.to(device)

# # for view in range(1,13):
# for i, batch in enumerate(dataloader):
#     pred_category, correct_category = predict_image_class(cnet, batch, view_for_prediction)
#     print(pred_category, correct_category)

In [None]:
# batchSize_SV = 1
# num_workers = 2
# optimizer = optim.Adam(cnet.parameters(), lr=5e-5, weight_decay=0.001)
# val_dataset = SingleImgDataset(root_dir, scale_aug=False, rot_aug=False, num_models=0, test_mode=True)
# val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batchSize_SV, shuffle=False, num_workers=num_workers)
# log_dir = '/content/drive/MyDrive/Arghadip/MVCNN/jongchyisu/mvcnn_pytorch_analysis/0_Journal/Logs/Debug'
# trainer = ModelNetTrainer(cnet, val_loader, val_loader, optimizer, nn.CrossEntropyLoss(), 'svcnn', log_dir, num_views=1)
# trainer.update_validation_accuracy(1)

Load models

In [None]:
def get_model_names (ri_dom, sp_dom, refresh_interval, sparsity):
  ## Get the unique list of sparsities and refresh intervals
  unique_intervals_list = list(set(refresh_interval))
  unique_sparsities_list = list(set(sparsity))

  ## Initialization
  file_name_dom = file_name_nondom = None

  ## Return file names for dom and non-dom groups
  if (len(unique_intervals_list) == 1) and (len(unique_sparsities_list) == 1):
    ## If accurate
    ## NOT covered here, loaded every-time

    # else if approximate but all equal (load once)
    if ((unique_intervals_list[0] == 0) and (not (unique_sparsities_list[0] == 0))):
      file_name_dom = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(1) + '_sparsity_' + str(unique_sparsities_list[0]) + '.pth.tar'    # RI = 0 and 1 has no errors
    else:
      file_name_dom = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(unique_intervals_list[0]) + '_sparsity_' + str(unique_sparsities_list[0]) + '.pth.tar'

    file_name_nondom = file_name_dom

  elif (len(unique_intervals_list) <= 2) and (len(unique_sparsities_list) <= 2):   # When operating with group-wise same knob settings
    ri_nondom = _get_non_dom_knob (unique_intervals_list, ri_dom)
    sp_nondom = _get_non_dom_knob (unique_sparsities_list, sp_dom)
    file_name_dom = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(ri_dom) + '_sparsity_' + str(sp_dom) + '.pth.tar'
    file_name_nondom = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(ri_nondom) + '_sparsity_' + str(sp_nondom) + '.pth.tar'

  else:
    pass

  return file_name_dom, file_name_nondom

def load_relevant_models (cnn_name, part_point, sps, ris, approx_model_dir):
  approx_model_dict = {}
  for sp in sps:
    for ri in ris:
      filename = 'approx_' + cnn_name + '_PP_' + str(part_point) + '_dram' + str(ri) + '_sparsity_' + str(sp) + '.pth.tar'
      ## If no such file exists
      if filename not in os.listdir(approx_model_dir):
        raise FileNotFoundError("File does not exist: {}".format(filename))
      ## Else
      else:
        temp = distiller.models.create_model(pretrained=True, dataset='modelnet', arch= cnn_name + '_mvcnn', parallel=False, device_ids=-1)
        load_checkpoint(temp, os.path.join(approx_model_dir,filename))
        print('Loaded model from ', filename)
        temp.eval()
        ## Add model parameters to the dictionary
        approx_model_dict[filename] = temp
        del temp
  # print(approx_model_dict.keys())
  return approx_model_dict

In [None]:
# cnn_name = 'resnet34'
# part_point = 5
# sps = [1.0, 2.5]
# ris = [1,10]
# approx_model_dir = '/content/Approx_models'
# approx_models = load_relevant_models (cnn_name, part_point, sps, ris, approx_model_dir)

New logging helpers

In [None]:
def generate_timing_log(run_dir, test_id, cnn_name, num_views, timer0_0, timer0_1, timer1_0, timer1_1, timer1_2, \
                        timer2_0, timer2_1, timer2_2, timer2_3, timer2_4):
  ## Form a dictionary to log the timings
  timing_dict = {
      'TID': test_id,
      'CNN': cnn_name,
      'Num_views': num_views,
      't_dataset_init': (timer0_1 - timer0_0).total_seconds(),
      't_sig_det': (timer1_1 - timer1_0).total_seconds(),
      't_knob_gen': (timer1_2 - timer1_1).total_seconds(),
      't_CNN': (timer2_1 - timer2_0).total_seconds(),
      't_dataset_batch': (timer2_2 - timer2_1).total_seconds(),
      't_testing': (timer2_3 - timer2_2).total_seconds(),
      't_logging': (timer2_4 - timer2_3).total_seconds(),
  }
  fields = timing_dict.keys()
  logfile = run_dir + '/' + 'Timing_log.csv'

  with open(logfile, 'a') as csvfile:
    csvwriter = csv.DictWriter(csvfile, fieldnames=fields)
    if (test_id == 'A1') and (current_batch == 0):
      csvwriter.writeheader()
    csvwriter.writerow(timing_dict)
  csvfile.close()


def generate_summary_log(run_dir):
  ### Forming the summary log file
  csv_path = os.path.join(run_dir,'Raw_log.csv')
  data = pd.read_csv(csv_path)
  # Baseline accuracies dictionary
  baseline_acc = {
      'alexnet': 0.91572124,
      'vgg11': 0.9226094,
      'resnet34': 0.95097244
  }
  ## Get the all possible test ids
  TIDs = data['test id'].unique()
  print(TIDs)
  test_counter = 0
  for TID in TIDs:
    test_counter += 1
    # Get only the row where the TID matches
    data_tid = data[data['test id'] == TID]
    # Form a dictionary with the first row
    log_dict = data_tid.iloc[0]['cnn':'Norm. Accuracy'].to_dict()
    # Update 'Original IFM (bytes)', 'Compressed IFM (bytes)', 'SAD', 'SAD (Normalized)', 'dominant_nodes', 'Compression Ratio', 'Accuracy', 'Norm. Accuracy'
    log_dict['Original IFM (bytes)'] = data_tid['Original IFM (bytes)'].sum()
    log_dict['Compressed IFM (bytes)'] = data_tid['Compressed IFM (bytes)'].sum()
    log_dict['SAD'] = data_tid['SAD'].sum()

    if (log_dict['Compression Method'] == 'None') or (log_dict['Original IFM (bytes)'] == 0) or (log_dict['Compressed IFM (bytes)'] == 0):
      log_dict['SAD (Normalized)'] = 0
      log_dict['Compression Ratio'] = 0
    else:
      log_dict['SAD (Normalized)'] = log_dict['SAD']*2/log_dict['Original IFM (bytes)']
      log_dict['Compression Ratio'] = log_dict['Original IFM (bytes)']/log_dict['Compressed IFM (bytes)']

    log_dict['dominant_nodes'] = 'Batchwise'

    total_models = data_tid['number_of_models_tested'].sum()
    correctly_classified = data_tid['correctly_classified_models'].sum()
    log_dict['Accuracy'] = correctly_classified/total_models
    log_dict['Norm. Accuracy'] = log_dict['Accuracy']/baseline_acc[log_dict['cnn']]

    fields = log_dict.keys()
    # logfile = run_dir + '/' + run_dir + '_' + cnn_name + '_Sensor_Subsampling_summary.csv'
    logfile = run_dir + '/' + 'Run_summary.csv'

    with open(logfile, 'a') as csvfile:
      csvwriter = csv.DictWriter(csvfile, fieldnames=fields)
      if (test_counter == 1):
        csvwriter.writeheader()
      csvwriter.writerow(log_dict)
    csvfile.close()

# Category-wise significance

User inputs

In [None]:
########################## Primary Inputs ######################################
cnn_name = "resnet34" #"alexnet"   # "vgg11"
expt_name = 'cat_en_corr'
# Batching
batchSize = 50

########################## CHANGE THIS ######################################
## Method to determine significance
significance_method = 'entropy'
# View for significance
view_to_determine_sig = None

## Grouping strategy
## "best" = SAGA
## "worst" = Anti-SAGA
## "random" = In between SAGA and Anti-SAGA
mode_list = ['best'] #['best', 'random', 'worst']

## Log timing or not
log_timing = False

# d
num_dom_nodes = [1, 2, 4, 6]
# SPd
sps_dom = [2.5] # 5.0
# RId
ris_dom = [1 10, 20]

# SFd
sfs_dom = [1, 1.15, 1.2]
# CFd
cfs_dom = [5, 10, 20]

compr_method_list = ['zfp'] #'fpzip'

Code body

In [None]:
###################### High level parameters ##################################
###############################################################################
tz = timezone('US/Eastern')
now = datetime.now(tz)

part_points_dict = {
                  'alexnet' : 5,
                  'vgg11'   : 10,
                  'resnet34': 5
}

accurate_models_dict = {
                  'alexnet_PP_5' : 'Alexnet_MVCNN_PP_5_pkl.pt',
                  'vgg11_PP_10'  : 'Vgg11_MVCNN_PP_10_pkl.pt',
                  'resnet34_PP_5': 'Resnet34_MVCNN_PP_5_pkl.pt'
}

#############################################################################
# SFn
sf_non_dom = sfs_dom #[1.1,1.15]

# CFn
cf_non_dom = cfs_dom #[5,10]

# RIn
ri_non_dom = [1, 10, 20]

# SPn
sp_non_dom = sps_dom    # Must be the same

################################# Secondary Inputs #############################
## Significance
sig_file_path = '/category_wise_significance.xlsx'
sig_methods = ['entropy']
significances = drax_significance(sig_file_path, sig_methods)

num_workers = 2
part_point = part_points_dict[cnn_name]
num_views = 12
name = cnn_name

run_name = cnn_name + '_' + expt_name + '_PP_' + str(part_point) + '_' + str(now.strftime('%Y-%m-%d_%H:%M:%S'))

# run_dir = '/content/drive/MyDrive/Arghadip/MVCNN/jongchyisu/mvcnn_pytorch_analysis/0_Journal/Logs/Classwise_significance/' + run_name
run_dir = '/data/' + run_name

# Entire test list
# significance = significances_dict[cnn_name + '_PP_' + str(part_point)]
interpolations = ['nearest'] #, 'bilinear', 'bicubic']
covered_drax_space = []
mask_dir = '/error_mask'
approx_model_dir = "/Approx_models"
approx_models = load_relevant_models (cnn_name, part_point, sps_dom, ri_non_dom, approx_model_dir)

# Number of 3D models to test, 0 means the whole dataset
n_models_test = 0
# n_models_test = X * num_views

accurate_model_dir= '/acc_models'

svcnn_load_path = 'dummy'
modelfile_svcnn = 'model-00010.pth'
modelfile_mvcnn = accurate_models_dict[cnn_name + '_PP_' + str(part_point)]

###################### Infrequent parameters ##################################
###############################################################################
# Parameters which are not frequently changed
lr = 5e-5
weight_decay = 0.001
no_pretraining = True
val_path = "/modelnet_test/*/test"

###################### Helper functions #######################################
###############################################################################
# Defined function to create log directory
def create_folder(log_dir):
    # make summary folder
    if not os.path.exists(log_dir):
        os.mkdir(log_dir)
    else:
        print('WARNING: summary folder already exists!! It will be overwritten!!')
        shutil.rmtree(log_dir)
        os.mkdir(log_dir)

create_folder(run_dir)

test_id_1 = 0
running_test = 0
## Load the pretrained Single View CNN (SVCNN)
cnet = SVCNN(name, nclasses=40, pretraining=False, cnn_name=cnn_name, part_point=part_point)
if view_to_determine_sig is not None:
  cnet.load(svcnn_load_path, modelfile_svcnn) # Load model parameters
  cnet = cnet.to(device)

for ss_mode in mode_list:
  for num_dom_node in num_dom_nodes:
    for compr_method in compr_method_list:
      for interpolation in interpolations:
        ###################### Derived parameters #####################################
        ###############################################################################
        # Derived parameters
        # n_models_test =  n_models_test * num_views
        if (interpolation == 'nearest'):
          interpolation_mode = InterpolationMode.NEAREST
          # interpolation_mode = Image.NEAREST
        elif (interpolation == 'bilinear'):
          interpolation_mode = InterpolationMode.BILINEAR
          # interpolation_mode = Image.BILINEAR
        elif (interpolation == 'bicubic'):
          interpolation_mode = InterpolationMode.BILINEAR
          # interpolation_mode = Image.BICUBIC
        for sp_dom in sps_dom:
          for ri_dom in ris_dom:
            for cf_dom in cfs_dom:
              for sf_dom in sfs_dom:
                test_id_1 += 1
                test_id_1_chr = chr(test_id_1 + 64)   # 64 is added to start with 'A'

                ## TO-DO1: TRIGGER THE SIGNIFICANCE DETERMINATION STEP IN EVERY FEW BATCHES
                ## TO-DO2: PUT A BYPASS LOOP FOR THE BATCH-WISE SIGNIFICANCE DETERMINATION (LOW PRIORITY)

                ## Timer
                if log_timing:
                  timer0_0 = datetime.now(tz)

                ## Load the whole ModelNet40 dataset (NO shuffling)
                val_dataset = Multiview_Dataset_Batch(val_path, scale_aug=False, rot_aug=False, test_mode=True,
                                            num_models=n_models_test, num_views=num_views, shuffle=False,
                                            mask_dir=mask_dir, interpolation=interpolation_mode)
                print(f"Whole val dataset content is {len(val_dataset)}")
                # Create a sampler
                batch_sampler = DRAX_Classwise_Batch_Sampler(class_counts=val_dataset.num_models_per_category, batch_size=batchSize)
                ## Create a dataloader (NO shuffling)
                val_loader = torch.utils.data.DataLoader(val_dataset, batch_sampler=batch_sampler, shuffle=False, num_workers=num_workers)
                total_batches = len(val_loader)

                # dd/mm/YY H:M:S
                timer0_1 = datetime.now(tz)
                dt_string = timer0_1.strftime("%d/%m/%Y %H:%M:%S")
                print("date and time =", dt_string)

                ## Start looping through each batch
                start_index = 0
                for current_batch, data in enumerate(val_loader):
                  ## For each batch, in stage-1, we predict the class of the first model in the batch
                  # print("################ STARTING STAGE 1 ###################")

                  ## Get the end index of the dataset
                  cur_batch_size = len(data[0])
                  end_index = start_index + cur_batch_size

                  if log_timing:
                    timer1_0 = datetime.now(tz)

                  ## Pass the 2D image and cnet to the test function, it returns the PREDICTED class of the image
                  ## Select a particular view (the one which is the overall most significanct)
                  if view_to_determine_sig is not None:
                    pred_category, correct_category = predict_image_class(cnet, data, view_to_determine_sig)
                    significance_category = val_dataset.classnames[pred_category]
                  else:
                    correct_category = int(data[0][0].to(device))
                    significance_category = val_dataset.classnames[correct_category]

                  ## Get the class-wise significance using the predicted class
                  significance = significances.get_class_significance(significance_method, significance_category)

                  if log_timing:
                    timer1_1 = datetime.now(tz)

                  ## Generate the combination of approximation knobs using the given significance of views
                  sparsities, dominant_nodes = generate_groupwise_knobs (num_views, num_dom_node, sp_dom, sp_non_dom, significance, ss_mode)
                  refresh_intervals, dominant_nodes = generate_groupwise_knobs (num_views, num_dom_node, ri_dom, ri_non_dom, significance, ss_mode)
                  compr_knob_list, dominant_nodes = generate_groupwise_knobs (num_views, num_dom_node, cf_dom, cf_non_dom, significance, ss_mode)
                  subsampling_factors, dominant_nodes = generate_groupwise_knobs(num_views, num_dom_node, sf_dom, sf_non_dom, significance, ss_mode)

                  if log_timing:
                    timer1_2 = datetime.now(tz)

                  experiments_per_batch = int(len(refresh_intervals)*len(compr_knob_list)*len(subsampling_factors)*len(sparsities)/len(sps_dom))

                  ## Start exploring all combination of approx knob settings for this particular batch
                  test_id_2 = 0
                  for sparsity in sparsities:
                    for refresh_interval in refresh_intervals:
                      for compr_knob in compr_knob_list:
                        for subsampling_factor in subsampling_factors:
                          if (len(set(sparsity)) != 1) or ([ss_mode, num_dom_node, compr_method, interpolation, sp_dom, ri_dom, cf_dom, sf_dom, sparsity, refresh_interval, compr_knob, subsampling_factor] in covered_drax_space):
                            continue
                          # print(subsampling_factor)
                          # Test id
                          total_iterations = int(len(compr_method_list)*len(mode_list)*len(ris_dom)*len(cfs_dom)*len(sfs_dom)\
                                            *len(sps_dom)*len(num_dom_nodes)*len(refresh_intervals)*len(compr_knob_list)\
                                            *len(interpolations)*len(subsampling_factors)*len(sparsities)*total_batches/len(sps_dom))

                          test_id_2 = test_id_2 + 1
                          test_id = test_id_1_chr + str(test_id_2)

                          running_test += 1

                          # STAGE 2
                          # print("################ STARTING STAGE 2 ###################")

                          if log_timing:
                            timer2_0 = datetime.now(tz)

                          # Load CNN model using MVCNN model
                          modelname_dom, modelname_nondom = get_model_names (ri_dom, sp_dom, refresh_interval, sparsity)
                          cnet_2 = MVCNN_approx_opt(name, cnet, nclasses=40, accurate_models_dict=accurate_models_dict,\
                                                    cnn_name=cnn_name, num_views=num_views, part_point=part_point, \
                                                    compr_method=compr_method, compr_knob=compr_knob, \
                                                    refresh_interval=refresh_interval, \
                                                    approx_model_dir=approx_model_dir, \
                                                    accurate_model_dir=accurate_model_dir, sparsity=sparsity, \
                                                    ri_dom=ri_dom, sp_dom=sp_dom,\
                                                    model_dom=approx_models[modelname_dom], model_nondom=approx_models[modelname_nondom]).to(device)

                          if log_timing:
                            timer2_1 = datetime.now(tz)

                          modelfile_mvcnn = cnet_2.file_name

                          ## Reload the batch images (all views) with proper sensor and memory approx
                          val_dataset_batch = Multiview_Dataset_Batch(val_path, scale_aug=False, rot_aug=False, test_mode=True,
                                              num_models=n_models_test, num_views=num_views, subsampling_factor=subsampling_factor,
                                              refresh_interval=refresh_interval, mask_dir=mask_dir, interpolation=interpolation_mode,
                                              shuffle=False, start_idx=start_index, end_idx=end_index)
                          # print(f"val dataset content is {len(val_dataset_batch)}")
                          ## Validation loader (no sampler needed, taken care of by using start and end index)
                          val_loader_batch = torch.utils.data.DataLoader(val_dataset_batch, batch_size=cur_batch_size, shuffle=False, num_workers=num_workers)
                          # print(val_loader_batch.batch_size, val_dataset_batch[0][0], val_dataset_batch[cur_batch_size-1][0])
                          # print('num_val_files: '+str(len(val_dataset_batch.filepaths)))

                          if log_timing:
                            timer2_2 = datetime.now(tz)

                          ## Test the models, returns category-wise total and correctly classified models
                          correctly_classified_models , num_test_models, val_overall_acc, loader_mem_orig, loader_mem_compressed, SAD_loader = \
                          test_one_batch(cnet_2, val_loader_batch, 'mvcnn', num_views=num_views)

                          ## Update class-wise running total and correctly classified models

                          # print('Itr ', running_test, '/', total_iterations, ' | TID: ', test_id, \
                          #       ' | Batch: ', current_batch, ' | Sig_Cat: ', significance_category, ' | Accuracy: ', val_overall_acc)

                          if log_timing:
                            timer2_3 = datetime.now(tz)

                          # Dump to run summary
                          write_run_summary(test_id, run_dir, expt_name, cnn_name, num_views, num_test_models,
                                            part_point, batchSize, val_loader_batch.batch_size, num_workers, True,
                                            True, modelfile_svcnn, modelfile_mvcnn,
                                            subsampling_factor, interpolation,
                                            val_overall_acc, ss_mode, num_dom_node, dominant_nodes, compr_method, compr_knob,
                                            loader_mem_orig, loader_mem_compressed, SAD_loader, sf_dom, cf_dom,
                                            refresh_interval, ri_dom, sparsity, sp_dom,
                                            significance_category, significance, current_batch, running_test, total_iterations,
                                            correctly_classified_models)

                          if log_timing:
                            timer2_4 = datetime.now(tz)
                            generate_timing_log(run_dir, test_id, cnn_name, num_views, timer0_0, timer0_1, timer1_0, timer1_1, timer1_2, \
                            timer2_0, timer2_1, timer2_2, timer2_3, timer2_4)

                  print('Itr ', running_test, '/', total_iterations, ' | Batch ', current_batch+1,  '/', total_batches, ' | Sig_Cat: ', significance_category)
                  start_index = end_index

                final_now = datetime.now(tz)
                dt_string = final_now.strftime("%d/%m/%Y %H:%M:%S")
                print(f'Finished test {test_id_1} at {dt_string}')
                duration = final_now - timer0_1

                ## Duration per test
                duration_in_s = duration.total_seconds()/experiments_per_batch
                days    = divmod(duration_in_s, 86400)        # Get days (without [0]!)
                hours   = divmod(days[1], 3600)               # Use remainder of days to calc hours
                minutes = divmod(hours[1], 60)                # Use remainder of hours to calc minutes
                seconds = divmod(minutes[1], 1)               # Use remainder of minutes to calc seconds
                print("Time taken (per test): %d days, %d hours, %d minutes and %d seconds" % (days[0], hours[0], minutes[0], seconds[0]))

                # del val_dataset
                # del val_loader

## Update the overall accuracy and compression details for a particular combination of knob settings
generate_summary_log(run_dir)
print('\n\n\n\nExperiments are completed!')

In [None]:
from google.colab import runtime
runtime.unassign()