<a href="https://colab.research.google.com/github/agercodes/Swin_Transformer/blob/main/Swin_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
! pip install timm

### Importing Dependencies

In [None]:
import torch
import torch.nn as nn
from timm.models.layers import PatchEmbed
from timm.models.swin_transformer import PatchMerging, SwinTransformerBlock

### Patch Partition/Embedding

In [None]:
x = torch.randn(1, 3, 224, 224)
patch_embed = PatchEmbed(img_size=224, patch_size=4, embed_dim=96)
patch_embed(x).shape

### Patch Merging

In [None]:
class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, out_dim=None, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.out_dim = out_dim or 2 * dim
        self.norm = norm_layer(4 * dim)
        self.reduction = nn.Linear(4 * dim, self.out_dim, bias=False)

    def forward(self, x):
        """
        x: B, H*W, C
        B: Batch size
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C

        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)
        return x

In [None]:
x = torch.randn(1, 56*56, 96)
l = PatchMerging(input_resolution=(56, 56), dim=96, out_dim=192, norm_layer=nn.LayerNorm)
l(x).shape

### Swin Transformer Block

In [None]:
x = torch.randn(1, 56*56, 96)
t_1 = SwinTransformerBlock(dim=96, input_resolution=(56, 56))
t_2 = SwinTransformerBlock(dim=96, input_resolution=(56, 56), shift_size=3)
t_1(x).shape, t_2(t_1(x)).shape

### Window Partitioning

In [None]:
def window_partition(x, window_size: int):
    """
    Args:
        x: (B, H, W, C)
        window_size (int): window size

    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows