# Model Compression Toolkit (MCT) Wrapper API Comprehensive Quantization Comparison(tensorflow)

[Run this tutorial in Google Colab](https://colab.research.google.com/github/SonySemiconductorSolutions/mct-model-optimization/blob/main/tutorials/notebooks/mct_features_notebooks/keras/example_keras_wrapper.ipynb)

## Overview 
This notebook provides a comprehensive demonstration of the MCT (Model Compression Toolkit) Wrapper API functionality, showcasing five different quantization methods on a MobileNetV2 model. The tutorial systematically compares the implementation, performance characteristics, and accuracy trade-offs of each quantization approach: PTQ (Post-Training Quantization), PTQ with Mixed Precision, GPTQ (Gradient-based PTQ), GPTQ with Mixed Precision, and LQ-PTQ (Low-bit Quantizer PTQ). Each method utilizes the unified MCTWrapper interface for consistent implementation and comparison.

## Summary
1. **Environment Setup**: Import required libraries and configure MCT with MobileNetV2 model
2. **Dataset Preparation**: Load and prepare ImageNet validation dataset with representative data generation
3. **PTQ Implementation**: Execute basic Post-Training Quantization with 8-bit precision and bias correction
4. **PTQ + Mixed Precision**: Apply intelligent bit-width allocation based on layer sensitivity analysis (75% compression ratio)
5. **GPTQ Implementation**: Perform gradient-based optimization with 5-epoch fine-tuning for enhanced accuracy
6. **GPTQ + Mixed Precision**: Combine gradient optimization with mixed precision for optimal accuracy-compression trade-off
7. **LQ-PTQ Implementation**: Execute ultra-low bit quantization (2-4 bits) with specialized converter requirements
8. **Performance Evaluation**: Comprehensive accuracy assessment and comparison across all quantization methods
9. **Results Analysis**: Compare model sizes, inference accuracy, and quantization trade-offs

## Setup

In [None]:
# Import required libraries for deep learning and file handling
import os
import tensorflow as tf
import keras
from keras.applications.mobilenet_v2 import MobileNetV2 
from pathlib import Path
from typing import Callable, Generator, List, Tuple, Any

# Alternative pip install commands (commented out for local development)
!pip install -q tensorflow

In [None]:
# Import MCT core
#import importlib
#if not importlib.util.find_spec('model_compression_toolkit'):
#    !pip install model_compression_toolkit

import sys
sys.path.append('/home/ubuntu/wrapper/sonyfork/mct-model-optimization')

import model_compression_toolkit as mct
from model_compression_toolkit.core import QuantizationErrorMethod

## Dataset preparation
Download ImageNet dataset with only the validation split.

**Note** that for demonstration purposes we use the validation set for the model quantization routines. Usually, a subset of the training dataset is used, but loading it is a heavy procedure that is unnecessary for the sake of this demonstration.

This step may take several minutes...

In [None]:
# Download and setup ImageNet validation dataset if not already present
if not os.path.isdir('imagenet'):
    # Create directory and download required ImageNet files
    os.system('mkdir imagenet')
    os.system('wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz')
    os.system('wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar')

    # Move downloaded files to imagenet directory
    os.system('mv ILSVRC2012_devkit_t12.tar.gz imagenet/')
    os.system('mv ILSVRC2012_img_val.tar imagenet/')

In [None]:
# Setup ImageNet validation directory structure if not exists
# This creates the directory structure expected by TensorFlow's image_dataset_from_directory
# Check if ImageNet validation directory already exists
if not os.path.isdir('imagenet/val'):
    import subprocess
    
    # Clone MCT repository temporarily to access setup scripts
    # This provides access to ImageNet data preparation utilities
    subprocess.run(['git', 'clone', 'https://github.com/sony/model_optimization.git', 'temp_mct'])
    
    # Make ImageNet preparation script executable with proper permissions
    os.system('chmod +x ../../../resources/scripts/prepare_imagenet.sh')

    # Run the preparation script to organize ImageNet data into proper directory structure
    # This script handles data extraction and organization for TensorFlow compatibility
    subprocess.run(['../../../resources/scripts/prepare_imagenet.sh'])

def imagenet_preprocess_input(images: tf.Tensor, labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
    """
    Apply MobileNetV2-specific preprocessing to input images.
    
    This function normalizes pixel values according to MobileNetV2 requirements,
    ensuring consistent input format for the model.
    
    Args:
        images: Input image tensor
        labels: Corresponding label tensor
        
    Returns:
        Tuple of preprocessed images and unchanged labels
    """
    return tf.keras.applications.mobilenet_v2.preprocess_input(images), labels

def get_dataset(batch_size: int, shuffle: bool):
    dataset = tf.keras.utils.image_dataset_from_directory(
        directory='./imagenet/val',
        batch_size=batch_size,
        image_size=[224, 224],
        shuffle=shuffle,
        crop_to_aspect_ratio=True,
        interpolation='bilinear'
    )
    dataset = dataset.map(lambda x, y: imagenet_preprocess_input(x, y), num_parallel_calls=tf.data.AUTOTUNE)
    return dataset.prefetch(buffer_size=tf.data.AUTOTUNE)


In [None]:
# Configuration parameters for representative dataset generation
# These parameters control the calibration dataset used for quantization
batch_size = 5  # Number of images per batch for quantization calibration
n_iter = 2      # Number of iterations to generate representative data
                # Total calibration samples = batch_size * n_iter = 10 images

# Create dataset instance for representative data generation
# Use shuffled data to ensure diverse representative samples
dataset = get_dataset(batch_size, shuffle=True)

# Generator function for representative dataset used in quantization calibration
def representative_dataset_gen():
    """
    Generator function for representative dataset used in quantization calibration.

    This function provides a small subset of data that MCT uses for:
    - Calibrating quantization parameters across all model layers
    - Determining optimal activation value ranges for each layer
    - Computing quantization thresholds based on actual data distribution
    - Minimizing quantization error through data-driven parameter selection
    
    Yields:
        List containing numpy arrays of image batches in MCT-expected format
    """
    for _ in range(n_iter):
        # Extract one batch from the dataset and convert to numpy format
       yield [dataset.take(1).get_single_element()[0].numpy()]

## Model Post-Training quantization using MCTWrapper

In [None]:
# Decorator to provide consistent logging and error handling for quantization functions
def decorator(func):
    """
    Wrapper decorator that provides standardized execution logging and error handling.
    
    This decorator enhances quantization functions by:
    - Providing clear start/end execution markers for debugging
    - Handling success/failure status from quantization operations
    - Implementing fail-fast behavior on quantization errors
    - Ensuring consistent logging format across all quantization methods
    
    Usage:
        @decorator
        def quantization_function(model):
            # quantization implementation
            return flag, quantized_model
    
    Args:
        func: Function to be decorated (typically a quantization function)
    
    Returns:
        Wrapped function with enhanced logging and error handling capabilities
    """
    def wrapper(*args, **kwargs):
        # Log function execution start with clear delimiter
        print(f"----------------- {func.__name__} Start ---------------")
        
        # Execute the quantization function and capture return values
        # Expected return format: (success_flag, quantized_model)
        flag, result = func(*args, **kwargs)
        
        # Log function execution completion
        print(f"----------------- {func.__name__} End -----------------")
        
        # Implement fail-fast behavior: exit immediately on quantization failure
        # This ensures early detection of quantization issues
        if not flag:
            exit()
        
        # Return original function results if successful
        return flag, result
    
    return wrapper

Run PTQ (Post-Training Quantization) with Keras

In [None]:
@decorator
def PTQ_Keras(float_model):
    """
    Perform Post-Training Quantization (PTQ) using MCT on Keras model.
    
    PTQ is a quantization method that:
    - Does not require model retraining
    - Uses representative data for calibration
    - Provides good accuracy with minimal computational overhead
    
    Args:
        float_model: Original floating-point Keras model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for basic PTQ quantization
    method = 'PTQ'                    # Post-Training Quantization method
    framework = 'tensorflow'          # Target framework (Keras/TensorFlow)
    use_MCT_TPC = True                # Use MCT's built-in Target Platform Capabilities
    use_MixP = False                  # Disable mixed-precision quantization

    # Parameter configuration for PTQ
    param_items = [
        ['tpc_version', '1.0', 'The version of the TPC to use.'],
        
        # Quantization configuration parameters
        ['activation_error_method', QuantizationErrorMethod.MSE, 'Error metric for activation quantization'],
        ['weights_bias_correction', True, 'Enable bias correction for weights'],
        ['z_threshold', float('inf'), 'Threshold for zero-point quantization'],
        ['linear_collapsing', True, 'Enable linear layer collapsing optimization'],
        ['residual_collapsing', True, 'Enable residual connection collapsing'],
        
        # Output configuration
        ['save_model_path', './qmodel_PTQ_Keras.tflite', 'Path to save the quantized model']
    ]

    # Execute quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model, method, framework, use_MCT_TPC, use_MixP, 
        representative_dataset_gen, param_items)
    return flag, quantized_model

Run PTQ + Mixed Precision Quantization (MixP) with Keras

In [None]:
@decorator
def PTQ_Keras_MixP(float_model):
    """
    Perform Post-Training Quantization with Mixed Precision (PTQ + MixP) on Keras model.
    
    Mixed Precision Quantization:
    - Uses different bit-widths for different layers
    - Optimizes model size while maintaining accuracy
    - Automatically selects optimal precision for each layer
    - Uses resource constraints to guide precision allocation
    
    Args:
        float_model: Original floating-point Keras model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for PTQ with mixed precision
    method = 'PTQ'                    # Post-Training Quantization method
    framework = 'tensorflow'          # Target framework (Keras/TensorFlow)
    use_MCT_TPC = True                # Use MCT's built-in Target Platform Capabilities
    use_MixP = True                   # Enable mixed-precision quantization

    # Parameter configuration for PTQ with Mixed Precision
    param_items = [
        ['tpc_version', '1.0', 'The version of the TPC to use.'],
        
        # Mixed precision configuration
        ['num_of_images', 5, 'Number of images for mixed precision analysis'],
        ['use_hessian_based_scores', False, 'Use Hessian-based sensitivity scores for layer importance'],
        
        # Resource constraint configuration
        ['weights_compression_ratio', 0.75, 'Target compression ratio for model weights (75% of original size)'],
        
        # Output configuration
        ['save_model_path', './qmodel_PTQ_Keras_MixP.tflite', 'Path to save the mixed precision quantized model']
    ]

    # Execute mixed precision quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model, method, framework, use_MCT_TPC, use_MixP, 
        representative_dataset_gen, param_items)
    return flag, quantized_model

Run GPTQ (Gradient-based PTQ) with Keras

In [None]:
@decorator
def GPTQ_Keras(float_model):
    """
    Perform Gradient-based Post-Training Quantization (GPTQ) on Keras model.
    
    GPTQ is an advanced quantization method that:
    - Uses gradient information to optimize quantization parameters
    - Fine-tunes the model during quantization process
    - Generally provides better accuracy than standard PTQ
    - Requires slightly more computational resources than PTQ
    
    Args:
        float_model: Original floating-point Keras model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for GPTQ quantization
    method = 'GPTQ'                   # Gradient-based Post-Training Quantization
    framework = 'tensorflow'          # Target framework (Keras/TensorFlow)
    use_MCT_TPC = False               # Use external EdgeMDT Target Platform Capabilities
    use_MixP = False                  # Disable mixed-precision quantization

    # Parameter configuration for GPTQ
    param_items = [
        # Platform configuration
        ['target_platform_version', 'v1', 'Target platform capabilities version'],
        
        # GPTQ-specific training parameters
        ['n_epochs', 5, 'Number of epochs for gradient-based fine-tuning'],
        ['optimizer', None, 'Optimizer for fine-tuning (None = use default)'],
        
        # Output configuration
        ['save_model_path', './qmodel_GPTQ_Keras.tflite', 'Path to save the GPTQ quantized model']
    ]

    # Execute GPTQ quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model, method, framework, use_MCT_TPC, use_MixP, 
        representative_dataset_gen, param_items)
    return flag, quantized_model

Run GPTQ + Mixed Precision Quantization (MixP) with Keras

In [None]:
@decorator
def GPTQ_Keras_MixP(float_model):
    """
    Perform Gradient-based Post-Training Quantization with Mixed Precision (GPTQ + MixP).
    
    This combines the benefits of both techniques:
    - GPTQ: Gradient-based optimization for better quantization accuracy
    - Mixed Precision: Optimal bit-width allocation for size/accuracy trade-off
    
    This is the most advanced quantization method available, providing:
    - Best possible accuracy preservation
    - Optimal model size reduction
    - Automatic precision selection per layer
    
    Args:
        float_model: Original floating-point Keras model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for GPTQ with mixed precision
    method = 'GPTQ'                   # Gradient-based Post-Training Quantization
    framework = 'tensorflow'          # Target framework (Keras/TensorFlow)
    use_MCT_TPC = False               # Use external EdgeMDT Target Platform Capabilities
    use_MixP = True                   # Enable mixed-precision quantization

    # Parameter configuration for GPTQ with Mixed Precision
    param_items = [
        # Platform configuration
        ['target_platform_version', 'v1', 'Target platform capabilities version'],
        
        # GPTQ-specific training parameters
        ['n_epochs', 5, 'Number of epochs for gradient-based fine-tuning'],
        ['optimizer', None, 'Optimizer for fine-tuning (None = use default)'],
        
        # Mixed precision configuration
        ['num_of_images', 5, 'Number of images for mixed precision sensitivity analysis'],
        ['use_hessian_based_scores', False, 'Use Hessian-based scores for layer importance ranking'],
        
        # Resource constraint configuration
        ['weights_compression_ratio', 0.75, 'Target compression ratio for model weights (75% reduction)'],
        
        # Output configuration
        ['save_model_path', './qmodel_GPTQ_Keras_MixP.tflite', 'Path to save the GPTQ+MixP quantized model']
    ]

    # Execute advanced GPTQ+MixP quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model, method, framework, use_MCT_TPC, use_MixP, 
        representative_dataset_gen, param_items)
    return flag, quantized_model

