From 54aba8dd820996ac1d99145806e80ffa6dfa3224 Mon Sep 17 00:00:00 2001 From: pinto0309 Date: Wed, 31 May 2023 22:45:36 +0900 Subject: [PATCH] Improved stability of transformations when `axes` is None and slices are for multiple axes --- README.md | 4 ++-- onnx2tf/__init__.py | 2 +- onnx2tf/ops/Slice.py | 34 +++++++++++++++++++--------------- 3 files changed, 22 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 889beec9..3e435eea 100644 --- a/README.md +++ b/README.md @@ -249,11 +249,11 @@ Video speed is adjusted approximately 50 times slower than actual speed. Username (xxxx): {Enter} Password: {Personal Access Token} Login Succeeded - + $ docker run --rm -it \ -v `pwd`:/workdir \ -w /workdir \ - ghcr.io/pinto0309/onnx2tf:1.13.3 + ghcr.io/pinto0309/onnx2tf:1.13.4 or diff --git a/onnx2tf/__init__.py b/onnx2tf/__init__.py index 60fec8b1..9802e96d 100644 --- a/onnx2tf/__init__.py +++ b/onnx2tf/__init__.py @@ -1,3 +1,3 @@ from onnx2tf.onnx2tf import convert, main -__version__ = '1.13.3' +__version__ = '1.13.4' diff --git a/onnx2tf/ops/Slice.py b/onnx2tf/ops/Slice.py index b610bfac..b9314d11 100644 --- a/onnx2tf/ops/Slice.py +++ b/onnx2tf/ops/Slice.py @@ -400,21 +400,25 @@ def make_node( else: onnx_slice_dims_count = len(starts) - tf_layers_dict[graph_node_output.name]['tf_node'] = \ - stridedslice_with_flexing_deterrence( - input_tensor=input_tensor, - begin=begin_, - end=end_, - strides=strides_, - begin_mask=begin_mask_, - end_mask=end_mask_, - ignore_axes=axes, - compression_defult_value=COMPRESSION_DEFAULT_VALUE, - onnx_slice_dims_count=onnx_slice_dims_count, - output_shape=tf_layers_dict[graph_node_output.name]['tf_node'].shape, - name=graph_node.name, - **kwargs, - ) + if onnx_slice_dims_count > COMPRESSION_DEFAULT_VALUE: + ignore_axes = axes + if axes is None: + ignore_axes = [idx for idx in range(input_tensor_rank)] + tf_layers_dict[graph_node_output.name]['tf_node'] = \ + stridedslice_with_flexing_deterrence( + input_tensor=input_tensor, + begin=begin_, + end=end_, + strides=strides_, + begin_mask=begin_mask_, + end_mask=end_mask_, + ignore_axes=ignore_axes, + compression_defult_value=COMPRESSION_DEFAULT_VALUE, + onnx_slice_dims_count=onnx_slice_dims_count, + output_shape=tf_layers_dict[graph_node_output.name]['tf_node'].shape, + name=graph_node.name, + **kwargs, + ) else: # OP replacement tf_layers_dict[graph_node_output.name]['tf_node'] = \