In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import sys
import cv2
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import optimizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import LearningRateScheduler
from keras.applications import imagenet_utils
from keras import backend
from keras.layers import VersionAwareLayers
from itertools import permutations, combinations
layers = VersionAwareLayers()



In [2]:
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
config.log_device_placement = True  # to log device placement (on which device the operation ran)
# config.gpu_options.per_process_gpu_memory_fraction = 0.60
sess = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(sess)

Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: GeForce RTX 3080, pci bus id: 0000:41:00.0, compute capability: 8.6



In [4]:
#### Using CIFAR10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Preprocess the data (these are NumPy arrays)
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255


y_train = y_train.astype("float32")
y_test = y_test.astype("float32")

y_train = keras.utils.to_categorical(y_train, num_classes=10)
y_test = keras.utils.to_categorical(y_test, num_classes=10) 

global_batch_size = 32
image_resize = 224

########## Train
#### removing data augmentation only here for profiling
train_datagen = ImageDataGenerator()

# train_datagen = ImageDataGenerator(featurewise_center = True,
#                                    rotation_range = 20, horizontal_flip = True, height_shift_range = 0.2,
#                                    width_shift_range = 0.2, zoom_range = 0.2, channel_shift_range = 0.2)


train_it_128 = train_datagen.flow_from_directory(
        'cifar10/train',
        class_mode='categorical',
        target_size=(128, 128),
        batch_size=global_batch_size)

train_it_224 = train_datagen.flow_from_directory(
        'cifar10/train',
        class_mode='categorical',
        target_size=(224, 224),
        batch_size=global_batch_size)


############ Test
test_datagen = ImageDataGenerator( validation_split=0.5)

validation_it = test_datagen.flow_from_directory(
        'cifar10/test',
        class_mode='categorical',
        target_size=(image_resize, image_resize),
        batch_size=global_batch_size,
        subset = "training",seed = 545)

test_it = test_datagen.flow_from_directory(
        'cifar10/test',
        class_mode='categorical',
        target_size=(image_resize, image_resize),
        batch_size=global_batch_size,
        subset = "validation",
        seed = 545)


# for data_batch, labels_batch in train_it:
#     print('data batch shape:', data_batch.shape)
#     print('labels batch shape:', labels_batch.shape)
#     break



Found 50000 images belonging to 10 classes.
Found 50000 images belonging to 10 classes.
Found 5000 images belonging to 10 classes.
Found 5000 images belonging to 10 classes.


In [11]:
class CustomLoss(keras.losses.Loss):
    def __init__(self, factor):
        super().__init__()
        self.factor = factor

    def call(self, y_true, y_pred):
        ce = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
        return ce * self.factor


# activation functions
def relu(x):
    return layers.ReLU()(x)

def hard_sigmoid(x):
    return layers.ReLU(6.)(x + 3.) * (1. / 6.)

def hard_swish(x):
    return layers.Multiply()([x, hard_sigmoid(x)])

# this function ensures that all layers have a channel number that is divisible by 8
def _depth(v, divisor=8, min_value=None):
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v




class _se_block(tf.keras.layers.Layer):
    def __init__(self, filters, se_ratio, prefix, **kwargs):
        super(_se_block, self).__init__(**kwargs)
        self.prefix = prefix
        
        self.se_conv = layers.Conv2D(_depth(filters * se_ratio), kernel_size=1, padding='same', name=prefix + 'squeeze_excite/Conv')
        self.se_conv1 = layers.Conv2D(filters, kernel_size=1, padding='same', name=prefix + 'squeeze_excite/Conv_1')
        
        
    def call(self, inputs, training):
        
        x = layers.GlobalAveragePooling2D(name=self.prefix + 'squeeze_excite/AvgPool')(inputs)
        
        ### not sure, it's instead of keep_dim, which is not supported in tf 2.5
        x = layers.Reshape((1, 1, inputs.shape[-1]))(x)
        ###
        
        x = self.se_conv(x)
        x = layers.ReLU(name=self.prefix + 'squeeze_excite/Relu')(x)
        x = self.se_conv1(x)
        x = hard_sigmoid(x)
        x = layers.Multiply(name=self.prefix + 'squeeze_excite/Mul')([inputs, x])
        
        return x

