forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_quantized.py
125 lines (98 loc) · 4.15 KB
/
test_quantized.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.jit
import numpy as np
import unittest
from common_utils import TestCase, run_tests, skipIfNotRegistered
def canonical(graph):
return str(torch._C._jit_pass_canonicalize(graph))
def _quantize(x, scale, zero_point, qmin=0, qmax=255):
"""Quantizes a numpy array."""
qx = np.round(x / scale + zero_point)
qx = np.clip(qx, qmin, qmax).astype(np.uint8)
return qx
def _dequantize(qx, scale, zero_point):
"""Dequantizes a numpy array."""
x = (qx.astype(np.float) - zero_point) * scale
return x
@skipIfNotRegistered("Relu_ENGINE_DNNLOWP",
"fbgemm-based Caffe2 ops are not linked")
class TestQuantized(TestCase):
def test_relu(self):
a = (torch.tensor([4, 6, 1, 10], dtype=torch.uint8), 0.01, 5)
r = torch.ops.c10.quantized_relu(a)
np.testing.assert_equal(r[0].numpy(), torch.tensor([5, 6, 5, 10], dtype=torch.uint8).numpy())
np.testing.assert_almost_equal(0.01, r[1])
self.assertEqual(5, r[2])
def test_quantize(self):
a = (torch.tensor([4, 6, 1, 10], dtype=torch.uint8), 0.01, 5)
r = torch.ops.c10.dequantize(a)
np.testing.assert_almost_equal(r.numpy(), [-0.01, 0.01, -0.04, 0.05])
# default args
q_def = torch.ops.c10.quantize(r)
# specified
q = torch.ops.c10.quantize(r, scale=0.01, zero_point=5)
np.testing.assert_equal(q[0].numpy(), a[0].numpy())
np.testing.assert_almost_equal(q[1], a[1])
self.assertEqual(q[2], a[2])
def test_script(self):
@torch.jit.script
def foo(x):
# type: (Tuple[Tensor, float, int]) -> Tuple[Tensor, float, int]
return torch.ops.c10.quantized_relu(x)
self.assertExpectedInline(canonical(foo.graph), '''\
graph(%x : (Tensor, float, int)):
%1 : (Tensor, float, int) = c10::quantized_relu(%x)
return (%1)
''')
class TestQuantizedOps(unittest.TestCase):
"""Tests the correctness of the quantized::relu op."""
def test_qrelu(self):
relu = torch.ops.quantized.relu
X = torch.arange(-5, 5, dtype=torch.float)
scale = 2.0
zero_point = 1
qX = X.quantize_linear(scale=scale, zero_point=zero_point)
# print("X:\n{}".format(X))
# print("\nQuantized:\n{}\nFake:\n{}".format(qX.int_repr(), _quantize(X.numpy(), scale, zero_point)))
Y = X.numpy().copy()
Y[Y < 0] = 0
qY = _quantize(Y, scale, zero_point)
qY_hat = relu(qX)
np.testing.assert_equal(qY, qY_hat.int_repr())
"""Tests the correctness of the quantized::sum_relu op."""
def test_qsumrelu_same_qparams(self):
sum_relu = torch.ops.quantized.sum_relu
A = torch.arange(-25, 25, dtype=torch.float)
B = torch.arange(-25, 25, dtype=torch.float)
scale = 2.0
zero_point = 127
qA = A.quantize_linear(scale=scale, zero_point=zero_point)
qB = A.quantize_linear(scale=scale, zero_point=zero_point)
# Sum + ReLU ground truth
C = (qA.dequantize() + qB.dequantize()).numpy()
C[C < 0] = 0
qC = _quantize(C, scale, zero_point)
qC_hat = sum_relu(qA, qB, scale=scale, zero_point=zero_point)
np.testing.assert_equal(qC, qC_hat.int_repr())
"""Tests the correctness of the quantized::sum_relu op."""
def test_qsumrelu_different_qparams(self):
sum_relu = torch.ops.quantized.sum_relu
A = torch.arange(-25, 25, dtype=torch.float)
B = torch.arange(-25, 25, dtype=torch.float)
scale_A = 3.0
zero_point_A = 7
scale_B = 5.0
zero_point_B = 127
scale_C = 0.5
zero_point_C = 5
qA = A.quantize_linear(scale=scale_A, zero_point=zero_point_A)
qB = A.quantize_linear(scale=scale_B, zero_point=zero_point_B)
# Sum + ReLU ground truth
C = (qA.dequantize() + qB.dequantize()).numpy()
C[C < 0] = 0
qC = _quantize(C, scale_C, zero_point_C)
qC_hat = sum_relu(qA, qB, scale=scale_C, zero_point=zero_point_C)
np.testing.assert_equal(qC, qC_hat.int_repr())
if __name__ == '__main__':
run_tests()