Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Support returning quantized weights and bias for BYOC use cases #9135

Merged
merged 9 commits into from
Sep 28, 2021

Conversation

masahi
Copy link
Member

@masahi masahi commented Sep 27, 2021

This addresses the issue discussed in https://discuss.tvm.apache.org/t/qnn-pytorch-byoc-full-integer-qnn-support/11127

PyTorch stores quantized weights in a custom format, so we cannot directly access 8 bit weights as Numpy arrays. We use a PyTorch function to unpack quantized weights into float32 arrays and quantization parameters.

By default, we use qnn.op.quantize(...) to recover int8 weights in a QNN graph, return float32 weights to users, and rely on the QNN lowering and the Relay constant folding pass to quantize weights at compile time. In BYOC use cases, however, we cannot apply the constant folding pass on a QNN graph.

I added a new option to quantize weights in the frontend using a function that is equivalent to qnn.op.quantize(...) operating on Numpy arrays. In hindsight, we should've chosen this way from the beginning. The old behavior is kept as the default for backward compatibility.

cc @comaniac

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Just a nit.

python/tvm/relay/frontend/pytorch.py Outdated Show resolved Hide resolved
@masahi masahi merged commit 4905a8c into apache:main Sep 28, 2021
AndrewZhaoLuo added a commit to AndrewZhaoLuo/tvm that referenced this pull request Sep 29, 2021
* main:
  Fix flaky NMS test by making sure scores are unique (apache#9140)
  [Relay] Merge analysis/context_analysis.cc and transforms/device_annotation.cc (apache#9038)
  [LLVM] Make changes needed for opaque pointers (apache#9138)
  Arm(R) Ethos(TM)-U NPU codegen integration (apache#8849)
  [CI] Split Integration tests out of first phase of pipeline (apache#9128)
  [Meta Schedule][M3b] Runner (apache#9111)
  Fix Google Mock differences between Ubuntu 18.04 and 16.04 (apache#9141)
  [TIR] add loop partition hint pragma (apache#9121)
  fix things (apache#9146)
  [Meta Schedule][M3a] SearchStrategy (apache#9132)
  [Frontend][PyTorch] support for quantized conv_transpose2d op (apache#9133)
  [UnitTest] Parametrized test_conv2d_int8_intrinsics (apache#9143)
  [OpenCL] Remove redundant visit statement in CodeGen. (apache#9144)
  [BYOC] support arbitrary input dims for add/mul/relu of dnnl c_src codegen (apache#9127)
  [Relay][ConvertLayout] Support for qnn.conv2d_transpose (apache#9139)
  add nn.global_avgpool to fq2i (apache#9137)
  [UnitTests] Enable minimum testing on Vulkan target in CI (apache#9093)
  [Torch] Support returning quantized weights and bias for BYOC use cases (apache#9135)
  [Relay] Prepare for new plan_devices.cc (part II) (apache#9130)
  [microTVM][Zephyr] Add MIMXRT1050 board support (apache#9068)
ylc pushed a commit to ylc/tvm that referenced this pull request Sep 29, 2021
…es (apache#9135)

* [Torch] refactored the way is bias quantization done

* support returning 8bit weight

* add test

* add doc

* pylint

* return_int8_weight -> keep_quantized_weight

* fixed for dynamic linear case

* remove test function call

* simplifying
ylc pushed a commit to ylc/tvm that referenced this pull request Jan 13, 2022
…es (apache#9135)

* [Torch] refactored the way is bias quantization done

* support returning 8bit weight

* add test

* add doc

* pylint

* return_int8_weight -> keep_quantized_weight

* fixed for dynamic linear case

* remove test function call

* simplifying
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants