In [3]:
import torch
from torch import nn, einsum
from einops import rearrange
import torch.nn.functional as f


In [None]:
def create_mask(window_size, displacement, upper_lower, left_right): # to handel the cyclic shift patches
    mask = torch.zeros(window_size**2, window_size**2)
    print("original mask: ", mask)
    
    if upper_lower:
        mask[-displacement*window_size:, :-displacement * window_size] = float('-inf') # downleft section
        mask[:-displacement * window_size, -displacement*window_size:] = float('-inf') # up right section 
        
    if left_right:
        mask = rearange(mask, '(h1 w1)(h2 w2) -> h1 w1 h2 w2', h1=window_size, h2=window_size)
        mask[:, -displacement:, :, :-displacement] = float('-inf')
        mask[:, :-displacement, :, -displacment:] = float('inf')
        mask = rearange(mask, 'h1, w1, h2, w2 -> (h1 w1)(h2 w2)')
        
    return mask 


In [None]:
class WindowAttention(nn.Module):
    def __init__(self, dim, heads, head_dim, shifted, window_size, relative_pos_embedding):
        # dim=hidden_dim=(96,192,384,768)
        # heads =num_heads=(3,6,12,24)
        # head_dim=32
        super().__init__()
        inner_dim = head_dim * heads                    # (32 * 3=96, 32*6=192, 32*12=384, 32*24=768) = C
        self.heads = heads
        self.scale = head_dim  * -0.5                   # scaling dot product inside the softmax 
        self.window_size = window_size                  # window_size=7
        self.relative_pos_embedding = relative_pos_embedding
        self.shifted = shifted 
        
        if self.shifted:
            displacement = window_size // 2
            self.cyclic_shift = CyclicShift(-displacement)
            self.cyclic_back_shift = CyclicShift(displacement)
            
            # (49, 49): masks are not learnable parameters: requires_grad=False
            self.upper_lower_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                             upper_lower=True, left_right=False),
                                                 requires_grad=False)
            self.left_right_mask = nn.Parameter(create_mask(window_size=window_size, displacement=displacement,
                                                            upper_lower=False, left_right=True),
                                                requires_grad=False)
            
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        # dim= (96, 192, 384, 768) and (inner_dim = head_dim * heads) ; We can also use C*3 and gives us same thing
        self.pos_embedding = nn.Parameter(torch.randn(window_size ** 2, window_size ** 2))  # (49, 49)
        self.to_out = nn.Linear(inner_dim, dim)
        # inner_dim = head_dim * heads = C, dim = hidden_dim = (96, 192, 384, 768)
        
    def forward(self, x):
        if self.shifted:
            #print('x.size: ', x.size())          # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
            x = self.cyclic_shift(x)
            #print('x size', x.size())             # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
            
        b, n_h, n_w, _, h = x.shape, self.heads           
        #print('x shape: ', x.shape)               # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
        
        #print('self.to_qkv(x): ', self.to_qkv(x).size())  # (1, (56, 28, 14, 7), (56, 28, 14, 7), (288, 576, 1152, 2304))  
        
        nw_h = n_h // self.window_size           
        nw_w = n_w // seld.window_size 
        
        q, k, v = map(lambda t: rearrange(t, 'b (nw_h w_h) (nw_w w_w) (h d) -> b h (nw_h nw_w) (w_h w_w) d',
                                         h=h, w_h=self.window_size, w_w=self.window_size), qkv)
        
        #print('q size: ', q.size())
        #(b=1, h=(3,6,12,24), (nw_h*nw_w)=(64, 16,4,1),(w_h*w_w)=40, d=32) where d=head_dim, h=#heads
        #print('k size: ', k.size())   # same sas the q
        #print('v size: ', v.size())    # same as q 
        
        
            
            
        

In [None]:
class CyclicShift(nn.Module):
    def __init__(self, displacement):
        super().__init__()
        self.displacement = displacement
        
    def forward(self, x):
        return torch.roll(x, shifts=(self.displacement, self.displacement), dims=(1,2))
    

In [None]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
                   nn.Linear(dim, hidden_dim),
                   nn.GELU(),
                   nn.Linear(hidden_dim, dim)
        )
        
    def forward(self, x):
        return self.net(x)
    
    

In [None]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
        
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x
    
    

In [None]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn 
        
    def forward(seld, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)
    

