In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import transforms, datasets
import numpy as np

# --- Enhanced Data Augmentation ---
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(64, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(15),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.5, scale=(0.02, 0.1))
])

test_transform = transforms.Compose([
    transforms.Resize(72),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# --- Relative Position Attention ---
class WindowAttention(nn.Module):
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        
        # Relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads))
        
        # Generate relative position index
        coords_h = torch.arange(window_size)
        coords_w = torch.arange(window_size)
        coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing='ij'))
        coords_flatten = torch.flatten(coords, 1)
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()
        relative_coords[:, :, 0] += window_size - 1
        relative_coords[:, :, 1] += window_size - 1
        relative_coords[:, :, 0] *= 2 * window_size - 1
        relative_position_index = relative_coords.sum(-1)
        self.register_buffer("relative_position_index", relative_position_index)

        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

    def forward(self, x):
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        
        # Add relative position bias
        relative_position_bias = self.relative_position_bias_table[
            self.relative_position_index.view(-1)].view(
                self.window_size * self.window_size,
                self.window_size * self.window_size,
                -1)
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
        attn = attn + relative_position_bias.unsqueeze(0)
        
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        return x

# --- Improved Patch Merging ---
class PatchMerging(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.norm = nn.LayerNorm(4 * dim)
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)

    def forward(self, x):
        B, H, W, C = x.shape
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = torch.cat([x0, x1, x2, x3], -1)
        x = self.norm(x)
        x = self.reduction(x)
        return x

# --- SwinLite Block with Enhanced Components ---
class SwinLiteBlock(nn.Module):
    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, 
                 mlp_ratio=4.0, dropout=0.0, use_se=True):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.window_size = window_size
        self.shift_size = shift_size if min(input_resolution) > window_size else 0

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, window_size, num_heads)
        
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = DSConvMLP(dim=dim, mlp_ratio=mlp_ratio, dropout=dropout, use_se=use_se)
        
        # Create attention mask for shifted windows
        if self.shift_size > 0:
            H, W = self.input_resolution
            img_mask = torch.zeros((1, H, W, 1))
            h_slices = (slice(0, -window_size),
                        slice(-window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -window_size),
                        slice(-window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            mask_windows = window_partition(img_mask, window_size)
            mask_windows = mask_windows.view(-1, window_size * window_size)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
            self.register_buffer("attn_mask", attn_mask)
        else:
            self.attn_mask = None

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        x = x.view(B, H, W, C)
        
        # Shift window attention
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        x_windows = window_partition(shifted_x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
        
        attn_windows = self.attn(x_windows)
        
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)
        
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
            
        x = x.view(B, H * W, C)
        x = x + self.mlp(self.norm2(x))
        return x

# --- Complete SwinLite Model ---
class SwinLite(nn.Module):
    def __init__(self, image_size=64, patch_size=4, in_chans=3, num_classes=200,
                 embed_dim=48, depths=[2, 2, 6, 2], num_heads=[2, 4, 6, 8],
                 window_size=7, mlp_ratio=4.0, dropout=0.0, use_se=True):
        super().__init__()
        
        self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        patches_resolution = (image_size // patch_size, image_size // patch_size)
        
        self.layers = nn.ModuleList()
        dim = embed_dim
        
        for i in range(len(depths)):
            stage = self._make_stage(
                dim=dim,
                input_resolution=patches_resolution,
                depth=depths[i],
                num_heads=num_heads[i],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                use_se=use_se,
                downsample=(i < len(depths)-1)
            )
            self.layers.append(stage)
            
            if i < len(depths)-1:
                patches_resolution = (patches_resolution[0]//2, patches_resolution[1]//2)
                dim *= 2
                
        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_classes)
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias: nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias: nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
            elif isinstance(m, WindowAttention):
                nn.init.trunc_normal_(m.relative_position_bias_table, std=0.02)

    def _make_stage(self, dim, input_resolution, depth, num_heads, window_size,
                    mlp_ratio, dropout, use_se, downsample):
        blocks = []
        for i in range(depth):
            shift_size = window_size//2 if i % 2 == 1 else 0
            blocks.append(SwinLiteBlock(
                dim=dim,
                input_resolution=input_resolution,
                num_heads=num_heads,
                window_size=window_size,
                shift_size=shift_size,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                use_se=use_se
            ))
        down = PatchMerging(dim) if downsample else None
        return nn.ModuleDict({
            "blocks": nn.ModuleList(blocks),
            "downsample": down
        })

    def forward(self, x):
        x = self.patch_embed(x)
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        
        curr_res = (H, W)
        for stage in self.layers:
            for blk in stage["blocks"]:
                x = blk(x)
                
            if stage["downsample"]:
                B, L, C = x.shape
                x = x.view(B, curr_res[0], curr_res[1], C)
                x = stage["downsample"](x)
                x = x.flatten(1, 2)
                curr_res = (curr_res[0]//2, curr_res[1]//2)
                
        x = self.norm(x)
        x = x.mean(dim=1)
        return self.head(x)

# --- Training Configuration ---
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load datasets with enhanced transforms
    train_dataset = datasets.ImageFolder('path/to/train', train_transform)
    val_dataset = datasets.ImageFolder('path/to/val', test_transform)
    
    train_loader = data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
    val_loader = data.DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=4)

    model = SwinLite(
        image_size=64,
        patch_size=4,
        num_classes=200,
        embed_dim=64,
        depths=[2, 2, 6, 2],
        num_heads=[2, 4, 8, 16],
        window_size=8,
        mlp_ratio=4.0,
        dropout=0.1,
        use_se=True
    ).to(device)

    optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.05)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

    # Training loop with early stopping
    best_acc = 0
    for epoch in range(100):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            
            # Mixup augmentation
            alpha = 1.0
            lam = np.random.beta(alpha, alpha)
            index = torch.randperm(images.size(0)).to(device)
            
            mixed_images = lam * images + (1 - lam) * images[index]
            outputs = model(mixed_images)
            loss = lam * criterion(outputs, labels) + (1 - lam) * criterion(outputs, labels[index])
            
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                outputs = model(images.to(device))
                _, predicted = outputs.max(1)
                correct += predicted.eq(labels.to(device)).sum().item()
                total += labels.size(0)
        
        acc = 100 * correct / total
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), 'best_model.pth')
        
        scheduler.step()
        print(f'Epoch {epoch+1}: Val Acc {acc:.2f}%')

if __name__ == '__main__':
    main()

FileNotFoundError: [WinError 3] The system cannot find the path specified: 'path/to/train'