In [1]:
MODEL_CFG = {
    'img_size': 32,
    'patch_size': 4,
    'in_chans': 3,
    'num_classes': 10,
    'embed_dim': 192,
    'depth': 10,
    'num_heads': 6,
    'mlp_ratio': 4.0,
    'drop_rate': 0.1,
    'attn_drop_rate': 0.0,
}

TRAIN_CFG = {
    'batch_size': 256,
    'epochs': 120,
    'lr': 3e-3,
    'weight_decay': 0.05,
    'warmup_epochs': 5,
    'device': 'cuda' if __import__('torch').cuda.is_available() else 'cpu',
}

# -------------------------
# Imports
# -------------------------
import math
import os
import random
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# -------------------------
# Utilities
# -------------------------

def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# -------------------------
# Small helper transforms: Cutout
# -------------------------
class Cutout(object):
    def __init__(self, n_holes, length):
        self.n_holes = n_holes
        self.length = length

    def __call__(self, img):
        # img: PIL Image or Tensor in [0,1]
        if not isinstance(img, torch.Tensor):
            img = transforms.functional.to_tensor(img)
        h, w = img.shape[1], img.shape[2]
        mask = torch.ones((h, w), dtype=torch.float32)
        for _ in range(self.n_holes):
            y = random.randrange(h)
            x = random.randrange(w)
            y1 = max(0, y - self.length // 2)
            y2 = min(h, y + self.length // 2)
            x1 = max(0, x - self.length // 2)
            x2 = min(w, x + self.length // 2)
            mask[y1:y2, x1:x2] = 0.0
        mask = mask.expand_as(img)
        img = img * mask
        return img

# -------------------------
# Data
# -------------------------

def get_dataloaders(batch_size=128, use_cutout=True):
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2470, 0.2435, 0.2616)

    train_transforms = [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
    if use_cutout:
        train_transforms.append(Cutout(n_holes=1, length=8))

    transform_train = transforms.Compose(train_transforms)

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])

    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, test_loader

# -------------------------
# ViT model
# -------------------------
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=192):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size // patch_size, img_size // patch_size)
        num_patches = self.grid_size[0] * self.grid_size[1]
        self.num_patches = num_patches
        # use a conv to get patches (non-overlapping)
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        # x: [B, C, H, W]
        x = self.proj(x)  # [B, embed_dim, H/ps, W/ps]
        x = x.flatten(2).transpose(1, 2)  # [B, num_patches, embed_dim]
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=attn_drop, batch_first=True)
        self.drop_path = nn.Identity()
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop)

    def forward(self, x):
        # x: [B, N, C]
        x_res = x
        x = self.norm1(x)
        # MultiheadAttention expects (B, N, C) with batch_first=True
        attn_out, _ = self.attn(x, x, x, need_weights=False)
        x = x_res + attn_out
        x_res = x
        x = self.norm2(x)
        x = x_res + self.mlp(x)
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=10, embed_dim=192, depth=10, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0.):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        # transformer blocks
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # initialization
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)  # [B, N, C]
        cls_tokens = self.cls_token.expand(B, -1, -1)  # [B,1,C]
        x = torch.cat((cls_tokens, x), dim=1)  # [B, N+1, C]
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        cls = x[:, 0]
        logits = self.head(cls)
        return logits

# -------------------------
# Training helpers
# -------------------------

def build_model(cfg):
    model = VisionTransformer(
        img_size=cfg['img_size'],
        patch_size=cfg['patch_size'],
        in_chans=cfg['in_chans'],
        num_classes=cfg['num_classes'],
        embed_dim=cfg['embed_dim'],
        depth=cfg['depth'],
        num_heads=cfg['num_heads'],
        mlp_ratio=cfg['mlp_ratio'],
        drop_rate=cfg['drop_rate'],
        attn_drop_rate=cfg['attn_drop_rate'],
    )
    return model


def get_parameter_groups(model, weight_decay=0.05):
    # no weight decay for biases and norm layers
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if name.endswith('.bias') or len(param.shape) == 1 or name.startswith('pos_embed') or name.startswith('cls_token'):
            no_decay.append(param)
        else:
            decay.append(param)
    return [{'params': decay, 'weight_decay': weight_decay}, {'params': no_decay, 'weight_decay': 0.0}]


def cosine_scheduler(base_lr, warmup_epochs, total_epochs, iters_per_epoch):
    total_iters = total_epochs * iters_per_epoch
    warmup_iters = warmup_epochs * iters_per_epoch
    def lr_lambda(current_step):
        if current_step < warmup_iters:
            return float(current_step) / float(max(1, warmup_iters))
        else:
            progress = float(current_step - warmup_iters) / float(max(1, total_iters - warmup_iters))
            return 0.5 * (1.0 + math.cos(math.pi * progress))
    return lr_lambda

# -------------------------
# Train / Eval loops
# -------------------------

