/
conv2d_transpose.py
104 lines (89 loc) · 4.01 KB
/
conv2d_transpose.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-variable, unused-argument
"""Transposed 2D convolution operators (sometimes called Deconvolution)."""
from __future__ import absolute_import as _abs
import tvm
from .dilate import dilate
from .pad import pad
from .util import get_pad_tuple
from ..util import simplify
@tvm.target.generic_func
def conv2d_transpose_nchw(Input, Filter, strides, padding, out_dtype):
"""Transposed 2D convolution nchw forward operator.
Parameters
----------
Input : tvm.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
Filter : tvm.Tensor
4-D with shape [in_channel, num_filter, filter_height, filter_width]
strides : tuple of two ints
The spatial stride along height and width
padding : int or str
Padding size, or ['VALID', 'SAME']
out_dtype : str
The output data type. This is used for mixed precision.
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
return declaration_conv2d_transpose_impl(Input, Filter, strides, padding, out_dtype)
def conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype):
"""Preprocess data and kernel to make the compute pattern
of conv2d_transpose the same as conv2d"""
batch, in_c, in_h, in_w = data.shape
_, out_c, filter_h, filter_w = kernel.shape
stride_h, stride_w = strides
# dilate data
data_dilate = dilate(data, [1, 1, stride_h, stride_w], name='data_dilate')
# pad data
fpad_top, fpad_left, fpad_bottom, fpad_right = get_pad_tuple(padding, (filter_h, filter_w))
bpad_top = filter_h - 1 - fpad_top
bpad_bottom = filter_h - 1 - fpad_bottom
bpad_left = filter_w - 1 - fpad_left
bpad_right = filter_w - 1 - fpad_right
data_pad = pad(data_dilate, \
[0, 0, bpad_top, bpad_left], \
[0, 0, bpad_bottom, bpad_right], \
name='data_pad')
# transform kernel layout from IOHW to OIHW, and rotate kernel by 180 degrees
kernel_transform = tvm.compute((out_c, in_c, filter_h, filter_w), \
lambda o, i, h, w: kernel[i][o][filter_h-1-h][filter_w-1-w], \
name='kernel_transform')
return data_pad, kernel_transform
def declaration_conv2d_transpose_impl(data, kernel, strides, padding, out_dtype):
"""Implementation of conv2d transpose"""
data_pad, kernel_transform = \
conv2d_transpose_nchw_preprocess(data, kernel, strides, padding, out_dtype)
batch, in_c, in_h, in_w = data_pad.shape
out_c, _, filter_h, filter_w = kernel_transform.shape
stride_h, stride_w = strides
# convolution stage
out_c = simplify(out_c)
out_h = simplify(in_h - filter_h + 1)
out_w = simplify(in_w - filter_w + 1)
dc = tvm.reduce_axis((0, in_c), name='dc')
dh = tvm.reduce_axis((0, filter_h), name='dh')
dw = tvm.reduce_axis((0, filter_w), name='dw')
Output = tvm.compute(
(batch, out_c, out_h, out_w),
lambda b, c, h, w: tvm.sum(
data_pad[b, dc, h+dh, w+dw].astype(out_dtype) *
kernel_transform[c, dc, dh, dw].astype(out_dtype),
axis=[dc, dh, dw]), tag="conv2d_transpose_nchw")
return Output