# Import Package

In [1]:
import torch
from torch import nn, einsum
import torch.nn.functional as f
import numpy as np
from einops import rearrange, repeat

In [2]:
torch.cuda.is_available()

True

# Swin Transformer Class

In [29]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        
        self.fn = fn
    
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) +x

In [30]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
        
    def forward(self, x, **kwargs):
        # return self.fn(self.norm(x), **kwargs) # swin T v1
        return self.norm(self.fn(x), **kwargs) # swin T v2

In [31]:
torch.manual_seed(0)
B,H,W,C=1,2,2,3
input = torch.randn(B,H,W,C)*100
print("input: ", input)
layer_norm = nn.LayerNorm(C)
output = layer_norm(input)
print("output: ",output)

input:  tensor([[[[ 154.0996,  -29.3429, -217.8789],
          [  56.8431, -108.4522, -139.8595]],

         [[  40.3347,   83.8026,  -71.9258],
          [ -40.3344,  -59.6635,   18.2036]]]])
output:  tensor([[[[ 1.2191,  0.0112, -1.2303],
          [ 1.3985, -0.5173, -0.8813]],

         [[ 0.3495,  1.0120, -1.3615],
          [-0.3948, -0.9787,  1.3735]]]], grad_fn=<NativeLayerNormBackward0>)


In [32]:
x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]).view(3, 3)
print(x)
torch.roll(x, shifts=(1, 1), dims=(0, 1))

tensor([[1, 2, 3],
        [4, 5, 6],
        [7, 8, 9]])


tensor([[9, 7, 8],
        [3, 1, 2],
        [6, 4, 5]])

In [33]:
x = torch.linspace(1,81,81).view(9,9)
print(x)
torch.roll(x, shifts=(-1, -1), dims=(0, 1))

tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18.],
        [19., 20., 21., 22., 23., 24., 25., 26., 27.],
        [28., 29., 30., 31., 32., 33., 34., 35., 36.],
        [37., 38., 39., 40., 41., 42., 43., 44., 45.],
        [46., 47., 48., 49., 50., 51., 52., 53., 54.],
        [55., 56., 57., 58., 59., 60., 61., 62., 63.],
        [64., 65., 66., 67., 68., 69., 70., 71., 72.],
        [73., 74., 75., 76., 77., 78., 79., 80., 81.]])


tensor([[11., 12., 13., 14., 15., 16., 17., 18., 10.],
        [20., 21., 22., 23., 24., 25., 26., 27., 19.],
        [29., 30., 31., 32., 33., 34., 35., 36., 28.],
        [38., 39., 40., 41., 42., 43., 44., 45., 37.],
        [47., 48., 49., 50., 51., 52., 53., 54., 46.],
        [56., 57., 58., 59., 60., 61., 62., 63., 55.],
        [65., 66., 67., 68., 69., 70., 71., 72., 64.],
        [74., 75., 76., 77., 78., 79., 80., 81., 73.],
        [ 2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.,  1.]])

In [34]:
class CyclicShift(nn.Module):
    def __init__(self, displacement):
        super().__init__()
        
        self.displacement = displacement
    
    def forward(self, x):
        return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1,2))

In [35]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )
    
    def forward(self, x):
        return self.net(x)

In [36]:
def create_mask(window_size, displacement, upper_lower, left_right):
    mask = torch.zeros(window_size ** 2, window_size ** 2) # (49,49)
    # print("Original mask: ", mask)
    
    if upper_lower:
#          down lef section
        mask[-displacement * window_size:, :-displacement*window_size] = float('-inf')
#         up right section
        mask[:-displacement * window_size, -displacement*window_size:] = float('-inf')
    
    if left_right:
                                                                # to handle last vertical patches
        mask = rearrange(mask, '(h1 w1) (h2 w2) -> h1 w1 h2 w2', h1=window_size,h2=window_size)
        
        mask[:, -displacement:, :, :-displacement] = float('inf')
        mask[:, :-displacement, :, -displacement:] = float('inf')
        mask = rearrange(mask, 'h1 w1 h2 w2 -> (h1 w1) (h2 w2)')
        
    return mask

In [37]:
create_mask(window_size=3, displacement=1, upper_lower=False, left_right=True)

tensor([[0., 0., inf, 0., 0., inf, 0., 0., inf],
        [0., 0., inf, 0., 0., inf, 0., 0., inf],
        [inf, inf, 0., inf, inf, 0., inf, inf, 0.],
        [0., 0., inf, 0., 0., inf, 0., 0., inf],
        [0., 0., inf, 0., 0., inf, 0., 0., inf],
        [inf, inf, 0., inf, inf, 0., inf, inf, 0.],
        [0., 0., inf, 0., 0., inf, 0., 0., inf],
        [0., 0., inf, 0., 0., inf, 0., 0., inf],
        [inf, inf, 0., inf, inf, 0., inf, inf, 0.]])