class _inverted_res_block(tf.keras.layers.Layer):
    def __init__(self, expansion, filters, kernel_size, stride, se_ratio, activation, block_id, infilters, **kwargs):
        super(_inverted_res_block, self).__init__(**kwargs)
        

        self.filters = filters
        self.kernel_size = kernel_size
        self.stride = stride
        self.se_ratio = se_ratio
        self.block_id = block_id
        self.activation = activation
        self.infilters = infilters
        prefix = 'expanded_conv/'
        channel_axis = -1
        
        if self.block_id:
            self.expand_conv1 = layers.Conv2D(_depth(infilters * expansion), kernel_size=1, padding='same', use_bias=False, name=prefix + 'expand')
            self.expand_bn1 = layers.BatchNormalization(axis=channel_axis, epsilon=1e-3, momentum=0.999, name=prefix + 'expand/BatchNorm')
            

        self.expand_depthwise_conv = layers.DepthwiseConv2D(kernel_size, strides=stride, padding='same' if stride == 1 else 'valid', use_bias=False, name=prefix + 'depthwise')
        self.expand_depthwise_bn = layers.BatchNormalization(axis=channel_axis, epsilon=1e-3, momentum=0.999, name=prefix + 'depthwise/BatchNorm')
        
        if self.se_ratio:
            self._se_block0 = _se_block(_depth(infilters * expansion), se_ratio, prefix)
        
        self.project_conv = layers.Conv2D(filters, kernel_size=1, padding='same', use_bias=False, name=prefix + 'project')
        self.project_bn = layers.BatchNormalization(axis=channel_axis, epsilon=1e-3, momentum=0.999, name=prefix + 'project/BatchNorm')
        
    def call(self, x, training):
        
        channel_axis = -1
        shortcut = x
        
        prefix = 'expanded_conv/'
        
        if self.block_id:
            # Expand
            prefix = 'expanded_conv_{}/'.format(self.block_id)
            x = self.expand_conv1(x)
            x = self.expand_bn1(x, training = training)
            x = self.activation(x)

        if self.stride == 2:
            x = layers.ZeroPadding2D(padding=imagenet_utils.correct_pad(x, self.kernel_size), name=prefix + 'depthwise/pad')(x)
            
        x = self.expand_depthwise_conv(x)
        x = self.expand_depthwise_bn(x, training = training)
        x = self.activation(x)

        if self.se_ratio:
            x = self._se_block0(x)

        x = self.project_conv(x)
        x = self.project_bn(x)

        if self.stride == 1 and self.infilters == self.filters:
            x = layers.Add(name=prefix + 'Add')([shortcut, x])
            
        return x








    