def train_one_epoch(model, optimizer, data_loader, device, epoch, scheduler=None):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()
    for it, (images, targets) in enumerate(data_loader):
        images = images.to(device)
        targets = targets.to(device)
        outputs = model(images)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()
        running_loss += loss.item() * images.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(targets).sum().item()
        total += images.size(0)
    return running_loss / total, 100.0 * correct / total


def evaluate(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()
    running_loss = 0.0
    with torch.no_grad():
        for images, targets in data_loader:
            images = images.to(device)
            targets = targets.to(device)
            outputs = model(images)
            loss = criterion(outputs, targets)
            running_loss += loss.item() * images.size(0)
            _, preds = outputs.max(1)
            correct += preds.eq(targets).sum().item()
            total += images.size(0)
    return running_loss / total, 100.0 * correct / total

# -------------------------
# Main runner
# -------------------------

def main(model_cfg=MODEL_CFG, train_cfg=TRAIN_CFG):
    device = torch.device(train_cfg['device'])
    print('Using device:', device)
    train_loader, test_loader = get_dataloaders(batch_size=train_cfg['batch_size'])
    model = build_model(model_cfg).to(device)
    print('Model params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1e6))

    param_groups = get_parameter_groups(model, weight_decay=train_cfg['weight_decay'])
    optimizer = AdamW(param_groups, lr=train_cfg['lr'], betas=(0.9, 0.999))

    iters_per_epoch = len(train_loader)
    total_iters = train_cfg['epochs'] * iters_per_epoch
    lr_lambda = cosine_scheduler(train_cfg['lr'], train_cfg['warmup_epochs'], train_cfg['epochs'], iters_per_epoch)
    scheduler = LambdaLR(optimizer, lr_lambda)

    best_acc = 0.0
    history = {'train_loss': [], 'train_acc': [], 'test_loss': [], 'test_acc': []}

    for epoch in range(train_cfg['epochs']):
        train_loss, train_acc = train_one_epoch(model, optimizer, train_loader, device, epoch, scheduler)
        test_loss, test_acc = evaluate(model, test_loader, device)
        print(f'Epoch {epoch+1}/{train_cfg["epochs"]}  Train loss: {train_loss:.4f}  Train acc: {train_acc:.2f}%  Test acc: {test_acc:.2f}%')
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['test_loss'].append(test_loss)
        history['test_acc'].append(test_acc)
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save({'model_state_dict': model.state_dict(), 'cfg': model_cfg}, 'best_vit_cifar10.pth')
    print('Best test accuracy: %.2f%%' % best_acc)

    # write README with final config + tiny table
    readme = []
    readme.append('# ViT on CIFAR-10 (Colab)')
    readme.append('\n## Final config')
    readme.append('')
    for k, v in model_cfg.items():
        readme.append(f'- {k}: {v}')
    readme.append('\n## Training config')
    for k, v in train_cfg.items():
        readme.append(f'- {k}: {v}')
    readme.append('\n## Results (tiny table)')
    readme.append('\n| Model | Test accuracy |')
    readme.append('|---|---|')
    readme.append(f'| ViT | {best_acc:.2f}% |')
    readme.append('\n## Analysis')
    readme.append('\nSee top of script for short analysis (patch sizes, depth/width trade-offs, augmentations, optimizer).')
    open('README.md', 'w').write('\n'.join(readme))
    print('Saved README.md and best_vit_cifar10.pth')

if __name__ == '__main__':
    main()


Using device: cuda


100%|██████████| 170M/170M [02:58<00:00, 954kB/s]


Model params: 4.47M
Epoch 1/120  Train loss: 1.8834  Train acc: 29.88%  Test acc: 40.27%
Epoch 2/120  Train loss: 1.5793  Train acc: 42.29%  Test acc: 45.93%
Epoch 3/120  Train loss: 1.4694  Train acc: 46.55%  Test acc: 48.60%
Epoch 4/120  Train loss: 1.4304  Train acc: 47.86%  Test acc: 49.00%
Epoch 5/120  Train loss: 1.4372  Train acc: 47.60%  Test acc: 50.75%
Epoch 6/120  Train loss: 1.4363  Train acc: 47.74%  Test acc: 50.34%
Epoch 7/120  Train loss: 1.4099  Train acc: 48.61%  Test acc: 51.70%
Epoch 8/120  Train loss: 1.4086  Train acc: 48.67%  Test acc: 52.48%
Epoch 9/120  Train loss: 1.3943  Train acc: 49.23%  Test acc: 52.71%
Epoch 10/120  Train loss: 1.3873  Train acc: 49.67%  Test acc: 49.42%
Epoch 11/120  Train loss: 1.4127  Train acc: 48.41%  Test acc: 52.87%
Epoch 12/120  Train loss: 1.4040  Train acc: 49.05%  Test acc: 51.61%
Epoch 13/120  Train loss: 1.3871  Train acc: 49.73%  Test acc: 49.33%
Epoch 14/120  Train loss: 1.4046  Train acc: 48.97%  Test acc: 51.86%
Epoch 15/