### Installation/Initialization

In [1]:
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import RandAugment

### Configurations

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
BATCH_SIZE = 128
EPOCHS = 150
LR = 3e-4
PATCH_SIZE = 2
EMBED_DIM = 256
DEPTH = 6
NUM_HEADS = 8
MLP_DIM = 512
NUM_CLASSES = 10
IMG_SIZE = 32

### Dataset

In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    RandAugment(num_ops=2, magnitude=9),  # stronger augmentation
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

trainset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train)
testset = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test)

trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

100%|██████████| 170M/170M [00:04<00:00, 41.6MB/s]


### Vision Transformer

In [5]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
        super().__init__()
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x) # [B, embed_dim, H/P, W/P]
        x = x.flatten(2) # [B, embed_dim, N]
        x = x.transpose(1, 2) # [B, N, embed_dim]
        return x

In [6]:
class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim=256, num_heads=8, mlp_dim=512, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x

In [7]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size=32, patch_size=4, num_classes=10,
                 embed_dim=256, depth=6, num_heads=8, mlp_dim=512):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.num_patches, embed_dim))
        self.blocks = nn.ModuleList([
            TransformerEncoder(embed_dim, num_heads, mlp_dim) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        B = x.size(0)
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return self.head(x[:, 0])

In [8]:
# Training
model = VisionTransformer(img_size=IMG_SIZE, patch_size=PATCH_SIZE,
                          num_classes=NUM_CLASSES, embed_dim=EMBED_DIM,
                          depth=DEPTH, num_heads=NUM_HEADS, mlp_dim=MLP_DIM).to(device)

optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.05)
scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS)
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

In [9]:
def train_epoch(model, loader):
    model.train()
    total, correct, loss_sum = 0, 0, 0
    for imgs, labels in tqdm(loader):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        loss_sum += loss.item()
        _, preds = outputs.max(1)
        total += labels.size(0)
        correct += preds.eq(labels).sum().item()
    return loss_sum / len(loader), 100. * correct / total

In [10]:
def test_epoch(model, loader):
    model.eval()
    total, correct, loss_sum = 0, 0, 0
    with torch.no_grad():
        for imgs, labels in tqdm(loader):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss_sum += loss.item()
            _, preds = outputs.max(1)
            total += labels.size(0)
            correct += preds.eq(labels).sum().item()
    return loss_sum / len(loader), 100. * correct / total

In [None]:
for epoch in range(EPOCHS):
    train_loss, train_acc = train_epoch(model, trainloader)
    test_loss, test_acc = test_epoch(model, testloader)
    scheduler.step()
    print(f"Epoch {epoch+1}: Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")

100%|██████████| 391/391 [03:08<00:00,  2.07it/s]
100%|██████████| 79/79 [00:13<00:00,  5.83it/s]


Epoch 1: Train Acc=26.40%, Test Acc=40.49%


100%|██████████| 391/391 [03:11<00:00,  2.04it/s]
100%|██████████| 79/79 [00:13<00:00,  5.89it/s]


Epoch 2: Train Acc=38.64%, Test Acc=48.57%


100%|██████████| 391/391 [03:11<00:00,  2.05it/s]
100%|██████████| 79/79 [00:13<00:00,  5.89it/s]


Epoch 3: Train Acc=45.43%, Test Acc=52.46%


100%|██████████| 391/391 [03:12<00:00,  2.04it/s]
100%|██████████| 79/79 [00:13<00:00,  5.83it/s]


Epoch 4: Train Acc=49.39%, Test Acc=56.59%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.84it/s]


Epoch 5: Train Acc=51.93%, Test Acc=58.96%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.85it/s]


Epoch 6: Train Acc=54.65%, Test Acc=60.67%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.83it/s]


Epoch 7: Train Acc=56.22%, Test Acc=63.05%


100%|██████████| 391/391 [03:11<00:00,  2.04it/s]
100%|██████████| 79/79 [00:13<00:00,  5.78it/s]


Epoch 8: Train Acc=57.78%, Test Acc=63.00%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.87it/s]


Epoch 9: Train Acc=59.22%, Test Acc=63.91%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.86it/s]


Epoch 10: Train Acc=60.12%, Test Acc=65.85%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.83it/s]


Epoch 11: Train Acc=61.35%, Test Acc=66.39%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.85it/s]


Epoch 12: Train Acc=62.04%, Test Acc=68.56%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.81it/s]


Epoch 13: Train Acc=62.82%, Test Acc=66.52%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.85it/s]


Epoch 14: Train Acc=63.49%, Test Acc=68.18%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.85it/s]


Epoch 15: Train Acc=64.28%, Test Acc=70.03%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.85it/s]


Epoch 16: Train Acc=64.90%, Test Acc=70.16%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.83it/s]


Epoch 17: Train Acc=65.98%, Test Acc=69.24%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.83it/s]


Epoch 18: Train Acc=66.18%, Test Acc=70.35%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.84it/s]


Epoch 19: Train Acc=67.06%, Test Acc=71.22%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.84it/s]


Epoch 20: Train Acc=67.47%, Test Acc=71.46%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.82it/s]


Epoch 21: Train Acc=67.78%, Test Acc=72.50%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.86it/s]


Epoch 22: Train Acc=68.47%, Test Acc=72.58%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.83it/s]


Epoch 23: Train Acc=68.92%, Test Acc=73.04%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.84it/s]


Epoch 24: Train Acc=69.27%, Test Acc=73.14%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.87it/s]


Epoch 25: Train Acc=69.73%, Test Acc=72.91%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.83it/s]


Epoch 26: Train Acc=70.22%, Test Acc=73.67%


100%|██████████| 391/391 [03:13<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.84it/s]


Epoch 27: Train Acc=70.56%, Test Acc=74.18%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.88it/s]


Epoch 28: Train Acc=70.58%, Test Acc=73.91%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.86it/s]


Epoch 29: Train Acc=71.35%, Test Acc=74.73%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.79it/s]


Epoch 30: Train Acc=71.91%, Test Acc=74.95%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.88it/s]


Epoch 31: Train Acc=72.33%, Test Acc=73.94%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.83it/s]


Epoch 32: Train Acc=72.71%, Test Acc=74.51%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.84it/s]


Epoch 33: Train Acc=72.83%, Test Acc=75.12%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.85it/s]


Epoch 34: Train Acc=73.70%, Test Acc=74.34%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.87it/s]


Epoch 35: Train Acc=73.46%, Test Acc=74.30%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.83it/s]


Epoch 36: Train Acc=74.02%, Test Acc=75.49%


100%|██████████| 391/391 [03:13<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.82it/s]


Epoch 37: Train Acc=73.97%, Test Acc=75.61%


100%|██████████| 391/391 [03:13<00:00,  2.02it/s]
100%|██████████| 79/79 [00:13<00:00,  5.83it/s]


Epoch 38: Train Acc=74.90%, Test Acc=75.71%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.84it/s]


Epoch 39: Train Acc=74.87%, Test Acc=75.43%


100%|██████████| 391/391 [03:12<00:00,  2.03it/s]
100%|██████████| 79/79 [00:13<00:00,  5.82it/s]


Epoch 40: Train Acc=75.52%, Test Acc=76.08%


 98%|█████████▊| 382/391 [03:08<00:04,  2.03it/s]