Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx export of per channel fake quantize functions #42835

Closed

Conversation

skyw
Copy link
Contributor

@skyw skyw commented Aug 10, 2020

Fixes #39502

This PR adds support for exporting fake_quantize_per_channel_affine to a pair of QuantizeLinear and DequantizeLinear. Per tensor support was added by PR #39738.

axis attribute of QuantizeLinear and DequantizeLinear, which is required for per channel support, is added in opset13 added by onnx/onnx#2772.

[update 1/20/2021]: opset13 is being supported on master, the added function is now properly tested. Code also rebased to new master.

The function is also tested offline with the following code

import torch
from torch import quantization

from torchvision import models
qat_resnet18 = models.resnet18(pretrained=True).eval().cuda()

qat_resnet18.qconfig = quantization.QConfig(
    activation=quantization.default_fake_quant, weight=quantization.default_per_channel_weight_fake_quant)
quantization.prepare_qat(qat_resnet18, inplace=True)
qat_resnet18.apply(quantization.enable_observer)
qat_resnet18.apply(quantization.enable_fake_quant)

dummy_input = torch.randn(16, 3, 224, 224).cuda()
_ = qat_resnet18(dummy_input)
for module in qat_resnet18.modules():
    if isinstance(module, quantization.FakeQuantize):
        module.calculate_qparams()
qat_resnet18.apply(quantization.disable_observer)

qat_resnet18.cuda()

input_names = [ "actual_input_1" ]
output_names = [ "output1" ]


torch.onnx.export(qat_resnet18, dummy_input, "quant_model.onnx", verbose=True, opset_version=13)

It can generate the desired graph.

@dr-ci
Copy link

dr-ci bot commented Aug 10, 2020

💊 CI failures summary and remediations

As of commit 0e35b9a (more details on the Dr. CI page):


  • 2/2 failures possibly* introduced in this PR
    • 2/2 non-CircleCI failure(s)

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@skyw
Copy link
Contributor Author

skyw commented Aug 11, 2020

