# Post-Training Quantization in PyTorch using the Model Compression Toolkit (MCT)


## Overview
This quick-start guide explains how to use the **Model Compression Toolkit (MCT)** to quantize a PyTorch model. We will load a pre-trained model and  quantize it using the MCT with **Post-Training Quatntization (PTQ)**. Finally, we will evaluate the quantized model and export it to an ONNX file.

## Summary
In this tutorial, we will cover:

1. Loading and preprocessing ImageNet’s validation dataset.
2. Constructing an unlabeled representative dataset.
3. Post-Training Quantization using MCT.
4. Accuracy evaluation of the floating-point and the quantized models.

## Setup
Install the relevant packages:

In [1]:
import os
from importlib import util

IMX500_AI_TOOLCHAIN_VER = os.getenv('IMX500_AI_TOOLCHAIN_VERSION', '0.0.0.dev0')

if not util.find_spec('imx500_ai_toolchain') or not util.find_spec("uni.pytorch"):
    print(f"Installing imx500_ai_toolchain {IMX500_AI_TOOLCHAIN_VER}")
    !pip install imx500_ai_toolchain[pt]

if not util.find_spec('torch') or not util.find_spec("torchvision"):
    !pip install -q torch torchvision

In [2]:
from torch.utils.data import DataLoader
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
from torchvision.datasets import ImageNet


Load a pre-trained MobileNetV2 model from torchvision, in 32-bits floating-point precision format.

In [3]:
weights = MobileNet_V2_Weights.IMAGENET1K_V2

float_model = mobilenet_v2(weights=weights)

## Dataset preparation
### Download ImageNet validation set
Download ImageNet dataset with only the validation split.

**Note** that for demonstration purposes we use the validation set for the model quantization routines. Usually, a subset of the training dataset is used, but loading it is a heavy procedure that is unnecessary for the sake of this demonstration.

This step may take several minutes...

In [4]:
import os

IMAGENET_DIR = os.getenv('IMAGENET_DIR', './imagenet')
print(f"using imagenet from: {IMAGENET_DIR}")

if not os.path.isdir(IMAGENET_DIR):
    !mkdir imagenet
    !wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz
    !wget -P imagenet https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar

Extract ImageNet validation dataset using torchvision "datasets" module.

In [5]:
dataset = ImageNet(root=IMAGENET_DIR, split='val', transform=weights.transforms())

## Representative Dataset
For quantization with MCT, we need to define a representative dataset required by the PTQ algorithm. This dataset is a generator that returns a list of images:

In [6]:
batch_size = 16
n_iter = 10

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

def representative_dataset_gen():
    dataloader_iter = iter(dataloader)
    for _ in range(n_iter):
        yield [next(dataloader_iter)[0]]


## Target Platform Capabilities (TPC)
In addition, MCT optimizes the model for dedicated hardware platforms. This is done using TPC (for more details, please visit our [documentation](https://sony.github.io/model_optimization/api/api_docs/modules/target_platform.html)). Here, we use the default Pytorch TPC:

In [7]:
import model_compression_toolkit as mct

# Get a FrameworkQuantizationCapabilities object that models the hardware platform for the quantized model inference. Here, for example, we use the default platform that is attached to a Pytorch layers representation.
target_platform_cap = mct.get_target_platform_capabilities('pytorch', 'default')

## Post-Training Quantization using MCT
Now for the exciting part! Let’s run PTQ on the model. 

In [8]:
quantized_model, quantization_info = mct.ptq.pytorch_post_training_quantization(
        in_module=float_model,
        representative_data_gen=representative_dataset_gen,
        target_platform_capabilities=target_platform_cap
)

Statistics Collection: 10it [00:21,  2.13s/it]



Running quantization parameters search. This process might take some time, depending on the model size and the selected quantization methods.



Calculating quantization parameters: 100%|██████████| 102/102 [00:09<00:00, 11.13it/s]


Please run your accuracy evaluation on the exported quantized model to verify it's accuracy.
Checkout the FAQ and Troubleshooting pages for resolving common issues and improving the quantized model accuracy:
FAQ: https://github.com/sony/model_optimization/tree/main/FAQ.md
Quantization Troubleshooting: https://github.com/sony/model_optimization/tree/main/quantization_troubleshooting.md





Our model is now quantized. MCT has created a simulated quantized model within the original PyTorch framework by inserting [quantization representation modules](https://github.com/sony/mct_quantizers). These modules, such as `PytorchQuantizationWrapper` and `PytorchActivationQuantizationHolder`, wrap PyTorch layers to simulate the quantization of weights and activations, respectively. While the size of the saved model remains unchanged, all the quantization parameters are stored within these modules and are ready for deployment on the target hardware. In this example, we used the default MCT settings, which compressed the model from 32 bits to 8 bits, resulting in a compression ratio of 4x. Let's print the quantized model and examine the quantization modules:

In [9]:
save_folder = './mobilenet_pt'
os.makedirs(save_folder, exist_ok=True)
onnx_path = os.path.join(save_folder, 'qmodel.onnx')
mct.exporter.pytorch_export_model(quantized_model, save_model_path=onnx_path, repr_dataset=representative_dataset_gen)

Exporting onnx model with MCTQ quantizers: ./mobilenet_pt/qmodel.onnx


  threshold = torch.tensor(threshold, dtype=torch.float32).to(get_working_device())


In [10]:
import subprocess
from IPython.display import display, HTML

try:
    # Check if Java is installed
    result = subprocess.run(["java", "-version"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
    if result.returncode == 0:
        print("Java is installed:")
        print(result.stderr.strip())  # Java version details are typically in stderr
    else:
        raise FileNotFoundError
except FileNotFoundError:
    # Display an error message and halt further execution
    display(
        HTML("<p style='color: red; font-weight: bold;'>Java is not installed. Please install Java 17 to proceed.</p>"))
    raise SystemExit("Stopping execution: Java is not installed.")

Java is installed:
openjdk version "17.0.13" 2024-10-15
OpenJDK Runtime Environment Homebrew (build 17.0.13+0)
OpenJDK 64-Bit Server VM Homebrew (build 17.0.13+0, mixed mode, sharing)


In [11]:
import subprocess
import sys
cmd = ["imxconv-pt", "-i", onnx_path,  "-o", save_folder, "--overwrite-output"]

env_bin_path = os.path.dirname(sys.executable)
os.environ["PATH"] = f"{env_bin_path}:{os.environ['PATH']}"
env = os.environ.copy()

subprocess.run(cmd, env=env, check=True)

2025-01-14 19:41:52,786 INFO : Running version 1.10.0 [/Users/chizkiyahu/envs/wr_pt/lib/python3.11/site-packages/uni/common/logger.py:148]
2025-01-14 19:41:52,786 INFO : Converting mobilenet_pt/qmodel.onnx [/Users/chizkiyahu/envs/wr_pt/lib/python3.11/site-packages/uni/common/logger.py:148]
2025-01-14 19:41:54,838 INFO : Wrote outputs to /var/folders/pp/xsvkwn4n6dz94h8y0_nrw_fw0000gn/T/tmpvwiazter/qmodel.uni-pytorch.um.pb [/Users/chizkiyahu/envs/wr_pt/lib/python3.11/site-packages/uni/common/logger.py:148]
2025-01-14 19:41:54,851 INFO : Converted successfully [/Users/chizkiyahu/envs/wr_pt/lib/python3.11/site-packages/uni/common/logger.py:148]
2025-01-14 19:41:55,232 INFO : CODE: [START] Starting SDSPconv
2025-01-14 19:42:03,696 INFO : ConvFe conversion finished successfully
2025-01-14 19:42:04,084 INFO : CBE component - DspConvParser has started conversion.
2025-01-14 19:42:04,217 INFO : Dsp-Dnn-Parser finished successfully !
2025-01-14 19:42:05,047 INFO : LogicModel generated successful

CompletedProcess(args=['imxconv-pt', '-i', './mobilenet_pt/qmodel.onnx', '-o', './mobilenet_pt', '--overwrite-output'], returncode=0)

## Conclusion

In this tutorial, we demonstrated how to quantize a classification model for MNIST in a hardware-friendly manner using MCT. We observed that a 4x compression ratio was achieved with minimal performance degradation.

The key advantage of hardware-friendly quantization is that the model can run more efficiently in terms of runtime, power consumption, and memory usage on designated hardware.

MCT can deliver competitive results across a wide range of tasks and network architectures. For more details, [check out the paper:](https://arxiv.org/abs/2109.09113).

## Copyrights

Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
