
# Models

This notebooks serves as a centralized program to obtain some of the used in the rest of the repository.


## Early Exit for Image Classification

Among those models, we have made an implementation of Branchynet that we refer to as Early Exit. This is a technique that allows a model to terminate the inference at intermediate layers, potentially saving computation time and resources when a high confidence prediction can be made early on.

The notebook demonstrates how to:

- Implement Early Exit mechanisms within image classification models.
- Convert TensorFlow models to TFLite format for deployment on CPU and TPU devices.
- Convert models to ONNX format for use with Jetson devices.
- Evaluate the performance of Early Exit models in terms of accuracy and computational efficiency.

 Due to the size of the ImageNet models, here you can also get the base models for the Jetson Orin Nano



## Required packets



In [None]:
!pip install torch torchvision

In [None]:
!pip install -U tf2onnx onnx2pytorch  onnxruntime

# Early Exit models for CIFAR-10


---



In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.utils import plot_model
from tensorflow.keras.models import Model
import time
import keras

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
assert x_train.shape == (50000, 32, 32, 3)
assert x_test.shape == (10000, 32, 32, 3)
assert y_train.shape == (50000, 1)
assert y_test.shape == (10000, 1)

### EE_resnet8

Build the models from TensorFlow and convert to TFLite to be used in CPU and TPU devices, and ONNX to be implemented in Jetson Devices

In [None]:
## TFLite

model = tf.keras.models.load_model(r"EE_threeLyaersResnet_500Epochs.h5")
model.trainable = False

# Define common model
common = Model(inputs=model.input, outputs=model.layers[18].output)

# Define branch models
branch1 = Model(inputs=model.layers[18].output, outputs=model.layers[-2].output)
branch2 = Model(inputs=model.layers[18].output, outputs=model.layers[-1].output)


# Define custom layer to choose between branches
class ChooseBranchLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(ChooseBranchLayer, self).__init__()
        self.branch1 = branch1
        self.branch2 = branch2

    def call(self, inputs):
        common_output = inputs
        output1 = self.branch1(common_output)
        condition = tf.reduce_max(output1) > 0.90
        return tf.cond(condition, lambda: output1, lambda: self.branch2(common_output))


# Create input layer
inputs = tf.keras.layers.Input(shape=(32, 32, 3))

# Get common output
common_output = common(inputs)

# Use the custom layer to choose output based on condition
final_output = ChooseBranchLayer()(common_output)


# Define the new model
model_EE= tf.keras.Model(inputs=inputs, outputs=final_output)


spec = (tf.TensorSpec((None, 32, 32, 3), tf.float32, name="input"),)
#model_proto, _ = tf2onnx.convert.from_keras(model_EE, input_signature=spec, opset=13, output_path='EE_resnet8.onnx')



converter = tf.lite.TFLiteConverter.from_keras_model(model_EE)
tflite_model = converter.convert()

interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


with open(r"EE_resnet8.tflite", 'wb') as f:
    f.write(tflite_model)

For conversion to ONNX, we use PyTorch. Conversion from TFLite is possible with tf2onnx but we have experienced this pipeline produces a model that is not graph capturable and behaves very poorly in TensorRT. PyTorch produces a model that, when converted to ONNX, works correctly.

There are two implementations: The first just returns the class prediction, which would be the model to use for production, and the second also includes a flag to track which branch was taken, to be used for testing. If you used TFLite, TensorRT will complain if the conditional block returns two outputs, so you would need to append the flag to the output tensor, so you have an N+1 vector in which the last element is the branch indicator. This model would produce a TensorRT engine that is not graph-capturable, as stated before.

When using Pytorch, the model with the branch flag should be used for testing, but the model without the flag is better suited for production because we have observed a relatively high impact on performance when using the flag model.

The thresholds here are arbitrary. You may refer to the paper to value the tradeoff between latency and accuracy and adjust to your needs. For testing, you can also assign 0 so that the branch is always taken, or 1 so that the branch is never taken.

In [None]:
## ONNX model without branch indicator
import tf2onnx
import torch.onnx
import torch
import torch.nn as nn
import numpy as np
import onnx
from onnx2pytorch import ConvertModel

common_model, _ = tf2onnx.convert.from_keras(common, opset=17)
branch1_model, _ = tf2onnx.convert.from_keras(branch1, opset=17)
branch2_model, _ = tf2onnx.convert.from_keras(branch2, opset=17)



class ConditionalModel(nn.Module):
    def __init__(self, common_model, branch1_model, branch2_model):
        super().__init__()

        # Integrate ONNX models as submodules (adjust naming if needed)
        self.common = ConvertModel(common_model)
        self.branch1 = ConvertModel(branch1_model)
        self.branch2 = ConvertModel(branch2_model)

    def forward(self, input_tensor):
        output = self.common(input_tensor)
        branch1_output = self.branch1(output)

        # Find the maximum value in branch1_output
        max_value = torch.max(branch1_output)

        if max_value.item() > 0.9:
            return branch1_output
        else:
            return self.branch2(output)

model = ConditionalModel(common_model, branch1_model, branch2_model)

dummy_input = torch.randn(1, 32, 32, 3)  # Batch size of 1, 3 channels, 32x32 image
torch.onnx.export(model, dummy_input, "../edge/Jetson_Nano/inference/models/EE_resnet8.onnx", input_names=["input"], output_names=["output"])

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset

x_t = torch.tensor(x_test, dtype=torch.float32)
y_t = torch.tensor(y_test, dtype=torch.long)

dataset = TensorDataset(x_t, y_t)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

model.eval()

correct, total = 0, 0

with torch.no_grad():
    for inputs, labels in data_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f'Accuracy: {accuracy * 100:.2f}%')

Accuracy: 84.34%


In [None]:
# Model with flag to track which branch was taken


model = tf.keras.models.load_model(r"EE_threeLyaersResnet_500Epochs.h5")
model.trainable = False

common = Model(inputs=model.input, outputs=model.layers[18].output)
branch1 = Model(inputs=model.layers[18].output, outputs=model.layers[-2].output)
branch2 = Model(inputs=model.layers[18].output, outputs=model.layers[-1].output)

common_model, _ = tf2onnx.convert.from_keras(common, opset=17)
branch1_model, _ = tf2onnx.convert.from_keras(branch1, opset=17)
branch2_model, _ = tf2onnx.convert.from_keras(branch2, opset=17)


class ConditionalModel(nn.Module):
    def __init__(self, common_model, branch1_model, branch2_model):
        super().__init__()

        # Integrate ONNX models as submodules (adjust naming if needed)
        self.common = ConvertModel(common_model)
        self.branch1 = ConvertModel(branch1_model)
        self.branch2 = ConvertModel(branch2_model)

    def forward(self, input_tensor):
        output = self.common(input_tensor)
        branch1_output = self.branch1(output)

        # Find the maximum value in branch1_output
        max_value = torch.max(branch1_output)

        if max_value.item() > 0.9:
          return branch1_output, torch.tensor(0)
        else:
          return self.branch2(output), torch.tensor(1)