class MobileNetV3_Small(tf.keras.Model):
    def __init__(self, branch_number):
        super(MobileNetV3_Small, self).__init__()

        self.branch_number = branch_number
        self.se_ratio = 0.25
        self.channel_axis = -1
        self.alpha = 1
        self.minimalistic = False
        
        
        ################ Begining layers ############################
        #### changed this stride from two to 1
        self.conv1 = layers.Conv2D(16, kernel_size=3, strides=(2, 2), padding='same', use_bias=False, name='Conv', input_shape = (image_resize, image_resize, 3))
        self.bn1 = layers.BatchNormalization(epsilon=1e-3, momentum=0.999, name='Conv/BatchNorm')
        

        ################ MobileNt blocks #############################
        ### expansion, filters, kernel_size, stride, se_ratio, activation, block_id, infilters
        self._inverted_res_block0 = _inverted_res_block(1, 16, 3, 2, self.se_ratio, relu, 0, infilters=16)
        self._inverted_res_block1 = _inverted_res_block(72. / 16, 24, 3, 2, None, relu, 1, infilters=16)
        self._inverted_res_block2 = _inverted_res_block(88. / 24, 24, 3, 1, None, relu, 2, infilters=24)
        self._inverted_res_block3 = _inverted_res_block(4, 40, 5, 2, self.se_ratio, hard_swish, 3, infilters=24)
        self._inverted_res_block4 = _inverted_res_block(6, 40, 5, 1, self.se_ratio, hard_swish, 4, infilters=40)
        self._inverted_res_block5 = _inverted_res_block(6, 40, 5, 1, self.se_ratio, hard_swish, 5, infilters=40)
        self._inverted_res_block6 = _inverted_res_block(3, 48, 5, 1, self.se_ratio, hard_swish, 6, infilters=40)
        self._inverted_res_block7 = _inverted_res_block(3, 48, 5, 1, self.se_ratio, hard_swish, 7, infilters=48)
        self._inverted_res_block8 = _inverted_res_block(6, 96, 5, 2, self.se_ratio, hard_swish, 8, infilters=48)
        self._inverted_res_block9 = _inverted_res_block(6, 96, 5, 1, self.se_ratio, hard_swish, 9, infilters=96)
        self._inverted_res_block10 = _inverted_res_block(6, 96, 5, 1, self.se_ratio, hard_swish, 10, infilters=96)
        
        
        ################ after blocks #######################3#######
        self.after_conv1 = layers.Conv2D(_depth(96*6), kernel_size=1, padding='same', use_bias=False, name='Conv_1')
        self.after_bn = layers.BatchNormalization(axis=self.channel_axis, epsilon=1e-3, momentum=0.999, name='Conv_1/BatchNorm')
        self.after_conv2 = layers.Conv2D(_depth(1024), kernel_size=1, padding='same', use_bias=True, name='Conv_2')
        
        ################ Final Layers (all branches) ################################
        self.final_conv = layers.Conv2D(10, kernel_size=1, padding='same', name='Logits')
        
        self.final_conv0 = layers.Conv2D(10, kernel_size=1, padding='same', name='Logits0')
        self.final_conv1 = layers.Conv2D(10, kernel_size=1, padding='same', name='Logits1')
        self.final_conv2 = layers.Conv2D(10, kernel_size=1, padding='same', name='Logits2')
        self.final_conv3 = layers.Conv2D(10, kernel_size=1, padding='same', name='Logits3')
        self.final_conv4 = layers.Conv2D(10, kernel_size=1, padding='same', name='Logits4')
        self.final_conv5 = layers.Conv2D(10, kernel_size=1, padding='same', name='Logits5')
        self.final_conv6 = layers.Conv2D(10, kernel_size=1, padding='same', name='Logits6')
        self.final_conv7 = layers.Conv2D(10, kernel_size=1, padding='same', name='Logits7')
        self.final_conv8 = layers.Conv2D(10, kernel_size=1, padding='same', name='Logits8')
        self.final_conv9 = layers.Conv2D(10, kernel_size=1, padding='same', name='Logits9')
        
        
    def call(self, inputs, training):
        inference_flag = -1000
        comp_latency_list_backbone = [0 for i in range(self.branch_number)]
        comp_latency_list_exitbranch = [0 for i in range(self.branch_number)]
        out_vector_list = [[] for i in range(self.branch_number)]
        
        if (training != False and training != True):
            inference_flag = 1000
            training = False
            
            

        ################ HERE ########################
        start_time = tf.timestamp()
        x = layers.Rescaling(scale=1./ 127.5, offset=-1.)(inputs)
        x = self.conv1(x)
        x = self.bn1(x, training=training)
        x = hard_swish(x)
        
        x = self._inverted_res_block0(x)
        comp_latency_list_backbone[0] = tf.timestamp() - start_time
        
        ########branch 0
#         print("bl0 ", x.shape)
        start_time = tf.timestamp()
        temp = layers.GlobalAveragePooling2D()(x)
        temp = layers.Reshape((1, 1, x.shape[-1]))(temp)
        temp = layers.Dropout(0.5)(temp)
        temp = self.final_conv0(temp)
        temp = layers.Flatten()(temp)
        temp = layers.Activation(activation='softmax', name='Predictions0')(temp)
        out_vector_list[0] = temp
        comp_latency_list_exitbranch[0] = tf.timestamp() - start_time
        ################
    
        start_time = tf.timestamp()
        x = self._inverted_res_block1(x)
        comp_latency_list_backbone[1] = tf.timestamp() - start_time
        
        ########branch 1
