# Model Compression Toolkit (MCT) Wrapper API Comprehensive Quantization Comparison (pytorch)

[Run this tutorial in Google Colab](https://colab.research.google.com/github/SonySemiconductorSolutions/mct-model-optimization/blob/main/tutorials/notebooks/mct_features_notebooks/pytorch/example_pytorch_wrapper.ipynb)

## Overview 
This notebook provides a comprehensive demonstration of the MCT (Model Compression Toolkit) Wrapper API functionality, showcasing five different quantization methods on a MobileNetV2 model. The tutorial systematically compares the implementation, performance characteristics, and accuracy trade-offs of each quantization approach: PTQ (Post-Training Quantization), PTQ with Mixed Precision, GPTQ (Gradient-based PTQ), GPTQ with Mixed Precision. Each method utilizes the unified MCTWrapper interface for consistent implementation and comparison.

## Summary
1. **Environment Setup**: Import required libraries and configure MCT with MobileNetV2 model
2. **Dataset Preparation**: Load and prepare ImageNet validation dataset with representative data generation
3. **PTQ Implementation**: Execute basic Post-Training Quantization with 8-bit precision and bias correction
4. **PTQ + Mixed Precision**: Apply intelligent bit-width allocation based on layer sensitivity analysis (75% compression ratio)
5. **GPTQ Implementation**: Perform gradient-based optimization with 5-epoch fine-tuning for enhanced accuracy
6. **GPTQ + Mixed Precision**: Combine gradient optimization with mixed precision for optimal accuracy-compression trade-off
7. **Performance Evaluation**: Comprehensive accuracy assessment and comparison across all quantization methods
8. **Results Analysis**: Compare model sizes, inference accuracy, and quantization trade-offs

## Setup

In [None]:
# Import required libraries for PyTorch deep learning and data handling
import os
import torch
from torch.utils.data import DataLoader
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
from torchvision.datasets import ImageNet
from tqdm import tqdm
from typing import List, Tuple, Generator, Any, Callable

In [None]:
# Configure system path to include MCT library for local development
import sys
sys.path.append('/home/ubuntu/wrapper/sonyfork/mct-model-optimization')

#pip install -q tensorflow
#import importlib
#if not importlib.util.find_spec('model_compression_toolkit'):
#   !pip install model_compression_toolkit

# Import Model Compression Toolkit (MCT) core functionality for PyTorch
import model_compression_toolkit as mct
from model_compression_toolkit.core import QuantizationErrorMethod

## Dataset preparation
Download ImageNet dataset with only the validation split.

**Note** that for demonstration purposes we use the validation set for the model quantization routines. Usually, a subset of the training dataset is used, but loading it is a heavy procedure that is unnecessary for the sake of this demonstration.

This step may take several minutes...

In [None]:
# Download and setup ImageNet validation dataset if not already present
if not os.path.isdir('imagenet'):
    # Create directory and download required ImageNet files
    os.system('mkdir imagenet')
    os.system('wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz')
    os.system('wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar')


### Representative dataset construction
We show how to create a generator for the representative dataset, which is required for post-training quantization.

The representative dataset is used for collecting statistics on the inference outputs of all layers in the model.
 
In order to decide on the size of the representative dataset, we configure the batch size and the number of calibration iterations.
This gives us the total number of samples that will be used during PTQ (batch_size x n_iter).
In this example we set `batch_size = 50` and `n_iter = 10`, resulting in a total of 500 representative images.

Please ensure that the dataset path has been set correctly.

In [None]:
# Load pre-trained MobileNetV2 weights and configure dataset transforms
weights = MobileNet_V2_Weights.IMAGENET1K_V2  # Use ImageNet V2 pre-trained weights
# Create ImageNet validation dataset with automatic preprocessing transforms
dataset = ImageNet(root='./imagenet', split='val', transform=weights.transforms())

# Configuration parameters for representative dataset generation
default_batch_size: int = 10  # Batch size for quantization calibration data
n_iter: int = 5               # Number of iterations to generate representative batches
# Create DataLoader with shuffling for representative data diversity
dataloader = DataLoader(dataset, batch_size=default_batch_size, shuffle=True)

def representative_dataset_gen() -> Generator[List[torch.Tensor], None, None]:
    """
    Generator function for representative dataset used in PyTorch quantization.
    
    This function provides calibration data that MCT uses to:
    - Determine optimal quantization parameters for PyTorch models
    - Calibrate activation ranges and thresholds
    - Configure layer-specific quantization settings
    
    Yields:
        List containing PyTorch tensors for model calibration
    """
    dataloader_iter = iter(dataloader)
    for _ in range(n_iter):
        # Extract image batch (ignore labels) and yield as list for MCT compatibility
        yield [next(dataloader_iter)[0]]

## Model Evaluation Function
Define a comprehensive evaluation function for PyTorch models that provides accurate performance measurement on the validation dataset with GPU acceleration support.

In [None]:
def evaluate(model: torch.nn.Module, testloader: DataLoader, mode: str) -> float:
    """
    Evaluate PyTorch model accuracy using a DataLoader with GPU acceleration.
    
    This function performs complete accuracy evaluation by:
    - Moving model and data to available device (GPU/CPU)
    - Running inference in evaluation mode (no gradient computation)
    - Computing Top-1 accuracy across the entire validation set
    - Providing progress tracking during evaluation
    
    Args:
        model: PyTorch model to evaluate (float or quantized)
        testloader: DataLoader containing validation dataset
        mode: String identifier for logging (e.g., 'Float', 'PTQ_Pytorch')
    
    Returns:
        float: Top-1 accuracy percentage
    """
    # Determine best available device for inference
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()  # Set model to evaluation mode
    correct = 0
    total = 0
    
    # Perform inference without gradient computation for efficiency
    with torch.no_grad():
        for data in tqdm(testloader):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass to get predictions
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    # Calculate and display accuracy
    val_acc = (100 * correct / total)
    print(mode + ' Accuracy: %.2f%%' % val_acc)
    return val_acc

## Model Post-Training quantization using MCTWrapper

In [None]:
# Decorator to provide consistent logging and error handling for quantization functions
def decorator(func: Callable[[torch.nn.Module], Tuple[bool, torch.nn.Module]]) -> Callable[[torch.nn.Module], Tuple[bool, torch.nn.Module]]:
    """
    Wrapper decorator that provides:
    - Consistent start/end logging for quantization operations
    - Automatic error handling and program termination on failure
    - Success/failure status tracking for all quantization methods
    
    Args:
        func: Quantization function to be decorated
    
    Returns:
        Wrapped function with enhanced logging and error handling
    """
    def wrapper(*args: Any, **kwargs: Any) -> Tuple[bool, torch.nn.Module]:
        print(f"----------------- {func.__name__} Start ---------------")
        flag, result = func(*args, **kwargs)
        print(f"----------------- {func.__name__} End -----------------")
        if not flag:exit()
        return flag, result
    return wrapper

Run PTQ (Post-Training Quantization) with PyTorch

In [None]:
@decorator
def PTQ_Pytorch(float_model: torch.nn.Module) -> Tuple[bool, torch.nn.Module]:
    """
    Perform Post-Training Quantization (PTQ) on PyTorch model.
    
    PTQ for PyTorch provides:
    - Fast quantization without model retraining
    - Standard 8-bit integer quantization
    - Efficient calibration using representative data
    - Direct ONNX export for deployment
    
    Args:
        float_model: Original floating-point PyTorch model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for PyTorch PTQ quantization
    method = 'PTQ'                    # Post-Training Quantization method
    framework = 'pytorch'             # Target framework (PyTorch)
    use_MCT_TPC = False               # Use external EdgeMDT Target Platform Capabilities
    use_MixP = False                  # Disable mixed-precision quantization

    # Parameter configuration for PyTorch PTQ
    param_items = [
        # Platform configuration
        ['target_platform_version', 'v1', 'Target platform capabilities version'],
        
        # Quantization configuration parameters
        ['activation_error_method', QuantizationErrorMethod.MSE, 'Error metric for activation quantization'],
        ['weights_bias_correction', True, 'Enable bias correction for weights'],
        ['z_threshold', float('inf'), 'Threshold for zero-point quantization'],
        ['linear_collapsing', True, 'Enable linear layer collapsing optimization'],
        ['residual_collapsing', True, 'Enable residual connection collapsing'],
        
        # Output configuration
        ['save_model_path', './qmodel_PTQ_Pytorch.onnx', 'Path to save quantized model as ONNX']
    ]

    # Execute PyTorch PTQ quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model, method, framework, use_MCT_TPC, use_MixP, 
        representative_dataset_gen, param_items)
    return flag, quantized_model

Run PTQ + Mixed Precision Quantization (MixP) with PyTorch

In [None]:
@decorator
def PTQ_Pytorch_MixP(float_model: torch.nn.Module) -> Tuple[bool, torch.nn.Module]:
    """
    Perform Post-Training Quantization with Mixed Precision (PTQ + MixP) on PyTorch model.
    
    Mixed Precision PTQ for PyTorch offers:
    - Automatic bit-width selection per layer
    - Optimal size/accuracy trade-off
    - Resource-constrained quantization
    - Advanced sensitivity analysis for PyTorch models
    
    Args:
        float_model: Original floating-point PyTorch model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for PyTorch PTQ with mixed precision
    method = 'PTQ'                    # Post-Training Quantization method
    framework = 'pytorch'             # Target framework (PyTorch)
    use_MCT_TPC = False               # Use external EdgeMDT Target Platform Capabilities
    use_MixP = True                   # Enable mixed-precision quantization

    # Parameter configuration for PyTorch PTQ with Mixed Precision
    param_items = [
        # Platform configuration
        ['target_platform_version', 'v1', 'Target platform capabilities version'],
        
        # Mixed precision configuration
        ['num_of_images', 5, 'Number of images for mixed precision sensitivity analysis'],
        ['use_hessian_based_scores', False, 'Use Hessian-based scores for layer importance ranking'],
        
        # Resource constraint configuration (more aggressive for PyTorch)
        ['weights_compression_ratio', 0.5, 'Target compression ratio for model weights (50% reduction)'],
        
        # Output configuration
        ['save_model_path', './qmodel_PTQ_Pytorch_MixP.onnx', 'Path to save mixed precision quantized model']
    ]

    # Execute PyTorch mixed precision PTQ using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model, method, framework, use_MCT_TPC, use_MixP, 
        representative_dataset_gen, param_items)
    return flag, quantized_model

