Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x] ONNX export support for multiple input data types (#19796)
Browse files Browse the repository at this point in the history
* add test

* support multiple input nodes

* fix sanity

* update input dtype

* fix typo

* update export_onnx

* fix sanity

* fix space

* update import

* fix sanity

* remove float64 from test_where

* update test

* fix bert test input type

* enable defalut input_type

* more default fix

* fix typo

* fix empty lines

Co-authored-by: Wei Chu <weichu@amazon.com>
  • Loading branch information
waytrue17 and Wei Chu committed Mar 7, 2021
1 parent 9a2a502 commit b14aae2
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 22 deletions.
12 changes: 7 additions & 5 deletions python/mxnet/contrib/onnx/mx2onnx/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def export_model(sym, params, input_shape, input_type=np.float32,
Path to the params file or params dictionary. (Including both arg_params and aux_params)
input_shape : List of tuple
Input shape of the model e.g [(1,3,224,224)]
input_type : data type
Input data type e.g. np.float32
input_type : data type or list of data types
Input data type e.g. np.float32, or [np.float32, np.int32]
onnx_file_path : str
Path where to save the generated onnx file
verbose : Boolean
Expand Down Expand Up @@ -73,17 +73,19 @@ def export_model(sym, params, input_shape, input_type=np.float32,
# default is to use latest opset version the onnx package supports
opset_version = onnx_opset_version()

data_format = np.dtype(input_type)
if not isinstance(input_type, list):
input_type = [input_type for _ in range(len(input_shape))]
input_dtype = [mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(inp_type)] for inp_type in input_type]
# if input parameters are strings(file paths), load files and create symbol parameter objects
if isinstance(sym, string_types) and isinstance(params, string_types):
logging.info("Converting json and weight file to sym and params")
sym_obj, params_obj = load_module(sym, params)
onnx_graph = converter.create_onnx_graph_proto(sym_obj, params_obj, input_shape,
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
input_dtype,
verbose=verbose, opset_version=opset_version)
elif isinstance(sym, symbol.Symbol) and isinstance(params, dict):
onnx_graph = converter.create_onnx_graph_proto(sym, params, input_shape,
mapping.NP_TYPE_TO_TENSOR_TYPE[data_format],
input_dtype,
verbose=verbose, opset_version=opset_version)
else:
raise ValueError("Input sym and params should either be files or objects")
Expand Down
22 changes: 16 additions & 6 deletions python/mxnet/contrib/onnx/mx2onnx/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,19 @@ def get_outputs(sym, params, in_shape, in_label, in_type):

assert len(out_shapes) == len(out_names)

# infer output types
args = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[in_type] for n in sym.list_inputs()}
_, out_type, _ = sym.infer_type(**args)
## Infer output types
# Remove any input listed in params from sym.list_inputs() and bind them to the input types provided
# by user. Also remove in_label
in_dtype = {n: mapping.TENSOR_TYPE_TO_NP_TYPE[t]
for n, t in zip([n for n in sym.list_inputs() if n not in params and n != in_label], in_type)}
# Add params and their types to list of inputs
in_dtype.update({n: v.dtype for n, v in params.items() if n in sym.list_inputs()})
_, out_type, _ = sym.infer_type(**in_dtype)
out_types = [mapping.NP_TYPE_TO_TENSOR_TYPE[o(0).dtype] for o in out_type]

assert len(out_types) == len(out_names)

# bind output shapes with output names
# bind output shapes/types with output names
graph_outputs = {n: {'shape': s, 'dtype': d} for n, s, d in zip(out_names, out_shapes, out_types)}

return graph_outputs
Expand Down Expand Up @@ -256,21 +261,26 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False,
mx_graph=mx_graph,
weights=weights,
in_shape=in_shape[graph_input_idx],
in_type=in_type,
in_type=in_type[graph_input_idx],
proc_nodes=all_processed_nodes,
initializer=initializer,
outputs_lookup=outputs_lookup)
graph_input_idx += 1

else:
# Handle no input case
intype = 1 # Float32 in tensor type
if len(in_type) > 0:
intype = in_type[0]

# Handling graph layers
converted = MXNetGraph.convert_layer(
node,
is_input=False,
mx_graph=mx_graph,
weights=weights,
in_shape=in_shape,
in_type=in_type,
in_type=intype,
proc_nodes=all_processed_nodes,
initializer=initializer,
outputs_lookup=outputs_lookup,
Expand Down
3 changes: 2 additions & 1 deletion python/mxnet/contrib/onnx/onnx2mx/import_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,8 @@ def get_graph_metadata(self, graph):
for graph_input in graph.input:
if graph_input.name not in _params:
shape = [val.dim_value for val in graph_input.type.tensor_type.shape.dim]
input_data.append((graph_input.name, tuple(shape)))
dtype = graph_input.type.tensor_type.elem_type
input_data.append((graph_input.name, tuple(shape), dtype))

output_data = []
for graph_out in graph.output:
Expand Down
7 changes: 4 additions & 3 deletions tests/python-pytest/onnx/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def set_params(cls, backend, operation):
cls.operation = operation

@staticmethod
def perform_import_export(sym, arg_params, aux_params, input_shape):
def perform_import_export(sym, arg_params, aux_params, input_shape, input_dtype):
""" Import ONNX model to mxnet model and then export to ONNX model
and then import it back to mxnet for verifying the result"""
graph = GraphProto()
Expand All @@ -63,7 +63,7 @@ def perform_import_export(sym, arg_params, aux_params, input_shape):
# exporting to onnx graph proto format
converter = MXNetGraph()
graph_proto = converter.create_onnx_graph_proto(sym, params, in_shape=input_shape,
in_type=mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype('float32')],
in_type=input_dtype,
opset_version=opset_version)

