From 579a631c44c97de838379b138858277999754b94 Mon Sep 17 00:00:00 2001 From: waytrue17 <52505574+waytrue17@users.noreply.github.com> Date: Thu, 3 Jun 2021 00:23:27 -0700 Subject: [PATCH] [v1.9.x] ONNX fix node output sort (#20327) * fix output sort * fix sanity * fix sanity * fix sanity * Update _export_onnx.py Co-authored-by: Wei Chu Co-authored-by: Zhaoqi Zhu --- python/mxnet/onnx/mx2onnx/_export_onnx.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/mxnet/onnx/mx2onnx/_export_onnx.py b/python/mxnet/onnx/mx2onnx/_export_onnx.py index 20bf2fe7f980..7c96b2896a49 100644 --- a/python/mxnet/onnx/mx2onnx/_export_onnx.py +++ b/python/mxnet/onnx/mx2onnx/_export_onnx.py @@ -46,6 +46,7 @@ # coding: utf-8 # pylint: disable=invalid-name,too-many-locals,no-self-use,too-many-arguments, # pylint: disable=maybe-no-member,too-many-nested-blocks,logging-not-lazy +# pylint: disable=cell-var-from-loop """MXNet to ONNX graph converter functions""" import logging import json @@ -393,15 +394,14 @@ def __init__(self, name, dtype): if not node_output_names: node_output_names = [converted[-1].name] # process node outputs (sort by output index) - def str2int(s): - import re - i = re.search(r'\d{0,2}$', s).group() - if i == '': - return 0 + def str2int(s, name): + l = len(name) + if len(s) == l: + return -1 else: - return int(i) + return int(s[l:]) - sorted(node_output_names, key=str2int) + node_output_names = sorted(node_output_names, key=lambda x: str2int(x, name)) # match the output names to output dtypes if dtypes is not None: