In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import wandb
import os
import numpy as np
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
import torch.nn.utils.parametrize as parametrize
import math
import random
import copy

In [None]:
# Set random seeds for reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
np.random.seed(42)
random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [None]:
# Helper functions
def exists(v):
    return v is not None

def default(v, d):
    return v if exists(v) else d

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

def divisible_by(numer, denom):
    return (numer % denom) == 0

def l2norm(t, dim=-1):
    return F.normalize(t, dim=dim, p=2)

# For use with parametrize
class L2Norm(nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        return l2norm(t, dim=self.dim)

class NormLinear(nn.Module):
    def __init__(self, dim, dim_out, norm_dim_in=True):
        super().__init__()
        self.linear = nn.Linear(dim, dim_out, bias=False)

        parametrize.register_parametrization(
            self.linear,
            'weight',
            L2Norm(dim=-1 if norm_dim_in else 0)
        )

    @property
    def weight(self):
        return self.linear.weight

    def forward(self, x):
        return self.linear(x)

In [None]:
# Define MixUp Function
def mixup_data(x, y, alpha=0.2):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# Define MixUp Criterion
def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# Scaled dot product attention function
def scaled_dot_product_attention(q, k, v, dropout_p=0., training=True):
    d_k = q.size(-1)
    attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    attn_weights = F.softmax(attn_weights, dim=-1)
    if training and dropout_p > 0.0:
        attn_weights = F.dropout(attn_weights, p=dropout_p)
    output = torch.matmul(attn_weights, v)
    return output

In [None]:
# Attention and FeedForward classes
class Attention(nn.Module):
    def __init__(self, dim, *, dim_head=64, heads=8, dropout=0.):
        super().__init__()
        dim_inner = dim_head * heads
        self.to_q = NormLinear(dim, dim_inner)
        self.to_k = NormLinear(dim, dim_inner)
        self.to_v = NormLinear(dim, dim_inner)

        self.dropout = dropout

        self.q_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))
        self.k_scale = nn.Parameter(torch.ones(heads, 1, dim_head) * (dim_head ** 0.25))

        self.split_heads = Rearrange('b n (h d) -> b h n d', h=heads)
        self.merge_heads = Rearrange('b h n d -> b n (h d)')

        self.to_out = NormLinear(dim_inner, dim, norm_dim_in=False)

    def forward(self, x):
        q, k, v = self.to_q(x), self.to_k(x), self.to_v(x)

        q, k, v = map(self.split_heads, (q, k, v))

        # Query key rmsnorm
        q, k = map(l2norm, (q, k))

        q = q * self.q_scale
        k = k * self.k_scale

        out = scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.dropout,
            training=self.training
        )

        out = self.merge_heads(out)
        return self.to_out(out)

In [None]:
class FeedForward(nn.Module):
    def __init__(self, dim, *, dim_inner, dropout=0.):
        super().__init__()
        dim_inner = int(dim_inner * 2 / 3)

        self.dim = dim
        self.dropout = nn.Dropout(dropout)

        self.to_hidden = NormLinear(dim, dim_inner)
        self.to_gate = NormLinear(dim, dim_inner)

        self.hidden_scale = nn.Parameter(torch.ones(dim_inner))
        self.gate_scale = nn.Parameter(torch.ones(dim_inner))

        self.to_out = NormLinear(dim_inner, dim, norm_dim_in=False)

    def forward(self, x):
        hidden, gate = self.to_hidden(x), self.to_gate(x)

        hidden = hidden * self.hidden_scale
        gate = gate * self.gate_scale * (self.dim ** 0.5)

        hidden = F.silu(gate) * hidden

        hidden = self.dropout(hidden)
        return self.to_out(hidden)

In [None]:
# ConvStem Module
class ConvStem(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 2, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels // 2, out_channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

In [None]:
# Updated nViT with ConvStem and ViT-like configurations
class nViT(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        num_classes,
        dim,
        depth,
        heads,
        mlp_dim,
        dropout=0.,
        emb_dropout=0.,
        channels=3,
        dim_head=64,
        residual_lerp_scale_init=None
    ):
        super().__init__()
        image_height, image_width = pair(image_size)

        assert divisible_by(image_height, patch_size) and divisible_by(image_width, patch_size), 'Image dimensions must be divisible by the patch size.'

        patch_height_dim, patch_width_dim = (image_height // patch_size), (image_width // patch_size)
        patch_dim = channels * (patch_size ** 2)
        num_patches = patch_height_dim * patch_width_dim

        self.channels = channels
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.image_size = image_size

        # ConvStem integration
        self.conv_stem = ConvStem(channels, dim)

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c h w -> b (h w) c'),
            NormLinear(dim, dim, norm_dim_in=False),
        )

        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.emb_dropout = nn.Dropout(emb_dropout)

        residual_lerp_scale_init = default(residual_lerp_scale_init, 1. / depth)

        self.dim = dim
        self.scale = dim ** 0.5

        self.layers = nn.ModuleList([])
        self.residual_lerp_scales = nn.ModuleList([])

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, dim_head=dim_head, heads=heads, dropout=dropout),
                FeedForward(dim, dim_inner=mlp_dim, dropout=dropout),
            ]))

            self.residual_lerp_scales.append(nn.ParameterList([
                nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
                nn.Parameter(torch.ones(dim) * residual_lerp_scale_init / self.scale),
            ]))

        # Classification head
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        device = x.device

        x = self.conv_stem(x)
        x = self.to_patch_embedding(x)

        B, N, C = x.shape

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embedding[:, :N + 1, :]
        x = self.emb_dropout(x)

        for (attn, ff), residual_scales in zip(self.layers, self.residual_lerp_scales):
            attn_alpha, ff_alpha = residual_scales

            attn_out = l2norm(attn(x))
            x = l2norm(x.lerp(attn_out, attn_alpha * self.scale))

            ff_out = l2norm(ff(x))
            x = l2norm(x.lerp(ff_out, ff_alpha * self.scale))

        # Classification token
        tokens = x[:, 0]
        logits = self.mlp_head(tokens)
        return logits

In [None]:
def main():
    # Initialize wandb
    wandb.init(project='nvit-cifar100-v2-withConv', config={
        'model': 'nViT',
        'dataset': 'CIFAR-100',
        'epochs': 200,  # Increased to match ViT
        'batch_size': 128,
        'learning_rate': 3e-4,
        'weight_decay': 0.01,  # Updated to match ViT
        'image_size': 32,
        'patch_size': 1,  # Updated to match ViT
        'dim': 512,        # Updated to match ViT
        'depth': 8,
        'heads': 8,
        'mlp_dim': 512 * 4,  # Updated to match ViT
        'dropout': 0.1,
        'emb_dropout': 0.1,
        'num_classes': 100,
        'dim_head': 64,
        'mixup_alpha': 0.2,     # Added MixUp alpha
        'label_smoothing': 0.1  # Added label smoothing
    })
    config = wandb.config

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Data transforms with advanced augmentations
    transform_train = transforms.Compose([
        transforms.RandomCrop(config.image_size, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    # Load CIFAR-100 dataset
    train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # Initialize model
    model = nViT(
        image_size=config.image_size,
        patch_size=config.patch_size,
        num_classes=config.num_classes,
        dim=config.dim,
        depth=config.depth,
        heads=config.heads,
        mlp_dim=config.mlp_dim,
        dropout=config.dropout,
        emb_dropout=config.emb_dropout,
        channels=3,
        dim_head=config.dim_head
    ).to(device)

    # Define Loss Function with Label Smoothing
    criterion = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)

    # Optimizer with parameter-wise weight decay (no decay for bias and norm parameters)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            'weight_decay': config.weight_decay
        },
        {
            'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            'weight_decay': 0.0
        }
    ]
    optimizer = optim.AdamW(optimizer_grouped_parameters, lr=config.learning_rate)

    # Learning rate scheduler with cosine annealing and warmup
    total_steps = config.epochs * len(train_loader)
    warmup_steps = int(0.1 * total_steps)  # 10% of total steps for warmup

    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return 0.5 * (1. + math.cos(math.pi * (current_step - warmup_steps) / (total_steps - warmup_steps)))

    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # Training loop with Early Stopping
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    patience = 20  # Number of epochs to wait for improvement
    trigger_times = 0

    for epoch in range(config.epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            # Apply MixUp
            inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, alpha=config.mixup_alpha)
            inputs, targets_a, targets_b = map(lambda x: x.to(device), (inputs, targets_a, targets_b))

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            scheduler.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            # Approximate correct predictions with MixUp
            correct += (lam * predicted.eq(targets_a).sum().item() + (1 - lam) * predicted.eq(targets_b).sum().item())

            if batch_idx % 100 == 0:
                wandb.log({
                    'train_loss': running_loss / (batch_idx + 1),
                    'train_acc': 100. * correct / total,
                    'learning_rate': optimizer.param_groups[0]['lr']
                })

        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total

        # Validation Phase
        model.eval()
        test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        acc = 100. * correct / total
        avg_test_loss = test_loss / len(test_loader)
        wandb.log({
            'test_loss': avg_test_loss,
            'test_acc': acc,
            'epoch': epoch
        })

        # Early Stopping Check
        if acc > best_acc:
            best_acc = acc
            best_model_wts = copy.deepcopy(model.state_dict())
            trigger_times = 0
            # Save the best model
            torch.save(model.state_dict(), 'best_nvit_cifar100.pth')
        else:
            trigger_times += 1
            if trigger_times >= patience:
                print("Early stopping triggered!")
                break

        print(f"Epoch {epoch + 1}/{config.epochs} - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {avg_test_loss:.4f}, Test Acc: {acc:.2f}%")

        # Log additional hyperparameters and metrics at the end of each epoch
        wandb.log({
            'epoch': epoch,
            'best_test_acc': best_acc
        })

    # Load best model weights
    model.load_state_dict(best_model_wts)
    print(f"Training completed. Best Test Accuracy: {best_acc:.2f}%")
    wandb.finish()

