In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import optimizers
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import add, Input, Dense, Conv2D, MaxPooling2D, AveragePooling2D, ZeroPadding2D, Dropout, Flatten, Concatenate, Reshape, Activation, BatchNormalization, GlobalAveragePooling2D, ZeroPadding2D
from tensorflow.nn import local_response_normalization
from tensorflow.python.keras.layers.merge import concatenate
from tensorflow.keras.activations import relu
import sys
import matplotlib.pyplot as plt
import math
import matplotlib
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.regularizers import l2, l1, l1_l2
from itertools import permutations, combinations
import cv2

In [2]:
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True  # dynamically grow the memory used on the GPU
config.log_device_placement = True  # to log device placement (on which device the operation ran)
# config.gpu_options.per_process_gpu_memory_fraction = 0.60
sess = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(sess)

Device mapping:
/job:localhost/replica:0/task:0/device:GPU:0 -> device: 0, name: GeForce RTX 3080, pci bus id: 0000:41:00.0, compute capability: 8.6



In [3]:
#### Using CIFAR10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Preprocess the data (these are NumPy arrays)
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255


y_train = y_train.astype("float32")
y_test = y_test.astype("float32")

y_train = keras.utils.to_categorical(y_train, num_classes=10)
y_test = keras.utils.to_categorical(y_test, num_classes=10) 

global_batch_size = 32
image_resize = 299

########## Train
#### removing data augmentation only here for profiling
train_datagen = ImageDataGenerator(rescale=1./255)

# train_datagen = ImageDataGenerator(rescale=1./255, featurewise_center = True,
#                                    rotation_range = 20, horizontal_flip = True, height_shift_range = 0.2,
#                                    width_shift_range = 0.2, zoom_range = 0.2, channel_shift_range = 0.2)


train_it = train_datagen.flow_from_directory(
        'cifar10/train',
        class_mode='categorical',
        target_size=(image_resize, image_resize),
        batch_size=global_batch_size)


############ Test
test_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.5)

validation_it = test_datagen.flow_from_directory(
        'cifar10/test',
        class_mode='categorical',
        target_size=(image_resize, image_resize),
        batch_size=global_batch_size,
        subset = "training",seed = 545)

test_it = test_datagen.flow_from_directory(
        'cifar10/test',
        class_mode='categorical',
        target_size=(image_resize, image_resize),
        batch_size=global_batch_size,
        subset = "validation",
        seed = 545)


for data_batch, labels_batch in train_it:
    print('data batch shape:', data_batch.shape)
    print('labels batch shape:', labels_batch.shape)
    break


Found 50000 images belonging to 10 classes.
Found 5000 images belonging to 10 classes.
Found 5000 images belonging to 10 classes.
data batch shape: (32, 299, 299, 3)
labels batch shape: (32, 10)


In [4]:
class CustomLoss(keras.losses.Loss):
    def __init__(self, factor):
        super().__init__()
        self.factor = factor

    def call(self, y_true, y_pred):
        ce = tf.keras.losses.categorical_crossentropy(y_true, y_pred)
        return ce * self.factor
    
 

class conv2d_bn(tf.keras.layers.Layer):
    def __init__(self, filters, num_row, num_col, padding='same',strides=(1, 1)):
        super(conv2d_bn, self).__init__()
        
        self.conv = Conv2D(filters, (num_row, num_col), strides=strides, padding=padding, use_bias=False)
        self.bn = BatchNormalization(axis=3, scale=False)
        self.act = Activation('relu')
        
    def call(self, x, training):
        if (training != False and training != True):
            training = False

        x = self.conv(x)
        x = self.bn(x, training = training)
        x = self.act(x)

        return x

    

