![logo](../../picture/license_header_logo.png)
**Copyright (c) 2020-2021 CertifAI Sdn. Bhd.**

This program is part of OSRFramework. You can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.

You should have received a copy of the GNU Affero General Public License
along with this program. If not, see http://www.gnu.org/licenses/.

Authored by: [Jacklyn Lim](mailto:jacklyn.lim@certifai.ai)

### Motivation of Doing Graph Mode Quantization

By default, PyTorch uses eager mode computation, hence when we quantise our model, we are also using an eager mode approach. However, the quantization process is highly manual especially in the following aspects:
- Explicitly adding in QuantStub and DeQuantStub when defining model architecture
- Having to explicitly specify the layers to be fused during layers fusion
- Lack of support for `torch.nn.functionals` (functional.conv2d and functional.linear would not get quantized)
- Having to replace operations that require special handling such as `torch.cat` and `torch.cat` with `nn.quantized.FloatFunctional`

Here is where Torchscript comes to rescue. Torchscript records its definitions in an Intermediate Representation (or IR), commonly referred to in Deep learning as a graph. Pytorch also provides quantization methods for Torchscript models. 

In graph mode, quantization is achieved by module and graph manipulations. It is able to automatically figure out things like which modules to fuse and where to insert observer calls, quantize/dequantize functions etc., hence the whole quantization process can be automated.

Advantages of Graph Mode Quantization are:

- Simple quantization flow, minimal manual steps
- Unlocks the possibility of doing higher level optimizations like automatic precision selection

Note: 
Graph Mode Quantization is still a very new feature (Available since Pytorch 1.8.0), so do expect changes in how it's implemented.

### Import Libraries

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import download_model, download_dataset, load_model_state_dict, load_dataset, load_image, compare_performance

# for Graph Mode Quantization
from torch.quantization import quantize_dynamic_jit, per_channel_dynamic_qconfig, get_default_qconfig, quantize_jit
from torch.quantization.quantize_fx import prepare_fx, convert_fx

### Download Model and Dataset

In [None]:
# model download
MODEL_DOWNLOAD_PATH = 'https://s3.eu-central-1.wasabisys.com/certifai/deployment-training-labs/models/fruit_classifier_state_dict.pt'
MODEL_STATE_DICT_PATH = '../../resources/model/'
MODEL_FILENAME = 'fruits_image_classification.zip'

# data download
DATA_DOWNLOAD_PATH = "https://s3.eu-central-1.wasabisys.com/certifai/deployment-training-labs/fruits_image_classification-20210604T123547Z-001.zip"
DATA_SAVE_PATH = "../../resources/data/"
DATA_ZIP_FILENAME = "fruits_image_classification.zip"

# download model
download_model(MODEL_DOWNLOAD_PATH, MODEL_STATE_DICT_PATH, MODEL_FILENAME)

# download dataset
download_dataset(DATA_DOWNLOAD_PATH, DATA_SAVE_PATH, DATA_ZIP_FILENAME)

### Load Original Model

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # Note that the input of this layers is depending on your input image sizes
        self.fc1 = nn.Linear(18496, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 3)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
# load original model
model_fp32 = Net()
model_fp32 = load_model_state_dict(model_fp32, MODEL_STATE_DICT_PATH + MODEL_FILENAME)
model_fp32.eval()

# Print original model
print("\033[1mFP32 Model: \033[0m")
print(model_fp32)
print("\n")

### Dynamic Quantization

Reference: https://pytorch.org/tutorials/prototype/graph_mode_dynamic_bert_tutorial.html#quantizing-bert-model-with-graph-mode-quantization

#### Convert Original Model to Torchscript model

We have to convert the original Pytorch model to a Torchscript model since the input for Graph Mode Quantization is a Torchscript model.

In [None]:
# script model (convert to torchscript model)
script_model = torch.jit.script(model_fp32).eval()

#### Quantize Model

Using per-channel quantization which helps improving the final accuracy

In [None]:
"""
Task: Perform dynamic quantization using the quantize_dynamic_jit API
"""
############ Enter your code here ############

############ Enter your code here ############

#### Compare Model Performance

In [None]:
INFERENCE_IMAGE_PATH = "../../resources/data/fruits_image_classification/test/apple/image1.jpg"
TEST_DATASET_ROOTDIR = "../../resources/data/fruits_image_classification/test"
    
# load image
inference_image = load_image(INFERENCE_IMAGE_PATH)

# load test dataset
test_dataloader = load_dataset(TEST_DATASET_ROOTDIR)

compare_performance(model_fp32, quantized_model, "model_fp32", "quantized_model", inference_image, test_dataloader)

