Skip to content

Commit

Permalink
Improved conversion stability for tensors with None geometry
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Sep 26, 2023
1 parent 6975a1b commit adafca4
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 7 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,15 @@ Video speed is adjusted approximately 50 times slower than actual speed.
$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.17.8
ghcr.io/pinto0309/onnx2tf:1.17.9

or

# Authentication is not required for pulls from Docker Hub.
$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
docker.io/pinto0309/onnx2tf:1.17.8
docker.io/pinto0309/onnx2tf:1.17.9

or

Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.17.8'
__version__ = '1.17.9'
3 changes: 2 additions & 1 deletion onnx2tf/ops/Concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ def make_node(

# TensorFlow does not support Concat for scalar values, so convert to tensor
values = [
value if len(value.shape) > 0 else tf.reshape(value, [1]) for value in values
value if value.shape != tf.TensorShape(None) \
and len(value.shape) > 0 else tf.reshape(value, [1]) for value in values
]

# Generation of TF OP
Expand Down
3 changes: 3 additions & 0 deletions onnx2tf/ops/Gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ def make_node(
and input_tensor.shape[axis] is not None:
maximum_number_of_elements = input_tensor.shape[axis]
indices_values = indices_values + maximum_number_of_elements
elif tf.keras.backend.is_keras_tensor(indices_values) \
and indices_values.shape == tf.TensorShape(None):
indices_values = tf.reshape(indices_values, [-1])

tf_layers_dict[graph_node_output.name]['tf_node'] = \
tf.gather(
Expand Down
10 changes: 7 additions & 3 deletions onnx2tf/ops/Slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,9 +410,13 @@ def make_node(
name=graph_node.name,
)

check_input_shape = list(input_tensor_shape)
check_output_shape = list(tf_layers_dict[graph_node_output.name]['tf_node'].shape)
if None not in check_input_shape \
check_input_shape = list(input_tensor_shape) \
if input_tensor_shape != tf.TensorShape(None) else None
check_output_shape = list(tf_layers_dict[graph_node_output.name]['tf_node'].shape) \
if tf_layers_dict[graph_node_output.name]['tf_node'].shape != tf.TensorShape(None) else None
if check_input_shape is not None \
and check_output_shape is not None \
and None not in check_input_shape \
and None not in check_output_shape \
and check_input_shape == check_output_shape:
# Disable useless slice
Expand Down

0 comments on commit adafca4

Please sign in to comment.