class CIFAR_Inception_V3(tf.keras.Model):
    def __init__(self, branch_number):
        super(CIFAR_Inception_V3, self).__init__()

        self.branch_number = branch_number

        
        ### Begining layers
        self.conv1_bg = Conv2D(32, (3, 3), strides=(2,2) , padding='valid', use_bias=False, input_shape=(image_resize, image_resize, 3))
        self.bn1_bg = BatchNormalization(axis=3, scale=False)
        
        
        self.conv2_bg = conv2d_bn(32, 3, 3, padding='valid')
        self.conv3_bg = conv2d_bn(64, 3, 3)
        self.pool1_bg = MaxPooling2D((3, 3), strides=(2, 2))

        self.conv4_bg = conv2d_bn(80, 1, 1, padding='valid')
        self.conv5_bg = conv2d_bn(192, 3, 3, padding='valid')
        self.pool2_bg = MaxPooling2D((3, 3), strides=(2, 2))

        
        ### Inception blocks
        # mixed 0: 35 x 35 x 256 (bl0)
        self.branch1x1_bl0 = conv2d_bn(64, 1, 1)

        self.branch5x5_bl0_0 = conv2d_bn(48, 1, 1)
        self.branch5x5_bl0_1 = conv2d_bn(64, 5, 5)

        self.branch3x3dbl_bl0_0 = conv2d_bn(64, 1, 1)
        self.branch3x3dbl_bl0_1 = conv2d_bn(96, 3, 3)
        self.branch3x3dbl_bl0_2 = conv2d_bn(96, 3, 3)
        
        self.branch_pool_bl0 = AveragePooling2D((3, 3), strides=(1, 1), padding='same')
        self.branch_poolout_bl0 = conv2d_bn(32, 1, 1)
        
        
        
        # mixed 1: 35 x 35 x 288 (bl1)
        self.branch1x1_bl1 = conv2d_bn(64, 1, 1)

        self.branch5x5_bl1_0 = conv2d_bn(48, 1, 1)
        self.branch5x5_bl1_1 = conv2d_bn(64, 5, 5)

        self.branch3x3dbl_bl1_0 = conv2d_bn(64, 1, 1)
        self.branch3x3dbl_bl1_1 = conv2d_bn(96, 3, 3)
        self.branch3x3dbl_bl1_2 = conv2d_bn(96, 3, 3)

        self.branch_pool_bl1 = AveragePooling2D((3, 3), strides=(1, 1), padding='same')
        self.branch_poolout_bl1 = conv2d_bn(64, 1, 1)
        

        # mixed 2: 35 x 35 x 288 (bl2)
        self.branch1x1_bl2 = conv2d_bn(64, 1, 1)

        self.branch5x5_bl2_0 = conv2d_bn(48, 1, 1)
        self.branch5x5_bl2_1 = conv2d_bn(64, 5, 5)

        self.branch3x3dbl_bl2_0 = conv2d_bn(64, 1, 1)
        self.branch3x3dbl_bl2_1 = conv2d_bn(96, 3, 3)
        self.branch3x3dbl_bl2_2 = conv2d_bn(96, 3, 3)

        self.branch_pool_bl2 = AveragePooling2D((3, 3), strides=(1, 1), padding='same')
        self.branch_poolout_bl2 = conv2d_bn(64, 1, 1)
        
        

        # mixed 3: 17 x 17 x 768 (bl3)
        self.branch3x3_bl3 = conv2d_bn(384, 3, 3, strides=(2, 2), padding='valid')

        self.branch3x3dbl_bl3_0 = conv2d_bn(64, 1, 1)
        self.branch3x3dbl_bl3_1 = conv2d_bn(96, 3, 3)
        self.branch3x3dbl_bl3_2 = conv2d_bn(96, 3, 3, strides=(2, 2), padding='valid')

        self.branch_pool_bl3 = MaxPooling2D((3, 3), strides=(2, 2))

        # mixed 4: 17 x 17 x 768 (bl4)
        self.branch1x1_bl4 = conv2d_bn(192, 1, 1)

        self.branch7x7_bl4_0 = conv2d_bn(128, 1, 1)
        self.branch7x7_bl4_1 = conv2d_bn(128, 1, 7)
        self.branch7x7_bl4_2 = conv2d_bn(192, 7, 1)

        self.branch7x7dbl_bl4_0 = conv2d_bn(128, 1, 1)
        self.branch7x7dbl_bl4_1 = conv2d_bn(128, 7, 1)
        self.branch7x7dbl_bl4_2 = conv2d_bn(128, 1, 7)
        self.branch7x7dbl_bl4_3 = conv2d_bn(128, 7, 1)
        self.branch7x7dbl_bl4_4 = conv2d_bn(192, 1, 7)

        self.branch_pool_bl4 = AveragePooling2D((3, 3), strides=(1, 1), padding='same')
        self.branch_poolou4_bl4 = conv2d_bn(192, 1, 1)
        
    
        # mixed 5: 17 x 17 x 768 (bl5)
        self.branch1x1_bl5 = conv2d_bn(192, 1, 1)

        self.branch7x7_bl5_0 = conv2d_bn(160, 1, 1)
        self.branch7x7_bl5_1 = conv2d_bn(160, 1, 7)
        self.branch7x7_bl5_2 = conv2d_bn(192, 7, 1)

        self.branch7x7dbl_bl5_0 = conv2d_bn(160, 1, 1)
        self.branch7x7dbl_bl5_1 = conv2d_bn(160, 7, 1)
        self.branch7x7dbl_bl5_2 = conv2d_bn(160, 1, 7)
        self.branch7x7dbl_bl5_3 = conv2d_bn(160, 7, 1)
        self.branch7x7dbl_bl5_4 = conv2d_bn(192, 1, 7)

        self.branch_pool_bl5 = AveragePooling2D((3, 3), strides=(1, 1), padding='same')
        self.branch_poolout_bl5 = conv2d_bn(192, 1, 1)
        
        
        # mixed 6: 17 x 17 x 768 (bl6)
        self.branch1x1_bl6 = conv2d_bn(192, 1, 1)

        self.branch7x7_bl6_0 = conv2d_bn(160, 1, 1)
        self.branch7x7_bl6_1 = conv2d_bn(160, 1, 7)
        self.branch7x7_bl6_2 = conv2d_bn(192, 7, 1)

        self.branch7x7dbl_bl6_0 = conv2d_bn(160, 1, 1)
        self.branch7x7dbl_bl6_1 = conv2d_bn(160, 7, 1)
        self.branch7x7dbl_bl6_2 = conv2d_bn(160, 1, 7)
        self.branch7x7dbl_bl6_3 = conv2d_bn(160, 7, 1)
        self.branch7x7dbl_bl6_4 = conv2d_bn(192, 1, 7)

        self.branch_pool_bl6 = AveragePooling2D((3, 3), strides=(1, 1), padding='same')
        self.branch_poolout_bl6 = conv2d_bn(192, 1, 1)
        
        
        
        # mixed 7: 17 x 17 x 768 (bl7)
        self.branch1x1_bl7 = conv2d_bn(192, 1, 1)

        self.branch7x7_bl7_0 = conv2d_bn(192, 1, 1)
        self.branch7x7_bl7_1 = conv2d_bn(192, 1, 7)
        self.branch7x7_bl7_2 = conv2d_bn(192, 7, 1)

        self.branch7x7dbl_bl7_0 = conv2d_bn(192, 1, 1)
        self.branch7x7dbl_bl7_1 = conv2d_bn(192, 7, 1)
        self.branch7x7dbl_bl7_2 = conv2d_bn(192, 1, 7)
        self.branch7x7dbl_bl7_3 = conv2d_bn(192, 7, 1)
        self.branch7x7dbl_bl7_4 = conv2d_bn(192, 1, 7)

        self.branch_pool_bl7 = AveragePooling2D((3, 3), strides=(1, 1), padding='same')
        self.branch_poolout_bl7 = conv2d_bn(192, 1, 1)
        

        # mixed 8: 8 x 8 x 1280 (bl8)
        self.branch3x3_bl8 = conv2d_bn(192, 1, 1)
        self.branch3x3_bl8_0 = conv2d_bn(320, 3, 3, strides=(2, 2), padding='valid')

        self.branch7x7x3_bl8_0 = conv2d_bn(192, 1, 1)
        self.branch7x7x3_bl8_1 = conv2d_bn(192, 1, 7)
        self.branch7x7x3_bl8_2 = conv2d_bn(192, 7, 1)
        self.branch7x7x3_bl8_3 = conv2d_bn(192, 3, 3, strides=(2, 2), padding='valid')
        
        self.branch_pool_bl8 = MaxPooling2D((3, 3), strides=(2, 2))

        # mixed 9: 8 x 8 x 2048 (bl9)
        self.branch1x1_bl9 = conv2d_bn(320, 1, 1)

        self.branch3x3_bl9 = conv2d_bn(384, 1, 1)
        self.branch3x3_1_bl9 = conv2d_bn(384, 1, 3)
        self.branch3x3_2_bl9 = conv2d_bn(384, 3, 1)

        self.branch3x3dbl_bl9 = conv2d_bn(448, 1, 1)
        self.branch3x3dbl_bl9_0 = conv2d_bn(384, 3, 3)
        self.branch3x3dbl_1_bl9 = conv2d_bn(384, 1, 3)
        self.branch3x3dbl_2_bl9 = conv2d_bn(384, 3, 1)

        self.branch_pool_bl9 = AveragePooling2D((3, 3), strides=(1, 1), padding='same')
        self.branch_poolout_bl9 = conv2d_bn(192, 1, 1)
        

        # mixed 10: 8 x 8 x 2048 (bl10)
        self.branch1x1_bl10 = conv2d_bn(320, 1, 1)

        self.branch3x3_bl10 = conv2d_bn(384, 1, 1)
        self.branch3x3_1_bl10 = conv2d_bn(384, 1, 3)
        self.branch3x3_2_bl10 = conv2d_bn(384, 3, 1)

        self.branch3x3dbl_bl10_0 = conv2d_bn(448, 1, 1)
        self.branch3x3dbl_bl10_1 = conv2d_bn(384, 3, 3)
        self.branch3x3dbl_1_bl10 = conv2d_bn(384, 1, 3)
        self.branch3x3dbl_2_bl10 = conv2d_bn(384, 3, 1)

        self.branch_pool_bl10 = AveragePooling2D((3, 3), strides=(1, 1), padding='same')
        self.branch_poolout_bl10 = conv2d_bn(192, 1, 1)
        
        

        
        ### last exit branch
        self.global_pool = GlobalAveragePooling2D()
        self.dense_0 = Dense(10, activation='softmax')
        self.dense_1 = Dense(10, activation='softmax')
        self.dense_2 = Dense(10, activation='softmax')
        self.dense_3 = Dense(10, activation='softmax')
        self.dense_4 = Dense(10, activation='softmax')
        self.dense_5 = Dense(10, activation='softmax')
        self.dense_6 = Dense(10, activation='softmax')
        self.dense_7 = Dense(10, activation='softmax')
        self.dense_8 = Dense(10, activation='softmax')
        self.dense_9 = Dense(10, activation='softmax')
        self.dense_10 = Dense(10, activation='softmax')
        self.dense_11 = Dense(10, activation='softmax')
        
        self.dense_last = Dense(10, activation='softmax')
        
        
        
    def call(self, inputs, training):
        inference_flag = -1000
        channel_axis = 3
        comp_latency_list_backbone = [0 for i in range(self.branch_number)]
        comp_latency_list_exitbranch = [0 for i in range(self.branch_number)]
        out_vector_list = [[] for i in range(self.branch_number)]
        
        if (training != False and training != True):
            inference_flag = 1000
            training = False
            
        ################ HERE ########################
        ### begining layers
        start_time = tf.timestamp()
        x = self.conv1_bg(inputs)
        x = self.bn1_bg(x, training = training)
        x = relu(x)

        x = self.conv2_bg(x)
        x = self.conv3_bg(x)
        x = self.pool1_bg(x)
        comp_latency_list_backbone[0] = tf.timestamp() - start_time
        