# importing back to MXNET for verifying result.
Expand Down Expand Up @@ -108,8 +108,9 @@ def prepare(cls, model, device='CPU', **kwargs):
metadata = graph.get_graph_metadata(model.graph)
input_data = metadata['input_tensor_data']
input_shape = [data[1] for data in input_data]
input_dtype = [data[2] for data in input_data]
sym, arg_params, aux_params = MXNetBackend.perform_import_export(sym, arg_params, aux_params,
input_shape)
input_shape, input_dtype)

return MXNetBackendRep(sym, arg_params, aux_params, device)
elif backend == 'gluon':
Expand Down
1 change: 1 addition & 0 deletions tests/python-pytest/onnx/mxnet_export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _check_onnx_export(net, group_outputs=False, shape_type=tuple, extra_params=
sym=net_sym,
params=net_params,
input_shape=[shape_type(data.shape)],
input_type=[data.dtype],
onnx_file_path=onnx_file_path)
assert export_path == onnx_file_path
# Try importing the model to symbol
Expand Down
6 changes: 3 additions & 3 deletions tests/python-pytest/onnx/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def test_import_export(self):

if mxnet_specific:
onnxmodelfile = onnx_mxnet.export_model(test_op, {}, [np.shape(ip) for ip in inputs],
np.float32,
[ip.dtype for ip in inputs],
onnx_name + ".onnx")
onnxmodel = load_model(onnxmodelfile)
else:
Expand Down Expand Up @@ -190,9 +190,9 @@ def test_import_export(self):
onnx_file_path=outsym.name + ".onnx")

sym, arg_params, aux_params = onnx_mxnet.import_model(converted_model)
result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)
result = forward_pass(sym, arg_params, aux_params, ['input1'], input1)

npt.assert_almost_equal(result, forward_op)
npt.assert_almost_equal(result, forward_op)

def test_imports(self):
for test in import_test_cases:
Expand Down
55 changes: 55 additions & 0 deletions tests/python-pytest/onnx/test_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,3 +555,58 @@ def load_video(filepath):
finally:
shutil.rmtree(tmp_path)

@with_seed()
@pytest.mark.parametrize('model', ['bert_12_768_12'])
def test_bert_inference_onnxruntime(tmp_path, model):
tmp_path = str(tmp_path)
try:
import gluonnlp as nlp
dataset = 'book_corpus_wiki_en_uncased'
ctx = mx.cpu(0)
model, vocab = nlp.model.get_model(
name=model,
ctx=ctx,
dataset_name=dataset,
pretrained=False,
use_pooler=True,
use_decoder=False,
use_classifier=False)
model.initialize(ctx=ctx)
model.hybridize(static_alloc=True)

batch = 5
seq_length = 16
# create synthetic test data
inputs = mx.nd.random.uniform(0, 30522, shape=(batch, seq_length), dtype='float32')
token_types = mx.nd.random.uniform(0, 2, shape=(batch, seq_length), dtype='float32')
valid_length = mx.nd.array([seq_length] * batch, dtype='float32')

seq_encoding, cls_encoding = model(inputs, token_types, valid_length)

prefix = "%s/bert" % tmp_path
model.export(prefix)
sym_file = "%s-symbol.json" % prefix
params_file = "%s-0000.params" % prefix
onnx_file = "%s.onnx" % prefix


input_shapes = [(batch, seq_length), (batch, seq_length), (batch,)]
input_types = [np.float32, np.float32, np.float32]
converted_model_path = mx.contrib.onnx.export_model(sym_file, params_file, input_shapes, input_types, onnx_file)


# create onnxruntime session using the generated onnx file
ses_opt = onnxruntime.SessionOptions()
ses_opt.log_severity_level = 3
session = onnxruntime.InferenceSession(onnx_file, ses_opt)
onnx_inputs = [inputs, token_types, valid_length]
input_dict = dict((session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) for i in range(len(onnx_inputs)))
pred_onx, cls_onx = session.run(None, input_dict)

assert_almost_equal(seq_encoding, pred_onx)
assert_almost_equal(cls_encoding, cls_onx)

finally:
shutil.rmtree(tmp_path)


7 changes: 3 additions & 4 deletions tests/python-pytest/onnx/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,15 @@ def export_to_onnx(model, model_name, inputs):
model.export(model_path, epoch=0)
sym_file = '{}-symbol.json'.format(model_path)
params_file = '{}-0000.params'.format(model_path)
dtype = inputs[0].dtype
onnx_file = '{}/{}.onnx'.format(tmp_path, model_name)
mx.contrib.onnx.export_model(sym_file, params_file, [inp.shape for inp in inputs],
dtype, onnx_file)
[inp.dtype for inp in inputs], onnx_file)
return onnx_file

def onnx_rt(onnx_file, inputs):
sess = rt.InferenceSession(onnx_file)
dtype_0 = inputs[0].asnumpy().dtype
input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy().astype(dtype_0)) for i in range(len(inputs)))
input_dict = dict((sess.get_inputs()[i].name, inputs[i].asnumpy()) for i in range(len(inputs)))
pred = sess.run(None, input_dict)
return pred

Expand Down Expand Up @@ -560,7 +559,7 @@ def test_onnx_export_equal_scalar(tmp_path, dtype, scalar):
op_export_test('_internal._equal_scalar', M, [x], tmp_path)


@pytest.mark.parametrize("dtype", ["float16", "float32", "float64", "int32", "int64"])
@pytest.mark.parametrize("dtype", ["float16", "float32", "int32", "int64"])
@pytest.mark.parametrize("shape", [(1,1), (3,3), (10,2), (20,30,40)])
def test_onnx_export_where(tmp_path, dtype, shape):
M = def_model('where')
Expand Down

0 comments on commit b14aae2

Please sign in to comment.