-
Notifications
You must be signed in to change notification settings - Fork 585
/
deeplabv3.py
185 lines (145 loc) · 6.41 KB
/
deeplabv3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
"""Pyramid Scene Parsing Network"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .segbase import SegBaseModel
from .fcn import _FCNHead
__all__ = ['DeepLabV3', 'get_deeplabv3', 'get_deeplabv3_resnet50_voc', 'get_deeplabv3_resnet101_voc',
'get_deeplabv3_resnet152_voc', 'get_deeplabv3_resnet50_ade', 'get_deeplabv3_resnet101_ade',
'get_deeplabv3_resnet152_ade']
class DeepLabV3(SegBaseModel):
r"""DeepLabV3
Parameters
----------
nclass : int
Number of categories for the training dataset.
backbone : string
Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
'resnet101' or 'resnet152').
norm_layer : object
Normalization layer used in backbone network (default: :class:`nn.BatchNorm`;
for Synchronized Cross-GPU BachNormalization).
aux : bool
Auxiliary loss.
Reference:
Chen, Liang-Chieh, et al. "Rethinking atrous convolution for semantic image segmentation."
arXiv preprint arXiv:1706.05587 (2017).
"""
def __init__(self, nclass, backbone='resnet50', aux=False, pretrained_base=True, **kwargs):
super(DeepLabV3, self).__init__(nclass, aux, backbone, pretrained_base=pretrained_base, **kwargs)
self.head = _DeepLabHead(nclass, **kwargs)
if self.aux:
self.auxlayer = _FCNHead(1024, nclass, **kwargs)
self.__setattr__('exclusive', ['head', 'auxlayer'] if aux else ['head'])
def forward(self, x):
size = x.size()[2:]
_, _, c3, c4 = self.base_forward(x)
outputs = []
x = self.head(c4)
x = F.interpolate(x, size, mode='bilinear', align_corners=True)
outputs.append(x)
if self.aux:
auxout = self.auxlayer(c3)
auxout = F.interpolate(auxout, size, mode='bilinear', align_corners=True)
outputs.append(auxout)
return tuple(outputs)
class _DeepLabHead(nn.Module):
def __init__(self, nclass, norm_layer=nn.BatchNorm2d, norm_kwargs=None, **kwargs):
super(_DeepLabHead, self).__init__()
self.aspp = _ASPP(2048, [12, 24, 36], norm_layer=norm_layer, norm_kwargs=norm_kwargs, **kwargs)
self.block = nn.Sequential(
nn.Conv2d(256, 256, 3, padding=1, bias=False),
norm_layer(256, **({} if norm_kwargs is None else norm_kwargs)),
nn.ReLU(True),
nn.Dropout(0.1),
nn.Conv2d(256, nclass, 1)
)
def forward(self, x):
x = self.aspp(x)
return self.block(x)
class _ASPPConv(nn.Module):
def __init__(self, in_channels, out_channels, atrous_rate, norm_layer, norm_kwargs):
super(_ASPPConv, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rate, dilation=atrous_rate, bias=False),
norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
nn.ReLU(True)
)
def forward(self, x):
return self.block(x)
class _AsppPooling(nn.Module):
def __init__(self, in_channels, out_channels, norm_layer, norm_kwargs, **kwargs):
super(_AsppPooling, self).__init__()
self.gap = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
nn.ReLU(True)
)
def forward(self, x):
size = x.size()[2:]
pool = self.gap(x)
out = F.interpolate(pool, size, mode='bilinear', align_corners=True)
return out
class _ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates, norm_layer, norm_kwargs, **kwargs):
super(_ASPP, self).__init__()
out_channels = 256
self.b0 = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
nn.ReLU(True)
)
rate1, rate2, rate3 = tuple(atrous_rates)
self.b1 = _ASPPConv(in_channels, out_channels, rate1, norm_layer, norm_kwargs)
self.b2 = _ASPPConv(in_channels, out_channels, rate2, norm_layer, norm_kwargs)
self.b3 = _ASPPConv(in_channels, out_channels, rate3, norm_layer, norm_kwargs)
self.b4 = _AsppPooling(in_channels, out_channels, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
norm_layer(out_channels, **({} if norm_kwargs is None else norm_kwargs)),
nn.ReLU(True),
nn.Dropout(0.5)
)
def forward(self, x):
feat1 = self.b0(x)
feat2 = self.b1(x)
feat3 = self.b2(x)
feat4 = self.b3(x)
feat5 = self.b4(x)
x = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
x = self.project(x)
return x
def get_deeplabv3(dataset='pascal_voc', backbone='resnet50', pretrained=False, root='~/.torch/models',
pretrained_base=True, **kwargs):
acronyms = {
'pascal_voc': 'pascal_voc',
'pascal_aug': 'pascal_aug',
'ade20k': 'ade',
'coco': 'coco',
'citys': 'citys',
}
from ..data.dataloader import datasets
model = DeepLabV3(datasets[dataset].NUM_CLASS, backbone=backbone, pretrained_base=pretrained_base, **kwargs)
if pretrained:
from .model_store import get_model_file
device = torch.device(kwargs['local_rank'])
model.load_state_dict(torch.load(get_model_file('deeplabv3_%s_%s' % (backbone, acronyms[dataset]), root=root),
map_location=device))
return model
def get_deeplabv3_resnet50_voc(**kwargs):
return get_deeplabv3('pascal_voc', 'resnet50', **kwargs)
def get_deeplabv3_resnet101_voc(**kwargs):
return get_deeplabv3('pascal_voc', 'resnet101', **kwargs)
def get_deeplabv3_resnet152_voc(**kwargs):
return get_deeplabv3('pascal_voc', 'resnet152', **kwargs)
def get_deeplabv3_resnet50_ade(**kwargs):
return get_deeplabv3('ade20k', 'resnet50', **kwargs)
def get_deeplabv3_resnet101_ade(**kwargs):
return get_deeplabv3('ade20k', 'resnet101', **kwargs)
def get_deeplabv3_resnet152_ade(**kwargs):
return get_deeplabv3('ade20k', 'resnet152', **kwargs)
if __name__ == '__main__':
model = get_deeplabv3_resnet50_voc()
img = torch.randn(2, 3, 480, 480)
output = model(img)