In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
import torchvision.datasets as datasets

sys.path.append("..")
from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders

#############################################################################
# 1. TinyImageNet Dataloader (If you have your own loader, use that instead)
#############################################################################

# If you have a custom TinyImageNet_loader.py, you can import the function
# get_tinyimagenet_dataloaders. Otherwise, here's a simple reference using
# torchvision's ImageFolder (assuming your data is structured like:
#
# data_dir/
#   train/
#       class_1/
#           img1.jpeg
#           ...
#       class_2/
#           img1.jpeg
#           ...
#   val/
#       class_1/
#           img1.jpeg
#           ...
#       ...
#   test/   (if you have a separate test folder)
#       class_1/
#           ...
#
# If your folder structure is different, adjust accordingly.
#
# For the official TinyImageNet, you need to reshape the val folder
# by class. That means placing each val image into subfolders named
# by their label. For many standard distributions, this is already done.
#############################################################################





image_size = 224
tiny_transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop(image_size, scale=(0.08, 1.0)),  # stronger random crop
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

tiny_transform_val = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

tiny_transform_test = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])
train_loader, val_loader, test_loader = get_tinyimagenet_dataloaders(
                                                    data_dir = '../datasets',
                                                    transform_train=tiny_transform_train,
                                                    transform_val=tiny_transform_val,
                                                    transform_test=tiny_transform_test,
                                                    batch_size=64,
                                                    image_size=image_size)


#############################################################################
# 3. Loaders
#############################################################################

data_dir = "../datasets"  # Adjust this path to your TinyImageNet folder
batch_size = 64
train_loader, val_loader, test_loader = get_tinyimagenet_dataloaders(
    data_dir=data_dir,
    transform_train=tiny_transform_train,
    transform_val=tiny_transform_val,
    transform_test=tiny_transform_test,
    batch_size=batch_size,
    image_size=image_size
)

#############################################################################
# 4. Squeeze-and-Excitation (SE) Block
#############################################################################

class SEBlock(nn.Module):
    """
    Squeeze-and-Excitation block for channel attention
    """
    def __init__(self, dim, reduction=4):
        super().__init__()
        self.fc1 = nn.Linear(dim, dim // reduction)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(dim // reduction, dim)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: (B, N, C)
        b, n, c = x.shape
        # Squeeze
        x_mean = x.mean(dim=1)   # (B, C)
        # Excitation
        y = self.fc1(x_mean)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y)
        y = y.unsqueeze(1)  # (B, 1, C)
        return x * y        # broadcast over tokens

#############################################################################
# 5. Depthwise-Separable MLP
#############################################################################

class DSConvMLP(nn.Module):
    """
    Depthwise-Separable MLP + optional Squeeze-and-Excitation
    """
    def __init__(self, dim, mlp_ratio=4.0, dropout=0.0, use_se=True):
        super().__init__()
        hidden_dim = int(dim * mlp_ratio)
        self.fc1 = nn.Linear(dim, hidden_dim)

        # Depthwise 1D Conv (across token dimension)
        self.depthwise = nn.Conv1d(
            in_channels=hidden_dim,
            out_channels=hidden_dim,
            kernel_size=3,
            padding=1,
            groups=hidden_dim
        )

        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.drop = nn.Dropout(dropout)
        self.use_se = use_se
        if use_se:
            self.se = SEBlock(dim)

    def forward(self, x):
        # x: (B, N, C)
        x_fc1 = self.fc1(x)  # (B, N, hidden_dim)
        x_fc1 = self.act(x_fc1)
        x_fc1 = self.drop(x_fc1)

        b, n, h = x_fc1.shape
        # Depthwise convolution across N dimension
        x_dw = x_fc1.permute(0, 2, 1)   # (B, hidden_dim, N)
        x_dw = self.depthwise(x_dw)     # (B, hidden_dim, N)
        x_dw = x_dw.permute(0, 2, 1)    # (B, N, hidden_dim)

        x_dw = self.act(x_dw)
        x_dw = self.drop(x_dw)

        x_fc2 = self.fc2(x_dw)  # (B, N, dim)
        x_fc2 = self.drop(x_fc2)

        if self.use_se:
            x_fc2 = self.se(x_fc2)

        return x_fc2

