forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
symbolic_opset9.py
2294 lines (1811 loc) · 91.3 KB
/
symbolic_opset9.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import torch
from torch._C import ListType, OptionalType
from torch.nn.modules.utils import _single, _pair, _triple
import torch.onnx
# This import monkey-patches graph manipulation methods on Graph, used for the
# ONNX symbolics
import torch.onnx.utils
from functools import partial
from functools import wraps
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _parse_arg, _unimplemented
import numpy
import math
import warnings
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
# This file exports ONNX ops for opset 9
# Opset 9 is supported by ONNX release 1.4.1
# release on 01/23/19
# Note [Pointwise by scalar]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
# What happens if you add a tensor with a constant (e.g., x + 2)? There are
# some moving parts to implementing the ONNX translation in this case:
#
# - By the time we get the scalar in a symbolic function here, it is no longer
# a Python long/float, but a PyTorch tensor with numel == 1 (eventually, we
# want it to be a zero dim tensor but this change has not happened yet.)
# However, the type of this scalar is *exactly* what the user wrote in
# Python, which may not match the tensor it is being added to. PyTorch
# will do implicit conversions on scalars; however, ONNX will not, so
# we must do the conversion ourselves. This is what _if_scalar_type_as
# does.
#
# - Dispatch to these functions takes advantage an outrageous coincidence
# between the tensor and scalar name. When we add two tensors together,
# you get the dispatch:
#
# add(*[self, other], **{"alpha": alpha})
#
# When you add a tensor and a scalar, you get the dispatch:
#
# add(*[self], **{"other": other, "alpha": alpha})
#
# By having the argument name line up with the name of the scalar attribute
# if it exists, we can write a single function for both overloads.
#
# used to represent "missing" optional inputs
def unused(g):
n = g.op("prim::Constant")
n.setType(OptionalType.ofTensor())
return n
def _shape_as_tensor(g, input):
return g.op('Shape', input)
def _reshape_from_tensor(g, input, shape):
return g.op('Reshape', input, shape)
def reshape(g, self, shape):
return view(g, self, shape)
def reshape_as(g, self, other):
shape = g.op('Shape', other)
return reshape(g, self, shape)
def add(g, self, other, alpha=None):
# default alpha arg is to allow no-alpha add (aten add st overload no alpha)
if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1:
return _unimplemented("add", "alpha != 1")
return g.op("Add", self, other)
def sub(g, self, other, alpha=None):
# default alpha arg is to allow no-alpha sub (aten sub st overload no alpha)
if alpha and sym_help._scalar(sym_help._maybe_get_scalar(alpha)) != 1:
return _unimplemented("sub", "alpha != 1")
return g.op("Sub", self, other)
def rsub(g, self, other, alpha=None):
return sub(g, other, self, alpha=alpha)
def mul(g, self, other):
return g.op("Mul", self, other)
def div(g, self, other):
return g.op("Div", self, other)
def floor_divide(g, self, other):
out = div(g, self, other)
# the correct operation is truncate, which is not supported in ONNX,
# we cannot call floor since it will behave differently for negative numbers
# (eg. -0.1 should become -0 )
# - if scalar_type information are not available, assume that
# we need to call floor (treat as float)
out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx['Long'])
# Matching PyTorch's behavior:
# - if self is fp the output's type is self's type
# - if self is not fp and other is fp, the output is of type 'Float'
# - self is not fp and other is not fp, the output's type is self's output type
# - the output type defaults to Float
scalar_type = self.type().scalarType()
if scalar_type is not None:
if not sym_help._is_fp(self) and \
other.type().scalarType() is not None and \
sym_help._is_fp(other):
out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx['Float'])
else:
out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx[scalar_type])
else:
out = g.op("Cast", out, to_i=sym_help.cast_pytorch_to_onnx['Float'])
return out
# Division where both inputs are cast to floating types
# If both inputs are floating, performs div as usual
# If only one input is a floating type, the other input is cast to its type
# If neither input is a floating type, both inputs are cast to the default scalar type
def true_divide(g, self, other):
# Case 1: both values are floating
# Performs div as usual
if sym_help._is_fp(self) and sym_help._is_fp(other):
return div(g, self, other)
# Case 2: self is floating, other is not
# Casts other to self's dtype
if sym_help._is_fp(self):
g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
return div(g, self, other)
# Case 3: other is floating, self is not
# Casts self to other's dtype
if sym_help._is_fp(other):
g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[other.type().scalarType()])
return div(g, self, other)
# Case 4: neither is floating
# Casts both inputs to the default scalar type
scalar_type = torch.get_default_dtype()
onnx_scalar_type = sym_help.cast_pytorch_to_onnx['Float']
assert scalar_type is torch.float or scalar_type is torch.double
if torch.get_default_dtype() is torch.double:
onnx_scalar_type = sym_help.cast_pytorch_to_onnx['Double']
g.op("Cast", self, to_i=onnx_scalar_type)
g.op("Cast", other, to_i=onnx_scalar_type)
return div(g, self, other)
def reciprocal(g, self):
return g.op("Div", torch.ones(1), self)
@parse_args('v', 'i')
def cat(g, tensor_list, dim):
tensors = sym_help._unpack_list(tensor_list)
return g.op("Concat", *tensors, axis_i=dim)
@parse_args('v', 'i')
def stack(g, tensor_list, dim):
unsqueezed = [g.op("Unsqueeze", t, axes_i=[dim]) for t in sym_help._unpack_list(tensor_list)]
return g.op("Concat", *unsqueezed, axis_i=dim)
def mm(g, self, other):
# Create a dummy C tensor. Only needed for API purposes, the value is
# since beta = 0
C = g.op("Constant", value_t=torch.tensor([1]))
return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
def bmm(g, self, other):
return g.op("MatMul", self, other)
def matmul(g, self, other):
return g.op("MatMul", self, other)
@parse_args('v', 'v', 'v', 't', 't')
def addmm(g, self, mat1, mat2, beta, alpha):
return g.op("Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha))
def neg(g, self):
return g.op("Neg", self)
def sqrt(g, self):
return g.op("Sqrt", self)
def rsqrt(g, self):
return div(g, sym_help._if_scalar_type_as(g, torch.ones(1), self), sqrt(g, self))
def tanh(g, self):
return g.op("Tanh", self)
def sin(g, self):
return g.op("Sin", self)
def cos(g, self):
return g.op("Cos", self)
def tan(g, self):
return g.op("Tan", self)
def asin(g, self):
return g.op("Asin", self)
def acos(g, self):
return g.op("Acos", self)
def atan(g, self):
return g.op("Atan", self)
def sigmoid(g, self):
return g.op("Sigmoid", self)
def sign(g, self):
return g.op("Sign", self)
def _slice(g, input, axes, starts, ends):
assert len(starts) == len(ends)
if len(starts) == 1 and starts[0] == 0 and ends[0] == 9223372036854775807:
return input
return g.op("Slice", input, axes_i=axes, starts_i=starts, ends_i=ends)
def _reduce_op_symbolic(onnx_op_name, allow_multi_dim_support=True):
def symbolic(g, self, dim=None, keepdim=None):
if dim is None:
# all-reduce path
return g.op(onnx_op_name, self, keepdims_i=0)
else:
# dim-reduce path
desc = 'is' if allow_multi_dim_support else 'i'
dim, keepdim = sym_help._get_const(dim, desc, 'dim'), sym_help._get_const(keepdim, 'i', 'keepdim')
dim_list = dim if allow_multi_dim_support else [dim]
return g.op(onnx_op_name, self, axes_i=dim_list, keepdims_i=keepdim)
return symbolic
def overload_by_arg_count(fn):
@wraps(fn)
def wrapper(g, *args):
overloads = fn(g, *args)
last_exception = None
for overload in overloads:
arg_descriptors = overload._arg_descriptors
if len(arg_descriptors) == len(args):
return overload(g, *args)
raise NotImplementedError("Unknown aten::{} signature".format(fn.__name__))
return wrapper
def _reduce_with_dtype(onnx_op, name, allow_multi_dim_support=True):
symbolic = _reduce_op_symbolic(onnx_op, allow_multi_dim_support=allow_multi_dim_support)
@overload_by_arg_count
def reduce(g, *args, **kwargs):
@parse_args('v', 'none')
def reduce_nodim(g, self, dtype):
if dtype.node().kind() != 'prim::Constant':
return _unimplemented(name, "dtype")
return symbolic(g, self)
dim_desc = 'is' if allow_multi_dim_support else 'i'
@parse_args('v', dim_desc, 'i', 'none')
def reduce_dim(g, self, dim, keepdim, dtype):
if dtype.node().kind() != 'prim::Constant':
return _unimplemented(name, "dtype")
return symbolic(g, self, dim, keepdim)
return reduce_nodim, reduce_dim
return reduce
sum = _reduce_with_dtype('ReduceSum', 'sum')
mean = _reduce_with_dtype('ReduceMean', 'mean')
prod = _reduce_with_dtype('ReduceProd', 'prod', allow_multi_dim_support=False) # torch.prod does not support multidimensional 'dim'
@parse_args('v', 'i', 'none')
def cumsum(g, input, dim, dtype):
if dtype.node().kind() != 'prim::Constant':
return _unimplemented(name, "dtype")
return g.op("ATen", input, operator_s="cumsum", dim_i=dim)
def _sample_dirichlet(g, self, generator):
if not sym_help._is_none(generator):
return _unimplemented('_sample_dirichlet',
'We are not able to export generator')
return g.op("ATen", self, operator_s="_sample_dirichlet")
def _standard_gamma(g, self, generator):
if not sym_help._is_none(generator):
return _unimplemented('_standard_gamma',
'We are not able to export generator')
return g.op("ATen", self, operator_s="_standard_gamma")
def t(g, self):
return g.op("Transpose", self, perm_i=(1, 0))
def expand(g, self, size, implicit):
size = sym_help._maybe_get_const(size, 'is')
if not sym_help._is_value(size):
size = g.op("Constant", value_t=torch.LongTensor(size))
elif sym_help._is_packed_list(size):
# Expand with -1 dim value means dim is unchanged.
# Since onnx::expand supports two-way broadcasting,
# -1 dim value can be exported to onnx as 1
size = view(g, stack(g, size, 0), [-1])
dtype = 4 # dim type is int64
ones = ones_like(g, size, dtype)
neg_ones = mul(g, ones, g.op("Constant", value_t=torch.tensor(-1)))
size = where(g, g.op("Equal", size, neg_ones), ones, size)
return g.op("Expand", self, size)
def expand_as(g, self, other):
shape = g.op("Shape", other)
return g.op("Expand", self, shape)
def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse):
return g.op("Gather", weight, indices)
@parse_args('v', 'v', 'v', 'i', 'i', 'i', 'v', 'i')
def embedding_bag(g,
embedding_matrix,
indices,
offsets,
scale_grad_by_freq,
mode,
sparse,
per_sample_weights,
include_last_offset):
if not sym_help._is_none(per_sample_weights):
raise RuntimeError('Unsupported: ONNX export of embedding_bag '
'with per_sample_weights')
return g.op("ATen",
embedding_matrix,
indices,
offsets,
operator_s="embedding_bag",
outputs=4,
scale_grad_by_freq_i=scale_grad_by_freq,
mode_i=mode,
sparse_i=sparse,
include_last_offset_i=include_last_offset)
def size(g, self, dim):
if sym_help._maybe_get_const(dim, 'i') < 0:
rank = self.type().dim()
if rank:
dim = sym_help._maybe_get_const(dim, 'i') + rank
dim = g.op("Constant", value_t=torch.tensor(dim))
return sym_help._size_helper(g, self, dim)
@parse_args('v', 'i', 'i')
def transpose(g, self, dim0, dim1):
if dim0 == dim1: # micro-optimization
return self
# NB: Transpose in ONNX is actually a Permute
if self.isCompleteTensor():
axes = list(range(self.type().dim()))
axes[dim0], axes[dim1] = axes[dim1], axes[dim0]
return g.op("Transpose", self, perm_i=axes)
else:
# if we don't have dim information we cannot
# output a permute so use ATen instead
return g.op("ATen", self, operator_s="transpose", dim0_i=dim0, dim1_i=dim1)
@parse_args('v', 'is')
def permute(g, self, dims):
if dims == list(range(0, len(dims))):
return self
return g.op("Transpose", self, perm_i=dims)
def view(g, self, size):
size = sym_help._maybe_get_const(size, 'is')
if sym_help._is_value(size):
shape = size
else:
if self.isCompleteTensor():
self_sizes = self.type().sizes()
if self_sizes and len(size) == 2 and self_sizes[0] == size[0]:
return g.op("Flatten", self, axis_i=1)
shape = g.op("Constant", value_t=torch.LongTensor(size))
return g.op("Reshape", self, shape)
def prim_ConstantSplit(g, self, split_size, dim):
size = self.type().sizes()[dim]
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=len(splits))
# TODO: It would be better to export this as a chunk directly, as this is
# less sensitive to changes in input size.
# TODO: Once we have proper scoping, stop reimplementing chunk, delete this
# method, and use the desugared version
def prim_ConstantChunk(g, self, chunks, dim):
split_size = (self.type().sizes()[dim] + chunks - 1) // chunks
return prim_ConstantSplit(g, self, split_size, dim)
def split(g, self, split_size_or_sizes, dim):
if sym_help._is_value(split_size_or_sizes) and split_size_or_sizes.node().kind() != 'onnx::Constant':
raise RuntimeError("ONNX symbolic expected a constant value of the {} argument, got `{}`"
.format('split_size_or_sizes', split_size_or_sizes))
split_val = split_size_or_sizes.node()['value']
if split_val.dim() > 0:
return split_with_sizes(g, self, split_size_or_sizes, dim)
split_size = sym_help._get_const(split_size_or_sizes, 'i', 'split_size')
dim = sym_help._get_const(dim, 'i', 'dim')
size = self.type().sizes()[dim]
splits = [split_size] * (size // split_size)
leftover = size % split_size
if leftover:
splits.append(leftover)
return g.op("Split", self, split_i=splits, axis_i=dim, outputs=1)
@parse_args('v', 'is', 'i')
def split_with_sizes(g, self, split_sizes, dim):
return g.op("Split", self, split_i=split_sizes, axis_i=dim, outputs=1)
@parse_args('v', 'i')
def unbind(g, self, dim=0):
# NOTE: This conversion of this node is handled in onnx peephole pass.
# Due to that an additional Squeeze node needs to be inserted for each output from unbind.
return g.op("aten::unbind", self, axis_i=dim)
@parse_args('v', 'i', 'v')
def select(g, self, dim, index):
index = sym_help._maybe_get_scalar(index)
if (not sym_help._is_value(index)) and (index < 0):
if index == -1:
end_index = 9223372036854775807
else:
end_index = index + 1
slice_node = sym_help._slice_helper(g, self, axes=[dim], starts=[index], ends=[end_index])
return g.op("Squeeze", slice_node, axes_i=[dim])
else:
return g.op("Gather", self, index, axis_i=dim)
def squeeze(g, self, dim=None):
if dim is None:
dims = []
for i, size in enumerate(self.type().sizes()):
if size == 1:
dims.append(i)
else:
dims = [sym_help._get_const(dim, 'i', 'dim')]
# Handle negative dims
for i, dim in enumerate(dims):
if dim < 0:
rank = self.type().dim()
if rank:
warnings.warn("ONNX export squeeze with negative axis " + str(dim) +
" might cause the onnx model to be incorrect. " +
"Negative axis is not supported in ONNX. " +
"Axis is converted to " + str(dim + rank) +
" based on input shape at export time. " +
"Passing an tensor of different rank in execution will be incorrect.")
dims[i] += rank
else:
return _unimplemented('squeeze', 'negative axis with unknown input rank')
return g.op("Squeeze", self, axes_i=dims)
def prelu(g, self, weight):
if self.isCompleteTensor():
self_sizes = self.type().sizes()
if self_sizes and len(self_sizes) > 2:
weight = g.op("Unsqueeze", weight, axes_i=list(range(1, len(self_sizes) - 1)))
return g.op("PRelu", self, weight)
def relu(g, input):
return g.op("Relu", input)
def ceil(g, input):
return g.op("Ceil", input)
def floor(g, input):
return g.op("Floor", input)
@parse_args('v', 't', 't')
def threshold(g, self, threshold, value):
# See Note [Export inplace]
if sym_help._scalar(threshold) != 0:
return _unimplemented("threshold", "non-zero threshold")
if sym_help._scalar(value) != 0:
return _unimplemented("threshold", "non-zero value")
return g.op("Relu", self)
def leaky_relu(g, input, negative_slope, inplace=False):
negative_slope = sym_help._get_const(negative_slope, 't', 'negative_slope')
# See Note [Export inplace]
# TODO: Talk to ONNX about unconditional cast of scalar to float
return g.op("LeakyRelu", input, alpha_f=sym_help._scalar(negative_slope))
@parse_args('v', 'i')
def glu(g, input, dim):
assert input.type().sizes()[dim] % 2 == 0
first, second = g.op('Split', input, axis_i=dim, outputs=2)
return g.op('Mul', first, g.op('Sigmoid', second))
@parse_args('v', 'i', 'none')
def softmax(g, input, dim, dtype=None):
# Softmax does normalization at vector level.
# PyTorch and ONNX use different strategies to split the input tensor into vectors.
# Thus dim and axis have different meanings.
# PyTorch slices the input tensor into vectors along the `dim`-th dimension.
# ONNX reshapes the input into a 2-D tensor, and `axis` indicates where the input is coerced.
# If input is a 2 x 3 tensor:
# input = [[1.0, 1.0, 1.0],
# [1.0, 1,0, 1,0]]
# with dim = 0, the result is:
# result = [[0.5, 0.5, 0.5],
# [0.5, 0.5, 0.5]]
# with axis = 0, the result is:
# result = [[0.167, 0.167, 0.167],
# [0.167, 0.167, 0.167]]
# So only when dim and axis both equal to ndim - 1 (the last dimension),
# their semantics are equivalent.
# So use softmax when dim and axis both equal to ndim - 1
# otherwise compute softmax using a subgraph with other operators
input_dim = input.type().dim()
if input_dim:
# TODO: remove this as onnx opset 11 spec allows negative axes
if dim < 0:
dim = input_dim + dim
if input_dim == dim + 1:
softmax = g.op('Softmax', input, axis_i=dim)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
softmax = g.op("Cast", softmax, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return softmax
exp = g.op('Exp', input)
sum = g.op('ReduceSum', exp, axes_i=[dim])
softmax = g.op('Div', exp, sum)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
softmax = g.op("Cast", softmax, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return softmax
@parse_args('v', 't', 'v')
def softplus(g, self, beta, threshold):
if beta != 1:
return _unimplemented("beta", "has to be 1")
return g.op('Softplus', self)
def get_pool_ceil_padding(input, kernel_size, stride, padding):
dim = input.type().sizes()[-len(padding):]
ceiled_output_dim = [int(math.ceil((dim[i] + 2 * padding[i] - kernel_size[i]) / float(stride[i]))) + 1
for i in range(0, len(padding))]
# ensure last pooling starts inside
ceiled_output_dim = [ceiled_output_dim[i] - 1
if (((ceiled_output_dim[i] - 1) * stride[i]) >= (dim[i] + padding[i]))
else ceiled_output_dim[i]
for i in range(0, len(ceiled_output_dim))]
padding_ceil = [0
if (stride[i] == 1)
else
(kernel_size[i] - (dim[i] + 2 * padding[i] - ((ceiled_output_dim[i] - 1) * stride[i] + 1)))
for i in range(0, len(padding))]
# ensure padding is not > kernel_size
padding_ceil = [(int(padding_ceil[i]) if padding_ceil[i] < kernel_size[i] - 1 else int(kernel_size[i] - 1))
if ((padding_ceil[i] + 2 * padding[i]) >= (kernel_size[i]))
else
int(padding_ceil[i])
for i in range(0, len(padding_ceil))]
return padding_ceil
def _max_pool(name, tuple_fn, ndims, return_indices):
@parse_args('v', 'is', 'is', 'is', 'is', 'i')
def symbolic_fn(g, input, kernel_size, stride, padding, dilation, ceil_mode):
if ceil_mode and not input.isCompleteTensor():
return _unimplemented(name, "input size not accessible")
if set(tuple_fn(dilation)) != {1}:
return _unimplemented(name, "dilation")
if not stride:
stride = kernel_size
padding = tuple(tuple_fn(padding))
if ceil_mode:
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
padding = padding + tuple(numpy.add(padding_ceil, padding))
else:
padding = padding * 2
kwargs = {
'kernel_shape_i': tuple_fn(kernel_size),
'pads_i': padding,
'strides_i': tuple_fn(stride),
}
# easy but hacky way to get flattened indices values
# to be used to convert the indices values to non-flattened.
# In ONNX the indices are computed as a flatten 1-D tensor,
# so the values in indices are in [0, N x C x D1 x ... x Dn).
# To convert the indices to the same format used by Pytorch,
# we first execute a maxpool with a kernel and stride of 1 on the same input.
# This will result in a tensor of indices in which each index will have it's own value.
# Using this tensor as a reference, we extract the first index of each axis and substract
# it from each index of this axis in the indices to convert.
# This step will result in a tensor were each dimension has values of indices within
# the dimension it is in.
# For more information :
# https://github.com/pytorch/pytorch/pull/16455#issuecomment-460776407
if return_indices:
r, indices = g.op("MaxPool", input, outputs=2, **kwargs)
_, flattened_indices = g.op("MaxPool", input, outputs=2,
kernel_shape_i=[1 for _ in range(ndims)],
strides_i=[1 for _ in range(ndims)])
# convert indices to have non-flattened indices values
s = sym_help._slice_helper(g, flattened_indices, axes=[2 + i for i in range(ndims)],
starts=tuple_fn(0), ends=tuple_fn(1))
indices = sub(g, indices, s)
return r, indices
else:
r = g.op("MaxPool", input, outputs=1, **kwargs)
return r
return symbolic_fn
max_pool1d = _max_pool("max_pool1d", _single, 1, return_indices=False)
max_pool2d = _max_pool("max_pool2d", _pair, 2, return_indices=False)
max_pool3d = _max_pool("max_pool3d", _triple, 3, return_indices=False)
max_pool1d_with_indices = _max_pool("max_pool1d_with_indices", _single, 1, return_indices=True)
max_pool2d_with_indices = _max_pool("max_pool2d_with_indices", _pair, 2, return_indices=True)
max_pool3d_with_indices = _max_pool("max_pool3d_with_indices", _triple, 3, return_indices=True)
def _avg_pool(name, tuple_fn):
@parse_args('v', 'is', 'is', 'is', 'i', 'i', 'none')
def symbolic_fn(g, input, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override=None):
if ceil_mode and not input.isCompleteTensor():
return _unimplemented(name, "input size not accessible")
if not stride:
stride = kernel_size
padding = sym_help._avgpool_helper(tuple_fn, padding, kernel_size, stride, divisor_override, name)
if ceil_mode:
padding_ceil = get_pool_ceil_padding(input, kernel_size, stride, padding)
if count_include_pad:
input = g.op("Pad", input,
pads_i=((0,) * 2 + padding) * 2,
mode_s='constant',
value_f=0.)
padding = (0,) * len(padding)
if ceil_mode:
padding = padding + tuple(numpy.add(padding_ceil, padding))
else:
padding = padding * 2
output = g.op("AveragePool", input,
kernel_shape_i=tuple_fn(kernel_size),
strides_i=tuple_fn(stride),
pads_i=padding)
return output
return symbolic_fn
avg_pool1d = _avg_pool('avg_pool1d', _single)
avg_pool2d = _avg_pool('avg_pool2d', _pair)
avg_pool3d = _avg_pool('avg_pool3d', _triple)
def _adaptive_pool(name, type, tuple_fn, fn=None):
@parse_args('v', 'is')
def symbolic_fn(g, input, output_size):
# _adaptive_pool is supported for cases where output_size is 1 for all dimensions,
# by executing a GlobalPool.
# It is also supported for cases where the output size is a factor of the input size.
# For these cases the stride and kernel size are uniform along all the indices of
# the same dimension, which makes it possible to export it to ONNX.
# for MaxPool, GlobalMaxPool does not return indices,
# so we try using max_poolxd_with_indices, and if it is not possible
# (input is not a complete tensor or output size not factor of input size)
# then we call GlobalAveragePool and return None for the indices
if output_size == [1] * len(output_size) and type == "AveragePool":
return g.op("GlobalAveragePool", input)
if not input.isCompleteTensor():
if output_size == [1] * len(output_size):
return g.op("GlobalMaxPool", input), None
return _unimplemented(name, 'input size not accessible')
dim = input.type().sizes()[2:]
# verify if output size % input size = 0 for all dim
mod = [dim[i] % output_size[i] for i in range(0, len(dim))]
if mod != [0] * len(mod):
if output_size == [1] * len(output_size):
return g.op("GlobalMaxPool", input), None
return _unimplemented(name, 'output size that are not factor of input size')
k = [int(dim[i] / output_size[i]) for i in range(0, len(dim))]
# call max_poolxd_with_indices to get indices in the output
if type == "MaxPool":
return fn(g, input, k, k, (0,) * len(dim), (1,) * len(dim), False)
output = g.op(type, input,
kernel_shape_i=tuple_fn(k),
strides_i=tuple_fn(k))
return output
return symbolic_fn
adaptive_avg_pool1d = _adaptive_pool('adaptive_avg_pool1d', "AveragePool", _single)
adaptive_avg_pool2d = _adaptive_pool('adaptive_avg_pool2d', "AveragePool", _pair)
adaptive_avg_pool3d = _adaptive_pool('adaptive_avg_pool3d', "AveragePool", _triple)
adaptive_max_pool1d = _adaptive_pool('adaptive_max_pool1d', "MaxPool", _single, max_pool1d_with_indices)
adaptive_max_pool2d = _adaptive_pool('adaptive_max_pool2d', "MaxPool", _pair, max_pool2d_with_indices)
adaptive_max_pool3d = _adaptive_pool('adaptive_max_pool3d', "MaxPool", _triple, max_pool3d_with_indices)
# Generate paddings in ONNX order based on pad in pytorch.
# Arguments:
# dim: the dimension of the tensor.
# pad: the paddings in pytorch.
# The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ...
def _prepare_onnx_paddings(dim, pad):
assert isinstance(dim, int)
# The desired order of paddings is
# dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
# n is the dimension of input.
# assume zero-dimensions in the beginning
paddings = list(pad[:]) + [0] * (dim * 2 - len(pad))
# reverse order and collate first beginnings and then ends
paddings = paddings[-2::-2] + paddings[-1::-2]
return paddings
@parse_args('v', 'is', 'f')
def constant_pad_nd(g, input, padding, value):
mode = "constant"
paddings = _prepare_onnx_paddings(input.type().dim(), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode, value_f=value)
@parse_args('v', 'is')
def reflection_pad(g, input, padding):
mode = "reflect"
paddings = _prepare_onnx_paddings(input.type().dim(), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode)
@parse_args('v', 'is')
def replication_pad(g, input, padding):
mode = "edge"
paddings = _prepare_onnx_paddings(input.type().dim(), padding)
return g.op("Pad", input, pads_i=paddings, mode_s=mode)
reflection_pad1d = reflection_pad
reflection_pad2d = reflection_pad
reflection_pad3d = reflection_pad
replication_pad1d = replication_pad
replication_pad2d = replication_pad
replication_pad3d = replication_pad
def _interpolate(name, dim, interpolate_mode):
def symbolic_fn(g, input, output_size, *args):
scales, align_corners = sym_help._get_interpolate_attributes(g, interpolate_mode, args)
sym_help._interpolate_warning(interpolate_mode)
align_corners = sym_help._maybe_get_scalar(align_corners)
if align_corners:
return _unimplemented(name, "align_corners == True")
if scales is None:
scales = sym_help._interpolate_size_to_scales(g, input, output_size, dim)
return g.op("Upsample", input, scales, mode_s=interpolate_mode)
return symbolic_fn
upsample_nearest1d = _interpolate('upsample_nearest1d', 3, "nearest")
upsample_nearest2d = _interpolate('upsample_nearest2d', 4, "nearest")
upsample_nearest3d = _interpolate('upsample_nearest3d', 5, "nearest")
upsample_linear1d = _interpolate('upsample_linear1d', 3, "linear")
upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, "linear")
upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, "linear")
def __interpolate(g, input, size, scale_factor, mode , align_corners, recompute_scale_factor):
scales, mode = sym_help._interpolate_get_scales_and_mode(g, input, size, scale_factor,
mode , align_corners)
return g.op("Upsample", input, scales, mode_s=mode)
@parse_args('v')
def bitwise_not(g, inp):
if inp.type().scalarType() != 'Bool':
return _unimplemented("bitwise_not", "non-bool tensor")
return g.op("Not", inp)
def wrap_logical_op_with_cast_to(to_type):
def decorator(fn):
def wrap_with_cast(g, input, other):
return g.op("Cast", fn(g, input, other), to_i=sym_help.cast_pytorch_to_onnx[to_type])
return wrap_with_cast
return decorator
def wrap_logical_op_with_cast_to_and_from(to_type):
def decorator(fn):
def wrap_with_cast(g, input, other):
to_cast_func = globals()['_cast_{}'.format(to_type)]
from_cast_func = wrap_logical_op_with_cast_to(input.type().scalarType())(fn)
return from_cast_func(g, to_cast_func(g, input, False), to_cast_func(g, other, False))
return wrap_with_cast
return decorator
def wrap_logical_op_with_negation(func):
def wrap_with_not(g, input, other):
return g.op("Not", func(g, input, other))
return wrap_with_not
def eq(g, self, other):
return g.op("Equal", self, other)
@wrap_logical_op_with_negation
def ne(g, self, other):
return g.op("Equal", self, other)
def gt(g, input, other):
return gt_impl(g, input, other)
def gt_impl(g, input, other):
return g.op("Greater", input, other)
def lt(g, input, other):
return lt_impl(g, input, other)
def lt_impl(g, input, other):
return g.op("Less", input, other)
@wrap_logical_op_with_negation
def ge(g, input, other):
return lt_impl(g, input, other)
@wrap_logical_op_with_negation
def le(g, input, other):
return gt_impl(g, input, other)
@wrap_logical_op_with_cast_to_and_from('Bool')
def __and_(g, input, other):
return g.op('And', input, other)
@wrap_logical_op_with_cast_to_and_from('Bool')
def __or_(g, input, other):
return g.op('Or', input, other)
def __rshift_(g, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
if other.type().scalarType() != self.type().scalarType():
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
two = g.op('Constant', value_t=torch.tensor(2, dtype=torch.float32))
# exponent (same type as self) has to be float or double in onnx::Pow
if not sym_help._is_fp(self):
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx['Float'])
two_pow = g.op('Pow', two, other)
rshift = g.op('Div', self, two_pow)
return rshift
def __lshift_(g, self, other):
# make sure to cast other to self's type
# (when self is long, make sure that other is not float)
if other.type().scalarType() != self.type().scalarType():
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx[self.type().scalarType()])
two = g.op('Constant', value_t=torch.tensor(2, dtype=torch.float32))
# exponent (same type as self) has to be float or double in onnx::Pow
if not sym_help._is_fp(self):
other = g.op("Cast", other, to_i=sym_help.cast_pytorch_to_onnx['Float'])
two_pow = g.op('Pow', two, other)
lshift = g.op('Mul', self, two_pow)
return lshift
def where(g, condition, self, other):
return g.op("Where", condition, self, other)
@parse_args('v', 'i', 'none')
def log_softmax(g, input, dim, dtype=None):
# PyTorch dim and ONNX axis have different meanings.
# See Softmax comment for details.
# TODO: remove this as onnx opset 11 spec allows negative axes
input_dim = input.type().dim()
if input_dim is None:
return _unimplemented("dim",
"ONNX and PyTorch use different strategies to split the input. "
"Input rank must be known at export time.")
if dim < 0:
dim = input_dim + dim
is_transpose_required = (input_dim != dim + 1)
# ONNX only supports log_softmax with dim = -1. Transpose must be added before and after log_softmax to support other cases.
if is_transpose_required:
axes = list(range(input_dim))
axes[dim], axes[-1] = axes[-1], axes[dim]
input = g.op("Transpose", input, perm_i=axes)
dim = input_dim - 1
return_op = g.op("LogSoftmax", input, axis_i=dim)
if dtype and dtype.node().kind() != 'prim::Constant':
return_op = g.op("Cast", return_op, to_i=sym_help.scalar_type_to_onnx[dtype])
if is_transpose_required:
return_op = g.op("Transpose", return_op, perm_i=axes)
return return_op
@parse_args('v', 'v', 'v', 'is', 'is', 'is', 'i', 'is', 'i', 'i', 'i', 'i')
def _convolution(g, input, weight, bias, stride, padding, dilation,
transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled):
weight_size = weight.type().sizes()
args = [input, weight]
# ONNX only supports 1D bias
if not sym_help._is_none(bias) and bias.type().dim() == 1:
args.append(bias)
kwargs = {"kernel_shape_i": weight_size[2:],
"strides_i": stride,
# NB: ONNX supports asymmetric padding, whereas PyTorch supports only
# symmetric padding
"pads_i": padding + padding,
"dilations_i": dilation,
"group_i": groups}