In [4]:
import sys
sys.path.append("..")
import torchvision.transforms as transforms
from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders

image_size =224
tiny_transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.Resize((image_size, image_size)), 
        transforms.RandomCrop(image_size, padding=5),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
tiny_transform_val = transforms.Compose([
        transforms.Resize((image_size, image_size)), 
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
tiny_transform_test = transforms.Compose([
        transforms.Resize((image_size, image_size)), 
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

train_loader, val_loader, test_loader = get_tinyimagenet_dataloaders(
                                                    data_dir = '../datasets',
                                                    transform_train=tiny_transform_train,
                                                    transform_val=tiny_transform_val,
                                                    transform_test=tiny_transform_test,
                                                    batch_size=64,
                                                    image_size=image_size)


In [5]:
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

############################################
# 1. Low-Rank Linear Projection Module
############################################

class LowRankLinear(nn.Module):
    """
    A linear layer whose weight matrix is factorized as A * B,
    reducing parameters while approximating a full linear projection.
    """
    def __init__(self, in_features, out_features, rank=16, bias=True):
        super().__init__()
        self.A = nn.Linear(in_features, rank, bias=False)
        self.B = nn.Linear(rank, out_features, bias=bias)

    def forward(self, x):
        return self.B(self.A(x))

############################################
# 2. Low-Rank Self-Attention Module
############################################

class LowRankSelfAttention(nn.Module):
    """
    Self-attention using low-rank factorization for Q, K, and V projections.
    Expects input tokens of shape (B, N, C).
    """
    def __init__(self, dim, num_heads=8, rank=16):
        super().__init__()
        assert dim % num_heads == 0, "dim must be divisible by num_heads."
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.q_proj = LowRankLinear(dim, dim, rank=rank)
        self.k_proj = LowRankLinear(dim, dim, rank=rank)
        self.v_proj = LowRankLinear(dim, dim, rank=rank)
        self.out_proj = nn.Linear(dim, dim)

    def forward(self, x):
        # x: (B, N, C)
        B, N, C = x.shape
        q = self.q_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).reshape(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        # Scaled dot-product attention
        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B, num_heads, N, N)
        attn = attn.softmax(dim=-1)
        out = attn @ v  # (B, num_heads, N, head_dim)
        out = out.transpose(1, 2).reshape(B, N, C)
        out = self.out_proj(out)
        return out

############################################
# 3. Adaptive Window Attention with Token Pruning
############################################

class AdaptiveWindowAttention(nn.Module):
    """
    Partitions a spatial feature map into windows with a dynamically computed window size,
    applies low-rank self-attention within each window, and then prunes tokens.
    
    Input is expected in (B, H, W, C) format.
    """
    def __init__(self, dim, num_heads=8, rank=16,
                 min_window=4, max_window=8, token_keep_ratio=0.7):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.rank = rank
        self.min_window = min_window
        self.max_window = max_window
        self.token_keep_ratio = token_keep_ratio

        # Low-rank self-attention for each window.
        self.attn = LowRankSelfAttention(dim, num_heads=num_heads, rank=rank)

        # Gating module: predicts a scalar to decide the window size.
        self.window_gate = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # input: (B, C, H, W)
            nn.Flatten(),
            nn.Linear(dim, 1)
        )
        # Token scoring layer for pruning.
        self.token_score = nn.Linear(dim, 1)

    def forward(self, x):
        # x: (B, H, W, C)
        B, H, W, C = x.shape

        # ---- Dynamic Window Size ----
        # Compute a gating scalar from global features.
        x_perm = x.permute(0, 3, 1, 2).contiguous()  # (B, C, H, W)
        gate_val = torch.sigmoid(self.window_gate(x_perm))  # (B, 1) with values in (0,1)
        # Average the gate values over the batch.
        gate_scalar = gate_val.mean().item()
        dynamic_window = int(round(self.min_window + gate_scalar * (self.max_window - self.min_window)))
        dynamic_window = max(self.min_window, min(dynamic_window, self.max_window))
        # Pad H and W so they are divisible by dynamic_window.
        pad_h = (dynamic_window - H % dynamic_window) % dynamic_window
        pad_w = (dynamic_window - W % dynamic_window) % dynamic_window
        if pad_h or pad_w:
            # Pad last two dimensions (H, W); using (left, right, top, bottom) padding.
            x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
        H_new, W_new = x.shape[1], x.shape[2]

        # ---- Partition into Windows ----
        # Reshape: (B, H_new/dynamic_window, dynamic_window, W_new/dynamic_window, dynamic_window, C)
        x_windows = x.view(B, H_new // dynamic_window, dynamic_window,
                           W_new // dynamic_window, dynamic_window, C)
        # Permute to (B, num_windows_H, num_windows_W, dynamic_window, dynamic_window, C)
        x_windows = x_windows.permute(0, 1, 3, 2, 4, 5).contiguous()
        num_windows = (H_new // dynamic_window) * (W_new // dynamic_window)
        # Merge batch and window dims: (B*num_windows, window_area, C)
        x_windows = x_windows.view(B * num_windows, dynamic_window * dynamic_window, C)

        # ---- Local Self-Attention ----
        x_windows = self.attn(x_windows)  # (B*num_windows, window_area, C)

        # ---- Token Pruning ----
        scores = self.token_score(x_windows).squeeze(-1)  # (B*num_windows, window_area)
        N = x_windows.shape[1]
        k = max(1, int(self.token_keep_ratio * N))  # number of tokens to keep per window
        # Get top-k token indices per window.
        _, topk_indices = torch.topk(scores, k, dim=1)
        # Gather tokens
        topk_indices_expanded = topk_indices.unsqueeze(-1).expand(-1, -1, C)
        pruned_tokens = torch.gather(x_windows, dim=1, index=topk_indices_expanded)  # (B*num_windows, k, C)
        # Merge windows back: shape becomes (B, num_windows*k, C)
        pruned_tokens = pruned_tokens.view(B, num_windows * k, C)
        return pruned_tokens

############################################
# 4. Shared Swin Block with Weight Sharing
############################################

class SharedSwinBlock(nn.Module):
    """
    A transformer block that applies adaptive window attention and an MLP.
    This module is designed to be re‑used (i.e. its weights are shared) across multiple layers.
    
    Accepts input either as a spatial grid (B, H, W, C) or as a sequence of tokens (B, N, C).
    """
    def __init__(self, dim, mlp_ratio=4.0, drop=0.0,
                 num_heads=8, rank=16, min_window=4, max_window=8, token_keep_ratio=0.7):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = AdaptiveWindowAttention(dim, num_heads=num_heads, rank=rank,
                                              min_window=min_window, max_window=max_window,
                                              token_keep_ratio=token_keep_ratio)
        self.norm2 = nn.LayerNorm(dim)
        hidden_dim = int(dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(drop)
        )

    def forward(self, x):
        # If input is spatial: (B, H, W, C)
        if x.dim() == 4:
            B, H, W, C = x.shape
            x_flat = x.view(B, H * W, C)
            x_norm = self.norm1(x_flat)
            x_spatial = x_norm.view(B, H, W, C)
            # Apply adaptive window attention (returns pruned tokens)
            attn_out = self.attn(x_spatial)  # shape: (B, N', C)
            # Apply MLP on pruned tokens (using another normalization)
            attn_out = attn_out + self.mlp(self.norm2(attn_out))
            # (Note: Residual connection is applied inside the MLP branch only.)
            return attn_out
        elif x.dim() == 3:
            # Input already in token form: (B, N, C)
            residual = x
            x = self.norm1(x)
            x = x + self.mlp(self.norm2(x))
            return x + residual
        else:
            raise ValueError("Input tensor must be 3D or 4D.")

############################################
# 5. Patch Embedding Module
############################################

class PatchEmbed(nn.Module):
    """
    Splits an image into patches and projects them to a desired embedding dimension.
    Input: (B, in_chans, img_size, img_size)
    Output: (B, H, W, embed_dim) with H = W = img_size/patch_size.
    """
    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim,
                              kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, embed_dim, H, W)
        x = x.permute(0, 2, 3, 1).contiguous()  # (B, H, W, embed_dim)
        return x

############################################
# 6. Overall SwinLite Model
############################################

class SwinLite(nn.Module):
    """
    A simplified Swin Transformer variant that:
      - Embeds image patches,
      - Applies a shared (weight-shared) transformer block multiple times,
      - Uses adaptive window attention with token pruning and low-rank projections,
      - Outputs a classification score.
    """
    def __init__(self, img_size=224, patch_size=4, in_chans=3,
                 num_classes=1000, embed_dim=96, depth=4,
                 mlp_ratio=4.0, drop=0.0, num_heads=8, rank=16,
                 min_window=4, max_window=8, token_keep_ratio=0.7):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size,
                                      in_chans=in_chans, embed_dim=embed_dim)
        # Create one shared block to be used repeatedly.
        self.shared_block = SharedSwinBlock(dim=embed_dim, mlp_ratio=mlp_ratio,
                                            drop=drop, num_heads=num_heads, rank=rank,
                                            min_window=min_window, max_window=max_window,
                                            token_keep_ratio=token_keep_ratio)
        self.depth = depth
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # x: (B, 3, img_size, img_size)
        x = self.patch_embed(x)  # (B, H, W, embed_dim)
        # Apply the shared block repeatedly.
        for i in range(self.depth):
            x = self.shared_block(x)
            # After the first block, x becomes a sequence: (B, N, embed_dim)
        x = self.norm(x)
        # Global average pooling over tokens.
        x = x.mean(dim=1)
        x = self.head(x)
        return x

############################################
# 7. Training Code for Tiny ImageNet
############################################

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        total_correct += (preds == labels).sum().item()
        total_samples += images.size(0)
    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples * 100
    return avg_loss, accuracy

def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            total_correct += (preds == labels).sum().item()
            total_samples += images.size(0)
    avg_loss = total_loss / total_samples
    accuracy = total_correct / total_samples * 100
    return avg_loss, accuracy

def main():
    # ----------- Configuration -----------
    num_classes = 200  # Tiny ImageNet has 200 classes.
    num_epochs = 50
    learning_rate = 1e-3


    # ----------- Device -----------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # ----------- Model, Loss, Optimizer, Scheduler -----------
    model = SwinLite(img_size=224, patch_size=4, in_chans=3, num_classes=num_classes,
                     embed_dim=96, depth=4, mlp_ratio=4.0, drop=0.1,
                     num_heads=8, rank=16, min_window=4, max_window=8, token_keep_ratio=0.7)
    model = model.to(device)
    
    
    # Print the number of trainable parameters.
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of trainable parameters: {num_params:,}")

    
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=0.05)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

    # ----------- Training Loop -----------
    for epoch in range(num_epochs):
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        scheduler.step()

        print(f"Epoch [{epoch+1}/{num_epochs}] | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    # Save the trained model.
    torch.save(model.state_dict(), "swinlite_tiny_imagenet.pth")
    print("Training complete and model saved.")

if __name__ == '__main__':
    main()


Number of trainable parameters: 117,898


KeyboardInterrupt: 