# 3. Quantization

This notebook demonstrates how to quantize a model to INT8 using `torch.quantization`. Quantization, and especially INT8 quantization, aims to transform a 32-bits floating-point model into a more efficient 8-bits integer representation, leading up to 4x smaller model size and faster inference times due to integer arithmetic.

This is achieved by replacing the floating-point operations in the model with their integer counterparts. This requires a calibration step to determine the scale ($S$) and zero-point ($Z$) for each quantized tensor. The relationship between the real-world floating-point value ($r$) and its quantized integer representation ($q$) is given by:

$q = \text{round}(r / S) + Z$

Conversely, the dequantized value ($r_{\text{approx}}$) can be obtained by:

$r_{\text{approx}} = S \cdot (q - Z)$

The calibration process involves running inference on a representative dataset (the calibration dataset) to observe the range of activation values and determine appropriate $S$ and $Z$ values for each layer's weights and activations to minimize the loss of accuracy due to quantization.

Quantization is applied to both the weights and activations of the model. The weights are quantized to INT8, while the activations can be quantized to either INT8 or UINT8, depending on the use case. The choice of quantization type (symmetric vs. asymmetric) also affects how $S$ and $Z$ are calculated.

There are two main approaches to quantizing a model:

* **Post-Training Quantization (PTQ):** This is the simpler approach where a pre-trained FP32 model is quantized without any retraining. It involves:
    * Calibrating the model by feeding it a small, representative dataset (calibration data) to collect statistics (e.g., min/max ranges) of the weights and activations.
    *   These statistics are then used to calculate the quantization parameters ($S$ and $Z$) for each layer.
PTQ is fast and easy to implement but might lead to a more significant accuracy drop, especially for models sensitive to quantization.

* **Quantization-Aware Training (QAT):** QAT simulates the quantization effects (noise and clamping) during the training or fine-tuning process. It inserts "fake quantization" modules into the model architecture. These modules mimic the quantization and dequantization steps during the forward pass, while the backward pass uses straight-through estimators (STE) to allow gradients to flow.
The model learns to become robust to the quantization noise, often resulting in better accuracy compared to PTQ, but it requires access to the training pipeline and more computational resources.

The choice of $S$ and $Z$ is crucial for minimizing the quantization error. Common methods include:

* **Min-Max:** The simplest method. $S$ and $Z$ are determined by the minimum ($r_{\text{min}}$) and maximum ($r_{\text{max}}$) values observed in the tensor during calibration. Depending on how 0.0f maps to integers, the formulas are:
    * For asymmetric quantization (where 0.0f might not map to an integer 0):
        $S = (r_{\text{max}} - r_{\text{min}}) / (q_{\text{max}} - q_{\text{min}})$
        $Z = q_{\text{min}} - \text{round}(r_{\text{min}} / S)$
        (where $q_{\text{min}}$ and $q_{\text{max}}$ are the min/max values of the target integer range, e.g., 0 and 255 for uint8, or -128 and 127 for int8).
    * For symmetric quantization (where 0.0f maps to integer 0, so $Z=0$ for signed types or $Z$ is fixed for unsigned types):
        $S = \max(|r_{\text{min}}|, |r_{\text{max}}|) / \text{range\_limit}$
        (e.g., `range_limit` is 127 for int8).
This method is sensitive to outliers, as a single extreme value can significantly skew the range and thus the scale.

* **Percentile:** Instead of using the absolute min/max, this method uses percentile values (e.g., 1st and 99th percentile, or 0.1th and 99.9th percentile) of the observed data range to determine $r_{\text{min}}$ and $r_{\text{max}}$. This helps to mitigate the impact of extreme outliers.

* **Histogram:** This method builds a histogram of the observed values and uses the histogram to determine the quantization parameters. It can be more robust against outliers than min-max, as it considers the distribution of values rather than just the extremes.

* **Entropy (or KL-Divergence):** This method aims to find quantization parameters that minimize the information loss, often measured by the Kullback-Leibler (KL) divergence between the distribution of the original FP32 values and the dequantized INT8 values. It involves iterating through different clipping thresholds (which define the effective $r_{\text{min}}$ and $r_{\text{max}}$) and selecting the one that minimizes the KL divergence. This often provides a better trade-off between range coverage and quantization precision, especially when outliers are present.

In this notebook, we will apply PTQ and QAT to a MobileNetV2 model adapted for the CIFAR-10 dataset, using the `torch.quantization` default modules and methods (histogram for activations, min-max for weights). We will compare the accuracy of the quantized model with the original FP32 model and observe the impact of quantization on model size and inference speed on CPU.

# Setup

In [1]:
from typing import Any

import torch

from matplotlib import pyplot as plt

from nnopt.model.eval import eval_model
from nnopt.model.quant import post_training_quantization, quantization_aware_training
from nnopt.model.const import DEVICE, DTYPE

from nnopt.recipes.mobilenetv2_cifar10 import get_mobilenetv2_cifar10_model, get_cifar10_datasets, save_mobilenetv2_cifar10_model

2025-06-11 11:22:49,097 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Using device: cuda, dtype: torch.bfloat16


In [2]:
# CIFAR-10 dataset
cifar10_train_dataset, cifar10_val_dataset, cifar10_test_dataset = get_cifar10_datasets()

2025-06-11 11:22:49,103 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading existing training and validation datasets...
2025-06-11 11:22:50,591 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading existing test dataset...


In [3]:
def ptq_mobilenetv2_cifar10(
    version: str = "mobilenetv2_cifar10/baseline",
    unstruct_prune: bool = False,
) -> torch.nn.Module:
    """
    Post-training quantization of MobileNetV2 on CIFAR-10.
    """
    # Load the baseline model
    mobilenetv2_cifar10_quant, mobilenetv2_cifar10_metadata = get_mobilenetv2_cifar10_model(version=version, quantized=True)

    # Post-training quantization
    mobilenetv2_cifar10_ptq_int8 = post_training_quantization(
        model=mobilenetv2_cifar10_quant,
        val_dataset=cifar10_val_dataset,
        num_calibration_batches=10,
        batch_size=32,
    )

    # Evaluate the quantized model
    mobilenetv2_cifar10_ptq_int8_val_metrics = eval_model(
        model=mobilenetv2_cifar10_ptq_int8,
        test_dataset=cifar10_val_dataset,
        device="cpu",
        use_amp=False,
        dtype=torch.float32,
    )
    mobilenetv2_cifar10_ptq_int8_test_metrics = eval_model(
        model=mobilenetv2_cifar10_ptq_int8,
        test_dataset=cifar10_test_dataset,
        device="cpu",
        use_amp=False,
        dtype=torch.float32,
    )

    mobilenetv2_cifar10_val_accuracy_quantized = mobilenetv2_cifar10_ptq_int8_val_metrics["accuracy"]
    mobilenetv2_cifar10_test_accuracy_quantized = mobilenetv2_cifar10_ptq_int8_test_metrics["accuracy"]

    print(f"Validation accuracy of MobileNetV2 on CIFAR-10 (quantized): {mobilenetv2_cifar10_val_accuracy_quantized:.2f}")
    print(f"Test accuracy of MobileNetV2 on CIFAR-10 (quantized): {mobilenetv2_cifar10_test_accuracy_quantized:.2f}")

    # Save the quantized model
    save_mobilenetv2_cifar10_model(
        model=mobilenetv2_cifar10_ptq_int8,
        version=f"mobilenetv2_cifar10/quants/ptq/{version}",
        metrics_values=mobilenetv2_cifar10_ptq_int8_val_metrics,
        unstruct_sparse_config=mobilenetv2_cifar10_metadata.get("unstruct_sparse_config", None),
    )

    return mobilenetv2_cifar10_ptq_int8

