##### Copyright 2020 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Quantization aware training in Keras example

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/model_optimization/guide/quantization/training_example"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/quantization/training_example.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/quantization/training_example.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/model-optimization/tensorflow_model_optimization/g3doc/guide/quantization/training_example.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

## Overview

Welcome to an end-to-end example for *quantization aware training*.

### Other pages
For an introduction to what quantization aware training is and to determine if you should use it (including what's supported), see the [overview page](https://www.tensorflow.org/model_optimization/guide/quantization/training.md).

To quickly find the APIs you need for your use case (beyond fully-quantizing a model with 8-bits), see the
[comprehensive guide](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md).

### Summary

In this tutorial, you will:

1.   Train a `tf.keras` model for MNIST from scratch.
2.   Fine tune the model by applying the quantization aware training API, see the accuracy, and
     export a quantization aware model.
3.   Use the model to create an actually quantized model for the TFLite
     backend.
4.   See the persistence of accuracy in
     TFLite and a 4x smaller model. To see the latency benefits on mobile, try out the TFLite examples [in the TFLite app repository](https://www.tensorflow.org/lite/models).

## Setup

In [24]:
import subprocess
import pkg_resources

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_datasets as tfds

# Desired versions
tf_version = "2.10.0"
hub_version = "0.12.0"
datasets_version = "4.6.0"
tfmot_opt_version = "0.7.3"

# Function to install or upgrade a package to the desired version
def install_package(package_name, version):
    print(f"Installing {package_name}=={version}...")
    subprocess.check_call(["pip", "install", f"{package_name}=={version}"])

# Check if TensorFlow, TensorFlow Hub, TensorFlow Datasets, and TensorFlow Model Optimization are the correct versions
if (tf.__version__ != tf_version or
    hub.__version__ != hub_version or
    tfds.__version__ != datasets_version):

    print(f"Current TensorFlow version: {tf.__version__}, switching to {tf_version}")
    print(f"Current TensorFlow Hub version: {hub.__version__}, switching to {hub_version}")
    print(f"Current TensorFlow Datasets version: {tfds.__version__}, switching to {datasets_version}")

    # Uninstall current versions of TensorFlow, TensorFlow Hub, and TensorFlow Datasets
    !pip uninstall -y tensorflow tensorflow_hub tensorflow_datasets

    # Install desired versions
    install_package("tensorflow", tf_version)
    install_package("tensorflow_hub", hub_version)
    install_package("tensorflow_datasets", datasets_version)

    print("Specified versions of TensorFlow, TensorFlow Hub, and TensorFlow Datasets installed.")
    print("Please click on the Runtime > Restart session and run all.")

else:
    print("All packages (TensorFlow, TensorFlow Hub, and TensorFlow Datasets) are already at the specified versions.")

# Check for tensorflow-model-optimization
try:
    package_version = pkg_resources.get_distribution("tensorflow_model_optimization").version
    if package_version != tfmot_opt_version:
        print(f"Current tensorflow-model-optimization version: {package_version}, upgrading to {tfmot_opt_version}")
        # Uninstall and install the correct version
        subprocess.check_call(["pip", "uninstall", "-y", "tensorflow-model-optimization"])
        install_package("tensorflow-model-optimization", tfmot_opt_version)
    else:
        print(f"tensorflow-model-optimization {tfmot_opt_version} is already installed.")
except pkg_resources.DistributionNotFound:
    print(f"tensorflow-model-optimization not found, installing version {tfmot_opt_version}.")
    install_package("tensorflow-model-optimization", tfmot_opt_version)


All packages (TensorFlow, TensorFlow Hub, and TensorFlow Datasets) are already at the specified versions.
tensorflow-model-optimization 0.7.3 is already installed.


In [25]:
import tempfile
import os

import tensorflow as tf

from tensorflow import keras

## Train a model for MNIST without quantization aware training

In [32]:
# Load MNIST dataset
mnist = keras.datasets.mnist

#The MNIST database of handwritten digits has a training set of 60,000 examples,
#and a test set of 10,000 examples.

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

#using training set images
model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)

model.summary()

Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 reshape_4 (Reshape)         (None, 28, 28, 1)         0         
                                                                 
 conv2d_4 (Conv2D)           (None, 26, 26, 12)        120       
                                                                 
 max_pooling2d_4 (MaxPooling  (None, 13, 13, 12)       0         
 2D)                                                             
                                                                 
 flatten_4 (Flatten)         (None, 2028)              0         
                                                                 
 dense_4 (Dense)             (None, 10)                20290     
                                                                 
Total params: 20,410
Trainable params: 20,410
Non-trainable params: 0
__________________________________________________

## Clone and fine-tune pre-trained model with quantization aware training


### Define the model

You will apply quantization aware training to the whole model and see this in the model summary. All layers are now prefixed by "quant".

Note that the resulting model is quantization aware but not quantized (e.g. the weights are float32 instead of int8). The sections after show how to create a quantized model from the quantization aware one.

In the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/quantization/training_comprehensive_guide.md), you can see how to quantize some layers for model accuracy improvements.

