Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx export of per channel fake quantize functions #42835

Closed
23 changes: 22 additions & 1 deletion test/onnx/test_models.py
Expand Up @@ -182,7 +182,7 @@ def test_fake_quant(self):
self.exportTest(toC(FakeQuantNet()), toC(x))

@skipIfUnsupportedMinOpsetVersion(10)
def test_qat_resnet(self):
def test_qat_resnet_pertensor(self):
# Quantize ResNet50 model
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
qat_resnet50 = resnet50()
Expand All @@ -202,6 +202,27 @@ def test_qat_resnet(self):

self.exportTest(toC(qat_resnet50), toC(x))

@skipIfUnsupportedMinOpsetVersion(13)
def test_qat_resnet_per_channel(self):
# Quantize ResNet50 model
x = torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
qat_resnet50 = resnet50()

qat_resnet50.qconfig = quantization.QConfig(
activation=quantization.default_fake_quant,
weight=quantization.default_per_channel_weight_fake_quant)
quantization.prepare_qat(qat_resnet50, inplace=True)
qat_resnet50.apply(torch.quantization.enable_observer)
qat_resnet50.apply(torch.quantization.enable_fake_quant)

_ = qat_resnet50(x)
for module in qat_resnet50.modules():
if isinstance(module, quantization.FakeQuantize):
module.calculate_qparams()
qat_resnet50.apply(torch.quantization.disable_observer)

self.exportTest(toC(qat_resnet50), toC(x))

@disableScriptTest() # None type in outputs
def test_googlenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
Expand Down
14 changes: 14 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -5723,6 +5723,20 @@ def forward(self, input):
x = torch.randn(6, 4, 3, 3)
self.run_test(FakeQuantizePerTensorModel(), (x))

@skipIfUnsupportedMinOpsetVersion(13)
def test_fake_quantize_per_channel(self):
class FakeQuantizePerChannelModel(torch.nn.Module):
def forward(self, input):
amax = torch.ones(4)
scale = amax / 127.
zero_point = torch.zeros_like(amax, dtype=torch.long)
# Quantize twice to test differnet branches
y = torch.fake_quantize_per_channel_affine(input, scale, zero_point, 1, 0, 255)
return torch.fake_quantize_per_channel_affine(y, scale, zero_point, 1, -128, 127)

x = torch.randn(6, 4, 3, 3)
self.run_test(FakeQuantizePerChannelModel(), (x))

def test_batchnorm_training(self):
class MyModule(torch.nn.Module):
def __init__(self):
Expand Down
15 changes: 15 additions & 0 deletions torch/onnx/symbolic_opset13.py
Expand Up @@ -110,6 +110,21 @@ def glu(g, input, dim):
first, second = g.op('Split', input, dim, outputs=2)
return g.op('Mul', first, g.op('Sigmoid', second))

@parse_args('v', 'v', 'v', 'i', 'i', 'i')
def fake_quantize_per_channel_affine(g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127):
if quant_min not in [0, -128] or quant_max not in [127, 255]:
raise RuntimeError(
"ONNX defines [0, 255] for quint8 and [-128, 127] for qint8, got [{}, {}]".format(quant_min, quant_max))

# ONNX defines zero_point to be int8 or uint8
if quant_min == 0:
zero_point = g.op("Cast", zero_point, to_i=sym_help.cast_pytorch_to_onnx['Byte'])
else:
zero_point = g.op("Cast", zero_point, to_i=sym_help.cast_pytorch_to_onnx['Char'])
return g.op(
"DequantizeLinear",
g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis),
scale, zero_point, axis_i=axis)

def _reduce_op_symbolic(onnx_op_name):
def symbolic(g, self, dim=None, keepdim=None):
Expand Down