In [1]:
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import time
import matplotlib.pyplot as plt
from torch import Tensor 
import torchvision.transforms as transforms
import os
import sys
from collections import Counter
from functools import partial
import pandas as pd
import itertools
from itertools import combinations

sys.path.append('ins_folder/codes')
from inference import load_dataset
from helper_functions_last import get_lookup_tables, get_topology_info, add_estimated_exitrate_validation,load_partitioning_list,\
    add_estimated_accuracy_validation, add_estimated_computation, add_estimated_pickle_size, add_estimated_communication, save_df,\
        add_estimated_e2e_latency


In [2]:
### some info
dataset_info_dict = {'CIFAR10': {'num_class': 10, 'input_size': 32, 'num_channels': 3, 'validation_size': 5000, 'test_size': 5000},
                    'CIFAR100': {'num_class': 100, 'input_size': 32, 'num_channels': 3, 'validation_size': 5000, 'test_size': 5000},
                    'ImageNet': {'num_class': 1000, 'input_size_B0': 224, 'input_size_B7': 600 ,'crop_size_B0': 256, 'crop_size_B7': 600, 'num_channels': 3, 'mean': [0.485, 0.456, 0.406], 'std': [0.229, 0.224, 0.225], 'validation_size': 25000, 'test_size': 25000},
                    'AudioSet': {'num_class': 527, 'input_size': [96, 64], 'num_channels': 1},
                    }
    
model_info_dict = {'ResNet20': {'model_n': 3, 'branch_number': 10},
                    'ResNet110': {'model_n': 18, 'branch_number': 55},
                    'EfficientNetB0': {'branch_number': 8, 'dropout': 0.2, 'width_mult': 1.0, 'depth_mult': 1.0, 'norm_layer': nn.BatchNorm2d},
                    'EfficientNetB7': {'branch_number': 8, 'dropout': 0.5, 'width_mult': 2.0, 'depth_mult': 3.1, 'norm_layer': partial(nn.BatchNorm2d, eps=0.001, momentum=0.01)},
                    'VGGish': {'branch_number': 4}, 
                    }

source_tier_ind = 0
destination_tier_ind = 0
latency_between_tiers_list = [[20, 20, 20]]
bw_between_tiers_list = [[10, 500, 500]]


dataset = 'CIFAR10'
model = 'ResNet110'
model_mode = 'full'

threshold_list =  [0, 0.1, 0.2, 0.3, 0.4, 0.6, 0.7, 0.8, 0.9, 1]
placement_list = [[0], [1], [2], [3], [0, 1], [0, 2], [0, 3], [1, 0], [1, 2], [1, 3], [2, 0], [2, 1], [2, 3], [3, 0], [3, 1], [3, 2],
                  [0, 1, 2], [0, 1, 3], [0, 2, 1], [0, 2, 3], [0, 3, 1], [0, 3, 2],
                  [1, 0, 2], [1, 0, 3], [1, 2, 0], [1, 2, 3], [1, 3, 0], [1, 3, 2],
                  [2, 0, 1], [2, 0, 3], [2, 1, 0], [2, 1, 3], [2, 3, 0], [2, 3, 1],
                  [3, 0, 1], [3, 0, 2], [3, 1, 0], [3, 1, 2], [3, 2, 0], [3, 2, 1],
                  [0, 1, 2, 3], [3, 2, 1, 0],
]

# placement_list = [[0], [1], [2], [0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1],
#              [0, 1, 2], [0, 2, 1], [1, 0, 2], [1, 2, 0], [2, 0, 1], [2, 1, 0],
#              ]

partial_exec = True # whether to use entire model for inference or not

In [3]:
total_list = []