In [30]:
#We use an api for Quantization Aware Training
import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model #tfmot.quantization.keras returns an api

# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

q_aware_model.summary() #layers and all

Model: "sequential_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 quantize_layer_5 (QuantizeL  (None, 28, 28)           3         
 ayer)                                                           
                                                                 
 quant_reshape_3 (QuantizeWr  (None, 28, 28, 1)        1         
 apperV2)                                                        
                                                                 
 quant_conv2d_3 (QuantizeWra  (None, 26, 26, 12)       147       
 pperV2)                                                         
                                                                 
 quant_max_pooling2d_3 (Quan  (None, 13, 13, 12)       1         
 tizeWrapperV2)                                                  
                                                                 
 quant_flatten_3 (QuantizeWr  (None, 2028)            

### Train and evaluate the model against baseline

To demonstrate fine tuning after training the model for just an epoch (we ran the main model 1 time), fine tune with quantization aware training on a subset of the training data. (1000 data/60000 data)

In [33]:
#The MNIST database of handwritten digits has a training set of 60,000 examples,
#and a test set of 10,000 examples.

#For simplicity, we are using 1k training image and labels for the quantized model
train_images_subset = train_images[0:1000] # out of 60000
train_labels_subset = train_labels[0:1000]

q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size=500, epochs=1, validation_split=0.1)

#the accuracy is much higher than the earlier model



<keras.callbacks.History at 0x798dce0162c0>

For this example, there is minimal to no loss in test accuracy after quantization aware training, compared to the baseline.

In [35]:
#testing the model on test sets (10k data)
#we used  same amount of data to test both models

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)

#the normal model had 95.7%   accuracy and the quantized model had 96.2% accuracy. Almost same!

Baseline test accuracy: 0.9577000141143799
Quant test accuracy: 0.9629999995231628


## Create quantized model for TFLite backend

After this, you have an actually quantized model with int8 weights and uint8 activations.

In [36]:
#converting the quantized model
converter = tf.lite.TFLiteConverter.from_keras_model(q_aware_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]

quantized_tflite_model = converter.convert()



## See persistence of accuracy from TF to TFLite

Define a helper function to evaluate the TF Lite model on the test dataset.

In [38]:
import numpy as np

def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print('Evaluated on {n} results so far.'.format(n=i))
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

You evaluate the quantized model and see that the accuracy from TensorFlow persists to the TFLite backend.

In [40]:
interpreter = tf.lite.Interpreter(model_content=quantized_tflite_model)

"""Here tf.line.Interpreter  is a class in tensorflow lite that handles the execution of trained tensorflow model.It acts as a bridge.Here quantized_tflite_model is a variable that contains the quantized model."""
interpreter.allocate_tensors()
test_accuracy = evaluate_model(interpreter)

print('Quant TFLite test_accuracy:', test_accuracy)
print('Quant TF test accuracy:', q_aware_model_accuracy)

#The accuracy seems almost same.

Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Quant TFLite test_accuracy: 0.963
Quant TF test accuracy: 0.9629999995231628


## See 4x smaller model from quantization

You create a float TFLite model and then see that the quantized TFLite model
is 4x smaller.

In [41]:
# Create float TFLite model.
float_converter = tf.lite.TFLiteConverter.from_keras_model(model) #model is the normal model created earlier
float_tflite_model = float_converter.convert()

# Measure sizes of models.
_, float_file = tempfile.mkstemp('.tflite')
_, quant_file = tempfile.mkstemp('.tflite')

with open(quant_file, 'wb') as f:
  f.write(quantized_tflite_model)

with open(float_file, 'wb') as f:
  f.write(float_tflite_model)

print("Float model in Mb:", os.path.getsize(float_file) / float(2**20))
print("Quantized model in Mb:", os.path.getsize(quant_file) / float(2**20))



Float model in Mb: 0.080963134765625
Quantized model in Mb: 0.023895263671875


## Conclusion

In this tutorial, you saw how to create quantization aware models with the TensorFlow Model Optimization Toolkit API and then quantized models for the TFLite backend.

You saw a 4x model size compression benefit for a model for MNIST, with minimal accuracy
difference. To see the latency benefits on mobile, try out the TFLite examples [in the TFLite app repository](https://www.tensorflow.org/lite/models).

We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments.
