In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import os

sys.path.append("..")
import torchvision.transforms as transforms
from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders

# Enhanced Data Augmentation
image_size = 224
tiny_transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize((image_size + 20, image_size + 20)),
    transforms.RandomCrop(image_size, padding=8, padding_mode='reflect'),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.RandomGrayscale(p=0.1),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.1), value='random'),
])

tiny_transform_val = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

train_loader, val_loader, test_loader = get_tinyimagenet_dataloaders(
    data_dir='../datasets',
    transform_train=tiny_transform_train,
    transform_val=tiny_transform_val,
    transform_test=tiny_transform_val,
    batch_size=64,
    image_size=image_size
)





# Model components


class PatchEmbed(nn.Module):
    """Compact Patch Embedding with Depth-wise Convolution"""
    def __init__(self, img_size=64, patch_size=4, in_chans=3, embed_dim=48):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(in_chans, embed_dim // 2, kernel_size=3, stride=2, padding=1),
            nn.GELU(),
            nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=patch_size, stride=patch_size),
            nn.GroupNorm(4, embed_dim)
        )
        self.grid_size = img_size // patch_size
        self.num_patches = (img_size // patch_size) ** 2

    def forward(self, x):
        x = self.proj(x).flatten(2).transpose(1, 2)
        return x

class HybridAttention(nn.Module):
    """Depth-wise Conv Enhanced Window Attention with padding support"""
    def __init__(self, dim, num_heads=4, window_size=4, shift_size=0, groups=2):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.shift_size = shift_size
        self.num_heads = num_heads
        self.groups = groups

        # Depth-wise local feature enhancement
        self.dw_conv = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)
        self.norm = nn.GroupNorm(groups, dim)
        
        # Sparse attention parameters
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.scale = (dim // num_heads) ** -0.5

    def forward(self, x, H, W):
        B, L, C = x.shape
        # Reshape to spatial map: [B, H, W, C]
        x = x.view(B, H, W, C)
        
        # Pad H and W if they are not divisible by window_size
        pad_h = (self.window_size - H % self.window_size) % self.window_size
        pad_w = (self.window_size - W % self.window_size) % self.window_size
        if pad_h or pad_w:
            # For F.pad we need (pad_left, pad_right, pad_top, pad_bottom)
            # Here we pad width (last dim) and height (second-to-last dim).
            x = x.permute(0, 3, 1, 2)  # [B, C, H, W]
            x = F.pad(x, (0, pad_w, 0, pad_h))
            x = x.permute(0, 2, 3, 1)  # [B, H_pad, W_pad, C]
            H_pad, W_pad = H + pad_h, W + pad_w
        else:
            H_pad, W_pad = H, W

        # Depth-wise convolution enhancement
        x = x.permute(0, 3, 1, 2)  # [B, C, H_pad, W_pad]
        x = self.dw_conv(x) + x
        # Apply normalization: first permute to [B, H_pad, W_pad, C]
        x = self.norm(x).permute(0, 2, 3, 1).contiguous()

        # Optional shift
        if self.shift_size > 0:
            x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        
        # Window partitioning
        x = x.view(B,
                   H_pad // self.window_size, self.window_size,
                   W_pad // self.window_size, self.window_size,
                   C)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.window_size * self.window_size, C)

        # Grouped attention computation
        qkv = self.qkv(x).view(-1, self.window_size * self.window_size, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(-1, self.window_size * self.window_size, C)
        
        x = self.proj(x)
        
        # Window reconstruction
        x = x.view(-1,
                   H_pad // self.window_size, W_pad // self.window_size,
                   self.window_size, self.window_size, C)
        x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H_pad, W_pad, C)
        
        # Remove any padding if added
        if pad_h or pad_w:
            x = x[:, :H, :W, :].contiguous()
        
        if self.shift_size > 0:
            x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        
        return x.view(B, L, C)
class GroupMLP(nn.Module):
    """Parameter-efficient Grouped MLP"""
    def __init__(self, in_features, hidden_features=None, groups=2):
        super().__init__()
        hidden_features = hidden_features or in_features * 2
        self.groups = groups
        self.fc1 = nn.Conv1d(in_features, hidden_features, 1, groups=groups)
        self.act = nn.GELU()
        self.fc2 = nn.Conv1d(hidden_features, in_features, 1, groups=groups)
        
    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        return x.transpose(1, 2)

class SlimSwinBlock(nn.Module):
    """Compact Transformer Block"""
    def __init__(self, dim, num_heads, window_size=4, shift_size=0, groups=2):
        super().__init__()
        self.attn_norm = nn.GroupNorm(groups, dim)
        self.attn = HybridAttention(dim, num_heads, window_size, shift_size, groups)
        self.mlp_norm = nn.GroupNorm(groups, dim)
        self.mlp = GroupMLP(dim, groups=groups)
        
    def forward(self, x, H, W):
        # Apply norm on [B, L, C]: permute to [B, C, L], normalize, then permute back.
        x = x + self.attn(self.attn_norm(x.permute(0, 2, 1)).permute(0, 2, 1), H, W)
        x = x + self.mlp(self.mlp_norm(x.permute(0, 2, 1)).permute(0, 2, 1))
        return x

class SlimSwin(nn.Module):
    """Slim Swin Transformer Architecture"""
    def __init__(self, img_size=64, patch_size=4, in_chans=3, num_classes=200,
                 depths=[2, 2, 2], dims=[48, 96, 192], num_heads=[4, 8, 16], 
                 window_size=4, groups=2):
        super().__init__()
        self.stages = nn.ModuleList()
        current_dim = dims[0]
        
        # Initial patch embedding
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, current_dim)
        self.pos_drop = nn.Dropout(0.1)
        
        # Build transformer and merging stages
        for i, depth in enumerate(depths):
            # Transformer blocks stage
            stage = nn.Sequential(
                *[SlimSwinBlock(
                    dim=dims[i],
                    num_heads=num_heads[i],
                    window_size=window_size,
                    shift_size=0 if (i % 2 == 0) else window_size // 2,
                    groups=groups
                ) for _ in range(depth)]
            )
            self.stages.append(stage)
            
            if i != len(depths) - 1:
                # Efficient patch merging stage
                merge = nn.Sequential(
                    nn.GroupNorm(groups, dims[i]),
                    nn.Conv2d(dims[i], dims[i+1], 3, stride=2, padding=1, groups=groups)
                )
                self.stages.append(merge)
                
        self.norm = nn.GroupNorm(groups, dims[-1])
        self.head = nn.Linear(dims[-1], num_classes)
        
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
                
    def forward(self, x):
        x = self.patch_embed(x)  # [B, L, C]
        B, L, C = x.shape
        H = W = int(L ** 0.5)  # e.g., 28x28 for input images of size 224
        x = self.pos_drop(x)
        
        for stage in self.stages:
            # Determine if stage is a transformer block or a merging stage.
            if hasattr(stage[0], '__class__') and stage[0].__class__.__name__ == "SlimSwinBlock":
                # For transformer blocks, iterate over each block with H and W.
                for block in stage:
                    x = block(x, H, W)
            else:
                # Patch merging stage: reshape, process, then flatten back.
                x = x.view(B, H, W, C).permute(0, 3, 1, 2)  # [B, C, H, W]
                x = stage(x)
                B, C, H, W = x.shape
                x = x.view(B, C, -1).permute(0, 2, 1)  # [B, L, C]
        x = self.norm(x.mean(dim=1))
        return self.head(x)

# Training and evaluation routines

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f"Epoch {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}")

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print(f"\nTest set: Avg loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} "
          f"({100. * correct / len(test_loader.dataset):.0f}%)\n")

# Main script

if __name__ == "__main__":
    # Use CUDA if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Instantiate the model
    model = SlimSwin(
        img_size=64,  # Note: this parameter is used for grid computation in patch embedding.
        patch_size=4,
        num_classes=200,
        depths=[2, 2, 2],
        dims=[48, 96, 192],
        num_heads=[4, 8, 16],
        groups=4
    ).to(device)
    
    optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.05)
    
    num_epochs = 200
    for epoch in range(1, num_epochs + 1):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
    
    # Optional: test a forward pass with a random input
    x = torch.randn(2, 3, 64, 64).to(device)
    output = model(x)
    print(f"Model Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")
    print(f"Output shape: {output.shape}")  # Expected: (2, 200)

Epoch 1 [0/100000] Loss: 5.338192


KeyboardInterrupt: 