In [1]:
import torch

## 自定义模型

In [2]:
from torch import nn

class MyModule(torch.nn.Module):
    def __init__(self, ic, oc, kernel_size):
        super().__init__()
        self.conv = nn.Conv2d(ic, oc, kernel_size)

    def forward(self, x):
        return self.conv(x)

model = MyModule(1, 1, 1)
model = torch.quantization.QuantWrapper(model)
model.eval()

QuantWrapper(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (module): MyModule(
    (conv): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)

In [4]:
from torch.fx import symbolic_trace

symbolic_traced: torch.fx.GraphModule = symbolic_trace(model)
print(symbolic_traced)

qx: Proxy(x)
MyModule(
  (fc): Linear(in_features=4, out_features=5, bias=True)
)



def forward(self, x):
    fc = self.fc(x);  x = None
    return fc
    


In [3]:
def quantize_model(model, inp):
    model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
    torch.quantization.prepare(model, inplace=True)
    # Calibration
    model(inp)
    torch.quantization.convert(model, inplace=True)

In [5]:
from copy import deepcopy

inp = torch.randn((1, 1, 3, 3))

qmodel = deepcopy(model).eval()
quantize_model(qmodel, inp)

  src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1
  src_bin_end // dst_bin_width, 0, self.dst_nbins - 1


In [7]:
symbolic_traced: torch.fx.GraphModule = symbolic_trace(qmodel)
print(symbolic_traced)

qx: Proxy(quant)
MyModule(
  (quant): Quantize(scale=tensor([8.6572]), zero_point=tensor([12]), dtype=torch.quint8)
  (fc): QuantizedLinear(in_features=4, out_features=5, scale=5.336332321166992, zero_point=58, qscheme=torch.per_channel_affine)
  (dequant): DeQuantize()
)



def forward(self, x):
    quant = self.quant(x);  x = None
    fc = self.fc(quant);  quant = None
    dequant = self.dequant(fc);  fc = None
    return dequant
    


获取模型的所有的nodes：

In [8]:
from torchvision.models.feature_extraction import get_graph_node_names

train_nodes, eval_nodes = get_graph_node_names(qmodel)
eval_nodes

qx: Proxy(quant)
qx: Proxy(quant)


['x', 'quant', 'fc', 'dequant']

结果显示多了"quant"和"dequant"两个node。

In [9]:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

# 定义输出node
return_nodes = {
    'x': 'x',
    'fc': 'fc',
}

# 进行重建
n_model = create_feature_extractor(qmodel, return_nodes)

out = n_model(inp)
for k, v in out.items():
    print(k, v) 

qx: Proxy(quant)
qx: Proxy(quant)
x tensor([[-1.0000e+02,  0.0000e+00,  1.0000e-01,  1.0000e+03]])
fc tensor([[ 373.5433, -309.5073, -309.5073,  304.1709,  101.3903]], size=(1, 5),
       dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine,
       scale=5.336332321166992, zero_point=58)


model中的qx没有任何改变，但是qmodel中的qx被quantize了:

In [10]:
with torch.no_grad():
    print('model:')
    print(model(inp), '\n')

print('qmodel:')
print(qmodel(inp))

model:
qx: tensor([[-1.0000e+02,  0.0000e+00,  1.0000e-01,  1.0000e+03]])
tensor([[ 369.9384, -420.0211, -307.6655,  301.7999,  105.8759]]) 

qmodel:
qx: tensor([[-103.8863,    0.0000,    0.0000, 1004.2339]], size=(1, 4),
       dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine,
       scale=8.657188415527344, zero_point=12)
tensor([[ 373.5433, -309.5073, -309.5073,  304.1709,  101.3903]])


## 基本知识
$$
Q(x, \text{scale}, \text{zero\_point}) = \text{round}(\frac{x}{\text{scale}} + \text{zero\_point})
$$

下面操作等价于：
$\text{round}(a / 1.6) \times 1.6 = 3.2$

In [13]:
a = torch.tensor(3.0)
qa = torch.quantize_per_tensor(a, 1.6, 0, torch.qint8)
qa

tensor(3.2000, size=(), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=1.6, zero_point=0)

In [14]:
qa.int_repr()

tensor(2, dtype=torch.int8)

In [None]:
from torch.nn.quantized import QFunctional

q_add = QFunctional()
qa = torch.quantize_per_tensor(torch.tensor(3.0), 1.0, 0, torch.qint8)
qb = torch.quantize_per_tensor(torch.tensor(4.0), 1.0, 0, torch.qint8)
q_add.add(qa, qb)  # Equivalent to ``torch.ops.quantized.add(a, b, 1.0, 0)

tensor(7., size=(), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=1.0, zero_point=0)

## QConv2D

In [6]:
from torch import nn

m = nn.quantized.Conv2d(1, 1, 1)
m

QuantizedConv2d(1, 1, kernel_size=(1, 1), stride=(1, 1), scale=1.0, zero_point=0)