# Post-Training Quantization in Keras using the Model Compression Toolkit (MCT)
[Run this tutorial in Google Colab](https://colab.research.google.com/github/sony/model_optimization/blob/main/tutorials/notebooks/mct_features_notebooks/keras/example_keras_post_training_quantization.ipynb)

## Overview
This quick-start guide explains how to use the **Model Compression Toolkit (MCT)** to quantize a Keras model. We will load a pre-trained model and  quantize it using the MCT with **Post-Training Quatntization (PTQ)**. Finally, we will evaluate the quantized model and export it to a Keras or TFLite files.

## Summary
In this tutorial, we will cover:

1. Loading and preprocessing the Imagenette dataset using the Tensorflow Datasets package.
2. Constructing an unlabeled representative dataset.
3. Hardware-Friendly Post-Training Quantization using MCT.
4. Accuracy evaluation of the floating-point and the quantized models.
5. Exporting the model to Keras and TFLite files.

## Setup
Install the relevant packages:

In [None]:
TF_VER = '2.14'
!pip install -q tensorflow[and-cuda]~={TF_VER} tensorflow-datasets

In [None]:
import importlib
if not importlib.util.find_spec('model_compression_toolkit'):
    !pip install model_compression_toolkit

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
from tqdm import tqdm

Load a pre-trained MobileNetV2 model from Keras, in 32-bits floating-point precision format.

In [None]:
from keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input

float_model = MobileNetV2()

## Dataset preparation
### Download ImageNette validation set
For this demonstration, we will use the Imagenette dataset, a subset of 10 easily classified classes from the larger ImageNet dataset.

**Note** that for demonstration purposes we use the validation set for the model quantization routines. Typically, a subset of the training dataset is used, but loading it is a heavy procedure that is unnecessary for this example.

Load the Imagenette validation dataset using tensorflow-datasets:

In [None]:
imagenette_class_indices = [0, 217, 482, 491, 497, 566, 569, 571, 574, 701]  # ImageNet indices for Imagenette classes

# Load the Imagenette validation split
imagenet_val_ds, info = tfds.load('imagenette', split='validation', data_dir='./imagenette', with_info=True, as_supervised=True)

# Preprocess the dataset
img_size = 224  # Model's expected input size
batch_size = 50

def preprocess_image(image, label, img_size):
    image = tf.image.resize(image, (img_size, img_size))
    image = preprocess_input(image)  # Preprocess using MobileNetV2's preprocessing
    return image, label

val_ds = imagenet_val_ds.map(lambda img, lbl: preprocess_image(img, lbl, img_size), num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.batch(batch_size).prefetch(buffer_size=tf.data.AUTOTUNE)

## Representative Dataset
For quantization with MCT, we need to define a representative dataset required by the PTQ algorithm. This dataset is a generator that returns a list of images:

In [None]:
batch_size = 16
n_iter = 5

numpy_dataset = tfds.as_numpy(val_ds)

def representative_dataset_gen():
    dataloader_iter = iter(numpy_dataset)
    for _ in range(n_iter):
        yield [next(dataloader_iter)[0]]


## Target Platform Capabilities
MCT optimizes the model for dedicated hardware. This is done using TPC (for more details, please visit our [documentation](https://sony.github.io/model_optimization/docs/api/api_docs/modules/target_platform.html)). Here, we use the default Tensorflow TPC:

In [None]:
import model_compression_toolkit as mct

# Get a TargetPlatformCapabilities object that models the hardware for the quantized model inference. Here, for example, we use the default platform that is attached to a Keras layers representation.
target_platform_cap = mct.get_target_platform_capabilities('tensorflow', 'default')

## Hardware-Friendly Post-Training Quantization using MCT
Now for the exciting part! Let’s run hardware-friendly PTQ on the model. 
**Hardware-friendly** means symmetric quantization with power-of-2 thresholds.

In [None]:
quantized_model, quantization_info = mct.ptq.keras_post_training_quantization(
        in_model=float_model,
        representative_data_gen=representative_dataset_gen,
        target_platform_capabilities=target_platform_cap
)

Our model is now quantized. MCT has created a simulated quantized model within the original Keras framework by inserting [quantization representation modules](https://github.com/sony/mct_quantizers). These modules, such as `KerasQuantizationWrapper` and `KerasActivationQuantizationHolder`, wrap Keras layers to simulate the quantization of weights and activations, respectively. While the size of the saved model remains unchanged, all the quantization parameters are stored within these modules and are ready for deployment on the target hardware. In this example, we used the default MCT settings, which compressed the model from 32 bits to 8 bits, resulting in a compression ratio of 4x.

## Model Evaluation
In order to evaluate our models, we first need to define a function for evaluation of a Keras model trained on ImageNet using the Imagenette dataset.

In [None]:
def evaluate_imagenette_model(model, dataset, class_indices):

    # Initialize variables to store predictions and ground truth labels
    correct_predictions = 0
    total_samples = 0

    for images, true_labels in tqdm(dataset, total=len(dataset), desc="Evaluating"):
        # Make predictions with the pre-trained model
        preds = model.predict(images, verbose=0)

        # Extract predictions only for the Imagenette classes
        imagenette_preds = preds[:, class_indices]

        # Map predictions to the highest scoring Imagenette class
        predicted_classes = np.argmax(imagenette_preds, axis=1)

        # Compare predictions with true labels
        correct_predictions += np.sum(predicted_classes == true_labels.numpy())
        total_samples += true_labels.shape[0]

    # Calculate accuracy
    accuracy = correct_predictions / total_samples
    return accuracy

Let's start with the floating-point model evaluation.

In [None]:
float_accuracy = evaluate_imagenette_model(float_model, val_ds, imagenette_class_indices)
print(f"Float model's accuracy on Imagenette: {(float_accuracy * 100):.2f}%")

Finally, let's evaluate the quantized model:

In [None]:
quant_accuracy = evaluate_imagenette_model(quantized_model, val_ds, imagenette_class_indices)
print(f"Quantized model's accuracy on Imagenette: {(quant_accuracy * 100):.2f}%")

You can see that we got a very small degradation with a compression rate of x4 !
Now, we can export the quantized model to Keras and TFLite:

In [None]:
mct.exporter.keras_export_model(
    model=quantized_model,
    save_model_path='qmodel.tflite',
    serialization_format=mct.exporter.KerasExportSerializationFormat.TFLITE,
    quantization_format=mct.exporter.QuantizationFormat.FAKELY_QUANT)

mct.exporter.keras_export_model(model=quantized_model, save_model_path='qmodel.keras')

## Conclusion

In this tutorial, we demonstrated how to quantize a classification model in a hardware-friendly manner using MCT. We observed that a 4x compression ratio was achieved with minimal performance degradation.

The key advantage of hardware-friendly quantization is that the model can run more efficiently in terms of runtime, power consumption, and memory usage on designated hardware.

While this was a simple model and task, MCT can deliver competitive results across a wide range of tasks and network architectures. For more details, [check out the paper:](https://arxiv.org/abs/2109.09113).

## Copyrights

Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
