From 6e7d35ecd8d0ac23b2fc7e78bebe6ac49f6c4a0b Mon Sep 17 00:00:00 2001 From: zhangqi3 Date: Wed, 27 Apr 2022 17:08:31 +0800 Subject: [PATCH] [Fix] Fix data type bug in fake_quantize_per_channel_affine. --- mqbench/fake_quantize/adaround_quantizer.py | 6 +++++- mqbench/fake_quantize/dorefa.py | 6 +++++- mqbench/fake_quantize/fixed.py | 5 ++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/mqbench/fake_quantize/adaround_quantizer.py b/mqbench/fake_quantize/adaround_quantizer.py index c5fbae3..bd22634 100644 --- a/mqbench/fake_quantize/adaround_quantizer.py +++ b/mqbench/fake_quantize/adaround_quantizer.py @@ -1,8 +1,11 @@ import torch from torch.nn.parameter import Parameter + from mqbench.fake_quantize.quantize_base import QuantizeBase +_version_under_1100 = int(torch.__version__.split('.')[1]) < 10 + def _rectified_sigmoid(alpha, zeta, gamma): """Function to generate rounding mask. @@ -105,7 +108,8 @@ def forward(self, X): if not self.adaround: if self.is_per_channel: X = torch.fake_quantize_per_channel_affine( - X, self.scale.data, self.zero_point.data.long(), + X, self.scale, + self.zero_point.long() if _version_under_1100 else self.zero_point, self.ch_axis, self.quant_min, self.quant_max) else: X = torch.fake_quantize_per_tensor_affine( diff --git a/mqbench/fake_quantize/dorefa.py b/mqbench/fake_quantize/dorefa.py index daabdd1..170b7fa 100644 --- a/mqbench/fake_quantize/dorefa.py +++ b/mqbench/fake_quantize/dorefa.py @@ -3,6 +3,8 @@ from mqbench.fake_quantize.quantize_base import QuantizeBase +_version_under_1100 = int(torch.__version__.split('.')[1]) < 10 + class DoReFaFakeQuantize(QuantizeBase): def __init__(self, observer, **observer_kwargs): super(DoReFaFakeQuantize, self).__init__(observer, **observer_kwargs) @@ -26,7 +28,9 @@ def forward(self, X): if self.fake_quant_enabled[0] == 1: if self.is_per_channel: X = torch.fake_quantize_per_channel_affine( - X, self.scale, self.zero_point.long(), self.ch_axis, self.quant_min, self.quant_max) + X, self.scale, + self.zero_point.long() if _version_under_1100 else self.zero_point, + self.ch_axis, self.quant_min, self.quant_max) else: X = torch.fake_quantize_per_tensor_affine( X, self.scale.item(), self.zero_point.item(), self.quant_min, self.quant_max) diff --git a/mqbench/fake_quantize/fixed.py b/mqbench/fake_quantize/fixed.py index 3f4ce0b..2557691 100644 --- a/mqbench/fake_quantize/fixed.py +++ b/mqbench/fake_quantize/fixed.py @@ -3,6 +3,8 @@ from mqbench.fake_quantize.quantize_base import QuantizeBase +_version_under_1100 = int(torch.__version__.split('.')[1]) < 10 + class FixedFakeQuantize(QuantizeBase): """This is actually torch.quantization.FakeQuantize. """ @@ -25,7 +27,8 @@ def forward(self, X): if self.fake_quant_enabled[0] == 1: if self.is_per_channel: X = torch.fake_quantize_per_channel_affine( - X, self.scale.data, self.zero_point.data, + X, self.scale, + self.zero_point.long() if _version_under_1100 else self.zero_point, self.ch_axis, self.quant_min, self.quant_max) else: X = torch.fake_quantize_per_tensor_affine(