-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
resnet_block.py
706 lines (673 loc) · 23.1 KB
/
resnet_block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
from paddle import _legacy_C_ops, fluid
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.param_attr import ParamAttr
from paddle.nn import Layer
from paddle.nn import initializer as I
__all__ = ['resnet_basic_block', 'ResNetBasicBlock']
def resnet_basic_block(
x,
filter1,
scale1,
bias1,
mean1,
var1,
filter2,
scale2,
bias2,
mean2,
var2,
filter3,
scale3,
bias3,
mean3,
var3,
stride1,
stride2,
stride3,
padding1,
padding2,
padding3,
dilation1,
dilation2,
dilation3,
groups,
momentum,
eps,
data_format,
has_shortcut,
use_global_stats=None,
training=False,
trainable_statistics=False,
find_conv_max=True,
):
if fluid.framework.in_dygraph_mode():
attrs = (
'stride1',
stride1,
'stride2',
stride2,
'stride3',
stride3,
'padding1',
padding1,
'padding2',
padding2,
'padding3',
padding3,
'dilation1',
dilation1,
'dilation2',
dilation2,
'dilation3',
dilation3,
'group',
groups,
'momentum',
momentum,
'epsilon',
eps,
'data_format',
data_format,
'has_shortcut',
has_shortcut,
'use_global_stats',
use_global_stats,
"trainable_statistics",
trainable_statistics,
'is_test',
not training,
'act_type',
"relu",
'find_conv_input_max',
find_conv_max,
)
(
out,
_,
_,
_,
_,
_,
_,
_,
_,
_,
_,
_,
_,
_,
_,
_,
_,
_,
_,
_,
_,
_,
_,
) = _legacy_C_ops.resnet_basic_block(
x,
filter1,
scale1,
bias1,
mean1,
var1,
filter2,
scale2,
bias2,
mean2,
var2,
filter3,
scale3,
bias3,
mean3,
var3,
mean1,
var1,
mean2,
var2,
mean3,
var3,
*attrs
)
return out
helper = LayerHelper('resnet_basic_block', **locals())
bn_param_dtype = fluid.core.VarDesc.VarType.FP32
max_dtype = fluid.core.VarDesc.VarType.FP32
out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True
)
conv1 = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True
)
saved_mean1 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True
)
saved_invstd1 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True
)
running_mean1 = (
helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True
)
if mean1 is None
else mean1
)
running_var1 = (
helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True
)
if var1 is None
else var1
)
conv2 = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True
)
conv2_input = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True
)
saved_mean2 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True
)
saved_invstd2 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True
)
running_mean2 = (
helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True
)
if mean2 is None
else mean2
)
running_var2 = (
helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True
)
if var2 is None
else var2
)
conv3 = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True
)
saved_mean3 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True
)
saved_invstd3 = helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True
)
running_mean3 = (
helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True
)
if mean3 is None
else mean3
)
running_var3 = (
helper.create_variable_for_type_inference(
dtype=bn_param_dtype, stop_gradient=True
)
if var3 is None
else var3
)
conv1_input_max = helper.create_variable_for_type_inference(
dtype=max_dtype, stop_gradient=True
)
conv1_filter_max = helper.create_variable_for_type_inference(
dtype=max_dtype, stop_gradient=True
)
conv2_input_max = helper.create_variable_for_type_inference(
dtype=max_dtype, stop_gradient=True
)
conv2_filter_max = helper.create_variable_for_type_inference(
dtype=max_dtype, stop_gradient=True
)
conv3_input_max = helper.create_variable_for_type_inference(
dtype=max_dtype, stop_gradient=True
)
conv3_filter_max = helper.create_variable_for_type_inference(
dtype=max_dtype, stop_gradient=True
)
inputs = {
'X': x,
'Filter1': filter1,
'Scale1': scale1,
'Bias1': bias1,
'Mean1': mean1,
'Var1': var1,
'Filter2': filter2,
'Scale2': scale2,
'Bias2': bias2,
'Mean2': mean2,
'Var2': var2,
'Filter3': filter3,
'Scale3': scale3,
'Bias3': bias3,
'Mean3': mean3,
'Var3': var3,
}
attrs = {
'stride1': stride1,
'stride2': stride2,
'stride3': stride3,
'padding1': padding1,
'padding2': padding2,
'padding3': padding3,
'dilation1': dilation1,
'dilation2': dilation2,
'dilation3': dilation3,
'group': groups,
'momentum': momentum,
'epsilon': eps,
'data_format': data_format,
'has_shortcut': has_shortcut,
'use_global_stats': use_global_stats,
"trainable_statistics": trainable_statistics,
'is_test': not training,
'act_type': "relu",
'find_conv_input_max': find_conv_max,
}
outputs = {
'Y': out,
'Conv1': conv1,
'SavedMean1': saved_mean1,
'SavedInvstd1': saved_invstd1,
'Mean1Out': running_mean1,
'Var1Out': running_var1,
'Conv2': conv2,
'SavedMean2': saved_mean2,
'SavedInvstd2': saved_invstd2,
'Mean2Out': running_mean2,
'Var2Out': running_var2,
'Conv2Input': conv2_input,
'Conv3': conv3,
'SavedMean3': saved_mean3,
'SavedInvstd3': saved_invstd3,
'Mean3Out': running_mean3,
'Var3Out': running_var3,
'MaxInput1': conv1_input_max,
'MaxFilter1': conv1_filter_max,
'MaxInput2': conv2_input_max,
'MaxFilter2': conv2_filter_max,
'MaxInput3': conv3_input_max,
'MaxFilter3': conv3_filter_max,
}
helper.append_op(
type='resnet_basic_block', inputs=inputs, outputs=outputs, attrs=attrs
)
return out
class ResNetBasicBlock(Layer):
r"""
ResNetBasicBlock is designed for optimize the performence of the basic unit of ssd resnet block.
If has_shortcut = True, it can calculate 3 Conv2D, 3 BatchNorm and 2 ReLU in one time.
If has_shortcut = False, it can calculate 2 Conv2D, 2 BatchNorm and 2 ReLU in one time. In this
case the shape of output is same with input.
Args:
num_channels (int): The number of input image channel.
num_filter (int): The number of filter. It is as same as the output image channel.
filter_size (int|list|tuple): The filter size. If filter_size
is a tuple, it must contain two integers, (filter_size_height,
filter_size_width). Otherwise, filter_size_height = filter_size_width =\
filter_size.
stride (int, optional): The stride size. It means the stride in convolution.
If stride is a tuple, it must contain two integers, (stride_height, stride_width).
Otherwise, stride_height = stride_width = stride. Default: stride = 1.
act (str, optional): Activation type, if it is set to None, activation is not appended.
Default: None
momentum (float, optional): The value used for the moving_mean and
moving_var computation. This should be a float number or a Tensor with
shape [1] and data type as float32. The updated formula is:
:math:`moving\_mean = moving\_mean * momentum + new\_mean * (1. - momentum)`
:math:`moving\_var = moving\_var * momentum + new\_var * (1. - momentum)`
Default is 0.9.
eps (float, optional): A value added to the denominator for
numerical stability. Default is 1e-5.
data_format (str, optional): Specify the data format of the input, and the data format of the output
will be consistent with that of the input. Now is only support `"NCHW"`, the data is stored in
the order of: `[batch_size, input_channels, input_height, input_width]`.
has_shortcut (bool, optional): Whether to calculate CONV3 and BN3. Default: False.
use_global_stats (bool, optional): Whether to use global mean and
variance. In inference or test mode, set use_global_stats to true
or is_test to true, and the behavior is equivalent.
In train mode, when setting use_global_stats True, the global mean
and variance are also used during train period. Default: False.
is_test (bool, optional): A flag indicating whether it is in
test phrase or not. Default: False.
filter_attr (ParamAttr, optional): The parameter attribute for learnable parameters/weights
of conv2d. If it is set to None or one attribute of ParamAttr, conv2d
will create ParamAttr as param_attr. Default: None.
scale_attr (ParamAttr, optional): The parameter attribute for Parameter `scale`
of batch_norm. If it is set to None or one attribute of ParamAttr, batch_norm will create ParamAttr
as param_attr, the name of scale can be set in ParamAttr. If the Initializer of the param_attr is not set,
the parameter is initialized with Xavier. Default: None.
bias_attr (ParamAttr, optional): The parameter attribute for the bias of batch_norm.
If it is set to None or one attribute of ParamAttr, batch_norm
will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr.
If the Initializer of the bias_attr is not set, the bias is initialized zero.
Default: None.
moving_mean_name (str, optional): The name of moving_mean which store the global Mean. If it
is set to None, batch_norm will save global mean with a random name, otherwise, batch_norm
will save global mean with the string. Default: None.
moving_var_name (str, optional): The name of the moving_variance which store the global Variance.
If it is set to None, batch_norm will save global variance with a random name, otherwise, batch_norm
will save global variance with the string. Default: None.
padding (int, optional): The padding size. It is only spupport padding_height = padding_width = padding.
Default: padding = 0.
dilation (int, optional): The dilation size. It means the spacing between the kernel
points. It is only spupport dilation_height = dilation_width = dilation.
Default: dilation = 1.
trainable_statistics (bool, optional): Whether to calculate mean and var in eval mode. In eval mode, when
setting trainable_statistics True, mean and variance will be calculated by current batch statistics.
Default: False.
find_conv_max (bool, optional): Whether to calculate max value of each conv2d. Default: True.
Returns:
A Tensor representing the ResNetBasicBlock, whose data type is the same with input.
Examples:
.. code-block:: python
# required: xpu
import paddle
from paddle.incubate.xpu.resnet_block import ResNetBasicBlock
ch_in = 4
ch_out = 8
x = paddle.uniform((2, ch_in, 16, 16), dtype='float32', min=-1., max=1.)
resnet_basic_block = ResNetBasicBlock(num_channels1=ch_in,
num_filter1=ch_out,
filter1_size=3,
num_channels2=ch_out,
num_filter2=ch_out,
filter2_size=3,
num_channels3=ch_in,
num_filter3=ch_out,
filter3_size=1,
stride1=1,
stride2=1,
stride3=1,
act='relu',
padding1=1,
padding2=1,
padding3=0,
has_shortcut=True)
out = resnet_basic_block.forward(x)
print(out.shape) # [2, 8, 16, 16]
"""
def __init__(
self,
num_channels1,
num_filter1,
filter1_size,
num_channels2,
num_filter2,
filter2_size,
num_channels3,
num_filter3,
filter3_size,
stride1=1,
stride2=1,
stride3=1,
act='relu',
momentum=0.9,
eps=1e-5,
data_format='NCHW',
has_shortcut=False,
use_global_stats=False,
is_test=False,
filter1_attr=None,
scale1_attr=None,
bias1_attr=None,
moving_mean1_name=None,
moving_var1_name=None,
filter2_attr=None,
scale2_attr=None,
bias2_attr=None,
moving_mean2_name=None,
moving_var2_name=None,
filter3_attr=None,
scale3_attr=None,
bias3_attr=None,
moving_mean3_name=None,
moving_var3_name=None,
padding1=0,
padding2=0,
padding3=0,
dilation1=1,
dilation2=1,
dilation3=1,
trainable_statistics=False,
find_conv_max=True,
):
super().__init__()
self._stride1 = stride1
self._stride2 = stride2
self._kernel1_size = paddle.utils.convert_to_list(
filter1_size, 2, 'filter1_size'
)
self._kernel2_size = paddle.utils.convert_to_list(
filter2_size, 2, 'filter2_size'
)
self._dilation1 = dilation1
self._dilation2 = dilation2
self._padding1 = padding1
self._padding2 = padding2
self._groups = 1
self._momentum = momentum
self._eps = eps
self._data_format = data_format
self._act = act
self._has_shortcut = has_shortcut
self._use_global_stats = use_global_stats
self._is_test = is_test
self._trainable_statistics = trainable_statistics
self._find_conv_max = find_conv_max
if has_shortcut:
self._kernel3_size = paddle.utils.convert_to_list(
filter3_size, 2, 'filter3_size'
)
self._padding3 = padding3
self._stride3 = stride3
self._dilation3 = dilation3
else:
self._kernel3_size = None
self._padding3 = 1
self._stride3 = 1
self._dilation3 = 1
# check format
valid_format = {'NCHW'}
if data_format not in valid_format:
raise ValueError(
"conv_format must be one of {}, but got conv_format={}".format(
valid_format, data_format
)
)
def _get_default_param_initializer(channels, kernel_size):
filter_elem_num = np.prod(kernel_size) * channels
std = (2.0 / filter_elem_num) ** 0.5
return I.Normal(0.0, std)
# init filter
bn_param_dtype = fluid.core.VarDesc.VarType.FP32
bn1_param_shape = [1, 1, num_filter1]
bn2_param_shape = [1, 1, num_filter2]
filter1_shape = [num_filter1, num_channels1, filter1_size, filter1_size]
filter2_shape = [num_filter2, num_channels2, filter2_size, filter2_size]
self.filter_1 = self.create_parameter(
shape=filter1_shape,
attr=filter1_attr,
default_initializer=_get_default_param_initializer(
num_channels1, self._kernel1_size
),
)
self.scale_1 = self.create_parameter(
shape=bn1_param_shape,
attr=scale1_attr,
dtype=bn_param_dtype,
default_initializer=I.Constant(1.0),
)
self.bias_1 = self.create_parameter(
shape=bn1_param_shape,
attr=bias1_attr,
dtype=bn_param_dtype,
is_bias=True,
)
self.mean_1 = self.create_parameter(
attr=ParamAttr(
name=moving_mean1_name,
initializer=I.Constant(0.0),
trainable=False,
),
shape=bn1_param_shape,
dtype=bn_param_dtype,
)
self.mean_1.stop_gradient = True
self.var_1 = self.create_parameter(
attr=ParamAttr(
name=moving_var1_name,
initializer=I.Constant(1.0),
trainable=False,
),
shape=bn1_param_shape,
dtype=bn_param_dtype,
)
self.var_1.stop_gradient = True
self.filter_2 = self.create_parameter(
shape=filter2_shape,
attr=filter2_attr,
default_initializer=_get_default_param_initializer(
num_channels2, self._kernel2_size
),
)
self.scale_2 = self.create_parameter(
shape=bn2_param_shape,
attr=scale2_attr,
dtype=bn_param_dtype,
default_initializer=I.Constant(1.0),
)
self.bias_2 = self.create_parameter(
shape=bn2_param_shape,
attr=bias2_attr,
dtype=bn_param_dtype,
is_bias=True,
)
self.mean_2 = self.create_parameter(
attr=ParamAttr(
name=moving_mean2_name,
initializer=I.Constant(0.0),
trainable=False,
),
shape=bn2_param_shape,
dtype=bn_param_dtype,
)
self.mean_2.stop_gradient = True
self.var_2 = self.create_parameter(
attr=ParamAttr(
name=moving_var2_name,
initializer=I.Constant(1.0),
trainable=False,
),
shape=bn2_param_shape,
dtype=bn_param_dtype,
)
self.var_2.stop_gradient = True
if has_shortcut:
bn3_param_shape = [1, 1, num_filter3]
filter3_shape = [
num_filter3,
num_channels3,
filter3_size,
filter3_size,
]
self.filter_3 = self.create_parameter(
shape=filter3_shape,
attr=filter3_attr,
default_initializer=_get_default_param_initializer(
num_channels3, self._kernel3_size
),
)
self.scale_3 = self.create_parameter(
shape=bn3_param_shape,
attr=scale3_attr,
dtype=bn_param_dtype,
default_initializer=I.Constant(1.0),
)
self.bias_3 = self.create_parameter(
shape=bn3_param_shape,
attr=bias3_attr,
dtype=bn_param_dtype,
is_bias=True,
)
self.mean_3 = self.create_parameter(
attr=ParamAttr(
name=moving_mean3_name,
initializer=I.Constant(0.0),
trainable=False,
),
shape=bn3_param_shape,
dtype=bn_param_dtype,
)
self.mean_3.stop_gradient = True
self.var_3 = self.create_parameter(
attr=ParamAttr(
name=moving_var3_name,
initializer=I.Constant(1.0),
trainable=False,
),
shape=bn3_param_shape,
dtype=bn_param_dtype,
)
self.var_3.stop_gradient = True
else:
self.filter_3 = None
self.scale_3 = None
self.bias_3 = None
self.mean_3 = None
self.var_3 = None
def forward(self, x):
out = resnet_basic_block(
x,
self.filter_1,
self.scale_1,
self.bias_1,
self.mean_1,
self.var_1,
self.filter_2,
self.scale_2,
self.bias_2,
self.mean_2,
self.var_2,
self.filter_3,
self.scale_3,
self.bias_3,
self.mean_3,
self.var_3,
self._stride1,
self._stride2,
self._stride3,
self._padding1,
self._padding2,
self._padding3,
self._dilation1,
self._dilation2,
self._dilation3,
self._groups,
self._momentum,
self._eps,
self._data_format,
self._has_shortcut,
use_global_stats=self._use_global_stats,
training=self.training,
trainable_statistics=self._trainable_statistics,
find_conv_max=self._find_conv_max,
)
return out