Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] Add onnx export support and unit tests for zeros and ones. (#1…
Browse files Browse the repository at this point in the history
…9951)

* Add onnx export support and unit tests for zeros and ones.

* Fix lint

Co-authored-by: Joe Evans <joeev@amazon.com>
  • Loading branch information
josephevans and Joe Evans committed Feb 24, 2021
1 parent e1d1105 commit 01a6f3d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
34 changes: 34 additions & 0 deletions python/mxnet/contrib/onnx/mx2onnx/_op_translations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2832,6 +2832,40 @@ def convert_slice(node, **kwargs):
return nodes


@mx_op.register("_zeros")
def convert_zeros(node, **kwargs):
"""Map MXNet's zeros operator attributes to onnx's ConstantOfShape operator.
"""
from onnx.helper import make_node, make_tensor
name, _, attrs = get_inputs(node, kwargs)
dtype = attrs.get('dtype')
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]
shape = convert_string_to_list(attrs.get('shape'))
create_tensor(shape, name+'_shape', kwargs['initializer'])
tensor_value = make_tensor(name+'_zero', data_type, [1], [0])
nodes = [
make_node('ConstantOfShape', [name+'_shape'], [name], name=name, value=tensor_value)
]
return nodes


@mx_op.register("_ones")
def convert_ones(node, **kwargs):
"""Map MXNet's ones operator attributes to onnx's ConstantOfShape operator.
"""
from onnx.helper import make_node, make_tensor
name, _, attrs = get_inputs(node, kwargs)
dtype = attrs.get('dtype')
data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)]
shape = convert_string_to_list(attrs.get('shape'))
create_tensor(shape, name+'_shape', kwargs['initializer'])
tensor_value = make_tensor(name+'_one', data_type, [1], [1])
nodes = [
make_node('ConstantOfShape', [name+'_shape'], [name], name=name, value=tensor_value)
]
return nodes


@mx_op.register("zeros_like")
def convert_zeros_like(node, **kwargs):
"""Map MXNet's zeros_like operator attributes to onnx's ConstantOfShape operator.
Expand Down
15 changes: 14 additions & 1 deletion tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,19 @@ def test_onnx_export_stack(tmp_path):
y = mx.nd.array([3, 4], dtype='float32')
op_export_test('stack', M, [x, y], tmp_path)

@pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64"])
@pytest.mark.parametrize("shape", [(1), (1,2), (2,3,4), (5,6,7)])
def test_onnx_export_zeros(tmp_path, dtype, shape):
M = def_model('zeros', shape=shape, dtype=dtype, dummy_input=True)
x = mx.nd.array([1])
op_export_test('zeros', M, [x], tmp_path, dummy_input=True)

@pytest.mark.parametrize("dtype", [None, "float32", "float64", "int32", "int64"])
@pytest.mark.parametrize("shape", [(1), (1,2), (2,3,4), (5,6,7)])
def test_onnx_export_ones(tmp_path, dtype, shape):
M = def_model('ones', shape=shape, dtype=dtype, dummy_input=True)
x = mx.nd.array([0])
op_export_test('ones', M, [x], tmp_path, dummy_input=True)

def test_onnx_export_zeros_like(tmp_path):
M = def_model('zeros_like')
Expand Down Expand Up @@ -1167,4 +1180,4 @@ def test_onnx_export_take_raise(tmp_path, dtype, axis):
x = mx.nd.random.normal(0, 10, (3, 4, 5)).astype(dtype)
y = mx.random.randint(0, 3, (6, 7)).astype(dtype)
M = def_model('take', axis=axis, mode='raise')
op_export_test('take', M, [x, y], tmp_path)
op_export_test('take', M, [x, y], tmp_path)

0 comments on commit 01a6f3d

Please sign in to comment.