-
Notifications
You must be signed in to change notification settings - Fork 19
/
fan.py
265 lines (228 loc) · 11.4 KB
/
fan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import math
import fvcore.nn.weight_init as weight_init
import torch.nn.functional as F
import torch
from torch import nn
import os
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
import cv2
import PIL
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.cm as mpl_color_map
from detectron2.layers import Conv2d, ShapeSpec, get_norm
from .backbone import Backbone
from .build import BACKBONE_REGISTRY
from .resnet import build_resnet_backbone
__all__ = ["build_resnet_fan_backbone", "build_retinanet_resnet_fan_backbone", "FAN"]
# from dcn_v2 import DCN, DCNPooling, DCNv2, DCNv2Pooling, dcn_v2_conv, dcn_v2_pooling
from dcn_v2 import DCN as dcn_v2
from detectron2.layers import (CNNBlockBase, Conv2d, DeformConv, ModulatedDeformConv, ShapeSpec, get_norm, )
class FeatureSelectionModule(nn.Module):
def __init__(self, in_chan, out_chan, norm="GN"):
super(FeatureSelectionModule, self).__init__()
self.conv_atten = Conv2d(in_chan, in_chan, kernel_size=1, bias=False, norm=get_norm(norm, in_chan))
self.sigmoid = nn.Sigmoid()
self.conv = Conv2d(in_chan, out_chan, kernel_size=1, bias=False, norm=get_norm('', out_chan))
weight_init.c2_xavier_fill(self.conv_atten)
weight_init.c2_xavier_fill(self.conv)
def forward(self, x):
atten = self.sigmoid(self.conv_atten(F.avg_pool2d(x, x.size()[2:])))
feat = torch.mul(x, atten)
x = x + feat
feat = self.conv(x)
return feat
class FeatureAlign_V2(nn.Module): # FaPN full version
def __init__(self, in_nc=128, out_nc=128, norm=None):
super(FeatureAlign_V2, self).__init__()
self.lateral_conv = FeatureSelectionModule(in_nc, out_nc, norm="")
self.offset = Conv2d(out_nc * 2, out_nc, kernel_size=1, stride=1, padding=0, bias=False, norm=norm)
self.dcpack_L2 = dcn_v2(out_nc, out_nc, 3, stride=1, padding=1, dilation=1, deformable_groups=8,
extra_offset_mask=True)
self.relu = nn.ReLU(inplace=True)
weight_init.c2_xavier_fill(self.offset)
def forward(self, feat_l, feat_s, main_path=None):
HW = feat_l.size()[2:]
if feat_l.size()[2:] != feat_s.size()[2:]:
feat_up = F.interpolate(feat_s, HW, mode='bilinear', align_corners=False)
else:
feat_up = feat_s
feat_arm = self.lateral_conv(feat_l) # 0~1 * feats
offset = self.offset(torch.cat([feat_arm, feat_up * 2], dim=1)) # concat for offset by compute the dif
feat_align = self.relu(self.dcpack_L2([feat_up, offset], main_path)) # [feat, offset]
return feat_align + feat_arm
class FAN(Backbone):
"""
This module implements :paper:`FPN`.
It creates pyramid features built on top of some input feature maps.
"""
def __init__(self, bottom_up, in_features, out_channels, norm="", top_block=None, fuse_type="sum"):
"""
Args:
bottom_up (Backbone): module representing the bottom up subnetwork.
Must be a subclass of :class:`Backbone`. The multi-scale feature
maps generated by the bottom up network, and listed in `in_features`,
are used to generate FPN levels.
in_features (list[str]): names of the input feature maps coming
from the backbone to which FPN is attached. For example, if the
backbone produces ["res2", "res3", "res4"], any *contiguous* sublist
of these may be used; order must be from high to low resolution.
out_channels (int): number of channels in the output feature maps.
norm (str): the normalization to use.
top_block (nn.Module or None): if provided, an extra operation will
be performed on the output of the last (smallest resolution)
FPN output, and the result will extend the result list. The top_block
further downsamples the feature map. It must have an attribute
"num_levels", meaning the number of extra FPN levels added by
this block, and "in_feature", which is a string representing
its input feature (e.g., p5).
fuse_type (str): types for fusing the top down features and the lateral
ones. It can be "sum" (default), which sums up element-wise; or "avg",
which takes the element-wise mean of the two.
"""
super(FAN, self).__init__()
assert isinstance(bottom_up, Backbone)
# Feature map strides and channels from the bottom up network (e.g. ResNet)
input_shapes = bottom_up.output_shape()
strides = [input_shapes[f].stride for f in in_features]
in_channels_per_feature = [input_shapes[f].channels for f in in_features]
_assert_strides_are_log2_contiguous(strides)
align_modules = []
output_convs = []
use_bias = norm == ""
for idx, in_channels in enumerate(in_channels_per_feature[:-1]):
stage = int(math.log2(strides[idx]))
lateral_norm = get_norm(norm, out_channels)
align_module = FeatureAlign_V2(in_channels, out_channels, norm=lateral_norm) # proposed fapn
self.add_module("fan_align{}".format(stage), align_module)
align_modules.append(align_module)
output_conv = Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=use_bias,
norm=get_norm(norm, out_channels), )
weight_init.c2_xavier_fill(output_conv)
self.add_module("fpn_output{}".format(stage), output_conv)
output_convs.append(output_conv)
stage = int(math.log2(strides[len(in_channels_per_feature) - 1]))
lateral_conv = Conv2d(in_channels_per_feature[-1], out_channels, kernel_size=1, bias=use_bias,
norm=get_norm(norm, out_channels))
align_modules.append(lateral_conv)
self.add_module("fan_align{}".format(stage), lateral_conv)
# Place convs into top-down order (from low to high resolution) to make the top-down computation in forward clearer.
self.align_modules = align_modules[::-1]
self.output_convs = output_convs[::-1]
self.top_block = top_block
self.in_features = in_features
self.bottom_up = bottom_up
# Return feature names are "p<stage>", like ["p2", "p3", ..., "p6"]
self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides}
# top block output feature maps.
if self.top_block is not None:
for s in range(stage, stage + self.top_block.num_levels):
self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1)
self._out_features = list(self._out_feature_strides.keys())
self._out_feature_channels = {k: out_channels for k in self._out_features}
self._size_divisibility = strides[-1]
assert fuse_type in {"avg", "sum"}
self._fuse_type = fuse_type
@property
def size_divisibility(self):
return self._size_divisibility
def forward(self, x):
"""
Args:
input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to
feature map tensor for each feature level in high to low resolution order.
Returns:
dict[str->Tensor]:
mapping from feature map name to FPN feature map tensor
in high to low resolution order. Returned feature names follow the FPN
paper convention: "p<stage>", where stage has stride = 2 ** stage e.g.,
["p2", "p3", ..., "p6"].
"""
# Reverse feature maps into top-down order (from low to high resolution)
bottom_up_features = self.bottom_up(x)
x = [bottom_up_features[f] for f in self.in_features[::-1]]
results = []
prev_features = self.align_modules[0](x[0])
results.append(prev_features)
for features, align_module, output_conv in zip(x[1:], self.align_modules[1:], self.output_convs[0:]):
prev_features = align_module(features, prev_features)
results.insert(0, output_conv(prev_features))
if self.top_block is not None:
top_block_in_feature = bottom_up_features.get(self.top_block.in_feature, None)
if top_block_in_feature is None:
top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)]
results.extend(self.top_block(top_block_in_feature))
assert len(self._out_features) == len(results)
return dict(zip(self._out_features, results))
def output_shape(self):
return {name: ShapeSpec(channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]) for
name in self._out_features}
def _assert_strides_are_log2_contiguous(strides):
"""
Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2".
"""
for i, stride in enumerate(strides[1:], 1):
assert stride == 2 * strides[i - 1], "Strides {} {} are not log2 contiguous".format(stride, strides[i - 1])
class LastLevelMaxPool(nn.Module):
"""
This module is used in the original FPN to generate a downsampled
P6 feature from P5.
"""
def __init__(self):
super().__init__()
self.num_levels = 1
self.in_feature = "p5"
def forward(self, x):
return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)]
class LastLevelP6P7(nn.Module):
"""
This module is used in RetinaNet to generate extra layers, P6 and P7 from
C5 feature.
"""
def __init__(self, in_channels, out_channels, in_feature="res5"):
super().__init__()
self.num_levels = 2
self.in_feature = in_feature
self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1)
self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1)
for module in [self.p6, self.p7]:
weight_init.c2_xavier_fill(module)
def forward(self, c5):
p6 = self.p6(c5)
p7 = self.p7(F.relu(p6))
return [p6, p7]
@BACKBONE_REGISTRY.register()
def build_resnet_fan_backbone(cfg, input_shape: ShapeSpec):
"""
Args:
cfg: a detectron2 CfgNode
Returns:
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
"""
bottom_up = build_resnet_backbone(cfg, input_shape)
in_features = cfg.MODEL.FPN.IN_FEATURES
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
backbone = FAN(bottom_up=bottom_up, in_features=in_features, out_channels=out_channels,
norm=cfg.MODEL.FPN.NORM, top_block=LastLevelMaxPool(), fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
)
return backbone
@BACKBONE_REGISTRY.register()
def build_retinanet_resnet_fan_backbone(cfg, input_shape: ShapeSpec):
"""
Args:
cfg: a detectron2 CfgNode
Returns:
backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
"""
bottom_up = build_resnet_backbone(cfg, input_shape)
in_features = cfg.MODEL.FPN.IN_FEATURES
out_channels = cfg.MODEL.FPN.OUT_CHANNELS
in_channels_p6p7 = bottom_up.output_shape()["res5"].channels
backbone = FAN(bottom_up=bottom_up, in_features=in_features, out_channels=out_channels, norm=cfg.MODEL.FPN.NORM,
top_block=LastLevelP6P7(in_channels_p6p7, out_channels), fuse_type=cfg.MODEL.FPN.FUSE_TYPE, )
return backbone