forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_fake_quant.py
314 lines (278 loc) · 14.6 KB
/
test_fake_quant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
import torch
import torch.cuda
import torch.jit
import numpy as np
from hypothesis import given
from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu
hu.assert_deadline_disabled()
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.quantization import FakeQuantize
from torch.quantization import default_observer, default_per_channel_weight_observer
import io
import unittest
# Reference method for fake quantize
def _fake_quantize_per_tensor_affine_reference(X, scale, zero_point, quant_min, quant_max):
res = (torch.clamp(torch.round(X.cpu() * (1.0 / scale) + zero_point), quant_min, quant_max) - zero_point) * scale
return res
# Reference method for the gradient of the fake quantize operator
def _fake_quantize_per_tensor_affine_grad_reference(dY, X, scale, zero_point, quant_min, quant_max):
Xq = torch.round(X.cpu() * (1.0 / scale) + zero_point)
mask = (Xq >= quant_min) * (Xq <= quant_max)
res = torch.zeros_like(dY.cpu())
res[mask] = dY.cpu()[mask]
return res
# Helper function used to simulate per-channel fake-quant against any axis
def _permute_to_axis_zero(X, axis):
new_axis_list = list(range(X.dim()))
new_axis_list[axis] = 0
new_axis_list[0] = axis
y = X.permute(tuple(new_axis_list))
return y, new_axis_list
# Reference method for fake quantize
def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
X, permute_axis_list = _permute_to_axis_zero(X, axis)
res = torch.zeros_like(X)
for i in range(X.size()[0]):
res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) +
per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i]
out = res.permute(tuple(permute_axis_list))
return out
# Reference method for the gradient of the fake quantize operator
def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
X, permute_axis_list = _permute_to_axis_zero(X, axis)
Xq = torch.zeros_like(X)
for i in range(X.size()[0]):
Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i])
Xq = Xq.permute(tuple(permute_axis_list))
mask = (Xq >= quant_min) * (Xq <= quant_max)
res = torch.zeros_like(dY)
res[mask] = dY[mask]
return res
def to_tensor(X, device):
return torch.tensor(X).to(device=torch.device(device), dtype=torch.float32)
NP_RANDOM_SEED = 19
tolerance = 1e-6
class TestFakeQuantizePerTensor(TestCase):
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_forward_per_tensor(self, device, X):
r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
"""
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale, zero_point, quant_min, quant_max)
Y_prime = torch.fake_quantize_per_tensor_affine(
X, scale, zero_point, quant_min, quant_max)
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
@unittest.skip("temporarily disable the test")
def test_backward_per_tensor(self, device, X):
r"""Tests the backward method.
"""
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
X.requires_grad_()
Y = _fake_quantize_per_tensor_affine_reference(X.cpu(), scale, zero_point, quant_min, quant_max)
Y_prime = torch.fake_quantize_per_tensor_affine(
X, scale, zero_point, quant_min, quant_max)
dout = torch.rand(X.shape, dtype=torch.float).to(device)
dX = _fake_quantize_per_tensor_affine_grad_reference(
dout, X, scale, zero_point, quant_min, quant_max)
Y_prime.backward(dout)
np.testing.assert_allclose(dX.cpu(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
# https://github.com/pytorch/pytorch/issues/30604
@unittest.skip("temporarily disable the test")
def test_numerical_consistency_per_tensor(self, device, X):
r"""Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op
"""
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
# quantize_per_tensor and dequantize are only implemented in CPU
Y = torch.dequantize(torch.quantize_per_tensor(X.cpu(), scale, zero_point, torch_type))
Y_prime = torch.fake_quantize_per_tensor_affine(
X, scale, zero_point, quant_min, quant_max)
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=[torch.quint8])),
)
def test_fq_module(self, device, X):
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
X.requires_grad_()
fq_module = torch.quantization.default_fake_quant().to(device)
Y_prime = fq_module(X)
assert fq_module.scale is not None
assert fq_module.zero_point is not None
Y = _fake_quantize_per_tensor_affine_reference(X, fq_module.scale, fq_module.zero_point, quant_min, quant_max)
np.testing.assert_allclose(Y.cpu().detach().numpy(), Y_prime.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
# Test backward
dout = torch.rand(X.shape, dtype=torch.float, device=device)
Y_prime.backward(dout)
dX = _fake_quantize_per_tensor_affine_grad_reference(dout, X, fq_module.scale, fq_module.zero_point, quant_min, quant_max)
np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
def test_fq_serializable(self):
observer = default_observer
quant_min = 0
quant_max = 255
fq_module = FakeQuantize(observer, quant_min, quant_max)
X = torch.tensor([-5, -3.5, -2, 0, 3, 5, 7], dtype=torch.float32)
y_ref = fq_module(X)
state_dict = fq_module.state_dict()
self.assertEqual(state_dict['scale'], 0.094488)
self.assertEqual(state_dict['zero_point'], 53)
b = io.BytesIO()
torch.save(state_dict, b)
b.seek(0)
loaded_dict = torch.load(b)
loaded_fq_module = FakeQuantize(observer, quant_min, quant_max)
loaded_fq_module.load_state_dict(loaded_dict)
for key in state_dict:
self.assertEqual(state_dict[key], loaded_fq_module.state_dict()[key])
self.assertEqual(loaded_fq_module.calculate_qparams(), fq_module.calculate_qparams())
def test_fake_quant_control(self):
torch.manual_seed(42)
X = torch.rand(20, 10, dtype=torch.float32)
fq_module = torch.quantization.default_fake_quant()
# Output of fake quant is not identical to input
Y = fq_module(X)
self.assertNotEqual(Y, X)
torch.quantization.disable_fake_quant(fq_module)
X = torch.rand(20, 10, dtype=torch.float32)
Y = fq_module(X)
# Fake quant is disabled,output is identical to input
self.assertEqual(Y, X)
scale = fq_module.scale
zero_point = fq_module.zero_point
torch.quantization.disable_observer(fq_module)
torch.quantization.enable_fake_quant(fq_module)
X = 10.0 * torch.rand(20, 10, dtype=torch.float32) - 5.0
Y = fq_module(X)
self.assertNotEqual(Y, X)
# Observer is disabled, scale and zero-point do not change
self.assertEqual(fq_module.scale, scale)
self.assertEqual(fq_module.zero_point, zero_point)
torch.quantization.enable_observer(fq_module)
Y = fq_module(X)
self.assertNotEqual(Y, X)
# Observer is enabled, scale and zero-point are different
self.assertNotEqual(fq_module.scale, scale)
self.assertNotEqual(fq_module.zero_point, zero_point)
class TestFakeQuantizePerChannel(TestCase):
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_forward_per_channel(self, device, X):
r"""Tests the forward path of the FakeQuantizePerTensorAffine op.
"""
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, axis, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
scale = to_tensor(scale, device)
zero_point = torch.tensor(zero_point).to(dtype=torch.int64, device=device)
Y = _fake_quantize_per_channel_affine_reference(X.cpu(), scale.cpu(), zero_point.cpu(), axis, quant_min, quant_max)
Y_prime = torch.fake_quantize_per_channel_affine(
X, scale, zero_point, axis, quant_min, quant_max)
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
def test_backward_per_channel(self, device, X):
r"""Tests the backward method.
"""
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, axis, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
scale = to_tensor(scale, device)
zero_point = torch.tensor(zero_point).to(dtype=torch.int64, device=device)
X.requires_grad_()
Y_prime = torch.fake_quantize_per_channel_affine(
X, scale, zero_point, axis, quant_min, quant_max)
dout = torch.rand(X.shape, dtype=torch.float).to(device)
dX = _fake_quantize_per_channel_affine_grad_reference(
dout, X, scale, zero_point, axis, quant_min, quant_max)
Y_prime.backward(dout)
np.testing.assert_allclose(dX.cpu().detach().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.per_channel_tensor(shapes=hu.array_shapes(1, 5,),
qparams=hu.qparams(dtypes=torch.quint8)))
@unittest.skip("temporarily disable the test")
def test_numerical_consistency_per_channel(self, device, X):
r"""Comparing numerical consistency between CPU quantize/dequantize op and the CPU fake quantize op
"""
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, axis, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
scale = to_tensor(scale, device)
zero_point = torch.tensor(zero_point).to(dtype=torch.int64, device=device)
# quantize_linear and dequantize are only implemented in CPU
Y = torch.dequantize(torch.quantize_per_channel(X.cpu(), scale.cpu(), zero_point.cpu(), axis, torch_type))
Y_prime = torch.fake_quantize_per_channel_affine(
X, scale, zero_point, axis, quant_min, quant_max)
np.testing.assert_allclose(Y, Y_prime.cpu(), rtol=tolerance, atol=tolerance)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']),
X=hu.per_channel_tensor(shapes=hu.array_shapes(2, 5,),
qparams=hu.qparams(dtypes=torch.qint8)))
def test_fq_module(self, device, X):
np.random.seed(NP_RANDOM_SEED)
X, (scale, zero_point, axis, torch_type) = X
quant_min = torch.iinfo(torch_type).min
quant_max = torch.iinfo(torch_type).max
X = to_tensor(X, device)
X.requires_grad_()
fq_module = FakeQuantize(default_per_channel_weight_observer, quant_min, quant_max, ch_axis=axis).to(device)
Y_prime = fq_module(X)
assert fq_module.scale is not None
assert fq_module.zero_point is not None
Y = _fake_quantize_per_channel_affine_reference(X, fq_module.scale,
fq_module.zero_point, axis, quant_min, quant_max)
np.testing.assert_allclose(Y.cpu().detach().numpy(), Y_prime.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
# Test backward
dout = torch.rand(X.shape, dtype=torch.float, device=device)
Y_prime.backward(dout)
dX = _fake_quantize_per_channel_affine_grad_reference(dout, X, fq_module.scale,
fq_module.zero_point, axis, quant_min, quant_max)
np.testing.assert_allclose(dX.cpu().numpy(), X.grad.cpu().detach().numpy(), rtol=tolerance, atol=tolerance)
def test_fq_serializable(self):
observer = default_per_channel_weight_observer
quant_min = -128
quant_max = 127
fq_module = FakeQuantize(observer, quant_min, quant_max)
X = torch.tensor([[-5, -3.5, -2, 0, 3, 5, 7], [1, 3, 2, 5, 6.5, 8, 10]], dtype=torch.float32)
y_ref = fq_module(X)
state_dict = fq_module.state_dict()
self.assertEqual(state_dict['scale'], [0.054902, 0.078431])
self.assertEqual(state_dict['zero_point'], [0, 0])
b = io.BytesIO()
torch.save(state_dict, b)
b.seek(0)
loaded_dict = torch.load(b)
for key in state_dict:
self.assertEqual(state_dict[key], loaded_dict[key])
if __name__ == '__main__':
run_tests()