Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ONNX export: Slice op - Handle None value for ends (#14942)
Browse files Browse the repository at this point in the history
* ONNX export: Slice op - Handle None value for ends
  • Loading branch information
vandanavk committed Feb 20, 2020
1 parent ab48a43 commit 9dcf71d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
11 changes: 7 additions & 4 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,17 +1495,20 @@ def convert_slice_axis(node, **kwargs):

axes = int(attrs.get("axis"))
starts = int(attrs.get("begin"))
ends = int(attrs.get("end", None))
if not ends:
raise ValueError("Slice: ONNX doesnt't support 'None' in 'end' attribute")
ends = attrs.get("end", None)
if not ends or ends == 'None':
# ONNX doesn't support None for ends. Since ends=None depicts
# length of dimension, passing dimension in this case.
in_shape = kwargs['in_shape'][0]
ends = in_shape[axes]

node = onnx.helper.make_node(
"Slice",
input_nodes,
[name],
axes=[axes],
starts=[starts],
ends=[ends],
ends=[int(ends)],
name=name,
)
return [node]
Expand Down
6 changes: 5 additions & 1 deletion python/mxnet/contrib/onnx/onnx2mx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,15 +499,19 @@ def split(attrs, inputs, proto_obj):

def _slice(attrs, inputs, proto_obj):
"""Returns a slice of the input tensor along multiple axes."""
input_tensor_data = proto_obj.model_metadata.get('input_tensor_data')[0]
input_shape = input_tensor_data[1]
new_attrs = translation_utils._fix_attribute_names(attrs,
{'axes' : 'axis',
'ends' : 'end',
'starts' : 'begin'})
# onnx slice provides slicing on multiple axis. Adding multiple slice_axis operator
# for multiple axes from mxnet
begin = new_attrs.get('begin')
end = new_attrs.get('end')
end = list(new_attrs.get('end'))
axes = new_attrs.get('axis', tuple(range(len(begin))))
for i, axis in enumerate(axes):
end[i] = None if end[i] >= input_shape[axis] else end[i]
slice_op = symbol.slice_axis(inputs[0], axis=axes[0], begin=begin[0], end=end[0])
if len(axes) > 1:
for i, axis in enumerate(axes):
Expand Down
1 change: 1 addition & 0 deletions tests/python-pytest/onnx/test_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
'test_globalaveragepool',
'test_slice_cpu',
'test_slice_neg',
'test_slice_end',
'test_reciprocal',
'test_sqrt',
'test_pow',
Expand Down

0 comments on commit 9dcf71d

Please sign in to comment.