In [10]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import datasets, transforms
import wandb
import os
import numpy as np
from einops import rearrange
from einops.layers.torch import Rearrange
import math
import random


In [14]:
# Define MixUp Function
def mixup_data(x, y, alpha=0.2):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(x.device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

# Define MixUp Criterion
def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

In [11]:
class Attention(nn.Module):
    def __init__(self, dim, *, dim_head=64, heads=8, dropout=0.0):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.dim_head = dim_head
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.attn_drop = nn.Dropout(dropout)
        self.proj = nn.Linear(inner_dim, dim)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.to_qkv(x)
        qkv = qkv.reshape(B, N, 3, self.heads, self.dim_head)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, B, heads, N, dim_head)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q * self.scale
        attn = torch.matmul(q, k.transpose(-2, -1))
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).reshape(B, N, -1)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class FeedForward(nn.Module):
    def __init__(self, dim, dim_inner, dropout=0.0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, dim_inner),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(dim_inner, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

In [12]:
# Modified ViT for CIFAR-100
class ViT(nn.Module):
    def __init__(
        self,
        *,
        image_size,
        patch_size,
        num_classes,
        dim,
        depth,
        heads,
        mlp_dim,
        dropout=0.0,
        emb_dropout=0.0,
        channels=3,
        dim_head=64
    ):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, \
            'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width

        self.patch_size = patch_size
        self.dim = dim

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h ph) (w pw) -> b (h w) (ph pw c)', ph=patch_height, pw=patch_width),
            nn.Linear(patch_dim, dim)
        )

        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = nn.ModuleList([])
        for _ in range(depth):
            self.transformer.append(nn.ModuleList([
                nn.LayerNorm(dim),
                Attention(dim, dim_head=dim_head, heads=heads, dropout=dropout),
                nn.LayerNorm(dim),
                FeedForward(dim, dim_inner=mlp_dim, dropout=dropout)
            ]))

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        B, N, _ = x.shape

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embedding[:, :N + 1, :]
        x = self.dropout(x)

        for norm1, attn, norm2, ff in self.transformer:
            x = x + attn(norm1(x))
            x = x + ff(norm2(x))

        x = x[:, 0]
        x = self.mlp_head(x)
        return x


In [15]:
def main():
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    np.random.seed(42)
    random.seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Initialize wandb
    wandb.init(project='vit-cifar100', config={
        'model': 'ViT',
        'dataset': 'CIFAR-100',
        'epochs': 200,  # Increased epochs for better convergence
        'batch_size': 128,
        'learning_rate': 3e-4,
        'weight_decay': 5e-2,  # Increased weight decay
        'image_size': 32,
        'patch_size': 4,
        'dim': 384,
        'depth': 12,  # Increased depth for better representation
        'heads': 6,  # Adjusted number of heads
        'mlp_dim': 384 * 4,
        'dropout': 0.1,
        'emb_dropout': 0.1,
        'num_classes': 100,
        'mixup_alpha': 0.2,  # Added MixUp alpha
        'label_smoothing': 0.1  # Added label smoothing
    })
    config = wandb.config

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Data transforms for CIFAR-100 with additional augmentations
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])

    # Load CIFAR-100 dataset
    train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # Initialize model
    model = ViT(
        image_size=config.image_size,
        patch_size=config.patch_size,
        num_classes=config.num_classes,
        dim=config.dim,
        depth=config.depth,
        heads=config.heads,
        mlp_dim=config.mlp_dim,
        dropout=config.dropout,
        emb_dropout=config.emb_dropout
    ).to(device)

    # Loss function with label smoothing
    criterion = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)

    # Optimizer with parameter-wise weight decay (no decay for bias and norm parameters)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {
            'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            'weight_decay': config.weight_decay
        },
        {
            'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            'weight_decay': 0.0
        }
    ]
    optimizer = optim.AdamW(optimizer_grouped_parameters, lr=config.learning_rate)

    # Learning rate scheduler with warmup
    total_steps = config.epochs * len(train_loader)
    warmup_steps = int(0.1 * total_steps)  # 10% of total steps for warmup

    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return 0.5 * (1. + math.cos(math.pi * (current_step - warmup_steps) / (total_steps - warmup_steps)))

    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    # Training loop
    best_acc = 0.0
    for epoch in range(config.epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            # Apply MixUp
            inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, alpha=config.mixup_alpha)
            inputs, targets_a, targets_b = map(torch.autograd.Variable, (inputs, targets_a, targets_b))

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            scheduler.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            # For MixUp, accuracy is not straightforward. Use separate metrics or skip accuracy during training.
            total += targets.size(0)
            # Approximate correct predictions
            correct += (lam * predicted.eq(targets_a).sum().item() + (1 - lam) * predicted.eq(targets_b).sum().item())

            if batch_idx % 100 == 0:
                wandb.log({
                    'train_loss': running_loss / (batch_idx + 1),
                    'train_acc': 100. * correct / total,
                    'learning_rate': optimizer.param_groups[0]['lr']
                })

        train_acc = 100. * correct / total

        # Validation
        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        acc = 100. * correct / total
        wandb.log({
            'test_loss': test_loss / len(test_loader),
            'test_acc': acc,
            'epoch': epoch
        })

        # Save best model
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), 'best_cifar100_vit.pth')

        print(f"Epoch {epoch + 1}/{config.epochs} - "
              f"Train Loss: {running_loss / len(train_loader):.4f}, Train Acc: {train_acc:.2f}%, "
              f"Test Loss: {test_loss / len(test_loader):.4f}, Test Acc: {acc:.2f}%")

        # Log hyperparameters and metrics at the end of each epoch
        wandb.log({
            'epoch': epoch,
            'best_test_acc': best_acc
        })

    print(f"Training completed. Best Test Accuracy: {best_acc:.2f}%")
    wandb.finish()

if __name__ == '__main__':
    main()


VBox(children=(Label(value='0.012 MB of 0.012 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

Files already downloaded and verified
Files already downloaded and verified
Epoch 1/200 - Train Loss: 4.5775, Train Acc: 2.06%, Test Loss: 4.4002, Test Acc: 3.77%
Epoch 2/200 - Train Loss: 4.3902, Train Acc: 4.15%, Test Loss: 4.2550, Test Acc: 5.68%
Epoch 3/200 - Train Loss: 4.2412, Train Acc: 6.80%, Test Loss: 3.9857, Test Acc: 10.45%
Epoch 4/200 - Train Loss: 4.0786, Train Acc: 10.12%, Test Loss: 3.7612, Test Acc: 15.64%
Epoch 5/200 - Train Loss: 3.9211, Train Acc: 13.38%, Test Loss: 3.5925, Test Acc: 18.95%
Epoch 6/200 - Train Loss: 3.8285, Train Acc: 15.45%, Test Loss: 3.4326, Test Acc: 23.00%
Epoch 7/200 - Train Loss: 3.7285, Train Acc: 18.00%, Test Loss: 3.3353, Test Acc: 25.46%
Epoch 8/200 - Train Loss: 3.6466, Train Acc: 19.87%, Test Loss: 3.2844, Test Acc: 26.91%
Epoch 9/200 - Train Loss: 3.5785, Train Acc: 21.33%, Test Loss: 3.2138, Test Acc: 28.58%
Epoch 10/200 - Train Loss: 3.5342, Train Acc: 22.57%, Test Loss: 3.1294, Test Acc: 30.78%
Epoch 11/200 - Train Loss: 3.4854, Tra

VBox(children=(Label(value='0.030 MB of 0.030 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
best_test_acc,▁▂▂▂▃▄▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇▇▇█████████████████
epoch,▁▁▁▂▂▂▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇███
learning_rate,▄▄▄▆███████▇▇▇▇▆▆▆▅▅▅▄▄▄▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁
test_acc,▁▂▃▄▄▆▆▆▆▇▇▇▇▇██████████████████████████
test_loss,█▇▆▅▅▂▂▂▂▁▁▁▂▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁
train_acc,▁▂▃▃▄▄▅▅▅▆▆▆▆▇▇█▇▇▇▆█▆▃█▇▆▇▇▇▇█▇▃▇▇█▇▇▇▇
train_loss,██▇▇▅▅█▄▄▄▄▃▃▃▃▃▃▃▁▃▃▃▇▃▃▁▃▂▃▂▂▂▁▃▂▂▂▂▇▂

0,1
best_test_acc,58.77
epoch,199.0
learning_rate,0.0
test_acc,58.62
test_loss,2.34108
train_acc,87.77142
train_loss,1.33631
