# Attention likes Conv, Swim Transformer

In previous chapters, we revisited the strengths and weaknesses of convolutions and attention mechanisms, introduced CoAtNet—a model that combines these two paradigms—and explored ConvNeXt, a convolutional neural network inspired by the Vision Transformer (ViT) architecture. In this chapter, we will delve into Swin Transformer, a novel hierarchical vision transformer that has achieved impressive results across various computer vision tasks.

## Introduction to Swin Transformer

![Swim Transformer](../imgs/SwimTransformer.jpg)<br>
The Swin Transformer, introduced by Liu et al., is a groundbreaking architecture that applies hierarchical design principles to Vision Transformers. Swin stands for Shifted Window, referring to its unique approach to handling image patches. This model effectively captures both local and global features, addressing some of the limitations of previous transformer-based models.

## Key Innovations in Swin Transformer

### 1. Shifted Window Attention

Traditional Vision Transformers use global attention mechanisms, which can be computationally expensive and difficult to scale for high-resolution images. Swin Transformer introduces shifted window attention, where self-attention is computed within local windows, and the windows are shifted between successive layers. This design reduces computational complexity and enhances the model's ability to capture fine-grained details.

### 2. Hierarchical Feature Maps

Unlike ViTs, which maintain a fixed-size feature map throughout the network, Swin Transformer employs a hierarchical structure similar to convolutional neural networks. The model progressively reduces the spatial dimensions of the feature maps while increasing the number of channels. This design allows the Swin Transformer to handle high-resolution images more efficiently and capture features at multiple scales.

### 3. Efficient Computation

By computing self-attention within local windows and progressively merging feature maps, Swin Transformer achieves significant computational savings. This efficiency makes it feasible to apply the model to a wide range of vision tasks, from image classification to object detection and segmentation.

### 4. Swin Transformer Block

The Swin Transformer block is the fundamental building block of the model. Here's a breakdown of a typical Swin Transformer block:

In [1]:
import torch
import torch.nn as nn

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=3, embed_dim=96, patch_size=4):
        super(PatchEmbedding, self).__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)  # (B, C, H, W) -> (B, H*W, C)
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, drop=0.):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads):
        super(WindowAttention, self).__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.attn_drop = nn.Dropout(0.0)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(0.0)

    def forward(self, x):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
        q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0)
        
        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4., drop=0.):
        super(SwinTransformerBlock, self).__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, window_size=window_size, num_heads=num_heads)
        self.drop_path = nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), out_features=dim, drop=drop)

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)

        attn_windows = self.attn(x_windows)
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)

        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        x = x.view(B, H * W, C)
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

D:\Anaconda\envs\bdl\lib\site-packages\numpy\.libs\libopenblas64__v0.3.23-246-g3d31191b-gcc_10_3_0.dll
D:\Anaconda\envs\bdl\lib\site-packages\numpy\.libs\libopenblas64__v0.3.23-gcc_10_3_0.dll


### 5. Swin Transformer Macro Architecture
The Swin Transformer consists of multiple stages, each containing several Swin Transformer blocks. The number of blocks and their configurations can be adjusted to create different variants of the model.

In [2]:
class SwinTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depth=6, num_heads=3, window_size=7, mlp_ratio=4., drop_rate=0.):
        super(SwinTransformer, self).__init__()
        self.num_classes = num_classes
        self.num_features = embed_dim

        self.patch_embed = PatchEmbedding(in_channels=in_chans, embed_dim=embed_dim, patch_size=patch_size)
        self.pos_drop = nn.Dropout(p=drop_rate)

        self.layers = nn.ModuleList([
            SwinTransformerBlock(dim=embed_dim, input_resolution=(img_size // patch_size, img_size // patch_size), num_heads=num_heads, window_size=window_size, shift_size=0 if i % 2 == 0 else window_size // 2, mlp_ratio=mlp_ratio, drop=drop_rate)
            for i in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.pos_drop(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x)
        x = x.mean(dim=1)
        x = self.head(x)
        return x

# Test the simplified Swin Transformer with a sample input
input_tensor = torch.randn(8, 3, 224, 224)  # Batch of 8, 3x224x224 images
model = SwinTransformer()
output = model(input_tensor)
print(output.shape)  # Expected output shape: (8, 1000)
print(model)

torch.Size([8, 1000])
SwinTransformer(
  (patch_embed): PatchEmbedding(
    (proj): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (layers): ModuleList(
    (0-5): 6 x SwinTransformerBlock(
      (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (attn): WindowAttention(
        (qkv): Linear(in_features=96, out_features=288, bias=False)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=96, out_features=96, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=96, out_features=384, bias=True)
        (act): GELU(approximate='none')
        (fc2): Linear(in_features=384, out_features=96, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
  )
  (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=Tru