#         print("bl1 ", x.shape)
        start_time = tf.timestamp()
        temp = layers.GlobalAveragePooling2D()(x)
        temp = layers.Reshape((1, 1, x.shape[-1]))(temp)
        temp = layers.Dropout(0.5)(temp)
        temp = self.final_conv1(temp)
        temp = layers.Flatten()(temp)
        temp = layers.Activation(activation='softmax', name='Predictions1')(temp)
        out_vector_list[1] = temp
        comp_latency_list_exitbranch[1] = tf.timestamp() - start_time
        ################
        
        start_time = tf.timestamp()
        x = self._inverted_res_block2(x)
        comp_latency_list_backbone[2] = tf.timestamp() - start_time
        
        
        ########branch 2
#         print("bl2 ", x.shape)
        start_time = tf.timestamp()
        temp = layers.GlobalAveragePooling2D()(x)
        temp = layers.Reshape((1, 1, x.shape[-1]))(temp)
        temp = layers.Dropout(0.5)(temp)
        temp = self.final_conv2(temp)
        temp = layers.Flatten()(temp)
        temp = layers.Activation(activation='softmax', name='Predictions2')(temp)
        out_vector_list[2] = temp
        comp_latency_list_exitbranch[2] = tf.timestamp() - start_time
        ################
        
        start_time = tf.timestamp()
        x = self._inverted_res_block3(x)
        comp_latency_list_backbone[3] = tf.timestamp() - start_time
        
        
        ########branch 3
#         print("bl3 ", x.shape)
        start_time = tf.timestamp()
        temp = layers.GlobalAveragePooling2D()(x)
        temp = layers.Reshape((1, 1, x.shape[-1]))(temp)
        temp = layers.Dropout(0.5)(temp)
        temp = self.final_conv3(temp)
        temp = layers.Flatten()(temp)
        temp = layers.Activation(activation='softmax', name='Predictions3')(temp)
        out_vector_list[3] = temp
        comp_latency_list_exitbranch[3] = tf.timestamp() - start_time
        ################
        
        start_time = tf.timestamp()
        x = self._inverted_res_block4(x)
        comp_latency_list_backbone[4] = tf.timestamp() - start_time
        
        
        ########branch 4
#         print("bl4 ", x.shape)
        start_time = tf.timestamp()
        temp = layers.GlobalAveragePooling2D()(x)
        temp = layers.Reshape((1, 1, x.shape[-1]))(temp)
        temp = layers.Dropout(0.5)(temp)
        temp = self.final_conv4(temp)
        temp = layers.Flatten()(temp)
        temp = layers.Activation(activation='softmax', name='Predictions4')(temp)
        out_vector_list[4] = temp
        comp_latency_list_exitbranch[4] = tf.timestamp() - start_time
        ################
        
        start_time = tf.timestamp()
        x = self._inverted_res_block5(x)
        comp_latency_list_backbone[5] = tf.timestamp() - start_time
        
        
        ########branch 5
#         print("bl5 ", x.shape)
        start_time = tf.timestamp()
        temp = layers.GlobalAveragePooling2D()(x)
        temp = layers.Reshape((1, 1, x.shape[-1]))(temp)
        temp = layers.Dropout(0.5)(temp)
        temp = self.final_conv5(temp)
        temp = layers.Flatten()(temp)
        temp = layers.Activation(activation='softmax', name='Predictions5')(temp)
        out_vector_list[5] = temp
        comp_latency_list_exitbranch[5] = tf.timestamp() - start_time
        ################
        
        start_time = tf.timestamp()
        x = self._inverted_res_block6(x)
        comp_latency_list_backbone[6] = tf.timestamp() - start_time
        
        
        ########branch 6
