Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix embedding and output order (#20305)
Browse files Browse the repository at this point in the history
Co-authored-by: Wei Chu <weichu@amazon.com>
  • Loading branch information
waytrue17 and Wei Chu committed May 26, 2021
1 parent 7ce251e commit c2ab102
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
13 changes: 11 additions & 2 deletions python/mxnet/onnx/mx2onnx/_export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,17 @@ def __init__(self, name, dtype):
# if node_output_names is empty then we use the last returned node as output
if not node_output_names:
node_output_names = [converted[-1].name]
# process node outputs (sort by alphabetical order)
node_output_names.sort()
# process node outputs (sort by output index)
def str2int(s):
import re
i = re.search(r'\d{0,2}$', s).group()
if i == '':
return 0
else:
return int(i)

sorted(node_output_names, key=str2int)

# match the output names to output dtypes
if dtypes is not None:
assert len(node_output_names) == len(dtypes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3094,13 +3094,14 @@ def convert_embedding(node, **kwargs):

name, input_nodes, attrs = get_inputs(node, kwargs)
axis = int(attrs.get('axis', 0))
dtype = str(attrs.get('dtype', 'float32'))

nodes = [
make_node('Cast', [input_nodes[0]], [name+'_indices_casted'], to=int(TensorProto.INT64)),
make_node('Gather', [input_nodes[1], name+'_indices_casted'], [name], axis=axis, name=name)
]

return nodes
return nodes, (dtype, )


@mx_op.register("stack")
Expand Down

0 comments on commit c2ab102

Please sign in to comment.