model = ConditionalModel(common_model, branch1_model, branch2_model)

dummy_input = torch.randn(1, 32, 32, 3)  # Batch size of 1, 3 channels, 32x32 image
torch.onnx.export(model, dummy_input, "../edge/Jetson_Nano/inference/models/EE_resnet8_flag.onnx", input_names=["input"], output_names=["output"])

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset

x_t = torch.tensor(x_test, dtype=torch.float32)
y_t = torch.tensor(y_test, dtype=torch.long)

dataset = TensorDataset(x_t, y_t)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

model.eval()

correct, total = 0, 0
branch = [0, 0]

with torch.no_grad():
    for inputs, labels in data_loader:
        outputs, eexit = model(inputs)
        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        branch[eexit.item()] += 1

accuracy = correct / total
print(f'Accuracy: {accuracy * 100:.2f}%')
print(f"Branch 0 taken: {branch[0]} times")
print(f"Branch 1 taken: {branch[1]} times")

Accuracy: 84.34%
Branch 0 taken: 5350 times
Branch 1 taken: 4650 times


### Resnet56

Same pipeline as with Resnet8

In [None]:
model = tf.keras.models.load_model(r"EE_resnet56.h5")

# Preprocessing for Resnet56
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]

def prep(x, y):
    x = x / 255.
    x = (x - mean) / std
    return x, y


x_train_p, y_train_p = prep(x_train, y_train)
x_test_p, y_test_p = prep(x_test, y_test)

target_layer_name = 're_lu_21'

# Find the number of the layer
layer_number = None
for i, layer in enumerate(model.layers):
    if layer.name == target_layer_name:
        layer_number = i
        break

print(layer_number)

66


In [None]:
## TFLite
model = tf.keras.models.load_model(r"EE_resnet56.h5")
common = Model(inputs=model.input, outputs=model.layers[3].output)
branch1 = Model(inputs=model.layers[3].output, outputs=model.layers[-3].output)
bb1 = Model(inputs=model.layers[3].output, outputs=model.layers[66].output)
branch2 = Model(inputs=model.layers[66].output, outputs=model.layers[-2].output)
bb2 = Model(inputs=model.layers[66].output, outputs=model.layers[-1].output)


class ChooseBranchLayer1(tf.keras.layers.Layer):
    def __init__(self):
        super(ChooseBranchLayer1, self).__init__()
        self.branch1 = branch1
        self.bb1 = bb1
        self.branch2 = branch2
        self.bb2 = bb2
    def call(self, inputs):
        common_output = inputs
        output1 = self.branch1(common_output)
        output1 = tf.nn.softmax(output1)
        condition = tf.reduce_max(output1) > 0.95
        output = tf.cond(condition, lambda:[output1,0] , lambda: self.continuation(common_output))
        return output
        #return output1
    def continuation(self,inputs):
        common_output2 = self.bb1(inputs)
        output2 = tf.nn.softmax(self.branch2(common_output2))
        condition = tf.reduce_max(output2) > 0.95
        return tf.cond(condition, lambda:[output2,1] , lambda: [tf.nn.softmax(self.bb2(common_output2)),2] )

# Create input layer
inputs = tf.keras.layers.Input(shape=(32, 32, 3))
# Get common output
common_output = common(inputs)
# Use the custom layer to choose output based on condition
final_output, eexit = ChooseBranchLayer1()(common_output)
# Define the new model
EE_resnet56 = tf.keras.Model(inputs=inputs, outputs=[final_output, eexit])


converter = tf.lite.TFLiteConverter.from_keras_model(EE_resnet56)
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

with open(r"EE_resnet56.tflite", 'wb') as f:
    f.write(tflite_model)


In [None]:
correct = 0
total = len(y_test_p)

# Preprocessing for Resnet56
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]

def prep(x, y):
    x = x / 255.
    x = (x - mean) / std
    return x, y


x_test_p, y_test_p = prep(x_test, y_test)

x_test_p = x_test_p.astype('float32')

interpreter.set_tensor(input_details[0]['index'], np.expand_dims(x_test_p[0], axis=0))
interpreter.invoke()
a = interpreter.get_tensor(output_details[0]['index'])
b = interpreter.get_tensor(output_details[1]['index'])

if isinstance(a, np.ndarray):
  index = 0
else:
  index = 1


for input, ground_truth in zip(x_test_p, y_test_p):
    # Set the tensor to point to the input data to be inferred
    interpreter.set_tensor(input_details[0]['index'], np.expand_dims(input, axis=0))

    # Run inference
    interpreter.invoke()

    # Get the output tensor
    output_data = interpreter.get_tensor(output_details[index]['index'])

    # Convert output data to predicted labels
    predicted = np.argmax(output_data[0])

    #print(f'\noutput_data: {output_data}\npredicted: {predicted}\nground_truth: {ground_truth}\n\n')

    if predicted == ground_truth[0]:
      correct += 1

# Calculate accuracy
accuracy = correct / total
print(f'Accuracy: {accuracy * 100:.2f}%')

In [None]:
# ONNX without branch indicator

model = tf.keras.models.load_model(r"EE_resnet56.h5")
common = Model(inputs=model.input, outputs=model.layers[3].output)
branch1 = Model(inputs=model.layers[3].output, outputs=model.layers[-3].output)
bb1 = Model(inputs=model.layers[3].output, outputs=model.layers[66].output)
branch2 = Model(inputs=model.layers[66].output, outputs=model.layers[-2].output)
bb2 = Model(inputs=model.layers[66].output, outputs=model.layers[-1].output)

common_model, _ = tf2onnx.convert.from_keras(common, opset=17)
branch1_model, _ = tf2onnx.convert.from_keras(branch1, opset=17)
bb1_model, _ = tf2onnx.convert.from_keras(bb1, opset=17)
branch2_model, _ = tf2onnx.convert.from_keras(branch2, opset=17)
bb2_model, _ = tf2onnx.convert.from_keras(bb2, opset=17)



class ConditionalModel(nn.Module):
    def __init__(self, common_model, branch1_model, bb1_model, branch2_model, bb2_model):
        super().__init__()

        # Integrate ONNX models as submodules (adjust naming if needed)
        self.common = ConvertModel(common_model)
        self.branch1 = ConvertModel(branch1_model)
        self.bb1 = ConvertModel(bb1_model)
        self.branch2 = ConvertModel(branch2_model)
        self.bb2 = ConvertModel(bb2_model)

    def forward(self, input_tensor):
        common_output = self.common(input_tensor)
        output1 = self.branch1(common_output)
        output1 = torch.nn.functional.softmax(output1, dim=1)

        max_value = torch.max(output1)

        if max_value.item() > 0.95:
            return output1 # First branch exit
        else:
          common_output2 = self.bb1(common_output)
          output2 = torch.nn.functional.softmax(self.branch2(common_output2), dim=1)

          max_value = torch.max(output2)

          if max_value.item() > 0.95:
            return output2 # Second branch exit
          else:
            return torch.nn.functional.softmax(self.bb2(common_output2), dim=1) # Last exit through the main branch