### Static Quantization

#### Convert Original Model to Torchscript model

We have to convert the original Pytorch model to a Torchscript model since the input for Graph Mode Quantization is a Torchscript model.

In [None]:
# script model (convert to torchscript model)
script_model = torch.jit.script(model_fp32).eval()

#### Specify How to Quantize the Model With `qconfig_dict` (Including Specifying the Backend) 

Please refer to [Pytorch guide on setting qconfig](https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html#specify-how-to-quantize-the-model-with-qconfig-dict).

In [None]:
def set_qconfig_dict(backend, config):
    """
    A function to set backend to run the quantized operators 

    Parameters:
    model (Net): model to be quantized

    Returns:
    Model with set qconfig (engine used for quantized computations)
    """
    qconfig = get_default_qconfig(backend)
    qconfig_dict = {config: qconfig}
    return qconfig_dict

In [None]:
# specify how to quantize the model with qconfig_dict (including specifying the backend)
qconfig_dict = set_qconfig_dict("qnnpack", "")

#### Calibration

In [None]:
def calibration(model, dataloader, device="cpu"):
    """ Returns calibrated model"""

    pass
    ############ Enter your code here ############
    
    ############ Enter your code here ############

For Graph Mode Quantization, instead of the `torch.quantization.prepare` function we will make use of the  `prepare_fx` function from the `torch.quantization.quantize_fx` module.

In [None]:
CALIBRATION_DATASET_ROOTDIR = "../../resources/data/fruits_image_classification/train"

# load calibration dataset
calibration_dataloader = load_dataset(CALIBRATION_DATASET_ROOTDIR)

# prepare model for calibration - including fuse modules and insert observers
prepared_model = prepare_fx(model_fp32, qconfig_dict)

# calibrate model
calibrated_model = calibration(prepared_model, calibration_dataloader)

#### Convert Calibrated Model to a Quantized Model

Hint: For Graph Mode Quantization, instead of the `torch.quantization.convert` function we will make use of the `convert_fx` function from the torch.quantization.quantize_fx module.

In [None]:
def convert_to_quantized_model(calibrated_model):
    """ Returns a quantized int8 model """
    pass
    ############ Enter your code here ############
    
    ############ Enter your code here ############

In [None]:
# Convert to a quantized model
model_int8 = convert_to_quantized_model(calibrated_model)
model_int8.eval()

# Print quantized model
print("\033[1mINT8 Model: \033[0m")
print(model_int8)
print("\n")

#### Compare Model Performance

In [None]:
INFERENCE_IMAGE_PATH = "../../resources/data/fruits_image_classification/test/apple/image1.jpg"
TEST_DATASET_ROOTDIR = "../../resources/data/fruits_image_classification/test"

# COMPARING PERFORMANCE 
print("\033[1mCOMPARING PERFORMANCE... \033[0m")

# load image
inference_image = load_image(INFERENCE_IMAGE_PATH)

# load test dataset
test_dataloader = load_dataset(TEST_DATASET_ROOTDIR)

compare_performance(model_fp32, model_int8, "model_fp32", "quantized_model", inference_image, test_dataloader)

#### Save Torchscript Model

In [None]:
def save_torchscript_model(model, model_dir, model_filename):
    print("\nSaving model...")
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    model_filepath = os.path.join(model_dir, model_filename)
    torch.jit.save(torch.jit.script(model), model_filepath)
    print("Model saved in {}".format(model_dir))

In [None]:
MODEL_SAVE_PATH = "../generated_model"
TORCHSCRIPT_MODEL_FILENAME = "graphmode_static_quantized_model.pt"

save_torchscript_model(model_int8, MODEL_SAVE_PATH, TORCHSCRIPT_MODEL_FILENAME)

### Additional Notes

Since graph model quantization automatically figures out things like which modules to fuse, it tries to quantize as many layers as possible, in contrast with manually specifying the layers to quantize. This is why we can see that the quantized model size is significantly smaller here.

As mentioned previously, Pytorch Graph Mode Quantization is currently using the FX Graph Mode Quantization module which is still a very new feature (Available since Pytorch 1.8.0), so do expect changes in how it's implemented.

Nonetheless from how easy it is to use and the performance improvement especially the model size, it is a very exciting feature to look forward to.

### Extra Reading Materials
- https://pytorch.org/docs/stable/quantization-support.html#torch-quantization
- https://pytorch.org/tutorials/prototype/fx_graph_mode_ptq_static.html
- https://www.kaggle.com/okeaditya/what-s-new-in-pytorch-1-6
- https://zhuanlan.zhihu.com/p/349019936