Skip to content

Commit

Permalink
Fixed bugs in parameter replacement logic and Flatten attribute repla…
Browse files Browse the repository at this point in the history
…cement timing
  • Loading branch information
PINTO0309 committed Oct 27, 2022
1 parent ebece01 commit fca5c2a
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.0.27
ghcr.io/pinto0309/onnx2tf:1.0.28
or
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.0.27'
__version__ = '1.0.28'
14 changes: 8 additions & 6 deletions onnx2tf/ops/Flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ def make_node(
'dtype': dtype,
}

# Param replacement
axis = replace_parameter(
value_before_replacement=axis,
param_target='attributes',
param_name='axis',
**kwargs,
)

# Generation of TF OP
cal_shape = None
if axis == 0:
Expand All @@ -85,12 +93,6 @@ def make_node(
param_name=graph_node.inputs[0].name,
**kwargs,
)
cal_shape = replace_parameter(
value_before_replacement=cal_shape,
param_target='attributes',
param_name='axis',
**kwargs,
)

# Pre-process transpose
input_tensor = pre_process_transpose(
Expand Down
3 changes: 2 additions & 1 deletion onnx2tf/utils/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def replace_parameter(
op_rep_params = kwargs.get('op_rep_params', [])
for op_rep_param in op_rep_params:
if op_rep_param['param_target'] == param_target \
and op_rep_param['param_name'] == param_name:
and op_rep_param['param_name'] == param_name \
and 'values' in op_rep_param:
replace_value = op_rep_param.get('values', value_before_replacement)
if isinstance(value_before_replacement, np.ndarray):
replace_value = np.asarray(
Expand Down

0 comments on commit fca5c2a

Please sign in to comment.