In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from einops import rearrange, reduce
import os
from timm.data import Mixup
from timm.loss import SoftTargetCrossEntropy
from timm.models.layers import DropPath
import numpy as np
import sys
sys.path.append("..")
from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders
import torch.nn.functional as F

# ------------------- Model Definition -------------------
class HybridStem(nn.Module):
    def __init__(self, in_ch=3, out_ch=48):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.GELU()
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_ch, out_ch, 3, padding=1, groups=out_ch),
            nn.BatchNorm2d(out_ch),
            nn.GELU(),
            nn.Conv2d(out_ch, out_ch*2, 1),
            nn.BatchNorm2d(out_ch*2),
            nn.GELU()
        )
        self.pool = nn.AvgPool2d(2, stride=2)
        
    def forward(self, x):
        x = self.conv1(x)  # 32x32
        x = self.conv2(x)
        x = self.pool(x)    # 16x16
        return x

class AdaptiveWindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_range=(4,8)):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_range = window_range

        self.ws_predictor = nn.Sequential(
            nn.Linear(dim, dim//4),
            nn.GELU(),
            nn.Linear(dim//4, len(range(*window_range)))
        )
        
        self.qkv = nn.Linear(dim, dim*3)
        self.proj = nn.Linear(dim, dim)
        
    def get_window_size(self, x):
        pooled = reduce(x, 'b h w c -> b c', 'mean')
        logits = self.ws_predictor(pooled)
        return torch.argmax(logits, dim=1) + self.window_range[0]

    def forward(self, x):
        B, H, W, C = x.shape
        # Get the predicted window size from the first sample
        ws_pred = self.get_window_size(x)[0].item()
        # All candidate window sizes in the specified range
        ws_candidates = list(range(*self.window_range))
        # Filter candidates that evenly divide H and W
        valid_candidates = [ws for ws in ws_candidates if H % ws == 0 and W % ws == 0]
        if valid_candidates:
            # Choose the candidate closest to the predicted value
            ws = min(valid_candidates, key=lambda x: abs(x - ws_pred))
        else:
            # If none of the candidates work, fall back to using H (assuming H==W)
            ws = H
        # Now ws divides H and W exactly, so reshape works correctly
        x = x.view(B, H // ws, ws, W // ws, ws, C)
        windows = rearrange(x, 'b h1 w1 h2 w2 c -> (b h1 w1) (h2 w2) c')
        
        qkv = self.qkv(windows).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), qkv)
        
        attn = (q @ k.transpose(-2, -1)) * (C ** -0.5)
        attn = attn.softmax(dim=-1)
        
        x = (attn @ v).transpose(1, 2).reshape(B * (H // ws) * (W // ws), ws * ws, C)
        x = self.proj(x)
        
        x = x.view(B, H // ws, W // ws, ws, ws, C)
        x = rearrange(x, 'b h w ws1 ws2 c -> b (h ws1) (w ws2) c')
        return x

class SparseCrossAttention(nn.Module):
    def __init__(self, dim, num_heads, sparse_ratio=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.sparse_ratio = sparse_ratio
        self.global_token = nn.Parameter(torch.randn(1, 1, dim))
        self.qkv = nn.Linear(dim, dim*3)
        self.proj = nn.Linear(dim, dim)
        
    def forward(self, x):
        B, H, W, C = x.shape
        cls_token = self.global_token.expand(B, -1, -1)
        local_feat = rearrange(x, 'b h w c -> b (h w) c')
        combined = torch.cat([cls_token, local_feat], dim=1)
        
        qkv = self.qkv(combined).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), qkv)
        
        scores = torch.norm(q, dim=-1)
        _, idx = torch.topk(scores, k=int(scores.size(-1)*self.sparse_ratio), dim=-1)
        
        sparse_q = torch.gather(q, -2, idx.unsqueeze(-1).expand(-1, -1, -1, C//self.num_heads))
        attn = (sparse_q @ k.transpose(-2, -1)) * (C ** -0.5)
        attn = attn.softmax(dim=-1)
        
        x = (attn @ v).transpose(1, 2).reshape(B, -1, C)
        x = self.proj(x)
        return x[:, 0]


class DWSSBlock(nn.Module):
    def __init__(self, dim, num_heads, window_range, sparse_ratio=0.1, drop_path=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.awa = AdaptiveWindowAttention(dim, num_heads, window_range)
        self.sca = SparseCrossAttention(dim, num_heads, sparse_ratio)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim*4),
            nn.GELU(),
            nn.Linear(dim*4, dim)
        )
        # Projection to 128 dimensions so that all cls tokens match distillation target.
        self.cls_proj = nn.Linear(dim, 128)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
    def forward(self, x):
        x = x + self.drop_path(self.awa(self.norm1(x)))
        cls_token = self.sca(self.norm2(x))
        # Map cls token to common dimension (128)
        cls_token = self.cls_proj(cls_token)
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x, cls_token

class DWSS(nn.Module):
    def __init__(self, num_classes=200):
        super().__init__()
        self.stem = HybridStem()
        
        # Stage 1
        self.stage1 = nn.ModuleList([
            DWSSBlock(96, 4, (4,6), 0.1) for _ in range(2)
        ])
        
        # Stage 2
        self.down1 = nn.Sequential(
            nn.Conv2d(96, 128, 2, stride=2),
            nn.BatchNorm2d(128),
            nn.GELU()
        )
        
        self.stage2 = nn.ModuleList([
            DWSSBlock(128, 8, (6,8), 0.2) for _ in range(4)
        ])
        
        # Stage 3
        self.down2 = nn.Sequential(
            nn.Conv2d(128, 256, 2, stride=2),
            nn.BatchNorm2d(256),
            nn.GELU()
        )
        
        self.stage3 = nn.ModuleList([
            DWSSBlock(256, 16, (8,9), 0.3) for _ in range(2)
        ])
        
        # Head
        self.distill_token = nn.Parameter(torch.randn(1, 1, 256))
        self.distill_head = nn.Linear(256, 128)
        self.head = nn.Linear(256, num_classes)
        
    def forward(self, x):
        x = self.stem(x)  # [B, 96, 16, 16]
        x = rearrange(x, 'b c h w -> b h w c')
        
        distill_loss = []
        for blk in self.stage1:
            x, cls = blk(x)
            distill_loss.append(cls)
        
        x = rearrange(x, 'b h w c -> b c h w')
        x = self.down1(x)
        x = rearrange(x, 'b c h w -> b h w c')
        
        for blk in self.stage2:
            x, cls = blk(x)
            distill_loss.append(cls)
        
        x = rearrange(x, 'b h w c -> b c h w')
        x = self.down2(x)
        x = rearrange(x, 'b c h w -> b h w c')
        
        for blk in self.stage3:
            x, cls = blk(x)
            distill_loss.append(cls)
        
        # Distillation
        distill_target = self.distill_head(self.distill_token.expand(x.size(0), -1, -1))
        distill_loss = torch.stack([
            F.kl_div(
                F.log_softmax(d, dim=-1),
                F.softmax(distill_target, dim=-1)
            ) for d in distill_loss
        ]).mean()
        
        # Classification
        x = reduce(x, 'b h w c -> b c', 'mean')
        return self.head(x), distill_loss

# ------------------- Data Loading -------------------
class TinyImageNet:
    def __init__(self, root='./data', split='train', img_size=64):
        self.root = root
        self.split = split
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])
        self.data = torchvision.datasets.ImageFolder(
            os.path.join(root, split), transform=self.transform)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

# ------------------- Training Utilities -------------------
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {}
        self.original = {}
        
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def apply(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.original[name] = param.data.clone()
                param.data = self.shadow[name]

    def update(self, model):
        with torch.no_grad():
            for name, param in model.named_parameters():
                if param.requires_grad:
                    new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                    self.shadow[name] = new_average.clone()

def train_epoch(model, loader, optimizer, criterion, ema, mixup_fn, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        
        if mixup_fn is not None:
            inputs, targets = mixup_fn(inputs, targets)
            
        optimizer.zero_grad()
        outputs, distill_loss = model(inputs)
        
        if mixup_fn is not None:
            loss = criterion(outputs, targets)
        else:
            loss = F.cross_entropy(outputs, targets)
            
        loss += 0.3 * distill_loss
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        ema.update(model)
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        # Convert soft targets to hard labels if necessary
        if targets.dim() > 1:
            hard_targets = targets.argmax(dim=1)
        else:
            hard_targets = targets
        total += hard_targets.size(0)
        correct += predicted.eq(hard_targets).sum().item()
    
    return total_loss / len(loader), 100 * correct / total


    

@torch.no_grad()
def validate(model, loader, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)
        outputs, _ = model(inputs)
        loss = F.cross_entropy(outputs, targets)
        
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        # Convert soft targets to hard labels if necessary
        if targets.dim() > 1:
            hard_targets = targets.argmax(dim=1)
        else:
            hard_targets = targets
        total += hard_targets.size(0)
        correct += predicted.eq(hard_targets).sum().item()
    
    return total_loss / len(loader), 100 * correct / total

# ------------------- Main Training Loop -------------------
def main():
    # Hard-coded hyperparameters and configurations
    data_dir = './data'
    epochs = 300
    batch_size = 64  # Adjust as needed
    lr = 2e-4
    wd = 0.05
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    image_size = 224
    tiny_transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.Resize((image_size + 20, image_size + 20)),
        transforms.RandomCrop(image_size, padding=8, padding_mode='reflect'),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.RandomGrayscale(p=0.1),
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        transforms.RandomErasing(p=0.5, scale=(0.02, 0.1), value='random'),
    ])

    tiny_transform_val = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    train_loader, val_loader, test_loader = get_tinyimagenet_dataloaders(
        data_dir='../datasets',
        transform_train=tiny_transform_train,
        transform_val=tiny_transform_val,
        transform_test=tiny_transform_val,
        batch_size=batch_size,
        image_size=image_size
    )

    mixup_fn = Mixup(mixup_alpha=0.2, cutmix_alpha=1.0, prob=0.5, num_classes=200)

    # Initialize model, optimizer, scheduler, loss function, and EMA
    model = DWSS(num_classes=200).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    criterion = SoftTargetCrossEntropy() if mixup_fn is not None else nn.CrossEntropyLoss()
    ema = EMA(model)
    
    best_acc = 0
    for epoch in range(epochs):
        train_loss, train_acc = train_epoch(
            model, train_loader, optimizer, criterion, ema, mixup_fn, device
        )
        val_loss, val_acc = validate(model, val_loader, device)
        
        scheduler.step()
        
        print(f"Epoch {epoch+1}/{epochs}")
        print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.2f}%")
        
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), 'best_dwss.pth')
    
    print(f"Best Validation Accuracy: {best_acc:.2f}%")

if __name__ == '__main__':
    main()


: 