In [4]:
def qat_mobilenetv2_cifar10(
    version: str = "mobilenetv2_cifar10/baseline",
) -> torch.nn.Module:
    """
    Quantization-aware training of MobileNetV2 on CIFAR-10.
    """
    # Load the baseline model
    mobilenetv2_cifar10_quant, mobilenetv2_cifar10_metadata = get_mobilenetv2_cifar10_model(version=version, quantized=True)

    # Quantization-aware training
    mobilenetv2_cifar10_qat_int8 = quantization_aware_training(
        model=mobilenetv2_cifar10_quant,
        train_dataset=cifar10_train_dataset,
        val_dataset=cifar10_val_dataset,
        epochs=1,
        batch_size=64,
        training_device=DEVICE
    )

    # Evaluate the quantized model
    mobilenetv2_cifar10_qat_int8_val_metrics = eval_model(
        model=mobilenetv2_cifar10_qat_int8,
        test_dataset=cifar10_val_dataset,
        device="cpu",
        use_amp=False,
        dtype=torch.float32,
    )
    mobilenetv2_cifar10_qat_int8_test_metrics = eval_model(
        model=mobilenetv2_cifar10_qat_int8,
        test_dataset=cifar10_test_dataset,
        device="cpu",
        use_amp=False,
        dtype=torch.float32,
    )

    mobilenetv2_cifar10_val_accuracy_quantized = mobilenetv2_cifar10_qat_int8_val_metrics["accuracy"]
    mobilenetv2_cifar10_test_accuracy_quantized = mobilenetv2_cifar10_qat_int8_test_metrics["accuracy"]

    print(f"Validation accuracy of MobileNetV2 on CIFAR-10 (quantized): {mobilenetv2_cifar10_val_accuracy_quantized:.2f}")
    print(f"Test accuracy of MobileNetV2 on CIFAR-10 (quantized): {mobilenetv2_cifar10_test_accuracy_quantized:.2f}")

    # Save the quantized model
    save_mobilenetv2_cifar10_model(
        model=mobilenetv2_cifar10_qat_int8,
        version=f"mobilenetv2_cifar10/quants/qat/{version}",
        metrics_values=mobilenetv2_cifar10_qat_int8_val_metrics,
        unstruct_sparse_config=mobilenetv2_cifar10_metadata.get("unstruct_sparse_config", None),
    )
    return mobilenetv2_cifar10_qat_int8

# Post-Training Quantization (PTQ)

## Baseline Model

In [5]:
mobilenetv2_cifar10_ptq_int8 = ptq_mobilenetv2_cifar10(
    version="mobilenetv2_cifar10/baseline",
)

2025-06-11 11:08:31,637 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading MobileNetV2 model for CIFAR-10 from version: mobilenetv2_cifar10/baseline at /home/pbeuran/repos/nnopt/models
2025-06-11 11:08:31,638 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading base MobileNetV2 model, quantized: True
2025-06-11 11:08:31,929 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loaded metadata: {'metrics_values': {'val_metrics': {'accuracy': 0.9144, 'avg_loss': 0.2497599508523941, 'samples_per_second': 8953.201494033201, 'avg_time_per_batch': 0.007069106987337018, 'avg_time_per_sample': 0.00011169189039992489, 'n_params': 2236682, 'n_nonzero_params': 2236682}, 'test_metrics': {'accuracy': 0.9151, 'avg_loss': 0.2575160602211952, 'samples_per_second': 9123.233553351622, 'avg_time_per_batch': 0.006981545210198425, 'avg_time_per_sample': 0.00010961025980011527, 'n_params': 2236682, 'n_nonzero_params': 2236682}}}
2025-06-11 11:08:31,930 - nnopt.model.quant - INFO - Starting in-place post-training q

KeyboardInterrupt: 

## L1-unstructured Pruning 0.7

In [6]:
mobilenetv2_cifar10_ptq_int8_l1_unstruct_prune = ptq_mobilenetv2_cifar10(
    version="mobilenetv2_cifar10/unstruct_prune/l1_unstruct_prune_0.7",
    unstruct_prune=True
)

