In [1]:
!pip install torch numpy einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[K     |████████████████████████████████| 44 kB 1.4 MB/s eta 0:00:01
Installing collected packages: einops
Successfully installed einops-0.7.0


In [2]:
#nn and numpy
import torch
from torch import nn, einsum
import numpy as np

#some bullshit
from einops import rearrange

#nn, but some bullshit
import torch.nn.functional as f

In [None]:
class SwinBlock(nn.Module):
    def __init__(self, dim, heads, head_dim, mlp_dim, shifted, window_size, relative_pos_embedding):
        super().__init__()
        self.attention_block = Residual(PreNorm(dim, WindowAttention(dim=dim,
                                                                     heads=heads,
                                                                     head_dim=head_dim,
                                                                     shifted=shifted,
                                                                     window_size=window_size,
                                                                     relative_pos_embedding=relative_pos_embedding)))
        self.mlp_block = Residual(PreNorm(dim, FeedForward(dim=dim, hidden_dim=mlp_dim)))

    def forward(self, x):
        x = self.attention_block(x)
        x = self.mlp_block(x)
        return x

In [None]:
class PatchMerging_conv(nn.Module):
    def __init__(self, in_channels, out_channels, downscaling_factor):
        super().__init__()

        self.patch_merge = nn.Conv2d(in_channels,
                                     out_channels,
                                     kernel_size=downscaling_factor,
                                     stride=downscaling_factor,
                                     padding=0)

    def forward(self, x):
        x = self.patch_merge(x).permute(0, 2, 3, 1)
        return x


In [None]:
class StageModule(nn.Module):
    def __init__(self, in_channels, hidden_dimension, layers, downscaling_factor, 
                 num_heads, head_dim, window_size,relative_pos_embedding):
        super().__init__()

        # Stage layers need to be divisible by 2 for regular and shifted block.
        assert layers % 2 == 0

        self.patch_partition = PatchMerging_conv(in_channels=in_channels, out_channels=hidden_dimension,
                                            downscaling_factor=downscaling_factor)

        self.layers = nn.ModuleList([])
        # here we can see that we need a number of layers to be divisible by 2 so that we can divide input into 
        # two windows: first is fixed window, second is a shifted window 
        for _ in range(layers // 2):
            self.layers.append(nn.ModuleList([
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=False, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
                SwinBlock(dim=hidden_dimension, heads=num_heads, head_dim=head_dim, mlp_dim=hidden_dimension * 4,
                          shifted=True, window_size=window_size, relative_pos_embedding=relative_pos_embedding),
            ]))

    def forward(self, x):
        # patch merging(partition) works before swin transformer block. It crops input to create hierarhic structure 
        x = self.patch_partition(x)
        for regular_block, shifted_block in self.layers:
            x = regular_block(x)
            x = shifted_block(x)
        return x.permute(0, 3, 1, 2)


In [None]:
class SwinTransformer(nn.Module):
    '''hidden dim is a number of channels named C at scheme that we will use for tokenization; channel is RGB that's why 3; 
    num_classes is somehow connected with our net or somthing... 1000 is just a default number;
      window size is a size of a window we obtain at the end (7 x 7 = 49 vectors); downsca;ing factor 
    is a number on which we divide each side (1 stage a/ 4, 2 stage a/4/2 etc.)'''
    def __init__(self, *, hidden_dim, layers, heads,channels=3, num_classes=1000, head_dim=32, window_size=7,
                 downscaling_factors=(4, 2, 2, 2), relative_pos_embedding=True):
        
        super().__init__()
        
        #swin works in 4 stages
        self.stage1 = StageModule(in_channels=channels, hidden_dimension=hidden_dim, layers=layers[0],
                                  downscaling_factor=downscaling_factors[0], num_heads=heads[0], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage2 = StageModule(in_channels=hidden_dim, hidden_dimension=hidden_dim * 2, layers=layers[1],
                                  downscaling_factor=downscaling_factors[1], num_heads=heads[1], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage3 = StageModule(in_channels=hidden_dim * 2, hidden_dimension=hidden_dim * 4, layers=layers[2],
                                  downscaling_factor=downscaling_factors[2], num_heads=heads[2], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        self.stage4 = StageModule(in_channels=hidden_dim * 4, hidden_dimension=hidden_dim * 8, layers=layers[3],
                                  downscaling_factor=downscaling_factors[3], num_heads=heads[3], head_dim=head_dim,
                                  window_size=window_size, relative_pos_embedding=relative_pos_embedding)
        
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 8),
            nn.Linear(hidden_dim * 8, num_classes)
        )
    
    def forward(self, img):
        x = self.stage1(img)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        
        x = x.mean(dim=[2, 3])
        
        return self.mlp_head(x)

In [None]:
'''there are different types of arvhitecture (swin-t, swin-s, swin-b, the only difference is the number of 
layers at third stage and number of channels), we use swin-t and the hyperparameters for it are fixed'''
def swin_t(hidden_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(hidden_dim=hidden_dim, layers=layers, heads=heads, **kwargs)