In [1]:
####

In [16]:
import torch
from torch import nn

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

In [None]:
x = torch.randn(2 , 224 , 224 , 3).to(device)
window_size = 4
z = window_partition(x , window_size)
z.shape

In [10]:
def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

In [None]:
y = window_reverse(z , 4 , 224 , 224)
y.shape

In [14]:
class MLP(nn.Module):
    def __init__(self , 
                 in_channels , 
                 hidden_dim = None, 
                 out_channels = None , 
                 use_activation = True , 
                 dropout = 0.0):
        super(MLP , self).__init__()
        
        self.use_activation = use_activation
        out_channels = out_channels or in_channels
        hidden_dim = hidden_dim or in_channels

        self.fc1 = nn.Linear(in_channels , hidden_dim)
        if self.use_activation:
            self.activation = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim , out_channels)
        self.dropout = nn.Dropout(dropout)

    def forward(self , x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [100]:
class Window_Attention(nn.Module):
    def __init__(self , 
                 dim , 
                 window_size , 
                 num_heads , 
                 qk_scale = None , 
                 attn_dropout = 0.0 , 
                 proj_drop = 0.0):
        super(Window_Attention , self).__init__()

        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale  or head_dim ** -0.5



        self.qkv = nn.Linear(dim , dim * 3 , bias=False)
        self.attn_dropout = nn.Dropout(attn_dropout)
        self.proj = nn.Linear(dim , dim)
        self.proj_dropout = nn.Dropout(proj_drop)
        self.softmax = nn.Softmax(-1)

    
    def forward(self , x , mask = None):
        B_ , N , C = x.shape
        qkv = self.qkv(x).reshape(B_ , N , 3 , self.num_heads , C//self.num_heads).permute(2 , 0 , 3 , 1 , 4)
        q , k , v = qkv[0] , qkv[1] , qkv[2]
        q = q * self.scale

        attn = (q @ k.transpose(-2 , -1))

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW , nW , self.num_heads , N , N)
            attn = attn.view(-1 , self.num_heads , N , N)
            attn = self.softmax(attn)
        else :
            attn = self.softmax(attn)
        attn = self.attn_dropout(attn)

        x = (attn @ v).transpose(1 , 2).reshape(B_ , N , C)
        x = self.proj(x)
        x = self.proj_dropout(x)
        return x


In [115]:
class Swin_Transformer_Block(nn.Module):
    def __init__(self , 
                 in_channels , 
                 input_res , 
                 num_heads , 
                 window_size = 7 , 
                 shift_size = 0 , 
                 mlp_ratio = 4.0 , 
                 qk_scale = None , 
                 drop = 0.0 , 
                 attn_drop = 0.0 ):
        super(Swin_Transformer_Block , self).__init__()

        self.in_channels = in_channels
        self.input_res = input_res
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        self.qk_scale = qk_scale
        
        if min(self.input_res) <= self.window_size:
            self.shidt_size = 0
            self.window_size = min(self.input_res)
        
        self.norm1 = nn.LayerNorm(in_channels)
        self.attn = Window_Attention(
            self.in_channels , self.window_size , self.num_heads , self.qk_scale , attn_drop , drop
        )

        self.dropout_path = nn.Identity()
        self.norm2 = nn.LayerNorm(in_channels)
        mlp_hidden_dim = int(in_channels * self.mlp_ratio)
        self.mlp = MLP(self.in_channels , mlp_hidden_dim , dropout = drop)

        if self.shift_size > 0:
           
            H, W = self.input_res
            img_mask = torch.zeros((1, H, W, 1)) 
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
        
            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            self.attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            self.attn_mask = None

    def forward(self , x):
        H , W = self.input_res
        B , L , C = x.shape
        assert L == H * W , "input feature has wrong size"

        x_ = x.clone()
        x = self.norm1(x)
        x = x.view(B , H , W , C)

        if self.shift_size > 0:
            shifted_x = torch.roll(x , shifts=(-self.shift_size , -self.shift_size) , dims=(1 , 2))
        else :
            shifted_x = x
        x_window = window_partition(shifted_x , self.window_size)
        x_window = x_window.view(-1 , self.window_size * self.window_size , C)

        attn_window = self.attn(x_window , mask = self.attn_mask) #
        shifted_x = window_reverse(attn_window , self.window_size , H , W)

        if self.shift_size > 0:
            x = torch.roll(shifted_x , (self.shift_size , self.shift_size) , dims=(1 , 2))
        else:
            x = shifted_x
        x = x.view(B , H * W , C)
        x = x_ + self.dropout_path(x)
        x = x + self.dropout_path(self.mlp(self.norm2(x)))
        return x


In [116]:
class Patch_Merging(nn.Module):
    def __init__(self , 
                 input_res , 
                 in_channels):
        super(Patch_Merging , self).__init__()

        self.input_res = input_res
        self.in_channels = in_channels
        self.reduction = nn.Linear(4 * in_channels , 2 * in_channels , bias=False)
        self.norm = nn.LayerNorm(4 * in_channels)
    
    def forward(self , x):
        H , W = self.input_res
        B , L , C = x.shape
        assert L == H * W , "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0 , f'x size {H} {W} are not even'
        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  
        x1 = x[:, 1::2, 0::2, :]  
        x2 = x[:, 0::2, 1::2, :]  
        x3 = x[:, 1::2, 1::2, :]  
        x = torch.cat([x0, x1, x2, x3], -1)  
        x = x.view(B, -1, 4 * C)  

        x = self.norm(x)
        x = self.reduction(x)

        return x

In [121]:
class Layer(nn.Module):
    def __init__(self , 
                 in_channels , 
                 input_res , 
                 depth , 
                 num_heads , 
                 window_size ,
                 mlp_ratio = 4.0 , 
                 qk_scale = None , 
                 drop = 0.0 , 
                 attn_drop = 0.0 , 
                 downsample = None ):
        super(Layer , self).__init__()

        self.in_channels = in_channels
        self.input_res = input_res
        self.depth = depth
        
        self.blocks = nn.ModuleList([
                                     Swin_Transformer_Block(
                                         self.in_channels , 
                                         self.input_res ,
                                         num_heads , 
                                         window_size , 
                                         shift_size = 0 if (i % 2 == 0) else window_size //2 , 
                                         mlp_ratio = mlp_ratio , 
                                         qk_scale = qk_scale , 
                                         drop = drop , 
                                         attn_drop = attn_drop
                                     )
        for i in range(depth)])

        if downsample is not None:
            self.downsample = downsample(self.input_res , in_channels)
        else :
            self.downsample = None

    def forward(self , x):
        for block in self.blocks:
            x = block(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x


In [122]:
class Patch_Embedding(nn.Module):
    def __init__(self ,
                 img_size = 224 , 
                 patch_size = 4 , 
                 in_channels = 3 , 
                 embed_dim = 96 , 
                 norm_layer = None):
        super(Patch_Embedding , self).__init__()

        img_size = (img_size , img_size)
        patch_size = (patch_size , patch_size)
        patch_res = [img_size[0] // patch_size[0] , img_size[1] // patch_size[1]]

        self.img_size = img_size
        self.patch_size = patch_size
        self.patch_res = patch_res
        self.num_patches = patch_res[0] * patch_res[1]

        self.in_channels  = in_channels
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(self.in_channels , self.embed_dim , kernel_size=self.patch_size , stride=self.patch_size)

        if norm_layer is not None:
            self.norm = norm_layer(self.embed_dim)
        else:
            self.norm = None
    
    def forward(self , x):
        B , C , H , W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1 , 2)
        if self.norm is not None:
            x = self.norm(x)
        return x

In [125]:
class Swin_Transformer(nn.Module):
    def __init__(self , 
                 img_size = 224 , 
                 patch_size = 4 , 
                 in_channels = 3 , 
                 num_classes = 1000 , 
                 embed_dim = 96 , 
                 depths = [2 , 2 , 6 , 2] , 
                 num_heads = [3 , 6 , 12 , 24] , 
                 window_size = 7 , 
                 mlp_ratio = 4.0 , 
                 qk_scale = None , 
                 drop_rate = 0.0 , 
                 attn_drop_rate = 0.0 , 
                 norm_layer = nn.LayerNorm , 
                 patch_norm = True):
        super(Swin_Transformer , self).__init__()

        self.num_classes = num_classes
        self.num_layers = len(depths)
        self.embed_dim = embed_dim
        self.patch_norm = patch_norm
        self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
        self.mlp_ratio = mlp_ratio

        self.patch_embed = Patch_Embedding(
            img_size = img_size , 
            patch_size = patch_size , 
            in_channels = in_channels , 
            embed_dim = embed_dim , 
            norm_layer = norm_layer
        )

        num_patches = self.patch_embed.num_patches
        patch_res = self.patch_embed.patch_res
        self.patch_res = patch_res

        #dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        self.layer = nn.ModuleList()

        for i in range(self.num_layers):
            layer = Layer(
                in_channels = int(embed_dim * 2 ** i) , 
                input_res = (patch_res[0] // ( 2 ** i) , 
                             patch_res[1] // (2 ** i)) , 
                depth = depths[i] , 
                num_heads = num_heads[i] , 
                window_size = window_size , 
                mlp_ratio = self.mlp_ratio , 
                qk_scale = qk_scale , 
                drop = drop_rate , 
                downsample = Patch_Merging if (i < self.num_layers - 1) else None 
            )

            self.layer.append(layer)
        self.norm = norm_layer(self.num_features)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(self.num_features , num_classes)

    def forward_features(self, x):
        x = self.patch_embed(x)
        for layer in self.layer:
            x = layer(x)

        x = self.norm(x) 
        x = self.avgpool(x.transpose(1, 2)) 
        x = torch.flatten(x, 1)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)
        return x

In [None]:
x = torch.randn(2 , 3 , 224 , 224).to(device)
swin_transformer = Swin_Transformer().to(device)
z = swin_transformer(x)
z.shape