/
fpn.py
executable file
·152 lines (131 loc) · 6.46 KB
/
fpn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# -*- coding: utf-8 -*-
"""
000 -> 000 -> feature map
^ |
| v
00000 -> 000 -> feature map
^ |
| v
0000000 -> 000 -> feature map
"""
import torch.nn as nn
__all__ = ['FPN']
class FPN(nn.Module):
def __init__(self,
num_input_channels_list,
num_input_strides_list,
num_output_channels,
num_outputs,
extra_on_input=False,
extra_type='conv',
norm_on_lateral=False,
relu_on_lateral=False,
relu_before_extra=False,
norm_cfg=None,
):
super(FPN, self).__init__()
assert num_outputs >= 1
assert extra_type in ['conv', 'pooling']
if norm_on_lateral:
assert norm_cfg is not None
if norm_cfg is not None:
assert 'type' in norm_cfg
assert norm_cfg['type'] in ['BatchNorm2d', 'GroupNorm']
if norm_cfg['type'] == 'GN':
assert 'num_groups' in norm_cfg
assert len(num_input_channels_list) == len(num_input_strides_list), 'they must have the same length!'
self._num_input_channels_list = num_input_channels_list
self._num_input_strides_list = num_input_strides_list
self._num_inputs = len(self._num_input_channels_list)
self._num_output_channels = num_output_channels
self._num_outputs = num_outputs
self._extra_on_input = extra_on_input
self._extra_type = extra_type
self._norm_on_lateral = norm_on_lateral
self._relu_on_lateral = relu_on_lateral
self._relu_before_extra = relu_before_extra
self._norm_cfg = norm_cfg
# lateral convs
for i in range(self._num_inputs):
lateral = []
if self._norm_on_lateral:
lateral.append(nn.Conv2d(self._num_input_channels_list[i], self._num_output_channels, kernel_size=1, stride=1, padding=0, bias=False))
lateral.append(nn.BatchNorm2d(num_features=self._num_output_channels) if self._norm_cfg['type'] == 'BatchNorm2d' else
nn.GroupNorm(num_groups=self._norm_cfg['num_groups'], num_channels=self._num_output_channels))
else:
lateral.append(nn.Conv2d(self._num_input_channels_list[i], self._num_output_channels, kernel_size=1, stride=1, padding=0, bias=True))
if self._relu_on_lateral:
lateral.append(nn.ReLU(inplace=False))
setattr(self, 'lateral%d' % i, nn.Sequential(*lateral))
# output convs
for i in range(self._num_outputs):
fpn_out = []
if i == self._num_inputs:
if self._extra_on_input:
if self._relu_before_extra:
fpn_out.append(nn.ReLU(inplace=True))
if self._extra_type == 'conv':
fpn_out.append(nn.Conv2d(self._num_input_channels_list[-1], self._num_output_channels, kernel_size=3, stride=2, padding=1, bias=True))
else:
fpn_out.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
else:
if self._relu_before_extra:
fpn_out.append(nn.ReLU(inplace=True))
if self._extra_type == 'conv':
fpn_out.append(nn.Conv2d(self._num_output_channels, self._num_output_channels, kernel_size=3, stride=2, padding=1, bias=True))
else:
fpn_out.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
elif i > self._num_inputs:
if self._relu_before_extra:
fpn_out.append(nn.ReLU(inplace=True))
if self._extra_type == 'conv':
fpn_out.append(nn.Conv2d(self._num_output_channels, self._num_output_channels, kernel_size=3, stride=2, padding=1, bias=True))
else:
fpn_out.append(nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
else:
fpn_out.append(nn.Conv2d(self._num_output_channels, self._num_output_channels, kernel_size=3, stride=1, padding=1, bias=True))
setattr(self, 'fpn_out%d' % i, nn.Sequential(*fpn_out))
self.__init_weights()
# compute the output stride for subsequent use by FCOS(point stride)
if self._num_outputs <= self._num_inputs:
self._num_output_strides_list = self._num_input_strides_list[:self._num_outputs]
else:
self._num_output_strides_list = self._num_input_strides_list
for i in range(self._num_outputs - self._num_inputs):
self._num_output_strides_list.append(self._num_input_strides_list[-1] * 2**(i+1))
@property
def num_output_strides_list(self):
return self._num_output_strides_list
def __init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight, gain=1)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
if hasattr(m, 'weight') and m.weight is not None:
nn.init.constant_(m.weight, 1)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, inputs):
assert len(inputs) == len(self._num_input_channels_list)
lateral_outputs = []
for i in range(self._num_inputs):
lateral_outputs.append(getattr(self, 'lateral%d' % i)(inputs[i]))
# top-down
for i in range(self._num_inputs - 1, 0, -1):
target_shape = lateral_outputs[i - 1].shape[2:]
lateral_outputs[i - 1] += nn.Upsample(target_shape, mode='nearest')(lateral_outputs[i])
# fpn output
fpn_outputs = []
for i in range(self._num_outputs):
if i == self._num_inputs:
if self._extra_on_input:
fpn_outputs.append(getattr(self, 'fpn_out%d' % i)(inputs[-1]))
else:
fpn_outputs.append(getattr(self, 'fpn_out%d' % i)(fpn_outputs[-1]))
elif i > self._num_inputs:
fpn_outputs.append(getattr(self, 'fpn_out%d' % i)(fpn_outputs[-1]))
else:
fpn_outputs.append(getattr(self, 'fpn_out%d' % i)(lateral_outputs[i]))
return tuple(fpn_outputs)