all_list =  list(itertools.product(threshold_list, placement_list, latency_between_tiers_list, bw_between_tiers_list))
for threshold, placement, latency_between_tiers, bw_between_tiers  in all_list:


    for partition in combinations(range(1, model_info_dict[model]['branch_number']+1), len(placement)):
        
        # if partition != (3, 8):
        #     continue

        if model_info_dict[model]['branch_number'] not in partition and not partial_exec:
            continue
        
        if len(partition) == 1 and threshold != 0:
            continue

        # print('threshold: ', threshold, 'placement: ', placement, 'partitioning: ', partition)
        beg = 1
        partitioning_list = []
        for p in partition:

            partitioning_list.append([beg, p])
            beg = p+1
        
        if len(partitioning_list) == 1 and threshold != 0:
            continue
        
        
        partial_flag = False
        if partitioning_list[-1][-1] != model_info_dict[model]['branch_number']:
            partial_flag = True
        
        # print('threshold: ', threshold, 'placement: ', placement, 'partitioning: ', partitioning_list)
        total_list.append([threshold, bw_between_tiers, latency_between_tiers,\
                source_tier_ind, destination_tier_ind, dataset, model, model_mode,\
                placement, partitioning_list, len(placement), partial_flag])

print('number of cases ', len(total_list))

number of cases  13295920


In [4]:

total_df = pd.DataFrame(total_list, columns = ['threshold',
                                    'bw_list', 'latency_list',\
                                    'source_tier', 'destination_tier', 'dataset', 'model', 'model_mode',\
                                    'placement', 'partitioning', 'num_partitions', 'partial_flag'])



total_df = add_estimated_exitrate_validation(total_df)
total_df = add_estimated_accuracy_validation(total_df)
total_df = add_estimated_computation(total_df)
total_df = add_estimated_pickle_size(total_df)
total_df = add_estimated_communication(total_df)
total_df = add_estimated_e2e_latency(total_df)


KeyboardInterrupt: 

In [5]:
rows = total_df.loc[total_df['estimated_accuracy']>0.87]
rows = rows.sort_values(by=['estimated_e2e_latency', 'estimated_accuracy'], ascending=[True, False])
with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also

    display(rows[['estimated_e2e_latency', 'estimated_accuracy', 'placement', 'partitioning', 'threshold', 'bw_list', 'estimated_pickle_size', 'estimated_exit_rate']])

Unnamed: 0,estimated_e2e_latency,estimated_accuracy,placement,partitioning,threshold,bw_list,estimated_pickle_size,estimated_exit_rate
8,15.893975,0.8768,[0],"[[1, 9]]",0.0,"[10, 500]","[13144, 244]",[1]
7994,15.893975,0.8768,"[0, 1]","[[1, 9], [10, 10]]",0.9,"[10, 500]","[13155, 17301, 297]","[1.0, 0.0]"
8039,15.893975,0.8768,"[0, 2]","[[1, 9], [10, 10]]",0.9,"[10, 500]","[13155, 17301, 297]","[1.0, 0.0]"
8984,15.893975,0.8768,"[0, 1]","[[1, 9], [10, 10]]",1.0,"[10, 500]","[13155, 17301, 297]","[1.0, 0.0]"
9029,15.893975,0.8768,"[0, 2]","[[1, 9], [10, 10]]",1.0,"[10, 500]","[13155, 17301, 297]","[1.0, 0.0]"
7004,15.905032,0.876799,"[0, 1]","[[1, 9], [10, 10]]",0.8,"[10, 500]","[13153, 17299, 297]","[0.9998, 0.0002]"
7049,15.913002,0.876799,"[0, 2]","[[1, 9], [10, 10]]",0.8,"[10, 500]","[13155, 17301, 297]","[0.9998, 0.0002]"
6014,15.949261,0.876796,"[0, 1]","[[1, 9], [10, 10]]",0.7,"[10, 500]","[13153, 17299, 297]","[0.999, 0.001]"
6059,15.989108,0.876796,"[0, 2]","[[1, 9], [10, 10]]",0.7,"[10, 500]","[13155, 17301, 297]","[0.999, 0.001]"
5024,16.070895,0.876786,"[0, 1]","[[1, 9], [10, 10]]",0.6,"[10, 500]","[13155, 17301, 297]","[0.9968, 0.0032]"
