Skip to content

Commit

Permalink
[Fix] Fix data type bug in fake_quantize_per_channel_affine.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangqi3 authored and Tracin committed Apr 27, 2022
1 parent cf41d91 commit 6e7d35e
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
6 changes: 5 additions & 1 deletion mqbench/fake_quantize/adaround_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import torch
from torch.nn.parameter import Parameter

from mqbench.fake_quantize.quantize_base import QuantizeBase


_version_under_1100 = int(torch.__version__.split('.')[1]) < 10

def _rectified_sigmoid(alpha, zeta, gamma):
"""Function to generate rounding mask.
Expand Down Expand Up @@ -105,7 +108,8 @@ def forward(self, X):
if not self.adaround:
if self.is_per_channel:
X = torch.fake_quantize_per_channel_affine(
X, self.scale.data, self.zero_point.data.long(),
X, self.scale,
self.zero_point.long() if _version_under_1100 else self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
else:
X = torch.fake_quantize_per_tensor_affine(
Expand Down
6 changes: 5 additions & 1 deletion mqbench/fake_quantize/dorefa.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from mqbench.fake_quantize.quantize_base import QuantizeBase


_version_under_1100 = int(torch.__version__.split('.')[1]) < 10

class DoReFaFakeQuantize(QuantizeBase):
def __init__(self, observer, **observer_kwargs):
super(DoReFaFakeQuantize, self).__init__(observer, **observer_kwargs)
Expand All @@ -26,7 +28,9 @@ def forward(self, X):
if self.fake_quant_enabled[0] == 1:
if self.is_per_channel:
X = torch.fake_quantize_per_channel_affine(
X, self.scale, self.zero_point.long(), self.ch_axis, self.quant_min, self.quant_max)
X, self.scale,
self.zero_point.long() if _version_under_1100 else self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
else:
X = torch.fake_quantize_per_tensor_affine(
X, self.scale.item(), self.zero_point.item(), self.quant_min, self.quant_max)
Expand Down
5 changes: 4 additions & 1 deletion mqbench/fake_quantize/fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from mqbench.fake_quantize.quantize_base import QuantizeBase


_version_under_1100 = int(torch.__version__.split('.')[1]) < 10

class FixedFakeQuantize(QuantizeBase):
"""This is actually torch.quantization.FakeQuantize.
"""
Expand All @@ -25,7 +27,8 @@ def forward(self, X):
if self.fake_quant_enabled[0] == 1:
if self.is_per_channel:
X = torch.fake_quantize_per_channel_affine(
X, self.scale.data, self.zero_point.data,
X, self.scale,
self.zero_point.long() if _version_under_1100 else self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
else:
X = torch.fake_quantize_per_tensor_affine(
Expand Down

0 comments on commit 6e7d35e

Please sign in to comment.