#         print("bl6 ", x.shape)
        start_time = tf.timestamp()
        temp = layers.GlobalAveragePooling2D()(x)
        temp = layers.Reshape((1, 1, x.shape[-1]))(temp)
        temp = layers.Dropout(0.5)(temp)
        temp = self.final_conv6(temp)
        temp = layers.Flatten()(temp)
        temp = layers.Activation(activation='softmax', name='Predictions6')(temp)
        out_vector_list[6] = temp
        comp_latency_list_exitbranch[6] = tf.timestamp() - start_time
        ################
        
        start_time = tf.timestamp()
        x = self._inverted_res_block7(x)
        comp_latency_list_backbone[7] = tf.timestamp() - start_time
        
        
        ########branch 7
#         print("bl7 ", x.shape)
        start_time = tf.timestamp()
        temp = layers.GlobalAveragePooling2D()(x)
        temp = layers.Reshape((1, 1, x.shape[-1]))(temp)
        temp = layers.Dropout(0.5)(temp)
        temp = self.final_conv7(temp)
        temp = layers.Flatten()(temp)
        temp = layers.Activation(activation='softmax', name='Predictions7')(temp)
        out_vector_list[7] = temp
        comp_latency_list_exitbranch[7] = tf.timestamp() - start_time
        ################
        
        start_time = tf.timestamp()
        x = self._inverted_res_block8(x)
        comp_latency_list_backbone[8] = tf.timestamp() - start_time
        
        ########branch 8
#         print("bl8 ", x.shape)
        start_time = tf.timestamp()
        temp = layers.GlobalAveragePooling2D()(x)
        temp = layers.Reshape((1, 1, x.shape[-1]))(temp)
        temp = layers.Dropout(0.5)(temp)
        temp = self.final_conv8(temp)
        temp = layers.Flatten()(temp)
        temp = layers.Activation(activation='softmax', name='Predictions8')(temp)
        out_vector_list[8] = temp
        comp_latency_list_exitbranch[8] = tf.timestamp() - start_time
        ################
        
        start_time = tf.timestamp()
        x = self._inverted_res_block9(x)
        comp_latency_list_backbone[9] = tf.timestamp() - start_time
        
        ########branch 9
#         print("bl9 ", x.shape)
        start_time = tf.timestamp()
        temp = layers.GlobalAveragePooling2D()(x)
        temp = layers.Reshape((1, 1, x.shape[-1]))(temp)
        temp = layers.Dropout(0.5)(temp)
        temp = self.final_conv9(temp)
        temp = layers.Flatten()(temp)
        temp = layers.Activation(activation='softmax', name='Predictions9')(temp)
        out_vector_list[9] = temp
        comp_latency_list_exitbranch[9] = tf.timestamp() - start_time
        ################
        
        start_time = tf.timestamp()
        x = self._inverted_res_block10(x)
        
        x = self.after_conv1(x)
        x = self.after_bn(x, training = training)
        x = hard_swish(x)
        comp_latency_list_backbone[10] = tf.timestamp() - start_time
        

        #########branch 10(last)
#         print("bl10 ", x.shape)
        start_time = tf.timestamp()
        temp = x
        x = layers.GlobalAveragePooling2D()(x)
        x = layers.Reshape((1, 1, temp.shape[-1]))(x)
        x = self.after_conv2(x)
        x = hard_swish(x)

        x = layers.Dropout(0.5)(x)
        x = self.final_conv(x)
        x = layers.Flatten()(x)
        x = layers.Activation(activation='softmax', name='Predictions')(x)
        out_vector_list[10] = x
        comp_latency_list_exitbranch[10] = tf.timestamp() - start_time
        ####################
        
        if (inference_flag == 1000):
            return(out_vector_list, comp_latency_list_backbone, comp_latency_list_exitbranch)
        

        return out_vector_list
    

In [9]:
#####MobileNetV3Small + CIFAR10 (128)
#### don't forget that this is trained with first stride 1

opt=optimizers.Adam()

branch_number = 11
ls = ['categorical_crossentropy' for i in range(branch_number)]

model_MobileNetV3_Small_128 = MobileNetV3_Small(branch_number)
model_MobileNetV3_Small_128.compile(optimizer=opt, loss=ls, metrics=['accuracy'])

model_MobileNetV3_Small_128.load_weights('profiling_models/MobileNetV3_11out_bs32_epoch200_lr0.01_CIFAR(128)')
model_MobileNetV3_Small_128.evaluate(test_it);