2025-06-11 11:08:35,855 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading MobileNetV2 model for CIFAR-10 from version: mobilenetv2_cifar10/unstruct_prune/l1_unstruct_prune_0.7 at /home/pbeuran/repos/nnopt/models
2025-06-11 11:08:35,856 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading base MobileNetV2 model, quantized: True
2025-06-11 11:08:35,961 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loaded metadata: {'unstructured_sparse_config': {'pruning_amount': 0.7}, 'metrics_values': {'val_metrics': {'accuracy': 0.8452, 'avg_loss': 0.44096147298812866, 'samples_per_second': 6141.987929246691, 'avg_time_per_batch': 0.010304666822793468, 'avg_time_per_sample': 0.0001628137358001368, 'n_params': 2236682, 'n_nonzero_params': 694890}, 'test_metrics': {'accuracy': 0.8401, 'avg_loss': 0.45495227823257445, 'samples_per_second': 6181.925517121272, 'avg_time_per_batch': 0.010303305554154263, 'avg_time_per_sample': 0.00016176189720022192, 'n_params': 2236682, 'n_nonzero_params': 694890}}}
2025

KeyboardInterrupt: 

# Quantization-Aware Training (QAT)

## Baseline Model

In [7]:
mobilenetv2_cifar10_qat_int8 = qat_mobilenetv2_cifar10(
    version="mobilenetv2_cifar10/baseline",
)

2025-06-11 11:08:39,057 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading MobileNetV2 model for CIFAR-10 from version: mobilenetv2_cifar10/baseline at /home/pbeuran/repos/nnopt/models
2025-06-11 11:08:39,057 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading base MobileNetV2 model, quantized: True
2025-06-11 11:08:39,166 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loaded metadata: {'metrics_values': {'val_metrics': {'accuracy': 0.9144, 'avg_loss': 0.2497599508523941, 'samples_per_second': 8953.201494033201, 'avg_time_per_batch': 0.007069106987337018, 'avg_time_per_sample': 0.00011169189039992489, 'n_params': 2236682, 'n_nonzero_params': 2236682}, 'test_metrics': {'accuracy': 0.9151, 'avg_loss': 0.2575160602211952, 'samples_per_second': 9123.233553351622, 'avg_time_per_batch': 0.006981545210198425, 'avg_time_per_sample': 0.00010961025980011527, 'n_params': 2236682, 'n_nonzero_params': 2236682}}}
2025-06-11 11:08:39,167 - nnopt.model.quant - INFO - Preparing model for quantization-

KeyboardInterrupt: 

# 

## L1-unstructured Pruning 0.7

In [5]:
mobilenetv2_cifar10_qat_int8_l1_unstruct_prune = qat_mobilenetv2_cifar10(
    version="mobilenetv2_cifar10/unstruct_prune/l1_unstruct_prune_0.7",
)

2025-06-11 11:11:00,789 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading MobileNetV2 model for CIFAR-10 from version: mobilenetv2_cifar10/unstruct_prune/l1_unstruct_prune_0.7 at /home/pbeuran/repos/nnopt/models
2025-06-11 11:11:00,790 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loading base MobileNetV2 model, quantized: True
2025-06-11 11:11:01,042 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Loaded metadata: {'unstructured_sparse_config': {'pruning_amount': 0.7}, 'metrics_values': {'val_metrics': {'accuracy': 0.8452, 'avg_loss': 0.44096147298812866, 'samples_per_second': 6141.987929246691, 'avg_time_per_batch': 0.010304666822793468, 'avg_time_per_sample': 0.0001628137358001368, 'n_params': 2236682, 'n_nonzero_params': 694890}, 'test_metrics': {'accuracy': 0.8401, 'avg_loss': 0.45495227823257445, 'samples_per_second': 6181.925517121272, 'avg_time_per_batch': 0.010303305554154263, 'avg_time_per_sample': 0.00016176189720022192, 'n_params': 2236682, 'n_nonzero_params': 694890}}}
2025

Epoch 1/1, Train Loss: 1.2495, Train Acc: 0.5589, Train Throughput: 805.83 samples/s | Val Loss: 0.7120, Val Acc: 0.7572, Val Throughput: 1558.02 samples/s | CPU Usage: 10.20% | RAM Usage: 7.9/30.9GB (32.1%) | GPU 0 Util: 50.00% | GPU 0 Mem: 11.3/24.0GB (47.2%)