EE_model = ConditionalModel(common_model, branch1_model, bb1_model, branch2_model, bb2_model)

dummy_input = torch.randn(1, 32, 32, 3)  # Batch size of 1, 3 channels, 32x32 image
torch.onnx.export(EE_model, dummy_input, "../edge/Jetson_Nano/inference/models/EE_resnet56.onnx", input_names=["input"], output_names=["output"])

The following blocks may be run if, with the previous model, you get an error message from TRT onnx parser like the following:



```bash
[E] [TRT] ModelImporter.cpp:768: While parsing node number 183 [Squeeze -> "/bb2/Squeeze_model_14/global_average_pooling2d/Mean_Squeeze__1373:0/Squeeze_2_output_0"]:
[E] [TRT] ModelImporter.cpp:769: --- Begin node ---
[E] [TRT] ModelImporter.cpp:770: input: "/bb2/GlobalAveragePool_model_14/global_average_pooling2d/Mean:0/ReduceMean_output_0"
input: "/bb2/Squeeze_model_14/global_average_pooling2d/Mean_Squeeze__1373:0/Unsqueeze_output_0"
output: "/bb2/Squeeze_model_14/global_average_pooling2d/Mean_Squeeze__1373:0/Squeeze_2_output_0"
name: "/bb2/Squeeze_model_14/global_average_pooling2d/Mean_Squeeze__1373:0/Squeeze_2"
op_type: "Squeeze"

[E] [TRT] ModelImporter.cpp:771: --- End node ---
[E] [TRT] ModelImporter.cpp:773: ERROR: ModelImporter.cpp:178 In function parseGraph:
[6] Invalid Node - /bb2/Squeeze_model_14/global_average_pooling2d/Mean_Squeeze__1373:0/Squeeze_2
Squeeze axes input must be an initializer! Try applying constant folding on the model using Polygraphy: https://github.com/NVIDIA/TensorRT/tree/master/tools/Polygraphy/examples/cli/surgeon/02_folding_constants

```


In [None]:
!export POLYGRAPHY_AUTOINSTALL_DEPS=1
!pip install polygraphy
!python3 -m pip install onnx_graphsurgeon --index-url https://pypi.ngc.nvidia.com
!python3 -m pip install colored

In [None]:
!polygraphy surgeon sanitize ../edge/Jetson_Nano/inference/models/EE_resnet56.onnx --fold-constants -o ../edge/Jetson_Nano/inference/models/EE_resnet56.onnx

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset

x_t = torch.tensor(x_test_p, dtype=torch.float32)
y_t = torch.tensor(y_test_p, dtype=torch.long)

dataset = TensorDataset(x_t, y_t)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

EE_model.eval()

correct, total = 0, 0

with torch.no_grad():
    for inputs, labels in data_loader:
        outputs = EE_model(inputs)
        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f'Accuracy: {accuracy * 100:.2f}%')

Accuracy: 90.77%


In [None]:
# ONNX with branch indicator

model = tf.keras.models.load_model(r"EE_resnet56.h5")
common = Model(inputs=model.input, outputs=model.layers[3].output)
branch1 = Model(inputs=model.layers[3].output, outputs=model.layers[-3].output)
bb1 = Model(inputs=model.layers[3].output, outputs=model.layers[66].output)
branch2 = Model(inputs=model.layers[66].output, outputs=model.layers[-2].output)
bb2 = Model(inputs=model.layers[66].output, outputs=model.layers[-1].output)

common_model, _ = tf2onnx.convert.from_keras(common, opset=17)
branch1_model, _ = tf2onnx.convert.from_keras(branch1, opset=17)
bb1_model, _ = tf2onnx.convert.from_keras(bb1, opset=17)
branch2_model, _ = tf2onnx.convert.from_keras(branch2, opset=17)
bb2_model, _ = tf2onnx.convert.from_keras(bb2, opset=17)

class ConditionalModel(nn.Module):
    def __init__(self, common_model, branch1_model, bb1_model, branch2_model, bb2_model):
        super().__init__()

        # Integrate ONNX models as submodules (adjust naming if needed)
        self.common = ConvertModel(common_model)
        self.branch1 = ConvertModel(branch1_model)
        self.bb1 = ConvertModel(bb1_model)
        self.branch2 = ConvertModel(branch2_model)
        self.bb2 = ConvertModel(bb2_model)

    def forward(self, input_tensor):
        common_output = self.common(input_tensor)
        output1 = self.branch1(common_output)
        output1 = torch.nn.functional.softmax(output1, dim=1)

        max_value = torch.max(output1)

        if max_value.item() > 0.95:
            return output1, torch.tensor(0) # First branch exit
        else:
          common_output2 = self.bb1(common_output)
          output2 = torch.nn.functional.softmax(self.branch2(common_output2), dim=1)

          max_value = torch.max(output2)

          if max_value.item() > 0.95:
            return output2, torch.tensor(1) # Second branch exit
          else:
            output3 = torch.nn.functional.softmax(self.bb2(common_output2), dim=1)
            return output3, torch.tensor(2) # Last exit through the main branch



EE_model = ConditionalModel(common_model, branch1_model, bb1_model, branch2_model, bb2_model)

dummy_input = torch.randn(1, 32, 32, 3)  # Batch size of 1, 3 channels, 32x32 image
torch.onnx.export(EE_model, dummy_input, "../edge/Jetson_Nano/inference/models/EE_resnet56_flag.onnx", input_names=["input"], output_names=["output"])

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset

x_t = torch.tensor(x_test_p, dtype=torch.float32)
y_t = torch.tensor(y_test_p, dtype=torch.long)

dataset = TensorDataset(x_t, y_t)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

EE_model.eval()

correct, total = 0, 0
branch = [0, 0, 0]

with torch.no_grad():
    for inputs, labels in data_loader:
        outputs, eexit = EE_model(inputs)
        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        branch[eexit.item()] += 1

accuracy = correct / total
print(f'Accuracy: {accuracy * 100:.2f}%')
print(f"Branch 0 taken: {branch[0]} times")
print(f"Branch 1 taken: {branch[1]} times")
print(f"Branch 2 taken: {branch[2]} times")

Accuracy: 90.77%
Branch 0 taken: 3381 times
Branch 1 taken: 3486 times
Branch 1 taken: 3133 times


## Alexnet

In [None]:

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()


# Preprocessing for AlexNet
mean = [125.307, 122.95, 113.865]
std = [62.9932, 62.0887, 66.7048]

def prep(x, y):
    x = (x - mean) / std
    return x, y


x_train_p, y_train_p = prep(x_train, y_train)
x_test_p, y_test_p = prep(x_test, y_test)

In [None]:
## TFLite

model = tf.keras.models.load_model(r"EE2_Alexnet_30Epochs.h5")

# Define common model
common = Model(inputs=model.input, outputs=model.layers[6].output)

# Define branch models
branch1 = Model(inputs=model.layers[6].output, outputs=model.layers[-2].output)
branch2 = Model(inputs=model.layers[6].output, outputs=model.layers[-1].output)

