In [None]:
!pip install -q torch torchvision

import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


In [None]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
        super().__init__()
        assert img_size % patch_size == 0
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size * self.grid_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, qkv_bias=True, drop=0., attn_drop=0., norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(dim, mlp_hidden_dim, drop=drop)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=10, embed_dim=256, depth=8, num_heads=8, mlp_ratio=4.0, drop_rate=0.):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)
        self.blocks = nn.ModuleList([Block(embed_dim, num_heads, mlp_ratio, drop=drop_rate) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        cls = x[:, 0]
        return self.head(cls)


In [None]:
def get_data_loaders(batch_size=128, num_workers=2):
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2470, 0.2435, 0.2616)
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ])
    trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
    testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return train_loader, test_loader


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total, correct = 0, 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        preds = model(imgs).argmax(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)
    return correct / total


In [None]:
def train_vit(epochs=200, batch_size=128, lr=3e-4, weight_decay=0.05, patch_size=4, embed_dim=256, depth=8, heads=8, device='cuda'):
    set_seed(42)
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    train_loader, test_loader = get_data_loaders(batch_size=batch_size)
    model = VisionTransformer(embed_dim=embed_dim, patch_size=patch_size, depth=depth, num_heads=heads)
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    best_acc = 0
    for epoch in range(1, epochs + 1):
        model.train()
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = F.cross_entropy(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        acc = evaluate(model, test_loader, device)
        print(f"Epoch {epoch:03d}: Test Accuracy = {acc*100:.2f}%")
        best_acc = max(best_acc, acc)
    print(f"Best Accuracy: {best_acc*100:.2f}%")
    return model


In [None]:
model = train_vit(epochs=200, batch_size=256, lr=3e-4, embed_dim=256, depth=8, heads=8)

100%|██████████| 170M/170M [00:14<00:00, 11.8MB/s]


Epoch 001: Test Accuracy = 38.03%
Epoch 002: Test Accuracy = 49.34%
Epoch 003: Test Accuracy = 53.88%
Epoch 004: Test Accuracy = 56.74%
Epoch 005: Test Accuracy = 60.80%
Epoch 006: Test Accuracy = 63.33%
Epoch 007: Test Accuracy = 66.29%
Epoch 008: Test Accuracy = 67.66%
Epoch 009: Test Accuracy = 68.92%
Epoch 010: Test Accuracy = 70.25%
Epoch 011: Test Accuracy = 70.68%
Epoch 012: Test Accuracy = 73.02%
Epoch 013: Test Accuracy = 73.02%
Epoch 014: Test Accuracy = 73.91%
Epoch 015: Test Accuracy = 75.71%
Epoch 016: Test Accuracy = 74.62%
Epoch 017: Test Accuracy = 75.85%
Epoch 018: Test Accuracy = 75.17%
Epoch 019: Test Accuracy = 75.90%
Epoch 020: Test Accuracy = 76.24%
Epoch 021: Test Accuracy = 77.00%
Epoch 022: Test Accuracy = 78.13%
Epoch 023: Test Accuracy = 77.94%
Epoch 024: Test Accuracy = 78.18%
Epoch 025: Test Accuracy = 78.72%
Epoch 026: Test Accuracy = 78.78%
Epoch 027: Test Accuracy = 78.52%
Epoch 028: Test Accuracy = 79.24%
Epoch 029: Test Accuracy = 78.99%
Epoch 030: Tes

##### Accuracy
1. epochs=50, batch_size=128, lr=3e-4, embed_dim=256, depth=8, heads=8: 80.18%
2. epochs=100, batch_size=256, lr=3e-4, embed_dim=256, depth=8, heads=8: 80.35%
3. epochs=100, batch_size=128, lr=3e-4, embed_dim=256, depth=8, heads=8: 80.89%