In [2]:
import torch
import numpy as np
from torch import nn, einsum
import torch.nn.functional as f

import einops
from einops import rearrange

In [None]:
class FeedForward(nn.Module):
    def __init__(self, hidden_dim, dim):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim)
        )
        

    def forward(self, x):
        return self.network(x)

In [None]:
class WindowAttention(nn.Module):
    def __init__(self, dim, fn):
        
        
    def forward(self, x, **kwargs):
        pass

In [None]:
class PreNorm(nn.Module):
    def __init__(self, fn, dim):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
        
    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(self.norm(x), **kwargs)

In [None]:
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

In [6]:
class SwinBlock(nn.Module):
    def __init__(self, dim, num_heads, head_dim, mlp_dim, shifted,window_size, rel_pos_emb):
        super().__init__()
        self.attention_block = Residual(PreNorm(WindowAttention(dim=dim, num_heads=num_heads, head_dim=head_dim, shifted=shifted, window_size=window_size, rel_pos_emb=rel_pos_emb), dim))
        self.mlp_block = Residual(PreNorm(FeedForward(hid_dim=mlp_dim, dim=dim,), dim))
    

    def forward(self, x):
        x = self.attention_block(x)
        x = self.mlp_block(x)
        return x
        

In [8]:
class PatchMerging_Conv(nn.Module):
    def _init_(self, in_channels, out_channels, down_scaling_fact):
        super()._init_()
        self.patch_merge = nn.Conv2d(in_channels, out_channels, kernel_size=down_scaling_fact, stride=down_scaling_fact, padding=0)

    def forward(self, x):
        x = self.patch_merge(x).permute(0, 2, 3, 1)
        return x

In [5]:
class StageModule():
    def __init__(self, in_channel, hid_dim, layers, down_scaling_factor, num_heads, head_dim, window_size, rel_pos_emb):
        super().__init__()
        assert layers % 2 == 0, 'number of layers should be even'
        self.patch_partition = PatchMerging_Conv(in_channels=in_channel, out_channels=hid_dim, down_scaling_factor=down_scaling_factor)
        self.layers = nn.ModuleList([])
        for _ in range(layers//2):
            self.layers.append(nn.ModuleList([
                Swin_Block(dim=hid_dim, num_heads=num_heads, head_dim=head_dim, mlp_dim = hid_dim*4, shifted=False ,window_size=window_size, rel_pos_emb=rel_pos_emb),
                Swin_Block(dim=hid_dim, num_heads=num_heads, head_dim=head_dim, mlp_dim = hid_dim*4, shifted=True ,window_size=window_size, rel_pos_emb=rel_pos_emb),
            ]))

    def forward(self, x):
        x = self.patch_partition(x)
        for regular, shifted in self.layers:
            x = regular(x)
            x = shifted(x)
        
        return x.permute(0,3,1,2)

In [3]:

class SwinTransformer(nn.Module):
    def __init__(self, *, hid_dim, layers, heads, channels=3, num_classes=1000, head_dim=32, window_size=7, down_scaling_fact=(4,2,2,2), rel_pos_emb = True):
      super().__init__()
      self.stage1 = StageModule(in_channel = channels, hid_dim=hid_dim, layers=layers[0], down_scaling_fact=down_scaling_fact[0], num_heads=heads[0], head_dim=head_dim, window_size=window_size, rel_pos_emb=rel_pos_emb)
      self.stage2 = StageModule(in_channel = hid_dim, hid_dim=hid_dim*2, layers=layers[1], down_scaling_fact=down_scaling_fact[1], num_heads=heads[1], head_dim=head_dim, window_size=window_size, rel_pos_emb=rel_pos_emb)
      self.stage3 = StageModule(in_channel = hid_dim*2, hid_dim=hid_dim*4, layers=layers[2], down_scaling_fact=down_scaling_fact[2], num_heads=heads[2], head_dim=head_dim, window_size=window_size, rel_pos_emb=rel_pos_emb)
      self.stage4 = StageModule(in_channel = hid_dim*4, hid_dim=hid_dim*8, layers=layers[3], down_scaling_fact=down_scaling_fact[3], num_heads=heads[3], head_dim=head_dim, window_size=window_size, rel_pos_emb=rel_pos_emb)

      self.mlp_head = nn.Sequential(
            nn.LayerNorm(hid_dim*8),
            nn.Linear(hid_dim*8, num_classes)
        )
    
    def forward(self, x):
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = x.mean(dim=[2,3])
        return self.mlp_head(x)

In [None]:
def swin_t(hid_dim=96, layers=(2, 2, 6, 2), heads=(3, 6, 12, 24), **kwargs):
    return SwinTransformer(hid_dim=hid_dim, layers=layers, heads=heads, **kwargs)