Skip to content

Commit

Permalink
onnx export of per channel fake quantize functions (#42835)
Browse files Browse the repository at this point in the history
Summary:
Fixes #39502

This PR adds support for exporting **fake_quantize_per_channel_affine** to a pair of QuantizeLinear and DequantizeLinear. Per tensor support was added by PR #39738.

`axis` attribute of QuantizeLinear and DequantizeLinear, which is required for per channel support, is added in opset13 added by onnx/onnx#2772.

[update 1/20/2021]: opset13 is being supported on master, the added function is now properly tested. Code also rebased to new master.

The function is also tested offline with the following code
```python
import torch
from torch import quantization

from torchvision import models
qat_resnet18 = models.resnet18(pretrained=True).eval().cuda()

qat_resnet18.qconfig = quantization.QConfig(
    activation=quantization.default_fake_quant, weight=quantization.default_per_channel_weight_fake_quant)
quantization.prepare_qat(qat_resnet18, inplace=True)
qat_resnet18.apply(quantization.enable_observer)
qat_resnet18.apply(quantization.enable_fake_quant)

dummy_input = torch.randn(16, 3, 224, 224).cuda()
_ = qat_resnet18(dummy_input)
for module in qat_resnet18.modules():
    if isinstance(module, quantization.FakeQuantize):
        module.calculate_qparams()
qat_resnet18.apply(quantization.disable_observer)

qat_resnet18.cuda()

input_names = [ "actual_input_1" ]
output_names = [ "output1" ]

torch.onnx.export(qat_resnet18, dummy_input, "quant_model.onnx", verbose=True, opset_version=13)
```
It can generate the desired graph.

Pull Request resolved: #42835

Reviewed By: houseroad

Differential Revision: D26293823

Pulled By: SplitInfinity

fbshipit-source-id: 300498a2e24b7731b12fa2fbdea4e73dde80e7ea
  • Loading branch information
skyw authored and Meghan Lele committed Feb 18, 2021
1 parent f7c4afc commit 1cf3ff6
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 1 deletion.
23 changes: 22 additions & 1 deletion test/onnx/test_models.py
Expand Up @@ -182,7 +182,7 @@ def test_fake_quant(self):
self.exportTest(toC(FakeQuantNet()), toC(x))

@skipIfUnsupportedMinOpsetVersion(10)
def test_qat_resnet(self):
def test_qat_resnet_pertensor(self):
# Quantize ResNet50 model
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
qat_resnet50 = resnet50()
Expand All @@ -202,6 +202,27 @@ def test_qat_resnet(self):

self.exportTest(toC(qat_resnet50), toC(x))

@skipIfUnsupportedMinOpsetVersion(13)
def test_qat_resnet_per_channel(self):
# Quantize ResNet50 model
x = torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0)
qat_resnet50 = resnet50()

qat_resnet50.qconfig = quantization.QConfig(
activation=quantization.default_fake_quant,
weight=quantization.default_per_channel_weight_fake_quant)
quantization.prepare_qat(qat_resnet50, inplace=True)
qat_resnet50.apply(torch.quantization.enable_observer)
qat_resnet50.apply(torch.quantization.enable_fake_quant)

_ = qat_resnet50(x)
for module in qat_resnet50.modules():
if isinstance(module, quantization.FakeQuantize):
module.calculate_qparams()
qat_resnet50.apply(torch.quantization.disable_observer)

self.exportTest(toC(qat_resnet50), toC(x))

@disableScriptTest() # None type in outputs
def test_googlenet(self):
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
Expand Down
14 changes: 14 additions & 0 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Expand Up @@ -5998,6 +5998,20 @@ def forward(self, input):
x = torch.randn(6, 4, 3, 3)
self.run_test(FakeQuantizePerTensorModel(), (x))

@skipIfUnsupportedMinOpsetVersion(13)
def test_fake_quantize_per_channel(self):
class FakeQuantizePerChannelModel(torch.nn.Module):
def forward(self, input):
amax = torch.ones(4)
scale = amax / 127.
zero_point = torch.zeros_like(amax, dtype=torch.long)
# Quantize twice to test differnet branches
y = torch.fake_quantize_per_channel_affine(input, scale, zero_point, 1, 0, 255)
return torch.fake_quantize_per_channel_affine(y, scale, zero_point, 1, -128, 127)

x = torch.randn(6, 4, 3, 3)
self.run_test(FakeQuantizePerChannelModel(), (x))

def test_batchnorm_training(self):
class MyModule(torch.nn.Module):
def __init__(self):
Expand Down
15 changes: 15 additions & 0 deletions torch/onnx/symbolic_opset13.py
Expand Up @@ -121,6 +121,21 @@ def where(g, condition, self=None, other=None, _outputs=None):
return sym_help._unbind_helper(g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs)
return g.op("Where", condition, self, other)

@parse_args('v', 'v', 'v', 'i', 'i', 'i')
def fake_quantize_per_channel_affine(g, inputs, scale, zero_point, axis, quant_min=-128, quant_max=127):
if quant_min not in [0, -128] or quant_max not in [127, 255]:
raise RuntimeError(
"ONNX defines [0, 255] for quint8 and [-128, 127] for qint8, got [{}, {}]".format(quant_min, quant_max))

# ONNX defines zero_point to be int8 or uint8
if quant_min == 0:
zero_point = g.op("Cast", zero_point, to_i=sym_help.cast_pytorch_to_onnx['Byte'])
else:
zero_point = g.op("Cast", zero_point, to_i=sym_help.cast_pytorch_to_onnx['Char'])
return g.op(
"DequantizeLinear",
g.op("QuantizeLinear", inputs, scale, zero_point, axis_i=axis),
scale, zero_point, axis_i=axis)

def _reduce_op_symbolic(onnx_op_name):
def symbolic(g, self, dim=None, keepdim=None):
Expand Down

0 comments on commit 1cf3ff6

Please sign in to comment.