Run LQPTQ (Low-bit Quantizer PTQ) with Keras

In [None]:
@decorator
def LQPTQ_Keras(float_model):
    """
    Perform Low-bit Quantizer Post-Training Quantization (LQ-PTQ) on Keras model.
    
    LQ-PTQ is a specialized quantization method that:
    - Targets very low bit-width quantization (e.g., 2-4 bits)
    - Uses advanced techniques for ultra-low precision
    - Requires specific converter versions for deployment
    - Currently only supports TensorFlow/Keras framework
    
    Args:
        float_model: Original floating-point Keras model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for LQ-PTQ quantization
    method = 'LQPTQ'                  # Low-bit Quantizer Post-Training Quantization
    framework = 'tensorflow'          # Target framework (TensorFlow only for LQ-PTQ)
    use_MCT_TPC = False               # Use external Target Platform Capabilities
    use_MixP = False                  # Mixed precision not applicable for LQ-PTQ

    # Parameter configuration for LQ-PTQ
    param_items = [
        # LQ-PTQ specific training parameters
        ['learning_rate', 0.0001, 'Learning rate for low-bit quantization optimization'],
        ['converter_ver', 'v3.14', 'Converter version for deployment compatibility'],
        
        # Output configuration
        ['save_model_path', './qmodel_LQPTQ_Keras.tflite', 'Path to save the LQ-PTQ quantized model']
    ]

    # LQ-PTQ requires a different representative dataset format (single batch, not generator)
    representative_dataset = dataset.take(1).get_single_element()[0].numpy()
    
    # Execute LQ-PTQ quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model, method, framework, use_MCT_TPC, use_MixP, 
        representative_dataset, param_items)
    return flag, quantized_model

### Run model Post-Training Quantization
Lastly, we quantize our model using MCTWrapper API.

In [None]:
# Load pre-trained MobileNetV2 model as the base model for quantization experiments
# This model serves as the reference floating-point model for all quantization methods
float_model = MobileNetV2()

# Execute comprehensive quantization method comparison using MCT Wrapper functionality
# Each method represents different trade-offs between accuracy, model size, and computation time
print("Starting quantization experiments with different methods...")

# Method 1: Basic Post-Training Quantization (PTQ)
# - Standard 8-bit quantization without advanced optimization techniquesed
flag, quantized_model = PTQ_Keras(float_model)

# Method 2: PTQ with Mixed Precision Quantization
# - Uses different bit-widths for different layers based on sensitivity analysis
flag, quantized_model2 = PTQ_Keras_MixP(float_model)

# Method 3: Gradient-based Post-Training Quantization (GPTQ)
# - Uses gradient information to fine-tune quantization parameters during conversion
flag, quantized_model3 = GPTQ_Keras(float_model)

# Method 4: GPTQ with Mixed Precision Quantization
# - Combines gradient-based optimization with mixed precision techniques
flag, quantized_model4 = GPTQ_Keras_MixP(float_model)

# Method 5: Low-bit Quantization Post-Training Quantization (LQ-PTQ)
# - Experimental ultra-low precision quantization (2-4 bits per weight)
#flag, quantized_model5 = LQPTQ_Keras(float_model)

print("All quantization methods completed successfully!")

## Models evaluation
In order to evaluate our models, we first need to load the validation dataset. As before, please ensure that the dataset path has been set correctly.

In [None]:
# Model Evaluation and Accuracy Comparison
print("Starting model evaluation phase...")

# Prepare validation dataset for accuracy assessment
val_dataset = get_dataset(batch_size=50, shuffle=False)

# Evaluate original floating-point model accuracy
print("\n=== Original Model Evaluation ===")
float_model.compile(loss=keras.losses.SparseCategoricalCrossentropy(), metrics="accuracy")
float_accuracy = float_model.evaluate(val_dataset)
print(f"Float model's Top 1 accuracy on the Imagenet validation set: {(float_accuracy[1] * 100):.2f}%")

# Evaluate PTQ quantized model accuracy
print("\n=== PTQ Model Evaluation ===")
quantized_model.compile(loss=keras.losses.SparseCategoricalCrossentropy(), metrics="accuracy")
quantized_accuracy = quantized_model.evaluate(val_dataset)
print(f"PTQ_Keras Quantized model's Top 1 accuracy on the Imagenet validation set: {(quantized_accuracy[1] * 100):.2f}%")

# Evaluate PTQ + Mixed Precision model accuracy
print("\n=== PTQ + Mixed Precision Model Evaluation ===")
quantized_model2.compile(loss=keras.losses.SparseCategoricalCrossentropy(), metrics="accuracy")
quantized_accuracy = quantized_model2.evaluate(val_dataset)
print(f"PTQ_Keras_MixP Quantized model's Top 1 accuracy on the Imagenet validation set: {(quantized_accuracy[1] * 100):.2f}%")

# Evaluate GPTQ quantized model accuracy
print("\n=== GPTQ Model Evaluation ===")
quantized_model3.compile(loss=keras.losses.SparseCategoricalCrossentropy(), metrics="accuracy")
quantized_accuracy = quantized_model3.evaluate(val_dataset)
print(f"GPTQ_Keras Quantized model's Top 1 accuracy on the Imagenet validation set: {(quantized_accuracy[1] * 100):.2f}%")

# Evaluate GPTQ + Mixed Precision model accuracy
print("\n=== GPTQ + Mixed Precision Model Evaluation ===")
quantized_model4.compile(loss=keras.losses.SparseCategoricalCrossentropy(), metrics="accuracy")
quantized_accuracy = quantized_model4.evaluate(val_dataset)
print(f"GPTQ_Keras_MixP Quantized model's Top 1 accuracy on the Imagenet validation set: {(quantized_accuracy[1] * 100):.2f}%")

# LQ-PTQ model evaluation (commented out)
#print("\n=== LQ-PTQ Model Evaluation ===")
#quantized_model5.compile(loss=keras.losses.SparseCategoricalCrossentropy(), metrics="accuracy")
#quantized_accuracy = quantized_model5.evaluate(val_dataset)
#print(f"Quantized model5's Top 1 accuracy on the Imagenet validation set: {(quantized_accuracy[1] * 100):.2f}%")

print("Fisish")

## Conclusion

In this tutorial, we demonstrated how to quantize a pre-trained model using MCTWrapper with a few lines of code.

MCT can deliver competitive results across a wide range of tasks and network architectures. For more details, [check out the paper:](https://arxiv.org/abs/2109.09113).

## Copyrights

Copyright 2024 Sony Semiconductor Solutions, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