# Define custom layer to choose between branches
class ChooseBranchLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(ChooseBranchLayer, self).__init__()
        self.branch1 = branch1
        self.branch2 = branch2

    def call(self, inputs):
        common_output = inputs
        output1 = self.branch1(common_output)
        condition = tf.reduce_max(output1) > 0.80
        return tf.cond(condition, lambda: [output1,0], lambda: [self.branch2(common_output),1])

# Create input layer
inputs = tf.keras.layers.Input(shape=(32, 32, 3))

# Get common output
common_output = common(inputs)

# Use the custom layer to choose output based on condition
final_output, eexit = ChooseBranchLayer()(common_output)

# Define the new model
model_EE = tf.keras.Model(inputs=inputs, outputs=[final_output, eexit])


converter = tf.lite.TFLiteConverter.from_keras_model(model_EE)
tflite_model = converter.convert()

interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


with open(r"EE_alexnet.tflite", 'wb') as f:
    f.write(tflite_model)

In [None]:
mean = [125.307, 122.95, 113.865]
std = [62.9932, 62.0887, 66.7048]

def prep(x, y):
    x = (x - mean) / std
    return x, y


x_train_p, y_train_p = prep(x_train, y_train)
x_test_p, y_test_p = prep(x_test, y_test)

x_test_p = x_test_p.astype('float32')

interpreter.set_tensor(input_details[0]['index'], np.expand_dims(x_test_p[0], axis=0))
interpreter.invoke()
a = interpreter.get_tensor(output_details[0]['index'])
b = interpreter.get_tensor(output_details[1]['index'])

if isinstance(a, np.ndarray):
  index = 0
else:
  index = 1

correct = 0
total = len(y_test_p)


for input, ground_truth in zip(x_test_p, y_test_p):
    # Set the tensor to point to the input data to be inferred
    interpreter.set_tensor(input_details[0]['index'], np.expand_dims(input, axis=0))

    # Run inference
    interpreter.invoke()

    # Get the output tensor
    output_data = interpreter.get_tensor(output_details[index]['index'])

    # Convert output data to predicted labels
    predicted = np.argmax(output_data[0])

    #print(f'\noutput_data: {output_data}\npredicted: {predicted}\nground_truth: {ground_truth}\n\n')

    if predicted == ground_truth[0]:
      correct += 1

# Calculate accuracy
accuracy = correct / total
print(f'Accuracy: {accuracy * 100:.2f}%')


Accuracy: 76.41%


In [None]:
## ONNX model without branch indicator

model = tf.keras.models.load_model(r"EE2_Alexnet_30Epochs.h5")
common = Model(inputs=model.input, outputs=model.layers[6].output)
branch1 = Model(inputs=model.layers[6].output, outputs=model.layers[-2].output)
branch2 = Model(inputs=model.layers[6].output, outputs=model.layers[-1].output)

common_model, _ = tf2onnx.convert.from_keras(common, opset=17)
branch1_model, _ = tf2onnx.convert.from_keras(branch1, opset=17)
branch2_model, _ = tf2onnx.convert.from_keras(branch2, opset=17)



class ConditionalModel(nn.Module):
    def __init__(self, common_model, branch1_model, branch2_model):
        super().__init__()

        # Integrate ONNX models as submodules (adjust naming if needed)
        self.common = ConvertModel(common_model)
        self.branch1 = ConvertModel(branch1_model)
        self.branch2 = ConvertModel(branch2_model)

    def forward(self, input_tensor):
        output = self.common(input_tensor)
        branch1_output = self.branch1(output)

        # Find the maximum value in branch1_output
        max_value = torch.max(branch1_output)

        if max_value.item() > .8:
            return branch1_output
        else:
            return self.branch2(output)

model_EE = ConditionalModel(common_model, branch1_model, branch2_model)

dummy_input = torch.randn(1, 32, 32, 3)  # Batch size of 1, 3 channels, 32x32 image
torch.onnx.export(model_EE, dummy_input, "../edge/Jetson_Nano/inference/models/EE_alexnet.onnx", input_names=["input"], output_names=["output"])

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset

# Preprocessing for AlexNet
mean = [125.307, 122.95, 113.865]
std = [62.9932, 62.0887, 66.7048]

def prep(x, y):
    x = (x - mean) / std
    return x, y


x_train_p, y_train_p = prep(x_train, y_train)
x_test_p, y_test_p = prep(x_test, y_test)

x_t = torch.tensor(x_test_p, dtype=torch.float32)
y_t = torch.tensor(y_test_p, dtype=torch.long)

dataset = TensorDataset(x_t, y_t)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

model_EE.eval()

correct, total = 0, 0

with torch.no_grad():
    for inputs, labels in data_loader:
        outputs = model_EE(inputs)
        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = correct / total
print(f'Accuracy: {accuracy * 100:.2f}%')

Accuracy: 76.41%


In [None]:
# Model with flag to track which branch was taken

model = tf.keras.models.load_model(r"EE2_Alexnet_30Epochs.h5")
common = Model(inputs=model.input, outputs=model.layers[6].output)
branch1 = Model(inputs=model.layers[6].output, outputs=model.layers[-2].output)
branch2 = Model(inputs=model.layers[6].output, outputs=model.layers[-1].output)

common_model, _ = tf2onnx.convert.from_keras(common, opset=17)
branch1_model, _ = tf2onnx.convert.from_keras(branch1, opset=17)
branch2_model, _ = tf2onnx.convert.from_keras(branch2, opset=17)


class ConditionalModel(nn.Module):
    def __init__(self, common_model, branch1_model, branch2_model):
        super().__init__()

        # Integrate ONNX models as submodules (adjust naming if needed)
        self.common = ConvertModel(common_model)
        self.branch1 = ConvertModel(branch1_model)
        self.branch2 = ConvertModel(branch2_model)

    def forward(self, input_tensor):
        output = self.common(input_tensor)
        branch1_output = self.branch1(output)

        # Find the maximum value in branch1_output
        max_value = torch.max(branch1_output)

        if max_value.item() > 0.8:
            return branch1_output, torch.tensor(0)
        else:
            return self.branch2(output), torch.tensor(1)

model_EE = ConditionalModel(common_model, branch1_model, branch2_model)

dummy_input = torch.randn(1, 32, 32, 3)  # Batch size of 1, 3 channels, 32x32 image
torch.onnx.export(model_EE, dummy_input, "../edge/Jetson_Nano/inference/models/EE_alexnet_flag.onnx", input_names=["input"], output_names=["output"])

In [None]:
import torch
from torch.utils.data import DataLoader, TensorDataset

# Preprocessing for AlexNet
mean = [125.307, 122.95, 113.865]
std = [62.9932, 62.0887, 66.7048]

def prep(x, y):
    x = (x - mean) / std
    return x, y


x_train_p, y_train_p = prep(x_train, y_train)
x_test_p, y_test_p = prep(x_test, y_test)

x_t = torch.tensor(x_test_p, dtype=torch.float32)
y_t = torch.tensor(y_test_p, dtype=torch.long)

