# A demonstration of the patch merging logic from the Swin Transformer architecture

In [2]:
import torch
import torch.nn as nn

class PatchMerging(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.reduction = nn.Linear(4 * dim, 2 * dim)
        self.norm = nn.LayerNorm(4 * dim)

    def forward(self, x):
        """ x shape: (B, H, W, C) """
        # Step 1: Group 2x2 patches
        x0 = x[:, 0::2, 0::2, :]  # (B, H/2, W/2, C)
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        
        # Step 2: Channel concatenation (B, H/2, W/2, 4C)
        x = torch.cat([x0, x1, x2, x3], -1)
        
        # Step 3: Linear projection (B, H/2, W/2, 2C)
        x = self.norm(x)
        x = self.reduction(x)
        return x


# Create a random tensor with shape (B, H, W, C)
B, H, W, C = 1, 4, 4, 8
x = torch.randn(B, H, W, C)

# Initialize the PatchMerging layer
patch_merging = PatchMerging(dim=C)

# Forward pass
output = patch_merging(x)
print("Output shape:", output.shape)  # Expected shape: (B, H/2, W/2, 2C)

Output shape: torch.Size([1, 2, 2, 16])