Run GPTQ (Gradient-based PTQ) with PyTorch

In [None]:
@decorator
def GPTQ_Pytorch(float_model: torch.nn.Module) -> Tuple[bool, torch.nn.Module]:
    """
    Perform Gradient-based Post-Training Quantization (GPTQ) on PyTorch model.
    
    GPTQ for PyTorch provides:
    - Advanced gradient-based quantization optimization
    - Fine-tuning during quantization process
    - Superior accuracy preservation compared to PTQ
    - Optimized parameter updates using representative data
    
    Args:
        float_model: Original floating-point PyTorch model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for PyTorch GPTQ quantization
    method = 'GPTQ'                   # Gradient-based Post-Training Quantization
    framework = 'pytorch'             # Target framework (PyTorch)
    use_MCT_TPC = False               # Use external EdgeMDT Target Platform Capabilities
    use_MixP = False                  # Disable mixed-precision quantization

    # Parameter configuration for PyTorch GPTQ
    param_items = [
        # Platform configuration
        ['target_platform_version', 'v1', 'Target platform capabilities version'],
        
        # GPTQ-specific training parameters
        ['n_epochs', 5, 'Number of epochs for gradient-based fine-tuning'],
        ['optimizer', None, 'Optimizer for fine-tuning (None = use default Adam)'],
        
        # Output configuration
        ['save_model_path', './qmodel_GPTQ_Pytorch.onnx', 'Path to save GPTQ quantized model']
    ]

    # Execute PyTorch GPTQ quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model, method, framework, use_MCT_TPC, use_MixP, 
        representative_dataset_gen, param_items)
    return flag, quantized_model

Run GPTQ + Mixed Precision Quantization (MixP) with PyTorch

In [None]:
@decorator
def GPTQ_Pytorch_MixP(float_model: torch.nn.Module) -> Tuple[bool, torch.nn.Module]:
    """
    Perform Gradient-based Post-Training Quantization with Mixed Precision (GPTQ + MixP).
    
    This advanced method combines:
    - GPTQ: Gradient-based optimization for optimal quantization parameters
    - Mixed Precision: Automatic bit-width selection for each layer
    
    Provides the best quantization results for PyTorch models with:
    - Maximum accuracy preservation
    - Optimal model size reduction
    - Layer-wise precision optimization
    - Advanced gradient-based calibration
    
    Args:
        float_model: Original floating-point PyTorch model
    
    Returns:
        tuple: (success_flag, quantized_model)
    """
    # Configuration for PyTorch GPTQ with mixed precision
    method = 'GPTQ'                   # Gradient-based Post-Training Quantization
    framework = 'pytorch'             # Target framework (PyTorch)
    use_MCT_TPC = False               # Use external EdgeMDT Target Platform Capabilities
    use_MixP = True                   # Enable mixed-precision quantization

    # Parameter configuration for PyTorch GPTQ with Mixed Precision
    param_items = [
        # Platform configuration
        ['target_platform_version', 'v1', 'Target platform capabilities version'],
        
        # GPTQ-specific training parameters
        ['n_epochs', 5, 'Number of epochs for gradient-based fine-tuning'],
        ['optimizer', None, 'Optimizer for fine-tuning (None = use default Adam)'],
        
        # Mixed precision configuration
        ['num_of_images', 5, 'Number of images for mixed precision sensitivity analysis'],
        ['use_hessian_based_scores', False, 'Use Hessian-based scores for layer importance'],
        
        # Resource constraint configuration
        ['weights_compression_ratio', 0.5, 'Target compression ratio for model weights (50% reduction)'],
        
        # Output configuration
        ['save_model_path', './qmodel_GPTQ_Pytorch_MixP.onnx', 'Path to save GPTQ+MixP quantized model']
    ]

    # Execute advanced PyTorch GPTQ+MixP quantization using MCTWrapper
    wrapper = mct.wrapper.mct_wrapper.MCTWrapper()
    flag, quantized_model = wrapper.quantize_and_export(
        float_model, method, framework, use_MCT_TPC, use_MixP, 
        representative_dataset_gen, param_items)
    return flag, quantized_model

### Run model Post-Training Quantization
Lastly, we quantize our model using MCTWrapper API.

In [None]:
# Load pre-trained MobileNetV2 model with ImageNet weights for quantization experiments
float_model = mobilenet_v2(weights=weights)

# Create DataLoader for validation/evaluation with larger batch size for efficiency
val_dataloader = DataLoader(dataset, batch_size=50, shuffle=False)

In [None]:
# Execute all PyTorch quantization methods on the same base model for comparison
print("Starting PyTorch quantization experiments with different methods...")

# 1. Basic Post-Training Quantization for PyTorch
flag, quantized_model = PTQ_Pytorch(float_model)

# 2. PTQ with Mixed Precision (optimized size/accuracy trade-off for PyTorch)
flag, quantized_model2 = PTQ_Pytorch_MixP(float_model)

# 3. Gradient-based PTQ (improved accuracy through fine-tuning for PyTorch)
flag, quantized_model3 = GPTQ_Pytorch(float_model)

# 4. GPTQ with Mixed Precision (best accuracy with optimal compression for PyTorch)
flag, quantized_model4 = GPTQ_Pytorch_MixP(float_model)

print("All PyTorch quantization methods completed successfully!")

## Models evaluation
In order to evaluate our models, we first need to load the validation dataset. As before, please ensure that the dataset path has been set correctly.

In [None]:
# PyTorch Model Evaluation and Accuracy Comparison
print("Starting PyTorch model evaluation phase...")
print("This evaluation will test all quantized models against the validation dataset")

# Evaluate original floating-point PyTorch model accuracy
print("\n=== Original PyTorch Model Evaluation ===")
evaluate(float_model, val_dataloader, 'Float')

# Evaluate PTQ quantized PyTorch model accuracy
print("\n=== PyTorch PTQ Model Evaluation ===")
evaluate(quantized_model, val_dataloader, 'PTQ_Pytorch')

# Evaluate PTQ + Mixed Precision PyTorch model accuracy
print("\n=== PyTorch PTQ + Mixed Precision Model Evaluation ===")
evaluate(quantized_model2, val_dataloader, 'PTQ_Pytorch_MixP')

# Evaluate GPTQ quantized PyTorch model accuracy
print("\n=== PyTorch GPTQ Model Evaluation ===")
evaluate(quantized_model3, val_dataloader, 'GPTQ_Pytorch')

# Evaluate GPTQ + Mixed Precision PyTorch model accuracy
print("\n=== PyTorch GPTQ + Mixed Precision Model Evaluation ===")
evaluate(quantized_model4, val_dataloader, 'GPTQ_Pytorch_MixP')
print("Fisish")

## Conclusion

In this tutorial, we demonstrated how to quantize a pre-trained model using MCTWrapper with a few lines of code.

MCT can deliver competitive results across a wide range of tasks and network architectures. For more details, [check out the paper:](https://arxiv.org/abs/2109.09113).

## Copyrights

Copyright 2024 Sony Semiconductor Solutions, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
