Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Onnx fix slice_axis and embedding and reshape #19677

Merged
merged 9 commits into from
Dec 17, 2020
39 changes: 23 additions & 16 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1645,30 +1645,37 @@ def convert_cast(node, **kwargs):

@mx_op.register("slice_axis")
def convert_slice_axis(node, **kwargs):
from onnx.helper import make_node
"""Map MXNet's slice_axis operator attributes to onnx's Slice operator
and return the created node.
"""
name, input_nodes, attrs = get_inputs(node, kwargs)

axes = int(attrs.get("axis"))
starts = int(attrs.get("begin"))
ends = attrs.get("end", None)
if not ends or ends == 'None':
axis = int(attrs.get("axis"))
begin = int(attrs.get("begin"))
end = attrs.get("end", None)

nodes = []
create_tensor([axis], name+'_axis',kwargs["initializer"])
create_tensor([begin], name+'_begin',kwargs["initializer"])
if not end or end == 'None':
# ONNX doesn't support None for ends. Since ends=None depicts
# length of dimension, passing dimension in this case.
in_shape = kwargs['in_shape'][0]
ends = in_shape[axes]
create_tensor([axis+1], name+"_axis_plus_1", kwargs["initializer"])
nodes += [
make_node('Shape', [input_nodes[0]], [name+"_data_shape"]),
make_node('Slice', [name+'_data_shape', name+'_axis', name+'_axis_plus_1'],
Zha0q1 marked this conversation as resolved.
Show resolved Hide resolved
[name+"_end"]),
]
else:
create_tensor([int(end)], name+'_end',kwargs["initializer"])

node = onnx.helper.make_node(
"Slice",
input_nodes,
[name],
axes=[axes],
starts=[starts],
ends=[int(ends)],
name=name,
)
return [node]
nodes += [
make_node('Slice', [input_nodes[0], name+'_begin', name+'_end', name+'_axis'],
[name], name=name)
]

return nodes


@mx_op.register("SliceChannel")
Expand Down
27 changes: 24 additions & 3 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,32 @@ def test_onnx_export_SequenceMask(tmp_path, dtype):
op_export_test('SequenceMask_2', M2, [x, seq_len2], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32'])
@pytest.mark.parametrize('dtype', ['float32'])
def test_onnx_export_contrib_interleaved_matmul_selfatt_qk(tmp_path, dtype):
M1 = def_model('contrib.interleaved_matmul_selfatt_qk', heads=3)
x1 = mx.nd.random.uniform(0, 1, (3, 3, 3*3*3))
x1 = mx.nd.random.uniform(0, 1, (3, 3, 3*3*3), dtype=dtype)
op_export_test('contrib_interleaved_matmul_selfatt_qk_1', M1, [x1], tmp_path)
M2 = def_model('contrib.interleaved_matmul_selfatt_qk', heads=5)
x2 = mx.nd.random.uniform(0, 1, (7, 5, 4*5*6))
x2 = mx.nd.random.uniform(0, 1, (7, 5, 4*5*6), dtype=dtype)
op_export_test('contrib_interleaved_matmul_selfatt_qk_2', M2, [x2], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32'])
def test_onnx_export_slice_axis(tmp_path, dtype):
x = mx.nd.array([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]], dtype=dtype)
M1 = def_model('slice_axis', axis=0, begin=1, end=3)
M2 = def_model('slice_axis', axis=0, begin=1, end=None)
M3 = def_model('slice_axis', axis=1, begin=-3, end=-1)
op_export_test('slice_axis_1', M1, [x], tmp_path)
op_export_test('slice_axis_2', M2, [x], tmp_path)
op_export_test('slice_axis_3', M3, [x], tmp_path)

@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32'])
def test_onnx_export_reshape(tmp_path, dtype):
x = mx.nd.ones((2, 3, 4, 5, 6), dtype=dtype)
M1 = def_model('reshape', shape=(2, 1, 1, -1, 0, 1, 0), reverse=True)
op_export_test('reshape_1', M1, [x], tmp_path)
M2 = def_model('reshape', shape=(6, 1, 0, -1))
op_export_test('reshape_2', M2, [x], tmp_path)