In [38]:
def get_relative_distances(window_size):
    indices = torch.tensor(np.array([[x, y] for x in range(window_size) for y in range(window_size)]))
    distances = indices[None, :, :] - indices[:, None, :]
    return distances

In [39]:
class WindowAttention(nn.Module):
    def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
        inner_dim = head_dim * heads
        self.heads = heads
        self.scale = head_dim ** -0.5          # scaling dot product inside softmax
        self.window_size = window_size
        self.relative_pos_embedding = relative_pos_embedding
        self.shifted = shifted
        
        self.tau = nn.Parameter(torch.tensor(0.01), requires_grad=True)
        
        if self.shifted:
            displacement = window_size // 2
            self.cyclic_shift = CyclicShift(-displacement)
            self.cyclic_back_shift = CyclicShift(displacement)
            
            
            # (49,49): masks are NOT learnable parameters; requires_grad=False
            self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size,
                                                            displacement=displacement,
                                                            upper_lower=True, left_right=False), 
                                                 requires_grad=False)
            
            self.left_right_mask = nn.Parameter(create_mask(window_size=window_size,
                                                            displacement=displacement,
                                                            upper_lower=False, left_right=True), 
                                                 requires_grad=False)
            
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        # dim = (96, 192, 384,384) and inner_dim=head_dim * heads, can also use C*3
        
        # self.pos_embedding= nn.Parameter(torch.randn(window_size**2, window_size**2)) #(49, 49)
        
        if self.relative_pos_embedding:
            self.relative_indicies = get_relative_distances(window_size) + window_size - 1
            self.pos_embedding = nn.Parameter(torch.randn(2*window_size -1, 2*window_size -1))
        else:
            self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))
        
        self.to_out = nn.Linear(inner_dim, dim)
        # inner_dim = head_dim * heads = C, dim=hidden_dim=(96,192,384,768)
        
    def forward(self, x):
        
        if self.shifted:
            x = self.cyclic_shift(x)
            
        b, n_h, n_w, _, h = *x.shape, self.heads
        
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        
        nw_h = n_h // self.window_size
        nw_w = n_w // self.window_size
        
        q,k,v = map(
            lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
            h = h, w_h = self.window_size, w_w = self.window_size), qkv
        )
        
        # dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) * self.scale
        
        q = f.normalize(q, p=2, dim=-1)
        k = f.normalize(k, p=2, dim=-1)
        
        dots = einsum('b h w i d, b h w j d -> b h w i j', q, k) / self.tau
        
        
        if self.relative_pos_embedding:
            dots += self.pos_embedding[self.relative_indicies[:,:,0], self.relative_indicies[:,:,1]]
        else:
            dots += self.pos_embedding
            
        if self.shifted:
            dots[:,:, -nw_w:] += self.upper_lower_mask
            dots[:,:, nw_w-1:: nw_w] += self.left_right_mask
        
        attn = dots.softmax(dim=-1)
        
        out = einsum('b h w i j, b h w j d -> b h w i d', attn, v)
        out = rearrange(out, 'b h (nw_h nw_w) (w_h w_w) d -> b (nw_h w_h) (nw_w w_w) (h d)',
                        h=h, w_h=self.window_size, w_w=self.window_size, nw_h=nw_h, nw_w=nw_w)
        out = self.to_out(out)

        if self.shifted:
            out = self.cyclic_back_shift(out)
            
        return out

In [40]:
class SwinBlock(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size,
                relative_pos_embedding):
        super().__init__()
        self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim,
                                                                    heads=heads,
                                                                    head_dim=head_dim,
                                                                    shifted=shifted,
                                                                    window_size=window_size,
                                                                    relative_pos_embedding=relative_pos_embedding)))
        self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))
        
    def forward(self, x):
        x = self.attention_block(x)
        x = self.mlp_block(x)
        return x

