## 自定义模型

In [1]:
import torch

In [2]:
class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.fc = torch.nn.Linear(4, 5)

    def forward(self, x):
        x = self.quant(x)
        print('qx:', x)
        x = self.fc(x)
        x = self.dequant(x)
        return x

model = MyModule()
model.eval()

MyModule(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (fc): Linear(in_features=4, out_features=5, bias=True)
)

In [3]:
from torch.fx import symbolic_trace

symbolic_traced: torch.fx.GraphModule = symbolic_trace(model)
print(symbolic_traced)

qx: Proxy(x)
MyModule(
  (fc): Linear(in_features=4, out_features=5, bias=True)
)



def forward(self, x):
    fc = self.fc(x);  x = None
    return fc
    


In [4]:

def quantize_model(model, inp):
    model.qconfig = torch.quantization.get_default_qconfig("fbgemm")
    torch.quantization.prepare(model, inplace=True)
    # Calibration
    model(inp)
    torch.quantization.convert(model, inplace=True)

In [5]:
from copy import deepcopy

inp = torch.tensor([[-100, 0, 0.1, 1000]], dtype=torch.float)

qmodel = deepcopy(model)
qmodel.eval()
quantize_model(qmodel, inp)

  src_bin_begin // dst_bin_width, 0, self.dst_nbins - 1
  src_bin_end // dst_bin_width, 0, self.dst_nbins - 1


qx: tensor([[-1.0000e+02,  0.0000e+00,  1.0000e-01,  1.0000e+03]])


In [6]:
symbolic_traced: torch.fx.GraphModule = symbolic_trace(qmodel)
print(symbolic_traced)

qx: Proxy(quant)
MyModule(
  (quant): Quantize(scale=tensor([8.6572]), zero_point=tensor([12]), dtype=torch.quint8)
  (fc): QuantizedLinear(in_features=4, out_features=5, scale=3.0581657886505127, zero_point=0, qscheme=torch.per_channel_affine)
  (dequant): DeQuantize()
)



def forward(self, x):
    quant = self.quant(x);  x = None
    fc = self.fc(quant);  quant = None
    dequant = self.dequant(fc);  fc = None
    return dequant
    


获取模型的所有的nodes：

In [7]:
from torchvision.models.feature_extraction import get_graph_node_names

train_nodes, eval_nodes = get_graph_node_names(qmodel)
eval_nodes

qx: Proxy(quant)
qx: Proxy(quant)


['x', 'quant', 'fc', 'dequant']

结果显示多了"quant"和"dequant"两个node。

In [8]:
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

# 定义输出node
return_nodes = {
    'x': 'x',
    'fc': 'fc',
}

# 进行重建
n_model = create_feature_extractor(qmodel, return_nodes)

out = n_model(inp)
for k, v in out.items():
    print(k, v) 

qx: Proxy(quant)
qx: Proxy(quant)
x tensor([[-1.0000e+02,  0.0000e+00,  1.0000e-01,  1.0000e+03]])
fc tensor([[  9.1745, 391.4452, 223.2461, 232.4206, 461.7830]], size=(1, 5),
       dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine,
       scale=3.0581657886505127, zero_point=0)


model中的qx没有任何改变，但是qmodel中的qx被quantize了:

In [9]:
with torch.no_grad():
    print('model:')
    print(model(inp), '\n')

print('qmodel:')
print(qmodel(inp))

model:
qx: tensor([[-1.0000e+02,  0.0000e+00,  1.0000e-01,  1.0000e+03]])
tensor([[  9.3630, 387.9530, 220.9602, 228.7354, 461.4546]]) 

qmodel:
qx: tensor([[-103.8863,    0.0000,    0.0000, 1004.2339]], size=(1, 4),
       dtype=torch.quint8, quantization_scheme=torch.per_tensor_affine,
       scale=8.657188415527344, zero_point=12)
tensor([[  9.1745, 391.4452, 223.2461, 232.4206, 461.7830]])
