In [None]:
'''
We actually have several options as to how much we want to quantize a model. 
In this tutorial, we'll perform "full integer quantization," 
which converts all weights and activation outputs into 8-bit integer data, 
whereas other strategies may leave some amount of data in floating-point.
'''

In [2]:
import logging
logging.getLogger("tensorflow").setLevel(logging.DEBUG)

import tensorflow as tf
import numpy as np
print("TensorFlow version: ", tf.__version__)

TensorFlow version:  2.14.0


In [4]:
# Load MNIST dataset
mnist = tf.keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images.astype(np.float32) / 255.0
test_images = test_images.astype(np.float32) / 255.0

# Define the model architecture
model = tf.keras.Sequential([
  tf.keras.layers.InputLayer(input_shape=(28, 28)),
  tf.keras.layers.Reshape(target_shape=(28, 28, 1)),
  tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(
                  from_logits=True), metrics=['accuracy'])

model.fit(train_images, train_labels, epochs=5,
  validation_data=(test_images, test_labels)
)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.src.callbacks.History at 0x175669c8c90>

In [5]:
# converted model with no quantization
# still using 32-bit float values for all parameter data.

converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

INFO:tensorflow:Assets written to: C:\Users\abhishek.sri\AppData\Local\Temp\tmpmegpxaz4\assets


INFO:tensorflow:Assets written to: C:\Users\abhishek.sri\AppData\Local\Temp\tmpmegpxaz4\assets


In [6]:
# default optimizations flag to quantize all fixed parameters (such as weights):
# model is now a bit smaller with quantized weights, but other variable data is still in float format.
# dynamic range quantization

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_model_quant = converter.convert()

INFO:tensorflow:Assets written to: C:\Users\abhishek.sri\AppData\Local\Temp\tmpi30xkfiw\assets


INFO:tensorflow:Assets written to: C:\Users\abhishek.sri\AppData\Local\Temp\tmpi30xkfiw\assets


In [7]:
# integer-only quantization : input and output tensors are now integer format:
# To quantize the input and output tensors, and make the converter throw an error if it encounters an operation it cannot quantize
# convert the model again with some additional parameters and representative dataset
# helps in running on embedded devices such as Edge TPU

def representative_data_gen():
  for input_value in tf.data.Dataset.from_tensor_slices(train_images).batch(1).take(100):
    yield [input_value]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen

# Ensure that if any ops can't be quantized, the converter throws an error
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]

# Set the input and output tensors to uint8 (APIs added in r2.3)
converter.inference_input_type = tf.uint8
converter.inference_output_type = tf.uint8

tflite_model_quant = converter.convert()

INFO:tensorflow:Assets written to: C:\Users\abhishek.sri\AppData\Local\Temp\tmpgwfo_3lw\assets


INFO:tensorflow:Assets written to: C:\Users\abhishek.sri\AppData\Local\Temp\tmpgwfo_3lw\assets


In [8]:
interpreter = tf.lite.Interpreter(model_content=tflite_model_quant)
input_type = interpreter.get_input_details()[0]['dtype']
print('input: ', input_type)
output_type = interpreter.get_output_details()[0]['dtype']
print('output: ', output_type)

input:  <class 'numpy.uint8'>
output:  <class 'numpy.uint8'>
