# Data Preparation and Augmentation

Data augmentation techniques include RandomResizedCrop, RandomHorizontalFlip, RandAugment (applies random augmentations), and RandomErasing (randomly erases a rectangular region in an image). The pixel values are also normalized using the mean and standard deviation of the dataset. I found this in my ResNet-34 from scratch project.

Dataset is then split into a 90% training set and a 10% validation set.

In [1]:
import os, shutil, csv, pathlib

ROOT = "/kaggle/input/tiny-imagenet-200/tiny-imagenet-200"
VAL_DIR = os.path.join(ROOT, "val")
VAL_ANN = os.path.join(VAL_DIR, "val_annotations.txt")
VAL_IMAGES = os.path.join(VAL_DIR, "images")

# Where to build an ImageFolder-compatible val/ structure
OUT_VAL = "/kaggle/working/tiny-imagenet-200-val"

os.makedirs(OUT_VAL, exist_ok=True)

# Build mapping: filename -> wnid (class folder)
fname_to_wnid = {}
with open(VAL_ANN, "r") as f:
    reader = csv.reader(f, delimiter='\t')
    for row in reader:
        fname, wnid = row[0], row[1]
        fname_to_wnid[fname] = wnid

# Copy images into OUT_VAL/<wnid>/<filename>
for fname, wnid in fname_to_wnid.items():
    src = os.path.join(VAL_IMAGES, fname)
    dst_dir = os.path.join(OUT_VAL, wnid)
    os.makedirs(dst_dir, exist_ok=True)
    dst = os.path.join(dst_dir, fname)
    if not os.path.exists(dst):
        shutil.copyfile(src, dst)


In [2]:
import torch
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
from torchvision.transforms import RandAugment
from torch import nn, optim
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

image_size = 64

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(image_size, scale=(0.9, 1.0)),
    transforms.RandomHorizontalFlip(),
    RandAugment(num_ops=2, magnitude=5),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.05),
    transforms.ToTensor(),
    transforms.Normalize(mean=([0.4802, 0.4481, 0.3975]),   # normalize using mean & std
                         std=([0.2296, 0.2263, 0.2255])),
    transforms.RandomErasing(p=0.1, scale=(0.02, 0.2), value='random'),
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=([0.4802, 0.4481, 0.3975]),   # normalize using mean & std
                         std=([0.2296, 0.2263, 0.2255])),
])

train_dir = os.path.join(ROOT, "train")
val_dir   = OUT_VAL

train_set = datasets.ImageFolder(train_dir, transform=train_transform)
val_set   = datasets.ImageFolder(val_dir,   transform=val_transform)

train_loader = DataLoader(train_set, batch_size=64, shuffle=True,
                          num_workers=4, pin_memory=True, persistent_workers=True, prefetch_factor=2)
val_loader   = DataLoader(val_set, batch_size=64, shuffle=False,
                          num_workers=4, pin_memory=True, persistent_workers=True, prefetch_factor=2)

Using device: cuda


In [3]:
xb, yb = next(iter(val_loader))

print("Val batch images shape:", xb.shape)   # should be (B, 3, H, W)
print("Val batch labels shape:", yb.shape)   # should be (B,)
print("Sample labels:", yb[:10].tolist())

# Verify ranges and types
assert xb.ndim == 4 and xb.size(1) == 3, "Images should have shape (B, 3, H, W)"
assert torch.is_floating_point(xb), "Images should be float tensors"
assert yb.ndim == 1 and yb.dtype == torch.long, "Labels should be 1D LongTensor"
assert xb.min() >= -5 and xb.max() <= 5, "Values look off; check normalization"

Val batch images shape: torch.Size([64, 3, 64, 64])
Val batch labels shape: torch.Size([64])
Sample labels: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]


## Patching and Splitting Windows
Unlike traditional CNNs that use sliding convolutional filters, Vision Transformers break the image down into a sequence of patches, treating them similarly to words in a sentence.

Patchify Class: Takes an input image and converts it into patch embeddings. It uses a single convolutional layer where the kernel size and stride are equal to the patch_size. It divides the image into non overlapping patches and creating an initial vector embedding for each one.

split_into_windows: Takes the patches and splits them into smaller windows. Self-attention is calculated within these windows, which is far more computationally efficient than the original ViT's approach of global attention across all patches.

reverse_windows: Merges the split windows back into their original spatial layout.

In [4]:
class Patchify(nn.Module):
    """
    Convert an image into patch embeddings using a convolutional layer.

    Args:
        in_channels: Number of input channels (e.g., 3 for RGB).
        embed_dim: Output embedding dimension per patch.
        patch_size: Size of each square patch

    Returns:
        Tensor of shape (BS, H//patch_size, W//patch_size, embed_dim)
    """
    def __init__(self, in_channels, embed_dim, patch_size):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        # x shape = (B, 3, H, W)
        x = self.conv(x)         # (BS, embed_dim, H//patch_size, W//patch_size)
        return x.permute(0, 2, 3, 1) # (BS, H//patch_size, W//patch_size, embed_dim)