In [12]:
#####MobileNetV3Small + CIFAR10 (224)
#### don't forget that this is trained with first stride 2
opt=optimizers.Adam()

branch_number = 11
ls = ['categorical_crossentropy' for i in range(branch_number)]

model_MobileNetV3_Small_224 = MobileNetV3_Small(branch_number)
model_MobileNetV3_Small_224.compile(optimizer=opt, loss=ls, metrics=['accuracy'])

model_MobileNetV3_Small_224.load_weights('profiling_models/MobileNetV3_11out_bs32_epoch200_lr0.01_CIFAR(224)')
model_MobileNetV3_Small_224.evaluate(test_it);




In [15]:
##### passing TRAIN samples through model
##### this is to save computation time and exit rates for all the possible branches
print("%%%%%%%%%%%%%%%%%%%%%%% MobileNet 128 %%%%%%%%%%%%%%%%%%%%%%%%%%%%")
image_resize = 128
per_sample_label_list = []
per_sample_out_vector_list = []
per_sample_comp_latency_backbone_list = []
per_sample_comp_latency_exitbranch_list = []


#### saving intermediate data once
for i in range (len(train_it_128)):
    temp_batch = train_it_128[i]

    for j in range (len(temp_batch[0])):
        pic = temp_batch[0][j]
        label = temp_batch[1][j]
        per_sample_label_list.append(np.array(label).reshape(1,10))

        res = model_MobileNetV3_Small_128(np.array(pic.reshape(1,image_resize,image_resize,3)), training = 1000)
   
        per_sample_out_vector_list.append(res[0])
        per_sample_comp_latency_backbone_list.append(res[1])
        per_sample_comp_latency_exitbranch_list.append(res[2])

print ("computation latencies backbone(train data)   ", np.mean(per_sample_comp_latency_backbone_list, axis=0)*1000)
print ("computation latencies exitbranch(train data)   ", np.mean(per_sample_comp_latency_exitbranch_list, axis=0)*1000)




##### passing TRAIN samples through model
##### this is to save computation time and exit rates for all the possible branches

entropy_threshold_list = np.linspace(0.000000001, 0.999999, num=20)

### all possible branches, and chosen number
chosen_number = branch_number
I = list(range(1, branch_number+1))

#### number of train samples for simulation
sample_number = train_it_128.samples



#### per placement, in this case it is only 1, all the possible branches
for item in combinations(I, chosen_number):
    placement = list(np.array(item))
    print ("&&&&& selected exit ", placement , "  &&&&&&")
        
    ent_exitrate_list = []

    ##### per threshold 
    for thresh in entropy_threshold_list:
        print ("------------------ ", thresh, "------------------------")

        threshold_exit = []
        threshold_exitrate = [[] for i in range(branch_number)]

        ##### per sample
        for sample in range (sample_number):

            #### determining the exit branch based on entropy of the output and thresh
            for exit in range(branch_number):
                out = per_sample_out_vector_list[sample][exit]
                entropy = -1 * tf.math.reduce_sum((tf.math.log(out) * out)/ np.log(out.shape[1]))

                if (entropy < thresh or exit+1==branch_number):
                    threshold_exit.append(exit+1)
                    break


        # handling exit percentage part
        unique, counts = np.unique(threshold_exit, return_counts=True)
        exitper_list_dict = dict(zip(unique, counts))
        print ("exit rate per thresh ", exitper_list_dict)

        for i in range(branch_number):
            if (exitper_list_dict.get(placement[i]) is None):
                threshold_exitrate[i].append (0)
            else:
                threshold_exitrate[i].append(exitper_list_dict.get(placement[i]))


        ent_exitrate_list.append(np.mean(threshold_exitrate, axis=1)/sample_number)

    print("------------------------------------------------------------------")
    print("exit rate average ", np.mean(ent_exitrate_list, axis=0))
    break
    
    
    
    
    
    
##### passing TRAIN samples through model
##### this is to save computation time and exit rates for all the possible branches
print("%%%%%%%%%%%%%%%%%%%%%%% MobileNet 224 %%%%%%%%%%%%%%%%%%%%%%%%%%%%")
image_resize = 224
per_sample_label_list = []
per_sample_out_vector_list = []
per_sample_comp_latency_backbone_list = []
per_sample_comp_latency_exitbranch_list = []


