Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Unified quantization backend for x86 CPU platforms #83888

Closed
Xia-Weiwen opened this issue Aug 23, 2022 · 20 comments
Closed

[RFC] Unified quantization backend for x86 CPU platforms #83888

Xia-Weiwen opened this issue Aug 23, 2022 · 20 comments
Assignees
Labels
oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Xia-Weiwen
Copy link
Collaborator

Xia-Weiwen commented Aug 23, 2022

🚀 The feature, motivation and pitch

Description

Add a unified quantization backend ‘X86’ for x86 CPU platforms. Make it the default PyTorch quantization backend for x86 in place of FBGEMM.
It is implemented by auto selection between FBGEMM and ONEDNN during weight prepacking.

Motivation

The ONEDNN quantization backend takes advantage of features of the latest Intel® CPU products. And it supports more fused ops. It has shown better performance over FBGEMM in many (but not all) cases.
From an API design point of view, it would not be user-friendly if we expose both FBGEMM and ONEDNN backends to end users. In that case, we propose a unified quantization backend named ‘X86’ to combine the goodness of both backends while keep API simple.
In the frontend, users will be using the x86 backend by default for x86 platforms. And in the backend, we decide for them about which kernel to run and hide the details. The selection between different kernels is automatically done during weight prepacking with static information.
Thus, the X86 backend can replace FBGEMM and offer better performance.

Design philosophy

Auto kernel selection between FBGEMM and ONEDNN by simple heuristics without runtime info.

  • For non-fused ops, choose ONEDNN if it's always better. Otherwise, use simple heuristics to make selection.
  • For fused ops, from FX quantization, the x86 QEngine suggests the quant fusion patterns statically as how current fbgemm backend or onednn backend does now. The fusion patterns might include those from onednn backend that fbgemm backend doesn't support. If fused op (e.g. conv-add-relu) is always better on ONEDNN than non-fused conv + add-relu on FBGEMM, fused op is exposed. Otherwise, we expose conv + add-relu.
  • During the runtime, x86 QEngine implements the fused ops by choosing the right kernels, fbgemm or onednn. The decision can be made statically (e.g., conv+add+relu is only available on onednn, then onednn kernel is used.)
  • For implementation, the X86 backend will follow the QEngine and backend_config API.

Proposed Heuristics

Rules for auto selection between FBGEMM and ONEDNN:

  • On platforms without VNNI (e.g., Skylake), FBGEMM is always used.
  • On platforms with VNNI (e.g., Cascade lake, Ice lake, and future platforms):
    • For linear, FBGEMM is always used.
    • For convolution, FBGEMM is used for depth-wise convolution whose groups > 100; otherwise, ONEDNN is used.
  • Currently, X86 supports the same fusion patterns as FBGEMM.

For the unified backend, selection occurs during weight prepacking when ‘quantized::conv/linear_prepack’ is called. The prepack function will check hardware info (with or without VNNI) and op parameters and return a proper prepacked weight object accordingly.
For example, for linear or on platforms without VNNI, FBGEMM’s prepacked weight is always returned; and for convolution with groups=1, ONEDNN’s prepacked weight is returned. Then at runtime, it will automatically call corresponding kernels.

We have done implementation and run common models for benchmarking. The following table lists speedup ratio of throughputs of unified x86 backend vs. Pure FBGEMM:

Device\Ratio Geomean 1 core/instance 2 cores/instance 4 cores/instance 1 socket/instance
Intel(R) Xeon(R) Cascade Lake 1.701 1.821 1.921 1.513
Intel(R) Xeon(R) Ice Lake 1.767 1.810 2.042 1.346

(Table updated on Feb 21, 2023)

Note:

Performance data updated on Feb 5, 2023:
int8_benchmark_x86_vs_fbgemm_20230205.xlsx
(Using PyTorch nightly build on Feb 4, 2023, installed by pip)

Performance data updated on Feb 21, 2023:
int8_benchmark_x86_vs_fbgemm_20230221.xlsx
(Using PyTorch nightly build on Feb 20, 2023, installed by pip)

About qconfig
For compatibility, the new backend will use reduce_range=True to align with FBGEMM.
However, for accuracy, we hope to change it to reduce_range=False in the future.

Accuracy
We have run torchvision models to compare accuracy. Results show that FBGEMM and X86 backends give the same accuracy. For details, please see the worksheet:
torchvision_accuracy_comparison_fbgemm_vs_x86.xlsx