In [5]:
def split_into_windows(x, M):
    """
    Splits (BS, H, W, channels) into non overlapping MxM windows.
    Args:
        x: Tensor of shape (BS, H, W, channels)
        M: Window size

    Returns:
        Tensor of shape (BS * num_windows, M*M, channels)
    """
    BS, H, W, channels = x.shape
    x = x.reshape(BS, H//M, M, W//M, M, channels)
    # Permute fixes the order so the MxM pixels are together properly
    x = x.permute(0, 1, 3, 2, 4, 5)     # (BS, H//M, W//M, M, M, channels)
    return x.reshape(-1, M*M, channels) # (BS * num_windows, M*M, channels)


In [6]:
def reverse_windows(x, M, H, W, channels):
    """
    Reverses MxM window tokens back into the original layout.

    Args:
        x: Tensor of shape (BS * num_windows, M*M, channels)
        M: Window size
        H: Original image height
        W: Original image width
        channels: Number of channels

    Returns:
        Tensor of shape (BS, H, W, channels)
    """
    BS = x.shape[0] // (H//M * W//M) # Original BS
    x = x.reshape(BS, H//M, W//M, M, M, channels)
    x = x.permute(0, 1, 3, 2, 4, 5)      # (BS, H//M, M, W//M, M, channels)
    return x.reshape(BS, H, W, channels) # (BS, H, W, channels)

## Relative Position Bias

Relative Position Bias shows self-attention mechanism about the geometry of the image. It computes a learnable bias between every pair of tokens in a window, based on how far apart they are. For each pair of tokens in an M x M window, compute relative position, use that to index into a learnable bias table, add this bias to the attention logits.

In [7]:
class RelativePositionBias(nn.Module):
    """
    Computes relative position bias for self-attention.

    Generates a table of learnable relative position biases between token pairs
    within an attention window of shape (M x M). The relative position between
    any two tokens is encoded as a bias vector per attention head, and these biases are
    added to the attention scores in self-attention.

    Args:
        M: The height/width of the attention window. The total number of tokens is M * M.
        nheads: Number of attention heads.

    Attributes:
        relative_table: Learnable parameter of shape ((2M - 1)^2, nheads),
            where each entry represents a bias value for a specific relative position and head.
        relative_index: Lookup table of shape (M*M, M*M), where each entry is an index
            into relative_table that maps the relative position between two tokens to a bias vector.

    Forward Output:
        Tensor of shape (nheads, M*M, M*M) containing the relative bias for each token pair
        and each attention head. This can be directly added to attention logits.

    Example:
        relative = RelativePositionBias(M=3, nheads=4)
        bias = relative()  # Output shape: (4, 9, 9), for 4 heads and 3x3 tokens
    """
    def __init__(self, M, nheads):
        super().__init__()
        self.M = M
        self.nheads = nheads
        # (2M-1)^2 because there are up to M-1 tokens above or below or left or right of each token.
        # 2M - 1 possibilities for above and below, same for left and right. So (2M -1)^2 total.
        # If M = 3 row and col would go between -2 and +2 when comparing two tokens.
        # That gives 5^2 possible combinations. len(-2, -1, 0, 1, 2)^2
        self.relative_table = nn.Parameter(torch.zeros(size=((2*M - 1) * (2*M -1), nheads)))

        # Coordinate grid of token positions shows where each token is in the window.
        # It gives every token a (row, col) coordinate.
        # If M = 3: coords[0] (rows): [[0, 0, 0], [1, 1, 1], [2, 2, 2]], 
        #           coords[1] (cols): [[0, 1, 2], [0, 1, 2], [0, 1, 2]]
        coords = torch.stack(torch.meshgrid( # Matrix style over Cartesian
            torch.arange(M), torch.arange(M), indexing='ij'))  # (2, M, M)
        
        # Flatten the coordinates for token indices, so coords[:, i] is the (row, col) of token i
        coords = coords.flatten(1) # (2, M*M)

        # Compute relative positions so we have the position of a token relative to another token
        # coords[:, :, None] shape: (2, M*M, 1), coords[:, None, :] shape: (2, 1, M*M)
        relative = coords[:, :, None] - coords[:, None, :] # (2, M*M, M*M)

        # Reformat so we can use each (row, col) as an index into a table but row/col values
        # range from -(M-1) to (M-1) so we shift them up so they are positive: [0, 2M -2]
        relative = relative.permute(1, 2, 0) # (M*M, M*M, 2)
        relative[:, :, 0] += M - 1
        relative[:, :, 1] += M - 1

        # Flatten 2D positions into 1D. To convert: row * num_cols + col
        self.register_buffer(   # Register buffer to move to GPU
            "relative_index",
            (relative[:, :, 0] * (2*M - 1) + relative[:, :, 1]).long() # (M*M, M*M)
        )
    def forward(self): 
        # Use index to get bias values, look up the bias vector for each token pair
        bias = self.relative_table[self.relative_index.view(-1)] # (M*M * M*M, nheads)
        bias = bias.reshape(self.relative_index.shape[0], self.relative_index.shape[1], self.nheads)
        return bias.permute(2, 0, 1) # (nheads, M*M, M*M)


## Windowed Multi-Head Self-Attention (W-MSA)

Implements Windowed Multi-head Self-Attention, which is a more efficient version of the standard attention used in ViT. Instead of calculating attention across all patches in the entire image, W-MSA computes attention within M x M windows. Significantly reduces the number of calculations needed.

Attention(Q, K, V) = Softmax((QK.T / √d) + B + mask) @ V

In [8]:
class WindowAttention(nn.Module):
    """
    Window-based Multihead Self-Attention with relative position bias.

    Performs self-attention within non overlapping MxM windows of the input feature map.
    It incorporates relative positional encoding and attention masks for shifted windows

    Attention(Q, K, V) = Softmax((QK.T / √d) + B + mask) @ V
    B is relative position bias.
    Args:
        channels: Input channels
        M: Height and width of the attention window.
        nheads: Number of attention heads.

    Attributes:
        q, k, v: Linear layers for queries, keys, and values.
        out: Output linear layer after attention.
    """
    def __init__(self, channels, M, nheads):
        super().__init__()
        self.M = M
        self.nheads = nheads
        self.rootd = (channels // nheads) ** -0.5 # (1 / √d) == (1 / √dim_per_head)

        self.q = nn.Linear(channels, channels)
        self.k = nn.Linear(channels, channels)
        self.v = nn.Linear(channels, channels)
        self.out = nn.Linear(channels, channels)

        self.relative = RelativePositionBias(M, nheads)
    
    def forward(self, x, attn_mask=None):
        """
        Forward pass for window based self-attention.

        Args:
            x: shape (B * nW, M*M, channels) where nW is number of windows
            attn_mask: Attention mask used for shifted windows to prevent cross-window attention.

        Returns:
            Tensor of shape (B*nW, M*M, channels)
        """
        # x shape = (B * nW, M*M, channels) where nW is number of windows
        BnW, M_sq, channels = x.shape # M_sq is M*M

        # d stands for dim_per_head. Permute on k so no transpose when computing attn.
        # q: (BnW, M*M, channels) --> (BnW, M*M, nheads, d) --> transpose(1, 2) --> (BnW, heads, M*M, d)
        # k: (BnW, M*M, channels) --> (BnW, M*M, nheads, d) --> permute(0, 2, 3, 1) --> (BnW, heads, d, M*M)
        # v: (BnW, M*M, channels) --> (BnW, M*M, nheads, d) --> transpose(1, 2) --> (BnW, heads, M*M, d)
        q = self.q(x).reshape(BnW, M_sq, self.nheads, channels // self.nheads).transpose(1, 2)
        k = self.k(x).reshape(BnW, M_sq, self.nheads, channels // self.nheads).permute(0, 2, 3, 1)
        v = self.v(x).reshape(BnW, M_sq, self.nheads, channels // self.nheads).transpose(1, 2)

        # Attention: (Q @ K.T) / √d + relative bias + optional mask
        # k is already transposed.
        attn = (q @ k) * self.rootd # (BnW, nheads, M*M, M*M)
        attn = attn + self.relative()

        if attn_mask is not None:
            # attn_mask: (M*M, M*M) --> unsqueeze(0) --> (1, M*M, M*M)
            # Broadcasted to (BnW, nheads, M*M, M*M)
            nW = attn_mask.shape[0]
            BS = BnW // nW
            attn_mask = attn_mask.to(attn.device)
            attn = attn.view(BS, nW, self.nheads, M_sq, M_sq) + attn_mask.unsqueeze(0).unsqueeze(2)
            attn = attn.view(BnW, self.nheads, M_sq, M_sq)
        
        attn = attn.softmax(dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(BnW, M_sq, channels)
        return self.out(out)
        

In [9]:
def create_attention_mask(H, W, M, shift):
    """
    Creates an attention mask for shifted window self-attention (SW-MSA).

    This function generates a mask to prevent tokens from attending across windows when 
    performing SW-MSA. It divides the feature map into distinct regions, assigns unique labels
    to each, uses cyclic shifting, partitions it into non overlapping windows, and then
    builds an attention mask that blocks attention between different labeled regions.

    Args:
        H: Height of the feature map.
        W: Width of the feature map.
        M: Window size.
        shift: Number of pixels to cyclically shift the window. 
               If shift is 0, no mask is needed. In Swin it is M // 2

    Returns:
        Tensor of shape (nW, M*M, M*M) where nW is the number of windows.
        Or None if shift is 0.
    """
    if shift == 0:
        return None
    
    img_mask = torch.zeros((1, H, W, 1))  # Mask

    count = 0
    H, W = img_mask.shape[1:3]

    # Split image into 9 regions
    h_ranges = [(0, H - M), (H - M, H - shift), (H - shift, H)]
    w_ranges = [(0, W - M), (W - M, W - shift), (W - shift, W)]

    # so if H = W = 12, M = 6, shift = 3
    # h_ranges = [(0, 6), (6, 9), (9, 12)]
    # w_ranges = [(0, 6), (6, 9), (9, 12)]

    # Fill each region with a unique integer
    for h_start, h_end in h_ranges:
        for w_start, w_end in w_ranges:
            img_mask[:, h_start:h_end, w_start:w_end, :] = count
            count += 1
            
    # Cyclic shift the mask
    img_mask = torch.roll(img_mask, shifts=(-shift, -shift), dims=(1,2))

    # Split into M*M windows
    mask_windows = split_into_windows(img_mask, M)  # (nW, M*M, 1)
    mask_windows = mask_windows.squeeze(-1)       # (nW, M*M)
    # Create attention mask
    attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)  # (nW, M*M, M*M)
    attn_mask = attn_mask.masked_fill(attn_mask != 0, float('-inf')).masked_fill(attn_mask == 0, 0.0)
    return attn_mask


In [10]:
create_attention_mask(12, 12, 6, 3)

tensor([[[0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         ...,
         [-inf, -inf, -inf,  ..., 0., 0., 0.],
         [-inf, -inf, -inf,  ..., 0., 0., 0.],
         [-inf, -inf, -inf,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         ...,
         [-inf, -inf, -inf,  ..., 0., 0., 0.],
         [-inf, -inf, -inf,  ..., 0., 0., 0.],
         [-inf, -inf, -inf,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         ...,
         [-inf, -inf, -inf,  ..., 0., 0., 0.],
         [-inf, -inf, -inf,  ..., 0., 0., 0.],
         [-inf, -inf, -inf,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., -inf, -inf, -inf],
         [0., 0., 0.,  ..., -inf, -inf, -inf],
         [0.

## Stochastic Depth

Stochastic depth is a regularization technique used to improve generalization and reduce overfitting. Instead of dropping individual neurons, entire residual branches are skipped during training with a given probability (drop_prob).

In [11]:
def stochastic_depth(x, drop_prob, training):
    """
    Applies stochastic depth to the input.

    Args:
        x: Input tensor
        drop_prob: Probability of dropping the path.
        training: If True, stochastic depth is applied. If False, input is returned unchanged.

    Returns:
        Output tensor with some residual paths zeroed out.
    """
    if drop_prob == 0.0 or not training:
        return x

    keep_prob = 1.0 - drop_prob
    # Create mask with shape (BS, 1, 1, ..., 1) so it broadcasts over all non batch dims
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    # 1 with prob keep_prob, 0 with prob drop_prob
    mask = torch.rand(shape, dtype=x.dtype, device=x.device) < keep_prob
    # Scale
    x = x / keep_prob
    return x * mask

class StochasticDepth(nn.Module):
    """
    Module for stochastic depth.
    """
    def __init__(self, drop_prob):
        super().__init__()
        self.drop_prob = float(drop_prob)

    def forward(self, x):
        return stochastic_depth(x, self.drop_prob, self.training)


## SwinBlock (Shifted Window Transformer Block)

Applies windowed self-attention over non overlapping M×M windows and alternates between non-shifted and shifted windows across consecutive blocks to enable cross-window connections. There is also an MLP with GELU at the end. Residual connections are used.

![Block Architecture](figures/block_arch.png)

In [12]:
class SwinBlock(nn.Module):
    """
    Swin Transformer block with shifted/non-shifted window self-attention + MLP.

    Args:
        dim: Channel dimension of the input features.
        H: Feature map height.
        W: Feature map width.
        nheads: Number of attention heads in WindowAttention.
        M: Window size.
        shift: Cyclic shift size (0 for non-shifted windows, M//2 for shifted).
        ratio: Expansion ratio for the MLP hidden size (hidden dim = ratio * dim).
        stoch_depth: stochastic depth probability for dropping residual branches.

    Attributes:
        norm1: Pre attention normalization.
        attn (WindowAttention): Window-based multi-head self-attention.
        drop_path: Stochastic depth module or identity if stoch_depth == 0.
        norm2: Pre MLP normalization.
        mlp: Two-layer feed forward network with GELU activation.
        attn_mask: Mask for shifted attention, shape (nW, M*M, M*M) when shift > 0, else None.

    Input:
        x: Tensor of shape (BS, H, W, channels).

    Output:
        Tensor of shape (BS, H, W, Channels), same spatial shape and channels as input.

    """
    def __init__(self, dim, H, W, nheads, M, shift, ratio, stoch_depth):
        super().__init__()
        self.dim = dim
        self.M = M
        self.shift = shift
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(channels=dim, M=M, nheads=nheads)
        if stoch_depth > 0:
            self.stoch = StochasticDepth(stoch_depth)
        else:
            self.stoch = nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, int(dim * ratio)), # Error I had before, forgot to wrap with int()
            nn.GELU(),
            nn.Linear(int(dim * ratio), dim)
        )
        self.attn_mask = create_attention_mask(H, W, M, shift)
    
    def forward(self, x):
        """
        Forward pass of SwinBlock.

        LayerNorm --> cyclic shift --> split windows --> WindowAttention (masked if shifted)
        --> reverse windows --> reverse shift --> residual + stochastic depth --> LayerNorm
        --> MLP → residual +stochastic depth

        Returns:
            Tensor of shape (BS, H, W, channels).
        """
        BS, H, W, channels = x.shape

        store = x # For Residual connection
        x = self.norm1(x)

        if self.shift > 0:
            x = torch.roll(x, shifts=(-self.shift, -self.shift), dims=(1, 2)) # Cyclic shift
        x_windows = split_into_windows(x, self.M)                  # (BS*nW, M*M, channels)
        x_windows = self.attn(x_windows, attn_mask=self.attn_mask) # (BS*nW, M*M, channels)
        x = reverse_windows(x_windows, self.M, H, W, channels)     # (BS, H, W, channels)
        if self.shift > 0:
            x = torch.roll(x, shifts=(self.shift, self.shift), dims=(1, 2))

        x = store + self.stoch(x)

        return x + self.stoch(self.mlp(self.norm2(x)))

## Patch Merging Layer

In [13]:
class PatchMerging(nn.Module):
    """
    Reduces H and W by 2x in each dimension. Doubles the channel dimension.
    - Extract non overlapping 2x2 patches from the feature map.
    - Concatenate features from each patch along the channel dimension.
    - Apply LayerNorm for normalization across channels.
    - Lower 4*channels down to 2*channels with a Linear layer.

    Input:
        x: shape (BS, H, W, channels)

    Output:
        shape (BS, H/2, W/2, 2*channels)

    H and W should be even.
    """
    def __init__(self, dim):
        super().__init__()
        self.lin = nn.Linear(4*dim, 2*dim)
        self.norm = nn.LayerNorm(4*dim) # 4 * channels --> 2 * channels
    
    def forward(self, x):
        # x shape = (BS, H, W, channels)
        x1 = x[:, 0::2, 0::2, :]
        x2 = x[:, 1::2, 0::2, :]
        x3 = x[:, 0::2, 1::2, :]
        x4 = x[:, 1::2, 1::2, :]
        x = torch.cat([x1, x2, x3, x4], dim=3) # (BS, H/2, W/2, 4*channels)
        x = self.norm(x)
        return self.lin(x) # (BS, H/2, W/2, 2*channels)


## Stage (Stack of Swin Blocks + Optional Patch Merging)

A Stage stacks SwinBlocks, alternating between:
- W-MSA (non-shifted windows, shift=0)
- SW-MSA (shifted windows, shift=M//2)

Then optionally applies PatchMerging to downsample and increase channels for the next stage.
After PatchMerging, the next stage should be constructed with dim = 2 * previous_dim and H and W halved.

In [14]:
class Stage(nn.Module):
    """
    Args:
        dim: Channel dimension.
        H: Feature map height.
        W: Feature map width.
        M: Window size.
        blocks: Number of SwinBlocks in this stage.
        nheads: Number of attention heads per block.
        stoch_depth_list: List of stochastic depth probabilities (len == blocks).
        patch_merging: If True, apply PatchMerging at the end of the stage.
        ratio: MLP expansion ratio (hidden dim = ratio * dim).

    Input:
        x: Tensor of shape (BS, H, W, channels or dim)

    Output:
        - If patch_merging is False: (BS, H, W, dim)
        - If patch_merging is True:  (BS, H/2, W/2, 2*dim)
    """
    def __init__(self, dim, H, W, M, blocks, nheads, stoch_depth_list, patch_merging, ratio):
        # stoch_depth_list is a list of the stochastic depth rates for each SwinBlock.

        super().__init__()
        self.blocks = nn.ModuleList()
        for i in range(blocks):
            if i % 2 == 0:
                shift = 0  # Alternate between W-MSA and SW-MSA, W-MSA has no shift.
            else:
                shift = M // 2
            self.blocks.append(
                SwinBlock(dim, H, W, nheads, M, shift, ratio, stoch_depth_list[i])
            )
        if patch_merging:
            self.patch = PatchMerging(dim)
        else:
            self.patch = nn.Identity()
    
    def forward(self, x):
        # x shape = (BS, H, W, channels)
        for block in self.blocks:
            x = block(x)
        return self.patch(x)

## Parameters

- Patch size: 4 × 4
- Base embed dim C: 96
- Depths: [2, 2, 6, 2]
- Num heads: [3, 6, 12, 24]
- Window size: 7 for all blocks
- Shift: 3 (7 // 2) for all shift blocks
- MLP expantion ratio: 4.0 for all blocks
- Drop path rate: 0.2 (linearly increased across all blocks)
- Patch Merging / Downsample at the end of Stage 1, 2 and 3.


| Stage         | Blocks | Heads | Stoch_dep | In Channels | Out Channels | Output Shape             |
|---------------|--------|-------|-----------|-------------|--------------|--------------------------|
| PatchEmbed    | None   | None  | None           | 3           | 96           | (BS, H/4, W/4, 96)       |
| Stage 1       | 2      | 3     | [0.0000, 0.0182]          | 96          | 192          | (BS, H/8, W/8, 192)      |
| Stage 2       | 2      | 6     | [0.0364, 0.0545]          | 192         | 384          | (BS, H/16, W/16, 384)    |
| Stage 3       | 6      | 12    | [0.0727, 0.0909, 0.1091, 0.1273, 0.1455, 0.1636]          | 384         | 768          | (BS, H/32, W/32, 768)    |
| Stage 4       | 2      | 24    | [0.1818, 0.2000]          | 768         | 768          | (BS, H/32, W/32, 768)    |
| Head          | None   | None  | None          | 768         | num_classes  | (BS, num_classes)        |

## Swin Transformer Architecture

![Architecture](figures/Architecture.png)

In [15]:
class SwinTransformer(nn.Module):
    """
    Swin Transformer

    Steps:
        - Patchify to get (BS, H/patch, W/patch, emb_dim)
        - 4 stages of Swin blocks with alternating W-MSA (shift=0) and SW-MSA (shift=M//2)
        - Patch Merging at the end of stages 1-3
        - Global average pooling and linear classifier head

    Args:
        img_size: Input image size
        patch_size: Patch size for Patchify.
        emb_dim: Base embedding dimension channel for stage 1.
        blocks (List[int]): Number of blocks per stage.
        nheads (List[int]): Number of attention heads per stage.
        M: Window size for all blocks.
        n_classes: Number of output classes for the classifier head.
        stochastic_endpoint: Stochastic depth ratio endpoint for linspace.

    Shapes:
        Input:  (BS, 3, 224, 224)
        After patchify: (BS, 56, 56, 96)
        After stage1:   (BS, 28, 28, 192)
        After stage2:   (BS, 14, 14, 384)
        After stage3:   (BS, 7, 7, 768)
        After stage4:   (BS, 7, 7, 768)
        Output logits:  (BS, n_classes)
    """
    def __init__(self, img_size, patch_size, emb_dim, blocks,
                 nheads, M, n_classes, stochastic_endpoint):
        super().__init__()

        H = img_size // patch_size
        W = img_size // patch_size
        dims = [emb_dim, 2*emb_dim, 4*emb_dim, 8*emb_dim]

        self.patchify = Patchify(3, emb_dim, patch_size)

        # Linearly increase across all blocks (inclusive endpoints)
        stoch_depth = list(np.linspace(0, stochastic_endpoint, sum(blocks)))
        
        # 4 Stages with stochastic depth aligned with the number of blocks in each stage.
        ind = 0
        self.stage1 = Stage(
            dims[0], H, W, M[0], 
            blocks[0], nheads[0], 
            stoch_depth[ind:ind+blocks[0]], 
            patch_merging=True, ratio=4.0)
        ind += blocks[0]
        H //= 2
        W //= 2
        self.stage2 = Stage(
            dims[1], H, W, M[1], 
            blocks[1], nheads[1], 
            stoch_depth[ind:ind+blocks[1]], 
            patch_merging=True, ratio=4.0)
        ind += blocks[1]
        H //= 2
        W //= 2
        self.stage3 = Stage(
            dims[2], H, W, M[2], 
            blocks[2], nheads[2], 
            stoch_depth[ind:ind+blocks[2]], 
            patch_merging=True, ratio=4.0)
        ind += blocks[2]
        H //= 2
        W //= 2
        self.stage4 = Stage(
            dims[3], H, W, M[3], 
            blocks[3], nheads[3], 
            stoch_depth[ind:ind+blocks[3]], 
            patch_merging=False, ratio=4.0)

        self.norm = nn.LayerNorm(dims[3])
        self.head = nn.Linear(dims[3], n_classes)
    
    def forward(self, x):
        """
        Forward pass.

        Args:
            x: Input tensor of shape (BS, 3, img_size, img_size).

        Returns:
            Class logits of shape (BS, n_classes).
        """
        # x shape = (BS, 3, 224, 224)
        x = self.patchify(x) # (BS, H//patch_size, W//patch_size, embed_dim)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.norm(x).mean(dim=(1, 2)) # Global average pooling over H, W.
        return self.head(x)

# Soft Cross-Entropy

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def one_hot(labels: torch.Tensor, num_classes: int) -> torch.Tensor:
    return F.one_hot(labels, num_classes=num_classes).float()

class SoftCrossEntropy(nn.Module):
    """
    If target is LongTensor -> falls back to nn.CrossEntropyLoss (hard labels).
    If target is FloatTensor (N,C) -> computes soft CE: -sum(p * log_softmax).
    """
    def __init__(self, label_smoothing: float = 0.0):
        super().__init__()
        # If you use Mixup/CutMix, set smoothing=0.0 to avoid double softening.
        self.ce = nn.CrossEntropyLoss(label_smoothing=label_smoothing)

    def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        if target.dtype == torch.long:
            return self.ce(logits, target)
        log_probs = F.log_softmax(logits, dim=1)
        loss = -(target * log_probs).sum(dim=1).mean()
        return loss

## Mixup + CutMix module

In [17]:
import random

def _rand_bbox(W: int, H: int, lam: float):
    # CutMix box size from area ratio lam
    cut_rat = (1.0 - lam) ** 0.5
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = random.randint(0, W - 1)
    cy = random.randint(0, H - 1)
    x1 = max(cx - cut_w // 2, 0)
    y1 = max(cy - cut_h // 2, 0)
    x2 = min(cx + cut_w // 2, W)
    y2 = min(cy + cut_h // 2, H)
    return x1, y1, x2, y2

class MixupCutmix:
    """
    On each batch, applies Mixup or CutMix with given probabilities.
    Returns:
      images: possibly mixed tensor
      targets: either Long (no mix) or Float one-hot (mixed)
    """
    def __init__(self, num_classes: int,
                 mixup_alpha: float = 0.8,
                 cutmix_alpha: float = 1.0,
                 p_mixup: float = 0.5,
                 p_cutmix: float = 0.5):
        self.num_classes = num_classes
        self.mixup_alpha = mixup_alpha
        self.cutmix_alpha = cutmix_alpha
        self.p_mixup = p_mixup
        self.p_cutmix = p_cutmix
        self.enabled = True

    def off(self):  self.enabled = False
    def on(self):   self.enabled = True

    @torch.no_grad()
    def __call__(self, images: torch.Tensor, targets: torch.Tensor):
        if (not self.enabled) or (self.p_mixup <= 0 and self.p_cutmix <= 0):
            return images, targets  # no change

        B, C, H, W = images.shape
        # decide op
        op = None
        r = random.random()
        if r < self.p_mixup:
            op = 'mixup'
        elif r < self.p_mixup + self.p_cutmix:
            op = 'cutmix'
        else:
            return images, targets  # no change

        # sample lambda from Beta
        from torch.distributions import Beta

        if op == 'mixup' and self.mixup_alpha > 0:
            lam = Beta(self.mixup_alpha, self.mixup_alpha).sample().item()
        elif op == 'cutmix' and self.cutmix_alpha > 0:
            lam = Beta(self.cutmix_alpha, self.cutmix_alpha).sample().item()
        else:
            return images, targets

        lam = max(min(lam, 0.999), 0.001)

        # shuffle
        index = torch.randperm(B, device=images.device)
        y1 = one_hot(targets, self.num_classes).to(images.dtype)
        y2 = one_hot(targets[index], self.num_classes).to(images.dtype)

        if op == 'mixup':
            mixed = lam * images + (1.0 - lam) * images[index]
            y = lam * y1 + (1.0 - lam) * y2
            return mixed, y

        # CutMix
        x1, y1b, x2, y2b = _rand_bbox(W, H, lam)
        mixed = images.clone()
        mixed[:, :, y1b:y2b, x1:x2] = images[index, :, y1b:y2b, x1:x2]

        # adjust lam to actual area
        box_area = (x2 - x1) * (y2b - y1b)
        lam_adj = 1.0 - float(box_area) / float(W * H)
        y = lam_adj * one_hot(targets, self.num_classes).to(images.dtype) + \
            (1.0 - lam_adj) * one_hot(targets[index], self.num_classes).to(images.dtype)
        return mixed, y

In [18]:
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
import json


model = SwinTransformer(
    img_size=64, patch_size=4, emb_dim=96,
    blocks=[2, 2, 6, 2], nheads=[2, 4, 8, 16],
    M=[8, 8, 4, 2], n_classes=200, stochastic_endpoint=0.2).to(device)

loss_fn = SoftCrossEntropy(label_smoothing=0.05)
mixer = MixupCutmix(num_classes=200, mixup_alpha=0.4, cutmix_alpha=0.8,
                    p_mixup=0.3, p_cutmix=0.3)
# Scheduler Epochs
warmup_epochs = 10  # Gradually increase lr for the first 5 epochs
cosine_epochs = 190 # Cosine-anneal lr for the remaining 95 epochs

def param_groups_weight_decay(model, weight_decay=0.05):
    decay, no_decay = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        name = n.lower()
        if (
            p.ndim == 1
            or n.endswith(".bias")
            or "relative" in name
            or "pos_embed" in name
        ):
            no_decay.append(p)
        else:
            decay.append(p)
    return [
        {"params": decay, "weight_decay": weight_decay},
        {"params": no_decay, "weight_decay": 0.0},
    ]

start_lr = 2.5e-4          # 1e-3 * (64/256)
eta_min  = 5e-6            # ~2% floor

# AdamW optimizer, weight decay taken directly from the paper.
# fused=True uses a fused CUDA kernel
optimizer = torch.optim.AdamW(param_groups_weight_decay(model, 0.05), lr=start_lr, betas=(0.9, 0.999), fused=True)

# Linear warmup scheduler. Start at 1% of lr and gradually fo up to 100%
warmup_scheduler = LinearLR(optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_epochs)
# Cosine annealing scheduler. Decay lr following a cosine curve
cosine_scheduler = CosineAnnealingLR(optimizer, T_max=cosine_epochs, eta_min=eta_min)
# Run warmup_scheduler first, then cosine_scheduler
scheduler = SequentialLR(optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_epochs])
# GradScaler for Automatic Mixed Precision (AMP). Saves VRAM.
scaler = torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))

best_top1 = 0.0
history = {
    "train_loss": [],
    "val_loss": [],
    "val_top1": [],
    "val_top5": [],
    "lr": []
}

for epoch in range(warmup_epochs + cosine_epochs):
    model.train()
    total_loss = 0.0
    n_samples = 0

    if epoch < 1 or epoch >= warmup_epochs + cosine_epochs - 10:
        mixer.off()
    else:
        mixer.on()
    
    for xb, yb in train_loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)
        xb, yb = mixer(xb, yb)  # labels now (N) or (N,C)
        optimizer.zero_grad(set_to_none=True) # set_to_none saves a bit of memory.
        # Forward and loss in mixed precision
        with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
            yhat = model(xb)
            loss = loss_fn(yhat, yb)
        # Backprop with scaled loss
        scaler.scale(loss).backward()
        # unscale before clipping, then clip
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=3.0)
        # Step optimizer only goes if gradients are finite
        scaler.step(optimizer)
        scaler.update()

        bs = xb.size(0)
        total_loss += loss.item() * bs
        n_samples += bs

    # Safe averages
    avg_loss = total_loss / n_samples
    
    # Validation
    model.eval()
    val_loss = 0.0
    val_samples = 0
    val_correct_top1 = 0
    val_correct_top5 = 0
    
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
                yhat = model(xb)
                loss = loss_fn(yhat, yb)
            bs = xb.size(0)
            val_loss += loss.item() * bs
            val_samples += bs
            
            topk = torch.topk(yhat, k=5, dim=1).indices  # (B,5)
            val_correct_top1 += (topk[:, 0] == yb).sum().item()
            val_correct_top5 += topk.eq(yb.view(-1, 1)).any(dim=1).sum().item()
            
    avg_val_loss = val_loss / val_samples
    val_top1 = val_correct_top1 / val_samples
    val_top5 = val_correct_top5 / val_samples
    
    # Get current learning rate. I had this wrong before, I was grabbing the past lr instead of current
    scheduler.step()
    current_lr = scheduler.get_last_lr()[0]

    history["train_loss"].append(avg_loss)
    history["val_loss"].append(avg_val_loss)
    history["val_top1"].append(val_top1)
    history["val_top5"].append(val_top5)
    history["lr"].append(current_lr)

    with open("training_history.json", "w") as f:
        json.dump(history, f)

    if val_top1 > best_top1:
        best_top1 = val_top1
        torch.save({
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
            "epoch": epoch,
            "best_val_top1": best_top1,
            "best_val_top5": val_top5
        }, "swin_t_best.pt")
        print("Saved new best model.")

    print(f"Epoch {epoch+1}, Train Loss: {avg_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Top1 Acc: {val_top1:.4f}, Top5 Acc: {val_top5:.4f}, LR {current_lr:.6f}")

  scaler = torch.cuda.amp.GradScaler(enabled=(device.type=="cuda"))
  with torch.cuda.amp.autocast(enabled=(device.type=="cuda")):
  with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):


Saved new best model.
Epoch 1, Train Loss: 5.1915, Val Loss: 4.9821, Top1 Acc: 0.0336, Top5 Acc: 0.1240, LR 0.000027
Saved new best model.
Epoch 2, Train Loss: 4.9946, Val Loss: 4.5802, Top1 Acc: 0.0739, Top5 Acc: 0.2341, LR 0.000052
Saved new best model.
Epoch 3, Train Loss: 4.7811, Val Loss: 4.3221, Top1 Acc: 0.1142, Top5 Acc: 0.3072, LR 0.000077
Saved new best model.
Epoch 4, Train Loss: 4.6086, Val Loss: 4.1162, Top1 Acc: 0.1410, Top5 Acc: 0.3514, LR 0.000101
Saved new best model.
Epoch 5, Train Loss: 4.4472, Val Loss: 3.9678, Top1 Acc: 0.1593, Top5 Acc: 0.3931, LR 0.000126
Saved new best model.
Epoch 6, Train Loss: 4.3176, Val Loss: 3.7881, Top1 Acc: 0.1904, Top5 Acc: 0.4447, LR 0.000151
Saved new best model.
Epoch 7, Train Loss: 4.2444, Val Loss: 3.7120, Top1 Acc: 0.2052, Top5 Acc: 0.4634, LR 0.000176
Saved new best model.
Epoch 8, Train Loss: 4.1417, Val Loss: 3.6967, Top1 Acc: 0.2110, Top5 Acc: 0.4625, LR 0.000201
Saved new best model.
Epoch 9, Train Loss: 4.0691, Val Loss: 3.4



Saved new best model.
Epoch 10, Train Loss: 3.9755, Val Loss: 3.4789, Top1 Acc: 0.2510, Top5 Acc: 0.5200, LR 0.000250
Saved new best model.
Epoch 11, Train Loss: 3.9105, Val Loss: 3.4478, Top1 Acc: 0.2537, Top5 Acc: 0.5273, LR 0.000250
Saved new best model.
Epoch 12, Train Loss: 3.8741, Val Loss: 3.3205, Top1 Acc: 0.2829, Top5 Acc: 0.5594, LR 0.000250
Saved new best model.
Epoch 13, Train Loss: 3.7887, Val Loss: 3.1831, Top1 Acc: 0.3134, Top5 Acc: 0.5907, LR 0.000250
Epoch 14, Train Loss: 3.6939, Val Loss: 3.2082, Top1 Acc: 0.3073, Top5 Acc: 0.5875, LR 0.000250
Saved new best model.
Epoch 15, Train Loss: 3.6365, Val Loss: 3.1031, Top1 Acc: 0.3382, Top5 Acc: 0.6039, LR 0.000250
Epoch 16, Train Loss: 3.5900, Val Loss: 3.0840, Top1 Acc: 0.3380, Top5 Acc: 0.6162, LR 0.000249
Saved new best model.
Epoch 17, Train Loss: 3.5168, Val Loss: 2.9807, Top1 Acc: 0.3560, Top5 Acc: 0.6367, LR 0.000249
Epoch 18, Train Loss: 3.4848, Val Loss: 2.9828, Top1 Acc: 0.3541, Top5 Acc: 0.6366, LR 0.000249
Save

224 images by 224 was way too much for my GPU. So I went down to 64 x 64 and adjusted. I didn't want M to be 2 so I modified so each stage has a different M. Entire architecture was adapted and controlled.

Augmentations too strong at 64×64.
RandAugment(mag=9) + RandomResizedCrop(0.8,1.0) + RandomErasing(p=0.25) on tiny images strips too much signal. Model keeps fitting train (loss ↓) but can’t push val higher → plateau at ~37%.

Relative Position Bias (RPB) was decayed.
relative_table lived in the weight-decay group before your change. That nudges locality info toward zero over time, hurting window attention’s inductive bias → earlier “ok” gains, then flattening.

Mixup/CutMix on for ~90% of training.
Great regularizers, but at 64×64 + heavy RA they can over-regularize, making logits softer. You also validate with hard CE, so it’s common to see val_loss drift up a bit while acc stays flat (calibration mismatch).

Resolution + capacity trade-off.
At 64px, content is compressed. Your Swin-T config (C=64, windows [8,8,4,2]) is reasonable, but it simply has less separable info than 224px baselines. ~35–40% Top-1 on Tiny-IN@64 with strong regularization is plausible.

Validation hygiene.
The split is from train/ (not official val/) and not stratified in the “before” run. That adds noise and can slightly mute peak accuracy.

Minor dampeners (secondary):

Grad clip = 1.0 can be tight for attention; may slow learning a bit.

LR schedule is fine; after ~30–40 epochs, cosine has already lowered LR a lot, making it harder to escape the regularization-limited regime.