#         print("branch 0 ", x.shape)
        ### *** exit branch 0
        start_time = tf.timestamp()
        temp = self.global_pool(x)
        temp = self.dense_0(temp)
        out_vector_list[0] = temp
        comp_latency_list_exitbranch[0] = tf.timestamp() - start_time
        
        
        
        start_time = tf.timestamp()
        x = self.conv4_bg(x)
        x = self.conv5_bg(x)
        x = self.pool2_bg(x)
        comp_latency_list_backbone[1] = tf.timestamp() - start_time
        
#         print("branch 1 ", x.shape)
        ### *** exit branch 1
        start_time = tf.timestamp()
        temp = self.global_pool(x)
        temp = self.dense_1(temp)
        out_vector_list[1] = temp
        comp_latency_list_exitbranch[1] = tf.timestamp() - start_time

        ### inception blocks
        # mixed 0: 35 x 35 x 256 (bl0)
        start_time = tf.timestamp()
        branch1x1 = self.branch1x1_bl0(x)

        branch5x5 = self.branch5x5_bl0_0(x)
        branch5x5 = self.branch5x5_bl0_1(branch5x5)

        branch3x3dbl = self.branch3x3dbl_bl0_0(x)
        branch3x3dbl = self.branch3x3dbl_bl0_1(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_bl0_2(branch3x3dbl)

        branch_pool = self.branch_pool_bl0(x)
        branch_pool = self.branch_poolout_bl0(branch_pool)
        x = concatenate([branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=channel_axis)
        comp_latency_list_backbone[2] = tf.timestamp() - start_time
        
#         print("branch 2 ", x.shape)
        ### *** exit branch 2
        start_time = tf.timestamp()
        temp = self.global_pool(x)
        temp = self.dense_2(temp)
        out_vector_list[2] = temp
        comp_latency_list_exitbranch[2] = tf.timestamp() - start_time
        
        # mixed 1: 35 x 35 x 288 (bl1)
        start_time = tf.timestamp()
        branch1x1 = self.branch1x1_bl1(x)

        branch5x5 = self.branch5x5_bl1_0(x)
        branch5x5 = self.branch5x5_bl1_1(branch5x5)

        branch3x3dbl = self.branch3x3dbl_bl1_0(x)
        branch3x3dbl = self.branch3x3dbl_bl1_1(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_bl1_2(branch3x3dbl)

        branch_pool = self.branch_pool_bl1(x)
        branch_pool = self.branch_poolout_bl1(branch_pool)
        x = concatenate([branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=channel_axis)
        comp_latency_list_backbone[3] = tf.timestamp() - start_time
        
        
#         print("branch 3 ", x.shape)
        ### *** exit branch 3
        start_time = tf.timestamp()
        temp = self.global_pool(x)
        temp = self.dense_3(temp)
        out_vector_list[3] = temp
        comp_latency_list_exitbranch[3] = tf.timestamp() - start_time
        

        # mixed 2: 35 x 35 x 288 (bl2)
        start_time = tf.timestamp()
        branch1x1 = self.branch1x1_bl2(x)

        branch5x5 = self.branch5x5_bl2_0(x)
        branch5x5 = self.branch5x5_bl2_1(branch5x5)

        branch3x3dbl = self.branch3x3dbl_bl2_0(x)
        branch3x3dbl = self.branch3x3dbl_bl2_1(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_bl2_2(branch3x3dbl)

        branch_pool = self.branch_pool_bl2(x)
        branch_pool = self.branch_poolout_bl2(branch_pool)
        x = concatenate([branch1x1, branch5x5, branch3x3dbl, branch_pool], axis=channel_axis)
        comp_latency_list_backbone[4] = tf.timestamp() - start_time
        

#         print("branch 4 ", x.shape)
        ### *** exit branch 4
        start_time = tf.timestamp()
        temp = self.global_pool(x)
        temp = self.dense_4(temp)
        out_vector_list[4] = temp
        comp_latency_list_exitbranch[4] = tf.timestamp() - start_time
        
        # mixed 3: 17 x 17 x 768 (bl3)
        start_time = tf.timestamp()
        branch3x3 = self.branch3x3_bl3(x)

        branch3x3dbl = self.branch3x3dbl_bl3_0(x)
        branch3x3dbl = self.branch3x3dbl_bl3_1(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_bl3_2(branch3x3dbl)
        
        branch_pool = self.branch_pool_bl3(x)
        x = concatenate([branch3x3, branch3x3dbl, branch_pool], axis=channel_axis)
        comp_latency_list_backbone[5] = tf.timestamp() - start_time
        
#         print("branch 5 ", x.shape)
        ### *** exit branch 5
        start_time = tf.timestamp()
        temp = self.global_pool(x)
        temp = self.dense_5(temp)
        out_vector_list[5] = temp
        comp_latency_list_exitbranch[5] = tf.timestamp() - start_time
        

        # mixed 4: 17 x 17 x 768 (bl4)
        start_time = tf.timestamp()
        branch1x1 = self.branch1x1_bl4(x)

        branch7x7 = self.branch7x7_bl4_0(x)
        branch7x7 = self.branch7x7_bl4_1(branch7x7)
        branch7x7 = self.branch7x7_bl4_2(branch7x7)

        branch7x7dbl = self.branch7x7dbl_bl4_0(x)
        branch7x7dbl = self.branch7x7dbl_bl4_1(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_bl4_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_bl4_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_bl4_4(branch7x7dbl)

        branch_pool = self.branch_pool_bl4(x)
        branch_pool = self.branch_poolou4_bl4(branch_pool)
        x = concatenate([branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=channel_axis)
        comp_latency_list_backbone[6] = tf.timestamp() - start_time

#         print("branch 6 ", x.shape)
        ### *** exit branch 6
        start_time = tf.timestamp()
        temp = self.global_pool(x)
        temp = self.dense_6(temp)
        out_vector_list[6] = temp
        comp_latency_list_exitbranch[6] = tf.timestamp() - start_time
        
        
        # mixed 5: 17 x 17 x 768 (bl5)
        start_time = tf.timestamp()
        branch1x1 = self.branch1x1_bl5(x)

        branch7x7 = self.branch7x7_bl5_0(x)
        branch7x7 = self.branch7x7_bl5_1(branch7x7)
        branch7x7 = self.branch7x7_bl5_2(branch7x7)

        branch7x7dbl = self.branch7x7dbl_bl5_0(x)
        branch7x7dbl = self.branch7x7dbl_bl5_1(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_bl5_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_bl5_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_bl5_4(branch7x7dbl)

        branch_pool = self.branch_pool_bl5(x)
        branch_pool = self.branch_poolout_bl5(branch_pool)
        x = concatenate([branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=channel_axis)
        comp_latency_list_backbone[7] = tf.timestamp() - start_time
        
#         print("branch 7 ", x.shape)
        ### *** exit branch 7
        start_time = tf.timestamp()
        temp = self.global_pool(x)
        temp = self.dense_7(temp)
        out_vector_list[7] = temp
        comp_latency_list_exitbranch[7] = tf.timestamp() - start_time
        
        
        # mixed 6: 17 x 17 x 768 (bl6)
        start_time = tf.timestamp()
        branch1x1 = self.branch1x1_bl6(x)

        branch7x7 = self.branch7x7_bl6_0(x)
        branch7x7 = self.branch7x7_bl6_1(branch7x7)
        branch7x7 = self.branch7x7_bl6_2(branch7x7)

        branch7x7dbl = self.branch7x7dbl_bl6_0(x)
        branch7x7dbl = self.branch7x7dbl_bl6_1(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_bl6_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_bl6_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_bl6_4(branch7x7dbl)

        branch_pool = self.branch_pool_bl6(x)
        branch_pool = self.branch_poolout_bl6(branch_pool)
        x = concatenate([branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=channel_axis)
        comp_latency_list_backbone[8] = tf.timestamp() - start_time

#         print("branch 8 ", x.shape)
        ### *** exit branch 8
        start_time = tf.timestamp()
        temp = self.global_pool(x)
        temp = self.dense_8(temp)
        out_vector_list[8] = temp
        comp_latency_list_exitbranch[8] = tf.timestamp() - start_time
        
            
        # mixed 7: 17 x 17 x 768 (bl7)
        start_time = tf.timestamp()
        branch1x1 = self.branch1x1_bl7(x)

        branch7x7 = self.branch7x7_bl7_0(x)
        branch7x7 = self.branch7x7_bl7_1 (branch7x7)
        branch7x7 = self.branch7x7_bl7_2(branch7x7)

        branch7x7dbl = self.branch7x7dbl_bl7_0(x)
        branch7x7dbl = self.branch7x7dbl_bl7_1(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_bl7_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_bl7_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_bl7_4(branch7x7dbl)

        branch_pool = self.branch_pool_bl7(x)
        branch_pool = self.branch_poolout_bl7(branch_pool)
        x = concatenate([branch1x1, branch7x7, branch7x7dbl, branch_pool], axis=channel_axis)
        comp_latency_list_backbone[9] = tf.timestamp() - start_time
        
#         print("branch 9 ", x.shape)
        ### *** exit branch 9
        start_time = tf.timestamp()
        temp = self.global_pool(x)
        temp = self.dense_9(temp)
        out_vector_list[9] = temp
        comp_latency_list_exitbranch[9] = tf.timestamp() - start_time
        
        # mixed 8: 8 x 8 x 1280 (bl8)
        start_time = tf.timestamp()
        branch3x3 = self.branch3x3_bl8(x)
        branch3x3 = self.branch3x3_bl8_0(branch3x3)

        branch7x7x3 = self.branch7x7x3_bl8_0(x)
        branch7x7x3 = self.branch7x7x3_bl8_1(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_bl8_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_bl8_3(branch7x7x3)

        branch_pool = self.branch_pool_bl8(x)
        x = concatenate([branch3x3, branch7x7x3, branch_pool], axis=channel_axis)
        comp_latency_list_backbone[10] = tf.timestamp() - start_time

#         print("branch 10 ", x.shape)
        ### *** exit branch 10
        start_time = tf.timestamp()
        temp = self.global_pool(x)
        temp = self.dense_10(temp)
        out_vector_list[10] = temp
        comp_latency_list_exitbranch[10] = tf.timestamp() - start_time
        
        
        # mixed 9: 8 x 8 x 2048 (bl9)
        start_time = tf.timestamp()
        branch1x1 =  self.branch1x1_bl9(x)

        branch3x3 = self.branch3x3_bl9(x)
        branch3x3_1 = self.branch3x3_1_bl9(branch3x3)
        branch3x3_2 = self.branch3x3_2_bl9(branch3x3)
        branch3x3 = concatenate([branch3x3_1, branch3x3_2], axis=channel_axis)

        branch3x3dbl = self.branch3x3dbl_bl9(x)
        branch3x3dbl = self.branch3x3dbl_bl9_0(branch3x3dbl)
        branch3x3dbl_1 = self.branch3x3dbl_1_bl9(branch3x3dbl)
        branch3x3dbl_2 = self.branch3x3dbl_2_bl9(branch3x3dbl)
        branch3x3dbl = concatenate([branch3x3dbl_1, branch3x3dbl_2], axis=channel_axis)

        branch_pool = self.branch_pool_bl9(x)
        branch_pool = self.branch_poolout_bl9(branch_pool)
        x = concatenate([branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=channel_axis)
        comp_latency_list_backbone[11] = tf.timestamp() - start_time
        
#         print("branch 11 ", x.shape)
        ### *** exit branch 11
        start_time = tf.timestamp()
        temp = self.global_pool(x)
        temp = self.dense_11(temp)
        out_vector_list[11] = temp
        comp_latency_list_exitbranch[11] = tf.timestamp() - start_time
        
        
        # mixed 10: 8 x 8 x 2048 (bl10)
        start_time = tf.timestamp()
        branch1x1 = self.branch1x1_bl10(x)

        branch3x3 = self.branch3x3_bl10(x)
        branch3x3_1 = self.branch3x3_1_bl10(branch3x3)
        branch3x3_2 = self.branch3x3_2_bl10(branch3x3)
        branch3x3 = concatenate([branch3x3_1, branch3x3_2], axis=channel_axis)

        branch3x3dbl = self.branch3x3dbl_bl10_0(x)
        branch3x3dbl = self.branch3x3dbl_bl10_1(branch3x3dbl)
        branch3x3dbl_1 = self.branch3x3dbl_1_bl10(branch3x3dbl)
        branch3x3dbl_2 = self.branch3x3dbl_2_bl10(branch3x3dbl)
        branch3x3dbl = concatenate([branch3x3dbl_1, branch3x3dbl_2], axis=channel_axis)

        branch_pool = self.branch_pool_bl10(x)
        branch_pool = self.branch_poolout_bl10(branch_pool)
        x = concatenate([branch1x1, branch3x3, branch3x3dbl, branch_pool], axis=channel_axis)
        comp_latency_list_backbone[12] = tf.timestamp() - start_time
        
#         print("branch 12 ", x.shape)
        ### *** exit branch 12 last
        start_time = tf.timestamp()
        temp = self.global_pool(x)
        temp = self.dense_last(temp)
        out_vector_list[12] = temp
        comp_latency_list_exitbranch[12] = tf.timestamp() - start_time
        ##############################################
            
        if (inference_flag == 1000):
            return(out_vector_list, comp_latency_list_backbone, comp_latency_list_exitbranch)
        
        
        return out_vector_list

In [5]:
#### Inception V3 + CIFAR10 (299x299)
##### load weight
opt = tf.keras.optimizers.RMSprop(momentum=0.9)

branch_number= 13

ls = [CustomLoss(1) for i in range(branch_number)]

model_CIFAR_Inception_V3 = CIFAR_Inception_V3(branch_number)
model_CIFAR_Inception_V3.compile(optimizer=opt, loss=ls, metrics=['accuracy'])
model_CIFAR_Inception_V3.load_weights('profiling_models/Inception_V3_13out_bs32_epoch100_lr0.01')
model_CIFAR_Inception_V3.evaluate(test_it);



In [6]:
###### now the profiling part
###### TRAIN PART

In [6]:
##### passing TRAIN samples through model
##### this is to save computation time and exit rates for all the possible branches

per_sample_label_list = []
per_sample_out_vector_list = []
per_sample_comp_latency_backbone_list = []
per_sample_comp_latency_exitbranch_list = []


#### saving intermediate data once
for i in range (len(train_it)):
    temp_batch = train_it[i]

    for j in range (len(temp_batch[0])):
        pic = temp_batch[0][j]
        label = temp_batch[1][j]
        per_sample_label_list.append(np.array(label).reshape(1,10))

        res = model_CIFAR_Inception_V3(np.array(pic.reshape(1,image_resize,image_resize,3)), training = 1000)
   
        per_sample_out_vector_list.append(res[0])
        per_sample_comp_latency_backbone_list.append(res[1])
        per_sample_comp_latency_exitbranch_list.append(res[2])

print ("computation latencies backbone(train data)   ", np.mean(per_sample_comp_latency_backbone_list, axis=0)*1000)
print ("computation latencies exitbranch(train data)   ", np.mean(per_sample_comp_latency_exitbranch_list, axis=0)*1000)




##### passing TRAIN samples through model
##### this is to save computation time and exit rates for all the possible branches

entropy_threshold_list = np.linspace(0.000000001, 0.999999, num=20)

### all possible branches, and chosen number
chosen_number = branch_number
I = list(range(1, branch_number+1))

#### number of train samples for simulation
sample_number = train_it.samples



#### per placement, in this case it is only 1, all the possible branches
for item in combinations(I, chosen_number):
    placement = list(np.array(item))
    print ("&&&&& selected exit ", placement , "  &&&&&&")
        
    ent_exitrate_list = []

    ##### per threshold 
    for thresh in entropy_threshold_list:
        print ("------------------ ", thresh, "------------------------")

        threshold_exit = []
        threshold_exitrate = [[] for i in range(branch_number)]

        ##### per sample
        for sample in range (sample_number):

            #### determining the exit branch based on entropy of the output and thresh
            for exit in range(branch_number):
                out = per_sample_out_vector_list[sample][exit]
                entropy = -1 * tf.math.reduce_sum((tf.math.log(out) * out)/ np.log(out.shape[1]))

                if (entropy < thresh or exit+1==branch_number):
                    threshold_exit.append(exit+1)
                    break


        # handling exit percentage part
        unique, counts = np.unique(threshold_exit, return_counts=True)
        exitper_list_dict = dict(zip(unique, counts))
        print ("exit rate per thresh ", exitper_list_dict)

        for i in range(branch_number):
            if (exitper_list_dict.get(placement[i]) is None):
                threshold_exitrate[i].append (0)
            else:
                threshold_exitrate[i].append(exitper_list_dict.get(placement[i]))


        ent_exitrate_list.append(np.mean(threshold_exitrate, axis=1)/sample_number)

    print("------------------------------------------------------------------")
    print("exit rate average ", np.mean(ent_exitrate_list, axis=0))
    break

computation latencies backbone(train data)    [2.75459888 1.32624567 5.12389028 5.02911969 5.01359507 4.0302824
 7.54120809 7.10608516 7.08174134 7.31302761 5.15766109 7.87290069
 7.76894309]
computation latencies exitbranch(train data)    [0.54501376 0.4557634  0.49166145 0.49352409 0.49595615 0.50750211
 0.51163939 0.50188097 0.50008883 0.50284579 0.52006973 0.50709085
 0.50691931]
&&&&& selected exit  [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]   &&&&&&
------------------  1e-09 ------------------------
exit rate per thresh  {4: 163, 5: 873, 6: 2933, 7: 7202, 8: 8524, 9: 1080, 10: 81, 11: 114, 12: 45, 13: 28985}
------------------  0.052631527263157896 ------------------------
exit rate per thresh  {1: 1277, 2: 5170, 3: 14852, 4: 11114, 5: 6877, 6: 4405, 7: 4232, 8: 1318, 9: 140, 10: 17, 11: 73, 12: 41, 13: 484}
------------------  0.10526305352631579 ------------------------
exit rate per thresh  {1: 2726, 2: 7960, 3: 16387, 4: 9975, 5: 5684, 6: 3320, 7: 2851, 8: 726, 9: 70, 10: 17

In [7]:
##### TEST PART, FOR SIMULATION ONLY

In [None]:
##### passing TEST samples through model once and save the intermediate data
#### this is for doing the simulation

per_sample_label_list = []
per_sample_out_vector_list = []
per_sample_comp_latency_backbone_list = []
per_sample_comp_latency_exitbranch_list = []

for i in range (len(test_it)):
    temp_batch = test_it[i]

    for j in range (len(temp_batch[0])):
        pic = temp_batch[0][j]
        label = temp_batch[1][j]
        per_sample_label_list.append(np.array(label).reshape(1,10))

        res = model_CIFAR_Inception_V3(np.array(pic.reshape(1,image_resize,image_resize,3)), training = 1000)
   
        per_sample_out_vector_list.append(res[0])
        per_sample_comp_latency_backbone_list.append(res[1])
        per_sample_comp_latency_exitbranch_list.append(res[2])
        
print ("computation latencies backbone(test data)   ", np.mean(per_sample_comp_latency_backbone_list, axis=0)*1000)
print ("computation latencies exitbranch(test data)   ", np.mean(per_sample_comp_latency_exitbranch_list, axis=0)*1000)
