Skip to content

Commit

Permalink
update swin model (#281)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ldpe2G committed May 13, 2022
1 parent e473bd3 commit f02edc0
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions libai/models/swin_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
def window_partition(x, window_size):
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).view(-1, window_size, window_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows


def window_reverse(windows, window_size, H, W):
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).view(B, H, W, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x


Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
coords = flow.stack(flow.meshgrid(*[coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = flow.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
Expand Down Expand Up @@ -133,7 +133,9 @@ def forward(self, x, mask):
self.window_size[0] * self.window_size[1],
-1,
) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1) # nH, Wh*Ww, Wh*Ww
relative_position_bias = relative_position_bias.permute(
2, 0, 1
).contiguous() # nH, Wh*Ww, Wh*Ww
unsqueeze_relative_position_bias = relative_position_bias.unsqueeze(0)
attn = attn + unsqueeze_relative_position_bias

Expand Down

0 comments on commit f02edc0

Please sign in to comment.