@skipIfUnsupportedMinOpsetVersion(13)
def test_qat_resnet_per_channel(self):
# Quantize ResNet50 model
x = Variable(torch.randn(BATCH_SIZE, 3, 224, 224).fill_(1.0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is explicit Variable still useful?

Copy link
Contributor Author

@skyw skyw Aug 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess no. I copied from other tests.

It seems all other tests use Variable, should I just keep it consistent with the rest?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest to remove the deprecated usage of Variable. The other tests could be updated in a separate PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed Variable in all fake quantization ONNX export tests.

@zou3519 zou3519 requested a review from houseroad August 11, 2020 14:04
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 11, 2020
@skyw
Copy link
Contributor Author

skyw commented Aug 19, 2020

@houseroad , any more comment? Thanks.

@skyw skyw force-pushed the skyw/fake_quant_per_channel_onnx_export branch from 650f2e6 to 0a53c2d Compare August 20, 2020 16:10
@skyw skyw force-pushed the skyw/fake_quant_per_channel_onnx_export branch from cb9015d to 583b5ce Compare August 27, 2020 03:48
@codecov
Copy link

codecov bot commented Aug 27, 2020

Codecov Report

Merging #42835 (0e35b9a) into master (f68e5f1) will decrease coverage by 0.34%.
The diff coverage is 41.93%.

@@            Coverage Diff             @@
##           master   #42835      +/-   ##
==========================================
- Coverage   80.85%   80.50%   -0.35%     
==========================================
  Files        1931     1931              
  Lines      210934   210965      +31     
==========================================
- Hits       170544   169836     -708     
- Misses      40390    41129     +739     

@skyw
Copy link
Contributor Author

skyw commented Aug 27, 2020

ONNX runtime now supports opset13 version of QuantizeLinear and DequantizeLinear, microsoft/onnxruntime#4759 on master. Verified Q/DQ with axis generated by this change does is valid ONNX op.

@skyw
Copy link
Contributor Author

skyw commented Sep 4, 2020

TF2ONNX now supports this, onnx/tensorflow-onnx#1081

Pull in upstream master
@daquexian
Copy link
Contributor

daquexian commented Nov 17, 2020

@houseroad @raghuramank100 Could you please review this PR? It is an important step to deploy the PyTorch QAT model on various backends like TensorRT and many mobile frameworks.

@skyw skyw force-pushed the skyw/fake_quant_per_channel_onnx_export branch from 44ac44b to da9c59b Compare January 19, 2021 20:09
@facebook-github-bot
Copy link
Contributor

Hi @skyw!

Thank you for your pull request and welcome to our community. We require contributors to sign our Contributor License Agreement, and we don't seem to have you on file.

In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@skyw skyw force-pushed the skyw/fake_quant_per_channel_onnx_export branch from da9c59b to 1899521 Compare January 20, 2021 00:31
@skyw skyw force-pushed the skyw/fake_quant_per_channel_onnx_export branch from 1899521 to 154d556 Compare January 20, 2021 17:26
@facebook-github-bot
Copy link
Contributor

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Facebook open source project. Thanks!

@skyw skyw force-pushed the skyw/fake_quant_per_channel_onnx_export branch from cdb0de3 to b37a9e6 Compare January 28, 2021 21:11
@skyw
Copy link
Contributor Author

skyw commented Feb 2, 2021

@spandantiwari , rebased to new master.

Copy link
Contributor

@neginraoof neginraoof left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this op.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SplitInfinity has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SplitInfinity has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Member

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

@facebook-github-bot
Copy link
Contributor

@SplitInfinity merged this pull request in 7363da7.

SplitInfinity pushed a commit that referenced this pull request Feb 18, 2021
Summary:
Fixes #39502

This PR adds support for exporting **fake_quantize_per_channel_affine** to a pair of QuantizeLinear and DequantizeLinear. Per tensor support was added by PR #39738.

`axis` attribute of QuantizeLinear and DequantizeLinear, which is required for per channel support, is added in opset13 added by onnx/onnx#2772.

[update 1/20/2021]: opset13 is being supported on master, the added function is now properly tested. Code also rebased to new master.

The function is also tested offline with the following code
```python
import torch
from torch import quantization

from torchvision import models
qat_resnet18 = models.resnet18(pretrained=True).eval().cuda()

qat_resnet18.qconfig = quantization.QConfig(
    activation=quantization.default_fake_quant, weight=quantization.default_per_channel_weight_fake_quant)
quantization.prepare_qat(qat_resnet18, inplace=True)
qat_resnet18.apply(quantization.enable_observer)
qat_resnet18.apply(quantization.enable_fake_quant)

dummy_input = torch.randn(16, 3, 224, 224).cuda()
_ = qat_resnet18(dummy_input)
for module in qat_resnet18.modules():
    if isinstance(module, quantization.FakeQuantize):
        module.calculate_qparams()
qat_resnet18.apply(quantization.disable_observer)

qat_resnet18.cuda()

input_names = [ "actual_input_1" ]
output_names = [ "output1" ]

torch.onnx.export(qat_resnet18, dummy_input, "quant_model.onnx", verbose=True, opset_version=13)
```
It can generate the desired graph.

Pull Request resolved: #42835

Reviewed By: houseroad

Differential Revision: D26293823

Pulled By: SplitInfinity

fbshipit-source-id: 300498a2e24b7731b12fa2fbdea4e73dde80e7ea
malfet pushed a commit that referenced this pull request Feb 18, 2021
Summary:
Fixes #39502

This PR adds support for exporting **fake_quantize_per_channel_affine** to a pair of QuantizeLinear and DequantizeLinear. Per tensor support was added by PR #39738.

`axis` attribute of QuantizeLinear and DequantizeLinear, which is required for per channel support, is added in opset13 added by onnx/onnx#2772.

[update 1/20/2021]: opset13 is being supported on master, the added function is now properly tested. Code also rebased to new master.

The function is also tested offline with the following code
```python
import torch
from torch import quantization

from torchvision import models
qat_resnet18 = models.resnet18(pretrained=True).eval().cuda()

qat_resnet18.qconfig = quantization.QConfig(
    activation=quantization.default_fake_quant, weight=quantization.default_per_channel_weight_fake_quant)
quantization.prepare_qat(qat_resnet18, inplace=True)
qat_resnet18.apply(quantization.enable_observer)
qat_resnet18.apply(quantization.enable_fake_quant)

dummy_input = torch.randn(16, 3, 224, 224).cuda()
_ = qat_resnet18(dummy_input)
for module in qat_resnet18.modules():
    if isinstance(module, quantization.FakeQuantize):
        module.calculate_qparams()
qat_resnet18.apply(quantization.disable_observer)

qat_resnet18.cuda()

input_names = [ "actual_input_1" ]
output_names = [ "output1" ]

torch.onnx.export(qat_resnet18, dummy_input, "quant_model.onnx", verbose=True, opset_version=13)
```
It can generate the desired graph.

Pull Request resolved: #42835

Reviewed By: houseroad

Differential Revision: D26293823

Pulled By: SplitInfinity

fbshipit-source-id: 300498a2e24b7731b12fa2fbdea4e73dde80e7ea

Co-authored-by: Hao Wu <skyw@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed Merged open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Export fake quantization function to ONNX
9 participants