In [2]:
import torch
import sys
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
sys.path.append("..")
import torchvision.transforms as transforms
from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders

# Enhanced Data Augmentation
image_size = 224
tiny_transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize((image_size + 20, image_size + 20)),
    transforms.RandomCrop(image_size, padding=8, padding_mode='reflect'),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.RandomGrayscale(p=0.1),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.1), value='random'),
])

tiny_transform_val = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
])

train_loader, val_loader, test_loader = get_tinyimagenet_dataloaders(
    data_dir='../datasets',
    transform_train=tiny_transform_train,
    transform_val=tiny_transform_val,
    transform_test=tiny_transform_val,
    batch_size=64,
    image_size=image_size
)

# Helper function: window_reverse (used in SwinLiteBlock)
def window_reverse(windows, window_size, H, W):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, C)
        window_size (int): window size
        H (int): Height of image
        W (int): Width of image
    Returns:
        x: (B, H, W, C)
    """
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

class SEBlock(nn.Module):
    def __init__(self, dim, reduction=4):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim // reduction)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(dim // reduction, dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, n, c = x.shape
        x_mean = x.mean(dim=1)
        y = self.fc1(x_mean)
        y = self.relu(y)
        y = self.fc2(y)
        return x * self.sigmoid(y).unsqueeze(1)

class DSConvMLP(nn.Module):
    def __init__(self, dim, mlp_ratio=4.0, dropout=0.2, use_se=True):
        super().__init__()
        hidden_dim = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.depthwise = nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, 
                                   padding=1, groups=hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.drop = nn.Dropout(dropout)
        self.se = SEBlock(dim) if use_se else nn.Identity()

    def forward(self, x):
        x = self.drop(self.act(self.fc1(x)))
        x = self.drop(self.act(self.depthwise(x.permute(0, 2, 1)).permute(0, 2, 1)))
        return self.se(self.fc2(x))

class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.window_size = window_size

    def forward(self, x):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        x = (attn.softmax(dim=-1) @ v).transpose(1, 2).reshape(B_, N, C)
        return self.proj(x)

class SwinLiteBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 
                 mlp_ratio=4.0, dropout=0.2, use_se=True):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, window_size, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = DSConvMLP(dim, mlp_ratio, dropout, use_se)
        self.shift_size = shift_size if min(input_resolution) > window_size else 0
        self.window_size = window_size
        self.input_resolution = input_resolution

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        
        # Window attention
        shortcut = x
        x = self.norm1(x).view(B, H, W, C)
        
        # Shift window if needed
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x
            
        # Partition into windows and apply attention
        x_windows = shifted_x.view(B, H // self.window_size, self.window_size,
                                   W // self.window_size, self.window_size, C)
        x_windows = x_windows.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, self.window_size**2, C)
        attn_windows = self.attn(x_windows)
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)
        
        # Reverse shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
            
        # Merge windows and add residual connection
        x = x.view(B, H * W, C) + shortcut
        
        # MLP and residual connection
        return self.mlp(self.norm2(x)) + x

class SwinLite(nn.Module):
    def __init__(self, image_size=224, patch_size=4, in_chans=3, num_classes=200,
                 embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
                 window_size=7, mlp_ratio=4.0, dropout=0.2, use_se=True):
        super().__init__()
        self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.pos_drop = nn.Dropout(dropout)
        
        # Calculate initial resolution
        self.curr_res = (image_size // patch_size, image_size // patch_size)
        self.layers = nn.ModuleList()
        dim = embed_dim
        
        # Build stages: stage blocks followed by downsample layers
        for i, depth in enumerate(depths):
            # Stage blocks
            stage_blocks = []
            for j in range(depth):
                stage_blocks.append(
                    SwinLiteBlock(
                        dim=dim,
                        input_resolution=self.curr_res,
                        num_heads=num_heads[i],
                        window_size=window_size,
                        shift_size=0 if (j % 2 == 0) else window_size // 2,
                        mlp_ratio=mlp_ratio,
                        dropout=dropout,
                        use_se=use_se
                    )
                )
            self.layers.append(nn.Sequential(*stage_blocks))
            
            # Add downsample layer except for the last stage
            if i != len(depths) - 1:
                downsample = nn.Sequential(
                    nn.LayerNorm(dim),
                    nn.Conv2d(dim, dim * 2, kernel_size=2, stride=2)
                )
                downsample.is_downsample = True  # Mark this layer as downsample
                self.layers.append(downsample)
                dim *= 2
                self.curr_res = (self.curr_res[0] // 2, self.curr_res[1] // 2)
                
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_classes)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        # Initial patch embedding
        x = self.patch_embed(x)  # (B, C, H, W)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # (B, N, C)
        x = self.pos_drop(x)
        
        # Track current resolution
        curr_h, curr_w = H, W
        
        # Process each layer: stage blocks or downsample layers
        for layer in self.layers:
            if hasattr(layer, 'is_downsample') and layer.is_downsample:
                # Reshape to spatial dimensions for downsampling
                x = x.transpose(1, 2).view(B, -1, curr_h, curr_w)
                # Apply LayerNorm and Conv2d (with proper permutation)
                x = layer[0](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
                x = layer[1](x)
                curr_h, curr_w = curr_h // 2, curr_w // 2
                x = x.flatten(2).transpose(1, 2)
            else:
                x = layer(x)
        
        # Final classification head
        x = self.norm(x.mean(1))
        return self.head(x)

# Training Configuration
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SwinLite().to(device)
    
    # Optimizer and loss with regularization
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.1)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
    
    best_val_acc = 0.0
    
    for epoch in range(200):
        model.train()
        train_loss = correct = total = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            
            # Gradient clipping
            nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            train_loss += loss.item() * images.size(0)
            correct += outputs.argmax(1).eq(labels).sum().item()
            total += labels.size(0)
        
        scheduler.step()
        
        # Validation loop
        model.eval()
        val_loss = val_correct = val_total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                val_loss += criterion(outputs, labels).item() * images.size(0)
                val_correct += outputs.argmax(1).eq(labels).sum().item()
                val_total += labels.size(0)
        
        val_acc = 100 * val_correct / val_total
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_swinlite.pth")
        
        print(f"Epoch {epoch+1}/200")
        print(f"Train Loss: {train_loss/total:.4f} | Acc: {100*correct/total:.2f}%")
        print(f"Val Loss: {val_loss/val_total:.4f} | Acc: {val_acc:.2f}%")
        print(f"LR: {scheduler.get_last_lr()[0]:.2e}\n")
    
    # Final Test Evaluation
    model.load_state_dict(torch.load("best_swinlite.pth"))
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            correct += outputs.argmax(1).eq(labels).sum().item()
            total += labels.size(0)
    
    print(f"Final Test Accuracy: {100*correct/total:.2f}%")

if __name__ == "__main__":
    main()


KeyboardInterrupt: 