forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
symbolic_opset11.py
119 lines (90 loc) · 4.87 KB
/
symbolic_opset11.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx.symbolic_helper import _black_list_in_opset
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in symbolic_helper.py
# This file exports ONNX ops for opset 11
black_listed_operators = [
"eq", "ne", "sort", "topk", "hardtanh"
]
for black_listed_op in black_listed_operators:
vars()[black_listed_op] = _black_list_in_opset(black_listed_op)
def clamp(g, self, min, max):
dtype = self.type().scalarType()
def _cast_if_not_none(tensor, dtype):
if tensor is not None and not tensor.node().mustBeNone():
return g.op("Cast", tensor, to_i=sym_help.cast_pytorch_to_onnx[dtype])
else:
return tensor
if dtype is not None:
min = _cast_if_not_none(min, dtype)
max = _cast_if_not_none(max, dtype)
return g.op("Clip", self, min, max)
@parse_args('v', 'i')
def pixel_shuffle(g, self, upscale_factor):
dims = self.type().sizes()
if len(dims) != 4:
return _unimplemented("pixel_shuffle", "only support 4d input")
return g.op("DepthToSpace", self, blocksize_i=upscale_factor, mode_s="CRD")
def _interpolate(name, dim, interpolate_mode):
def symbolic_fn(g, input, output_size, align_corners=None):
align_corners = sym_help._maybe_get_scalar(align_corners)
output_size = sym_help._maybe_get_const(output_size, 'is')
if sym_help._is_value(output_size):
offsets = g.op("Constant", value_t=torch.ones(offset, dtype=torch.int64))
output_size = g.op("Concat", offsets, output_size, axis_i=0)
else:
output_size = [1 if i < 2 else output_size[-(dim - i)] for i in range(0, dim)]
output_size = g.op("Constant", value_t=torch.tensor(output_size))
coordinate_transformation_mode = "asymmetric" if interpolate_mode == "nearest" \
else "align_corners" if align_corners else "pytorch_half_pixel"
empty_tensor = g.op("Constant", value_t=torch.tensor([], dtype=torch.float32))
return g.op("Resize",
input,
empty_tensor, # roi only takes effect whith coordinate_transformation_mode="tf_crop_and_resize"
empty_tensor, # scales is not needed since we are sending out_size
output_size,
coordinate_transformation_mode_s=coordinate_transformation_mode,
cubic_coeff_a_f=-0.75, # only valid when mode="cubic"
mode_s=interpolate_mode, # nearest, linear, or cubic
nearest_mode_s="floor") # only valid when mode="nearest"
return symbolic_fn
upsample_nearest1d = _interpolate('upsample_nearest1d', 3, "nearest")
upsample_nearest2d = _interpolate('upsample_nearest2d', 4, "nearest")
upsample_nearest3d = _interpolate('upsample_nearest3d', 5, "nearest")
upsample_linear1d = _interpolate('upsample_linear1d', 3, "linear")
upsample_bilinear2d = _interpolate('upsample_bilinear2d', 4, "linear")
upsample_trilinear3d = _interpolate('upsample_trilinear3d', 5, "linear")
upsample_bicubic2d = _interpolate('upsample_bicubic2d', 4, "cubic")
@parse_args('v', 'i', 'v', 'v')
def gather(g, self, dim, index, sparse_grad=False):
if sym_help._maybe_get_const(sparse_grad, 'i'):
return _unimplemented("gather", "sparse_grad == True")
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", self, dim, index, sparse_grad, operator_s="gather")
return g.op("GatherElements", self, index, axis_i=dim)
@parse_args('v', 'i', 'v', 'v')
def scatter(g, self, dim, index, src):
if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
return g.op("ATen", self, dim, index, src, operator_s="scatter")
return g.op("ScatterElements", self, index, src, axis_i=dim)
@parse_args('v', 'i', 'none')
def cumsum(g, self, dim, dtype=None):
dim_tensor = g.op("Constant", value_t=torch.tensor(dim))
csum = g.op("CumSum", self, dim_tensor)
if dtype and dtype.node().kind() != 'prim::Constant':
parsed_dtype = sym_help._get_const(dtype, 'i', 'dtype')
csum = g.op("Cast", csum, to_i=sym_help.scalar_type_to_onnx[parsed_dtype])
return csum
@parse_args('v', 'i', 'i', 'i')
def _unique2(g, self, sorted, return_inverse, return_counts):
u, indices, inverse_indices, counts = g.op("Unique", self, sorted_i=sorted, outputs=4)
return u, inverse_indices, counts
@parse_args('v', 'i', 'i', 'i', 'i')
def unique_dim(g, self, dim, sorted, return_inverse, return_counts):
u, indices, inverse_indices, counts = g.op("Unique", self, axis_i=dim, sorted_i=sorted, outputs=4)
return u, inverse_indices, counts
def round(g, self):
return g.op("Round", self)