Skip to content

Commit

Permalink
Supports some undefined dimensions (UNK, None)
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Nov 20, 2022
1 parent 61e8cce commit 86cc1a0
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.1.26
ghcr.io/pinto0309/onnx2tf:1.1.27
or
Expand Down Expand Up @@ -139,8 +139,8 @@ $ onnx2tf -i emotion-ferplus-8.onnx -oiqt
$ onnx2tf -i emotion-ferplus-8.onnx -oiqt -qt per-tensor

# Parameter replacement (Resize,Transpose,Softmax)
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/1.1.12/human_segmentation_pphumanseg_2021oct.onnx
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/1.1.12/replace.json
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/1.1.27/human_segmentation_pphumanseg_2021oct.onnx
$ wget https://github.com/PINTO0309/onnx2tf/releases/download/1.1.27/replace.json
$ onnx2tf -i human_segmentation_pphumanseg_2021oct.onnx -prf replace.json
```
## CLI Parameter
Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.1.26'
__version__ = '1.1.27'
7 changes: 5 additions & 2 deletions onnx2tf/utils/common_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from typing import Any, List, Optional
from functools import wraps
from collections import namedtuple

from onnx2tf.utils.enums import (
TF_DTYPES_TO_NUMPY_DTYPES,
)

INF_INDEX_VALUE: int = 4294967296

def get_replacement_parameter(func):
@wraps(func)
Expand Down Expand Up @@ -644,7 +644,10 @@ def explicit_broadcast(
graph_node_input_name1, graph_node_input_name2 = graph_node_input_name2, graph_node_input_name1

# If const_or_var_2.shape is all 1's, do not broadcast and return as is
if np.prod(const_or_var_2.shape) == 1:
shape_for_judging_skip_processing = [
i if i is not None else INF_INDEX_VALUE for i in const_or_var_2.shape
]
if np.prod(shape_for_judging_skip_processing) == 1:
return const_or_var_1, const_or_var_2

const_or_var_1_shape = const_or_var_1.shape
Expand Down

0 comments on commit 86cc1a0

Please sign in to comment.