# Swin Transformer Implementation Using PyTorch

This notebook showcases the implementation of a Swin Transformer from scratch using PyTorch. The Swin Transformer (Shifted Window Transformer) is a hierarchical vision transformer architecture that introduces window-based self-attention and shifted windowing to achieve computational efficiency and improved performance on vision tasks.

We use a downsampled version of the MNIST dataset (resized to 32x32) or similar small grayscale datasets for demonstration. The model is kept lightweight for ease of understanding and training from scratch.

## Dataset

The dataset used is the MNIST dataset, resized to **32x32**. The dataset consists of 70,000 grayscale images of handwritten digits (0-9). Each image has one channel and is labeled with a corresponding digit.

- **Input Size**: 32×32×1
- **Number of Classes**: 10

## Swin Transformer Components

The Swin Transformer is implemented with the following components:

- **PatchPartition**: Splits the input image into non-overlapping patches and reshapes the tensor.
- **LinearEmbedding**: Projects flattened patches to the desired embedding dimension.
- **Window-based Multi-Head Self Attention (W-MSA)**: Performs self-attention within local non-overlapping windows.
- **Shifted Window Multi-Head Self Attention (SW-MSA)**: Introduces overlapping context by shifting the windows.
- **SwinTransformerBlock**: Core building block that alternates between W-MSA and SW-MSA.
- **PatchMerging**: Downsamples the feature map by merging adjacent patches and increasing the embedding dimension.
- **SwinTransformer**: Full model composed of stacked Swin Transformer blocks and hierarchical downsampling.

## Import libraries

In [43]:
from tqdm import tqdm

import torch
from torch import Tensor
import torch.nn as nn

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

## Load MNIST dataset and move it to dataloader

In [44]:
dataset_train = datasets.MNIST(
    root='../../../datasets',
    download=True,
    train=True,
    transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
    ])
)

dataset_test = datasets.MNIST(
    root='../../../datasets',
    download=True,
    train=False,
    transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
    ])
)

dataloader_train = DataLoader(dataset_train, batch_size=64, shuffle=True)
dataloader_test = DataLoader(dataset_test, batch_size=64, shuffle=True)

# Swin Transformer

The **Swin Transformer** is a hierarchical vision transformer that computes self-attention within local windows and shifts them between blocks. This design combines the benefits of both convolutional neural networks (CNNs) and Transformers — achieving linear computational complexity and strong inductive bias for vision tasks.

![Swin Transformer](swinT_architecture.png)

---

## 1. Patch Partition

The image is first partitioned into **non-overlapping patches** of size $P \times P$.

For an image $x \in \mathbb{R}^{B \times C \times H \times W}$, the patch partition operation reshapes it to:

$$
x' \in \mathbb{R}^{B \times \frac{H}{P} \times \frac{W}{P} \times (P^2 \cdot C)}
$$

This operation is similar to convolution with:
- Kernel size = patch size
- Stride = patch size

---

## 2. Linear Embedding

Each patch (flattened vector of size $P^2 \cdot C$) is projected to a fixed dimension $D$ using a learnable linear layer:

$$
z = xW_e + b_e, \quad \text{where } W_e \in \mathbb{R}^{(P^2 \cdot C) \times D}
$$

The result is a feature map of shape:

$$
z \in \mathbb{R}^{B \times \frac{H}{P} \times \frac{W}{P} \times D}
$$

---

## 3. Swin Transformer Block

### 3.1 Window-based Self-Attention

Instead of global self-attention, the image is divided into **non-overlapping windows** of size $M \times M$, and attention is computed within each window independently.

Given input $x \in \mathbb{R}^{B \times H \times W \times C}$, the window partition splits it into:

$$
\text{windows} \in \mathbb{R}^{B \cdot \frac{H}{M} \cdot \frac{W}{M} \times M \times M \times C}
$$

Self-attention within each window is computed as:

$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V
$$

where:

- $Q = XW^Q, \quad K = XW^K, \quad V = XW^V$
- $W^Q, W^K, W^V \in \mathbb{R}^{C \times d}$

### 3.2 Shifted Windows

To enable cross-window interactions, **the windows are shifted** by $\frac{M}{2}$ pixels in the next block.

This operation:
- Introduces connections across windows.
- Preserves locality like CNNs.
- Avoids quadratic cost of full attention.

### 3.3 Attention Masking

When using shifted windows, some tokens across different windows fall into the same attention region. To avoid unintended interactions, an **attention mask** is applied.

Let the attention score matrix be $A \in \mathbb{R}^{M^2 \times M^2}$, the mask $M \in \{0, -\infty\}^{M^2 \times M^2}$ is added:

$$
\text{MaskedAttention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}} + M\right)V
$$

---

## 4. Patch Merging

To downsample feature maps (like pooling), Swin Transformer merges every $2 \times 2$ patch region:

1. Reshape each $2 \times 2$ neighborhood to a vector of size $4C$.
2. Apply a linear layer $W \in \mathbb{R}^{4C \times 2C}$.

This reduces spatial dimensions by 2 and doubles feature dimension:

$$
\text{Input: } B \times H \times W \times C \\
\text{Output: } B \times \frac{H}{2} \times \frac{W}{2} \times 2C
$$

---

## 5. Final Classification Head

After all stages (each with multiple blocks and optional patch merging), the resulting feature map is normalized and pooled:

$$
x \in \mathbb{R}^{B \times H' \times W' \times C} \\
x = \text{LayerNorm}(x) \\
x = \text{AvgPool}(x) \rightarrow \mathbb{R}^{B \times C} \\
\hat{y} = \text{Linear}(x) \rightarrow \mathbb{R}^{B \times \text{num\_classes}}
$$

---

## 6. Summary Table

| Component             | Role                                                        | Key Idea |
|-----------------------|-------------------------------------------------------------|----------|
| Patch Partition       | Divide input into flat patches                              | CNN-like stride operation |
| Linear Embedding      | Map patches to embedding space                              | FC projection |
| Window Self-Attention | Attention within fixed-size windows                         | Reduces complexity |
| Shifted Windows       | Alternate blocks with shifted windows                       | Cross-window interaction |
| Patch Merging         | Spatial downsampling and channel upsampling                 | Hierarchical |
| LayerNorm + Head      | Normalize + aggregate features + predict class              | Standard |

---

## Mathematical Complexity

Let:
- Image size: $H \times W$
- Patch size: $P$
- Window size: $M$
- Number of channels: $C$

Then:
- Global attention: $O((HW)^2)$
- Swin window attention: $O\left(\frac{HW}{M^2} \cdot M^4\right) = O(HW \cdot M^2)$

Hence, Swin attention scales **linearly** with image size (like CNNs), compared to quadratic scaling in ViT.

---

## References

