-
Notifications
You must be signed in to change notification settings - Fork 82
/
nat.py
390 lines (346 loc) · 10.9 KB
/
nat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
"""
Neighborhood Attention Transformer.
To appear in CVPR 2023.
https://arxiv.org/abs/2204.07143
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_, DropPath
from timm.models.registry import register_model
import natten
from natten import NeighborhoodAttention2D as NeighborhoodAttention
is_natten_post_017 = hasattr(natten, "context")
model_urls = {
"nat_mini_1k": "https://shi-labs.com/projects/nat/checkpoints/CLS/nat_mini.pth",
"nat_tiny_1k": "https://shi-labs.com/projects/nat/checkpoints/CLS/nat_tiny.pth",
"nat_small_1k": "https://shi-labs.com/projects/nat/checkpoints/CLS/nat_small.pth",
"nat_base_1k": "https://shi-labs.com/projects/nat/checkpoints/CLS/nat_base.pth",
}
class ConvTokenizer(nn.Module):
def __init__(self, in_chans=3, embed_dim=96, norm_layer=None):
super().__init__()
self.proj = nn.Sequential(
nn.Conv2d(
in_chans,
embed_dim // 2,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
),
nn.Conv2d(
embed_dim // 2,
embed_dim,
kernel_size=(3, 3),
stride=(2, 2),
padding=(1, 1),
),
)
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = self.proj(x).permute(0, 2, 3, 1)
if self.norm is not None:
x = self.norm(x)
return x
class ConvDownsampler(nn.Module):
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.reduction = nn.Conv2d(
dim, 2 * dim, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False
)
self.norm = norm_layer(2 * dim)
def forward(self, x):
x = self.reduction(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
x = self.norm(x)
return x
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class NATLayer(nn.Module):
def __init__(
self,
dim,
num_heads,
kernel_size=7,
dilation=None,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
layer_scale=None,
):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.norm1 = norm_layer(dim)
extra_args = {"rel_pos_bias": True} if is_natten_post_017 else {"bias": True}
self.attn = NeighborhoodAttention(
dim,
kernel_size=kernel_size,
dilation=dilation,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
**extra_args,
)
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.layer_scale = False
if layer_scale is not None and type(layer_scale) in [int, float]:
self.layer_scale = True
self.gamma1 = nn.Parameter(
layer_scale * torch.ones(dim), requires_grad=True
)
self.gamma2 = nn.Parameter(
layer_scale * torch.ones(dim), requires_grad=True
)
def forward(self, x):
if not self.layer_scale:
shortcut = x
x = self.norm1(x)
x = self.attn(x)
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
shortcut = x
x = self.norm1(x)
x = self.attn(x)
x = shortcut + self.drop_path(self.gamma1 * x)
x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x)))
return x
class NATBlock(nn.Module):
def __init__(
self,
dim,
depth,
num_heads,
kernel_size,
dilations=None,
downsample=True,
mlp_ratio=4.0,
qkv_bias=True,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
norm_layer=nn.LayerNorm,
layer_scale=None,
):
super().__init__()
self.dim = dim
self.depth = depth
self.blocks = nn.ModuleList(
[
NATLayer(
dim=dim,
num_heads=num_heads,
kernel_size=kernel_size,
dilation=None if dilations is None else dilations[i],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, list)
else drop_path,
norm_layer=norm_layer,
layer_scale=layer_scale,
)
for i in range(depth)
]
)
self.downsample = (
None if not downsample else ConvDownsampler(dim=dim, norm_layer=norm_layer)
)
def forward(self, x):
for blk in self.blocks:
x = blk(x)
if self.downsample is None:
return x
return self.downsample(x)
class NAT(nn.Module):
def __init__(
self,
embed_dim,
mlp_ratio,
depths,
num_heads,
drop_path_rate=0.2,
in_chans=3,
kernel_size=7,
dilations=None,
num_classes=1000,
qkv_bias=True,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
norm_layer=nn.LayerNorm,
layer_scale=None,
**kwargs
):
super().__init__()
self.num_classes = num_classes
self.num_levels = len(depths)
self.embed_dim = embed_dim
self.num_features = int(embed_dim * 2 ** (self.num_levels - 1))
self.mlp_ratio = mlp_ratio
self.patch_embed = ConvTokenizer(
in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer
)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.levels = nn.ModuleList()
for i in range(self.num_levels):
level = NATBlock(
dim=int(embed_dim * 2**i),
depth=depths[i],
num_heads=num_heads[i],
kernel_size=kernel_size,
dilations=None if dilations is None else dilations[i],
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i]) : sum(depths[: i + 1])],
norm_layer=norm_layer,
downsample=(i < self.num_levels - 1),
layer_scale=layer_scale,
)
self.levels.append(level)
self.norm = norm_layer(self.num_features)
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.head = (
nn.Linear(self.num_features, num_classes)
if num_classes > 0
else nn.Identity()
)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {"rpb"}
def forward_features(self, x):
x = self.patch_embed(x)
x = self.pos_drop(x)
for level in self.levels:
x = level(x)
x = self.norm(x).flatten(1, 2)
x = self.avgpool(x.transpose(1, 2))
x = torch.flatten(x, 1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
@register_model
def nat_mini(pretrained=False, **kwargs):
model = NAT(
depths=[3, 4, 6, 5],
num_heads=[2, 4, 8, 16],
embed_dim=64,
mlp_ratio=3,
drop_path_rate=0.2,
kernel_size=7,
**kwargs
)
if pretrained:
url = model_urls["nat_mini_1k"]
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint)
return model
@register_model
def nat_tiny(pretrained=False, **kwargs):
model = NAT(
depths=[3, 4, 18, 5],
num_heads=[2, 4, 8, 16],
embed_dim=64,
mlp_ratio=3,
drop_path_rate=0.2,
kernel_size=7,
**kwargs
)
if pretrained:
url = model_urls["nat_tiny_1k"]
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint)
return model
@register_model
def nat_small(pretrained=False, **kwargs):
model = NAT(
depths=[3, 4, 18, 5],
num_heads=[3, 6, 12, 24],
embed_dim=96,
mlp_ratio=2,
drop_path_rate=0.3,
layer_scale=1e-5,
kernel_size=7,
**kwargs
)
if pretrained:
url = model_urls["nat_small_1k"]
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint)
return model
@register_model
def nat_base(pretrained=False, **kwargs):
model = NAT(
depths=[3, 4, 18, 5],
num_heads=[4, 8, 16, 32],
embed_dim=128,
mlp_ratio=2,
drop_path_rate=0.5,
layer_scale=1e-5,
kernel_size=7,
**kwargs
)
if pretrained:
url = model_urls["nat_base_1k"]
checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu")
model.load_state_dict(checkpoint)
return model