In [1]:
import sys
sys.path.append("..")
import torchvision.transforms as transforms
from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders

image_size =224
tiny_transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.Resize((image_size, image_size)), 
        transforms.RandomCrop(image_size, padding=5),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
tiny_transform_val = transforms.Compose([
        transforms.Resize((image_size, image_size)), 
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
tiny_transform_test = transforms.Compose([
        transforms.Resize((image_size, image_size)), 
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

train_loader, val_loader, test_loader = get_tinyimagenet_dataloaders(
                                                    data_dir = '../datasets',
                                                    transform_train=tiny_transform_train,
                                                    transform_val=tiny_transform_val,
                                                    transform_test=tiny_transform_test,
                                                    batch_size=64,
                                                    image_size=image_size)


FileNotFoundError: [WinError 3] The system cannot find the path specified: '../datasets\\tiny-imagenet-200/train'

In [4]:
import math
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

############################################
# 1. Dynamic Relative Positional Bias
############################################

class DynamicRelativePositionalBias(nn.Module):
    """
    Computes a dynamic relative positional bias for each token pair within a square window.
    The bias is computed by feeding the relative coordinates (a 2D vector) through a small MLP.
    """
    def __init__(self, hidden_dim=32):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, window_size):
        # Create coordinate grid for a window of size window_size x window_size.
        coords_h = torch.arange(window_size, device=self.mlp[0].weight.device)
        coords_w = torch.arange(window_size, device=self.mlp[0].weight.device)
        coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)  # (ws, ws, 2)
        coords_flat = coords.view(-1, 2)  # (M, 2) with M = window_size^2
        # Compute relative coordinates (difference between each pair)
        relative_coords = coords_flat.unsqueeze(0) - coords_flat.unsqueeze(1)  # (M, M, 2)
        relative_coords = relative_coords.float()
        # Compute bias for each relative coordinate pair.
        bias = self.mlp(relative_coords)  # (M, M, 1)
        bias = bias.squeeze(-1)  # (M, M)
        return bias  # Shape: (window_size^2, window_size^2)

############################################
# 2. Dynamic Shifted Window Attention with Dynamic Relative Positional Encoding
############################################

class DynamicShiftedWindowAttention(nn.Module):
    """
    Partitions the input feature map into windows with a dynamic spatial shift.
    A gating network predicts a (dx, dy) shift applied via torch.roll before partitioning.
    Within each window, self-attention is computed with QKV projections and dynamic relative
    positional bias is added to the attention scores.
    
    Assumes input x is in (B, H, W, C) format.
    """
    def __init__(self, dim, num_heads, window_size, shift_range=1, dropout=0.0, rp_hidden_dim=32):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size  # e.g. 7 or 14
        self.shift_range = shift_range  # maximum pixel shift in each direction
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)
        self.attn_drop = nn.Dropout(dropout)
        self.proj_drop = nn.Dropout(dropout)

        self.rel_pos_bias = DynamicRelativePositionalBias(hidden_dim=rp_hidden_dim)

        # A gating network that predicts a (dx, dy) shift (values in [-1,1]).
        self.shift_gate = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # expects input in (B, C, H, W)
            nn.Flatten(),
            nn.Linear(dim, 2),
            nn.Tanh()  # output in range [-1,1]
        )

    def forward(self, x):
        # x: (B, H, W, C)
        B, H, W, C = x.shape

        # ----- Compute Dynamic Shift Offset -----
        # Permute to (B, C, H, W) for global pooling.
        x_perm = x.permute(0, 3, 1, 2).contiguous()
        shift_offsets = self.shift_gate(x_perm)  # (B, 2)
        # For simplicity, use the average offset across the batch.
        shift_offsets = shift_offsets.mean(dim=0)  # (2,)
        shift_x = int(round(shift_offsets[0].item() * self.shift_range))
        shift_y = int(round(shift_offsets[1].item() * self.shift_range))
        # Apply negative shift.
        x_shifted = torch.roll(x, shifts=(-shift_x, -shift_y), dims=(1, 2))

        # ----- Partition into Windows -----
        pad_h = (self.window_size - H % self.window_size) % self.window_size
        pad_w = (self.window_size - W % self.window_size) % self.window_size
        if pad_h or pad_w:
            x_shifted = F.pad(x_shifted, (0, 0, 0, pad_w, 0, pad_h))
        Hp, Wp = x_shifted.shape[1], x_shifted.shape[2]
        # Reshape into windows.
        x_windows = x_shifted.view(B, Hp // self.window_size, self.window_size,
                                   Wp // self.window_size, self.window_size, C)
        x_windows = x_windows.permute(0, 1, 3, 2, 4, 5).contiguous()  # (B, num_win_H, num_win_W, ws, ws, C)
        num_windows = (Hp // self.window_size) * (Wp // self.window_size)
        x_windows = x_windows.view(B * num_windows, self.window_size * self.window_size, C)

        # ----- Self-Attention within Each Window -----
        qkv = self.qkv(x_windows)  # (B*num_win, M, 3*dim) with M = window_size^2
        qkv = qkv.reshape(B * num_windows, self.window_size * self.window_size, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B*num_win, num_heads, M, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B*num_win, num_heads, M, M)

        # ----- Add Dynamic Relative Positional Bias -----
        bias = self.rel_pos_bias(self.window_size)  # (M, M)
        attn = attn + bias.unsqueeze(0).unsqueeze(0)  # broadcast to (B*num_win, num_heads, M, M)

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x_attn = attn @ v  # (B*num_win, num_heads, M, head_dim)
        x_attn = x_attn.transpose(1, 2).reshape(B * num_windows, self.window_size * self.window_size, C)
        x_attn = self.proj(x_attn)
        x_attn = self.proj_drop(x_attn)

        # ----- Merge Windows Back -----
        x_attn = x_attn.view(B, Hp // self.window_size, Wp // self.window_size,
                              self.window_size, self.window_size, C)
        x_attn = x_attn.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, C)
        # Reverse the initial shift.
        x_out = torch.roll(x_attn, shifts=(shift_x, shift_y), dims=(1, 2))
        if pad_h or pad_w:
            x_out = x_out[:, :H, :W, :]
        return x_out

############################################
# 3. Multi-Scale Feature Aggregation
############################################

class MultiScaleDynamicAttention(nn.Module):
    """
    Implements two parallel attention branches that operate on different window sizes.
    Their outputs are concatenated and fused via a linear projection.
    """
    def __init__(self, dim, num_heads, window_size_small, window_size_large,
                 shift_range=1, dropout=0.0, rp_hidden_dim=32):
        super().__init__()
        self.attn_small = DynamicShiftedWindowAttention(dim, num_heads, window_size_small,
                                                        shift_range, dropout, rp_hidden_dim)
        self.attn_large = DynamicShiftedWindowAttention(dim, num_heads, window_size_large,
                                                        shift_range, dropout, rp_hidden_dim)
        self.fuse = nn.Linear(dim * 2, dim)

    def forward(self, x):
        # x: (B, H, W, C)
        out_small = self.attn_small(x)
        out_large = self.attn_large(x)
        out = torch.cat([out_small, out_large], dim=-1)
        out = self.fuse(out)
        return out

############################################
# 4. Advanced Swin Transformer Block
############################################

class SwinTransformerBlockAdvanced(nn.Module):
    """
    A transformer block that applies multi-scale dynamic shifted window attention
    with dynamic relative positional encoding. A standard residual connection and
    an MLP (feed-forward network) follow the attention.
    
    Input and output are in spatial format: (B, H, W, C).
    """
    def __init__(self, dim, num_heads, window_size_small, window_size_large,
                 shift_range=1, dropout=0.0, rp_hidden_dim=32, mlp_ratio=4.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiScaleDynamicAttention(dim, num_heads, window_size_small, window_size_large,
                                                 shift_range, dropout, rp_hidden_dim)
        self.norm2 = nn.LayerNorm(dim)
        hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # x: (B, H, W, C)
        shortcut = x
        B, H, W, C = x.shape
        x_norm = self.norm1(x.view(B * H * W, C)).view(B, H, W, C)
        x_attn = self.attn(x_norm)
        x = shortcut + x_attn

        shortcut2 = x
        x_norm2 = self.norm2(x.view(B * H * W, C)).view(B, H, W, C)
        x_mlp = self.mlp(x_norm2)
        x = shortcut2 + x_mlp
        return x

############################################
# 5. Overall Advanced Swin Transformer
############################################

class SwinTransformerAdvanced(nn.Module):
    """
    A Swin Transformer variant that:
      - Uses a patch embedding layer,
      - Applies a sequence of advanced transformer blocks with multi-scale dynamic shifted window attention,
      - Uses global average pooling and a classification head.
    
    Args:
      img_size: Input image size.
      patch_size: Patch size for the initial embedding.
      in_chans: Number of input channels.
      num_classes: Number of classes.
      embed_dim: Embedding dimension.
      depth: Number of transformer blocks.
      num_heads: Number of attention heads.
      window_size_small: A smaller window size (e.g. 7).
      window_size_large: A larger window size (e.g. 14).
      shift_range, dropout, rp_hidden_dim, mlp_ratio: Additional hyperparameters.
    """
    def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=200,
                 embed_dim=96, depth=4, num_heads=4, window_size_small=7, window_size_large=14,
                 shift_range=1, dropout=0.0, rp_hidden_dim=32, mlp_ratio=4.0):
        super().__init__()
        self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_drop = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            SwinTransformerBlockAdvanced(embed_dim, num_heads,
                                         window_size_small, window_size_large,
                                         shift_range, dropout, rp_hidden_dim, mlp_ratio)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # x: (B, in_chans, img_size, img_size)
        x = self.patch_embed(x)  # (B, embed_dim, H, W)
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1).contiguous()  # (B, H, W, C)
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = x.mean(dim=[1, 2])  # Global average pooling over spatial dims.
        x = self.norm(x)
        x = self.head(x)
        return x

