/
pytorch.py
3301 lines (2765 loc) · 118 KB
/
pytorch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, too-many-lines, len-as-condition, no-else-return, unused-variable, too-many-nested-blocks
# pylint: disable=consider-iterating-dictionary, invalid-name, unused-argument, unused-variable, broad-except
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda
# pylint: disable=missing-function-docstring
"""PT: PyTorch frontend."""
import itertools
import logging
import math
import sys
import numpy as np
import tvm
from tvm.ir import IRModule
from tvm.topi.utils import get_const_tuple
from .. import analysis as _analysis
from .. import expr as _expr
from .. import function as _function
from .. import op as _op
from .. import qnn, transform
from ..expr_functor import ExprMutator
from ..loops import while_loop
from ..prelude import Prelude, StaticTensorArrayOps
from ..ty import Any, TensorType, TupleType
from . import qnn_torch
from .common import AttrCvt, get_relay_op
from .common import infer_value as _infer_value
from .common import infer_value_simulated as _infer_value_simulated
from .common import try_infer_value
from .pytorch_utils import is_version_greater_than
__all__ = ["from_pytorch"]
# This returns a "subgraph" which puts variables whenever
# the type is known. It also records things to map the input
# nodes to the extracted graph's nodes.
# As Python objects are not round-trippable through C++, and
# our type annotations only live in Python, we need to map
# the we need to map the nodes we get in visiting to the nodes
# we used to construct the graph (they are the same in C++,
# match each other in dictionary lookups, but are not the same
# in Python) by using the hint dictionary filled as
# {node: node for node in nodes} to get the type annotations.
# https://discuss.tvm.apache.org/t/round-tripping-objects-through-the-ffi/8440
class _TypeFinder(ExprMutator):
def __init__(self, types):
super().__init__()
self.counter = 0
self.vars = {}
self.types = types
self.leave = set() # some variables are not inputs
def visit_let(self, let):
self.leave.add(let.var)
return super().visit_let(let)
def visit_function(self, fn):
self.leave.update(fn.params)
return super().visit_function(fn)
def visit(self, expr):
if expr in self.leave:
return super().visit(expr)
if expr in self.vars:
return self.vars[expr]
if isinstance(expr, tvm.relay.Var):
self.vars[expr] = expr
return expr
if expr in self.types:
ty = self.types[expr]
v = tvm.relay.var(f"_{self.counter}", type_annotation=ty)
self.counter += 1
self.vars[expr] = v
return v
v = super().visit(expr)
return v
def _should_construct_dynamic_list(list_construct_node):
# if this list is element-accessed or modified at runtime, generate List ADT
def inplace_add_to_add(op_name):
if op_name == "aten::add_":
return "aten::add"
else:
return op_name
uses = _get_uses(list_construct_node)
for loop_use in filter(lambda use: use.user.kind() == "prim::Loop", uses):
block_input_index = loop_use.offset - 1
block = list(loop_use.user.blocks())[0]
list_loop_var = list(block.inputs())[block_input_index]
uses += _get_uses(list_loop_var.node())
op_names = map(inplace_add_to_add, set(use.user.kind() for use in uses))
list_ops = set(["aten::add", "aten::__getitem__"])
intersect = list_ops.intersection(op_names)
if len(intersect) > 0 and intersect != set(["aten::add"]):
return True
# if add op outputs list, it is dynamic so we need to construct List ADT
for use in filter(lambda use: use.user.kind() in ["aten::add", "aten::add_"], uses):
output_type = _get_node_type(use.user)
if output_type == "ListType":
return True
return False
def _is_int_seq(seq):
# TODO (t-vi): handle non-int constants? (like numpy.intXX)
return len(seq) > 0 and all([isinstance(i, int) for i in seq])
# operator implementation
class PyTorchOpConverter:
"""A helper class for holding PyTorch op converters."""
def __init__(self, prelude, default_dtype):
self.prelude = prelude
self.default_dtype = default_dtype
self.create_convert_map()
self.types = {} # map from nodes to (Relay) type annotations
# this incrementally infers the type, see the comments on the type visitor
# above.
def infer_type(self, node, mod=None):
"""An incremental method to infer the type of a node in the relay graph."""
if node in self.types:
return self.types[node]
if isinstance(node, tvm.relay.Var):
return node.type_annotation
tf = _TypeFinder(types=self.types)
new_node = tf.visit(node)
fn = _function.Function(list(tf.vars.values()), new_node)
new_mod = IRModule({"main": fn})
if mod is not None:
new_mod.update(mod)
new_mod = transform.RemoveUnusedFunctions()(new_mod)
new_mod = transform.InferType()(new_mod)
entry = new_mod["main"]
ty = entry.body.checked_type
self.types[node] = ty
return self.types[node]
def infer_type_with_prelude(self, val):
body = self.infer_type(val, self.prelude.mod)
return body
# list ADT utilities
def convert_to_list_adt(self, py_lst):
elem_tys = [self.infer_type_with_prelude(elem) for elem in py_lst]
msg = "List elements should have identical types"
assert all(map(lambda ty: ty == elem_tys[0], elem_tys)), msg
# get_type returns type_name, ctor1, ..., ctorN
# 1 is nil
_, cons, nil = self.prelude.mod.get_type("List")
adt_lst = nil()
for elem in reversed(py_lst):
adt_lst = cons(elem, adt_lst)
return adt_lst
def map_tensor_array_constructor(self, adt_lst, shape):
static_tensor_array_ops = StaticTensorArrayOps(self.prelude, "float32", shape)
static_tensor_array_ops.register()
tensor_create = self.prelude.get_tensor_ctor_static("tensor_constructor", "float32", shape)
return self.prelude.map(tensor_create, adt_lst)
def convert_to_tensor_array(self, adt_lst):
_, cons, nil = self.prelude.mod.get_type("List")
if self.prelude.length(adt_lst) == 0:
return nil()
checked_type = self.infer_type_with_prelude(self.prelude.hd(adt_lst))
shape = checked_type.shape
tensor_array = self.map_tensor_array_constructor(adt_lst, shape)
return tensor_array, tuple(shape)
def infer_shape(self, inputs, mod=None):
"""A method to get the output type of an intermediate node in the graph."""
typ = self.infer_type(inputs, mod=mod)
if hasattr(typ, "shape"):
# Regular operator that outputs tensors
return get_const_tuple(typ.shape)
# The return type is not a tensor, for example List
return typ
def infer_shape_with_prelude(self, inputs):
return self.infer_shape(inputs, mod=self.prelude.mod)
def record_output_type(self, output):
if isinstance(output, tuple):
cleaned_output = [o for o in output if o is not None]
types = self.infer_type_with_prelude(_expr.Tuple(cleaned_output))
for o, t in zip(cleaned_output, types.fields):
self.types[o] = t
elif isinstance(output, _expr.Expr):
self.infer_type_with_prelude(output)
# it can also happen that the type is int or so
def pytorch_promote_types(self, inputs, dtypes):
"""This promotes TVM inputs with TVM dtypes passed like PyTorch would"""
actual_dtypes = []
for i, inp in enumerate(inputs):
if isinstance(inp, _expr.Expr):
idt = self.infer_type(inp).dtype
actual_dtypes.append(idt)
else:
actual_dtypes.append(dtypes[i])
dtypes = actual_dtypes
tensor_dtypes = [dt for inp, dt in zip(inputs, dtypes) if not np.isscalar(inp)]
non_tensor_inputs = [inp for inp in inputs if np.isscalar(inp)]
result_type = _pytorch_result_type(tensor_dtypes, non_tensor_inputs)
results = []
for inp, dt in zip(inputs, dtypes):
if np.isscalar(inp):
results.append(_expr.const(inp, dtype=result_type))
elif dt == result_type:
results.append(inp)
else:
results.append(_op.cast(inp, result_type))
return results
def is_quantized_tensor(self, data):
# If a quantized Torch module is saved and loaded back, dtype will be dropped
# Since dtypes from Torch tensors are not reliable in such cases, we use
# Relay's type inference result to decide if an input tensor is quantized
ty = self.infer_type_with_prelude(data)
return ty.dtype == "uint8"
# Operator implementations
def make_elemwise(self, name):
def elemwise(inputs, input_types):
data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2])
return get_relay_op(name)(data0, data1)
return elemwise
def min_max_common(self, name_elemwise, name_reduce, inputs, input_types):
if len(inputs) == 1:
data = self.pytorch_promote_types(inputs[:1], input_types[:1])
return get_relay_op(name_reduce)(data[0])
elif len(inputs) >= 2 and isinstance(inputs[1], int):
data = self.pytorch_promote_types(inputs[:1], input_types[:1])
dim = inputs[1]
keepdims = inputs[2] if len(inputs) > 2 else False
# also return dummy indices
return get_relay_op(name_reduce)(data[0], axis=dim, keepdims=keepdims), None
else:
data0, data1 = self.pytorch_promote_types(inputs[:2], input_types[:2])
return get_relay_op(name_elemwise)(data0, data1)
def max(self, inputs, input_types):
return self.min_max_common("maximum", "max", inputs, input_types)
def min(self, inputs, input_types):
return self.min_max_common("minimum", "min", inputs, input_types)
def make_unary(self, name):
def unary(inputs, input_types):
# this is just to ensure tensor input
(data,) = self.pytorch_promote_types(inputs[:1], input_types[:1])
return get_relay_op(name)(data)
return unary
def log1p(self, inputs, input_types):
# 1_plus_log x = log(x + 1)
(dtype,) = input_types
one = _expr.const(1, dtype=dtype)
return _op.log(inputs[0] + one)
def arange(self, inputs, input_types):
def _get_value(val, dtype):
# dtype is a tvm dtype
if isinstance(val, _expr.Expr):
inp = _op.cast(val, dtype)
ret, _ = try_infer_value(inp, lambda ret: _expr.const(ret, dtype))
else:
ret = _create_typed_const(val, dtype)
return ret
def _get_type(val, inp_type):
if isinstance(val, _expr.Expr):
dtype = str(self.infer_type(val))
return dtype
return inp_type
# PyTorch arange uses the following type semantics:
# - if a dtype is given, start, stop, step are converted to that dtype
# - if no dtype is given and all args are integral, dtype is int64
# - if no dtype is given and there is a float arg, dtype is float32
if len(inputs) == 5:
dtype0 = _get_type(inputs[0], input_types[0])
if inputs[1] is not None:
dtype = _convert_dtype_value(inputs[1])
elif dtype0.startswith("float"):
dtype = "float32"
else:
dtype = "int64"
start = _expr.const(0, dtype)
stop = _get_value(inputs[0], dtype)
step = _expr.const(1, dtype)
elif len(inputs) == 7:
types = [_get_type(inputs[i], input_types[i]) for i in range(3)]
if inputs[3] is not None:
dtype = _convert_dtype_value(inputs[3])
elif any([t.startswith("float") for t in types]):
dtype = "float32"
else:
dtype = "int64"
start = _get_value(inputs[0], dtype)
stop = _get_value(inputs[1], dtype)
step = _get_value(inputs[2], dtype)
else:
msg = "Unknown number of arguments (%d) to parse." % (len(inputs))
raise AssertionError(msg)
return _op.transform.arange(start=start, stop=stop, step=step, dtype=dtype)
def squeeze(self, inputs, input_types):
data = inputs[0]
if len(inputs) == 1:
axis = None
else:
# TODO (t-vi): why is the cast to int needed? similarly elsewhere
axis = [int(inputs[1])]
return _op.transform.squeeze(data, axis)
def unsqueeze(self, inputs, input_types):
data = inputs[0]
axis = inputs[1]
return _op.transform.expand_dims(data, int(axis), 1)
def concatenate(self, inputs, input_types):
def tensor_array_concat(lst, axis):
assert axis == 0, "Tensor array concat supported only for axis 0"
tensor_array, shape = self.convert_to_tensor_array(lst)
concat_shape = (Any(),) + shape[1:]
concat = self.prelude.get_global_var_static("tensor_array_concat", "float32", shape)
concatenated = concat(tensor_array)
static_tensor_array_ops = StaticTensorArrayOps(self.prelude, "float32", concat_shape)
static_tensor_array_ops.register()
get_tensor = self.prelude.get_global_var_static(
"tensor_get_data", "float32", concat_shape
)
return get_tensor(concatenated)
data = inputs[0]
axis = inputs[1]
if not isinstance(data, list):
return tensor_array_concat(data, axis)
if isinstance(data, _expr.Expr):
data = [data]
return _op.tensor.concatenate(data, int(axis))
def slice(self, inputs, input_types):
axis_dtype = "int64"
index_size_limit = sys.maxsize
data = inputs[0]
dshape = self.infer_shape(data)
ndim = len(dshape)
dim = int(inputs[1])
stride = inputs[4]
target_begin, is_begin_const = try_infer_value(
inputs[2], lambda ret: np.asscalar(ret.astype(np.int))
)
target_end, is_end_const = try_infer_value(
inputs[3], lambda ret: np.asscalar(ret.astype(np.int))
)
# A fast path when slicing is nop.
if (
isinstance(target_begin, int)
and isinstance(target_end, int)
and target_begin == 0
and target_end >= index_size_limit
and stride == 1
):
return data
# Process begin
begin = [0] * ndim
begin[dim] = target_begin
if not isinstance(begin[dim], int):
tmp = []
for b in begin:
if isinstance(b, int):
tmp.append(_op.expand_dims(_expr.const(b, axis_dtype), axis=0))
else:
tmp.append(_op.cast(_op.expand_dims(b, axis=0), axis_dtype))
begin = _op.concatenate(tmp, axis=0)
btype = self.infer_type(begin).dtype
if str(btype) != axis_dtype:
begin = _op.cast(begin, axis_dtype)
# Process end
if isinstance(target_end, int) and target_end >= index_size_limit:
target_end = dshape[dim]
if any([isinstance(d, tvm.tir.Any) for d in dshape]):
end = _op.shape_of(data)
else:
end = dshape
if isinstance(target_end, int):
if isinstance(end, list):
end[dim] = target_end
else:
all_static = True
for i, shape_dim in enumerate(dshape):
if i != dim and isinstance(shape_dim, tvm.tir.Any):
all_static = False
if all_static:
end = list(get_const_tuple(dshape))
end[dim] = target_end
else:
target_end = _expr.const(target_end)
end = _op.scatter(
end,
_op.expand_dims(_expr.const(dim), axis=0),
_op.expand_dims(target_end, axis=0),
axis=0,
)
else:
end = _op.cast(_op.shape_of(data), axis_dtype)
if not isinstance(target_end, tvm.tir.Any):
ttype = self.infer_type(target_end).dtype
if str(ttype) != axis_dtype:
target_end = _op.cast(target_end, axis_dtype)
end = _op.scatter(
end,
_op.expand_dims(_expr.const(dim), axis=0),
_op.expand_dims(target_end, axis=0),
axis=0,
)
if not isinstance(end, list):
etype = self.infer_type(end).dtype
if str(etype) != axis_dtype:
end = _op.cast(end, axis_dtype)
strides = [1] * ndim
strides[dim] = stride
return _op.transform.strided_slice(
data, begin=begin, end=end, strides=strides, slice_mode="end"
)
def narrow(self, inputs, input_types):
# Inputs are:
# 0 - the tensor to narrow
# 1 - the dimension along which to narrow
# 2 - the starting dimension
# 3 - the distance to the ending dimension
# Lets find the ending dimension
end = self.add(inputs[2:4], input_types[2:4])
stride = 1
slice_input = inputs[:3] + [end, stride]
slice_types = input_types + ["int32"]
return self.slice(slice_input, slice_types)
def split(self, inputs, input_types):
data = inputs[0]
split_size = int(inputs[1])
dim = int(inputs[2])
split_index = split_size
indices = []
while split_index < self.infer_shape(data)[dim]:
indices.append(split_index)
split_index += split_size
return _op.split(data, indices, dim)
def split_with_sizes(self, inputs, input_types):
data = inputs[0]
sections = inputs[1]
dim = int(inputs[2])
if len(sections) == 1:
# a special case used in torchvision detection models
return _expr.TupleWrapper(_expr.Tuple([data]), 1)
split_index = 0
indices = []
for i in range(len(sections) - 1):
index, _ = try_infer_value(sections[i], lambda ret: int(ret))
split_index += index
indices.append(split_index)
return _op.split(data, indices, dim)
def select(self, inputs, input_types):
data = inputs[0]
dim = int(inputs[1])
index = _wrap_const(inputs[2])
return _op.transform.take(data, index, axis=dim, mode="wrap")
def take(self, inputs, input_types):
data = inputs[0]
indices = _op.cast(inputs[1], "int32")
return _op.transform.take(data, indices=indices, mode="wrap")
def topk(self, inputs, input_types):
data = inputs[0]
axis = int(inputs[2])
is_ascend = not bool(inputs[3])
sort = bool(inputs[4])
if isinstance(inputs[1], _expr.Expr):
k, _ = try_infer_value(inputs[1], lambda ret: ret.tolist())
else:
k = inputs[1]
if not sort:
msg = "Currently supports only sorted output for topk operator."
raise AssertionError(msg)
outs = _op.topk(data, k=k, axis=axis, is_ascend=is_ascend, ret_type="both", dtype="int64")
return outs[0], outs[1]
def reciprocal(self, inputs, input_types):
data = inputs[0]
return _expr.const(1.0, dtype=input_types[0]) / data
def repeat(self, inputs, input_types):
data = inputs[0]
reps = []
for r in inputs[1]:
if isinstance(r, int):
reps.append(r)
else:
reps.append(int(_infer_value(r, {}).asnumpy()))
return _op.transform.tile(data, reps=reps)
def repeat_interleave(self, inputs, input_types):
data = inputs[0]
if isinstance(inputs[1], int):
repeats = inputs[1]
axis = inputs[2]
else:
msg = "Only repeat with one value as repeat is currently supported."
raise AssertionError(msg)
if axis is None: # Flatten the data if no axis is given from torch
data = _op.transform.reshape(data, [-1])
axis = 0
return _op.transform.repeat(data, repeats=repeats, axis=axis)
def addcdiv(self, inputs, input_types):
data, t1, t2, c = self.pytorch_promote_types(inputs[:4], input_types[:4])
return data + (c * (t1 / t2))
def addcmul(self, inputs, input_types):
data, t1, t2, c = self.pytorch_promote_types(inputs[:4], input_types[:4])
return data + (c * (t1 * t2))
def where(self, inputs, input_types):
if len(inputs) == 1:
return self.nonzero([inputs[0], True], input_types)
cond = inputs[0]
x, y = self.pytorch_promote_types(inputs[1:3], input_types[1:3])
return _op.where(cond, x, y)
def full_impl(self, data, fill_value, dtype):
size = []
need_reshape = False
new_shape = []
for dim in data:
if isinstance(dim, _expr.Expr):
if isinstance(dim, _expr.Constant):
dim = int(dim.data.asnumpy())
if isinstance(size, list):
size.append(dim)
new_shape.append(dim)
else:
dim, success = try_infer_value(dim, lambda ret: int(ret), lambda: 0)
new_shape.append(dim)
if success:
if isinstance(size, list):
size.append(dim)
else:
size = None
need_reshape = True
else:
if isinstance(size, list):
size.append(dim)
new_shape.append(dim)
if size is None:
tmp = []
for dim in data:
tmp.append(_op.cast(_op.expand_dims(dim, axis=0), "int64"))
size = _op.concatenate(tmp, axis=0)
out = _op.full(_expr.const(fill_value), size, dtype=dtype)
if need_reshape:
out = _op.reshape(out, new_shape)
return out
def ones(self, inputs, input_types):
data = inputs[0]
import torch
if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)):
msg = "Data type %s could not be parsed in ones op" % (type(data))
raise AssertionError(msg)
if inputs[1] is not None:
dtype = _convert_dtype_value(inputs[1])
else:
dtype = self.default_dtype
return self.full_impl(data, 1, dtype)
def ones_like(self, inputs, input_types):
data = inputs[0]
out = _op.ones_like(data)
# If the input and the output datatype is different, do a cast
if inputs[1] is not None:
dtype = _convert_dtype_value(inputs[1])
else:
dtype = self.default_dtype
if input_types[0] != dtype:
out = _op.cast(out, dtype)
return out
def zeros(self, inputs, input_types):
data = inputs[0]
import torch
if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)):
msg = "Data type %s could not be parsed in zeros op" % (type(data))
raise AssertionError(msg)
if inputs[1] is not None:
dtype = _convert_dtype_value(inputs[1])
else:
dtype = self.default_dtype
return self.full_impl(data, 0, dtype)
def zeros_like(self, inputs, input_types):
data = inputs[0]
out = _op.zeros_like(data)
# If the input and the output datatype is different, do a cast
if inputs[1] is not None:
dtype = _convert_dtype_value(inputs[1])
else:
dtype = self.default_dtype
if input_types[0] not in dtype:
out = _op.cast(out, dtype)
return out
def full(self, inputs, input_types):
data = inputs[0]
fill_value = inputs[1]
import torch
if not isinstance(data, (_expr.Expr, list, torch.Tensor, np.ndarray)):
msg = "Data type %s could not be parsed in full op" % (type(data))
raise AssertionError(msg)
if inputs[2] is not None: # dtype given
dtype = _convert_dtype_value(inputs[2])
else:
# if dtype is None, torch uses a global default set by torch.set_default_tensor_type()
dtype = self.default_dtype
return self.full_impl(data, fill_value, dtype)
def full_like(self, inputs, input_types):
data = inputs[0]
fill_value = inputs[1]
out = _op.full_like(data, _expr.const(fill_value))
# If the input and the output datatype is different, do a cast
if inputs[2] is not None: # dtype given
dtype = _convert_dtype_value(inputs[2])
else:
# if dtype is None, torch uses a global default set by torch.set_default_tensor_type()
dtype = self.default_dtype
if input_types[0] not in dtype:
out = _op.cast(out, dtype)
return out
def linspace(self, inputs, input_types):
start = inputs[0]
stop = inputs[1]
step = inputs[2]
# Find the spacing between values as step
if step != 1:
step = (stop - start) / (step - 1)
stop = stop + step
else:
stop = start + step
dtype = "float32" if inputs[3] is not None else _convert_dtype_value(inputs[3])
start = _create_typed_const(start, dtype)
stop = _create_typed_const(stop, dtype)
step = _create_typed_const(step, dtype)
return _op.transform.arange(start=start, stop=stop, step=step, dtype=dtype)
def relu(self, inputs, input_types):
data = inputs[0]
if self.is_quantized_tensor(data):
assert len(inputs) == 3, "Input quant param not found in op inputs"
input_zero_point = _expr.const(inputs[2], dtype="int32")
return qnn_torch.quantized_relu(data, input_zero_point)
return _op.nn.relu(data)
def prelu(self, inputs, input_types):
data = inputs[0]
alpha = inputs[1]
return _op.nn.prelu(data, alpha)
def leaky_relu(self, inputs, input_types):
data = inputs[0]
alpha = float(inputs[1])
return _op.nn.leaky_relu(data, alpha)
def elu(self, inputs, input_types):
data = inputs[0]
dtype = input_types[0]
alpha = _expr.const(float(inputs[1]), dtype=dtype)
return alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data)
def celu(self, inputs, input_types):
data = inputs[0]
dtype = input_types[0]
alpha = _expr.const(float(inputs[1]), dtype=dtype)
return alpha * _op.nn.relu(
_expr.const(1, dtype=dtype) - _op.exp(data / alpha)
) + _op.nn.relu(data)
def gelu(self, inputs, input_types):
data = inputs[0]
dtype = input_types[0]
# gelu is data * normcdf(data)
# normcdf expressed as erf because we don't currently have that intrinsic
# note that there is also a fastgelu variant approximating normcdf
# with tanh and third order polynomials, but this is "true" gelu
return data * (
_expr.const(0.5, dtype=dtype)
+ _op.erf(data * _expr.const(0.5 ** 0.5, dtype=dtype)) * _expr.const(0.5, dtype=dtype)
)
def selu(self, inputs, input_types):
data = inputs[0]
# https://pytorch.org/docs/stable/nn.html#selu
dtype = input_types[0]
alpha = _expr.const(-1.6732632423543772848170429916717, dtype=dtype)
gamma = _expr.const(1.0507009873554804934193349852946, dtype=dtype)
return gamma * (
alpha * _op.nn.relu(_expr.const(1.0, dtype=dtype) - _op.exp(data)) + _op.nn.relu(data)
)
def log_sigmoid(self, inputs, input_types):
data = inputs[0]
return _op.log(_op.tensor.sigmoid(data))
def hard_sigmoid(self, inputs, input_types):
def _relu6(x):
return _op.tensor.clip(x, 0.0, 6.0)
def func(x):
return _relu6(x + _expr.const(3.0)) / _expr.const(6.0)
if self.is_quantized_tensor(inputs[0]):
input_scale = _expr.const(inputs[1])
input_zero_point = _expr.const(inputs[2])
# PyTorch seems to use the following output qparams, but accuracy
# is broken if we use this.
# TODO(masahi): Revisit this parameter choice
#
# Taken from src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
# output_scale = _expr.const(0.00390625) # 1.0 / 2^8
# output_zero_point = _expr.const(-128)
output_scale = input_scale
output_zero_point = input_zero_point
data = qnn.op.dequantize(inputs[0], input_scale, input_zero_point, axis=1)
out = func(data)
return qnn.op.quantize(out, output_scale, output_zero_point, out_dtype="uint8")
return func(inputs[0])
def hard_swish(self, inputs, input_types):
data = inputs[0]
return data * self.hard_sigmoid(inputs, input_types)
def adaptive_avg_pool_2d(self, inputs, input_types):
data = inputs[0]
output_size = inputs[1]
def func(x):
return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)
if self.is_quantized_tensor(data):
return qnn_torch.apply_with_upcast(data, func)
return func(data)
def adaptive_max_pool_2d(self, inputs, input_types):
data = inputs[0]
output_size = inputs[1]
# returns dummy indices too
return _op.nn.adaptive_max_pool2d(data, output_size=output_size), None
def adaptive_max_pool_3d(self, inputs, input_types):
data = inputs[0]
output_size = inputs[1]
# returns dummy indices too
return _op.nn.adaptive_max_pool3d(data, output_size=output_size), None
def adaptive_avg_pool_3d(self, inputs, input_types):
data = inputs[0]
output_size = inputs[1]
return _op.nn.adaptive_avg_pool3d(data, output_size=output_size)
@staticmethod
def convert_const_list(data):
if isinstance(data, list):
for i, _ in enumerate(data):
if isinstance(data[i], _expr.Expr):
data[i] = int(_infer_value_simulated(data[i], {}).asnumpy())
return data
def maxpool_2d(self, inputs, input_types):
data = inputs[0]
pool_size = self.convert_const_list(inputs[1])
strides = self.convert_const_list(inputs[2] if inputs[2] else pool_size)
padding = inputs[3]
dilation = inputs[4]
ceil_mode = int(inputs[5])
return _op.nn.max_pool2d(
data,
pool_size=pool_size,
strides=strides,
dilation=dilation,
padding=padding,
layout="NCHW",
ceil_mode=ceil_mode,
)
def maxpool_2d_with_indices(self, inputs, input_types):
# returns dummy indices too
return self.maxpool_2d(inputs, input_types), None
def maxpool_1d(self, inputs, input_types):
data = inputs[0]
pool_size = inputs[1]
strides = inputs[2] if inputs[2] else pool_size
padding = inputs[3]
dilation = inputs[4]
ceil_mode = int(inputs[5])
return _op.nn.max_pool1d(
data,
pool_size=pool_size,
strides=strides,
dilation=dilation,
padding=padding,
layout="NCW",
ceil_mode=ceil_mode,
)
def maxpool_3d(self, inputs, input_types):
data = inputs[0]
pool_size = inputs[1]
strides = inputs[2] if inputs[2] else pool_size
padding = inputs[3]
dilation = inputs[4]
ceil_mode = int(inputs[5])
return _op.nn.max_pool3d(
data,
pool_size=pool_size,
strides=strides,
dilation=dilation,
padding=padding,
ceil_mode=ceil_mode,
)
def hardtanh(self, inputs, input_types):
a = inputs[0]
tanh_min = float(inputs[1])
tanh_max = float(inputs[2])
return _op.tensor.clip(a, tanh_min, tanh_max)
def convolution(self, inputs, input_types):
# Use transpose or normal
use_transpose = True if inputs[6] == 1 else False
data = inputs[0]
weight = inputs[1]
bias = inputs[2]
strides = tuple(inputs[3])
padding = tuple(inputs[4])
dilation = tuple(inputs[5])
if isinstance(weight, _expr.Expr):
inferred_shape = self.infer_shape(weight)
weight_shape = []
for infer in inferred_shape:
weight_shape.append(infer)
else:
msg = "Data type %s could not be parsed in conv op" % (type(weight))
raise AssertionError(msg)
# Transposed convolutions have IOHW layout.
if use_transpose:
weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0]
channels = weight_shape[0]
groups = int(inputs[8])
# Check if this is depth wise convolution
# We need to reshape weight so that Relay could recognize this is depth wise
# weight_shape[1] is always in_channels // groups
# For depthwise, in_channels == groups, so weight_shape[1] == 1
# If groups > 1 but weight_shape[1] != 1, this is group convolution
if groups > 1 and weight_shape[1] == 1:
channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier) + tuple(weight_shape[2:])
weight = _op.transform.reshape(weight, new_weight_shape)
kernel_size = weight_shape[2:]
use_bias = isinstance(bias, _expr.Expr)
if len(kernel_size) == 1:
strides = (1,) + strides
padding = (0,) + padding
dilation = (1,) + dilation
if use_transpose:
if len(kernel_size) == 3:
conv_op = _op.nn.conv3d_transpose
else:
conv_op = _op.nn.conv2d_transpose
else:
if len(kernel_size) == 3:
conv_op = _op.nn.conv3d
else:
conv_op = _op.nn.conv2d
if len(kernel_size) == 3:
data_layout = "NCDHW"
kernel_layout = "OIDHW"
else:
data_layout = "NCHW"