In [50]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.transforms import RandAugment
from torch import nn, optim

image_size = 224  

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize(mean=([0.4802, 0.4481, 0.3975]),   # normalize using mean & std
                         std=([0.2296, 0.2263, 0.2255])),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random'),
])

val_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=([0.4802, 0.4481, 0.3975]),   # normalize using mean & std
                         std=([0.2296, 0.2263, 0.2255])),
])

data_dir = 'tiny-imagenet-200/train'
full_dataset = datasets.ImageFolder(root=data_dir, transform=train_transform)

train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

val_dataset.dataset.transform = val_transform

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)


In [51]:
img, label = train_dataset[0]
img.shape

torch.Size([3, 224, 224])

In [52]:
torch.zeros(64, 128, 8, 8).flatten(2).shape

torch.Size([64, 128, 64])

Swin-T: C = 96, layer numbers = {2, 2, 6, 2}

In [53]:
class Patchify(nn.Module):
    """
    Converts an input image into a sequence of flattened patch embeddings

    Args:
        in_channels: Number of input channels (3 RGB).
        embed_dim: Dimensionality of output patch embeddings.
        patch_size: Size of each square patch

    Returns:
        Tensor of shape (BS, N_patches, embed_dim)
    """
    def __init__(self, in_channels, embed_dim=96, patch_size=4):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        # x shape = (B, 3, H, W)
        x = self.conv(x)         # (BS, embed_dim, H//patch_size, W//patch_size)
        x = x.flatten(2)         # (BS, embed_dim, N_patches)
        return x.transpose(1, 2) # (BS, N_patches, embed_dim)


In [54]:
def split_into_windows(x, M):
    """
    Splits (BS, H, W, channels) into non overlapping MxM windows.
    Args:
        x: Tensor of shape (BS, H, W, channels)
        M: Window size

    Returns:
        Tensor of shape (BS * num_windows, M*M, channels)
    """
    BS, H, W, channels = x.shape
    x = x.reshape(BS, H//M, M, W//M, M, channels)
    # Permute fixes the order so the MxM pixels are together properly
    x = x.permute(0, 1, 3, 2, 4, 5)     # (BS, H//M, W//M, M, M, channels)
    return x.reshape(-1, M*M, channels) # (BS * num_windows, M*M, channels)

In [55]:
def reverse_windows(x, M, H, W, channels):
    """
    Reverses MxM window tokens back into the original layout.

    Args:
        x: Tensor of shape (BS * num_windows, M*M, channels)
        M: Window size
        H: Original image height
        W: Original image width
        channels: Number of channels

    Returns:
        Tensor of shape (BS, H, W, channels)
    """
    BS = x.shape[0] // (H//M * W//M) # Original BS
    x = x.reshape(BS, H//M, W//M, M, M, channels)
    x = x.permute(0, 1, 3, 2, 4, 5)      # (BS, H//M, M, W//M, M, channels)
    return x.reshape(BS, H, W, channels) # (BS, H, W, channels)

Relative Position Bias computes a learnable bias between every pair of tokens in a window, based on how fat apart they are. For each pair of tokens in an M x M window, compute relative position, use that to index into a learnable bias table, add this bias to the attention logits.

In [56]:
class RelativePositionBias(nn.Module):
    """
    Computes relative position bias for self-attention.

    Generates a table of learnable relative position biases between token pairs
    within an attention window of shape (M x M). The relative position between
    any two tokens is encoded as a bias vector per attention head, and these biases are
    added to the attention scores in self-attention.

    Args:
        M: The height/width of the attention window. The total number of tokens is M * M.
        nheads: Number of attention heads.

    Attributes:
        relative_table: Learnable parameter of shape ((2M - 1)^2, nheads),
            where each entry represents a bias value for a specific relative position and head.
        relative_index: Lookup table of shape (M*M, M*M), where each entry is an index
            into relative_table that maps the relative position between two tokens to a bias vector.

    Forward Output:
        Tensor of shape (nheads, M*M, M*M) containing the relative bias for each token pair
        and each attention head. This can be directly added to attention logits.

    Example:
        relative = RelativePositionBias(M=3, nheads=4)
        bias = relative()  # Output shape: (4, 9, 9), for 4 heads and 3x3 tokens
    """
    def __init__(self, M, nheads):
        super().__init__()
        self.M = M
        self.nheads = nheads
        # (2M-1)^2 because there are up to M-1 tokens above or below or left or right of each token.
        # 2M - 1 possibilities for above and below, same for left and right. So (2M -1)^2 total.
        # If M = 3 row and col would go between -2 and +2 when comparing two tokens.
        # That gives 5^2 possible combinations. len(-2, -1, 0, 1, 2)^2
        self.relative_table = nn.Parameter(torch.zeros(size=((2*M - 1) * (2*M -1), nheads)))

        # Coordinate grid of token positions lets us know where each token is in the window.
        # It gives every token a (row, col) coordinate.
        # If M = 3: coords[0] (rows): [[0, 0, 0], [1, 1, 1], [2, 2, 2]], 
        #           coords[1] (cols): [[0, 1, 2], [0, 1, 2], [0, 1, 2]]
        coords = torch.stack(torch.meshgrid( # Matrix style over Cartesian
            torch.arange(M), torch.arange(M), indexing='ij'))  # (2, M, M)
        
        # Flatten the coordinates for token indices, so coords[:, i] is the (row, col) of token i
        coords = coords.flatten(1) # (2, M*M)

        # Compute relative positions so we have the position of a token relative to another token
        # coords[:, :, None] shape: (2, M*M, 1), coords[:, None, :] shape: (2, 1, M*M)
        relative = coords[:, :, None] - coords[:, None, :] # (2, M*M, M*M)

        # Reformat so we can use each (row, col) as an index into a table but row/col values
        # range from -(M-1) to (M-1) so we shift them up so they are positive: [0, 2M -2]
        relative = relative.permute(1, 2, 0) # (M*M, M*M, 2)
        relative[:, :, 0] += M - 1
        relative[:, :, 1] += M - 1

        # Flatten 2D positions into 1D. To convert: row * num_cols + col
        self.register_buffer(   # Register buffer to move to GPU
            "relative_index",
            (relative[:, :, 0] * (2*M - 1) + relative[:, :, 1]).long() # (M*M, M*M)
        )
    def forward(self): 
        # Use index to get bias values, look up the bias vector for each token pair
        bias = self.relative_table[self.relative_index.view(-1)] # (M*M * M*M, nheads)
        bias = bias.reshape(self.relative_index.shape[0], self.relative_index.shape[1], self.nheads)
        return bias.permute(2, 0, 1) # (nheads, M*M, M*M)


In [57]:
class WindowAttention(nn.Module):
    """
    Window-based Multihead Self-Attention with relative position bias.

    Performs self-attention within non overlapping MxM windows of the input feature map.
    It incorporates relative positional encoding and attention masks for shifted windows

    Attention(Q, K, V) = Softmax((QKᵀ / √d) + B + mask) @ V

    Args:
        channels: Input channels
        M: Height and width of the attention window.
        nheads: Number of attention heads.

    Attributes:
        q, k, v: Linear layers for queries, keys, and values.
        out: Output linear layer after attention.
    """
    def __init__(self, channels, M, nheads):
        super().__init__()
        self.M = M
        self.nheads = nheads
        self.rootd = (channels // nheads) ** -0.5 # (1 / √d) == (1 / √dim_per_head)

        self.q = nn.Linear(channels, channels)
        self.k = nn.Linear(channels, channels)
        self.v = nn.Linear(channels, channels)
        self.out = nn.Linear(channels, channels)

        self.relative = RelativePositionBias(M, nheads)
    
    def forward(self, x, attn_mask=None):
        """
        Forward pass for window based self-attention.

        Args:
            x: shape (B * nW, M*M, channels) where nW is number of windows
            attn_mask: Attention mask used for shifted windows to prevent cross-window attention.

        Returns:
            Tensor of shape (B*nW, M*M, channels)
        """
        # x shape = (B * nW, M*M, channels) where nW is number of windows
        BnW, M_sq, channels = x.shape # M_sq is M*M

        # d stands for dim_per_head. Permute on k so no transpose when computing attn.
        # q: (BnW, M*M, channels) --> (BnW, M*M, nheads, d) --> transpose(1, 2) --> (BnW, heads, M*M, d)
        # k: (BnW, M*M, channels) --> (BnW, M*M, nheads, d) --> permute(0, 2, 3, 1) --> (BnW, heads, d, M*M)
        # v: (BnW, M*M, channels) --> (BnW, M*M, nheads, d) --> transpose(1, 2) --> (BnW, heads, M*M, d)
        q = self.q(x).reshape(BnW, M_sq, self.nheads, channels // self.nheads).transpose(1, 2)
        k = self.k(x).reshape(BnW, M_sq, self.nheads, channels // self.nheads).permute(0, 2, 3, 1)
        v = self.v(x).reshape(BnW, M_sq, self.nheads, channels // self.nheads).transpose(1, 2)

        # Attention: (Q @ Kᵀ) / √d + relative bias + optional mask
        # k is already transposed.
        attn = (q @ k) * self.rootd # (BnW, nheads, M*M, M*M)
        attn = attn + self.relative()

        if attn_mask is not None:
            # attn_mask: (M*M, M*M) --> unsqueeze(0) --> (1, M*M, M*M)
            # Broadcasted to (BnW, nheads, M*M, M*M)
            attn = attn + attn_mask.unsqueeze(0)
        
        attn = attn.softmax(dim=-1)

        out = (attn @ v).transpose(1, 2).reshape(BnW, M_sq, channels)
        return self.out(out)
        

In [62]:
def create_attention_mask(H, W, M, shift):
    if shift == 0:
        return None
    
    img_mask = torch.zeros((1, H, W, 1))  # Mask

    count = 0
    H, W = img_mask.shape[1:3]

    # Split image into 9 regions
    h_ranges = [(0, H - M), (H - M, H - shift), (H - shift, H)]
    w_ranges = [(0, W - M), (W - M, W - shift), (W - shift, W)]

    # so if H = W = 12, M = 6, shift = 3
    # h_ranges = [(0, 6), (6, 9), (9, 12)]
    # w_ranges = [(0, 6), (6, 9), (9, 12)]

    # Fill each region with a unique integer
    for h_start, h_end in h_ranges:
        for w_start, w_end in w_ranges:
            img_mask[:, h_start:h_end, w_start:w_end, :] = count
            count += 1
            
    # Cyclic shift the mask
    img_mask = torch.roll(img_mask, shifts=(-shift, -shift), dims=(1,2))

    # Split into M*M windows
    mask_windows = split_into_windows(img_mask, M)  # (nW, M*M, 1)
    mask_windows = mask_windows.squeeze(-1)       # (nW, M*M)
    # Create attention mask
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # (nW, M*M, M*M)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float('-inf')).masked_fill(attn_mask == 0, 0.0)
    return attn_mask


In [64]:
create_attention_mask(12, 12, 6, 3)

tensor([[[0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         ...,
         [-inf, -inf, -inf,  ..., 0., 0., 0.],
         [-inf, -inf, -inf,  ..., 0., 0., 0.],
         [-inf, -inf, -inf,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         ...,
         [-inf, -inf, -inf,  ..., 0., 0., 0.],
         [-inf, -inf, -inf,  ..., 0., 0., 0.],
         [-inf, -inf, -inf,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         ...,
         [-inf, -inf, -inf,  ..., 0., 0., 0.],
         [-inf, -inf, -inf,  ..., 0., 0., 0.],
         [-inf, -inf, -inf,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0.

Input image: B × 3 × H × W

Patch size: P = 4

Window size: M = 7

Base embedding dim (stage 1): C = 96 (Swin-T)

Depths per stage (Swin-T): [2, 2, 6, 2]

Heads per stage: [3, 6, 12, 24] ⇒ per-head dim is ≈ 32

MLP expansion: α = 4

Shift size in SW-MSA: s = M // 2 = 3 (for M=7)

After patch embed, the feature map resolution is H' = H/P, W' = W/P.