-
Notifications
You must be signed in to change notification settings - Fork 461
/
conv2d_transpose_depthwise.dml
216 lines (202 loc) · 9.15 KB
/
conv2d_transpose_depthwise.dml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------
/*
* 2D Depthwise Transpose Convolutional layer.
*
* Utilizes built-in convolution operators for higher performance.
*/
source("scripts/nn/util.dml") as util
forward = function(matrix[double] X, matrix[double] W, matrix[double] b,
int C, int Hin, int Win, int M, int Hf, int Wf,
int strideh, int stridew, int padh, int padw,
int out_padh, int out_padw)
return (matrix[double] out, int Hout, int Wout){
/*
* Computes the forward pass for a 2D depthwise spatial transpose
* convolutional layer with C/M filters of depth M. The input data
* has N examples, each represented as a 3D volume with C channels
* unrolled into a single vector. For each group of M input channels,
* a 2D transpose convolution is applied with 1 unique filter,
* yielding 1 output channel per input group of M input channels.
* The resulting C/M separate output channels are then concatenated
* together channel-wise into a single volume of C/M output channels.
*
* For clarity, if we were to use the same terminology as a regular
* depthwise convolution, a depthwise transpose convolution has the
* ability to contract each group of M input channels (from a total of
* C*M input channels) back to a single output channel, thus leading
* to C output channels. Thus, this is the "transpose" of the regular
* depthwise convolution. To keep the convention of always referring
* to the number of input channels as C, in this depthwise transpose
* layer we can reformulate the above by dividing by M. With this
* reformulation, we can now state that there are C input channels,
* and for each group of M inputs we output a single output channel,
* for a total of C/M output channels. For this, we use 1 filter of
* depth M for each group of M input channels, and we store W as
* `(C/M, M*Hf*Wf)`.
*
* Inputs:
* - X: Inputs, of shape (N, C*Hin*Win).
* - W: Weights, of shape (C/M, M*Hf*Wf).
* - b: Biases, of shape (C/M, 1).
* - C: Number of input channels (dimensionality of depth).
* - Hin: Input height.
* - Win: Input width.
* - M: Depth of each filter (C must be divisible by M).
* - Hf: Filter height.
* - Wf: Filter width.
* - strideh: Stride over height.
* - stridew: Stride over width.
* - padh: Padding for top and bottom sides.
* - padw: Padding for left and right sides.
* - out_padh: extra padding for top side. This should
* lie in [0, strideh-1].
* - out_padw: extra padding for right side. This should
* lie in [0, stridew-1].
*
* Outputs:
* - out: Outputs, of shape (N, C/M*Hout*Wout).
* - Hout: Output height.
* - Wout: Output width.
*/
N = nrow(X)
F = nrow(W)
Hout = strideh*(Hin-1) - 2*padh + Hf + out_padh
Wout = stridew*(Win-1) - 2*padw + Wf + out_padw
# create output volume
# NOTE: We initialize to 1s vs. 0s to avoid conversions between sparse and dense formats.
# This is a complete hack until the engine is improved.
out = matrix(1, rows=N, cols=C/M*Hout*Wout)
# depthwise transpose convolution
# TODO: Explore usage of parfor loops more to determine if they can provide a performance
# benefit. Initial tests show that they are slower than the regular for loop, likely because
# they cause a reduction from a multithreaded conv2d op to a singlethreaded version. For a
# number of filters C/M >> the number of examples, it's possible that the parfor loop could be
# faster.
#parfor (f in 1:F, check=0) { # each channel
for (f in 1:F) {
# compute gradient wrt data of conv2d using 1 filter and M input channels
w = matrix(W[f,], rows=M, cols=Hf*Wf) # 1 filter, of shape (M, 1*Hf*Wf)
Xm = X[,((f-1)*M*Hin*Win + 1):f*M*Hin*Win] # M input channels, of shape (N, M*Hin*Win)
outm = conv2d_backward_data(w, Xm, stride=[strideh,stridew], padding=[padh,padw],
input_shape=[N,1,Hout,Wout], filter_shape=[M,1,Hf,Wf])
# store
out[,((f-1)*Hout*Wout + 1):f*Hout*Wout] = outm # outm has shape (N, 1*Hout*Wout)
}
# add bias term to each output filter
out = bias_add(out, b)
}
backward = function(matrix[double] dout, int Hout, int Wout,
matrix[double] X, matrix[double] W, matrix[double] b,
int C, int Hin, int Win, int M, int Hf, int Wf,
int strideh, int stridew, int padh, int padw)
return (matrix[double] dX, matrix[double] dW, matrix[double] db){
/*
* Computes the backward pass for a 2D spatial transpose
* convolutional layer with F filters.
*
* Inputs:
* - dout: Gradient wrt `out` from upstream, of
* shape (N, C/M*Hout*Wout).
* - Hout: Output height.
* - Wout: Output width.
* - X: Inputs, of shape (N, C*Hin*Win).
* - W: Weights, of shape (C/M, M*Hf*Wf).
* - b: Biases, of shape (C/M, 1).
* - C: Number of input channels (dimensionality of depth).
* - Hin: Input height.
* - Win: Input width.
* - M: Depth of each filter (C must be divisible by M).
* - Hf: Filter height.
* - Wf: Filter width.
* - strideh: Stride over height.
* - stridew: Stride over width.
* - padh: Padding for top and bottom sides.
* - padw: Padding for left and right sides.
*
* Outputs:
* - dX: Gradient wrt `X`, of shape (N, C*Hin*Win).
* - dW: Gradient wrt `W`, of shape (C/M, M*Hf*Wf).
* - db: Gradient wrt `b`, of shape (C/M, 1).
*/
N = nrow(X)
F = nrow(W)
# create gradient volumes
# NOTE: We initialize to 1s vs. 0s to avoid conversions between sparse and dense formats.
# This is a complete hack until the engine is improved.
dX = matrix(1, rows=N, cols=C*Hin*Win)
dW = matrix(1, rows=C/M, cols=M*Hf*Wf)
db = matrix(1, rows=C/M, cols=1)
# depthwise transpose convolution
for (f in 1:F) {
# extract 1 gradient channel, 1 depth-1 filter, and M input channels, since the forward pass
# maps M input channels to 1 output channel for each filter
doutf = dout[,((f-1)*Hout*Wout + 1):f*Hout*Wout] # shape (N, 1*Hout*Wout)
w = matrix(W[f,], rows=M, cols=Hf*Wf) # 1 filter, of shape (M, 1*Hf*Wf)
Xm = X[,((f-1)*M*Hin*Win + 1):f*M*Hin*Win] # M input channels, of shape (N, M*Hin*Win)
# compute gradients:
# conv2d_backward_filter takes the input and gradient wrt the output
# as first and second args, respectively. Given that we need to
# compute the grad wrt to filter for transpose convolution, where
# the roles of the input and output are reversed, we reverse the
# order of the args (along with setting input_shape to the dout
# shape).
dw = conv2d_backward_filter(doutf, Xm, stride=[strideh,stridew], padding=[padh,padw],
input_shape=[N,1,Hout,Wout], filter_shape=[M,1,Hf,Wf])
# Since the forward for transpose convolution makes a call to
# conv2d_backward_data, to compute its derivative wrt to data
# we can run conv2d by applying the filter on the grad wrt the
# output (this makes sense because convolution transpose is the
# 'reverse' of convolution). It's easy to see that this will produce
# an output of the required size.
dXm = conv2d(doutf, w, input_shape=[N,1,Hout,Wout], filter_shape=[M,1,Hf,Wf],
stride=[strideh,stridew], padding=[padh,padw])
# store
dX[,((f-1)*M*Hin*Win + 1):f*M*Hin*Win] = dXm
dW[f,] = matrix(dw, rows=1, cols=M*Hf*Wf)
}
# partial derivatives for bias vector
db = util::channel_sums(dout, C/M, Hout, Wout)
}
init = function(int C, int M, int Hf, int Wf)
return (matrix[double] W, matrix[double] b){
/*
* Utility function to initialize the parameters of this layer.
*
* We use the heuristic by He et al., which limits the magnification
* of inputs/gradients during forward/backward passes by scaling
* unit-Gaussian weights by a factor of sqrt(2/n), under the
* assumption of relu neurons.
* - http://arxiv.org/abs/1502.01852
*
* Inputs:
* - C: Number of input channels (dimensionality of depth).
* - M: Depth of each filter (C must be divisible by M).
* - Hf: Filter height.
* - Wf: Filter width.
*
* Outputs:
* - W: Weights, of shape (C/M, M*Hf*Wf).
* - b: Biases, of shape (C/M, 1).
*/
W = rand(rows=C/M, cols=M*Hf*Wf, pdf="normal") * sqrt(2/(M*Hf*Wf))
b = matrix(0, rows=C/M, cols=1)
}