Plans

Original plans

The implementation is still pending on some optimizations of the ONEDNN backend, which are not available yet in stock PyTorch. Thus the numbers we showed above cannot be reproduced by stock PyTorch right now.
In that case, we plan to take the steps below to finally unify x86 qengines:

  1. Update ideep in stock PyTorch. Many optimizations are based on the ideep update.
  2. Optimize performance of ONEDNN backend. PR(s) will be submitted after ideep's updates.
  3. Prepare PR of the unified qengine
  4. Publicize it to end users

We hope all these changes will be landed before 1.13 release.

Current status

  • Implementation is finished and PRs are landed
  • This feature is expected to be publicized on PyTorch 2.0 release.
  • We are continuing working on improvement of onednn backend. The dispatching heuristics and supported fusion patterns might change in the future.

Alternatives

N/A

Additional context

Example of implementing conv_prepack for unified X86 quantization backend.

template <int kSpatialDim = 2>
class QConvPackWeightInt8 final {
 public:
    // Public API to do conv prepack
 private:
  static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> _run(
      Tensor weight,
      c10::optional<Tensor> bias,
      torch::List<int64_t> stride,
      torch::List<int64_t> padding,
      torch::List<int64_t> output_padding,
      torch::List<int64_t> dilation,
      int64_t groups,
      bool transpose) {
    auto& ctx = at::globalContext();
    if (ctx.qEngine() == at::QEngine::X86) {
#ifdef USE_FBGEMM
      if (no_vnni || groups > 100) {
        return PackedConvWeight<kSpatialDim>::prepack(
            weight, bias, stride, padding, output_padding, dilation, groups, transpose);
      }
#endif
#if AT_MKLDNN_ENABLED()
      return PackedConvWeightsOnednn<kSpatialDim>::prepack(
          weight, bias, stride, padding, output_padding, dilation, groups, transpose);
#endif
    }


#ifdef USE_PYTORCH_QNNPACK
    if (ctx.qEngine() == at::QEngine::QNNPACK) {
      return PackedConvWeightsQnnp<kSpatialDim>::prepack(
          weight, bias, stride, padding, output_padding, dilation, groups,
          transpose);
    }
#endif

    TORCH_CHECK(
        false,
        "Didn't find engine for operation quantized::conv2d_prepack ",
        toString(ctx.qEngine()));
  }
};

The prepacking function returns an object of prepacked weight which belongs to either FBGEMM or ONEDNN. The call to prepacked_weight->run will automatically run into the correct kernel.

Original code can be found here for reference:

class QConvPackWeightInt8 final {

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @leslie-fang-intel

@Xia-Weiwen
Copy link
Collaborator Author

Hi @jerryzh168, please review. Thanks!

@jerryzh168
Copy link
Contributor

jerryzh168 commented Aug 23, 2022

looks good to me, will post in internal groups to have more people to review the doc.

@supriyar
Copy link
Contributor

@Xia-Weiwen are there any results comparing the numerics of the model with the unified backend vs fbgemm only backend? This would be useful to access any internal impact on numerics for customers who are currently using fbgemm in production.

@supriyar
Copy link
Contributor

Another question is around the reduce_range setting https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/qconfig.py#L238. This is set to true for fbgemm backend before the observers are run. If the decision to use fbgemm vs onednn is made during the prepack function, what is the plan to ensure we set the reduce_range field correctly before calibration?

@vkuzo
Copy link
Contributor

vkuzo commented Aug 23, 2022

The selection between different kernels is automatically done at quantization time with static information.

The machine where quantization is done is not always the same machine which runs inference. Thoughts on how this could be handled? Are we sure we have enough data at quantization time to do the backend selection? Do we need to relax this constraint?

@jerryzh168
Copy link
Contributor

The selection between different kernels is automatically done at quantization time with static information.

The machine where quantization is done is not always the same machine which runs inference. Thoughts on how this could be handled? Are we sure we have enough data at quantization time to do the backend selection? Do we need to relax this constraint?

I think the selection is done every time we prepacks the weight, so we will select backend again when we load the model from disk before inference, does that work in real use cases?

@Xia-Weiwen
Copy link
Collaborator Author

@Xia-Weiwen are there any results comparing the numerics of the model with the unified backend vs fbgemm only backend? This would be useful to access any internal impact on numerics for customers who are currently using fbgemm in production.

Thanks for the question. We will provide some results later.

@Xia-Weiwen
Copy link
Collaborator Author

Xia-Weiwen commented Aug 24, 2022

Another question is around the reduce_range setting https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/qconfig.py#L238. This is set to true for fbgemm backend before the observers are run. If the decision to use fbgemm vs onednn is made during the prepack function, what is the plan to ensure we set the reduce_range field correctly before calibration?

Good question. We will figure out a solution.


Hi @supriyar. We decide to use reduce_range=True. I have updated the RFC above. Please find it at the end of the 'Proposed Heuristics' part. Thanks.

@jgong5
Copy link
Collaborator

jgong5 commented Aug 24, 2022

@Xia-Weiwen are there any results comparing the numerics of the model with the unified backend vs fbgemm only backend? This would be useful to access any internal impact on numerics for customers who are currently using fbgemm in production.

Thanks for the question. We will provide some results later.

@supriyar FBGEMM and oneDNN backends should have the same numerics by design. In PyTorch UT, oneDNN backend also apply the same result checking logic as FBGEMM. Meanwhile, we can provide more end-to-end model accuracy comparison like @Xia-Weiwen mentioned.

Another question is around the reduce_range setting https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/qconfig.py#L238. This is set to true for fbgemm backend before the observers are run. If the decision to use fbgemm vs onednn is made during the prepack function, what is the plan to ensure we set the reduce_range field correctly before calibration?

@supriyar The reduce_range setting is irrelevant to the decision whether fbgemm or onednn kernel is selected based on the proposed heuristics. Note that the heuristics only chooses the kernel, not the qengine configuration. Given a reduce_range setting, the fbgemm or onednn kernel should bring the same result on a given CPU.

The selection between different kernels is automatically done at quantization time with static information.

The machine where quantization is done is not always the same machine which runs inference. Thoughts on how this could be handled? Are we sure we have enough data at quantization time to do the backend selection? Do we need to relax this constraint?

I think the selection is done every time we prepacks the weight, so we will select backend again when we load the model from disk before inference, does that work in real use cases?

@jerryzh168 @vkuzo @supriyar A related question is what is the default reduce_range to use. reduce_range should be True for CPUs without VNNI support. Current FBGEMM qengine has reduce_range as True by default. I guess that is a conservative choice for out-of-the-box experience considering not all CPUs supporting VNNI and also different CPUs might be used in calibration and deployment? With that, and also considering the backward compatibility, we plan to set True value with the unified backend by default too. Does that sound good to you?
The tradeoff is that users are not able to get best accuracy by default, especially considering the fact that CPUs with VNNI are becoming mainstream. An alternative solution is to set reduce_range to False by default in this new unified backend while recommend users to explicitly set to True or use FBGEMM qengine on CPUs without VNNI. Comments?

@vkuzo
Copy link
Contributor

vkuzo commented Aug 24, 2022

The selection between different kernels is automatically done at quantization time with static information.

The machine where quantization is done is not always the same machine which runs inference. Thoughts on how this could be handled? Are we sure we have enough data at quantization time to do the backend selection? Do we need to relax this constraint?

I think the selection is done every time we prepacks the weight, so we will select backend again when we load the model from disk before inference, does that work in real use cases?

Should we modify the language of this RFC to make this clearer? "done at quantization time" sounds like "done during prepare/convert calls", which is different from "done during weight prepacking".

@vkuzo
Copy link
Contributor

vkuzo commented Aug 24, 2022

A related question is what is the default reduce_range to use.

IMO, it's better for the default to be numerically correct and potentially slow compared to potentially numerically incorrect and fast. This would point to keeping reduce_range=True as default, and documenting the behavior clearly. Thoughts?

@jerryzh168
Copy link
Contributor

jerryzh168 commented Aug 24, 2022

yeah agree, I think the default should be True so that it is always correct, and we can provide documentations for users who knows that they are using CPUs with VNNI support and want to have a slightly better accuracy.

@ngimel ngimel added the oncall: quantization Quantization support in PyTorch label Aug 24, 2022
@jgong5
Copy link
Collaborator

jgong5 commented Aug 24, 2022

A related question is what is the default reduce_range to use.

IMO, it's better for the default to be numerically correct and potentially slow compared to potentially numerically incorrect and fast. This would point to keeping reduce_range=True as default, and documenting the behavior clearly. Thoughts?
yeah agree, I think the default should be True so that it is always correct, and we can provide documentations for users who knows that they are using CPUs with VNNI support and want to have a slightly better accuracy.

Thanks for the suggestions. Make sense to me. BTW, I guess @vkuzo wanted to say "potentially worse accurate" right? reduce_range doesn't impact the inference speed. :)

@Xia-Weiwen
Copy link
Collaborator Author

The selection between different kernels is automatically done at quantization time with static information.

The machine where quantization is done is not always the same machine which runs inference. Thoughts on how this could be handled? Are we sure we have enough data at quantization time to do the backend selection? Do we need to relax this constraint?

I think the selection is done every time we prepacks the weight, so we will select backend again when we load the model from disk before inference, does that work in real use cases?

Should we modify the language of this RFC to make this clearer? "done at quantization time" sounds like "done during prepare/convert calls", which is different from "done during weight prepacking".

Thanks for the suggestion. Updated.

@vkuzo
Copy link
Contributor

vkuzo commented Aug 25, 2022

A related question is what is the default reduce_range to use.

IMO, it's better for the default to be numerically correct and potentially slow compared to potentially numerically incorrect and fast. This would point to keeping reduce_range=True as default, and documenting the behavior clearly. Thoughts?
yeah agree, I think the default should be True so that it is always correct, and we can provide documentations for users who knows that they are using CPUs with VNNI support and want to have a slightly better accuracy.

Thanks for the suggestions. Make sense to me. BTW, I guess @vkuzo wanted to say "potentially worse accurate" right? reduce_range doesn't impact the inference speed. :)

Yes, you are right, my mistake.

@Xia-Weiwen
Copy link
Collaborator Author

@Xia-Weiwen are there any results comparing the numerics of the model with the unified backend vs fbgemm only backend? This would be useful to access any internal impact on numerics for customers who are currently using fbgemm in production.

Thanks for the question. We will provide some results later.

Hi @supriyar We have run torchvision models with FBGEMM and X86 for accuracy comparison. Results show that they have the same accuracy. For details, please refer to the attached worksheet:
torchvision_accuracy_comparison_fbgemm_vs_x86.xlsx

pytorchmergebot pushed a commit that referenced this issue Sep 29, 2022
## Description

Implement unified quantization backend 'X86' for x86 platforms. It combines the advantages of FBGEMM and ONEDNN. It selects kernels during weight prepacking and hide the details from end users. It will be the default backend in place of FBGEMM.

For details, please refer to this RFC: [[RFC] Unified quantization backend for x86 CPU platforms](#83888)

## Validation
**Correctness**
Covered by UT

**Accuracy**
By running torchvision models on imagenet, no accuracy difference is found between FBGEMM and the unified X86 backend:
[torchvision_accuracy_comparison_fbgemm_vs_x86.xlsx](https://github.com/pytorch/pytorch/files/9598114/torchvision_accuracy_comparison_fbgemm_vs_x86.xlsx)

**Performance**
Depends on #84470 which improves performance.
For early PoC results, please refer to https://github.com/pytorch/pytorch/files/9399202/unified_qengine_poc_performance_bechmark.xlsx

With the two PRs combined, we collected some data on Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz
Method: Run multi-instances with 4 cores per instance on whole socket. Using JeMalloc and Intel OMP.
Models/throughput | fbgemm | x86 | improvement
-- | -- | -- | --
wide_resnet101_2 | 173.5675 | 241.815 | 39.32%
resnext101_32x8d | 174.365 | 339.8175 | 94.89%
resnet50 | 573.155 | 1174.14 | 104.86%
vgg19_bn | 260.335 | 337.92 | 29.80%
vgg19 | 257.935 | 333.265 | 29.21%
inception_v3 | 601.1175 | 1309.33 | 117.82%
densenet161 | 296.645 | 435.5625 | 46.83%
mnasnet1_0 | 1216.7 | 4057.515 | 233.49%
squeezenet1_0 | 1220.085 | 5153.3875 | 322.38%
alexnet | 2294.91 | 2624.6375 | 14.37%
fbnetc_100 | 976.2825 | 3110.1825 | 218.57%
shufflenet_v2_x0_5 | 1555.76 | 3026.125 | 94.51%
spnasnet_100 | 1059.065 | 3502.0975 | 230.68%
pytorch-unet | 192.76 | 246.77 | 28.02%
acgan | 257.32 | 333.7325 | 29.70%
cgan | 7790.6925 | 7803.1025 | 0.16%
sgan | 257.565 | 338.8875 | 31.57%
se_resnet50 | 492.3725 | 916.5175 | 86.14%
vggm | 300.2875 | 316.2075 | 5.30%

Environment:
- PyTorch version: 1.13.0a0+gitcdd625b
- Is debug build: False
- CUDA used to build PyTorch: None
- ROCM used to build PyTorch: N/A
- OS: Ubuntu 20.04.3 LTS (x86_64)
- GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
- Clang version: Could not collect
- CMake version: version 3.22.5
- Libc version: glibc-2.31
- Python version: 3.9.12 (main, Jun  1 2022, 11:38:51)  [GCC 7.5.0] (64-bit runtime)
- Python platform: Linux-5.11.0-27-generic-x86_64-with-glibc2.31
- Is CUDA available: False
- CUDA runtime version: No CUDA
- GPU models and configuration: No CUDA
- Nvidia driver version: No CUDA
- cuDNN version: No CUDA
- HIP runtime version: N/A
- MIOpen runtime version: N/A
- Is XNNPACK available: True

Versions of relevant libraries:
- [pip3] intel-extension-for-pytorch==1.13.0+cpu
- [pip3] numpy==1.23.3
- [pip3] pytorch-widedeep==0.3.7
- [pip3] torch==1.13.0a0+git48b423b
- [pip3] torchvision==0.14.0a0+ebb68f3
- [conda] blas                      1.0                         mkl
- [conda] intel-extension-for-pytorch 1.13.0+cpu               pypi_0    pypi
- [conda] mkl                       2021.4.0           h06a4308_640
- [conda] mkl-include               2022.1.0                 pypi_0    pypi
- [conda] mkl-service               2.4.0            py39h7f8727e_0
- [conda] mkl-static                2022.1.0                 pypi_0    pypi
- [conda] mkl_fft                   1.3.1            py39hd3c417c_0
- [conda] mkl_random                1.2.2            py39h51133e4_0
- [conda] numpy                     1.23.3                   pypi_0    pypi
- [conda] numpy-base                1.22.3           py39hf524024_0
- [conda] torch                     1.13.0a0+git48b423b          pypi_0    pypi
- [conda] torchvision               0.14.0a0+ebb68f3          pypi_0    pypi

Pull Request resolved: #84329
Approved by: https://github.com/jerryzh168
mehtanirav pushed a commit that referenced this issue Oct 4, 2022
## Description

Implement unified quantization backend 'X86' for x86 platforms. It combines the advantages of FBGEMM and ONEDNN. It selects kernels during weight prepacking and hide the details from end users. It will be the default backend in place of FBGEMM.

For details, please refer to this RFC: [[RFC] Unified quantization backend for x86 CPU platforms](#83888)

## Validation
**Correctness**
Covered by UT

**Accuracy**
By running torchvision models on imagenet, no accuracy difference is found between FBGEMM and the unified X86 backend:
[torchvision_accuracy_comparison_fbgemm_vs_x86.xlsx](https://github.com/pytorch/pytorch/files/9598114/torchvision_accuracy_comparison_fbgemm_vs_x86.xlsx)

**Performance**
Depends on #84470 which improves performance.
For early PoC results, please refer to https://github.com/pytorch/pytorch/files/9399202/unified_qengine_poc_performance_bechmark.xlsx

With the two PRs combined, we collected some data on Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz
Method: Run multi-instances with 4 cores per instance on whole socket. Using JeMalloc and Intel OMP.
Models/throughput | fbgemm | x86 | improvement
-- | -- | -- | --
wide_resnet101_2 | 173.5675 | 241.815 | 39.32%
resnext101_32x8d | 174.365 | 339.8175 | 94.89%
resnet50 | 573.155 | 1174.14 | 104.86%
vgg19_bn | 260.335 | 337.92 | 29.80%
vgg19 | 257.935 | 333.265 | 29.21%
inception_v3 | 601.1175 | 1309.33 | 117.82%
densenet161 | 296.645 | 435.5625 | 46.83%
mnasnet1_0 | 1216.7 | 4057.515 | 233.49%
squeezenet1_0 | 1220.085 | 5153.3875 | 322.38%
alexnet | 2294.91 | 2624.6375 | 14.37%
fbnetc_100 | 976.2825 | 3110.1825 | 218.57%
shufflenet_v2_x0_5 | 1555.76 | 3026.125 | 94.51%
spnasnet_100 | 1059.065 | 3502.0975 | 230.68%
pytorch-unet | 192.76 | 246.77 | 28.02%
acgan | 257.32 | 333.7325 | 29.70%
cgan | 7790.6925 | 7803.1025 | 0.16%
sgan | 257.565 | 338.8875 | 31.57%
se_resnet50 | 492.3725 | 916.5175 | 86.14%
vggm | 300.2875 | 316.2075 | 5.30%

Environment:
- PyTorch version: 1.13.0a0+gitcdd625b
- Is debug build: False
- CUDA used to build PyTorch: None
- ROCM used to build PyTorch: N/A
- OS: Ubuntu 20.04.3 LTS (x86_64)
- GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
- Clang version: Could not collect
- CMake version: version 3.22.5
- Libc version: glibc-2.31
- Python version: 3.9.12 (main, Jun  1 2022, 11:38:51)  [GCC 7.5.0] (64-bit runtime)
- Python platform: Linux-5.11.0-27-generic-x86_64-with-glibc2.31
- Is CUDA available: False
- CUDA runtime version: No CUDA
- GPU models and configuration: No CUDA
- Nvidia driver version: No CUDA
- cuDNN version: No CUDA
- HIP runtime version: N/A
- MIOpen runtime version: N/A
- Is XNNPACK available: True

Versions of relevant libraries:
- [pip3] intel-extension-for-pytorch==1.13.0+cpu
- [pip3] numpy==1.23.3
- [pip3] pytorch-widedeep==0.3.7
- [pip3] torch==1.13.0a0+git48b423b
- [pip3] torchvision==0.14.0a0+ebb68f3
- [conda] blas                      1.0                         mkl
- [conda] intel-extension-for-pytorch 1.13.0+cpu               pypi_0    pypi
- [conda] mkl                       2021.4.0           h06a4308_640
- [conda] mkl-include               2022.1.0                 pypi_0    pypi
- [conda] mkl-service               2.4.0            py39h7f8727e_0
- [conda] mkl-static                2022.1.0                 pypi_0    pypi
- [conda] mkl_fft                   1.3.1            py39hd3c417c_0
- [conda] mkl_random                1.2.2            py39h51133e4_0
- [conda] numpy                     1.23.3                   pypi_0    pypi
- [conda] numpy-base                1.22.3           py39hf524024_0
- [conda] torch                     1.13.0a0+git48b423b          pypi_0    pypi
- [conda] torchvision               0.14.0a0+ebb68f3          pypi_0    pypi

Pull Request resolved: #84329
Approved by: https://github.com/jerryzh168
alvgaona pushed a commit to alvgaona/pytorch that referenced this issue Oct 11, 2022
## Description

Implement unified quantization backend 'X86' for x86 platforms. It combines the advantages of FBGEMM and ONEDNN. It selects kernels during weight prepacking and hide the details from end users. It will be the default backend in place of FBGEMM.

For details, please refer to this RFC: [[RFC] Unified quantization backend for x86 CPU platforms](pytorch#83888)

## Validation
**Correctness**
Covered by UT

**Accuracy**
By running torchvision models on imagenet, no accuracy difference is found between FBGEMM and the unified X86 backend:
[torchvision_accuracy_comparison_fbgemm_vs_x86.xlsx](https://github.com/pytorch/pytorch/files/9598114/torchvision_accuracy_comparison_fbgemm_vs_x86.xlsx)

**Performance**
Depends on pytorch#84470 which improves performance.
For early PoC results, please refer to https://github.com/pytorch/pytorch/files/9399202/unified_qengine_poc_performance_bechmark.xlsx

With the two PRs combined, we collected some data on Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz
Method: Run multi-instances with 4 cores per instance on whole socket. Using JeMalloc and Intel OMP.
Models/throughput | fbgemm | x86 | improvement
-- | -- | -- | --
wide_resnet101_2 | 173.5675 | 241.815 | 39.32%
resnext101_32x8d | 174.365 | 339.8175 | 94.89%
resnet50 | 573.155 | 1174.14 | 104.86%
vgg19_bn | 260.335 | 337.92 | 29.80%
vgg19 | 257.935 | 333.265 | 29.21%
inception_v3 | 601.1175 | 1309.33 | 117.82%
densenet161 | 296.645 | 435.5625 | 46.83%
mnasnet1_0 | 1216.7 | 4057.515 | 233.49%
squeezenet1_0 | 1220.085 | 5153.3875 | 322.38%
alexnet | 2294.91 | 2624.6375 | 14.37%
fbnetc_100 | 976.2825 | 3110.1825 | 218.57%
shufflenet_v2_x0_5 | 1555.76 | 3026.125 | 94.51%
spnasnet_100 | 1059.065 | 3502.0975 | 230.68%
pytorch-unet | 192.76 | 246.77 | 28.02%
acgan | 257.32 | 333.7325 | 29.70%
cgan | 7790.6925 | 7803.1025 | 0.16%
sgan | 257.565 | 338.8875 | 31.57%
se_resnet50 | 492.3725 | 916.5175 | 86.14%
vggm | 300.2875 | 316.2075 | 5.30%

Environment:
- PyTorch version: 1.13.0a0+gitcdd625b
- Is debug build: False
- CUDA used to build PyTorch: None
- ROCM used to build PyTorch: N/A
- OS: Ubuntu 20.04.3 LTS (x86_64)
- GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
- Clang version: Could not collect
- CMake version: version 3.22.5
- Libc version: glibc-2.31
- Python version: 3.9.12 (main, Jun  1 2022, 11:38:51)  [GCC 7.5.0] (64-bit runtime)
- Python platform: Linux-5.11.0-27-generic-x86_64-with-glibc2.31
- Is CUDA available: False
- CUDA runtime version: No CUDA
- GPU models and configuration: No CUDA
- Nvidia driver version: No CUDA
- cuDNN version: No CUDA
- HIP runtime version: N/A
- MIOpen runtime version: N/A
- Is XNNPACK available: True

Versions of relevant libraries:
- [pip3] intel-extension-for-pytorch==1.13.0+cpu
- [pip3] numpy==1.23.3
- [pip3] pytorch-widedeep==0.3.7
- [pip3] torch==1.13.0a0+git48b423b
- [pip3] torchvision==0.14.0a0+ebb68f3
- [conda] blas                      1.0                         mkl
- [conda] intel-extension-for-pytorch 1.13.0+cpu               pypi_0    pypi
- [conda] mkl                       2021.4.0           h06a4308_640
- [conda] mkl-include               2022.1.0                 pypi_0    pypi
- [conda] mkl-service               2.4.0            py39h7f8727e_0
- [conda] mkl-static                2022.1.0                 pypi_0    pypi
- [conda] mkl_fft                   1.3.1            py39hd3c417c_0
- [conda] mkl_random                1.2.2            py39h51133e4_0
- [conda] numpy                     1.23.3                   pypi_0    pypi
- [conda] numpy-base                1.22.3           py39hf524024_0
- [conda] torch                     1.13.0a0+git48b423b          pypi_0    pypi
- [conda] torchvision               0.14.0a0+ebb68f3          pypi_0    pypi

Pull Request resolved: pytorch#84329
Approved by: https://github.com/jerryzh168
pytorchmergebot pushed a commit that referenced this issue Jan 12, 2023
**Summary**
Make x86 the default quantization backend (qengine) for X86 CPU platforms.
X86 is a unified quantization backend combining goodness of fbgemm and onednn. For more details please see #83888

**Test plan**
python test/test_quantization.py

Pull Request resolved: #91235
Approved by: https://github.com/jgong5, https://github.com/XiaobingSuper, https://github.com/malfet
@andrewor14 andrewor14 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 10, 2023
@andrewor14
Copy link
Contributor

@jerryzh168 @Xia-Weiwen What is the status of this issue? Is it done?

@jgong5
Copy link
Collaborator

jgong5 commented Aug 14, 2023

@jerryzh168 @Xia-Weiwen What is the status of this issue? Is it done?

Yes, it is completed as a feature of 2.0.

@andrewor14
Copy link
Contributor

Ok, I'm closing this then. Feel free to reopen if you think this is a mistake. Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

8 participants