In [14]:
import numpy as np 
import tensorflow as tf 
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Dense, Conv1D, Flatten, ReLU, Softmax 
from tensorflow.keras.models import Sequential
import os
MODELS_DIR = 'models/'
if not os.path.exists(MODELS_DIR):
    os.mkdir(MODELS_DIR)
MODEL_TF = MODELS_DIR + 'model'
MODEL_NO_QUANT_TFLITE = MODELS_DIR + 'model_no_quant.tflite'
MODEL_TFLITE = MODELS_DIR + 'model.tflite'
MODEL_TFLITE_MICRO = MODELS_DIR + 'model.cc'

In [15]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

y_train = to_categorical(y_train)
y_test = to_categorical(y_test)
X_train = X_train.reshape(-1,784,1)
X_test = X_test.reshape(-1,784,1)


In [16]:
model = Sequential()

model.add(Conv1D(4, kernel_size=(3), input_shape=(784,1),activation='relu'))
model.add(Conv1D(2, kernel_size=(3), activation='relu'))

model.add(Flatten())
model.add(Dense(10))

model.add(Softmax())

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit(X_train,y_train,validation_data=(X_test,y_test),epochs=5)


model.save(MODEL_TF)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
INFO:tensorflow:Assets written to: models/model/assets


In [17]:
model.predict(X_test)

array([[2.8925133e-12, 1.2626830e-10, 5.8483089e-09, ..., 9.9997520e-01,
        2.7979166e-08, 1.1449767e-06],
       [1.0839684e-03, 9.1008187e-06, 9.9061626e-01, ..., 7.2358986e-13,
        3.7775229e-05, 9.4273565e-12],
       [1.4993768e-11, 9.9970168e-01, 2.8160727e-04, ..., 1.7151183e-06,
        1.2704593e-05, 1.4351219e-10],
       ...,
       [2.8050435e-11, 1.2745832e-12, 4.9266919e-09, ..., 1.4752751e-06,
        1.1376881e-05, 6.4100292e-05],
       [2.0542021e-07, 2.0725177e-09, 6.5189670e-10, ..., 7.1838460e-09,
        1.0661478e-01, 5.0935265e-09],
       [7.7583088e-08, 6.3182288e-14, 7.2918997e-06, ..., 1.2284613e-11,
        1.7329421e-07, 3.5140504e-06]], dtype=float32)

In [18]:

converter = tf.lite.TFLiteConverter.from_saved_model(MODEL_TF)
model_no_quant_tflite = converter.convert()

# Save the model to disk
open(MODEL_NO_QUANT_TFLITE, "wb").write(model_no_quant_tflite)
# Convert the model to the TensorFlow Lite format with quantization
def representative_dataset():
  for i in range(500):
    yield([(X_train[i].reshape(-1,784,1)).astype(np.float32)])
# Set the optimization flag.
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# Enforce integer only quantization
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8
# Provide a representative dataset to ensure we quantize correctly.
converter.representative_dataset = representative_dataset
# print((X_train[1][0]))

model_tflite = converter.convert()

# Save the model to disk
open(MODEL_TFLITE, "wb").write(model_tflite)

y_test_pred_tf = model.predict(X_test)


2021-08-21 12:14:04.163682: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:345] Ignored output_format.
2021-08-21 12:14:04.163749: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:348] Ignored drop_control_dependency.
2021-08-21 12:14:04.163760: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:354] Ignored change_concat_input_ranges.
2021-08-21 12:14:04.164052: I tensorflow/cc/saved_model/reader.cc:38] Reading SavedModel from: models/model
2021-08-21 12:14:04.166646: I tensorflow/cc/saved_model/reader.cc:90] Reading meta graph with tags { serve }
2021-08-21 12:14:04.166686: I tensorflow/cc/saved_model/reader.cc:132] Reading SavedModel debug info (if present) from: models/model
2021-08-21 12:14:04.175951: I tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle.
2021-08-21 12:14:04.233642: I tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: models/model
2021-08-21 12:1

In [19]:
def predict_tflite(tflite_model, x_test):
  # Prepare the test data
  x_test_ = x_test.copy()
  x_test_ = x_test_.reshape((-1,784,1))
  x_test_ = x_test_.astype(np.float32)

  # Initialize the TFLite interpreter
  interpreter = tf.lite.Interpreter(model_content=tflite_model)
  interpreter.allocate_tensors()

  input_details = interpreter.get_input_details()[0]
  output_details = interpreter.get_output_details()[0]

  # If required, quantize the input layer (from float to integer)
  input_scale, input_zero_point = input_details["quantization"]
  if (input_scale, input_zero_point) != (0.0, 0):
    x_test_ = x_test_ / input_scale + input_zero_point
    x_test_ = x_test_.astype(input_details["dtype"])
  
  # Invoke the interpreter
  y_pred = np.empty((x_test_.shape[0],10), dtype=output_details["dtype"])
  for i in range(len(x_test_)):
    interpreter.set_tensor(input_details["index"], [x_test_[i]])
    interpreter.invoke()
    y_pred[i] = interpreter.get_tensor(output_details["index"])[0]
  
  # If required, dequantized the output layer (from integer to float)
  output_scale, output_zero_point = output_details["quantization"]
  if (output_scale, output_zero_point) != (0.0, 0):
    y_pred = y_pred.astype(np.float32)
    y_pred = (y_pred - output_zero_point) * output_scale

  return y_pred

