# import libs

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
import numpy as np



In [2]:
tfds.disable_progress_bar()

Converts 2d image to 1d array. It has been easier to work with tensorflow micro with 1d inputs and the reshape to 2d in the model for layers that require other dimensions

In [3]:
def flatten_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  image=tf.dtypes.cast(image, tf.float32)/255
  return tf.reshape(image,[28*28]), label

## Load in mnist data

In [4]:

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /root/tensorflow_datasets/mnist/3.0.1...
Dataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.1. Subsequent calls will reuse this data.


## split data set

In [5]:
ds_train = ds_train.map(
    flatten_img)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(1)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

ds_test = ds_test.map(
    flatten_img)
ds_test = ds_test.batch(1)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.experimental.AUTOTUNE)

A simple fully connected model

In [6]:
simple_model = tf.keras.models.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28*28)),
  tf.keras.layers.Dense(64),
  tf.keras.layers.Dense(10)
])
# Train the digit classification model
simple_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])



A more interesting model with convultional layers

In [7]:
# conv model
model = tf.keras.models.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28*28)),
  tf.keras.layers.Reshape((28,28,1)),
  tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dropout(0.5),
  tf.keras.layers.Dense(10, activation="softmax"),
])
# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(),
              metrics=['accuracy'])

In [8]:
model.fit(
    ds_train,
    epochs=6,
    validation_data=ds_test,
)

Epoch 1/6
Epoch 2/6
Epoch 3/6
Epoch 4/6
Epoch 5/6
Epoch 6/6


<keras.src.callbacks.History at 0x7cab79e2a230>

## Save full tensorflow model

In [9]:
model.save('Models/Full/mnistModel')

# Convert model to tf micro

## load saved full size model

In [10]:
model = keras.models.load_model('Models/Full/mnistModel')


## prep sample data that will be used  for quantization

In [11]:
mnist_train, _ = tf.keras.datasets.mnist.load_data()
images = tf.cast(mnist_train[0], tf.float32) / 255.0
images=tf.reshape(images,[images.shape[0],28*28])
mnist_ds = tf.data.Dataset.from_tensor_slices((images)).batch(1)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


## create representative_data_gen providing function

In [12]:
def representative_data_gen():
  for input_value in c.take(100):
    # Model has only one input so each data point has one element.
    yield [input_value]

## configure converter for tf micro int8 models

In [13]:
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_types = [tf.int8]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_type = tf.int8
converter.inference_output_type =  tf.int8

## convert full model to tflite

In [14]:
tflite_model = converter.convert()




## save lite model

In [15]:
open("./mnist.tflite", "wb").write(tflite_model)


42032

## export tflite model as cc file for use in tensoflow micro

In [22]:
%%shell
xxd -i ./mnist.tflite > ./mnist.cc
cat ./mnist.cc

unsigned char __mnist_tflite[] = {
  0x20, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c, 0x33, 0x00, 0x00, 0x00, 0x00,
  0x14, 0x00, 0x20, 0x00, 0x1c, 0x00, 0x18, 0x00, 0x14, 0x00, 0x10, 0x00,
  0x0c, 0x00, 0x00, 0x00, 0x08, 0x00, 0x04, 0x00, 0x14, 0x00, 0x00, 0x00,
  0x1c, 0x00, 0x00, 0x00, 0x80, 0x00, 0x00, 0x00, 0xd8, 0x00, 0x00, 0x00,
  0xcc, 0x8b, 0x00, 0x00, 0xdc, 0x8b, 0x00, 0x00, 0x4c, 0xa3, 0x00, 0x00,
  0x03, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
  0x1e, 0x72, 0xff, 0xff, 0x0c, 0x00, 0x00, 0x00, 0x1c, 0x00, 0x00, 0x00,
  0x38, 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x73, 0x65, 0x72, 0x76,
  0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x66, 0x61, 0x75, 0x6c, 0x74, 0x00,
  0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x9c, 0xff, 0xff, 0xff,
  0x16, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x07, 0x00, 0x00, 0x00,
  0x64, 0x65, 0x6e, 0x73, 0x65, 0x5f, 0x32, 0x00, 0x01, 0x00, 0x00, 0x00,
  0x04, 0x00, 0x00, 0x00, 0x22, 0x72, 0xff, 0xff, 0x04, 0x00, 0x00, 0x00,
  0



# Test tflite model

It is possible to test the quantized model in python.

In [16]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [17]:
def flatten_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.reshape(image,[28*28]), label

In [18]:
ds_train = ds_train.map(
    flatten_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)

In [19]:
# Load TFLite model and allocate tensors.
interpreter = tf.lite.Interpreter(model_path="./mnist.tflite")
interpreter.allocate_tensors()

In [20]:
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()


print(input_details)
print(output_details)
# Test model on random input data.
input_shape = input_details[0]['shape']
print(input_shape)

[{'name': 'serving_default_input_2:0', 'index': 0, 'shape': array([  1, 784], dtype=int32), 'shape_signature': array([ -1, 784], dtype=int32), 'dtype': <class 'numpy.int8'>, 'quantization': (0.003921568859368563, -128), 'quantization_parameters': {'scales': array([0.00392157], dtype=float32), 'zero_points': array([-128], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
[{'name': 'StatefulPartitionedCall:0', 'index': 22, 'shape': array([ 1, 10], dtype=int32), 'shape_signature': array([-1, 10], dtype=int32), 'dtype': <class 'numpy.int8'>, 'quantization': (0.00390625, -128), 'quantization_parameters': {'scales': array([0.00390625], dtype=float32), 'zero_points': array([-128], dtype=int32), 'quantized_dimension': 0}, 'sparsity_parameters': {}}]
[  1 784]


This block will not stop by itself

In [21]:
count=0
rightount=0
for elem in ds_train:
    #print(elem[0])
    inputData=tf.reshape(tf.cast((tf.cast(elem[0],tf.int16)-128),tf.int8),[1,28*28])
    interpreter.set_tensor(input_details[0]['index'],inputData )
    for i in range(0,28*28):
        value=inputData[0][i]
        #display the image
        #print( f"{value:4} ,", end='')
        #if (i+1) %28==0:
        #    print("")
    interpreter.invoke()
    #print(interpreter.arena_used_bytes()) not implemented yet?
    output_data = interpreter.get_tensor(output_details[0]['index'])
    count+=1
    answer=np.argmax(np.array(output_data))
    truth=elem[1]
    if answer==truth:
        rightount+=1
    #print(elem[1])
    print(rightount/count)

1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
0.9943502824858758
0.9943820224719101
0.994413407821229
0.9944444444444445
0.994475138121547
0.9945054945054945
0.994535519125683
0.9945652173913043
0.9945945945945946
0.9946236559139785
0.9946524064171123
0.9946808510638298
0.9947089947089947
0.9947368421052631
0.9947643979057592
0.994791666666

KeyboardInterrupt: 