#############################################################################
# 6. Window Partition / Reverse
#############################################################################

def window_partition(x, window_size):
    """
    x: (B, H, W, C)
    return: (num_windows*B, window_size, window_size, C)
    """
    B, H, W, C = x.shape
    x = x.view(
        B,
        H // window_size, window_size,
        W // window_size, window_size,
        C
    )
    # reorder => (B, num_win_h, num_win_w, wsize, wsize, C) => (num_win*B, wsize, wsize, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    windows = windows.view(-1, window_size, window_size, C)
    return windows

def window_reverse(windows, window_size, H, W):
    """
    windows: (num_windows*B, window_size, window_size, C)
    return: (B, H, W, C)
    """
    B_ = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(
        B_,
        H // window_size,
        W // window_size,
        window_size,
        window_size,
        -1
    )
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    x = x.view(B_, H, W, -1)
    return x

#############################################################################
# 7. WindowAttention
#############################################################################

class WindowAttention(nn.Module):
    """
    A simplified Window-based Multi-Head Self-Attention
    """
    def __init__(self, dim, window_size, num_heads):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        """
        x: (num_windows*B, window_size*window_size, C)
        """
        B_, N, C = x.shape
        # Project QKV
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B_, heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale  # (B_, heads, N, N)
        attn = attn.softmax(dim=-1)

        out = attn @ v  # (B_, heads, N, head_dim)
        out = out.transpose(1, 2).reshape(B_, N, C)
        out = self.proj(out)
        return out

#############################################################################
# 8. SwinLite Block
#############################################################################

