forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_quantized.py
1701 lines (1526 loc) · 79.7 KB
/
test_quantized.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import numpy as np
import unittest
import torch
import torch.jit
import torch.nn.functional as F
from torch.nn.modules.utils import _pair
from hypothesis import settings, HealthCheck
from hypothesis import assume, given
from hypothesis import strategies as st
import hypothesis_utils as hu
from hypothesis_utils import no_deadline
from common_utils import TEST_WITH_UBSAN, TestCase, run_tests, IS_PPC
from common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
override_quantized_engine
# Make sure we won't have overflows from vpmaddubsw instruction used in FBGEMM.
# On the current Intel x86 architecture, we need to utilize vpmaddubsw instruction
# for the 8-bit int multiplication. This instruction vertically multiplies each
# unsigned 8-bit integer from a with the corresponding signed 8-bit integer from
# b, producing intermediate signed 16-bit integers. This function modifies the
# weights to eliminate the overflow on the signed 16-bit integers.
def avoid_vpmaddubsw_overflow_linear(
batch_size, input_channels, output_channels, X, X_min, X_max, W, W_min, W_max
):
for i, j in np.ndindex((batch_size, output_channels)):
for k in range(0, input_channels // 2 * 2, 2):
x0 = X[i, k] - X_min
x1 = X[i, k + 1] - X_min
w0 = W[j, k] - 128 - W_min
w1 = W[j, k + 1] - 128 - W_min
if x0 * w0 + x1 * w1 < -(1 << 15):
w1_adjusted = (-(1 << 15) - float(x0) * w0) / x1
W[j, k + 1] = int(w1_adjusted) + 128 + W_min
elif x0 * w0 + x1 * w1 > (1 << 15) - 1:
w1_adjusted = ((1 << 15) - 1 - float(x0) * w0) / x1
W[j, k + 1] = int(w1_adjusted) + 128 + W_min
# Go through the same loop again to double check we don't have any overflow
for i, j in np.ndindex((batch_size, output_channels)):
for k in range(0, input_channels // 2 * 2, 2):
x0 = X[i, k] - X_min
x1 = X[i, k + 1] - X_min
w0 = W[j, k] - 128 - W_min
w1 = W[j, k + 1] - 128 - W_min
assert -(1 << 15) <= x0 * w0 + x1 * w1 < (1 << 15)
# Reference quantized Linear operator
def qlinear_ref(X_q, X_scale, X_zp, W_q, W_scale, W_zp, b_q, Y_scale, Y_zp):
X_q = np.reshape(X_q, (-1, X_q.shape[X_q.ndim - 1]))
row_offsets_ref = X_q.sum(axis=1).astype(np.int32).reshape((-1, 1))
col_offsets_ref = W_q.sum(axis=1).astype(np.int32).reshape((1, -1))
assert X_q.ndim == 2
batch_size, input_channels = X_q.shape
Prod_XqWq_ref = (
np.matmul(X_q.astype(np.int32), W_q.astype(np.int32).T)
- W_zp * row_offsets_ref
- X_zp * col_offsets_ref
+ input_channels * X_zp * W_zp
)
if b_q is not None:
Prod_XqWq_ref += b_q
Y_q_ref = _quantize(Prod_XqWq_ref, Y_scale / (X_scale * W_scale), Y_zp)
return Y_q_ref
"""Computes the output shape given pooling parameters."""
def pool_output_shape(input_size, kernel_size, padding, stride,
dilation, ceiling_mode=False):
if stride is None:
stride = kernel_size
output_size = (
(input_size + 2 * padding - dilation * (kernel_size - 1) - 1
+ (stride - 1 if ceiling_mode else 0)) // stride + 1)
if (padding > 0 and
((output_size - 1) * stride >= input_size + padding)):
output_size += 1
return output_size
class TestQuantizedOps(TestCase):
"""Tests the correctness of the quantized::relu op."""
@given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
qparams=hu.qparams()))
def test_qrelu(self, X):
X, (scale, zero_point, torch_type) = X
Y = X.copy()
Y[Y < 0] = 0
qY = torch.quantize_per_tensor(torch.from_numpy(Y), scale=scale,
zero_point=zero_point, dtype=torch_type)
X = torch.from_numpy(X)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)
ops_under_test = {
'native': torch.relu,
'nn.functional': torch.nn.functional.relu,
}
for name, op in ops_under_test.items():
qY_hat = op(qX)
self.assertEqual(qY, qY_hat, message="{} relu failed".format(name))
ops_under_test_inplace = {
'inplace native': torch.relu_,
'inplace nn.functional': torch.nn.functional.relu_,
}
for name, op_ in ops_under_test_inplace.items():
qY_hat = qX.clone()
op_(qY_hat)
self.assertEqual(qY, qY_hat, message="{} relu failed".format(name))
"""Tests the correctness of the quantized::relu op."""
@given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
qparams=hu.qparams()))
def test_qrelu6(self, X):
X, (scale, zero_point, torch_type) = X
Y = X.copy()
Y[Y < 0] = 0
Y[Y > 6.0] = 6.0
qY = torch.quantize_per_tensor(torch.from_numpy(Y), scale=scale,
zero_point=zero_point, dtype=torch_type)
X = torch.from_numpy(X)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)
ops_under_test = {
'ops.quantized': torch.ops.quantized.relu6,
'module': torch.nn.quantized.ReLU6(),
}
for name, op in ops_under_test.items():
qY_hat = op(qX)
self.assertEqual(qY, qY_hat, message="{} relu failed".format(name))
"""Tests the correctness of the scalar addition."""
@no_deadline
@given(A=hu.tensor(shapes=hu.array_shapes(1, 4, 1, 5),
elements=st.floats(-1e6, 1e6, allow_nan=False),
qparams=hu.qparams()),
b=st.floats(-1e6, 1e6, allow_nan=False, allow_infinity=False))
def test_qadd_scalar_relu(self, A, b):
import copy
add_scalar = torch.ops.quantized.add_scalar
add_scalar_relu = torch.ops.quantized.add_scalar_relu
A, (scale, zero_point, dtype) = A
A = A.astype(np.float32)
qA = torch.quantize_per_tensor(torch.from_numpy(A), scale, zero_point, dtype)
C = qA.dequantize() + round(b / scale) * scale
C_relu = copy.deepcopy(C)
C_relu[C_relu < 0] = 0
C_hat = add_scalar(qA, b)
C_ref = torch.quantize_per_tensor(C, C_hat.q_scale(), C_hat.q_zero_point(), dtype)
C_relu_hat = add_scalar_relu(qA, b)
C_relu_ref = torch.quantize_per_tensor(
C_relu, C_relu_hat.q_scale(), C_relu_hat.q_zero_point(), dtype)
self.assertEqual(C_ref.dequantize(), C_hat.dequantize(),
message="Scalar add results don't match:\
{} vs {}".format(C_ref.dequantize(), C_hat.dequantize()))
self.assertEqual(C_relu_ref.dequantize(), C_relu_hat.dequantize(),
message="Scalar add relu results don't match:\
{} vs {}".format(C_relu_ref.dequantize(), C_relu_hat.dequantize()))
"""Tests the correctness of the add and add_relu op."""
def test_qadd_relu_same_qparams(self):
for dtype in [torch.quint8, torch.qint8, torch.qint32]:
add_relu = torch.ops.quantized.add_relu
add = torch.ops.quantized.add
add_out = torch.ops.quantized.add_out
add_relu_out = torch.ops.quantized.add_relu_out
# NB: This is a strange size so that we exercise both the vectorized
# implementation (64-element chunks at at time) as well as the scalar
# implementation
A = torch.arange(-128, 130, dtype=torch.float)
B = torch.arange(-128, 130, dtype=torch.float)
scale = 2.0
zero_point = 127
qA = torch.quantize_per_tensor(A, scale=scale, zero_point=zero_point,
dtype=dtype)
qB = torch.quantize_per_tensor(B, scale=scale, zero_point=zero_point,
dtype=dtype)
# Add ReLU ground truth
C = (qA.dequantize() + qB.dequantize()).numpy()
np_dtype = {
torch.quint8 : np.uint8,
torch.qint8 : np.int8,
torch.qint32 : np.int32
}
qC = _quantize(C, scale, zero_point, dtype=np_dtype[dtype])
qC_hat = add(qA, qB, scale=scale, zero_point=zero_point)
np.testing.assert_equal(qC, qC_hat.int_repr(),
"Quantized addition failed.")
qC_out_hat = torch._empty_affine_quantized(qC.shape,
scale=scale,
zero_point=zero_point,
dtype=dtype)
add_out(qA, qB, out=qC_out_hat)
self.assertEqual(qC_hat, qC_out_hat, message="Add.out failed")
# Add + ReLU ground truth
Crelu = C.copy()
Crelu[C < 0] = 0
qCrelu = _quantize(Crelu, scale, zero_point, dtype=np_dtype[dtype])
qCrelu_hat = add_relu(qA, qB, scale=scale, zero_point=zero_point)
np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(),
"Quantized addition with ReLU failed.")
qCrelu_out_hat = torch._empty_affine_quantized(qCrelu.shape,
scale=scale,
zero_point=zero_point,
dtype=dtype)
add_relu_out(qA, qB, out=qCrelu_out_hat)
self.assertEqual(qCrelu_hat, qCrelu_out_hat,
message="AddReLU.out failed")
"""Tests the correctness of the add and add_relu op."""
def test_qadd_relu_different_qparams(self):
for dtype in [torch.quint8, torch.qint8, torch.qint32]:
add_relu = torch.ops.quantized.add_relu
add = torch.ops.quantized.add
add_out = torch.ops.quantized.add_out
add_relu_out = torch.ops.quantized.add_relu_out
# NB: This is a strange size so that we exercise both the vectorized
# implementation (64-element chunks at at time) as well as the scalar
# implementation
A = torch.arange(-128, 130, dtype=torch.float)
B = torch.arange(-128, 130, dtype=torch.float)
scale_A = 3.0
zero_point_A = 7
scale_B = 5.0
zero_point_B = 127
scale_C = 0.5
zero_point_C = 5
qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point_A,
dtype=dtype)
qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point_B,
dtype=dtype)
# Add ground truth
C = (qA.dequantize() + qB.dequantize()).numpy()
np_dtype = {
torch.quint8 : np.uint8,
torch.qint8 : np.int8,
torch.qint32 : np.int32
}
qC = _quantize(C, scale_C, zero_point_C, dtype=np_dtype[dtype])
qC_hat = add(qA, qB, scale=scale_C, zero_point=zero_point_C)
np.testing.assert_equal(qC, qC_hat.int_repr(),
"Quantized addition failed.")
qC_out_hat = torch._empty_affine_quantized(qC.shape,
scale=scale_C,
zero_point=zero_point_C,
dtype=dtype)
add_out(qA, qB, out=qC_out_hat)
self.assertEqual(qC_hat, qC_out_hat, message="Add.out failed")
# Add + ReLU ground truth
Crelu = C.copy()
Crelu[C < 0] = 0
qCrelu = _quantize(Crelu, scale_C, zero_point_C, dtype=np_dtype[dtype])
qCrelu_hat = add_relu(qA, qB, scale=scale_C, zero_point=zero_point_C)
np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(),
"Quantized addition with ReLU failed.")
qCrelu_out_hat = torch._empty_affine_quantized(qCrelu.shape,
scale=scale_C,
zero_point=zero_point_C,
dtype=dtype)
add_relu_out(qA, qB, out=qCrelu_out_hat)
self.assertEqual(qCrelu_hat, qCrelu_out_hat,
message="AddReLU.out failed")
"""Tests the correctness of the mul and mul_relu op."""
def test_qmul_relu_same_qparams(self):
mul_relu = torch.ops.quantized.mul_relu
mul = torch.ops.quantized.mul
mul_out = torch.ops.quantized.mul_out
mul_relu_out = torch.ops.quantized.mul_relu_out
A = torch.arange(-25, 25, dtype=torch.float)
B = torch.arange(-25, 25, dtype=torch.float)
scale = 2.0
zero_point = 127
qA = torch.quantize_per_tensor(A, scale=scale, zero_point=zero_point,
dtype=torch.quint8)
qB = torch.quantize_per_tensor(B, scale=scale, zero_point=zero_point,
dtype=torch.quint8)
# mul ReLU ground truth
C = (qA.dequantize() * qB.dequantize()).numpy()
qC = _quantize(C, scale, zero_point)
qC_hat = mul(qA, qB, scale=scale, zero_point=zero_point)
np.testing.assert_equal(qC, qC_hat.int_repr(),
"Quantized mulition failed.")
qC_out_hat = torch._empty_affine_quantized(qC.shape,
scale=scale,
zero_point=zero_point,
dtype=torch.quint8)
mul_out(qA, qB, out=qC_out_hat)
self.assertEqual(qC_hat, qC_out_hat, message="mul.out failed")
# mul + ReLU ground truth
Crelu = C.copy()
Crelu[C < 0] = 0
qCrelu = _quantize(Crelu, scale, zero_point)
qCrelu_hat = mul_relu(qA, qB, scale=scale, zero_point=zero_point)
np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(),
"Quantized mulition with ReLU failed.")
qCrelu_out_hat = torch._empty_affine_quantized(qCrelu.shape,
scale=scale,
zero_point=zero_point,
dtype=torch.quint8)
mul_relu_out(qA, qB, out=qCrelu_out_hat)
self.assertEqual(qCrelu_hat, qCrelu_out_hat,
message="mulReLU.out failed")
# Scalar multiplication
for b in B:
C_ref = qA.dequantize().numpy() * b.item()
qC_hat = torch.ops.quantized.mul_scalar(qA, b.item())
self.assertEqual(C_ref, qC_hat.dequantize())
# Scalar multiplication + relu
for b in B:
C_ref = qA.dequantize().numpy() * b.item()
C_ref[C_ref < 0] = 0
qC_hat = torch.ops.quantized.mul_scalar_relu(qA, b.item())
self.assertEqual(C_ref, qC_hat.dequantize())
"""Tests the correctness of the mul and mul_relu op."""
def test_qmul_relu_different_qparams(self):
mul_relu = torch.ops.quantized.mul_relu
mul = torch.ops.quantized.mul
mul_out = torch.ops.quantized.mul_out
mul_relu_out = torch.ops.quantized.mul_relu_out
A = torch.arange(-25, 25, dtype=torch.float)
B = torch.arange(-25, 25, dtype=torch.float)
scale_A = 3.0
zero_point_A = 7
scale_B = 5.0
zero_point_B = 127
scale_C = 0.5
zero_point_C = 5
qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point_A,
dtype=torch.quint8)
qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point_B,
dtype=torch.quint8)
# mul ground truth
C = (qA.dequantize() * qB.dequantize()).numpy()
qC = _quantize(C, scale_C, zero_point_C)
qC_hat = mul(qA, qB, scale=scale_C, zero_point=zero_point_C)
np.testing.assert_equal(qC, qC_hat.int_repr(),
"Quantized multiplication failed.")
qC_out_hat = torch._empty_affine_quantized(qC.shape,
scale=scale_C,
zero_point=zero_point_C,
dtype=torch.quint8)
mul_out(qA, qB, out=qC_out_hat)
self.assertEqual(qC_hat, qC_out_hat, message="mul.out failed")
# mul + ReLU ground truth
Crelu = C.copy()
Crelu[C < 0] = 0
qCrelu = _quantize(Crelu, scale_C, zero_point_C)
qCrelu_hat = mul_relu(qA, qB, scale=scale_C, zero_point=zero_point_C)
np.testing.assert_equal(qCrelu, qCrelu_hat.int_repr(),
"Quantized multiplication with ReLU failed.")
qCrelu_out_hat = torch._empty_affine_quantized(qCrelu.shape,
scale=scale_C,
zero_point=zero_point_C,
dtype=torch.quint8)
mul_relu_out(qA, qB, out=qCrelu_out_hat)
self.assertEqual(qCrelu_hat, qCrelu_out_hat,
message="mulReLU.out failed")
"""Tests max pool operation on quantized tensors."""
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4,
min_side=1, max_side=10),
qparams=hu.qparams()),
kernel=st.sampled_from((3, 5, 7)),
stride=st.sampled_from((None, 1, 2)),
dilation=st.integers(1, 2),
padding=st.integers(0, 2),
ceil_mode=st.booleans())
def test_max_pool2d(self, X, kernel, stride, dilation, padding, ceil_mode):
X, (scale, zero_point, torch_type) = X
# Check constraints
assume(kernel // 2 >= padding) # Kernel cannot be overhanging!
iH, iW = X.shape[-2:]
oH = pool_output_shape(iH, kernel, padding, stride, dilation, ceil_mode)
assume(oH > 0)
oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode)
assume(oW > 0)
a = torch.from_numpy(X)
a_pool = torch.nn.functional.max_pool2d(a, kernel_size=kernel,
stride=stride,
padding=padding, dilation=dilation,
ceil_mode=ceil_mode)
a_ref = torch.quantize_per_tensor(a_pool, scale=scale,
zero_point=zero_point, dtype=torch_type)
a_ref = a_ref.dequantize()
qa = torch.quantize_per_tensor(a, scale=scale, zero_point=zero_point,
dtype=torch_type)
ops_under_test = {
"torch": torch.max_pool2d,
"nn.functional": torch.nn.functional.max_pool2d,
"nn.quantized.functional": torch.nn.quantized.functional.max_pool2d
}
for name, op in ops_under_test.items():
a_hat = op(qa, kernel_size=kernel, stride=stride, padding=padding,
dilation=dilation, ceil_mode=ceil_mode)
self.assertEqual(a_ref, a_hat.dequantize(),
message="{} results are off".format(name))
# Test the ops.quantized separately, because None is not treated.
a_hat = torch.ops.quantized.max_pool2d(
qa, kernel_size=_pair(kernel),
stride=_pair(kernel if stride is None else stride),
padding=_pair(padding), dilation=_pair(dilation), ceil_mode=ceil_mode)
self.assertEqual(a_ref, a_hat.dequantize(),
message="ops.quantized.max_pool2d results are off")
"""Tests max pool operation on NHWC quantized tensors."""
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
min_side=1, max_side=10),
qparams=hu.qparams()),
kernel=st.sampled_from((3, 5, 7)),
stride=st.sampled_from((None, 1, 2)),
dilation=st.integers(1, 2),
padding=st.integers(0, 2),
ceil_mode=st.booleans())
def test_max_pool2d_nhwc(self, X, kernel, stride, dilation, padding, ceil_mode):
X, (scale, zero_point, torch_type) = X
# Ensure we hit the vectorized paths
# 176 = 128 + 32 + 16
# 128 hits the interleaved path
# 32 hits the non-interleaved path
# 16 hits the scalar path
if X.shape[1] < 176:
X = np.repeat(X, 176 / X.shape[1], 1)
# Check constraints
assume(kernel // 2 >= padding) # Kernel cannot be overhanging!
iH, iW = X.shape[-2:]
oH = pool_output_shape(iH, kernel, padding, stride, dilation, ceil_mode)
assume(oH > 0)
oW = pool_output_shape(iW, kernel, padding, stride, dilation, ceil_mode)
assume(oW > 0)
X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 1]))
a = torch.from_numpy(X_nchw).permute([0, 3, 1, 2])
a_pool = torch.nn.functional.max_pool2d(a, kernel_size=kernel,
stride=stride,
padding=padding, dilation=dilation,
ceil_mode=ceil_mode)
a_ref = torch.quantize_per_tensor(a_pool, scale=scale,
zero_point=zero_point, dtype=torch_type)
a_ref = a_ref.dequantize()
qa = torch.quantize_per_tensor(torch.from_numpy(X_nchw), scale=scale, zero_point=zero_point,
dtype=torch_type).permute([0, 3, 1, 2])
self.assertTrue(qa.stride() != sorted(qa.stride()))
ops_under_test = {
"torch": torch.max_pool2d,
"nn.functional": torch.nn.functional.max_pool2d,
"nn.quantized.functional": torch.nn.quantized.functional.max_pool2d
}
for name, op in ops_under_test.items():
a_hat = op(qa, kernel_size=kernel, stride=stride, padding=padding,
dilation=dilation, ceil_mode=ceil_mode)
self.assertTrue(a_hat.stride() != sorted(a_hat.stride()))
self.assertEqual(a_ref, a_hat.dequantize(),
message="{} results are off".format(name))
# Test the ops.quantized separately, because None is not treated.
a_hat = torch.ops.quantized.max_pool2d(
qa, kernel_size=_pair(kernel),
stride=_pair(kernel if stride is None else stride),
padding=_pair(padding), dilation=_pair(dilation), ceil_mode=ceil_mode)
self.assertEqual(a_ref, a_hat.dequantize(),
message="ops.quantized.max_pool2d results are off")
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4,
min_side=5, max_side=10),
qparams=hu.qparams(dtypes=torch.quint8)),
kernel=st.sampled_from((3, 5)),
stride=st.sampled_from((None, 1, 2)),
padding=st.integers(0, 2),
ceil_mode=st.sampled_from((True, False)),
count_include_pad=st.sampled_from((True, False)),
divisor_override=st.sampled_from((None, None)))
def test_avg_pool2d(self, X, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override):
"""
Note: we currently cannot test the divisor_override, because quantized op will clamp the result
within range. However, the float op will not.
"""
X, (scale, zero_point, torch_type) = X
assume(kernel // 2 >= padding) # Kernel cannot be overhanging!
iH, iW = X.shape[-2:]
oH = pool_output_shape(iH, kernel, padding, stride, 0)
assume(oH > 0)
oW = pool_output_shape(iW, kernel, padding, stride, 0)
assume(oW > 0)
X = torch.from_numpy(X)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)
# Run reference on int_repr + round to avoid double rounding error.
X_ref = torch.nn.functional.avg_pool2d(
qX.int_repr().to(torch.float), kernel_size=kernel, stride=stride, padding=padding,
ceil_mode=ceil_mode, count_include_pad=count_include_pad, divisor_override=divisor_override).round()
ops_under_test = {
"nn.functional": torch.nn.functional.avg_pool2d,
"nn.quantized.functional": torch.nn.quantized.functional.avg_pool2d
}
error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}"
for name, op in ops_under_test.items():
qX_hat = op(qX, kernel_size=kernel, stride=stride, padding=padding, ceil_mode=ceil_mode,
count_include_pad=count_include_pad, divisor_override=divisor_override)
self.assertEqual(X_ref, qX_hat.int_repr(), prec=1.0,
message="{} results are off".format(name, qX_hat.int_repr(), X_ref))
self.assertEqual(scale, qX_hat.q_scale(),
message=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
self.assertEqual(zero_point, qX_hat.q_zero_point(),
message=error_message.format(name + '.zero_point', scale,
qX_hat.q_zero_point()))
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
min_side=5, max_side=10),
qparams=hu.qparams(dtypes=torch.qint8)),
kernel=st.sampled_from((4, 5)),
stride=st.sampled_from((None, 1, 2)),
padding=st.integers(0, 2),
ceil_mode=st.sampled_from((True, False)),
count_include_pad=st.sampled_from((True, False)),
divisor_override=st.sampled_from((None, None)))
def test_avg_pool2d_nhwc(self, X, kernel, stride, padding, ceil_mode, count_include_pad, divisor_override):
"""
Note: 1) we currently cannot test the divisor_override, because quantized op will clamp the result
within range. However, the float op will not.
2) we cannot test the qint32, since the float point precision is much lower than int32 for big number,
which will make the test be very flaky.
"""
X, (scale, zero_point, torch_type) = X
H, W = X.shape[-2:]
if X.shape[1] < 176:
X = np.repeat(X, 176 / X.shape[1], 1)
X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 1]))
X = torch.from_numpy(X_nchw).permute([0, 3, 1, 2])
qX = torch.quantize_per_tensor(torch.from_numpy(X_nchw), scale=scale,
zero_point=zero_point, dtype=torch_type).permute([0, 3, 1, 2])
# Run reference on int_repr + round to avoid double rounding error.
X_ref = torch.nn.functional.avg_pool2d(
qX.int_repr().to(torch.double), kernel_size=kernel, stride=stride, padding=padding,
ceil_mode=ceil_mode, count_include_pad=count_include_pad, divisor_override=divisor_override).round()
self.assertTrue(qX.stride() != sorted(qX.stride()))
ops_under_test = {
"nn.functional": torch.nn.functional.avg_pool2d,
"nn.quantized.functional": torch.nn.quantized.functional.avg_pool2d
}
error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}"
for name, op in ops_under_test.items():
X_hat = op(qX, kernel_size=kernel, stride=stride, padding=padding, ceil_mode=ceil_mode,
count_include_pad=count_include_pad, divisor_override=divisor_override)
self.assertTrue(X_hat.stride() != sorted(X_hat.stride()))
self.assertEqual(X_ref, X_hat.int_repr().to(torch.double), prec=1.0,
message="{} results are off".format(name))
self.assertEqual(scale, X_hat.q_scale(),
message=error_message.format(name + '.scale', scale, X_hat.q_scale()))
self.assertEqual(zero_point, X_hat.q_zero_point(),
message=error_message.format(name + '.zero_point', scale,
X_hat.q_zero_point()))
@no_deadline
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
min_side=1, max_side=10),
qparams=hu.qparams(dtypes=torch.quint8)),
output_size_h=st.integers(1, 10),
output_size_w=st.integers(1, 10))
def test_adaptive_avg_pool2d(self, X, output_size_h, output_size_w):
X, (scale, zero_point, torch_type) = X
H, W = X.shape[-2:]
assume(output_size_h <= H)
assume(output_size_w <= W)
if output_size_h == output_size_w:
output_size = output_size_h
else:
output_size = (output_size_h, output_size_w)
X = torch.from_numpy(X)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)
# Run reference on int_repr + round to avoid double rounding error.
X_ref = torch.nn.functional.adaptive_avg_pool2d(
qX.int_repr().to(torch.float), output_size).round()
ops_under_test = {
"nn.functional": torch.nn.functional.adaptive_avg_pool2d,
"nn.quantized.functional":
torch.nn.quantized.functional.adaptive_avg_pool2d
}
error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}"
for name, op in ops_under_test.items():
qX_hat = op(qX, output_size=output_size)
self.assertEqual(X_ref, qX_hat.int_repr(), prec=1.0,
message=error_message.format(name, X_ref, qX_hat))
self.assertEqual(scale, qX_hat.q_scale(),
message=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
self.assertEqual(zero_point, qX_hat.q_zero_point(),
message=error_message.format(name + '.zero_point', scale,
qX_hat.q_zero_point()))
"""Tests adaptive average pool operation on NHWC quantized tensors."""
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
min_side=1, max_side=10),
qparams=hu.qparams(dtypes=torch.qint8)),
output_size_h=st.integers(1, 10),
output_size_w=st.integers(1, 10))
def test_adaptive_avg_pool2d_nhwc(self, X, output_size_h, output_size_w):
X, (scale, zero_point, torch_type) = X
H, W = X.shape[-2:]
assume(output_size_h <= H)
assume(output_size_w <= W)
if output_size_h == output_size_w:
output_size = output_size_h
else:
output_size = (output_size_h, output_size_w)
if X.shape[1] < 176:
X = np.repeat(X, 176 / X.shape[1], 1)
X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 1]))
X = torch.from_numpy(X_nchw).permute([0, 3, 1, 2])
qX = torch.quantize_per_tensor(torch.from_numpy(X_nchw), scale=scale,
zero_point=zero_point, dtype=torch_type).permute([0, 3, 1, 2])
# Run reference on int_repr + round to avoid double rounding error.
X_ref = torch.nn.functional.adaptive_avg_pool2d(qX.int_repr().to(torch.double), output_size).round()
self.assertTrue(qX.stride() != sorted(qX.stride()))
ops_under_test = {
"nn.functional": torch.nn.functional.adaptive_avg_pool2d,
"nn.quantized.functional":
torch.nn.quantized.functional.adaptive_avg_pool2d
}
error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}"
for name, op in ops_under_test.items():
X_hat = op(qX, output_size=output_size)
self.assertTrue(X_hat.stride() != sorted(X_hat.stride()))
self.assertEqual(X_ref, X_hat.int_repr(), prec=1.0,
message="{} results are off".format(name))
self.assertEqual(scale, X_hat.q_scale(),
message=error_message.format(name + '.scale', scale, X_hat.q_scale()))
self.assertEqual(zero_point, X_hat.q_zero_point(),
message=error_message.format(name + '.zero_point', scale,
X_hat.q_zero_point()))
@no_deadline
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4,
min_side=1, max_side=10),
qparams=hu.qparams()),
k=st.integers(1, 10),
dim=st.integers(1, 4),
largest=st.booleans(),
sorted=st.booleans())
def test_qtopk(self, X, k, dim, largest, sorted):
X, (scale, zero_point, torch_type) = X
qX = torch.quantize_per_tensor(torch.from_numpy(X), scale, zero_point, torch_type)
assume(dim < X.ndim)
assume(k < X.shape[dim])
unquantized_out = torch.topk(qX.dequantize(), k, dim=dim, largest=largest, sorted=sorted)
values = torch.quantize_per_tensor(torch.from_numpy(X), scale, zero_point, torch_type)
indices = torch.tensor(torch.from_numpy(X)).long()
quantized_out = torch.topk(qX, k, dim=dim, largest=largest, sorted=sorted)
assert(len(unquantized_out) == len(quantized_out))
torch.testing.assert_allclose(quantized_out[0].dequantize(), unquantized_out[0])
torch.testing.assert_allclose(quantized_out[1], unquantized_out[1])
@no_deadline
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
min_side=1, max_side=10),
qparams=hu.qparams()),
k=st.integers(1, 10),
dim=st.integers(1, 4),
largest=st.booleans(),
sorted=st.booleans())
def test_qtopk_nhwc(self, X, k, dim, largest, sorted):
# X is NHWC, we permute to view as NCHW but keep NHWC in memory
X, (scale, zero_point, torch_type) = X
qX = torch.quantize_per_tensor(torch.from_numpy(X), scale, zero_point, torch_type).permute([0, 3, 1, 2])
X = np.transpose(X, [0, 3, 1, 2])
assume(dim < X.ndim)
assume(k < X.shape[dim])
unquantized_out = torch.topk(qX.dequantize(), k, dim=dim, largest=largest, sorted=sorted)
values = torch.quantize_per_tensor(torch.from_numpy(X), scale, zero_point, torch_type)
indices = torch.tensor(torch.from_numpy(X)).long()
quantized_out = torch.topk(qX, k, dim=dim, largest=largest, sorted=sorted)
assert(len(unquantized_out) == len(quantized_out))
torch.testing.assert_allclose(quantized_out[0].dequantize(), unquantized_out[0])
torch.testing.assert_allclose(quantized_out[1], unquantized_out[1])
"""Tests quantize concatenation (both fused and not)."""
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4,
min_side=1, max_side=10),
qparams=hu.qparams()),
num=st.integers(1, 4),
dim=st.integers(1, 4),
relu=st.booleans())
def test_cat(self, X, num, dim, relu):
tensors_q = []
tensors_ref = []
X, (scale, zero_point, torch_type) = X
assume(dim < X.ndim)
X = torch.from_numpy(X)
new_shape = np.array(X.shape)
new_shape[dim] = 0
for idx in range(num):
tensors_q.append(torch.quantize_per_tensor(X, scale, zero_point,
torch_type))
tensors_ref.append(X)
new_shape[dim] += tensors_ref[-1].shape[dim]
cat_ref = torch.cat(tensors_ref, dim=dim)
cat_ref = torch.quantize_per_tensor(cat_ref, scale, zero_point, torch_type)
cat_ref = cat_ref.dequantize()
if relu:
cat_ref = F.relu(cat_ref)
q_cat_op = torch.ops.quantized.cat_relu
q_cat_out_op = torch.ops.quantized.cat_relu_out
else:
q_cat_op = torch.ops.quantized.cat
q_cat_out_op = torch.ops.quantized.cat_out
cat_q = q_cat_op(tensors_q, dim=dim, scale=scale,
zero_point=zero_point)
cat_q = cat_q.dequantize()
np.testing.assert_equal(cat_ref.numpy(), cat_q.numpy())
cat_q_out = torch._empty_affine_quantized(
list(new_shape), scale=scale,
zero_point=zero_point, dtype=torch_type)
q_cat_out_op(tensors_q, dim=dim, out=cat_q_out)
cat_q_out = cat_q_out.dequantize()
np.testing.assert_equal(cat_ref.numpy(), cat_q_out.numpy())
# Test the cat on per-channel quantized tensor.
ch_axis = 1
scales = torch.from_numpy(np.array([1.0] * X.shape[ch_axis]))
scales = scales.to(torch.float64)
zero_points = torch.from_numpy(np.array([0] * X.shape[ch_axis]))
zero_points = zero_points.to(torch.long)
tensors_q[0] = torch.quantize_per_channel(
X, scales, zero_points, axis=ch_axis, dtype=torch_type)
with self.assertRaisesRegex(RuntimeError, "supported.*cat"):
cat_q = q_cat_op(tensors_q, dim=ch_axis, scale=scale,
zero_point=zero_point)
@no_deadline
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
min_side=5, max_side=10),
qparams=hu.qparams()),
size=st.sampled_from((1, 3, 5, 10)),
mode=st.sampled_from(("bilinear", "nearest")),
scale_factor=st.sampled_from((None, 1.5, 2.0)),
align_corners=st.sampled_from((True, False)),
nhwc_layout=st.sampled_from((True, False)))
def test_interpolate(self, X, size, mode, scale_factor, align_corners, nhwc_layout):
"""
This test cover upsample_nearest2d and upsample_bilinear2d
"""
X, (scale, zero_point, torch_type) = X
H, W = X.shape[-2:]
if scale_factor is not None:
size = None
if mode == "nearest":
align_corners = None
if nhwc_layout:
if X.shape[1] < 176:
X = np.repeat(X, 176 / X.shape[1], 1)
X_nchw = np.ascontiguousarray(X.transpose([0, 2, 3, 1]))
X = torch.from_numpy(X_nchw).permute([0, 3, 1, 2])
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type).permute([0, 3, 1, 2])
else:
X = torch.from_numpy(X)
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)
X_ref = torch.nn.functional.interpolate(
qX.int_repr().to(torch.float), size=size, scale_factor=scale_factor,
mode=mode, align_corners=align_corners)
ops_under_test = {
"nn.functional": torch.nn.functional.interpolate,
"nn.quantized.functional": torch.nn.quantized.functional.interpolate
}
error_message = r"Results are off for {}:\n\tExpected:\n{}\n\tGot:\n{}"
for name, op in ops_under_test.items():
qX_hat = op(qX, size=size, scale_factor=scale_factor,
mode=mode, align_corners=align_corners)
self.assertEqual(X_ref, qX_hat.int_repr(), prec=1.0,
message="{} results are off".format(name, qX_hat.int_repr(), X_ref))
self.assertEqual(scale, qX_hat.q_scale(),
message=error_message.format(name + '.scale', scale, qX_hat.q_scale()))
self.assertEqual(zero_point, qX_hat.q_zero_point(),
message=error_message.format(name + '.zero_point', scale,
qX_hat.q_zero_point()))
"""Tests quantize concatenation (both fused and not)."""
@no_deadline
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=4, max_dims=4,
min_side=1, max_side=10),
qparams=hu.qparams()),
relu=st.booleans())
def test_cat_nhwc(self, X, relu):
# X is NHWC
X, (scale, zero_point, torch_type) = X
# Tile out X so # channels is > 64
X = np.repeat(X, 70 / X.shape[3], 3)
X = torch.from_numpy(np.ascontiguousarray(X))
Y = X.clone()
Y = torch.from_numpy(np.ascontiguousarray(Y))
# Here, we quantize and get quantized tensors in NHWC for both dims and strides. The
# permute switches it so that the tensor looks like NCHW but it laid out in memory as
# NHWC.
qX = torch.quantize_per_tensor(X, scale, zero_point, torch_type).permute([0, 3, 1, 2])
qY = torch.quantize_per_tensor(Y, scale, zero_point, torch_type).permute([0, 3, 1, 2])
ref = torch.cat([qX.dequantize(), qY.dequantize()], dim=1)
if relu:
ref[ref < 0] = 0.0
ref = torch.quantize_per_tensor(ref, scale=scale, zero_point=zero_point, dtype=torch_type)
if relu:
out = torch.ops.quantized.cat_relu(
[qX, qY], dim=1, scale=scale, zero_point=zero_point)
else:
out = torch.ops.quantized.cat([qX, qY], dim=1, scale=scale, zero_point=zero_point)
torch.testing.assert_allclose(out.dequantize(), ref.dequantize())
self.assertNotEqual(out.stride(), sorted(out.stride()))
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=3,
min_side=1, max_side=2),
qparams=hu.qparams()),
dim=st.integers(1, 2))
def test_mean(self, X, dim):
X, (scale, zero_point, torch_type) = X
qX = torch.quantize_per_tensor(torch.tensor(X).float(), scale, zero_point, torch_type)
Y = torch.mean(qX.dequantize(), dim)
Y = torch.quantize_per_tensor(Y, scale, zero_point, torch_type).dequantize()
qY = torch.mean(qX, dim)
self.assertEqual(Y, qY.dequantize())
"""Tests the correctness of the quantized equal op."""
@unittest.skip("temporarily disable until failures are fixed. " +
"See https://github.com/pytorch/pytorch/issues/26279")
@given(X=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
qparams=hu.qparams()),
X2=hu.tensor(shapes=hu.array_shapes(1, 5, 1, 5),
qparams=hu.qparams()),
X_per_channel=st.booleans(),
X2_per_channel=st.booleans())
def test_equal(self, X, X2, X_per_channel, X2_per_channel):
X, X_params = X
(scale, zero_point, torch_type) = X_params
X2, X2_params = X2
(scale2, zero_point2, torch_type2) = X2_params
X = torch.from_numpy(X)
if X_per_channel:
X_scheme = 'per_channel'
channels = X.shape[-1]
qX = torch.quantize_per_channel(
X,
scales=torch.tensor([scale] * channels),
zero_points=torch.tensor([zero_point] * channels),
dtype=torch_type,
axis=X.ndim - 1)
else:
X_scheme = 'per_tensor'
qX = torch.quantize_per_tensor(X, scale=scale, zero_point=zero_point,
dtype=torch_type)
X2 = torch.from_numpy(X2)
if X2_per_channel:
X2_scheme = 'per_channel'
channels = X2.shape[-1]
qX2 = torch.quantize_per_channel(
X2,
scales=torch.tensor([scale2] * channels),
zero_points=torch.tensor([zero_point2] * channels),
dtype=torch_type2,
axis=X2.ndim - 1)
else:
X2_scheme = 'per_tensor'
qX2 = torch.quantize_per_tensor(X2, scale=scale2, zero_point=zero_point2,
dtype=torch_type2)
def equal_ref(qX, qX2):
if qX.qscheme() != qX2.qscheme():
return False
if qX.shape != qX2.shape:
return False
if qX.qscheme() == torch.per_tensor_affine:
if qX.q_scale() != qX2.q_scale():
return False
if qX.q_zero_point() != qX2.q_zero_point():
return False
elif qX.qscheme() == torch.per_channel_affine:
if (qX.q_per_channel_scales() !=
qX2.q_per_channel_scales()).any():
return False
if (qX.q_per_channel_zero_points() !=
qX2.q_per_channel_zero_points()).any():
return False
else:
raise NotImplementedError("Don't know what to do with",
qX.qscheme())
if (qX.int_repr().to(float) != qX2.int_repr().to(float)).any():
return False
return True
self.assertEqual(qX.equal(qX), equal_ref(qX, qX))
self.assertEqual(qX.equal(qX2), equal_ref(qX, qX2))
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.")
class TestDynamicQuantizedLinear(TestCase):
"""Tests the correctness of the dynamic quantized linear and linear_relu op."""
@no_deadline
@given(
batch_size=st.integers(1, 4),
input_channels=st.integers(16, 32),
output_channels=st.integers(4, 8),
use_bias=st.booleans(),
use_relu=st.booleans(),
use_multi_dim_input=st.booleans(),
use_channelwise=st.booleans())
def test_qlinear(self, batch_size, input_channels, output_channels,
use_bias, use_relu, use_multi_dim_input, use_channelwise):
qlinear_prepack = torch.ops.quantized.linear_prepack
if use_relu:
qlinear_dynamic = torch.ops.quantized.linear_relu_dynamic
else:
qlinear_dynamic = torch.ops.quantized.linear_dynamic
if use_multi_dim_input:
batch_size *= 3 # Test the multi-dim input tensor
X_scale = 1.0