dataset = TensorDataset(x_t, y_t)
data_loader = DataLoader(dataset, batch_size=1, shuffle=False)

model_EE.eval()

correct, total = 0, 0
branch = [0, 0]

with torch.no_grad():
    for inputs, labels in data_loader:
        outputs, eexit = model_EE(inputs)
        _, predicted = torch.max(outputs, 1)

        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        branch[eexit.item()] += 1

accuracy = correct / total
print(f'Accuracy: {accuracy * 100:.2f}%')
print(f"Branch 0 taken: {branch[0]} times")
print(f"Branch 1 taken: {branch[1]} times")

Accuracy: 76.41%
Branch 0 taken: 6138 times
Branch 1 taken: 3862 times


In [None]:
model = tf.keras.models.load_model(r"EE2_Alexnet_30Epochs.h5")

# Define common model
common = Model(inputs=model.input, outputs=model.layers[6].output)

# Define branch models
branch1 = Model(inputs=model.layers[6].output, outputs=model.layers[-2].output)
branch2 = Model(inputs=model.layers[6].output, outputs=model.layers[-1].output)

# Define custom layer to choose between branches
class ChooseBranchLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(ChooseBranchLayer, self).__init__()
        self.branch1 = branch1
        self.branch2 = branch2

    def call(self, inputs):
        common_output = inputs
        output1 = self.branch1(common_output)
        condition = tf.reduce_max(output1) > 0.80
        return tf.cond(condition, lambda: [output1,0], lambda: [self.branch2(common_output),1])

# Create input layer
inputs = tf.keras.layers.Input(shape=(32, 32, 3))

# Get common output
common_output = common(inputs)

# Use the custom layer to choose output based on condition
final_output = ChooseBranchLayer()(common_output)

# Define the new model
model_EE = tf.keras.Model(inputs=inputs, outputs=final_output)


converter = tf.lite.TFLiteConverter.from_keras_model(model_EE)
tflite_model = converter.convert()

interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


with open(r"EE_alexnet.tflite", 'wb') as f:
    f.write(tflite_model)

In [None]:
import numpy as np

# Prepare the interpreter
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

# Get input and output tensor indices
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

def run_inference(image):
    interpreter.set_tensor(input_details[0]['index'], image)
    interpreter.invoke()
    output = interpreter.get_tensor(output_details[0]['index'])
    eexit = interpreter.get_tensor(output_details[1]['index'])
    return output, eexit

# Evaluate the model on the test dataset
correct_predictions = 0
total_predictions = 0
branches = [0, 0]

for i in range(len(x_test_p)):
    # Prepare input data
    input_data = np.expand_dims(x_test_p[i], axis=0).astype(np.float32)

    # Run inference
    eexit, output = run_inference(input_data)

    # Convert output to class prediction
    predicted_class = np.argmax(output[0])
    true_class = np.argmax(y_test_p[i])

    branches[int(eexit)] += 1

    # Update the accuracy metrics
    if predicted_class == true_class:
        correct_predictions += 1
    total_predictions += 1

# Calculate accuracy
accuracy = correct_predictions / total_predictions
print(f'Accuracy: {accuracy * 100:.2f}%')
print(f"Branch 0 taken: {branches[0]} times")
print(f"Branch 1 taken: {branches[1]} times")


Accuracy: 12.84%
Branch 0 taken: 9892 times
Branch 1 taken: 108 times


# Imagenet

---



## Get base models

### Tensorflow

The models are available with pretrained weights in `torchvision`. However, they are in Pytorch and for Raspberry and Coral Micro, we need to convert them to TensorFlow. Surprisingly, despite been at the moment the two most used frameworks, there are not many reliable libraries to directly convert from one to the other. What we have found to be the most reliable way to port them is to create the architecture in TensorFlow and manually copy the weights from the Pytorch model.

In [None]:
import numpy as np
import torch
import torchvision.models as models
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
import math
from tensorflow import keras
from tensorflow.keras import layers
from keras.applications.resnet50 import ResNet50
from keras.layers import Input, Conv2D, GlobalAveragePooling2D, Dense, BatchNormalization, Activation, MaxPooling2D
from keras.models import Model
from keras.layers import concatenate, Dropout, Flatten
from torchsummary import summary
from tensorflow.keras.applications.resnet50 import ResNet50

#### Resnet18

In [None]:
kaiming_normal = keras.initializers.VarianceScaling(scale=2.0, mode='fan_out', distribution='untruncated_normal')

def conv3x3(x, out_planes, stride=1, name=None):
    x = layers.ZeroPadding2D(padding=1, name=f'{name}_pad')(x)
    return layers.Conv2D(filters=out_planes, kernel_size=3, strides=stride, use_bias=False, kernel_initializer=kaiming_normal, name=name)(x)

def basic_block(x, planes, stride=1, downsample=None, name=None):
    identity = x

    out = conv3x3(x, planes, stride=stride, name=f'{name}.conv1')
    out = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name=f'{name}.bn1')(out)
    out = layers.ReLU(name=f'{name}.relu1')(out)

    out = conv3x3(out, planes, name=f'{name}.conv2')
    out = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name=f'{name}.bn2')(out)

    if downsample is not None:
        for layer in downsample:
            identity = layer(identity)

    out = layers.Add(name=f'{name}.add')([identity, out])
    out = layers.ReLU(name=f'{name}.relu2')(out)

    return out

def make_layer(x, planes, blocks, stride=1, name=None):
    downsample = None
    inplanes = x.shape[3]
    if stride != 1 or inplanes != planes:
        downsample = [
            layers.Conv2D(filters=planes, kernel_size=1, strides=stride, use_bias=False, kernel_initializer=kaiming_normal, name=f'{name}.0.downsample.0'),
            layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name=f'{name}.0.downsample.1'),
        ]

    x = basic_block(x, planes, stride, downsample, name=f'{name}.0')
    for i in range(1, blocks):
        x = basic_block(x, planes, name=f'{name}.{i}')

    return x

def resnet(x, blocks_per_layer, num_classes=1000):
    x = layers.ZeroPadding2D(padding=3, name='conv1_pad')(x)
    x = layers.Conv2D(filters=64, kernel_size=7, strides=2, use_bias=False, kernel_initializer=kaiming_normal, name='conv1')(x)
    x = layers.BatchNormalization(momentum=0.9, epsilon=1e-5, name='bn1')(x)
    x = layers.ReLU(name='relu1')(x)
    x = layers.ZeroPadding2D(padding=1, name='maxpool_pad')(x)
    x = layers.MaxPool2D(pool_size=3, strides=2, name='maxpool')(x)

    x = make_layer(x, 64, blocks_per_layer[0], name='layer1')
    x = make_layer(x, 128, blocks_per_layer[1], stride=2, name='layer2')
    x = make_layer(x, 256, blocks_per_layer[2], stride=2, name='layer3')
    x = make_layer(x, 512, blocks_per_layer[3], stride=2, name='layer4')

    x = layers.GlobalAveragePooling2D(name='avgpool')(x)
    initializer = keras.initializers.RandomUniform(-1.0 / math.sqrt(512), 1.0 / math.sqrt(512))
    x = layers.Dense(units=num_classes, kernel_initializer=initializer, bias_initializer=initializer, name='fc')(x)

    return x

