In [16]:
import torch
from torch import nn
from torch.nn import functional as F
from torchsummary import summary
from einops import rearrange
from torchvision.transforms import Compose, Resize, ToTensor
from dataclasses import dataclass

from PIL import Image

In [8]:
img = Image.open("porsche918.jpg")
transform = Compose([Resize((224, 224)), ToTensor()])
img = transform(img).unsqueeze(0)

In [48]:
@dataclass
class Config:
    in_channels: int = 3
    patch_size: int = 4
    embed_dim: int = 96
    window_size: int = 7
    n_heads: int = 6
    mask: bool = True

# Swin-Transformer

1. Patches + Embedding
2. Patch Merging
3. Shifted Window Attention
4. Relative Position Embedding
5. Encoder
6. Swin-Transformer


## Patches + Embedding

- similar to ViT


In [19]:
class SwinEmbeddings(nn.Module):
    def __init__(self, config: dataclass = Config) -> None:
        super().__init__()
        self.linear_embedding = nn.Conv2d(
            config.in_channels, config.embed_dim, kernel_size=config.patch_size, stride=config.patch_size
        )
        self.norm = nn.LayerNorm(config.embed_dim)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear_embedding(x)
        x = rearrange(x, "b c h w -> b (h w) c")
        x = self.relu(self.norm(x))
        return x

In [20]:
SwinEmbeddings()(img).shape  # correct

torch.Size([1, 3136, 96])

## Patch Merging

- the number of tokens is reduced by these layers as the network gets deeper


In [25]:
class PatchMerging(nn.Module):
    def __init__(self, config: dataclass = Config) -> None:
        super().__init__()
        self.linear = nn.Linear(4 * config.embed_dim, 2 * config.embed_dim)
        self.ln = nn.LayerNorm(2 * config.embed_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = w = int((x.shape[1] ** (1 / 2)) / 2)
        x = rearrange(x, "b (h s1 w s2) c -> b (h w) (s1 s2 c)", h=h, w=w, s1=2, s2=2)
        x = self.ln(self.linear(x))
        return x

In [26]:
PatchMerging()(SwinEmbeddings()(img)).shape  # correct

torch.Size([1, 784, 192])

## Shifted Window Attention


In [55]:
class ShiftedWindowMSA(nn.Module):
    def __init__(self, config: dataclass = Config) -> None:
        super().__init__()
        self.config = config
        self.proj1 = nn.Linear(2 * self.config.embed_dim, self.config.embed_dim * 3)
        self.proj2 = nn.Linear(self.config.embed_dim, self.config.embed_dim * 2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        head_dim = self.config.embed_dim // self.config.n_heads
        h = w = int(x.shape[1] ** (1 / 2))
        x = self.proj1(x)

        x = rearrange(x, "b (h w) (c k) -> b h w c k", h=h, w=w, k=3)

        if self.config.mask:
            x = torch.roll(x, (-self.config.window_size // 2, -self.config.window_size // 2), dims=(1, 2))

        x = rearrange(
            x,
            "b (h m1) (w m2) (H e) k -> b H h w (m1 m2) e k",
            H=self.config.n_heads,
            h=h // self.config.window_size,
            m1=self.config.window_size,
            m2=self.config.window_size,
        )

        q, k, v = x.chunk(3, dim=-1)
        q, k, v = q.squeeze(-1), k.squeeze(-1), v.squeeze(-1)
        attn_scores = torch.matmul(q, k.transpose(-1, -2)) / (head_dim**0.5)

        if self.config.mask:
            row_mask = torch.zeros((self.config.window_size**2, self.config.window_size**2))
            row_mask[
                -self.config.window_size * (self.config.window_size // 2) :,
                0 : -self.config.window_size * (self.config.window_size // 2),
            ] = float("-inf")
            row_mask[
                0 : -self.config.window_size * (self.config.window_size // 2),
                -self.config.window_size * (self.config.window_size // 2) :,
            ] = float("-inf")
            column_mask = rearrange(
                row_mask, "(r w1) (c w2) -> (w1 r) (w2 c)", w1=self.config.window_size, w2=self.config.window_size
            )
            attn_scores[:, :, -1, :] += row_mask
            attn_scores[:, :, :, -1] += column_mask

        attn = F.softmax(attn_scores, dim=-1) @ v
        out = rearrange(
            attn, "b H h w (m1 m2) e -> b (h m1) (w m2) (H e)", m1=self.config.window_size, m2=self.config.window_size
        )
        out = rearrange(out, "b h w c -> b (h w) c")
        return self.proj2(out)

In [56]:
ShiftedWindowMSA()(PatchMerging()(SwinEmbeddings()(img))).shape  # correct

torch.Size([1, 784, 192])