#### saving intermediate data once
for i in range (len(train_it_224)):
    temp_batch = train_it_224[i]

    for j in range (len(temp_batch[0])):
        pic = temp_batch[0][j]
        label = temp_batch[1][j]
        per_sample_label_list.append(np.array(label).reshape(1,10))

        res = model_MobileNetV3_Small_224(np.array(pic.reshape(1,image_resize,image_resize,3)), training = 1000)
   
        per_sample_out_vector_list.append(res[0])
        per_sample_comp_latency_backbone_list.append(res[1])
        per_sample_comp_latency_exitbranch_list.append(res[2])

print ("computation latencies backbone(train data)   ", np.mean(per_sample_comp_latency_backbone_list, axis=0)*1000)
print ("computation latencies exitbranch(train data)   ", np.mean(per_sample_comp_latency_exitbranch_list, axis=0)*1000)




##### passing TRAIN samples through model
##### this is to save computation time and exit rates for all the possible branches

entropy_threshold_list = np.linspace(0.000000001, 0.999999, num=20)

### all possible branches, and chosen number
chosen_number = branch_number
I = list(range(1, branch_number+1))

#### number of train samples for simulation
sample_number = train_it_224.samples



#### per placement, in this case it is only 1, all the possible branches
for item in combinations(I, chosen_number):
    placement = list(np.array(item))
    print ("&&&&& selected exit ", placement , "  &&&&&&")
        
    ent_exitrate_list = []

    ##### per threshold 
    for thresh in entropy_threshold_list:
        print ("------------------ ", thresh, "------------------------")

        threshold_exit = []
        threshold_exitrate = [[] for i in range(branch_number)]

        ##### per sample
        for sample in range (sample_number):

            #### determining the exit branch based on entropy of the output and thresh
            for exit in range(branch_number):
                out = per_sample_out_vector_list[sample][exit]
                entropy = -1 * tf.math.reduce_sum((tf.math.log(out) * out)/ np.log(out.shape[1]))

                if (entropy < thresh or exit+1==branch_number):
                    threshold_exit.append(exit+1)
                    break


        # handling exit percentage part
        unique, counts = np.unique(threshold_exit, return_counts=True)
        exitper_list_dict = dict(zip(unique, counts))
        print ("exit rate per thresh ", exitper_list_dict)

        for i in range(branch_number):
            if (exitper_list_dict.get(placement[i]) is None):
                threshold_exitrate[i].append (0)
            else:
                threshold_exitrate[i].append(exitper_list_dict.get(placement[i]))


        ent_exitrate_list.append(np.mean(threshold_exitrate, axis=1)/sample_number)

    print("------------------------------------------------------------------")
    print("exit rate average ", np.mean(ent_exitrate_list, axis=0))
    break

%%%%%%%%%%%%%%%%%%%%%%% MobileNet 128 %%%%%%%%%%%%%%%%%%%%%%%%%%%%










computation latencies backbone(train data)    [9.05846252 3.70445464 3.67722706 8.27668501 8.23240369 8.30851388
 7.72183672 8.2740774  8.28845204 8.25237971 9.89978122]
computation latencies exitbranch(train data)    [2.87825449 2.88313701 2.88174295 2.85760005 2.86198883 2.86371267
 2.85832611 2.85649922 2.86270422 2.86827981 4.23583151]
&&&&& selected exit  [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]   &&&&&&
------------------  1e-09 ------------------------
exit rate per thresh  {11: 50000}
------------------  0.052631527263157896 ------------------------
exit rate per thresh  {11: 50000}
------------------  0.10526305352631579 ------------------------
exit rate per thresh  {2: 1, 11: 49999}
------------------  0.1578945797894737 ------------------------
exit rate per thresh  {2: 1, 3: 2, 11: 49997}
------------------  0.21052610605263158 ------------------------
exit rate per thresh  {2: 5, 3: 4, 11: 49991}
------------------  0.26315763231578954 ------------------------
exit rate per th