Skip to content

Commit

Permalink
[FRONTEND][TF] conv2d_transpose 'SAME' support kernel more than 1x1 (#…
Browse files Browse the repository at this point in the history
…4484)

* [FRONTEND][TF] conv3d_transpose 'SAME' support kernel more than 1x1

* revised per as review comments

* add more fallback wolkaround to make all tests pass
  • Loading branch information
optima2005 authored and tqchen committed Dec 28, 2019
1 parent 3227614 commit 227c7af
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 21 deletions.
47 changes: 30 additions & 17 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,12 @@ def _impl(inputs, attr, params):
attr['strides'][1], attr['strides'][2], attr['strides'][3] = \
attr['strides'][3], attr['strides'][1], attr['strides'][2]
attr['data_format'] = 'NCHW'

if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
tmp_shape = attr['_output_shapes'][0]
tmp_shape = [tmp_shape[ii] for ii in (0, 3, 1, 2)]
attr['_output_shapes'][0] = tmp_shape

flip_layout = True

inputs_data = inputs[0] if opname != 'conv_transpose' else inputs[2]
Expand Down Expand Up @@ -345,12 +351,17 @@ def _impl(inputs, attr, params):
elif attr['padding'] == 'SAME':
stride_h, stride_w = attr['strides']
kernel_h, kernel_w = attr['kernel_shape']

pdata_shape = input_shape
if opname == 'conv_transpose' and len(attr['_output_shapes']) > 0:
pdata_shape = attr['_output_shapes'][0]

if attr['data_format'] == 'NHWC':
in_h = input_shape[1]
in_w = input_shape[2]
in_h = pdata_shape[1]
in_w = pdata_shape[2]
else:
in_h = input_shape[2]
in_w = input_shape[3]
in_h = pdata_shape[2]
in_w = pdata_shape[3]

dilation_h = attr['dilations'][0]
dilation_w = attr['dilations'][1]
Expand All @@ -359,21 +370,23 @@ def _impl(inputs, attr, params):
pad_v = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
pad_h = _get_pad_pair(in_w, dilated_kernel_w, stride_w)

if opname != 'conv_transpose':
if attr['data_format'] == 'NHWC':
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
else:
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))

if attr['data_format'] == 'NHWC':
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1]),
(0, 0)))
attr['padding'] = [0, 0]
else:
inputs_data = _op.nn.pad(data=inputs_data,
pad_width=((0, 0),
(0, 0),
(pad_v[0], pad_v[1]),
(pad_h[0], pad_h[1])))

attr['padding'] = [0, 0]
attr['padding'] = [pad_v[0], pad_h[0], pad_v[1], pad_h[1]]

else:
msg = 'Value {} in attribute "padding" of operator Conv is not ' \
Expand Down
16 changes: 14 additions & 2 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,22 @@ bool Conv2DTransposeRel(const Array<Type>& types,
}
// dilation
Array<IndexExpr> oshape({dshape_nchw[0], channels, 0, 0});
auto pad_h = param->padding[0];
auto pad_w = param->padding[1];
if (param->padding.size() == 2) {
pad_h *= 2;
pad_w *= 2;
} else if (param->padding.size() == 4) {
pad_h += param->padding[2];
pad_w += param->padding[3];
} else {
CHECK_EQ(param->padding.size(), 4) << " Padding should be 2 or 4, but got "
<< param->padding.size();
}
oshape.Set(2, (param->strides[0] * (dshape_nchw[2] - 1) + dilated_ksize_y -
2 * param->padding[0] + param->output_padding[0]));
pad_h + param->output_padding[0]));
oshape.Set(3, (param->strides[1] * (dshape_nchw[3] - 1) + dilated_ksize_x -
2 * param->padding[1] + param->output_padding[1]));
pad_w + param->output_padding[1]));

DataType out_dtype = param->out_dtype;
if (out_dtype.bits() == 0) {
Expand Down
24 changes: 24 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,10 +403,22 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 12, 17, 17], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NCHW')
_test_convolution('conv_transpose', [4, 32, 8, 8], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 32, 8, 8], [2, 2, 176, 32], [1, 1], [2, 2], 'SAME',
'NCHW', [4, 176, 15, 15])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 176, 8, 8])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NCHW', [4, 176, 15, 15])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NCHW', [4, 176, 16, 16])
_test_convolution('conv_transpose', [4, 19, 8, 8], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 19, 17, 17])
_test_convolution('conv_transpose', [4, 19, 17, 17], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 124, 17, 17])
_test_convolution('conv_transpose', [4, 19, 17, 17], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME',
'NCHW', [4, 124, 17, 17])
_test_convolution('conv_transpose', [4, 32, 8, 8], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NCHW', [4, 12, 17, 17])
# kernel 2x2, strides (2,2)
Expand All @@ -429,10 +441,22 @@ def test_forward_convolution():
_test_convolution('depthwise', [4, 17, 17, 12], [3, 3, 12, 2], [1, 1], [2, 2], 'VALID', 'NHWC')
_test_convolution('conv_transpose', [4, 8, 8, 32], [1, 1, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [2, 2, 176, 32], [1, 1], [2, 2], 'SAME',
'NHWC', [4, 15, 15, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 8, 8, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NHWC', [4, 15, 15, 176])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 176, 32], [1, 1], [2, 2], 'SAME',
'NHWC', [4, 16, 16, 176])
_test_convolution('conv_transpose', [4, 8, 8, 19], [3, 3, 19, 19], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 19])
_test_convolution('conv_transpose', [4, 17, 17, 19], [1, 1, 124, 19], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 17, 17, 124])
_test_convolution('conv_transpose', [4, 17, 17, 19], [3, 3, 124, 19], [1, 1], [1, 1], 'SAME',
'NHWC', [4, 17, 17, 124])
_test_convolution('conv_transpose', [4, 8, 8, 32], [3, 3, 12, 32], [1, 1], [2, 2], 'VALID',
'NHWC', [4, 17, 17, 12])
# kernel 2x2, strides (2,2)
Expand Down
2 changes: 2 additions & 0 deletions topi/python/topi/cuda/conv2d_transpose_nchw.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def _callback(op):
do_fallback = False
elif (kh, kw) == (1, 1):
do_fallback = True
elif (stride_h, stride_w) == (2, 2):
do_fallback = False
elif (kh, kw) == (stride_h, stride_w):
do_fallback = False

Expand Down
9 changes: 7 additions & 2 deletions topi/python/topi/nn/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,13 @@ def get_pad_tuple(padding, kernel):
"""
# compute the padding size
if isinstance(padding, (tuple, list)):
pad_h = padding[0] * 2
pad_w = padding[1] * 2
if len(padding) == 2:
pad_h = padding[0] * 2
pad_w = padding[1] * 2
elif len(padding) == 4:
return padding[0], padding[1], padding[2], padding[3]
else:
raise ValueError("Size of padding can only be 2 or 4")
elif isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == "VALID":
Expand Down

0 comments on commit 227c7af

Please sign in to comment.