-
Notifications
You must be signed in to change notification settings - Fork 145
/
Copy pathcommon.lua
253 lines (223 loc) · 7.65 KB
/
common.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
require 'nn'
--------------------------------------------------------------------------
-- Below are list of common modules used in various architectures.
-- Thoses are defined as global variables in order to make other codes uncluttered.
--------------------------------------------------------------------------
seq = nn.Sequential
conv = nn.SpatialConvolution
deconv = nn.SpatialFullConvolution
relu = nn.ReLU
prelu = nn.PReLU
rrelu = nn.RReLU
elu = nn.ELU
leakyrelu = nn.LeakyReLU
bnorm = nn.SpatialBatchNormalization
avgpool = nn.SpatialAveragePooling
shuffle = nn.PixelShuffle
pad = nn.Padding
concat = nn.ConcatTable
id = nn.Identity
cadd = nn.CAddTable
join = nn.JoinTable
mulc = nn.MulConstant
function act(actParams, nOutputPlane)
local nOutputPlane = actParams.nFeat or nOutputPlane
local type = actParams.actType
if type == 'relu' then
return relu(true)
elseif type == 'prelu' then
return prelu(nOutputPlane)
elseif type == 'rrelu' then
return rrelu(actParams.l, actParams.u, true)
elseif type == 'elu' then
return elu(actParams.alpha, true)
elseif type == 'leakyrelu' then
return leakyrelu(actParams.negval, true)
else
error('unknown activation function!')
end
end
function addSkip(model, global)
local model = seq()
:add(concat()
:add(model)
:add(id()))
:add(cadd(true))
-- global skip or local skip connection of residual block
model:get(2).global = global or false
return model
end
function upsample(scale, method, nFeat, actParams)
local scale = scale or 2
local method = method or 'espcnn'
local nFeat = nFeat or 64
local actType = actParams.actType
local l, u = actParams.l, actParams.u
local alpha, negval = actParams.alpha, actParams.negval
actParams.nFeat = nFeat
local model = seq()
if method == 'deconv' then
if scale == 2 then
model:add(deconv(nFeat,nFeat, 6,6, 2,2, 2,2))
model:add(act(actParams))
elseif scale == 3 then
model:add(deconv(nFeat,nFeat, 9,9, 3,3, 3,3))
model:add(act(actParams))
elseif scale == 4 then
model:add(deconv(nFeat,nFeat, 6,6, 2,2, 2,2))
model:add(act(actParams))
model:add(deconv(nFeat,nFeat, 6,6, 2,2, 2,2))
model:add(act(actParams))
end
elseif method == 'espcnn' then -- Shi et al., 'Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network'
if scale == 2 then
model:add(conv(nFeat,4*nFeat, 3,3, 1,1, 1,1))
model:add(shuffle(2))
model:add(act(actParams))
elseif scale == 3 then
model:add(conv(nFeat,9*nFeat, 3,3, 1,1, 1,1))
model:add(shuffle(3))
model:add(act(actParams))
elseif scale == 4 then
model:add(conv(nFeat,4*nFeat, 3,3, 1,1, 1,1))
model:add(shuffle(2))
model:add(act(actParams))
model:add(conv(nFeat,4*nFeat, 3,3, 1,1, 1,1))
model:add(shuffle(2))
model:add(act(actParams))
end
end
return model
end
function upsample_wo_act(scale, method, nFeat)
local scale = scale or 2
local method = method or 'espcnn'
local nFeat = nFeat or 64
if method == 'deconv' then
if scale == 2 then
return deconv(nFeat,nFeat, 6,6, 2,2, 2,2)
elseif scale == 3 then
return deconv(nFeat,nFeat, 9,9, 3,3, 3,3)
elseif scale == 4 then
return seq()
:add(deconv(nFeat,nFeat, 6,6, 2,2, 2,2))
:add(deconv(nFeat,nFeat, 6,6, 2,2, 2,2))
elseif scale == 1 then
return id()
end
elseif method == 'espcnn' then -- Shi et al., 'Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network'
if scale == 2 then
return seq()
:add(conv(nFeat,4*nFeat, 3,3, 1,1, 1,1))
:add(shuffle(2))
elseif scale == 3 then
return seq()
:add(conv(nFeat,9*nFeat, 3,3, 1,1, 1,1))
:add(shuffle(3))
elseif scale == 4 then
return seq()
:add(conv(nFeat,4*nFeat, 3,3, 1,1, 1,1))
:add(shuffle(2))
:add(conv(nFeat,4*nFeat, 3,3, 1,1, 1,1))
:add(shuffle(2))
elseif scale == 1 then
return id()
end
end
end
function resBlock(nFeat, addBN, actParams, scaleRes, ipMulc)
local nFeat = nFeat or 64
local scaleRes = (scaleRes and scaleRes ~= 1) and scaleRes or false
local ipMulc = ipMulc or false
if not scaleRes then
assert(not ipMulc, 'Please specify -scaleRes option')
end
actParams.nFeat = nFeat
if addBN then
return addSkip(seq()
:add(conv(nFeat,nFeat, 3,3, 1,1, 1,1))
:add(bnorm(nFeat))
:add(act(actParams))
:add(conv(nFeat,nFeat, 3,3, 1,1, 1,1))
:add(bnorm(nFeat)))
else
if scaleRes then
return addSkip(seq()
:add(conv(nFeat,nFeat, 3,3, 1,1, 1,1))
:add(act(actParams))
:add(conv(nFeat,nFeat, 3,3, 1,1, 1,1))
:add(mulc(scaleRes, ipMulc)))
else
return addSkip(seq()
:add(conv(nFeat,nFeat, 3,3, 1,1, 1,1))
:add(act(actParams))
:add(conv(nFeat,nFeat, 3,3, 1,1, 1,1)))
end
end
end
function cbrcb(nFeat, addBN, actParams)
local nFeat = nFeat or 64
actParams.nFeat = nFeat
return seq()
:add(conv(nFeat,nFeat, 3,3, 1,1, 1,1))
:add(bnorm(nFeat))
:add(act(actParams))
:add(conv(nFeat,nFeat, 3,3, 1,1, 1,1))
:add(bnorm(nFeat))
end
function crc(nFeat, actParams)
local nFeat = nFeat or 64
actParams.nFeat = nFeat
return seq()
:add(conv(nFeat,nFeat, 3,3, 1,1, 1,1))
:add(act(actParams))
:add(conv(nFeat,nFeat, 3,3, 1,1, 1,1))
end
function brcbrc(nFeat, actParams)
return seq()
:add(bnorm(nFeat))
:add(act(actParams))
:add(conv(nFeat,nFeat, 3,3, 1,1, 1,1))
:add(bnorm(nFeat))
:add(act(actParams))
:add(conv(nFeat,nFeat, 3,3, 1,1, 1,1))
end
local MultiSkipAdd, parent = torch.class('nn.MultiSkipAdd', 'nn.Module')
function MultiSkipAdd:__init(ip)
parent.__init(self)
self.inplace = ip
end
--This function takes the input like {Skip, {Output1, Output2, ...}}
--and returns {Output1 + Skip, Output2 + Skip, ...}
--It also supports in-place calculation
function MultiSkipAdd:updateOutput(input)
self.output = {}
if self.inplace then
for i = 1, #input[2] do
self.output[i] = input[2][i]
end
else
for i = 1, #input[2] do
self.output[i] = input[2][i]:clone()
end
end
for i = 1, #input[2] do
self.output[i]:add(input[1])
end
return self.output
end
function MultiSkipAdd:updateGradInput(input, gradOutput)
self.gradInput = {gradOutput[1]:clone():fill(0), {}}
if self.inplace then
for i = 1, #input[2] do
self.gradInput[1]:add(gradOutput[i])
self.gradInput[2][i] = gradOutput[i]
end
else
for i = 1, #input[2] do
self.gradInput[1]:add(gradOutput[i])
self.gradInput[2][i] = gradOutput[i]:clone()
end
end
return self.gradInput
end