Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[v1.x] Onnx fix slice_axis and embedding and reshape #19677

Merged
merged 9 commits into from
Dec 17, 2020
176 changes: 94 additions & 82 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,12 @@ def create_const_node(input_name, value, kwargs):
return value_node

def create_tensor(shape_list, shape_name, initializer, dtype='int64'):
"""Helper function to create a tensor value node and a
initializer tensor node with constant value."""
shape_np = np.array(shape_list, dtype=dtype)
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[shape_np.dtype]
dims = np.shape(shape_np)
tensor_node = onnx.helper.make_tensor_value_info(shape_name, data_type, dims)
initializer.append(
onnx.helper.make_tensor(
name=shape_name,
Expand All @@ -190,6 +193,7 @@ def create_tensor(shape_list, shape_name, initializer, dtype='int64'):
raw=False
)
)
return tensor_node

@mx_op.register("null")
def convert_weights_and_inputs(node, **kwargs):
Expand Down Expand Up @@ -1543,53 +1547,58 @@ def convert_reshape(node, **kwargs):
Converts output shape attribute to output shape tensor
and return multiple created nodes.
"""
from onnx.helper import make_node

name, input_nodes, attrs = get_inputs(node, kwargs)

reverse = attrs.get('reverse', 'False')
output_shape_list = convert_string_to_list(attrs["shape"])
data_shape = list(kwargs['in_shape'][0])
if reverse == 'True':
output_shape_list.reverse()
data_shape.reverse()
for i, dim in enumerate(output_shape_list):
if dim == 0:
output_shape_list[i] = data_shape[i]
output_shape_list.reverse()

initializer = kwargs["initializer"]
output_shape_np = np.array(output_shape_list, dtype='int64')
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[output_shape_np.dtype]
dims = np.shape(output_shape_np)

output_shape_name = "reshape_attr_tensor" + str(kwargs["idx"])
tensor_node = onnx.helper.make_tensor_value_info(output_shape_name, data_type, dims)

initializer.append(
onnx.helper.make_tensor(
name=output_shape_name,
data_type=data_type,
dims=dims,
vals=output_shape_list,
raw=False,
)
)

input_nodes.append(output_shape_name)
targ_shape = convert_string_to_list(attrs["shape"])

not_supported_shape = [-2, -3, -4]

for val in output_shape_list:
for val in targ_shape:
if val in not_supported_shape:
raise AttributeError("Reshape: Shape value not supported in ONNX", val)

reshape_node = onnx.helper.make_node(
"Reshape",
input_nodes,
[name],
name=name
)
nodes = [
create_tensor(targ_shape, name+'_targ_shape', kwargs['initializer'])
]

if reverse == 'False':
nodes += [
make_node('Reshape', [input_nodes[0], name+'_targ_shape'], [name], name=name)
]
else:
nodes += [
create_tensor([0], name+'_0', kwargs['initializer']),
create_tensor([1], name+'_1', kwargs['initializer']),
make_node('Shape', [name+'_targ_shape'], [name+'_targ_dim']),
make_node('Shape', [input_nodes[0]], [name+'_orig_shape']),
make_node('Shape', [name+'_orig_shape'], [name+'_orig_dim']),
make_node('Sub', [name+'_targ_dim', name+'_orig_dim'], [name+'_dim_diff']),
make_node('Abs', [name+'_dim_diff'], [name+'_pad_len']),
make_node('Less', [name+'_targ_dim', name+'_orig_dim'], [name+'_targ_less_orig']),
make_node('Less', [name+'_orig_dim', name+'_targ_dim'], [name+'_orig_less_targ']),
make_node('Where', [name+'_targ_less_orig', name+'_pad_len', name+'_0'],
[name+'_targ_pad_len']),
make_node('Where', [name+'_orig_less_targ', name+'_pad_len', name+'_0'],
[name+'_orig_pad_len']),
make_node('Concat', [name+'_targ_pad_len', name+'_0'], [name+'_targ_pads'], axis=0),
make_node('Concat', [name+'_orig_pad_len', name+'_0'], [name+'_orig_pads'], axis=0),
make_node('Pad', [name+'_targ_shape', name+'_targ_pads', name+'_1'],
[name+'_targ_shape_padded'], mode='constant'),
make_node('Pad', [name+'_orig_shape', name+'_orig_pads', name+'_1'],
[name+'_orig_shape_padded'], mode='constant'),
make_node('Equal', [name+'_targ_shape_padded', name+'_0'],
[name+'_targ_shape_0_mask']),
make_node('Where', [name+'_targ_shape_0_mask', name+'_orig_shape_padded',
name+'_targ_shape_padded'], [name+'_targ_shape_new']),
make_node('Shape', [name+'_targ_shape_new'], [name+'_targ_new_dim']),
make_node('Slice', [name+'_targ_shape_new', name+'_targ_pad_len',
name+'_targ_new_dim'], [name+'_targ_shape_final']),
make_node('Reshape', [input_nodes[0], name+'_targ_shape_final'], [name], name=name)
]

return [tensor_node, reshape_node]
return nodes

@mx_op.register("Cast")
def convert_cast(node, **kwargs):
Expand Down Expand Up @@ -1623,27 +1632,34 @@ def convert_slice_axis(node, **kwargs):
"""Map MXNet's slice_axis operator attributes to onnx's Slice operator
and return the created node.
"""
from onnx.helper import make_node
name, input_nodes, attrs = get_inputs(node, kwargs)

axes = int(attrs.get("axis"))
starts = int(attrs.get("begin"))
ends = attrs.get("end", None)
if not ends or ends == 'None':
axis = int(attrs.get("axis"))
begin = int(attrs.get("begin"))
end = attrs.get("end", None)

nodes = []
create_tensor([axis], name+'_axis', kwargs["initializer"])
create_tensor([begin], name+'_begin', kwargs["initializer"])
if not end or end == 'None':
# ONNX doesn't support None for ends. Since ends=None depicts
# length of dimension, passing dimension in this case.
in_shape = kwargs['in_shape'][0]
ends = in_shape[axes]
create_tensor([axis+1], name+"_axis_plus_1", kwargs["initializer"])
nodes += [
make_node('Shape', [input_nodes[0]], [name+"_data_shape"]),
make_node('Slice', [name+'_data_shape', name+'_axis', name+'_axis_plus_1'],
Zha0q1 marked this conversation as resolved.
Show resolved Hide resolved
[name+"_end"])
]
else:
create_tensor([int(end)], name+'_end', kwargs["initializer"])

node = onnx.helper.make_node(
"Slice",
input_nodes,
[name],
axes=[axes],
starts=[starts],
ends=[int(ends)],
name=name,
)
return [node]
nodes += [
make_node('Slice', [input_nodes[0], name+'_begin', name+'_end', name+'_axis'],
[name], name=name)
]

return nodes


@mx_op.register("SliceChannel")
Expand Down Expand Up @@ -2289,15 +2305,14 @@ def convert_matmul_selfatt_qk(node, **kwargs):
heads = int(attrs.get('heads'))

# a, b, c, d, e are seq_len, batch_size, num_heads, 3, head_dim respectively
create_tensor([0], name+"_0", kwargs["initializer"])
create_tensor([1], name+"_1", kwargs["initializer"])
create_tensor([1], name+"_1_f", kwargs["initializer"], dtype='float32')
create_tensor([2], name+"_2", kwargs["initializer"])
create_tensor([3], name+"_3", kwargs["initializer"])
create_tensor([heads], name+"_c", kwargs["initializer"])
create_tensor([3], name+"_d", kwargs["initializer"])

nodes = [
create_tensor([0], name+"_0", kwargs["initializer"]),
create_tensor([1], name+"_1", kwargs["initializer"]),
create_tensor([1], name+"_1_f", kwargs["initializer"], dtype='float32'),
create_tensor([2], name+"_2", kwargs["initializer"]),
create_tensor([3], name+"_3", kwargs["initializer"]),
create_tensor([heads], name+"_c", kwargs["initializer"]),
create_tensor([3], name+"_d", kwargs["initializer"]),
make_node('Shape', [input_nodes[0]], [name+"_data_shape"]),
make_node('Slice', [name+'_data_shape', name+'_0', name+'_1'], [name+"_a"]),
make_node('Slice', [name+'_data_shape', name+'_1', name+'_2'], [name+"_b"]),
Expand Down Expand Up @@ -2358,14 +2373,13 @@ def convert_broadcast_axis(node, **kwargs):
size = convert_string_to_list(attrs.get('size', '()'))
assert len(axis) == len(size)

create_tensor([0], name+'_0', kwargs["initializer"])
create_tensor([1], name+'_1', kwargs["initializer"])
create_tensor([], name+'_void', kwargs["initializer"])
create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)

shape_name = name+'_shape_0'
nodes = [
create_tensor([0], name+'_0', kwargs["initializer"]),
create_tensor([1], name+'_1', kwargs["initializer"]),
create_tensor([], name+'_void', kwargs["initializer"]),
create_const_scalar_node(name+'_0_s', np.int64(0), kwargs),
create_const_scalar_node(name+'_1_s', np.int64(1), kwargs),
make_node('Shape', [input_nodes[0]], [shape_name]),
make_node('Shape', [shape_name], [name+'_in_dim']),
make_node('Reshape', [name+'_in_dim', name+'_void'], [name+'_in_dim_s']),
Expand All @@ -2374,17 +2388,16 @@ def convert_broadcast_axis(node, **kwargs):

for i, axis in enumerate(axis):
if axis not in (0, 1):
create_tensor([axis], name+'_'+str(axis), kwargs["initializer"])
create_tensor([size[i]-1], name+'_size_'+str(i), kwargs["initializer"])
_ = [
nodes += [create_tensor([axis], name+'_'+str(axis), kwargs["initializer"])]
nodes += [
create_tensor([size[i]-1], name+'_size_'+str(i), kwargs["initializer"]),
make_node('Equal', [name+'_range', name+'_'+str(axis)], [name+'_equal_'+str(i)]),
make_node('Cast', [name+'_equal_'+str(i)], [name+'_cast_'+str(i)], to=int(TensorProto.INT64)),
make_node('Mul', [name+'_size_'+str(i), name+'_cast_'+str(i)], [name+'_mul_'+str(i)]),
make_node('Add', [name+'_mul_'+str(i), name+'_1'], [name+'_add_'+str(i)]),
make_node('Mul', [name+'_add_'+str(i), shape_name], [name+'_shape_'+str(i+1)])
]
shape_name = name+'_shape_'+str(i+1)
nodes += _

nodes += [make_node('Expand', [input_nodes[0], shape_name], [name], name=name)]

Expand All @@ -2407,16 +2420,15 @@ def convert_sequencemask(node, **kwargs):
if(use_sequence_length == 'False'):
return [make_node('Identity', [input_nodes[0]], [name], name=name)]

create_tensor([], name+'_void', kwargs["initializer"])
create_tensor([0], name+'_0', kwargs["initializer"])
create_tensor([1], name+'_1', kwargs["initializer"])
create_tensor([2], name+'_2', kwargs["initializer"])
create_const_scalar_node(name+'_0_s', np.int64(0), kwargs)
create_const_scalar_node(name+'_1_s', np.int64(1), kwargs)
create_const_scalar_node(name+'_2_s', np.int64(2), kwargs)
create_tensor([mask_val], name+'_mask_val', kwargs["initializer"], dtype='float32')

nodes = [
create_tensor([], name+'_void', kwargs["initializer"]),
create_tensor([0], name+'_0', kwargs["initializer"]),
create_tensor([1], name+'_1', kwargs["initializer"]),
create_tensor([2], name+'_2', kwargs["initializer"]),
create_const_scalar_node(name+'_0_s', np.int64(0), kwargs),
create_const_scalar_node(name+'_1_s', np.int64(1), kwargs),
create_const_scalar_node(name+'_2_s', np.int64(2), kwargs),
create_tensor([mask_val], name+'_mask_val', kwargs["initializer"], dtype='float32'),
make_node('Shape', [input_nodes[0]], [name+'_in_shape']),
make_node('Slice', [name+'_in_shape', name+'_0', name+'_1'], [name+'_slice_0']),
make_node('Slice', [name+'_in_shape', name+'_1', name+'_2'], [name+'_slice_1']),
Expand Down Expand Up @@ -2459,7 +2471,7 @@ def convert_embedding(node, **kwargs):
axis = int(attrs.get('axis', 0))
node = onnx.helper.make_node(
"Gather",
input_nodes,
[input_nodes[1], input_nodes[0]],
[name],
axis=axis,
name=name
Expand Down
44 changes: 39 additions & 5 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,52 @@ def test_onnx_export_SequenceMask(tmp_path, dtype):
op_export_test('SequenceMask_2', M2, [x, seq_len2], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32'])
@pytest.mark.parametrize('dtype', ['float32'])
def test_onnx_export_contrib_interleaved_matmul_selfatt_qk(tmp_path, dtype):
M1 = def_model('contrib.interleaved_matmul_selfatt_qk', heads=3)
x1 = mx.nd.random.uniform(0, 1, (3, 3, 3*3*3))
x1 = mx.nd.random.uniform(0, 1, (3, 3, 3*3*3), dtype=dtype)
op_export_test('contrib_interleaved_matmul_selfatt_qk_1', M1, [x1], tmp_path)
M2 = def_model('contrib.interleaved_matmul_selfatt_qk', heads=5)
x2 = mx.nd.random.uniform(0, 1, (7, 5, 4*5*6))
x2 = mx.nd.random.uniform(0, 1, (7, 5, 4*5*6), dtype=dtype)
op_export_test('contrib_interleaved_matmul_selfatt_qk_2', M2, [x2], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32'])
def test_onnx_export_slice_axis(tmp_path, dtype):
x = mx.nd.array([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]], dtype=dtype)
M1 = def_model('slice_axis', axis=0, begin=1, end=3)
M2 = def_model('slice_axis', axis=0, begin=1, end=None)
M3 = def_model('slice_axis', axis=1, begin=-3, end=-1)
op_export_test('slice_axis_1', M1, [x], tmp_path)
op_export_test('slice_axis_2', M2, [x], tmp_path)
op_export_test('slice_axis_3', M3, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
def test_onnx_export_reshape(tmp_path, dtype):
x = mx.nd.ones((2, 3, 4, 5, 6), dtype=dtype)
M1 = def_model('reshape', shape=(6, 1, 0, -1))
op_export_test('reshape_1', M1, [x], tmp_path)
M2 = def_model('reshape', shape=(3, -1, 0, 0), reverse=True)
op_export_test('reshape_2', M2, [x], tmp_path)
M3 = def_model('reshape', shape=(5, 1, 1, 1, 1, 0 -1, 0), reverse=True)
op_export_test('reshape_3', M3, [x], tmp_path)


@pytest.mark.parametrize('dtype', ['int32', 'int64'])
def test_onnx_export_embedding(tmp_path, dtype):
x = mx.nd.array([[ 1., 3.],
[ 0., 2.]], dtype=dtype)
y = mx.nd.array([[ 0., 1., 2., 3., 4.],
[ 5., 6., 7., 8., 9.],
[ 10., 11., 12., 13., 14.],
[ 15., 16., 17., 18., 19.]], dtype=dtype)
M = def_model('Embedding', input_dim=4, output_dim=5)
op_export_test('Embedding', M, [x, y], tmp_path)


@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int32', 'int64'])
@pytest.mark.parametrize('num_hidden', [1, 5, 10, 20])
@pytest.mark.parametrize('no_bias', [False, True])
Expand All @@ -159,5 +195,3 @@ def test_onnx_export_fully_connected(tmp_path, dtype, num_hidden, no_bias, flatt
if not no_bias:
args.append(mx.nd.random.uniform(0,1,(num_hidden,)))
op_export_test('FullyConnected', M, args, tmp_path)