In [17]:
# Imports
import tensorflow as tf
from tensorflow.keras.layers import (
    Input, SeparableConv2D, Conv2D, BatchNormalization, LeakyReLU,
    Flatten, Dense, MaxPooling2D
)
from tensorflow.keras.models import Model
import numpy as np

In [18]:
print(tf.__version__)

2.10.0


In [19]:
# Get the dataset(s)
data = []

In [24]:
# Set up CNN using tensorflow
def build_full_cnn_with_early_exit():
    input_layer = Input(shape=(32, 32, 3), name="Input")

    # Layer 1: SeparableConv2D
    x = SeparableConv2D(8, kernel_size=(3, 3), strides=(2, 2), padding="same", depth_multiplier=1)(input_layer)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)

    # Early Exit Branch
    #ee_x = MaxPooling2D(pool_size=(2, 2))(x)
    #ee_x = Flatten()(ee_x)
    #ee_output = Dense(1, activation="sigmoid", name="EarlyExitOutput")(ee_x)

    # Main Network Continues
    # Layer 2: Conv2D   (3*3*8+1)*4
    x = Conv2D(4, kernel_size=(1, 1), strides=(1, 1), padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)

    # Layer 3: SeparableConv2D
    x = SeparableConv2D(16, kernel_size=(3, 3), strides=(2, 2), padding="same", depth_multiplier=2)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)

    # Layer 4: Conv2D
    x = Conv2D(8, kernel_size=(1, 1), strides=(1, 1), padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)

    # Layer 5: SeparableConv2D
    x = SeparableConv2D(20, kernel_size=(3, 3), strides=(2, 2), padding="same", depth_multiplier=4)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)

    # Layer 6: Conv2D
    x = Conv2D(12, kernel_size=(1, 1), strides=(1, 1), padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)

    # Layer 7: SeparableConv2D
    x = SeparableConv2D(32, kernel_size=(3, 3), strides=(2, 2), padding="same", depth_multiplier=8)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)

    # Layer 8: Conv2D
    x = Conv2D(16, kernel_size=(1, 1), strides=(1, 1), padding="same")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU()(x)

    # Fully Connected Layer
    x = Flatten()(x)
    final_output = Dense(3, activation="softmax", name="FinalOutput")(x)

    model = Model(inputs=input_layer, outputs=[final_output], name="EEFullCNN") # outputs=[final_output, ee_output]
    return model

ee_full_cnn = build_full_cnn_with_early_exit()
ee_full_cnn.summary()

Model: "EEFullCNN"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 Input (InputLayer)          [(None, 32, 32, 3)]       0         
                                                                 
 separable_conv2d_28 (Separa  (None, 16, 16, 8)        59        
 bleConv2D)                                                      
                                                                 
 batch_normalization_56 (Bat  (None, 16, 16, 8)        32        
 chNormalization)                                                
                                                                 
 leaky_re_lu_56 (LeakyReLU)  (None, 16, 16, 8)         0         
                                                                 
 conv2d_28 (Conv2D)          (None, 16, 16, 4)         36        
                                                                 
 batch_normalization_57 (Bat  (None, 16, 16, 4)        16

In [11]:
# Training + Validation

In [12]:
# Testing