def resnet18(x, **kwargs):
    return resnet(x, [2, 2, 2, 2], **kwargs)


torch_model = models.resnet18(pretrained=True)
torch_model.eval()


inputs = keras.Input(shape=(224, 224, 3))
outputs = resnet18(inputs)
model = keras.Model(inputs, outputs)

In [None]:
# Load pytorch weights
state_dict = torch_model.state_dict()
for layer in model.layers:
    if isinstance(layer, layers.Conv2D):
        layer.set_weights([state_dict[f'{layer.name}.weight'].numpy().transpose((2, 3, 1, 0))])
    elif isinstance(layer, layers.Dense):
        layer.set_weights([
            state_dict[f'{layer.name}.weight'].numpy().transpose(),
            state_dict[f'{layer.name}.bias'].numpy()
        ])
    elif isinstance(layer, layers.BatchNormalization):
        keys = ['weight', 'bias', 'running_mean', 'running_var']
        layer.set_weights([state_dict[f'{layer.name}.{key}'].numpy() for key in keys])


In [None]:
# Compare outputs
input_batch = np.random.rand(1, 224, 224, 3).astype(model.dtype)
output = model(input_batch).numpy()
with torch.no_grad():
    torch_output = torch_model(torch.tensor(input_batch.transpose((0, 3, 1, 2)))).numpy()
print(np.abs(output - torch_output).max())

3.5762787e-06


In [None]:
model.save('resnet18_imagenet.h5')

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

with open('resnet18_tflite.tflite', 'wb') as f:
    f.write(tflite_model)


# Pytorch -> ONNX
torch.onnx.export(torch_model, torch.tensor(input_batch.transpose((0, 3, 1, 2))), "../edge/Jetson_Nano/inference/models/resnet18_imagenet.onnx", input_names=["input"], output_names=["output"])

#### Alexnet

In [None]:
alexnet_pytorch = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=True)
alexnet_pytorch.eval()

In [None]:
state_dict = alexnet_pytorch.state_dict()
state_dict

In [None]:
def alexnet(img_input, classes=1000):
  # 1st conv layer
  x = layers.ZeroPadding2D(padding=2)(img_input)
  x = Conv2D(64, (11, 11), strides=(4, 4), padding='valid',
             activation='relu', kernel_initializer='uniform', use_bias = True)(x)  # valid
  x = MaxPooling2D(pool_size=(3, 3), strides=(
      2, 2), padding='valid', data_format='channels_last')(x)
  #x = BatchNormalization()(x)

  # 2nd conv layer
  x = layers.ZeroPadding2D(padding=2)(x)
  x = Conv2D(192, (5, 5), strides=(1, 1), padding='valid',
             activation='relu', kernel_initializer='uniform', use_bias = True)(x)
  x = MaxPooling2D(pool_size=(3, 3), strides=(
      2, 2), padding='valid', data_format='channels_last')(x)
  #x = BatchNormalization()(x)

  # 3rd conv layer
  x = layers.ZeroPadding2D(padding=1)(x)
  x = Conv2D(384, (3, 3), strides=(1, 1), padding='valid',
             activation='relu', kernel_initializer='uniform', use_bias = True)(x)
  #x = BatchNormalization()(x)

  # 4th conv layer
  x = layers.ZeroPadding2D(padding=1)(x)
  x = Conv2D(256, (3, 3), strides=(1, 1), padding='valid',
             activation='relu', kernel_initializer='uniform', use_bias = True)(x)
  #x = BatchNormalization()(x)

  # 5th conv layer
  x = layers.ZeroPadding2D(padding=1)(x)
  x = Conv2D(256, (3, 3), strides=(1, 1), padding='valid',
             activation='relu', kernel_initializer='uniform', use_bias = True)(x)
  x = MaxPooling2D(pool_size=(3, 3), strides=(
      2, 2), padding='valid', data_format='channels_last')(x)
  #x = BatchNormalization()(x)

  #x = tf.keras.layers.AveragePooling2D(
    #1,
    #strides=1,
    #padding='valid',
    #data_format='channels_last'
    #)(x)
  #x = tf.reshape(x,[1,256,6,6])
  # flattening before sending to fully connected layers
  x = tf.keras.layers.Permute((3,1,2))(x)
  #x  = tf.transpose(x, perm=[0,3,1,2])
  print(x.shape)
  x = Flatten(data_format = 'channels_last')(x)
  print(x.shape)
  # fully connected layers
  x = Dense(4096, activation='relu', use_bias = True)(x)
  #x = Dropout(0.5)(x)
  #x = BatchNormalization()(x)
  x = Dense(4096, activation='relu', use_bias = True)(x)
  #x = Dropout(0.5)(x)
  #x = BatchNormalization()(x)

  # output layer
  out = Dense(1000, use_bias = True)(x)#, activation='softmax')(x)
  return out

In [None]:
inputs = keras.Input(shape=(224, 224, 3))
outputs = alexnet(inputs)
alexnet_tf = keras.Model(inputs, outputs)

(None, 256, 6, 6)
(None, 9216)


Weight loading

In [None]:
alexnet_tf.layers[2].set_weights([state_dict['features.0.weight'].numpy().transpose((2, 3, 1, 0)),
                                  state_dict['features.0.bias'].numpy()])
alexnet_tf.layers[5].set_weights([state_dict['features.3.weight'].numpy().transpose((2, 3, 1, 0)),
                                  state_dict['features.3.bias'].numpy()])
alexnet_tf.layers[8].set_weights([state_dict['features.6.weight'].numpy().transpose((2, 3, 1, 0)),
                                  state_dict['features.6.bias'].numpy()])
alexnet_tf.layers[10].set_weights([state_dict['features.8.weight'].numpy().transpose((2, 3, 1, 0)),
                                   state_dict['features.8.bias'].numpy()])
alexnet_tf.layers[12].set_weights([state_dict['features.10.weight'].numpy().transpose((2, 3, 1, 0)),
                                   state_dict['features.10.bias'].numpy()])

In [None]:
alexnet_tf.layers[16].set_weights([
            state_dict['classifier.1.weight'].numpy().transpose(),
            state_dict['classifier.1.bias'].numpy()])
alexnet_tf.layers[17].set_weights([
            state_dict['classifier.4.weight'].numpy().transpose(),
            state_dict['classifier.4.bias'].numpy()])
alexnet_tf.layers[18].set_weights([
            state_dict['classifier.6.weight'].numpy().transpose(),
            state_dict['classifier.6.bias'].numpy()])

Testing

In [None]:
input_batch = np.random.rand(1, 224, 224, 3).astype(alexnet_tf.dtype)


intermediate_outputs = []
def hook(module, input, output):
    intermediate_outputs.append(output)

# Register hook to the desired intermediate layer

#desired_layer = alexnet_pytorch.features.0
hook_handle = alexnet_pytorch.classifier[6].register_forward_hook(hook)