[Warmup]: 100%|██████████| 5/5 [00:01<00:00,  4.37it/s]
2025-06-11 11:12:07,674 - nnopt.model.eval - INFO - Warmup complete.
[Evaluation]: 100%|██████████| 79/79 [00:14<00:00,  5.45it/s, acc=0.7362, cpu=45.8%, loss=1.0766, ram=8.2/30.9GB (33.4%), samples/s=322.4]
2025-06-11 11:12:22,197 - nnopt.model.quant - INFO - Quantization-aware training metrics: {'accuracy': 0.7362, 'avg_loss': 0.7653968486785888, 'samples_per_second': 372.2142159552869, 'avg_time_per_batch': 0.17003955391136732, 'avg_time_per_sample': 0.002686624951799604, 'params_stats': {'int_weight_params': 2202560, 'float_weight_params': 0, 'float_bias_params': 10, 'bn_param_params': 34112, 'other_float_params': 0, 'total_params': 2236682, 'approx_memory_mb_for_params': 2.2306900024414062}}
2025-06-11 11:12:22,198 - nnopt.model.eval - INFO - Starting evaluation on device: cpu, dtype: torch.float32, batch size: 32
2025-06-11 11:12:22,201 - nnopt.model.eval - INFO - Starting warmup for 5 batches...


Evaluation Complete: Avg Loss: 0.7654, Accuracy: 0.7362
Throughput: 372.21 samples/sec | Avg Batch Time: 170.04 ms | Avg Sample Time: 2.69 ms
System Stats: CPU Usage: 13.00% | RAM Usage: 7.9/30.9GB (32.5%)


[Warmup]: 100%|██████████| 5/5 [00:00<00:00,  7.55it/s]
2025-06-11 11:12:22,945 - nnopt.model.eval - INFO - Warmup complete.
[Evaluation]: 100%|██████████| 157/157 [00:11<00:00, 13.40it/s, acc=0.7362, cpu=51.0%, loss=1.0766, ram=8.3/30.9GB (34.6%), samples/s=559.8]
2025-06-11 11:12:34,674 - nnopt.model.eval - INFO - Starting evaluation on device: cpu, dtype: torch.float32, batch size: 32
2025-06-11 11:12:34,678 - nnopt.model.eval - INFO - Starting warmup for 5 batches...


Evaluation Complete: Avg Loss: 0.7654, Accuracy: 0.7362
Throughput: 446.13 samples/sec | Avg Batch Time: 71.38 ms | Avg Sample Time: 2.24 ms
System Stats: CPU Usage: 11.50% | RAM Usage: 8.1/30.9GB (34.0%)


[Warmup]: 100%|██████████| 5/5 [00:00<00:00,  9.48it/s]
2025-06-11 11:12:35,287 - nnopt.model.eval - INFO - Warmup complete.
[Evaluation]: 100%|██████████| 313/313 [00:23<00:00, 13.09it/s, acc=0.7291, cpu=44.1%, loss=0.5235, ram=8.2/30.9GB (34.6%), samples/s=457.2]
2025-06-11 11:12:59,233 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Metadata saved to /home/pbeuran/repos/nnopt/models/mobilenetv2_cifar10/quants/qat/mobilenetv2_cifar10/unstruct_prune/l1_unstruct_prune_0.7/metadata.json
2025-06-11 11:12:59,234 - nnopt.recipes.mobilenetv2_cifar10 - INFO - Model saved to /home/pbeuran/repos/nnopt/models/mobilenetv2_cifar10/quants/qat/mobilenetv2_cifar10/unstruct_prune/l1_unstruct_prune_0.7/model.pt


Evaluation Complete: Avg Loss: 0.7975, Accuracy: 0.7291
Throughput: 445.00 samples/sec | Avg Batch Time: 71.80 ms | Avg Sample Time: 2.25 ms
System Stats: CPU Usage: 12.00% | RAM Usage: 8.0/30.9GB (34.0%)
Validation accuracy of MobileNetV2 on CIFAR-10 (quantized): 0.74
Test accuracy of MobileNetV2 on CIFAR-10 (quantized): 0.73


In [8]:
for _, params in mobilenetv2_cifar10_qat_int8_l1_unstruct_prune.named_parameters():
    print(params.numel() - params.count_nonzero())

tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
tensor(0)