def evaluate_tflite(tflite_model, x_test, y_true):
  global model
  y_pred = predict_tflite(tflite_model, x_test)
  cce = tf.keras.losses.CategoricalCrossentropy()  
  loss = cce(y_true, y_pred).numpy()
  return loss


In [20]:
y_test_pred_tf = model.predict(X_test)
y_test_pred_no_quant_tflite = predict_tflite(model_no_quant_tflite, X_test)
y_test_pred_tflite = predict_tflite(model_tflite, X_test)

In [21]:
loss_tf, _ = model.evaluate(X_test, y_test, verbose=0)
loss_no_quant_tflite = evaluate_tflite(model_no_quant_tflite, X_test, y_test)
loss_tflite = evaluate_tflite(model_tflite, X_test, y_test)

In [22]:
import pandas as pd
df = pd.DataFrame.from_records(
    [["TensorFlow", loss_tf],
     ["TensorFlow Lite", loss_no_quant_tflite],
     ["TensorFlow Lite Quantized", loss_tflite]],
     columns = ["Model", "Loss/MSE"], index="Model").round(4)
df

Unnamed: 0_level_0,Loss/MSE
Model,Unnamed: 1_level_1
TensorFlow,0.1886
TensorFlow Lite,0.1886
TensorFlow Lite Quantized,0.6976


In [23]:
size_tf = os.path.getsize(MODEL_TF)
size_no_quant_tflite = os.path.getsize(MODEL_NO_QUANT_TFLITE)
size_tflite = os.path.getsize(MODEL_TFLITE)
pd.DataFrame.from_records(
    [["TensorFlow", f"{size_tf} bytes", ""],
     ["TensorFlow Lite", f"{size_no_quant_tflite} bytes ", f"(reduced by {size_tf - size_no_quant_tflite} bytes)"],
     ["TensorFlow Lite Quantized", f"{size_tflite} bytes", f"(reduced by {size_no_quant_tflite - size_tflite} bytes)"]],
     columns = ["Model", "Size", ""], index="Model")

Unnamed: 0_level_0,Size,Unnamed: 2_level_0
Model,Unnamed: 1_level_1,Unnamed: 2_level_1
TensorFlow,4096 bytes,
TensorFlow Lite,65972 bytes,(reduced by -61876 bytes)
TensorFlow Lite Quantized,19568 bytes,(reduced by 46404 bytes)


In [24]:
!xxd -i {MODEL_TFLITE} > {MODEL_TFLITE_MICRO}
# Update variable names
REPLACE_TEXT = MODEL_TFLITE.replace('/', '_').replace('.', '_')
!sed -i 's/'{REPLACE_TEXT}'/g_model/g' {MODEL_TFLITE_MICRO}

In [25]:
!cat {MODEL_TFLITE_MICRO}

unsigned char g_model[] = {
  0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x14, 0x00, 0x20, 0x00,
  0x1c, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00, 0x0c, 0x00, 0x00, 0x00,
  0x08, 0x00, 0x04, 0x00, 0x14, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00,
  0x1c, 0x00, 0x00, 0x00, 0x4c, 0x00, 0x00, 0x00, 0xb8, 0x3e, 0x00, 0x00,
  0xc8, 0x3e, 0x00, 0x00, 0xc4, 0x4b, 0x00, 0x00, 0x03, 0x00, 0x00, 0x00,
  0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x0c, 0x00, 0x00, 0x00,
  0x08, 0x00, 0x0c, 0x00, 0x08, 0x00, 0x04, 0x00, 0x08, 0x00, 0x00, 0x00,
  0x13, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x13, 0x00, 0x00, 0x00,
  0x6d, 0x69, 0x6e, 0x5f, 0x72, 0x75, 0x6e, 0x74, 0x69, 0x6d, 0x65, 0x5f,
  0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x00, 0x14, 0x00, 0x00, 0x00,
  0x68, 0x3e, 0x00, 0x00, 0x60, 0x3e, 0x00, 0x00, 0x4c, 0x3e, 0x00, 0x00,
  0x30, 0x3e, 0x00, 0x00, 0x18, 0x3e, 0x00, 0x00, 0xfc, 0x3d, 0x00, 0x00,
  0xdc, 0x3d, 0x00, 0x00, 0xb4, 0x3d, 0x00, 0x00, 0x9c, 0x3d, 0x00, 0x0