In [41]:
class PatchMerging_Conv(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        
        self.downscaling_factor = downscaling_factor
        
        self.patch_merge = nn.Conv2d(in_channels,
                                    out_channels,
                                    kernel_size=downscaling_factor,
                                    stride=downscaling_factor,
                                    padding=0)
        
    def forward(self, x):
        x = self.patch_merge(x).permute(0, 2, 3, 1)
        return x
    
class PatchMerging(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        
        self.downscaling_factor = downscaling_factor
        
        self.patch_merge = nn.Unfold(
                                    kernel_size=downscaling_factor,
                                    stride=downscaling_factor,
                                    padding=0)
        
        self.linear = nn.Linear(in_channels * downscaling_factor ** 2, out_channels)
        
    def forward(self, x):
        b,c,h,w = x.shape
        new_h, new_w = h // self.downscaling_factor, w // self.downscaling_factor
        x = self.patch_merge(x).view(b, -1, new_h, new_w).permute(0, 2, 3, 1)
        x = self.linear(x)
        return x

In [42]:
class StageModule(nn.Module):
    def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads,
                head_dim, window_size, relative_pos_embedding):
        super().__init__()
        assert layers %2==0,'Stage layers need to be divisible by 2 for regular and shifted blocks'
        # self.patch_partition = PatchMerging_Conv(in_channels=in_channels, out_channels=hidden_dimension,
        #                                         downscaling_factor=downscaling_factor)
        
        self.patch_partition = PatchMerging(in_channels=in_channels, out_channels=hidden_dimension,
                                        downscaling_factor=downscaling_factor)
        
        self.layers = nn.ModuleList([])
        for _ in range(layers // 2):
            self.layers.append(nn.ModuleList([
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, 
                         mlp_dim=hidden_dimension*4, shifted=False, window_size=window_size,
                         relative_pos_embedding=relative_pos_embedding),
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, 
                         mlp_dim=hidden_dimension*4, shifted=True, window_size=window_size,
                         relative_pos_embedding=relative_pos_embedding)
            ]))
    
    def forward(self, x):
        # print('before patching merge: ', x.size()) (1,(3,96,192,384),(224,56,28,14),(224,56,28,14))
        x = self.patch_partition(x)
        # print('after patching merge: ', x.size()) (1, (56,28,14,7), (56,28,14,7),(96,192,384,768))
        for regular_block, shifted_block in self.layers:
            x=regular_block(x) #(1, (56,28,14,7), (56,28,14,7),(96,192,384,768))
            
            x=shifted_block(x) #(1, (56,28,14,7), (56,28,14,7),(96,192,384,768))
        return x.permute(0,3,1,2) #(1,768,7,7)

In [43]:
class SwinTransformer(nn.Module):
    def __init__(self, *, hidden_dim, layers, heads, channels=3, num_classes=1000, head_dim=32,
            window_size=7, downscaling_factors=(4,2,2,2), relative_pos_embedding=True):
        super().__init__()
        
        self.stage1 = StageModule(in_channels=channels,hidden_dimension=hidden_dim,layers=layers[0],
                                 downscaling_factor=downscaling_factors[0], num_heads=heads[0],
                                 head_dim=head_dim, window_size=window_size, 
                                 relative_pos_embedding=relative_pos_embedding)
        
        self.stage2 = StageModule(in_channels=hidden_dim,hidden_dimension=hidden_dim*2, layers=layers[1],
                                 downscaling_factor=downscaling_factors[1], num_heads=heads[1],
                                 head_dim=head_dim, window_size=window_size, 
                                 relative_pos_embedding=relative_pos_embedding)
        
        self.stage3 = StageModule(in_channels=hidden_dim*2, hidden_dimension=hidden_dim*4, layers=layers[2],
                                 downscaling_factor=downscaling_factors[2], num_heads=heads[2],
                                 head_dim=head_dim, window_size=window_size, 
                                 relative_pos_embedding=relative_pos_embedding)
        
        self.stage4 = StageModule(in_channels=hidden_dim*4, hidden_dimension=hidden_dim*8, layers=layers[3],
                                 downscaling_factor=downscaling_factors[3], num_heads=heads[3],
                                 head_dim=head_dim, window_size=window_size, 
                                 relative_pos_embedding=relative_pos_embedding)
        
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim*8),
            nn.Linear(hidden_dim*8, num_classes)
        )
        
    def forward(self, img):
        x = self.stage1(img)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = x.mean(dim=[2,3])
        return self.mlp_head(x)

In [44]:
def swin_t(hidden_dim=96, layers=(2,2,6,2), heads=(3,6,12,24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)

In [47]:
net = swin_t(
    hidden_dim=96,
    layers=(2,2,6,2),
    heads=(3,6,12,24),
    channels=3,
    num_classes=3,
    head_dim=32,
    window_size=7,
    downscaling_factors=(4,2,2,2),
    relative_pos_embedding=True
)

dummy_x = torch.randn(1,3,224,224)
logits = net(dummy_x) # (1, 3)
# print("network: ", net)
print(logits)

tensor([[nan, nan, nan]], grad_fn=<AddmmBackward0>)
