In [31]:
import torch
from torch.quantization.observer import MinMaxObserver, MovingAverageMinMaxObserver, HistogramObserver

In [32]:
normal = torch.distributions.normal.Normal(0, 1)

In [33]:
C, L = 3, 4

In [34]:
inputs = [normal.sample((C, L)), normal.sample((C, L))]

In [35]:
observers = [MinMaxObserver(), MovingAverageMinMaxObserver(), HistogramObserver()]

In [36]:
for obs in observers:
    for x in inputs:
        obs(x)
    print(obs.__class__.__name__, obs.calculate_qparams())

MinMaxObserver (tensor([0.0186]), tensor([151], dtype=torch.int32))
MovingAverageMinMaxObserver (tensor([0.0118]), tensor([118], dtype=torch.int32))
HistogramObserver (tensor([0.0186]), tensor([151], dtype=torch.int32))


In [37]:
obs.calculate_qparams?

[0;31mSignature:[0m [0mobs[0m[0;34m.[0m[0mcalculate_qparams[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mFile:[0m      ~/miniconda3/envs/torch/lib/python3.11/site-packages/torch/ao/quantization/observer.py
[0;31mType:[0m      method

In [39]:
inputs[0].max()

tensor(1.6133)

In [42]:
pareto = torch.distributions.pareto.Pareto(1, 10)

In [45]:
pareto.sample((10,))

tensor([1.0037, 1.0171, 1.2260, 1.3984, 1.0909, 1.0961, 1.0097, 1.0023, 1.0179,
        1.0807])

In [46]:
3 * 64 * 7 * 7

9408

In [47]:
from torch.quantization.observer import MovingAveragePerChannelMinMaxObserver

In [52]:
obs = MovingAveragePerChannelMinMaxObserver(ch_axis=0)

In [55]:
for x in inputs:
    obs(x)
    print(obs.calculate_qparams())

(tensor([0.0063, 0.0105, 0.0113]), tensor([220, 126, 113], dtype=torch.int32))
(tensor([0.0063, 0.0105, 0.0113]), tensor([217, 126, 114], dtype=torch.int32))


In [73]:
backend = 'qnnpack'
qconfig = torch.quantization.get_default_qconfig(backend)
torch.backends.quantized.engine = backend

In [74]:
qconfig

QConfig(activation=functools.partial(<class 'torch.ao.quantization.observer.HistogramObserver'>, reduce_range=False){}, weight=functools.partial(<class 'torch.ao.quantization.observer.MinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})

## Post training dynamic quantization

In [65]:
import torch.nn as nn
from torch.quantization import quantize_dynamic

In [77]:
m = nn.Sequential(
    nn.Conv2d(3, 16, 3),
    nn.Linear(16, 32),
    nn.ReLU(), 
    nn.Linear(32, 16)
)

In [78]:
m.eval()

Sequential(
  (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (1): Linear(in_features=16, out_features=32, bias=True)
  (2): ReLU()
  (3): Linear(in_features=32, out_features=16, bias=True)
)

In [79]:
model_quantized = quantize_dynamic(model=m, qconfig_spec={nn.Linear, nn.Conv2d}, dtype=torch.qint8, inplace=False)

In [80]:
model_quantized

Sequential(
  (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (1): DynamicQuantizedLinear(in_features=16, out_features=32, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
  (2): ReLU()
  (3): DynamicQuantizedLinear(in_features=32, out_features=16, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)

In [87]:
model_quantized[1].weight().element_size()

1

In [92]:
model_quantized[0].weight.element_size()

4

### Post-training static quantization

In [97]:
import copy

In [94]:
model = nn.Sequential(nn.Conv2d(3, 16, 3), 
                     nn.BatchNorm2d(16), 
                     nn.ReLU(), 
                     nn.Conv2d(16, 64, 3), 
                     nn.BatchNorm2d(64), 
                     nn.ReLU())

In [98]:
m = copy.deepcopy(model)

In [99]:
m.eval()

Sequential(
  (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1))
  (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU()
)

In [96]:
#Fuse Modules

In [101]:
torch.quantization.fuse_modules(m, [['0', '1', '2',], ['3', '4', '5']], inplace=True)

Sequential(
  (0): ConvReLU2d(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
  )
  (1): Identity()
  (2): Identity()
  (3): ConvReLU2d(
    (0): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
  )
  (4): Identity()
  (5): Identity()
)

In [103]:
#Insert Stubs

In [105]:
m = nn.Sequential(torch.quantization.QuantStub(), *m, 
              torch.quantization.DeQuantStub())

In [109]:
#Prepare

In [110]:
m.qconfig = torch.quantization.get_default_qconfig(backend)

In [112]:
torch.quantization.prepare(m, inplace=True)

Sequential(
  (0): QuantStub(
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (1): ConvReLU2d(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (2): Identity()
  (3): Identity()
  (4): ConvReLU2d(
    (0): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (activation_post_process): HistogramObserver(min_val=inf, max_val=-inf)
  )
  (5): Identity()
  (6): Identity()
  (7): DeQuantStub()
)

In [113]:
#Calibrate
with torch.inference_mode():
    for _ in range(10):
        x = torch.randn(1, 3, 224, 224)
        m(x)

In [114]:
#Convert
torch.quantization.convert(m, inplace=True)

Sequential(
  (0): Quantize(scale=tensor([0.0316]), zero_point=tensor([129]), dtype=torch.quint8)
  (1): QuantizedConvReLU2d(3, 16, kernel_size=(3, 3), stride=(1, 1), scale=0.008907907642424107, zero_point=0)
  (2): Identity()
  (3): Identity()
  (4): QuantizedConvReLU2d(16, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.003892553737387061, zero_point=0)
  (5): Identity()
  (6): Identity()
  (7): DeQuantize()
)

In [121]:
m[1].weight().element_size()

1