<a href="https://colab.research.google.com/github/aravindchakravarti/OptimizeNetworks/blob/main/Quantization_Aware_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#### Reference Code Available at: 

- Quantization Aware Training: https://www.tensorflow.org/model_optimization/guide/quantization/training_example

- Post-training Quantization: https://www.tensorflow.org/lite/performance/post_training_quant

# Download Dependencies and Import Libraries


In [1]:
! pip install -q tensorflow
! pip install -q tensorflow-model-optimization

[?25l[K     |█▍                              | 10 kB 28.4 MB/s eta 0:00:01[K     |██▊                             | 20 kB 18.6 MB/s eta 0:00:01[K     |████▏                           | 30 kB 18.1 MB/s eta 0:00:01[K     |█████▌                          | 40 kB 19.0 MB/s eta 0:00:01[K     |██████▉                         | 51 kB 21.8 MB/s eta 0:00:01[K     |████████▎                       | 61 kB 16.6 MB/s eta 0:00:01[K     |█████████▋                      | 71 kB 7.8 MB/s eta 0:00:01[K     |███████████                     | 81 kB 8.7 MB/s eta 0:00:01[K     |████████████▍                   | 92 kB 9.6 MB/s eta 0:00:01[K     |█████████████▊                  | 102 kB 10.5 MB/s eta 0:00:01[K     |███████████████                 | 112 kB 10.5 MB/s eta 0:00:01[K     |████████████████▌               | 122 kB 10.5 MB/s eta 0:00:01[K     |█████████████████▉              | 133 kB 10.5 MB/s eta 0:00:01[K     |███████████████████▏            | 143 kB 10.5 MB/s eta 0:00:

In [2]:
import tempfile
import os
import tensorflow as tf
from tensorflow import keras
from time import perf_counter
from statistics import mean
import pathlib

# Build a MNIST Classifier

## Model Design and Training

In [3]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


<keras.callbacks.History at 0x7f50800720d0>

## Inference of Model

In [4]:
inference_time = []
for i in range (10):
  start = perf_counter()
  model.evaluate(test_images, test_labels)
  stop = perf_counter()
  inference_time.append(stop-start)
  
for i in range(10):
  print("Inference Time Diff = ", inference_time[i])

print("Mean Time Diff = ", mean(inference_time))

Inference Time Diff =  1.334142503999999
Inference Time Diff =  0.9253799219999905
Inference Time Diff =  0.798293238000042
Inference Time Diff =  1.3292515269999967
Inference Time Diff =  1.0805777310000053
Inference Time Diff =  1.3261469970000235
Inference Time Diff =  1.3384186769999928
Inference Time Diff =  1.3260544179999556
Inference Time Diff =  1.3274303059999966
Inference Time Diff =  0.7841694370000027
Mean Time Diff =  1.1569864757000006


# Quantization of Model

## Convert the model in Tensorflow lite

After applying quantization aware training to the whole model we can see the model summary. All layers are now prefixed by "quant".

Note that the resulting model is quantization aware but not quantized (e.g. the weights are float32 instead of int8). Next sections show how to create a quantized model from the quantization aware one.

In the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md),

In [5]:
import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

q_aware_model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer (QuantizeLay  (None, 28, 28)           3         
 er)                                                             
                                                                 
 quant_reshape (QuantizeWrap  (None, 28, 28, 1)        1         
 perV2)                                                          
                                                                 
 quant_conv2d (QuantizeWrapp  (None, 26, 26, 12)       147       
 erV2)                                                           
                                                                 
 quant_max_pooling2d (Quanti  (None, 13, 13, 12)       1         
 zeWrapperV2)                                                    
                                                                 
 quant_flatten (QuantizeWrap  (None, 2028)             1

## Psuedo-Transfer-Learning
Kind of a transfer learning. Re-train the model as we did model "Quantization Aware" for some epoch. To demonstrate fine tuning after training the model for just an epoch, fine tune with quantization aware training on a subset of the training data.

In [6]:
train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size=500, epochs=1, validation_split=0.1)



<keras.callbacks.History at 0x7f5032149b50>

For this example, there is minimal to no loss in test accuracy after quantization aware training, compared to the baseline.

In [7]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)

Baseline test accuracy: 0.9570000171661377
Quant test accuracy: 0.9563999772071838


## Interpreter

We do not have edge device with us. Hence, we need some kind of emulation to run tensorflow lite model

In [8]:
import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    #if i % 1000 == 0:
    #  print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  #print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

##  FLOAT-32 BIT Quantizing

In [9]:
converter_f32 = tf.lite.TFLiteConverter.from_keras_model(model)
# No optimiser speicification
# No supported_types 
tflite_model_f32 = converter_f32.convert()

interpreter = tf.lite.Interpreter(model_content=tflite_model_f32)
interpreter.allocate_tensors()

inference_time = []
for i in range (10):
  start = perf_counter()
  test_accuracy = evaluate_model(interpreter)
  stop = perf_counter()
  inference_time.append(stop-start)
  
for i in range(10):
  print("Inference Time Diff = ", inference_time[i])

print("Mean Time Diff = ", mean(inference_time))

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)



Inference Time Diff =  1.3139581130000124
Inference Time Diff =  1.3142963470000382
Inference Time Diff =  1.3004703240000026
Inference Time Diff =  1.312723329999983
Inference Time Diff =  1.3041038720000415
Inference Time Diff =  1.3025300639999955
Inference Time Diff =  1.2839879800000062
Inference Time Diff =  1.3104669129999706
Inference Time Diff =  1.3276079019999543
Inference Time Diff =  1.2687385360000007
Mean Time Diff =  1.3038883381000006
Quant TFLite test_accuracy: 0.957
Quant TF test accuracy: 0.9563999772071838


## FLOAT-16 BIT Quantizing

In [10]:
converter_fl16 = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter_fl16.optimizations = [tf.lite.Optimize.DEFAULT]
converter_fl16.target_spec.supported_types = [tf.float16]
quantized_tflite_model_f16 = converter_fl16.convert()

interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model_f16)
interpreter.allocate_tensors()

inference_time = []
for i in range (10):
  start = perf_counter()
  test_accuracy = evaluate_model(interpreter)
  stop = perf_counter()
  inference_time.append(stop-start)
  
for i in range(10):
  print("Inference Time Diff = ", inference_time[i])

print("Mean Time Diff = ", mean(inference_time))

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)



Inference Time Diff =  0.6119103119999636
Inference Time Diff =  0.5931340540000178
Inference Time Diff =  0.5934894220000047
Inference Time Diff =  0.6123927529999946
Inference Time Diff =  0.5740142430000219
Inference Time Diff =  0.5876860340000007
Inference Time Diff =  0.5845103400000085
Inference Time Diff =  0.6354716249999797
Inference Time Diff =  0.6081682299999898
Inference Time Diff =  0.6215319369999861
Mean Time Diff =  0.6022308949999967
Quant TFLite test_accuracy: 0.9564
Quant TF test accuracy: 0.9563999772071838


##  INT-8 BIT Quantizing

In [11]:
converter_t8 = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter_t8.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model_t8 = converter_t8.convert()

interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model_t8)
interpreter.allocate_tensors()

inference_time = []
for i in range (10):
  start = perf_counter()
  test_accuracy = evaluate_model(interpreter)
  stop = perf_counter()
  inference_time.append(stop-start)
  
for i in range(10):
  print("Inference Time Diff = ", inference_time[i])

print("Mean Time Diff = ", mean(inference_time))

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)



Inference Time Diff =  0.6110938000000488
Inference Time Diff =  0.5803180810000299
Inference Time Diff =  0.5870856820000085
Inference Time Diff =  0.597313865999979
Inference Time Diff =  0.6075698110000189
Inference Time Diff =  0.5936725999999908
Inference Time Diff =  0.598558799999978
Inference Time Diff =  0.5820988860000398
Inference Time Diff =  0.5880661200000077
Inference Time Diff =  0.5829093030000081
Mean Time Diff =  0.5928686949000109
Quant TFLite test_accuracy: 0.9564
Quant TF test accuracy: 0.9563999772071838


## Analyzing Memory

In [12]:
tflite_models_dir = pathlib.Path("./mnist_tflite_models/")
tflite_models_dir.mkdir(exist_ok=True, parents=True)

tflite_model_file = tflite_models_dir/"mnist_model_f32.tflite"
tflite_model_file.write_bytes(tflite_model_f32)

tflite_model_file = tflite_models_dir/"mnist_model_quant_f16.tflite"
tflite_model_file.write_bytes(quantized_tflite_model_f16)

tflite_model_file = tflite_models_dir/"mnist_model_quant_t8.tflite"
tflite_model_file.write_bytes(quantized_tflite_model_t8)

!ls -lh {tflite_models_dir}

total 140K
-rw-r--r-- 1 root root 83K Dec  6 11:52 mnist_model_f32.tflite
-rw-r--r-- 1 root root 25K Dec  6 11:52 mnist_model_quant_f16.tflite
-rw-r--r-- 1 root root 25K Dec  6 11:52 mnist_model_quant_t8.tflite


In [13]:
'''
# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model)
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')
_, quant_file_f16 = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
  f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
  f.write(float_tflite_model)

with open(quant_file_f16, 'wb') as f:
  f.write(quantized_tflite_model_f16)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))
print("Float f16 model in Mb:", os.path.getsize(quant_file_f16) / float(2**20))
'''

'\n# Create float TFLite model.\nfloat_converter = tf.lite.TFLiteConverter.from_keras_model(model)\nfloat_tflite_model = float_converter.convert()\n\n# Measure sizes of models.\n_, float_file = tempfile.mkstemp(\'.tflite\')\n_, quant_file = tempfile.mkstemp(\'.tflite\')\n_, quant_file_f16 = tempfile.mkstemp(\'.tflite\')\n\nwith open(quant_file, \'wb\') as f:\n  f.write(quantized_tflite_model)\n\nwith open(float_file, \'wb\') as f:\n  f.write(float_tflite_model)\n\nwith open(quant_file_f16, \'wb\') as f:\n  f.write(quantized_tflite_model_f16)\n\nprint("Float model in Mb:", os.path.getsize(float_file) / float(2**20))\nprint("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))\nprint("Float f16 model in Mb:", os.path.getsize(quant_file_f16) / float(2**20))\n'