- [Swin Transformer Paper](https://arxiv.org/abs/2103.14030)

### PatchPartition

The `PatchPartition` class splits the input image into non-overlapping patches of size `(patch_size × patch_size)`, then flattens and reshapes the result into shape `(B, H//p, W//p, patch_dim)`, where `patch_dim = patch_size² × channels`.

In [45]:
class PatchPartition(nn.Module):
    def __init__(self, in_channels, height, width, patch_size, embed_dim):
        super().__init__()
        self.num_patch = (height // patch_size) * (width // patch_size)

        self.partition = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size, bias=False)

    def forward(self, x: Tensor):
        x = self.partition(x)
        x = x.permute(0, 2, 3, 1).contiguous()
        return x

### LinearEmbedding

The `LinearEmbedding` class projects the flattened patch features into a higher-dimensional embedding space (e.g., 32 or 96) using a linear layer. This is analogous to word embedding in NLP transformers.

In [46]:
class LinearEmbedding(nn.Module):
    def __init__(self, from_dim, to_dim):
        super().__init__()
        self.projection = nn.Linear(from_dim, to_dim)

    def forward(self, x: Tensor):
        return self.projection(x)

## Window Partition & Reverse — Swin Transformer Helper Functions

In the Swin Transformer, self-attention is computed inside **local windows** rather than across the full image, significantly reducing computational complexity. Two crucial operations that enable this behavior are:

- `window_partition`: splits a feature map into non-overlapping windows.
- `window_reverse`: reconstructs the original spatial layout from the set of windows.

---

## 1. Window Partition

### Purpose

This function takes a feature map with shape:

$$
x \in \mathbb{R}^{B \times H \times W \times C}
$$

where:
- B is the batch size,
- H, W are the height and width,
- C is the number of channels,

and splits it into **non-overlapping windows** of shape $ M \times M$, where M is the window size.

### Mathematics Behind

1. **Reshaping**:  
   The feature map is reshaped into:

   $$
   x' \in \mathbb{R}^{B \times \frac{H}{M} \times M \times \frac{W}{M} \times M \times C}
   $$

2. **Permutation**:  
   Dimensions are permuted to group window blocks together:

   $$
   x'' \in \mathbb{R}^{B \times \frac{H}{M} \times \frac{W}{M} \times M \times M \times C}
   $$

3. **Flattening**:  
   The reshaped tensor is flattened so that all windows are treated as independent inputs:

   $$
   \text{windows} \in \mathbb{R}^{(B \cdot \frac{H}{M} \cdot \frac{W}{M}) \times M \times M \times C}
   $$

These windows are then used to apply **local self-attention** efficiently.

---

## 2. Window Reverse

### Purpose

After computing attention inside windows, this function reverses the operation and reconstructs the original feature map of shape:

$$
x \in \mathbb{R}^{B \times H \times W \times C}
$$

from the windows of shape:

$$
\text{windows} \in \mathbb{R}^{(B \cdot \frac{H}{M} \cdot \frac{W}{M}) \times M \times M \times C}
$$

### Mathematics Behind

1. **Determine Batch Size**:  
   The batch size B is inferred from the number of windows:

   $$
   B = \frac{\text{num\_windows}}{\frac{H}{M} \cdot \frac{W}{M}}
   $$

2. **Reshape**:  
   The windows are reshaped into the grouped format:

   $$
   x' \in \mathbb{R}^{B \times \frac{H}{M} \times \frac{W}{M} \times M \times M \times C}
   $$

3. **Permute & Merge**:  
   The spatial layout is restored by permuting the axes and merging the dimensions:

   $$
   x \in \mathbb{R}^{B \times H \times W \times C}
   $$

---

## Why These Functions Matter

- These operations enable the Swin Transformer to perform **efficient local attention** while still maintaining the ability to aggregate global information through **shifted windows** in successive layers.
- The key benefit is the **linear scaling** of attention cost with image size:
  $$
  \text{Complexity} = O\left(\frac{HW}{M^2} \cdot M^4\right) = O(HW \cdot M^2)
  $$

Compared to standard global attention:
  $$
  O((HW)^2)
  $$

---

## Summary

| Function          | Input Shape                           | Output Shape                          | Purpose                                |
|------------------|----------------------------------------|----------------------------------------|----------------------------------------|
| `window_partition` |$ B \times H \times W \times C $     | $ (B \cdot \frac{H}{M} \cdot \frac{W}{M}) \times M \times M \times C $ | Split into windows                     |
| `window_reverse`   | $(B \cdot \frac{H}{M} \cdot \frac{W}{M}) \times M \times M \times C $ |  $B \times H \times W \times C $     | Reconstruct the original feature map   |

In [47]:
def window_partition(x: Tensor, window_size: int):
    """
    Args:
        x: (B, H, W, C)
    Returns:
        windows: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()  # (B, H//ws, W//ws, ws, ws, C)
    windows = x.view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows: Tensor, window_size: int, H: int, W: int):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    x = x.view(B, H, W, -1)
    return x

# WindowAttention

The **WindowAttention** module implements **local self-attention** within fixed‐size windows of the feature map. It is a core building block of the Swin Transformer, enabling efficient computation and spatial inductive bias.

---

## 1. Purpose & Motivation

- **Locality**  
  By restricting attention to a window of size $M \times M$, we reduce complexity from  
  $$O\bigl((HW)^2\bigr)\quad\text{to}\quad O\!\Bigl(\tfrac{HW}{M^2}\,M^4\Bigr)=O(HW\,M^2).$$

- **Shifted Windows**  
  Alternating between regular and shifted windows enables cross‐window interactions without global attention.

- **Spatial Bias**  
  Injecting learnable **relative position biases** helps the model encode spatial relationships within each window.

---

## 2. Mathematical Details

### 2.1 Input & QKV Projection

Given input tokens  
$$X \in \mathbb{R}^{B N \,\times\, N \,\times\, D},$$  
where $B$ is batch size, $N=M^2$ is tokens per window, and $D$ is embedding dim, we compute:  
$$[Q,K,V] = X\,W_{qkv},\quad W_{qkv}\in\mathbb{R}^{D\times 3D},$$  
then reshape into $h$ heads of size $d=D/h$:  
$$Q,K,V\in\mathbb{R}^{B N\times h\times N\times d}.$$

### 2.2 Scaled Dot‐Product Attention

For each head:  
$$A = \frac{Q\,K^\top}{\sqrt{d}}\ \in\mathbb{R}^{B N\times N\times N},$$  
computing affinities between all token pairs within a window.

---

## 3. Relative Position Bias

### 3.1 Creation of the Bias Table

We introduce a learnable **bias table**  
$$B_r \in \mathbb{R}^{(2M-1)^2 \times h},$$  
where each of the $(2M-1)^2$ entries corresponds to a relative offset $(\Delta x,\Delta y)$ between two tokens in an $M\times M$ window.  
- Offsets range $\Delta x,\Delta y\in\{-M+1,\dots,M-1\}$.  
- We map each offset to a 1D index via  
  $$\text{idx}(\Delta x,\Delta y) = (\Delta x + M - 1)\,(2M-1) + (\Delta y + M - 1).$$  
- The table is initialized with a **truncated normal** distribution:  
  $$B_r[i,k]\sim\mathcal{N}(0,\sigma^2)\text{ clipped to }[-2\sigma,2\sigma],\quad \sigma=0.02.$$

### 3.2 Adding Bias to Attention

We precompute an index tensor  
$$\mathrm{idx}\in\{0,\dots,(2M-1)^2-1\}^{N\times N}$$  
that maps each token‐pair $(i,j)$ to its bias index. Then the adjusted logits become  
$$\widetilde{A}_{b,i,j}^{(k)} = A_{b,i,j}^{(k)} \;+\; B_r[\mathrm{idx}_{i,j},\,k].$$

---

## 4. Softmax & Output

1. **Softmax** over the last dimension:  
   $$\alpha = \mathrm{softmax}\bigl(\widetilde{A}\bigr)\in\mathbb{R}^{B N\times N\times N}.$$
2. **Aggregate** values:  
   $$O = \alpha\,V \;\in\mathbb{R}^{B N\times h\times N\times d}.$$
3. **Merge heads** and apply final linear projection to restore shape  
   $$O'\in\mathbb{R}^{B N\times N\times D}.$$

---

## 5. Why It’s Necessary

- **Efficiency**: Local windows reduce compute on high‐res inputs.  
- **Cross‐window Links**: Shifted windows connect neighboring regions without full attention.  
- **Inductive Bias**: Relative biases encode spatial priors (nearby patches more related) while remaining learnable.

---

## 6. Summary

| Step                      | Operation & Formula                                                           |
|---------------------------|-------------------------------------------------------------------------------|
| **QKV Projection**        | $[Q,K,V]=XW_{qkv},\;Q,K,V\in\mathbb{R}^{BN\times h\times N\times d}$          |
| **Scaled Dot‐Product**    | $A=\tfrac{QK^\top}{\sqrt{d}}$                                                  |
| **Relative Bias**         | $\widetilde{A}_{i,j}^{(k)}=A_{i,j}^{(k)}+B_r[\mathrm{idx}_{i,j},k]$            |
| **Softmax & Aggregate**   | $\alpha=\mathrm{softmax}(\widetilde{A}),\;O=\alpha V$                          |
| **Merge & Project**       | Concatenate heads, linear map $hd\to D$                                        |

The **WindowAttention** module balances computational tractability with the ability to capture both **local** and, via shifts, **cross‐window** interactions.

In [48]:
class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads, shift_size, input_resolution):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.shift_size = shift_size
        self.num_heads = num_heads
        self.input_resolution = input_resolution
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)

        # Position bias table
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2*window_size-1)*(2*window_size-1), num_heads)
        )
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

        # Relative position index for each token inside a window
        self.register_buffer("relative_position_index", self.get_relative_position_index())

    def get_relative_position_index(self):
        """Build relative position index tensor."""
        coords = torch.stack(torch.meshgrid(
            torch.arange(self.window_size),
            torch.arange(self.window_size), indexing="ij"
        ))  # (2, window_size, window_size)
        coords_flatten = coords.flatten(1)  # (2, M^2)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # (2, M^2, M^2)
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # (M^2, M^2, 2)
        relative_coords[:, :, 0] += self.window_size - 1  # shift to 0-based index
        relative_coords[:, :, 1] += self.window_size - 1
        relative_coords[:, :, 0] *= 2 * self.window_size - 1
        relative_position_index = relative_coords.sum(-1)  # (M^2, M^2)
        return relative_position_index

    def forward(self, x: torch.Tensor):
        """
        Args:
            x: (B*N, M^2, D)
        Returns:
            out: (B*N, M^2, D)
        """
        B_, N, D = x.shape  # B*N = number of windows, N = M^2

        # Step 1: Project input to Q, K, V
        qkv = self.qkv(x)  # (B*N, N, 3*D)
        qkv = qkv.reshape(B_, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)  # (3, B*N, h, N, d)
        q, k, v = qkv[0], qkv[1], qkv[2]  # each: (B*N, h, N, d)

        # Step 2: Scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B*N, h, N, N)

        # Step 3: Add relative positional bias
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
        relative_position_bias = relative_position_bias.view(N, N, -1).permute(2, 0, 1)  # (h, N, N)
        attn = attn + relative_position_bias.unsqueeze(0)  # (1, h, N, N)

        # Step 4: Softmax and apply attention
        attn = attn.softmax(dim=-1)  # (B*N, h, N, N)
        out = (attn @ v)  # (B*N, h, N, d)

        # Step 5: Merge heads and project
        out = out.transpose(1, 2).reshape(B_, N, D)  # (B*N, N, D)
        out = self.proj(out)  # final linear projection

        return out

### MLP (Multi-Layer Perceptron)

The **MLP** (Multi-Layer Perceptron) class implements a simple feed-forward neural network with two linear layers. The network also includes **GELU** activations and **dropout** regularization.

In [49]:
class MLP(nn.Module):
    def __init__(self, dim, hidden_dim=None, dropout=0.0):
        super().__init__()
        hidden_dim = hidden_dim or 4 * dim
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

### SwinTransformerBlock

The `SwinTransformerBlock` is the main processing unit. It contains two stages:
- **W-MSA** (Window-based Multi-head Self-Attention)
- **SW-MSA** (Shifted Window Multi-head Self-Attention)

Shifted windows allow information to flow between windows, maintaining global coherence.

In [50]:
class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, mlp_ratio=4.0):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.window_size = window_size
        self.shift_size = shift_size
        self.num_heads = num_heads

        # Norm and attention
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(
            dim, window_size=window_size, num_heads=num_heads,
            shift_size=shift_size, input_resolution=input_resolution
        )

        # MLP part
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, hidden_dim=int(dim * mlp_ratio))

    def forward(self, x):
        """
        x: (B, H, W, C)
        """
        H, W = self.input_resolution
        B, _, _, C = x.shape
        shortcut = x

        # Step 1: cyclic shift if needed
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # Step 2: partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # (num_windows*B, ws, ws, C)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)

        # Step 3: attention
        attn_windows = self.attn(self.norm1(x_windows))  # (num_windows*B, ws*ws, C)

        # Step 4: merge windows back
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        x = window_reverse(attn_windows, self.window_size, H, W)  # (B, H, W, C)

        # Step 5: reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))

        # Step 6: Residual + MLP
        x = shortcut + x
        x = x + self.mlp(self.norm2(x))

        return x

### PatchMerging

The `PatchMerging` class is used for **downsampling**. It concatenates neighboring patches (e.g., a 2×2 grid), applies a linear layer, and reduces the spatial resolution by 2 while increasing the feature dimension.

In [51]:
class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim):
        super().__init__()
        self.input_resolution = input_resolution
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = nn.LayerNorm(4 * dim)

    def forward(self, x):
        """
        x: (B, H, W, C)
        """
        H, W = self.input_resolution
        B, _, _, C = x.shape

        x = x.view(B, H // 2, 2, W // 2, 2, C)
        x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, H // 2, W // 2, 4 * C)
        x = self.norm(x)
        x = self.reduction(x)
        return x

### SwinTransformer

The `SwinTransformer` class combines:
- Patch Partition
- Embedding
- Two Stages:
  - Stage 1: W-MSA + SW-MSA + Patch Merging
  - Stage 2: W-MSA + SW-MSA
- Global pooling + classification head

Each stage doubles the feature dimension and halves the resolution.

In [52]:
class SwinTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=1, embed_dim=32, num_classes=10):
        super().__init__()
        self.patch_embed = PatchPartition(in_channels=in_chans, height=img_size, width=img_size,
                                          patch_size=patch_size, embed_dim=4 * 4 * 1)
        self.linear_embed = LinearEmbedding(from_dim=4 * 4 * 1, to_dim=embed_dim)

        H = W = img_size // patch_size  # 32 / 4 = 8
        self.stage1_block1 = SwinTransformerBlock(
            dim=embed_dim, input_resolution=(H, W), num_heads=2, window_size=2, shift_size=0
        )
        self.stage1_block2 = SwinTransformerBlock(
            dim=embed_dim, input_resolution=(H, W), num_heads=2, window_size=2, shift_size=1
        )

        # Apply Patch Merging: 8×8 → 4×4
        self.patch_merging = PatchMerging(input_resolution=(H, W), dim=embed_dim)
        H, W = H // 2, W // 2
        embed_dim *= 2

        self.stage2_block1 = SwinTransformerBlock(
            dim=embed_dim, input_resolution=(H, W), num_heads=2, window_size=2, shift_size=0
        )
        self.stage2_block2 = SwinTransformerBlock(
            dim=embed_dim, input_resolution=(H, W), num_heads=2, window_size=2, shift_size=1
        )

        self.norm = nn.LayerNorm(embed_dim)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)  # (B, H//p, W//p, C)
        x = self.linear_embed(x)

        x = self.stage1_block1(x)
        x = self.stage1_block2(x)

        x = self.patch_merging(x)  # downsampling
        x = self.stage2_block1(x)
        x = self.stage2_block2(x)

        B, H, W, C = x.shape
        x = self.norm(x.view(B, H * W, C))
        x = self.avgpool(x.transpose(1, 2)).squeeze(-1)
        x = self.head(x)
        return x

## Train the model on resized MNIST

In [53]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SwinTransformer().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

### train epoch

In [54]:
def train(model, loader, optimizer, criterion, device):
    model.train()
    total_loss, total_correct = 0, 0

    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss: Tensor = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        total_correct += (outputs.argmax(1) == labels).sum().item()

    avg_loss = total_loss / len(loader.dataset)
    acc = total_correct / len(loader.dataset)
    return avg_loss, acc


### Evaluate function

In [55]:
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, total_correct = 0, 0

    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            total_correct += (outputs.argmax(1) == labels).sum().item()

    avg_loss = total_loss / len(loader.dataset)
    acc = total_correct / len(loader.dataset)
    return avg_loss, acc


In [56]:
num_epochs = 10

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    train_loss, train_acc = train(model, dataloader_train, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, dataloader_test, criterion, device)
    
    scheduler.step()

    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc*100:.2f}%")
    print(f"Val   Loss: {test_loss:.4f} | Val   Acc: {test_acc*100:.2f}%")


Epoch 1/10


Training: 100%|██████████| 938/938 [00:16<00:00, 55.79it/s]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 92.22it/s]


Train Loss: 0.4814 | Train Acc: 84.34%
Val   Loss: 0.1425 | Val   Acc: 95.63%

Epoch 2/10


Training: 100%|██████████| 938/938 [00:16<00:00, 56.38it/s]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 95.23it/s]


Train Loss: 0.1294 | Train Acc: 96.12%
Val   Loss: 0.1263 | Val   Acc: 95.90%

Epoch 3/10


Training: 100%|██████████| 938/938 [00:16<00:00, 58.31it/s]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 94.61it/s]


Train Loss: 0.0993 | Train Acc: 96.91%
Val   Loss: 0.0929 | Val   Acc: 97.02%

Epoch 4/10


Training: 100%|██████████| 938/938 [00:16<00:00, 56.21it/s]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 96.35it/s]


Train Loss: 0.0793 | Train Acc: 97.52%
Val   Loss: 0.0710 | Val   Acc: 97.73%

Epoch 5/10


Training: 100%|██████████| 938/938 [00:16<00:00, 56.17it/s]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 93.31it/s]


Train Loss: 0.0695 | Train Acc: 97.85%
Val   Loss: 0.0818 | Val   Acc: 97.40%

Epoch 6/10


Training: 100%|██████████| 938/938 [00:16<00:00, 55.97it/s]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 97.91it/s]


Train Loss: 0.0409 | Train Acc: 98.73%
Val   Loss: 0.0482 | Val   Acc: 98.34%

Epoch 7/10


Training: 100%|██████████| 938/938 [00:16<00:00, 56.39it/s]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 93.72it/s]


Train Loss: 0.0355 | Train Acc: 98.88%
Val   Loss: 0.0560 | Val   Acc: 98.25%

Epoch 8/10


Training: 100%|██████████| 938/938 [00:16<00:00, 56.61it/s]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 97.97it/s]


Train Loss: 0.0315 | Train Acc: 99.00%
Val   Loss: 0.0475 | Val   Acc: 98.44%

Epoch 9/10


Training: 100%|██████████| 938/938 [00:16<00:00, 56.71it/s]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 92.07it/s]


Train Loss: 0.0292 | Train Acc: 99.08%
Val   Loss: 0.0444 | Val   Acc: 98.53%

Epoch 10/10


Training: 100%|██████████| 938/938 [00:16<00:00, 56.01it/s]
Evaluating: 100%|██████████| 157/157 [00:01<00:00, 94.56it/s]

Train Loss: 0.0254 | Train Acc: 99.23%
Val   Loss: 0.0497 | Val   Acc: 98.56%



