# TensorRT Quantization Tutorial

This notebook is designed to show the features of the TensorRT passes integrated into MASE as part of the MASERT framework. The following demonstrations were run on a NVIDIA RTX A2000 GPU with a Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz CPU.

## Section 1. Show Configuration
Firstly, we will show you how to do a int8 quantization of a simple model, `jsc-toy`, and compare the quantized model to the original model using the `Machop API`. The quantization process is split into the following stages, each using their own individual pass, and are explained in depth at each subsection:

1. [Fake quantization](#section-11-fake-quantization): `tensorrt_fake_quantize_transform_pass`
2. [Calibration](#section-12-calibration): `tensorrt_calibrate_transform_pass`
3. [Quantized Aware Training](#section-13-quantized-aware-training-qat): `tensorrt_fine_tune_transform_pass`
4. [Quantization](#section-14-tensorrt-quantization): `tensorrt_engine_interface_pass`
5. [Analysis](#section-15-performance-analysis): `tensorrt_analysis_pass`

We start by loading in the required libraries and passes required for the notebook as well as ensuring the correct path is set for machop to be used.

In [11]:
import sys
import os
from pathlib import Path
import toml

# Figure out the correct path
machop_path = Path(".").resolve().parent.parent.parent /"src"
assert machop_path.exists(), "Failed to find machop at: {}".format(machop_path)
sys.path.append(str(machop_path))

# Add directory to the PATH so that chop can be called
new_path = "../../../machop"
full_path = os.path.abspath(new_path)
os.environ['PATH'] += os.pathsep + full_path

from chop.tools.utils import to_numpy_if_tensor
from chop.tools.logger import set_logging_verbosity
from chop.tools import get_cf_args, get_dummy_input
from chop.passes.graph.utils import deepcopy_mase_graph
from chop.tools.get_input import InputGenerator
from chop.tools.checkpoint_load import load_model
from chop.ir import MaseGraph
from chop.models import get_model_info, get_model, get_tokenizer
from chop.dataset import MaseDataModule, get_dataset_info
from chop.passes.graph.transforms import metadata_value_type_cast_transform_pass
from chop.passes.graph import (
    summarize_quantization_analysis_pass,
    add_common_metadata_analysis_pass,
    init_metadata_analysis_pass,
    add_software_metadata_analysis_pass,
    tensorrt_calibrate_transform_pass,
    tensorrt_fake_quantize_transform_pass,
    tensorrt_fine_tune_transform_pass,
    tensorrt_engine_interface_pass,
    runtime_analysis_pass,
    )

set_logging_verbosity("info")

[32mINFO    [0m [34mSet logging level to info[0m


Check dependency (the dependent package "cuda" refers to "cuda-python")

In [12]:
from chop.tools.check_dependency import check_deps_tensorRT_pass
check_deps_tensorRT_pass(silent=False)

[32mINFO    [0m [34mExtension: All dependencies for TensorRT pass are available.[0m


True

Next, we load in the toml file used for quantization. To view the configuration, click [here](../../../machop/configs/tensorrt/jsc_toy_INT8_quantization_by_type.toml).

In [13]:
import toml
# Path to your TOML file
# JSC_TOML_PATH = 'toy_INT8_quantization_by_type.toml'
JSC_TOML_PATH = 'resnet50_INT8_quant.toml'

# Reading TOML file and converting it into a Python dictionary
with open(JSC_TOML_PATH, 'r') as toml_file:
    pass_args = toml.load(toml_file)

# Extract the 'passes.tensorrt' section and its children
tensorrt_config = pass_args.get('passes', {}).get('tensorrt', {})
print(tensorrt_config)
# Extract the 'passes.runtime_analysis' section and its children
runtime_analysis_config = pass_args.get('passes', {}).get('tensorrt', {}).get('runtime_analysis', {})
print(runtime_analysis_config)

{'by': 'type', 'num_calibration_batches': 10, 'post_calibration_analysis': True, 'default': {'config': {'quantize': True, 'calibrators': ['percentile', 'mse', 'entropy'], 'percentiles': [99.0, 99.9, 99.99], 'precision': 'int8'}, 'input': {'calibrator': 'histogram', 'quantize_axis': False}, 'weight': {'calibrator': 'histogram', 'quantize_axis': False}}, 'fine_tune': {'fine_tune': True}, 'runtime_analysis': {'num_batches': 500, 'num_GPU_warmup_batches': 5, 'test': True}}
{'num_batches': 500, 'num_GPU_warmup_batches': 5, 'test': True}


We then create a `MaseGraph` by loading in a model and training it using the toml configuration model arguments.

In [14]:
from chop.dataset import MaseDataModule
from chop.models import get_model_info
from chop.models import get_model
from chop.tools.get_input import InputGenerator

# Load the basics in
model_name = pass_args['model']
dataset_name = pass_args['dataset']
max_epochs = pass_args['max_epochs']
batch_size = pass_args['batch_size']
learning_rate = pass_args['learning_rate']
accelerator = pass_args['accelerator']

data_module = MaseDataModule(
    name=dataset_name,
    batch_size=batch_size,
    model_name=model_name,
    num_workers=0,
)
data_module.prepare_data()
data_module.setup()

# Add the data_module and other necessary information to the configs
configs = [tensorrt_config, runtime_analysis_config]
for config in configs:
    config['task'] = pass_args['task']
    config['dataset'] = pass_args['dataset']
    config['batch_size'] = pass_args['batch_size']
    config['model'] = pass_args['model']
    config['data_module'] = data_module
    config['accelerator'] = 'cuda' if pass_args['accelerator'] == 'gpu' else pass_args['accelerator']
    if config['accelerator'] == 'gpu':
        os.environ['CUDA_MODULE_LOADING'] = 'LAZY'

model_info = get_model_info(model_name)
# quant_modules.initialize()
model = get_model(
    model_name,
    # task="cls",
    dataset_info=data_module.dataset_info,
    pretrained=False)


input_generator = InputGenerator(
    data_module=data_module,
    model_info=model_info,
    task="cls",
    which_dataloader="train",
)

# generate the mase graph and initialize node metadata
mg = MaseGraph(model=model)

model_info is MaseModelInfo(name='resnet', model_source=<ModelSource.TORCHVISION: 'torchvision'>, task_type=<ModelTaskType.VISION: 'vision'>, image_classification=True, physical_data_point_classification=False, sequence_classification=False, seq2seqLM=False, causal_LM=False, is_quantized=False, is_lora=False, is_sparse=False, is_fx_traceable=True)


In [16]:
!python3 ./ch train --config /workspace/ADLS_Proj/docs/tutorials/proj/resnet50_INT8_quant.toml

INFO: Seed set to 0
I0315 18:20:15.102040 140174726280256 seed.py:57] Seed set to 0
+-------------------------+--------------------------+--------------+-----------------+--------------------------+
| Name                    |         Default          | Config. File | Manual Override |        Effective         |
+-------------------------+--------------------------+--------------+-----------------+--------------------------+
| task                    |      [38;5;8mclassification[0m      |     cls      |                 |           cls            |
| load_name               |           None           |              |                 |           None           |
| load_type               |            mz            |              |                 |            mz            |
| batch_size              |           [38;5;8m128[0m            |      64      |                 |            64            |
| to_debug                |          False           |              |                

Then we load in the checkpoint. You will have to adjust this according to where it has been stored in the mase_output directory.

In [None]:
# Load in the trained checkpoint - change this accordingly
RES_CHECKPOINT_PATH = "/workspace/ADLS_Proj/mase_output/resnet18_cls_cifar10_2025-03-08/software/training_ckpts/best.ckpt"


model = load_model(load_name=RES_CHECKPOINT_PATH, load_type="pl", model=model)
print("load model done!")

# Initiate metadata
dummy_in = next(iter(input_generator))
print("dummy in done")

_ = model(**dummy_in)
print("_ done")

mg, _ = init_metadata_analysis_pass(mg, None)
print("init_metadata_analysis_pass done")

mg, _ = add_common_metadata_analysis_pass(mg, {"dummy_in": dummy_in})
print("add_common_metadata_analysis_pass done")

mg, _ = add_software_metadata_analysis_pass(mg, None)
print("add_software_metadata_analysis_pass done")

mg, _ = metadata_value_type_cast_transform_pass(mg, pass_args={"fn": to_numpy_if_tensor})
print("metadata_value_type_cast_transform_pass done")

# Before we begin, we will copy the original MaseGraph model to use for comparison during quantization analysis
mg_original = deepcopy_mase_graph(mg)
print("deep copy done")

[32mINFO    [0m [34mLoaded pytorch lightning checkpoint from /workspace/ADLS_Proj/mase_output/resnet18_cls_cifar10_2025-03-08/software/training_ckpts/best.ckpt[0m


load model done!
dummy in done
_ done
init_metadata_analysis_pass done
add_common_metadata_analysis_pass done
add_software_metadata_analysis_pass done
metadata_value_type_cast_transform_pass done
using safe deepcopy
deep copy done


## Section 2. Resnet: INT8/FP16/FP32 Quantization Comparison

We will now load in a new toml configuration that uses fp16 instead of int8, whilst keeping the other settings the exact same for a fair comparison. This time however, we will use chop from the terminal which runs all the passes showcased in [Section 1](#section-1---int8-quantization).

Since float quantization does not require calibration, nor is it supported by `pytorch-quantization`, the model will not undergo fake quantization; for the time being this unfortunately means QAT is unavailable and only undergoes Post Training Quantization (PTQ). 

In [2]:
RES_INT8_BY_TYPE_TOML = "/workspace/ADLS_Proj/docs/tutorials/proj/resnet50_INT8_quant_debug.toml"
RES_CHECKPOINT_PATH = "/workspace/ADLS_Proj/mase_output/resnet50_cls_cifar10_2025-03-15/software/training_ckpts/best.ckpt"
!python ch transform --config {RES_INT8_BY_TYPE_TOML} --load {RES_CHECKPOINT_PATH} --load-type pl

INFO: Seed set to 0
I0316 23:05:43.177328 139704359384128 seed.py:57] Seed set to 0
+-------------------------+--------------------------+--------------+--------------------------+--------------------------+
| Name                    |         Default          | Config. File |     Manual Override      |        Effective         |
+-------------------------+--------------------------+--------------+--------------------------+--------------------------+
| task                    |      [38;5;8mclassification[0m      |     cls      |                          |           cls            |
| load_name               |           [38;5;8mNone[0m           |              | /workspace/ADLS_Proj/mas | /workspace/ADLS_Proj/mas |
|                         |                          |              | e_output/resnet50_cls_ci | e_output/resnet50_cls_ci |
|                         |                          |              | far10_2025-03-15/softwar | far10_2025-03-15/softwar |
|                     

In [1]:
RES_INT8_BY_TYPE_TOML = "/workspace/ADLS_Proj/docs/tutorials/proj/resnet50_INT8_quant_debug.toml"
RES_CHECKPOINT_PATH = "/workspace/ADLS_Proj/mase_output/resnet50_cls_cifar10_2025-03-15/software/training_ckpts/best.ckpt"
!python ch transform --config {RES_INT8_BY_TYPE_TOML} --load {RES_CHECKPOINT_PATH} --load-type pl

INFO: Seed set to 0
I0316 20:10:49.459262 140012433056832 seed.py:57] Seed set to 0
+-------------------------+--------------------------+--------------+--------------------------+--------------------------+
| Name                    |         Default          | Config. File |     Manual Override      |        Effective         |
+-------------------------+--------------------------+--------------+--------------------------+--------------------------+
| task                    |      [38;5;8mclassification[0m      |     cls      |                          |           cls            |
| load_name               |           [38;5;8mNone[0m           |              | /workspace/ADLS_Proj/mas | /workspace/ADLS_Proj/mas |
|                         |                          |              | e_output/resnet50_cls_ci | e_output/resnet50_cls_ci |
|                         |                          |              | far10_2025-03-15/softwar | far10_2025-03-15/softwar |
|                     

In [3]:
RES_INT8_BY_TYPE_TOML = "/workspace/ADLS_Proj/docs/tutorials/proj/resnet50_INT8_quant.toml"
RES_CHECKPOINT_PATH = "/workspace/ADLS_Proj/mase_output/resnet50_cls_cifar10_2025-03-15/software/training_ckpts/best.ckpt"
!python ch transform --config {RES_INT8_BY_TYPE_TOML} --load {RES_CHECKPOINT_PATH} --load-type pl

INFO: Seed set to 0
I0316 00:18:52.977772 140198266307648 seed.py:57] Seed set to 0
+-------------------------+--------------------------+--------------+--------------------------+--------------------------+
| Name                    |         Default          | Config. File |     Manual Override      |        Effective         |
+-------------------------+--------------------------+--------------+--------------------------+--------------------------+
| task                    |      [38;5;8mclassification[0m      |     cls      |                          |           cls            |
| load_name               |           [38;5;8mNone[0m           |              | /workspace/ADLS_Proj/mas | /workspace/ADLS_Proj/mas |
|                         |                          |              | e_output/resnet50_cls_ci | e_output/resnet50_cls_ci |
|                         |                          |              | far10_2025-03-15/softwar | far10_2025-03-15/softwar |
|                     

In [4]:
RES_FP16_BY_TYPE_TOML = "/workspace/ADLS_Proj/docs/tutorials/proj/resnet50_FP16_quant.toml"
RES_CHECKPOINT_PATH = "/workspace/ADLS_Proj/mase_output/resnet50_cls_cifar10_2025-03-15/software/training_ckpts/best.ckpt"
!python ch transform --config {RES_FP16_BY_TYPE_TOML} --load {RES_CHECKPOINT_PATH} --load-type pl

INFO: Seed set to 0
I0316 01:04:00.709789 140315264521280 seed.py:57] Seed set to 0
+-------------------------+--------------------------+--------------+--------------------------+--------------------------+
| Name                    |         Default          | Config. File |     Manual Override      |        Effective         |
+-------------------------+--------------------------+--------------+--------------------------+--------------------------+
| task                    |      [38;5;8mclassification[0m      |     cls      |                          |           cls            |
| load_name               |           [38;5;8mNone[0m           |              | /workspace/ADLS_Proj/mas | /workspace/ADLS_Proj/mas |
|                         |                          |              | e_output/resnet50_cls_ci | e_output/resnet50_cls_ci |
|                         |                          |              | far10_2025-03-15/softwar | far10_2025-03-15/softwar |
|                     

In [5]:
RES_FP32_BY_TYPE_TOML = "/workspace/ADLS_Proj/docs/tutorials/proj/resnet50_FP32_quant.toml"
RES_CHECKPOINT_PATH = "/workspace/ADLS_Proj/mase_output/resnet50_cls_cifar10_2025-03-15/software/training_ckpts/best.ckpt"
!python ch transform --config {RES_FP32_BY_TYPE_TOML} --load {RES_CHECKPOINT_PATH} --load-type pl

INFO: Seed set to 0
I0316 01:09:12.493298 140488866358336 seed.py:57] Seed set to 0
+-------------------------+--------------------------+--------------+--------------------------+--------------------------+
| Name                    |         Default          | Config. File |     Manual Override      |        Effective         |
+-------------------------+--------------------------+--------------+--------------------------+--------------------------+
| task                    |      [38;5;8mclassification[0m      |     cls      |                          |           cls            |
| load_name               |           [38;5;8mNone[0m           |              | /workspace/ADLS_Proj/mas | /workspace/ADLS_Proj/mas |
|                         |                          |              | e_output/resnet50_cls_ci | e_output/resnet50_cls_ci |
|                         |                          |              | far10_2025-03-15/softwar | far10_2025-03-15/softwar |
|                     

As you can see, `fp16` acheives a slighty higher test accuracy but a slightly lower latency (~30%) from that of int8 quantization; it is still ~2.5x faster than the unquantized model. Now lets apply quantization to a more complicated model.
