Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FRONTEND][TF] conv2d_transpose 'SAME' support kernel more than 1x1 #4484

Merged
merged 3 commits into from
Dec 28, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
47 changes: 30 additions & 17 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ def _impl(inputs, attr, params):
attr['strides'][1], attr['strides'][2], attr['strides'][3] = \
attr['strides'][3], attr['strides'][1], attr['strides'][2]
attr['data_format'] = 'NCHW'

if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
tmp_shape = attr['_output_shapes'][0]
tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
attr['_output_shapes'][0] = tmp_shape

flip_layout = True

inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]
Expand Down Expand Up @@ -281,12 +287,17 @@ def _impl(inputs, attr, params):
elif attr['padding'] == 'SAME':
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']

pdata_shape = input_shape
if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
pdata_shape = attr['_output_shapes'][0]

if attr['data_format'] == 'NHWC':
in_h = input_shape[1]
in_w = input_shape[2]
in_h = pdata_shape[1]
in_w = pdata_shape[2]
else:
in_h = input_shape[2]
in_w = input_shape[3]
in_h = pdata_shape[2]
in_w = pdata_shape[3]

dilation_h = attr['dilations'][0]
dilation_w = attr['dilations'][1]
Expand All @@ -295,21 +306,23 @@ def _impl(inputs, attr, params):
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)

if opname != 'conv_transpose':
if attr['data_format'] == 'NHWC':
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))

if attr['data_format'] == 'NHWC':
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
attr['padding'] = [0, 0]
else:
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))

attr['padding'] = [0, 0]
attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]

else:
msg = 'Value {} in attribute "padding" of operator Conv is not ' \
Expand Down
16 changes: 14 additions & 2 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,22 @@ bool Conv2DTransposeRel(const Array<Type>& types,
}
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
auto pad_h = param->padding[0];
auto pad_w = param->padding[1];
if (param->padding.size() == 2) {
pad_h *= 2;
pad_w *= 2;
} else if (param->padding.size() == 4) {
pad_h += param->padding[2];
pad_w += param->padding[3];
} else {
CHECK_EQ(param->padding.size(), 4) << " Padding should be 2 or 4, but got "
<< param->padding.size();
}
oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
2 * param->padding[0] + param->output_padding[0]));
pad_h + param->output_padding[0]));
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
2 * param->padding[1] + param->output_padding[1]));
pad_w + param->output_padding[1]));

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
Expand Down
26 changes: 26 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,10 +362,23 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [2, 2], 'SAME',
'NCHW', [4, 176, 15, 15])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NCHW', [4, 176, 15, 15])
# cuda target not working
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the reason why it doen't work for cuda

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case failed due to the "Cannot prove" error when compiling the module. see #4470. It seems there is no fix for it yet.

#_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
# 'NCHW', [4, 176, 16, 16])
_test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 19, 17, 17])
_test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 124, 17, 17])
_test_convolution('conv_transpose', [4, 19, 17, 17], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add test case for

kernel 2x2, strides 2x2, SAME
kernel 3x3, strides 2x2, SAME

E.g. if input is 5x5 then valid outputs are 9x9 or 10x10 (you can use one or another in the output_shape tensor) regardless of the kernel size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The strides in transpose convolution is used to dilate the input, see this. If the input is dilated, there would be no way to get 'SAME' size of output by only padding.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok. I can get your point. 'SAME' means the size of the enlarged input size(by dilation).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TF code for kernel 2x2, strides 2x2 and padding SAME is

dshape=(1,5,5,4)
#hwoi
kshape=(2,2,2,4)
oshape=(1,9,9,2)
# or 
oshape=(1,10,10,2)
dtype='float32'
with tf.Session() as sess:
    x = tf.placeholder(shape=dshape, dtype=dtype)
    w = tf.placeholder(shape=kshape, dtype=dtype)
    dc = tf.nn.conv2d_transpose(x, w, output_shape=oshape, strides=(1,2,2,1), padding='SAME')

    res_dc = sess.run(dc, feed_dict={x: data, w:weight})

'NCHW', [4, 124, 17, 17])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 12, 17, 17])
# kernel 2x2, strides (2,2)
Expand All @@ -388,10 +401,23 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution('conv_transpose', [4, 8, 8, 32], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [2, 2], 'SAME',
'NHWC', [4, 15, 15, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NHWC', [4, 15, 15, 176])
# cuda target not working
#_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case failed due to the "Cannot prove" error when compiling the module. see #4470. It seems there is no fix for it yet.

# 'NHWC', [4, 16, 16, 176])
_test_convolution('conv_transpose', [4, 8, 8, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 19])
_test_convolution('conv_transpose', [4, 17, 17, 19], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 17, 17, 124])
_test_convolution('conv_transpose', [4, 17, 17, 19], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 17, 17, 124])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 12])
# kernel 2x2, strides (2,2)
Expand Down
9 changes: 7 additions & 2 deletions topi/python/topi/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,13 @@ def get_pad_tuple(padding, kernel):
"""
# compute the padding size
if isinstance(padding, (tuple, list)):
pad_h = padding[0] * 2
pad_w = padding[1] * 2
if len(padding) == 2:
pad_h = padding[0] * 2
pad_w = padding[1] * 2
elif len(padding) == 4:
return padding[0], padding[1], padding[2], padding[3]
else:
raise ValueError("Size of padding can only be 2 or 4")
elif isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == "VALID":
Expand Down