In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math
import copy
import random
import numpy as np
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import wandb
from PIL import Image
from torchvision.transforms import RandAugment

In [13]:
# Define MixUp Function
def mixup_data(x, y, alpha=0.2):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# Define CutMix Function
def cutmix_data(x, y, alpha=1.0):
    '''Returns cutmixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size, C, H, W = x.size()
    index = torch.randperm(batch_size).to(x.device)

    # Bounding box coordinates
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int64(W * cut_rat)
    cut_h = np.int64(H * cut_rat)

    # Uniformly sample the center of the bbox
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    # Apply CutMix
    x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
    # Adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (C * H * W))
    y_a, y_b = y, y[index]
    return x, y_a, y_b, lam

# Define MixUp and CutMix Criterion
def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [3]:
class Attention(nn.Module):
    def __init__(self, dim, *, dim_head=64, heads=8, dropout=0.0):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.dim_head = dim_head
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(inner_dim, dim)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.to_qkv(x)
        qkv = qkv.reshape(B, N, 3, self.heads, self.dim_head)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, N, dim_head)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q * self.scale
        attn = torch.matmul(q, k.transpose(-2, -1))
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).reshape(B, N, -1)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class FeedForward(nn.Module):
    def __init__(self, dim, dim_inner, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim_inner),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_inner, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [4]:
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

class StochasticDepth(nn.Module):
    def __init__(self, drop_prob):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if not self.training or self.drop_prob == 0.0:
            return x
        keep_prob = 1 - self.drop_prob
        # Shape [batch, 1, 1]
        mask = torch.rand(x.shape[0], 1, 1, device=x.device) < keep_prob
        return x / keep_prob * mask

class ConvStem(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels // 2, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels // 2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels // 2, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

In [14]:
class ViT(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        num_classes,
        dim,
        depth,
        heads,
        mlp_dim,
        dropout=0.2,          # Increased dropout
        emb_dropout=0.2,      # Increased emb_dropout
        channels=3,
        dim_head=64,
        stochastic_depth_rate=0.1  # Added stochastic depth rate
    ):
        super().__init__()
        self.conv_stem = ConvStem(channels, dim)
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, \
            'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        # Updated patch_dim to use with convstem
        patch_dim = dim * patch_height * patch_width

        self.patch_size = patch_size
        self.dim = dim

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h ph) (w pw) -> b (h w) (ph pw c)', ph=patch_height, pw=patch_width),
            nn.Linear(patch_dim, dim)
        )

        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        # Transformer layers with Stochastic Depth
        self.transformer = nn.ModuleList([])
        for i in range(depth):
            sd_rate = stochastic_depth_rate * float(i) / depth  # Linearly increase drop rate
            self.transformer.append(nn.ModuleList([
                nn.LayerNorm(dim),
                Attention(dim, dim_head=dim_head, heads=heads, dropout=dropout),
                nn.Dropout(dropout),
                StochasticDepth(sd_rate),
                nn.LayerNorm(dim),
                FeedForward(dim, dim_inner=mlp_dim, dropout=dropout),
                nn.Dropout(dropout),
                StochasticDepth(sd_rate)
            ]))

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.conv_stem(img)
        x = self.to_patch_embedding(x)
        B, N, _ = x.shape

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embedding[:, :N + 1, :]
        x = self.dropout(x)

        for layer in self.transformer:
            norm1, attn, drop1, sd1, norm2, ff, drop2, sd2 = layer
            # Attention Block
            x_res = attn(norm1(x))
            x_res = drop1(x_res)
            x_res = sd1(x_res)
            x = x + x_res

            # FeedForward Block
            x_ff = ff(norm2(x))
            x_ff = drop2(x_ff)
            x_ff = sd2(x_ff)
            x = x + x_ff

        x = x[:, 0]
        x = self.mlp_head(x)
        return x

In [20]:
def main():
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    np.random.seed(42)
    random.seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Initialize wandb
    wandb.init(project='vit-cifar100', config={
        'model': 'ViT',
        'dataset': 'CIFAR-100',
        'epochs': 200,
        'batch_size': 128,
        'learning_rate': 3e-4,
        'weight_decay': 0.01,
        'image_size': 32,
        'patch_size': 2,
        'dim': 768,
        'depth': 8,                  # Increased depth
        'heads': 8,
        'mlp_dim': 768 * 4,
        'dropout': 0.2,               # Increased dropout
        'emb_dropout': 0.2,           # Increased emb_dropout
        'num_classes': 100,
        'mixup_alpha': 0.2,
        'cutmix_alpha': 1.0,          # Added CutMix alpha
        'label_smoothing': 0.1,
        'stochastic_depth_rate': 0.1   # Added stochastic depth
    })
    config = wandb.config

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Data transforms for CIFAR-100 with additional augmentations
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        RandAugment(num_ops=2, magnitude=10),  # Added RandAugment
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    # Load CIFAR-100 dataset
    train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # Initialize model
    model = ViT(
        image_size=config.image_size,
        patch_size=config.patch_size,
        num_classes=config.num_classes,
        dim=config.dim,
        depth=config.depth,
        heads=config.heads,
        mlp_dim=config.mlp_dim,
        dropout=config.dropout,
        emb_dropout=config.emb_dropout,
        channels=3,
        dim_head=64,
        stochastic_depth_rate=config.stochastic_depth_rate
    ).to(device)

    # Define Loss Function with Label Smoothing
    criterion = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)

    # Optimizer with parameter-wise weight decay (no decay for bias and norm parameters)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            'weight_decay': config.weight_decay
        },
        {
            'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            'weight_decay': 0.0
        }
    ]
    optimizer = optim.AdamW(optimizer_grouped_parameters, lr=config.learning_rate)

    # Learning rate scheduler with OneCycleLR
    #scheduler = optim.lr_scheduler.OneCycleLR(
    #    optimizer,
    #    max_lr=config.learning_rate,
    #    steps_per_epoch=len(train_loader),
    #    epochs=config.epochs,
    #    pct_start=0.1,
    #    anneal_strategy='cos',
    #    cycle_momentum=True,
    #    base_momentum=0.85,
    #    max_momentum=0.95,
    #    div_factor=25.0,
    #    final_div_factor=1e4
    #)

    # Optimizer with parameter-wise weight decay (no decay for bias and norm parameters)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            'weight_decay': config.weight_decay
        },
        {
            'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            'weight_decay': 0.0
        }
    ]
    optimizer = optim.AdamW(optimizer_grouped_parameters, lr=config.learning_rate)

    # Learning rate scheduler with cosine annealing and warmup
    total_steps = config.epochs * len(train_loader)
    warmup_steps = int(0.1 * total_steps)  # 10% of total steps for warmup

    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return 0.5 * (1. + math.cos(math.pi * (current_step - warmup_steps) / (total_steps - warmup_steps)))

    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


    # Training loop with Early Stopping
    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())
    patience = 30  # Increased patience for more epochs
    trigger_times = 0

    for epoch in range(config.epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            # Decide whether to apply MixUp or CutMix
            r = np.random.rand(1)
            if r < 0.5:
                # Apply MixUp
                inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, alpha=config.mixup_alpha)
                outputs = model(inputs)  # Store outputs
                loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            else:
                # Apply CutMix
                inputs, targets_a, targets_b, lam = cutmix_data(inputs, targets, alpha=config.cutmix_alpha)
                outputs = model(inputs)  # Store outputs
                loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)

            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            scheduler.step()

            running_loss += loss.item()

            # Calculate accuracy
            _, predicted = outputs.max(1)
            total += targets.size(0)
            if r < 0.5:
                # MixUp accuracy approximation
                correct += (lam * predicted.eq(targets_a).sum().item() + (1 - lam) * predicted.eq(targets_b).sum().item())
            else:
                # CutMix accuracy approximation
                correct += (lam * predicted.eq(targets_a).sum().item() + (1 - lam) * predicted.eq(targets_b).sum().item())

            if batch_idx % 100 == 0:
                wandb.log({
                    'train_loss': running_loss / (batch_idx + 1),
                    'train_acc': 100. * correct / total,
                    'learning_rate': scheduler.get_last_lr()[0]
                })

        train_loss = running_loss / len(train_loader)
        train_acc = 100. * correct / total

        # Validation Phase
        model.eval()
        test_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        acc = 100. * correct / total
        avg_test_loss = test_loss / len(test_loader)
        wandb.log({
            'test_loss': avg_test_loss,
            'test_acc': acc,
            'epoch': epoch
        })

        # Early Stopping Check
        if acc > best_acc:
            best_acc = acc
            best_model_wts = copy.deepcopy(model.state_dict())
            trigger_times = 0
            # Save the best model
            torch.save(model.state_dict(), 'best_cifar100_vit.pth')
        else:
            trigger_times += 1
            if trigger_times >= patience:
                print("Early stopping triggered!")
                break

        print(f"Epoch {epoch + 1}/{config.epochs} - "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {avg_test_loss:.4f}, Test Acc: {acc:.2f}%")

        # Log additional hyperparameters and metrics at the end of each epoch
        wandb.log({
            'epoch': epoch,
            'best_test_acc': best_acc
        })

    # Load best model weights
    model.load_state_dict(best_model_wts)
    print(f"Training completed. Best Test Accuracy: {best_acc:.2f}%")
    wandb.finish()

if __name__ == '__main__':
    main()


0,1
best_test_acc,▁▂▃▄▅▅▆▇▇████████████
epoch,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
learning_rate,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇███
test_acc,▁▂▃▄▅▅▆▇▇█████▇▇▇▇▆▇▇
test_loss,█▇▅▄▄▃▂▂▁▁▁▁▁▁▁▁▁▂▂▂▂
train_acc,▁▁▁▁▃▂▂▃▃▃▃▂▃▄▄▅▆▆▅▅▆█▆▆▆▃▆▆▆▆▆▅▅▆▆▅▆▆▅▅
train_loss,█▇▆▄▄▄▄▄▃▃▂▂▂▁▂▂▃▂▂▂▂▂▁▁▁▂▁▂▂▂▂▂▂▂▂▂▂▃▂▂

0,1
best_test_acc,30.02
epoch,20.0
learning_rate,0.0005
test_acc,25.15
test_loss,3.32945
train_acc,13.06596
train_loss,3.98994


Files already downloaded and verified
Files already downloaded and verified
Epoch 1/200 - Train Loss: 4.6408, Train Acc: 1.52%, Test Loss: 4.3702, Test Acc: 3.37%
Epoch 2/200 - Train Loss: 4.4767, Train Acc: 3.19%, Test Loss: 4.2319, Test Acc: 6.03%
Epoch 3/200 - Train Loss: 4.3574, Train Acc: 4.96%, Test Loss: 3.9533, Test Acc: 11.14%
Epoch 4/200 - Train Loss: 4.2809, Train Acc: 6.55%, Test Loss: 3.8191, Test Acc: 13.86%
Epoch 5/200 - Train Loss: 4.2154, Train Acc: 8.23%, Test Loss: 3.6639, Test Acc: 17.28%
Epoch 6/200 - Train Loss: 4.1606, Train Acc: 9.25%, Test Loss: 3.6180, Test Acc: 17.94%
Epoch 7/200 - Train Loss: 4.0677, Train Acc: 11.01%, Test Loss: 3.4404, Test Acc: 23.42%
Epoch 8/200 - Train Loss: 4.0113, Train Acc: 12.24%, Test Loss: 3.3382, Test Acc: 25.06%
Epoch 9/200 - Train Loss: 3.9895, Train Acc: 12.98%, Test Loss: 3.2585, Test Acc: 26.68%
Epoch 10/200 - Train Loss: 3.9873, Train Acc: 13.20%, Test Loss: 3.1957, Test Acc: 28.68%
Epoch 11/200 - Train Loss: 3.9209, Train 

KeyboardInterrupt: 