class SwinLiteBlock(nn.Module):
    """
    One Swin-like block. Shift or no-shift, window-based attention + DSConvMLP.
    """
    def __init__(
        self,
        dim,
        input_resolution,
        num_heads,
        window_size=7,
        shift_size=0,
        mlp_ratio=4.0,
        dropout=0.0,
        use_se=True
    ):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size

        if min(input_resolution) > window_size:
            self.shift_size = shift_size
        else:
            self.shift_size = 0

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, window_size=self.window_size, num_heads=num_heads)

        self.norm2 = nn.LayerNorm(dim)
        self.mlp = DSConvMLP(dim=dim, mlp_ratio=mlp_ratio, dropout=dropout, use_se=use_se)

    def forward(self, x):
        """
        x shape: (B, H*W, C).
        We'll reshape -> (B, H, W, C), do window attn, then reshape back.
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, f"Input feature has wrong size {L} != {H*W}"

        # (1) Norm + Reshape
        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # (2) Shift if needed
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # (3) Window partition
        x_windows = window_partition(shifted_x, self.window_size)  # (nW*B, wsize, wsize, C)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # (nW*B, wsize^2, C)

        # (4) Window Attention
        attn_windows = self.attn(x_windows)

        # (5) Merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)

        # (6) Reverse shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x

        # (7) Flatten back
        x = x.view(B, H * W, C)
        x = shortcut + x  # skip

        # (8) MLP
        shortcut2 = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = shortcut2 + x

        return x

#############################################################################
# 9. SwinLite Model (adjusted for better capacity)
#############################################################################

class SwinLite(nn.Module):
    """
    A slightly bigger Swin-like model for Tiny ImageNet.
    """
    def __init__(
        self,
        image_size=224,
        patch_size=4,
        in_chans=3,
        num_classes=200,
        embed_dim=96,         # Increased from 48 -> 96 for better capacity
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],  # Standard Swin-T-like ratios
        window_size=7,
        mlp_ratio=4.0,
        dropout=0.0,
        use_se=True
    ):
        super().__init__()

        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.num_classes = num_classes

        # Patch embedding
        self.patch_embed = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

        self.pos_drop = nn.Dropout(p=dropout)

        # resolution after patch embedding
        patches_resolution = (
            image_size // patch_size,
            image_size // patch_size
        )

        self.layers = nn.ModuleList()
        dim = embed_dim

        for i in range(len(depths)):
            stage = self._make_stage(
                dim=dim,
                input_resolution=patches_resolution,
                depth=depths[i],
                num_heads=num_heads[i],
                window_size=window_size,
                mlp_ratio=mlp_ratio,
                dropout=dropout,
                use_se=use_se,
                downsample=(i < len(depths) - 1)
            )
            self.layers.append(stage)

            # If downsampling, the resolution is halved, channels doubled
            if i < len(depths) - 1:
                patches_resolution = (
                    patches_resolution[0] // 2,
                    patches_resolution[1] // 2
                )
                dim *= 2

        self.norm = nn.LayerNorm(dim)
        self.head = nn.Linear(dim, num_classes)

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def _make_stage(
        self,
        dim,
        input_resolution,
        depth,
        num_heads,
        window_size,
        mlp_ratio,
        dropout,
        use_se,
        downsample
    ):
        blocks = []
        for d in range(depth):
            shift_size = 0 if (d % 2 == 0) else window_size // 2
            blocks.append(
                SwinLiteBlock(
                    dim=dim,
                    input_resolution=input_resolution,
                    num_heads=num_heads,
                    window_size=window_size,
                    shift_size=shift_size,
                    mlp_ratio=mlp_ratio,
                    dropout=dropout,
                    use_se=use_se
                )
            )

        down = None
        if downsample:
            down = nn.ModuleDict({
                "ln": nn.LayerNorm(dim),
                "conv": nn.Conv2d(dim, dim * 2, kernel_size=2, stride=2)
            })

        return nn.ModuleDict({
            "blocks": nn.ModuleList(blocks),
            "downsample": down
        })

    def forward(self, x):
        """
        Forward pass.
        """
        # x: (B, 3, H, W)
        x = self.patch_embed(x)  # => (B, embed_dim, H//patch, W//patch)
        B, C, H_, W_ = x.shape

        # Flatten + transpose => (B, H'*W', embed_dim)
        x = x.view(B, C, H_ * W_).transpose(1, 2)
        x = self.pos_drop(x)

        curr_resolution = (H_, W_)
        for stage in self.layers:
            # Blocks
            for blk in stage["blocks"]:
                x = blk(x)

            # Downsample
            if stage["downsample"] is not None:
                B_, N_, C_ = x.shape
                h_, w_ = curr_resolution

                # LN
                x = stage["downsample"]["ln"](x)

                # Reshape to (B, C, h, w)
                x = x.view(B_, h_, w_, C_).permute(0, 3, 1, 2)

                # Conv2d => (B, 2*C, h//2, w//2)
                x = stage["downsample"]["conv"](x)

                # Flatten => (B, h_*w_, 2*C)
                h_, w_ = h_ // 2, w_ // 2
                x = x.permute(0, 2, 3, 1).flatten(1, 2)

                curr_resolution = (h_, w_)

        # Final norm + classification
        x = self.norm(x)   # (B, N, final_dim)
        x = x.mean(dim=1)  # => (B, final_dim)
        x = self.head(x)   # => (B, num_classes)
        return x

#############################################################################
# 10. Training / Validation / Test Routines
#############################################################################

def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc


def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc


def test(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)
    test_acc = 100.0 * correct / total
    return test_acc

#############################################################################
# 11. Main
#############################################################################

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Hyperparameters
    epochs = 100  # You can set more epochs if you have the resources (e.g., 200)
    lr = 1e-3
    weight_decay = 0.05

    # Build a slightly larger SwinLite
    model = SwinLite(
        image_size=image_size,
        patch_size=4,
        in_chans=3,
        num_classes=200,
        embed_dim=96,                # bigger embed
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],    # typical for Swin-T
        window_size=7,
        mlp_ratio=4.0,
        dropout=0.0,
        use_se=True
    ).to(device)

    print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    # Cosine Annealing LR Scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    best_val_acc = 0.0

    for epoch in range(epochs):
        print(f"===== EPOCH {epoch+1} / {epochs} =====")

        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)

        # Step the scheduler after each epoch
        scheduler.step()

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_swinlite_task_2.pth")

        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%  "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%\n")

    print(f"Best validation accuracy: {best_val_acc:.2f}%")

    # Evaluate on test set if available
    if test_loader is not None:
        model.load_state_dict(torch.load("best_swinlite.pth"))
        test_acc = test(model, test_loader, device)
        print(f"Test Accuracy: {test_acc:.2f}%")


if __name__ == "__main__":
    main()


Using device: cpu
Number of parameters: 28801400
===== EPOCH 1 / 100 =====


KeyboardInterrupt: 