Skip to content

Commit

Permalink
[ONNX] Enable remaining failed tests in opset13 (pytorch#50806)
Browse files Browse the repository at this point in the history
* enable remaining test in opset13

* add comments for error version test info

* fix comments:opset12 unbind problem

* add ignore[no-redef]

* fix format

Co-authored-by: hwangdeyu <deyhuang@qq.com>

ghstack-source-id: 4ea81e90c66678d953b5219b9672444eedde62f3
Pull Request resolved: pytorch#51518
  • Loading branch information
BowenBao committed Feb 3, 2021
1 parent 287482e commit 01a3287
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 21 deletions.
18 changes: 2 additions & 16 deletions test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,6 @@ def get_test_images(self):

return [image], [image2]

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest() # Faster RCNN model is not scriptable
def test_faster_rcnn(self):
Expand Down Expand Up @@ -490,7 +489,6 @@ def test_paste_mask_in_image(self):

assert torch.all(out2.eq(out_trace2))

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_mask_rcnn(self):
Expand Down Expand Up @@ -535,7 +533,6 @@ def test_heatmaps_to_keypoints(self):
assert torch.all(out2[0].eq(out_trace2[0]))
assert torch.all(out2[1].eq(out_trace2[1]))

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_keypoint_rcnn(self):
Expand All @@ -557,7 +554,6 @@ def test_keypoint_rcnn(self):
dynamic_axes={"images_tensors": [0, 1, 2]},
rtol=5e-3, atol=1e-5)

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_shufflenet_v2_dynamic_axes(self):
Expand Down Expand Up @@ -1248,8 +1244,8 @@ def forward(self, x):
x = torch.randn(2, 3, 4)
self.run_test(FloatingPoint(), x)

@unittest.skip("If operator rank mismatch between outputs of two branches.")
@skipIfUnsupportedMinOpsetVersion(9)
# Operator rank mismatch between outputs of two branches for opsets below 11.
@skipIfUnsupportedMinOpsetVersion(11)
@skipIfONNXShapeInference(False)
def test_floating_point_infer_dtype(self):
class FloatingPoint(torch.jit.ScriptModule):
Expand Down Expand Up @@ -1802,7 +1798,6 @@ def forward(self, hidden_states):
dynamic_axes={'x': {0: 'seq_length', 1: 'batch_size'}}, test_with_inputs=[y])

@skipIfUnsupportedMinOpsetVersion(11)
@skipIfUnsupportedOpsetVersion([13])
def test_copy_(self):
class CopyModel(torch.nn.Module):
def forward(self, x, data):
Expand Down Expand Up @@ -3612,7 +3607,6 @@ def forward(self, x, y, z, ind):
ind = torch.tensor(-2, dtype=torch.long)
self.run_test(GetItemModel(), (x, y, z, ind))

@skipIfUnsupportedOpsetVersion([13])
@disableScriptTest() # torch.nonzero(x, as_tuple=True) is not scriptable.
@skipIfUnsupportedMinOpsetVersion(9)
def test_nonzero(self):
Expand Down Expand Up @@ -4416,7 +4410,6 @@ def forward(self, input):
x = torch.randint(10, (2, 3))
self.run_test(FModModel(), x)

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(9)
def test_glu(self):
class GluModel(torch.nn.Module):
Expand Down Expand Up @@ -5510,7 +5503,6 @@ def forward(self, cond, input, other):
z = torch.ones(2, 3, 1)
self.run_test(Model(), (x, y, z))

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(9)
@disableScriptTest() # scripting tests run for opsets > 11. See: test_where_condition_script
def test_where_condition(self):
Expand Down Expand Up @@ -6127,7 +6119,6 @@ def forward(self, x, y):
"ScriptModel - Initializers' sequence is not as same as named_parameters(). Expected: (" \
+ ', '.join(named_params_list) + "). Actual:(" + ', '.join(actual_list) + ")."

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
def test_nms(self):
boxes = torch.rand(5, 4)
Expand Down Expand Up @@ -6158,15 +6149,13 @@ def forward(self, boxes, size):
dynamic_axes={"size": [0, 1]},
test_with_inputs=[(boxes, size), (boxes, size_2)])

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
def test_roi_align(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
single_roi = torch.tensor([[0, 0, 0, 4, 4]], dtype=torch.float32)
model = ops.RoIAlign((5, 5), 1., 2)
self.run_test(model, (x, single_roi))

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
def test_roi_align_aligned(self):
x = torch.rand(1, 1, 10, 10, dtype=torch.float32)
Expand Down Expand Up @@ -6242,7 +6231,6 @@ def get_features(self, images):
features = OrderedDict(features)
return features

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_rpn(self):
Expand Down Expand Up @@ -6272,7 +6260,6 @@ def forward(self, images, features):
test_with_inputs=[(images, features), (images2, test_features)],
dict_check=False)

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_multi_scale_roi_align(self):
Expand Down Expand Up @@ -6300,7 +6287,6 @@ def forward(self, input, boxes):

self.run_test(TransformModule(), (i, [boxes],), test_with_inputs=[(i, [boxes],), (i1, [boxes1],)])

@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(11)
@disableScriptTest()
def test_roi_heads(self):
Expand Down
4 changes: 3 additions & 1 deletion torch/onnx/symbolic_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,10 @@ def __interpolate_helper(g, input, size, scale_factor, mode, align_corners, reco
def _unbind_helper(g, self, dim, _outputs):
if _export_onnx_opset_version < 11:
from torch.onnx.symbolic_opset9 import unbind
else:
elif _export_onnx_opset_version <= 12:
from torch.onnx.symbolic_opset11 import unbind # type: ignore[no-redef]
else:
from torch.onnx.symbolic_opset13 import unbind # type: ignore[no-redef]
return unbind(g, self, dim, _outputs)


Expand Down
19 changes: 15 additions & 4 deletions torch/onnx/symbolic_opset13.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
import torch.onnx.symbolic_helper as sym_help
from torch.onnx.symbolic_helper import parse_args, _unimplemented
from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input
from torch.onnx.symbolic_opset9 import overload_by_arg_count, _maybe_cast_reduce_op_input, nonzero


# EDITING THIS FILE? READ THIS FIRST!
Expand Down Expand Up @@ -106,9 +106,20 @@ def unbind(g, self, dim=0, _outputs=None):
return squeezed_outputs


def glu(g, input, dim):
first, second = g.op('Split', input, dim, outputs=2)
return g.op('Mul', first, g.op('Sigmoid', second))
# Emitted from `torch.nonzero(x, as_tuple=True)`
def nonzero_numpy(g, input, _outputs=None):
return unbind(g, nonzero(g, input), 1, _outputs=_outputs)


@parse_args('v', 'v', 'v', 'i')
def where(g, condition, self=None, other=None, _outputs=None):
# Assumes that torch.where's first argument takes only Bool and Byte tensors.
if condition.type().scalarType() != 'Bool':
condition = g.op("Cast", condition, to_i=sym_help.cast_pytorch_to_onnx['Bool'])
if self is None:
condition = nonzero(g, condition)
return sym_help._unbind_helper(g, condition, g.op("Constant", value_t=torch.tensor(1)), _outputs)
return g.op("Where", condition, self, other)


def _reduce_op_symbolic(onnx_op_name):
Expand Down

0 comments on commit 01a3287

Please sign in to comment.