Skip to content

Commit

Permalink
Pad bug fixes and added Upsample support
Browse files Browse the repository at this point in the history
  • Loading branch information
PINTO0309 committed Oct 27, 2022
1 parent 2393c56 commit ebece01
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 22 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Video speed is adjusted approximately 50 times slower than actual speed.
$ docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.0.26
ghcr.io/pinto0309/onnx2tf:1.0.27
or
Expand Down Expand Up @@ -592,6 +592,7 @@ Please don't post such low level questions as issues.
|Trilu|:heavy_check_mark:|
|Unique|**Help wanted**|
|Unsqueeze|:heavy_check_mark:|
|Upsample|:heavy_check_mark:|
|Where|:heavy_check_mark:|
|Xor|:heavy_check_mark:|

Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.0.26'
__version__ = '1.0.27'
48 changes: 28 additions & 20 deletions onnx2tf/ops/Pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,7 @@ def _symmetric_pad(i, x):

# tf requires int32 paddings
paddings = tf.cast(
x=tf.transpose(
a=tf.reshape(
tensor=paddings,
shape=[2, tensor_rank],
)
),
x=paddings,
dtype=tf.int32,
)

Expand Down Expand Up @@ -143,22 +138,26 @@ def make_node(
tf_layers_dict: dict
optype, shape, dtype, tensorflow graph
"""
before_op_output_shape_trans_1 = \
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
before_op_output_shape_trans_2 = \
tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
before_op_output_shape_trans = \
before_op_output_shape_trans_1 \
and before_op_output_shape_trans_2
before_op_output_shape_trans = True
if len(graph_node.inputs) == 1:
before_op_output_shape_trans_1 = \
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
before_op_output_shape_trans = \
before_op_output_shape_trans_1
elif len(graph_node.inputs) >= 2:
before_op_output_shape_trans_1 = \
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)
before_op_output_shape_trans_2 = \
tf_layers_dict.get(graph_node.inputs[1].name, {}).get('before_op_output_shape_trans', True)
before_op_output_shape_trans = \
before_op_output_shape_trans_1 \
and before_op_output_shape_trans_2

input_tensor = get_constant_or_variable(
graph_node.inputs[0],
before_op_output_shape_trans,
)
paddings = get_constant_or_variable(
graph_node.inputs[1],
before_op_output_shape_trans,
)

constant_value = 0
if len(graph_node.inputs) >= 3 and graph_node.inputs[2].name != '':
constant_value = get_constant_or_variable(
Expand All @@ -172,16 +171,25 @@ def make_node(
input_tensor = tf_layers_dict[input_tensor.name]['tf_node'] \
if isinstance(input_tensor, gs.Variable) else input_tensor
tensor_rank = len(input_tensor.shape)
paddings = tf_layers_dict[paddings.name]['tf_node'] \
if isinstance(paddings, gs.Variable) else paddings

constant_value = tf_layers_dict[constant_value.name]['tf_node'] \
if isinstance(constant_value, gs.Variable) else constant_value

# Transpose pads values
paddings = graph_node.inputs[1]
paddings = None
if len(graph_node.inputs) >= 2:
paddings = graph_node.inputs[1]
paddings = graph_node.attrs.get('pads', paddings)
if isinstance(paddings, list):
paddings = np.asarray(paddings)

values = None
if hasattr(paddings, 'values'):
values = paddings.values
elif isinstance(paddings, np.ndarray):
values = paddings

if values is not None:
paddings = values.reshape([2, tensor_rank]).transpose()
paddings_rank = paddings.shape[0]
if paddings_rank > 2:
Expand Down
130 changes: 130 additions & 0 deletions onnx2tf/ops/Upsample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import random
random.seed(0)
import numpy as np
np.random.seed(0)
import tensorflow as tf
import onnx_graphsurgeon as gs
from onnx2tf.utils.common_functions import (
get_constant_or_variable,
print_node_info,
inverted_operation_enable_disable,
make_tf_node_info,
)


@print_node_info
@inverted_operation_enable_disable
def make_node(
*,
graph_node: gs.Node,
tf_layers_dict: dict,
**kwargs: dict,
):
"""Upsample
Parameters
----------
graph_node: gs.Node
graph_surgeon Node
tf_layers_dict: dict
optype, shape, dtype, tensorflow graph
"""
before_op_output_shape_trans_1 = \
tf_layers_dict.get(graph_node.inputs[0].name, {}).get('before_op_output_shape_trans', True)

before_op_output_shape_trans = \
before_op_output_shape_trans_1

input_tensor = get_constant_or_variable(
graph_node.inputs[0],
before_op_output_shape_trans,
)
scales = None
if len(graph_node.inputs) >= 2:
scales = get_constant_or_variable(
graph_node.inputs[1],
before_op_output_shape_trans,
)
else:
scales = get_constant_or_variable(
graph_node.attrs.get('scales', scales),
before_op_output_shape_trans,
)

graph_node_output: gs.Variable = graph_node.outputs[0]
shape = graph_node_output.shape
dtype = graph_node_output.dtype

input_tensor = tf_layers_dict[input_tensor.name]['tf_node'] \
if isinstance(input_tensor, gs.Variable) else input_tensor
input_tensor_shape = input_tensor.shape
scales = tf_layers_dict[scales.name]['tf_node'] \
if isinstance(scales, gs.Variable) else scales

mode = graph_node.attrs.get('mode', 'nearest')

# Preserving Graph Structure (Dict)
tf_layers_dict[graph_node_output.name] = {
'optype': graph_node.op,
'shape': shape,
'dtype': dtype,
}

# Generation of TF OP
new_size = None
if hasattr(graph_node.outputs[0], 'shape') \
and graph_node.outputs[0].shape is not None \
and isinstance(graph_node.outputs[0].shape[-2], int) \
and isinstance(graph_node.outputs[0].shape[-1], int):
new_size = graph_node.outputs[0].shape[-2:len(graph_node.outputs[0].shape)] # Estimated from ONNX output shape
else:
h_w_scale = scales[1:3]
h_w_shape = input_tensor_shape[1:3]
new_size = tf.cast(h_w_scale * tf.cast(h_w_shape, scales.dtype), tf.int32)

if hasattr(new_size, 'set_shape'):
new_size.set_shape([2])

if hasattr(new_size, '_inferred_value'):
new_size_values = new_size._inferred_value
if new_size_values.count(None) == len(new_size_values):
tensor_rank = len(graph_node_output.shape)
convertion_table = [0] + [i for i in range(2, tensor_rank)] + [1]
new_values = [0] * tensor_rank
for new_idx, idx in enumerate(convertion_table):
new_values[new_idx] = graph_node_output.shape[idx]
new_size = new_values[-3:-1]

resized_tensor = None
tf_op_type = None
if mode.lower() == "bilinear" or mode.lower() == "linear":
mode = tf.image.ResizeMethod.BILINEAR
else:
mode = tf.image.ResizeMethod.NEAREST_NEIGHBOR

resized_tensor = tf.image.resize(
images=input_tensor,
size=new_size,
method=mode,
name=graph_node.name,
)
tf_op_type = tf.image.resize

tf_layers_dict[graph_node_output.name]['tf_node'] = resized_tensor

# Generation of Debug Info
tf_layers_dict[graph_node_output.name]['tf_node_info'] = \
make_tf_node_info(
node_info={
'tf_op_type': tf_op_type,
'tf_inputs': {
'images': input_tensor,
'new_size/crop_size': new_size,
'method': mode,
},
'tf_outputs': {
'output': tf_layers_dict[graph_node_output.name]['tf_node'],
},
}
)

0 comments on commit ebece01

Please sign in to comment.