In [6]:
class SwinBlock(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
        # dim=hidden_dim=(96, 192, 384, 768)  # heads = num_heads=(3,6,12,24), mlp_dim=hidden_dim * 4
        super().__init__()
        self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim,
                                                                    heads=heads,
                                                                    head_dim=head_dim,
                                                                    shifted=shifted,
                                                                    window_size=window_size,
                                                relative_pos_embedding=relative_pos_embedding)))
        self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))
        
    def forward(self, x):
        x = self.attention_block(x)   # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
        x = self.mlp_block(x)         # (1, (56, 28, 14, 7), (56, 28, 14, 7), (96, 192, 384, 768))
        return x 
    
    
    

In [5]:
class PatchMerging_Conv(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()
        
        self.patch_merge = nn.Conv(in_channels,
                                  out_channels, 
                                  kernel_size=downscaling_factor,
                                  stride=downscaling_factor, 
                                  padding=0)
        
    def forward(self, x):
        #print('x.size: ', x.size())       # (1, (3, 96,192,384), (224, 56,28,14), (224, 56, 28,14))
        #self.patch_merge(x)               # (1, (96, 192, 384, 768), (56, 28, 14, 7), (56, 28, 14, 7))
        x = self.patch_merge(x).permute(0, 2, 3, 1) # (1, (56, 28, 14, 7), (56, 28,14,7), (96, 192, 384, 768))
        return x 
    

In [None]:
class StageModule(nn.Module):
    def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, num_heads, head_dim,
                window_size, relative_pose_embedding):
        super().__init__() 
        assert layers % 2 == 0 # stage layers need to  be divisible by 2 for regular and shifted block.
        
        self.patch_partition = PatchMerging_Conv(in_channels=in_channels, out_channels=hidden_dimension, 
                                                downscaling_factor = downscaling_factor)
        self.layers = nn.ModuleList([])
        for _ in range(layers // 2):
            self.layers.append(nn.ModuleList([
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimensin*4,
                         shifted=False, window_size=window_size, relative_pose_embedding=relative_pos_embedding),
                SwinBlock(dim=hidden_dimensin, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                         shifted=True, window_size=window_size, retive_pos_embedding=relative_pos_embedding)
            ]))
            
    def forward(self, x):
        #print('before patch merging: ', x.size())  # (1, (3,96,192,384), (224,56,28,14), (224,56,28,14))
        x = self.patch_partition(x)
        #print('after patch merging: ', x.size())  # (1, (56, 28, 14,7), (56,28,14,7),(96,192,384,768))
        for regular_block, shifted_block in self.layers:
            x = regular_block(x)                    # (1, (56,28,14,7), (56,28,14,7), (96,192,384, 768))
            x = shifted_block(x)                    # (1, (56,28,14,7), (56,28,14,7), (96,192,384,768))
            
        return x.permute(0,3,1,2)                   # (1, 768, 7, 7)
    
    
    
            
            

In [4]:
class SwinTransformer(nn.Module):
    def __init__(self, *, hidden_dim, layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7,
                downscaling_factors=(4,2,2,2), relative_pos_embedding=True):
        super().__init__()
        
        self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0], 
                                  downscaling_factor=downscaling_factor[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pose_embedding=relative_pose_embedding)
        self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dimension * 2, layers=layers[1],
                                  downscaling_factor=downscaling_factor[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage3 = StageModule(in_channels = hidden_dim * 2, hidden_dimension = hidden_dim * 4,
                                  layers = layers[2], downscaling_factor = downscaling_factors[2], 
                                  num_heads=heads[2], head_dim=head_dim, window_size=window_size,
                                  relative_pos_embedding=relative_pos_embedding)
        self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3],
                                 downscaling_factor=downscaling_factors[3], num_heads=heads[3],
                head_dim=head_dim, window_size=window_size, relative_pose_embedding=relative_pos_embedding)
        
        self.mlp_head = nn.Sequential(
                         nn.LayerNorm(hidden_dim * 8),
                         nn.Linear(hidden_dim * 8, num_classes)
        )
        
    def forward(self, img):
        x = self.stage1(img)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = x.mean(dim=[2,3])
        return self.mlp_head(x)
    
    

In [None]:
def swin_t(hidden_dim=96, layers=(2,2,6,2), heads=(3,6,12,24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)

