In [1]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import wandb
import numpy as np
import torch.nn.functional as F
from einops.layers.torch import Rearrange
import math
import random
import copy

####################
# Helper Functions #
####################

def set_seed(seed=42):
    """
    Set random seeds for reproducibility.
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def l2norm(t, dim=-1):
    return F.normalize(t, dim=dim, p=2)

####################
# Model Components #
####################

class NormLinear(nn.Module):
    def __init__(self, dim, dim_out):
        super().__init__()
        self.linear = nn.Linear(dim, dim_out, bias=False)
        self.norm = nn.LayerNorm(dim_out)

    def forward(self, x):
        return self.norm(self.linear(x))

class Attention(nn.Module):
    def __init__(self, dim, dim_head=64, heads=8, dropout=0.):
        super().__init__()
        dim_inner = dim_head * heads
        self.to_q = NormLinear(dim, dim_inner)
        self.to_k = NormLinear(dim, dim_inner)
        self.to_v = NormLinear(dim, dim_inner)
        self.dropout = nn.Dropout(dropout)
        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_out = nn.Sequential(
            nn.Linear(dim_inner, dim),
            nn.LayerNorm(dim)
        )

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        q = self.to_q(x).reshape(b, n, h, -1).permute(0, 2, 1, 3) * self.scale
        k = self.to_k(x).reshape(b, n, h, -1).permute(0, 2, 1, 3)
        v = self.to_v(x).reshape(b, n, h, -1).permute(0, 2, 1, 3)

        attn = (q @ k.transpose(-2, -1)).softmax(dim=-1)
        attn = self.dropout(attn)

        out = (attn @ v).transpose(1, 2).reshape(b, n, -1)
        return self.to_out(out)

class FeedForward(nn.Module):
    def __init__(self, dim, mlp_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class nViT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, dropout=0.):
        super().__init__()
        assert image_size % patch_size == 0, "Image size must be divisible by patch size."
        num_patches = (image_size // patch_size) ** 2
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (c p1 p2)', p1=patch_size, p2=patch_size),
            NormLinear(patch_size * patch_size * 3, dim)
        )
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, dim_head=64, heads=heads, dropout=dropout),
                FeedForward(dim, mlp_dim, dropout=dropout)
            ]))
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        x = self.to_patch_embedding(x) + self.pos_embedding
        for attn, ff in self.layers:
            x = attn(x) + ff(x)
        x = x.mean(dim=1)
        return self.mlp_head(x)

#####################
# Data Augmentations#
#####################

def get_transforms(image_size, augment_level):
    if augment_level == 'advanced':
        train_transform = transforms.Compose([
            transforms.RandomCrop(image_size, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408),
                                 (0.2675, 0.2565, 0.2761)),
        ])
    else:  # baseline
        train_transform = transforms.Compose([
            transforms.RandomCrop(image_size, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4867, 0.4408),
                                 (0.2675, 0.2565, 0.2761)),
        ])

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),
                             (0.2675, 0.2565, 0.2761)),
    ])
    return train_transform, test_transform

###############################
# Experiment Runner Function  #
###############################

def run_experiment():
    """
    Runs the experiment with Advanced Augmentations without ConvStem and no MixUp.
    """
    # Initialize wandb
    config = {
        'epochs': 100,
        'batch_size': 128,
        'learning_rate': 3e-4,
        'weight_decay': 1e-4,
        'image_size': 32,
        'patch_size': 4,
        'dim': 384,
        'depth': 6,
        'heads': 6,
        'mlp_dim': 384*4,
        'dropout': 0.1,
        'num_classes': 100,
        'patience': 20,
        'augment_level': 'advanced',
        'use_convstem': False,
        'use_mixup': False
    }

    run_name = f"nViT_CIFAR100_convstem={config['use_convstem']}_mixup={config['use_mixup']}_aug={config['augment_level']}"
    wandb.init(project='nvit-cifar100-ablation', config=config, name=run_name)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nRunning Experiment: {run_name} on device: {device}\n")

    # Set seeds for reproducibility
    set_seed(42)

    # Data augmentation
    train_transform, test_transform = get_transforms(config['image_size'], config['augment_level'])

    # Load CIFAR-100 dataset
    train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=train_transform)
    test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=test_transform)

    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True,
                              num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False,
                             num_workers=4, pin_memory=True)

    # Initialize model
    model = nViT(
        image_size=config['image_size'],
        patch_size=config['patch_size'],
        num_classes=config['num_classes'],
        dim=config['dim'],
        depth=config['depth'],
        heads=config['heads'],
        mlp_dim=config['mlp_dim'],
        dropout=config['dropout']
    ).to(device)

    # Define Loss Function and Optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=config['learning_rate'],
                            weight_decay=config['weight_decay'])
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config['epochs'])

    # Early Stopping Parameters
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    patience = config['patience']
    trigger_times = 0

    # Training Loop
    for epoch in range(config['epochs']):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient Clipping
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # Logging intermediate batch metrics
            if batch_idx % 100 == 0:
                wandb.log({
                    'train_loss': running_loss / (batch_idx + 1),
                    'train_acc': 100. * correct / total,
                    'learning_rate': optimizer.param_groups[0]['lr']
                })

        # Epoch-wise Training Metrics
        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total

        # Validation Phase
        model.eval()
        test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        # Epoch-wise Validation Metrics
        acc = 100. * correct / total
        avg_test_loss = test_loss / len(test_loader)
        wandb.log({
            'epoch': epoch,
            'test_loss': avg_test_loss,
            'test_acc': acc
        })

        print(f"Epoch {epoch + 1}/{config['epochs']} - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {avg_test_loss:.4f}, Test Acc: {acc:.2f}%")

        # Early Stopping Logic
        if acc > best_acc:
            best_acc = acc
            best_model_wts = copy.deepcopy(model.state_dict())
            trigger_times = 0
            # Save the best model
            torch.save(model.state_dict(), f'best_nvit_{run_name}.pth')
        else:
            trigger_times += 1
            if trigger_times >= patience:
                print("Early stopping triggered!")
                break

        # Step the scheduler
        scheduler.step()

        # Log best accuracy so far
        wandb.log({'best_test_acc': best_acc})

    # Load Best Model Weights
    model.load_state_dict(best_model_wts)
    print(f"Training completed. Best Test Accuracy: {best_acc:.2f}%")
    wandb.finish()

####################
# Main Function    #
####################

if __name__ == '__main__':
    run_experiment()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc



Running Experiment: nViT_CIFAR100_convstem=False_mixup=False_aug=advanced on device: cuda

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169M/169M [00:13<00:00, 12.2MB/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified
Epoch 1/100 - Train Loss: 4.1983, Train Acc: 5.37%, Test Loss: 4.1039, Test Acc: 7.06%


KeyboardInterrupt: 