# Forward pass
with torch.no_grad():
    alexnet_pytorch(torch.tensor(input_batch.transpose((0, 3, 1, 2))))

# Access intermediate outputs
hook_handle.remove()
print(len(intermediate_outputs))

1


In [None]:
print(intermediate_outputs[0].shape)

outputs1 = alexnet_tf.layers[18].output
parcial = Model(inputs, outputs1)
output1 = parcial(input_batch).numpy()
print(output1.shape)

print(np.abs(output1- intermediate_outputs[0].numpy()).max())#.transpose(0,3,1,2)

torch.Size([1, 1000])
(1, 1000)
1.8775463e-06


Saving

In [None]:
import torch.nn as nn

alexnet_tf.save('alexnet_imagenet.h5')

converter = tf.lite.TFLiteConverter.from_keras_model(alexnet_tf)
tflite_model = converter.convert()

with open('alexnet_tflite_def.tflite', 'wb') as f:
    f.write(tflite_model)

# Pytorch -> ONNX
alexnet_pytorch.fc = nn.Sequential(
    *alexnet_pytorch.fc,
    nn.Softmax(),
)
torch.onnx.export(alexnet_pytorch, torch.tensor(input_batch.transpose((0, 3, 1, 2))), "../edge/Jetson_Nano/inference/models/alexnet_imagenet.onnx", input_names=["input"], output_names=["output"])

#### Resnet50


In [None]:
torch_model = models.resnet50(pretrained=True)
torch_model.eval()

summary(torch_model, (3, 224,224))

In [None]:
def residual_block_v1(
    x, filters, kernel_size=3, stride=1, conv_shortcut=True, name=None):

    bn_axis = 3
    if conv_shortcut:
        shortcut = layers.Conv2D(4 * filters, 1, strides=stride, name=name + ".downsample.0", use_bias = False)(x)
        shortcut = layers.BatchNormalization(axis=bn_axis, epsilon=1.00e-5, momentum = 0.9, name=name + ".downsample.1")(shortcut)
    else:
        shortcut = x

    x = layers.Conv2D(filters, 1, strides=1, name=name + ".conv1", use_bias = False)(x) # stride was = stride
    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.00e-5, momentum= 0.9, name=name + ".bn1")(x)
    x = layers.Activation("relu", name=name + "_1_relu")(x)

    x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name=name+".pad2")(x)
    x = layers.Conv2D(filters, kernel_size, strides = stride, padding="valid", name=name + ".conv2", use_bias = False)(x)
    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.00e-5, momentum = 0.9, name=name + ".bn2")(x)
    x = layers.Activation("relu", name=name + "_2_relu")(x)

    x = layers.Conv2D(4 * filters, 1, strides = 1, name=name + ".conv3", use_bias = False)(x)
    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.00e-5,momentum = 0.9, name=name + ".bn3")(x)

    x = layers.Add(name=name + "_add")([shortcut, x])
    x = layers.Activation("relu", name=name + ".relu")(x)
    return x

def stack_residual_blocks_v1(x, filters, blocks, stride1=2, name=None):
    x = residual_block_v1(x, filters, stride=stride1, name=name + ".0")
    for i in range(1, blocks):
        x = residual_block_v1(x, filters, conv_shortcut=False, name=name + "." + str(i))
    return x

In [None]:
img_input = layers.Input(shape = (224,224,3))
def custom_model(input):
  x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)), name="conv1_pad")(input)
  x = layers.Conv2D(64, 7, strides=2, use_bias=False, name="conv1")(x)
  x = layers.BatchNormalization(axis=3, epsilon=1.00e-5, momentum = 0.9, name="bn1")(x)
  x = layers.Activation("relu", name="conv1_relu")(x)
  x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name="pool1_pad")(x)
  x = layers.MaxPooling2D(3, strides=2, name="maxpool")(x)
  x = stack_residual_blocks_v1(x, 64, 3, stride1=1, name="layer1")
  x = stack_residual_blocks_v1(x, 128, 4, name="layer2")
  x = stack_residual_blocks_v1(x, 256, 6, name="layer3")
  x = stack_residual_blocks_v1(x, 512, 3, name="layer4")
  x = layers.GlobalAveragePooling2D(name="avgpool")(x)
  x = layers.Dense(1000, activation='linear', name="fc")(x)
  return x


In [None]:
for name, layer in torch_model.named_modules():
    print(name)


conv1
bn1
relu
maxpool
layer1
layer1.0
layer1.0.conv1
layer1.0.bn1
layer1.0.conv2
layer1.0.bn2
layer1.0.conv3
layer1.0.bn3
layer1.0.relu
layer1.0.downsample
layer1.0.downsample.0
layer1.0.downsample.1
layer1.1
layer1.1.conv1
layer1.1.bn1
layer1.1.conv2
layer1.1.bn2
layer1.1.conv3
layer1.1.bn3
layer1.1.relu
layer1.2
layer1.2.conv1
layer1.2.bn1
layer1.2.conv2
layer1.2.bn2
layer1.2.conv3
layer1.2.bn3
layer1.2.relu
layer2
layer2.0
layer2.0.conv1
layer2.0.bn1
layer2.0.conv2
layer2.0.bn2
layer2.0.conv3
layer2.0.bn3
layer2.0.relu
layer2.0.downsample
layer2.0.downsample.0
layer2.0.downsample.1
layer2.1
layer2.1.conv1
layer2.1.bn1
layer2.1.conv2
layer2.1.bn2
layer2.1.conv3
layer2.1.bn3
layer2.1.relu
layer2.2
layer2.2.conv1
layer2.2.bn1
layer2.2.conv2
layer2.2.bn2
layer2.2.conv3
layer2.2.bn3
layer2.2.relu
layer2.3
layer2.3.conv1
layer2.3.bn1
layer2.3.conv2
layer2.3.bn2
layer2.3.conv3
layer2.3.bn3
layer2.3.relu
layer3
layer3.0
layer3.0.conv1
layer3.0.bn1
layer3.0.conv2
layer3.0.bn2
layer3.0.conv

In [None]:
custom_tf_model = Model(img_input, custom_model(img_input))

custom_tf_model.summary()

state_dict = torch_model.state_dict()
state_dict

In [None]:
for layer in custom_tf_model.layers:
    if isinstance(layer, layers.Conv2D):
        layer.set_weights([state_dict[f'{layer.name}.weight'].numpy().transpose((2, 3, 1, 0))])
    elif isinstance(layer, layers.Dense):
        layer.set_weights([
            state_dict[f'{layer.name}.weight'].numpy().transpose(),
            state_dict[f'{layer.name}.bias'].numpy()
        ])
    elif isinstance(layer, layers.BatchNormalization):
        keys = ['weight', 'bias', 'running_mean', 'running_var']
        layer.set_weights([state_dict[f'{layer.name}.{key}'].numpy() for key in keys])

In [None]:
# Compare outputs
input_batch = np.random.rand(1, 224, 224, 3).astype(custom_tf_model.dtype)
output = custom_tf_model(input_batch).numpy()
with torch.no_grad():
    torch_output = torch_model(torch.tensor(input_batch.transpose((0, 3, 1, 2)))).numpy()
