Skip to content

Commit

Permalink
Add max_unpool operation to backend
Browse files Browse the repository at this point in the history
MaxUnpool is not supported by tensorflow by default. Refer to
https://github.com/tensorflow/tensorflow/issues/2169 for more information.
The current solution uses proposed code from the above issue with modifications
to support for padding and strides.
  • Loading branch information
sdmonov committed Jul 5, 2019
1 parent afef14f commit 7d1efca
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 1 deletion.
11 changes: 11 additions & 0 deletions onnx_tf/handlers/backend/max_unpool.py
@@ -0,0 +1,11 @@
from onnx_tf.handlers.backend_handler import BackendHandler
from onnx_tf.handlers.handler import onnx_op
from .unpool_mixin import UnpoolMixin


@onnx_op("MaxUnpool")
class MaxUnpool(UnpoolMixin, BackendHandler):

@classmethod
def version_9(cls, node, **kwargs):
return cls.max_unpool(node, kwargs["tensor_dict"])
76 changes: 76 additions & 0 deletions onnx_tf/handlers/backend/unpool_mixin.py
@@ -0,0 +1,76 @@
import tensorflow as tf

from onnx_tf.common import get_data_format
from onnx_tf.common import get_perm_from_formats

class UnpoolMixin(object):

@classmethod
def max_unpool(cls, node, input_dict):
x = input_dict[node.inputs[0]]
ind = input_dict[node.inputs[1]]

x_rank = len(x.get_shape())
storage_format, compute_format = get_data_format(x_rank)
spatial_size = x_rank - 2

kernel_shape = node.attrs["kernel_shape"]
# if strides are not provided default is same as the kernel
strides = node.attrs.get("strides", kernel_shape)
pads = node.attrs.get("pads", [0] * spatial_size)
output_shape = node.attrs.get("output_shape", None)

input_shape = x.get_shape()
# if output_shape is not provided, calculate it
if output_shape is None:
output_shape = []
for d in range(len(kernel_shape)):
output_shape.append((int(input_shape[d + 2]) - 1) * int(strides[d]) +
int(kernel_shape[d]) - 2 * int(pads[d]))
output_shape = [int(input_shape[0])] + output_shape + [int(input_shape[1])]

need_trans = storage_format != "NHWC"
if need_trans:
x = tf.transpose(x, perm=get_perm_from_formats(storage_format, "NHWC"))
ind = tf.transpose(ind, perm=get_perm_from_formats(storage_format, "NHWC"))

unpooled = cls.unpool(x, ind, output_shape)

if need_trans:
unpooled = tf.transpose(
unpooled, perm=get_perm_from_formats("NHWC", storage_format))

return [unpooled]

@classmethod
def unpool(cls, pool, ind, output_shape, scope='unpool'):
"""
Unpooling layer after max_pool_with_argmax.
Args:
pool: max pooled output tensor
ind: argmax indices
output_shape: the shape of the output
Return:
unpool: unpooling tensor
"""
with tf.variable_scope(scope):
input_shape = tf.shape(pool)

flat_input_size = tf.reduce_prod(input_shape)
flat_output_shape = [output_shape[0], output_shape[1] * output_shape[2] * output_shape[3]]

pool_ = tf.reshape(pool, [flat_input_size])
batch_range = tf.reshape(tf.range(tf.cast(output_shape[0], tf.int64), dtype=ind.dtype),
shape=[input_shape[0], 1, 1, 1])
b = tf.ones_like(ind) * batch_range
b1 = tf.reshape(b, [flat_input_size, 1])
ind_ = tf.reshape(ind, [flat_input_size, 1])
ind_ = tf.concat([b1, ind_], 1)

ret = tf.scatter_nd(ind_, pool_, shape=tf.cast(flat_output_shape, tf.int64))
ret = tf.reshape(ret, output_shape)

set_input_shape = pool.get_shape()
ret.set_shape(output_shape)
return ret
2 changes: 1 addition & 1 deletion onnx_tf/opset_version.py
Expand Up @@ -77,7 +77,7 @@
'Max': [1, 6, 8],
'MaxPool': [1, 8],
'MaxRoiPool': [],
'MaxUnpool': [],
'MaxUnpool': [9],
'Mean': [1, 6, 8],
'MeanVarianceNormalization': [1],
'Min': [1, 6, 8],
Expand Down
63 changes: 63 additions & 0 deletions test/backend/test_node.py
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import tensorflow as tf
from onnx_tf.backend import run_node
from onnx_tf.backend import prepare
from onnx_tf.common import supports_device
from onnx_tf.common.legacy import legacy_onnx_pre_ver, legacy_opset_pre_ver
from onnx import helper
Expand Down Expand Up @@ -718,6 +719,68 @@ def test_max_pool(self):
max(x[i1][i2][j1][2*j2], x[i1][i2][j1][2*j2 + 1])
np.testing.assert_almost_equal(output["Y"], test_output)

def test_max_unpool(self):
input_shape = [10,10,4,4]
x = self._get_rnd(input_shape)

""" Maxpool op version 10 is not implemented yet and that is why we use a workaround
to force the onnx ops set version to 9 by using a model instead of just running nodes"""
X = helper.make_tensor_value_info('X', TensorProto.FLOAT, input_shape)
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, input_shape)

maxpool_node_def = helper.make_node(
"MaxPool", ["X"], ["Pool", "Indices"],
kernel_shape=[2, 2],
strides=[2, 2])

maxunpool_node_def = helper.make_node(
"MaxUnpool", ["Pool", "Indices"], ["Y"],
kernel_shape=[2, 2],
strides=[2, 2])

graph_def = helper.make_graph(
[maxpool_node_def,maxunpool_node_def],
"MaxUnpool-model",
[X],
[Y],
)
version = helper.make_operatorsetid("",9)
model_def = helper.make_model(graph_def,
opset_imports=[version])
tf_rep = prepare(model_def) # run the loaded model
output_unpool = tf_rep.run(x)

""" This code is simpler way to test maxunpool but fails because
maxpool op version 10 is not supported yet
node_def = helper.make_node(
"MaxPool", ["X"], ["Pool", "Indices"],
kernel_shape=[2, 2],
strides=[2, 2])
output_pool = run_node(node_def, [x])
node_def = helper.make_node(
"MaxUnpool", ["Pool", "Indices"], ["Y"],
kernel_shape=[2, 2],
strides=[2, 2])
output_unpool = run_node(node_def, [output_pool["Pool"], output_pool["Indices"]])
"""

test_output = np.zeros(input_shape)
for i1 in range(0, input_shape[0]):
for i2 in range(0, input_shape[1]):
for i3 in range(0, input_shape[2], 2):
for i4 in range(0, input_shape[3], 2):
max_val = float('-inf')
for j1 in range(i3,i3+2):
for j2 in range(i4,i4+2):
if x[i1][i2][j1][j2] > max_val:
max_val = x[i1][i2][j1][j2]
max_ind = (j1, j2)
j1, j2 = max_ind
test_output[i1][i2][j1][j2] = max_val
np.testing.assert_almost_equal(output_unpool["Y"], test_output)

def test_min(self):
node_def = helper.make_node("Min", ["X1", "X2", "X3", "X4"], ["Z"])
x1 = self._get_rnd([10, 10])
Expand Down

0 comments on commit 7d1efca

Please sign in to comment.