Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved stability of transformations when axes is None and slices are for multiple axes #377

Merged
merged 1 commit into from May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Expand Up @@ -249,11 +249,11 @@ Video speed is adjusted approximately 50 times slower than actual speed.
Username (xxxx): {Enter}
Password: {Personal Access Token}
Login Succeeded

$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.13.3
ghcr.io/pinto0309/onnx2tf:1.13.4

or

Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.13.3'
__version__ = '1.13.4'
34 changes: 19 additions & 15 deletions onnx2tf/ops/Slice.py
Expand Up @@ -400,21 +400,25 @@ def make_node(
else:
onnx_slice_dims_count = len(starts)

tf_layers_dict[graph_node_output.name]['tf_node'] = \
stridedslice_with_flexing_deterrence(
input_tensor=input_tensor,
begin=begin_,
end=end_,
strides=strides_,
begin_mask=begin_mask_,
end_mask=end_mask_,
ignore_axes=axes,
compression_defult_value=COMPRESSION_DEFAULT_VALUE,
onnx_slice_dims_count=onnx_slice_dims_count,
output_shape=tf_layers_dict[graph_node_output.name]['tf_node'].shape,
name=graph_node.name,
**kwargs,
)
if onnx_slice_dims_count > COMPRESSION_DEFAULT_VALUE:
ignore_axes = axes
if axes is None:
ignore_axes = [idx for idx in range(input_tensor_rank)]
tf_layers_dict[graph_node_output.name]['tf_node'] = \
stridedslice_with_flexing_deterrence(
input_tensor=input_tensor,
begin=begin_,
end=end_,
strides=strides_,
begin_mask=begin_mask_,
end_mask=end_mask_,
ignore_axes=ignore_axes,
compression_defult_value=COMPRESSION_DEFAULT_VALUE,
onnx_slice_dims_count=onnx_slice_dims_count,
output_shape=tf_layers_dict[graph_node_output.name]['tf_node'].shape,
name=graph_node.name,
**kwargs,
)
else:
# OP replacement
tf_layers_dict[graph_node_output.name]['tf_node'] = \
Expand Down