############################################
# 6. Training and Validation Functions
############################################

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)  # outputs: (B, num_classes)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        total_correct += (preds == labels).sum().item()
        total_samples += images.size(0)
    return total_loss / total_samples, 100.0 * total_correct / total_samples

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            total_correct += (preds == labels).sum().item()
            total_samples += images.size(0)
    return total_loss / total_samples, 100.0 * total_correct / total_samples

############################################
# 7. Main Training Script for Tiny ImageNet
############################################

def main():
    # ------------- Configuration -------------
    num_classes = 200  # Tiny ImageNet has 200 classes.
    num_epochs = 50
    learning_rate = 1e-3

    # ------------- Device -------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ------------- Model, Loss, Optimizer, Scheduler -------------
    model = SwinTransformerAdvanced(img_size=224, patch_size=4, in_chans=3, num_classes=num_classes,
                                    embed_dim=96, depth=4, num_heads=4,
                                    window_size_small=7, window_size_large=14,
                                    shift_range=1, dropout=0.1, rp_hidden_dim=32, mlp_ratio=4.0)
    model = model.to(device)
    
        
    # Print the number of trainable parameters.
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of trainable parameters: {num_params:,}")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.05)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    # ------------- Training Loop -------------
    for epoch in range(num_epochs):
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        scheduler.step()
        print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    # ------------- Save the Model -------------
    torch.save(model.state_dict(), "swin_transformer_advanced_tiny_imagenet.pth")
    print("Training complete and model saved.")

if __name__ == '__main__':
    main()


Number of trainable parameters: 697,344


KeyboardInterrupt: 