if __name__ == '__main__':
    main()

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169M/169M [00:03<00:00, 42.9MB/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified
Epoch 1/200 - Train Loss: 4.5049, Train Acc: 3.02%, Test Loss: 4.1999, Test Acc: 7.30%
Epoch 2/200 - Train Loss: 4.1677, Train Acc: 8.03%, Test Loss: 3.8610, Test Acc: 13.30%
Epoch 3/200 - Train Loss: 3.9792, Train Acc: 11.77%, Test Loss: 3.6667, Test Acc: 17.00%
Epoch 4/200 - Train Loss: 3.8451, Train Acc: 14.54%, Test Loss: 3.4901, Test Acc: 20.69%
Epoch 5/200 - Train Loss: 3.7037, Train Acc: 17.67%, Test Loss: 3.3235, Test Acc: 24.80%
Epoch 6/200 - Train Loss: 3.6161, Train Acc: 19.91%, Test Loss: 3.1284, Test Acc: 29.16%
Epoch 7/200 - Train Loss: 3.4760, Train Acc: 23.23%, Test Loss: 3.0202, Test Acc: 32.62%
Epoch 8/200 - Train Loss: 3.3566, Train Acc: 25.96%, Test Loss: 2.9546, Test Acc: 33.42%
Epoch 9/200 - Train Loss: 3.2472, Train Acc: 28.80%, Test Loss: 2.8372, Test Acc: 37.15%
Epoch 10/200 - Train Loss: 3.1709, Train Acc: 31.15%, Test Loss: 2.6789, Test Acc: 40.89%
Epoch 11/200 - Train L

0,1
best_test_acc,▁▂▄▄▅▆▆▆▆▇▇▇▇▇██████████████████████████
epoch,▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▃▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██
learning_rate,▁▂▂▃▃▄▄▆▆▆████████████████████▇▇▇▇▇▇▇▇▇▇
test_acc,▁▂▂▃▃▄▅▅▆▆▇▇▇▇▇█████████████████████████
test_loss,█▇▆▅▅▄▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂
train_acc,▁▁▂▂▂▃▃▃▂▃▃▄▄▃▄▄▅▂▅▅▅▆▆▇▆▇▇▇█▇▇▇▇█▇▃▆▇▇▇
train_loss,█▇▇▆▆▆▆▆▅▅▄▃▄▄▄▂▂▃▆▃▃▆▃▁▂▃▂▂▂▁▂▂▃▂▂▂▆▂▂▂

0,1
best_test_acc,58.92
epoch,68.0
learning_rate,0.00025
test_acc,58.02
test_loss,2.43563
train_acc,83.91947
train_loss,1.47264