print(np.abs(output - torch_output).max())

4.053116e-06


In [None]:
custom_tf_model.save('resnet50_imagenet.h5')

converter = tf.lite.TFLiteConverter.from_keras_model(custom_tf_model)
resnet50_tflite = converter.convert()

with open('resnet50_tflite.tflite', 'wb') as f:
    f.write(resnet50_tflite)


torch.onnx.export(torch_model, torch.tensor(input_batch.transpose((0, 3, 1, 2))), "../edge/Jetson_Nano/inference/models/resnet50_imagenet.onnx", input_names=["input"], output_names=["output"])

From the previous models, add a softmax layer and convert to ONNX

In [None]:
import tensorflow as tf
import tf2onnx
import onnxruntime as rt

models = ['resnet18_imagenet', 'resnet50_imagenet', 'alexnet_imagenet']

for mod in models:
  model = tf.keras.models.load_model(
      mod + '.h5', custom_objects=None, compile=True, safe_mode=True
  )

  spec = (tf.TensorSpec((None, 224, 224, 3), tf.float32, name="input"),)

  # Add a softmax layer to the model
  softmax_layer = tf.keras.layers.Softmax()
  output = softmax_layer(model.output)

  # Create a new model with the softmax layer
  model_with_softmax = tf.keras.Model(inputs=model.input, outputs=output)
  preds = model_with_softmax.predict(x)

  output_path = mod + ".onnx"

  model_proto, _ = tf2onnx.convert.from_keras(model_with_softmax, input_signature=spec, opset=17, output_path=output_path)
  output_names = [n.name for n in model_proto.graph.output]

### Big model RegNetXT for Jetson Orin

In [None]:
from torchvision.models import regnet_y_32gf, RegNet_Y_32GF_Weights
import torch


weights = RegNet_Y_32GF_Weights.IMAGENET1K_V2
model = regnet_y_32gf(weights=weights)
model.eval()

# export to ONNX
dummy_input = torch.randn(1, 3, 224, 224)  # Batch size of 1, 3 channels, 224x224 image
torch.onnx.export(model, dummy_input, "../edge/Jetson_Nano/inference/models/regnet.onnx", input_names=["input"], output_names=["output"])

## Early Exit for ImageNet models



> This was not included in the paper because basic Branchynet implementation lost too much accuracy in complex datasets like ImageNet. This part is kept as a future improvement if techniques like self-distillaition (https://ieeexplore.ieee.org/document/9381661) are used to train the models



## Resnet18

In [None]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import Model

In [None]:
model = tf.keras.models.load_model(r"EE_resnet18_two_newweights_4.h5")

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

def preprocess_image(image):
    image = Image.open(image).convert('RGB')
    image = tf.convert_to_tensor(image)
    image = tf.cast(image, tf.float32) / 255.0
    image = (image -mean)/std
    image = tf.reshape(image, (1, 224, 224, 3))
    return image


In [None]:
# Model for inference using TFLite

# Define common model
common = Model(inputs=model.input, outputs=model.layers[44].output)

# Define branch models
branch1 = Model(inputs=model.layers[44].output, outputs=tf.nn.softmax(model.layers[-2].output))
branch2 = Model(inputs=model.layers[44].output, outputs=tf.nn.softmax(model.layers[-1].output))


# Define custom layer to choose between branches
class ChooseBranchLayer(tf.keras.layers.Layer):
    def __init__(self):
        super(ChooseBranchLayer, self).__init__()
        self.branch1 = branch1
        self.branch2 = branch2

    def call(self, inputs):
        common_output = inputs
        output1 = self.branch1(common_output)
        condition = ((tf.reduce_max(output1) > 1.0)) #| (tf.reduce_max(output1) < 0.5))
        return tf.cond(condition, lambda: [output1,0], lambda: [self.branch2(common_output),1])

# Create input layer
inputs = tf.keras.layers.Input(shape=(224, 224, 3))

# Get common output
common_output = common(inputs)

# Use the custom layer to choose output based on condition
final_output, eexit = ChooseBranchLayer()(common_output)


# Define the new model
model_for_inference = tf.keras.Model(inputs=inputs, outputs=[final_output, eexit])

### Quantization

In [None]:
def representative_data_generator():
    data = tf.data.Dataset.from_tensor_slices(imagenes.astype(np.float32)).batch(1)
    for input_value in data:
        yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model_for_inference)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_generator
# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8
tflite_model = converter.convert()


interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()

input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

input_scale, input_zero_point = input_details[0]['quantization']
output_scale, output_zero_point = output_details[1]['quantization']

with open(r"EE_resnet18_quant_mainbranch.tflite", 'wb') as f:
    f.write(tflite_model)


# Server Models

Here we show how to obtain the models that will run on the server

## CIFAR-10 :  ViT-H/14

We could not find this model pretrained for CIFAR-10 (it is for ImageNet-21k and ImageNet-21k), so we had to fine tune it to CIFAR-10. For this, you can use the official repository:

https://github.com/google-research/vision_transformer

Which uses JAX. You could also use the following if you prefer PyTorch:

https://github.com/bwconrad/vit-finetune


If you choose the latter, you can do the following to prepare the model:

In [None]:
from src.model import ClassificationModel
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os

home_dir = os.path.expanduser("~")

ckpt_path = "output/default/version_0/checkpoints/best-step-step=97000-val_acc=0.9847.ckpt" # Or whatever name your best checkpoint has
model = ClassificationModel.load_from_checkpoint(ckpt_path)
model = model.to("cpu")
model.eval()
dummy_input = torch.randn(1, 3, 224, 224, device="cpu")

onnx_model_path = os.path.join(home_dir, "model/big/vith14.onnx") # Change this to wherever you want to save the ONNX model
torch.onnx.export(model.net,
                dummy_input,
                onnx_model_path,
                export_params=True,
                opset_version=17,
                do_constant_folding=True
                input_names=['input'],
                output_names=['output'],
)

In [None]:
!python3 -m onnxruntime.transformers.optimizer --input ${HOME}/model/big/vith14.onnx --output ${HOME}/model/opt/vith14.onnx --hidden_size 1280 --num_heads 16 --opt_level 0 --float16 --use_gpu

This will produce a model that behaves well with ONNX runtime CUDA EP

## ImageNet-1k :  ConvNeXT (xlarge-sized model)

This one is available in the `transformers` module

In [None]:
from transformers import ConvNextForImageClassification,Data2VecVisionForImageClassification
import torch

torch_model = ConvNextForImageClassification.from_pretrained("facebook/convnext-xlarge-224-22k-1k")

torch_input = torch.randn(1, 3, 224, 224)

# Export the model
torch.onnx.export(torch_model,                   # our model to save
                  torch_input,                   # model input
                  "../server/convnext.onnx",     # path to save model
                  export_params=True,            # store the trained parameter weights inside the model file
                  opset_version=17,              # the ONNX opset version to export the model to
                  do_constant_folding=True,      # whether to execute constant folding for optimization
                  )