## Automatic quantization

This notebook demonstrates simple end-two-end pipeline for MobileNetV2 quantization.

Our quantization process consists of quantized model calibration, quantization threshold adjustment and weight fine-tuning using distillation. Finally, we demonstrate inference of our quantized model using ONNX Runtime framework.

### Main chapters of this notebook:
1. Setup the environment
1. Prepare dataset and create dataloaders
1. Evaluate pretrained MobileNetV2 from torchvision
1. End2end quantization with our framework
1. Inference using ONNX Runtime with TensorRT Execution Provider

Before running this example make sure that TensorRT supports your GPU for int8 inference  (``cuda compute capability`` > 6.1, as described [here](https://docs.nvidia.com/deeplearning/tensorrt/support-matrix/index.html#hardware-precision-matrix)).

## Setup the environment

First, let's set up the environment and make some common imports.

In [None]:
!pip install -r requirements.txt

In [None]:
import os

os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
# You may need to uncomment and change this variable to match free GPU index
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
# Common:
import itertools
import numpy as np
import torch
from pathlib import Path
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim import RAdam
from tqdm.auto import tqdm
from tutorial_utils.dataset import create_imagenet10k_dataloaders
from tutorial_utils.train import accuracy

# Quantization:
from enot.quantization import RMSELoss
from enot.quantization import TensorRTFakeQuantizedModel
from enot.quantization import calibrate
from enot.quantization import distill

# ONNX Runtime inference:
from tutorial_utils.inference import create_onnxruntime_session
import onnxsim

Define model evaluation function:

In [None]:
# This function can evaluate both nn.Modules and executable functions.
def eval_model(model_fn, dataloader):
    if isinstance(model_fn, nn.Module):
        model_fn.eval()

    total = 0
    total_loss = 0.0
    total_correct = 0.0

    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader):
            n = inputs.shape[0]

            pred_labels = model_fn(inputs)
            batch_loss = criterion(pred_labels, labels)
            batch_accuracy = accuracy(pred_labels, labels)

            total += n
            total_loss += batch_loss.item() * n
            total_correct += batch_accuracy.item() * n

    return total_loss / total, total_correct / total

### In the following cell we setup all necessary dirs

* `HOME_DIR` - experiments home directory
* `DATASETS_DIR` - root directory for datasets (imagenette2, ...)
* `PROJECT_DIR` - project directory to save training logs, checkpoints, ...
* `ONNX_MODEL_PATH` - onnx model path

In [None]:
HOME_DIR = Path.home() / '.optimization_experiments'
DATASETS_DIR = HOME_DIR / 'datasets'
PROJECT_DIR = HOME_DIR / 'enot-lite_quantization'
ONNX_MODEL_PATH = PROJECT_DIR / 'mobilenetv2.onnx'

HOME_DIR.mkdir(exist_ok=True)
DATASETS_DIR.mkdir(exist_ok=True)
PROJECT_DIR.mkdir(exist_ok=True)

## Prepare dataset and create dataloaders

We will use Imagenet-10k dataset in this example.

Imagenet-10k dataset is a subsample of [Imagenet](https://image-net.org/challenges/LSVRC/index.php) dataset. It contains 5000 training images and 5000 validation images. Training images are uniformly gathered from the original training set, and validation images are gathered from the original validation set, 5 per each class.

`create_imagenet10k_dataloaders` function prepares datasets for you in this example; specifically, it:
1. downloads and unpacks dataset into `DATASETS_DIR`;
1. creates and returns train and validation dataloaders.

The two parts of the dataset:
* train: for quantization procedure (`DATASETS_DIR`/imagenet10k/train/)
* validation: for model validation (`DATASETS_DIR`/imagenet10k/val/)

In [None]:
train_dataloader, validation_dataloader = create_imagenet10k_dataloaders(
    dataset_root_dir=DATASETS_DIR,
    input_size=224,
    batch_size=25,
    num_workers=4,
)

## Evaluate pretrained MobileNetV2 from torchvision

In [None]:
from torchvision.models.mobilenetv2 import mobilenet_v2

regular_model = mobilenet_v2(pretrained=True).cuda()

# Turning off FullyConnected layer dropout.
# This is required to stabilize fine-tuning procedure.
regular_model.classifier[0].p = 0.0

In [None]:
val_loss, val_accuracy = eval_model(regular_model, validation_dataloader)
print(f'Regular (non-quantized) model: accuracy={val_accuracy:.3f}, loss={val_loss:.3f}')

## End2end quantization

- wrap `regular_model` to `TensorRTFakeQuantizedModel`
- calibrate quantization threshold using `calibration` context
- distill quantization threshold and scale-factors using `distill` context

In [None]:
fake_quantized_model = TensorRTFakeQuantizedModel(regular_model).cuda()

In [None]:
# Calibrate quantization thresholds using 10 batches.
with torch.no_grad(), calibrate(fake_quantized_model):
    for batch in itertools.islice(train_dataloader, 10):
        batch = batch[0].cuda()
        fake_quantized_model(batch)

In [None]:
# Distill quantization thresholds and scale-factors using RMSE loss for 5 epochs.
n_epochs = 5

with distill(fq_model=fake_quantized_model, tune_weight_scale_factors=True) as (qdistill_model, params):
    optimizer = RAdam(params=params, lr=0.005, betas=(0.9, 0.95))
    scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=len(train_dataloader) * n_epochs)
    distillation_criterion = RMSELoss()

    for _ in range(n_epochs):
        for batch in (tqdm_it := tqdm(train_dataloader)):
            batch = batch[0].cuda()

            optimizer.zero_grad()
            loss: torch.Tensor = torch.tensor(0.0).cuda()
            for student_output, teacher_output in qdistill_model(batch):
                loss += distillation_criterion(student_output, teacher_output)

            loss.backward()
            optimizer.step()
            scheduler.step()

            tqdm_it.set_description(f'loss: {loss.item():.3f}')

In [None]:
fake_quantized_model.enable_quantization_mode(True)
val_loss, val_accuracy = eval_model(fake_quantized_model, validation_dataloader)
print(f'Optimized quantized model: accuracy={val_accuracy:.3f}, loss={val_loss:.3f}')

## Inference using ONNX Runtime with TensorRT Execution Provider

In [None]:
torch.onnx.export(
    model=fake_quantized_model.cpu(),
    args=torch.zeros(25, 3, 224, 224),
    f='exported_model.onnx',
    opset_version=13,
    input_names=['input'],
    output_names=['output'],
)

proto, _ = onnxsim.simplify('exported_model.onnx')

Initialize ONNX Runtime inference session with TensorRT Execution Provider:

In [None]:
torch.cuda.empty_cache()  # Empty PyTorch CUDA cache before running ONNX Runtime.

sess = create_onnxruntime_session(
    proto=proto,
    input_sample=torch.zeros(25, 3, 224, 224, device='cuda'),
    output_shape=(25, 1000),
)

Evaluate quantized model on TensorRT Execution Provider:

In [None]:
def model_fn(inputs):
    return sess(inputs)


val_loss, val_accuracy = eval_model(model_fn, validation_dataloader)
print(f'Quantized model with fine-tuned weights with TRT: accuracy={val_accuracy:.3f}, loss={val_loss:.3f}')