### PTSQ

In [None]:
import torch

# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensor from quantized to floating point.
        self.dequant = torch.ao.quantization.DeQuantStub()
    
    def forward(self, x):
        # manually specify where tensors will be converted from floating point to quantized in the quantized model.
        x = self.quant(x)
        x = self.conv(x)
        x = self.relu(x)
        # manually specify where tensors will be converted from quantized to floating point in the quantized model.
        x = self.dequant(x)
        return x
    
# create a model instance
model_fp32 = M()

# model must be set to eval mode for static quantization logic to work
model_fp32.eval()


model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])

model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)

input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)

model_int8 = torch.ao.quantizaiton.convert(model_fp32_prepared)

# run the mode, relevant calculations will happen in int8
res = model_int8(input_fp32)

## QAT

In [None]:
import torch

# define a floating point model where some layers could benefit from QAT
class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # QuantStub converts tensors from floating point to quantized
        self.quant = torch.ao.quantization.QuantStub()
        self.conv = torch.nn.Conv2d(1, 1, 1)
        self.bn = torch.nn.BatchNorm2d(1)
        self.relu = torch.nn.ReLU()
        # DeQuantStub converts tensors from quantized to floating point
        self.dequant = torch.ao.quantization.DeQuantStub()
        
    def forward(self, x):
        x = self.quant(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dequant(x)
        return x

# create a model instance
model_fp32 = M()

# model must be set to eval for fusion to work
model_fp32.eval()

# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here. 
model_fp32.qconfig = torch.ao.quantization.get_defalut_qat_qconfig('x86')

# fuse the activations to preceding layers, where applicable
# this needs to be done manually depending on the model architecture.
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'bn', 'relu']])

# Prepare the model for QAT. 
# This inserts observers and fake_quants in the model needs to be set to train for QAT logic 
# to work the model that will observe weight and activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.preparation.prepare_qat(model_fp32_fused.train())

# run the training loop 
# training_loop(model_fp32_prepared)
        
model_fp32_prepared.eval()
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happpen in int8
# res = model_int8(input_fp32)

#### FX Graph Mode Quantization

In [None]:
import torch
from torch.ao.quantization import (
    get_default_qconfig_mapping,
    get_default_qat_qconfig_mapping,
    QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx
import copy

model_fp = UserModel()

# specific details from the following links;
# https://pytorch.org/docs/stable/quantization.html
