In [1]:
# ============================================
# Tiny Vision Transformer (efficient variant)
# Train on CIFAR-10 or CIFAR-100
# ============================================
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision, torchvision.transforms as T
from tqdm import tqdm

# ---------------------
# Model definition
# ---------------------
class ConvStem(nn.Module):
    def __init__(self, in_ch=3, out_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_dim//2, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_dim//2),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_dim//2, out_dim, 3, 2, 1, bias=False),
            nn.BatchNorm2d(out_dim),
            nn.ReLU(inplace=True),
        )
    def forward(self, x): return self.net(x)

class PatchEmbed(nn.Module):
    def __init__(self, in_ch, embed_dim, patch_size=1):
        super().__init__()
        self.proj = nn.Conv2d(in_ch, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        return x

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        return self.drop(self.fc2(self.act(self.fc1(x))))

class EfficientAttention(nn.Module):
    def __init__(self, dim, heads=4, attn_drop=0.):
        super().__init__()
        self.heads = heads
        self.dim = dim
        self.head_dim = dim // heads
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)
        self.attn_drop = nn.Dropout(attn_drop)
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.heads, self.head_dim).permute(2,0,3,1,4)
        q, k, v = qkv
        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1,2).reshape(B, N, C)
        return self.proj(x)

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, mlp_ratio=2., drop=0., attn_drop=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = EfficientAttention(dim, heads, attn_drop)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim*mlp_ratio), drop)
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class TinyViT(nn.Module):
    def __init__(self, img_size=32, num_classes=10, embed_dim=128,
                 depth=8, heads=4, mlp_ratio=2.0, conv_stem=True):
        super().__init__()
        self.conv_stem = ConvStem(3, embed_dim) if conv_stem else None
        self.patch = PatchEmbed(embed_dim if conv_stem else 3, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        self.pos_embed = None
        self.blocks = nn.ModuleList([TransformerBlock(embed_dim, heads, mlp_ratio) for _ in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        if self.conv_stem: x = self.conv_stem(x)
        x = self.patch(x)
        B, N, C = x.shape
        cls_token = self.cls_token.expand(B,-1,-1)
        if (self.pos_embed is None) or (self.pos_embed.shape[1] != N+1):
            self.pos_embed = nn.Parameter(torch.zeros(1,N+1,C,device=x.device))
            nn.init.trunc_normal_(self.pos_embed, std=0.02)
        x = torch.cat([cls_token,x],dim=1)
        x = x + self.pos_embed
        for blk in self.blocks: x = blk(x)
        x = self.norm(x)
        return self.head(x[:,0])

# ---------------------
# Data
# ---------------------
def get_dataloaders(dataset='cifar10', bs=128):
    t_train = T.Compose([
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
    ])
    t_test = T.Compose([
        T.ToTensor(),
        T.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
    ])
    if dataset == 'cifar10':
        train = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=t_train)
        test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=t_test)
        num_classes = 10
    else:
        train = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=t_train)
        test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=t_test)
        num_classes = 100
    return DataLoader(train, batch_size=bs, shuffle=True, num_workers=2), \
           DataLoader(test, batch_size=bs, shuffle=False, num_workers=2), num_classes

# ---------------------
# Training
# ---------------------
def train_one_epoch(model, loader, opt, device, loss_fn):
    model.train()
    total_loss = 0
    for x,y in tqdm(loader, leave=False):
        x,y = x.to(device), y.to(device)
        opt.zero_grad()
        out = model(x)
        loss = loss_fn(out,y)
        loss.backward()
        opt.step()
        total_loss += loss.item()
    return total_loss/len(loader)

def evaluate(model, loader, device):
    model.eval()
    correct,total = 0,0
    with torch.no_grad():
        for x,y in loader:
            x,y = x.to(device), y.to(device)
            pred = model(x).argmax(1)
            correct += (pred==y).sum().item()
            total += y.size(0)
    return correct/total

# ---------------------
# Main training loop
# ---------------------
dataset = "cifar10"  # change to "cifar100" for CIFAR-100
epochs = 100          # increase to 150-300 for better performance
bs = 128
lr = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_loader, test_loader, num_classes = get_dataloaders(dataset, bs)

model = TinyViT(embed_dim=128, depth=8, heads=4, mlp_ratio=2.0, num_classes=num_classes).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.05)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
loss_fn = nn.CrossEntropyLoss()

print(f"Model params: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

best_acc = 0
for ep in range(epochs):
    loss = train_one_epoch(model, train_loader, opt, device, loss_fn)
    acc = evaluate(model, test_loader, device)
    sched.step()
    print(f"Epoch {ep+1:03d}/{epochs} | loss {loss:.4f} | val_acc {acc*100:.2f}%")
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), f"best_tinyvit_{dataset}.pth")

print(f"✅ Training done. Best val accuracy: {best_acc*100:.2f}%")


100%|██████████| 170M/170M [00:09<00:00, 18.4MB/s]


Model params: 1.15M


                                                 

Epoch 001/100 | loss 1.6157 | val_acc 50.32%


                                                 

Epoch 002/100 | loss 1.2563 | val_acc 56.92%


                                                 

Epoch 003/100 | loss 1.1322 | val_acc 61.52%


                                                 

Epoch 004/100 | loss 1.0444 | val_acc 63.06%


                                                 

Epoch 005/100 | loss 0.9751 | val_acc 64.99%


                                                 

Epoch 006/100 | loss 0.9170 | val_acc 65.24%


                                                 

Epoch 007/100 | loss 0.8750 | val_acc 66.79%


                                                 

Epoch 008/100 | loss 0.8227 | val_acc 69.61%


                                                 

Epoch 009/100 | loss 0.7862 | val_acc 71.67%


                                                 

Epoch 010/100 | loss 0.7553 | val_acc 70.11%


                                                 

Epoch 011/100 | loss 0.7223 | val_acc 72.34%


                                                 

Epoch 012/100 | loss 0.6936 | val_acc 72.28%


                                                 

Epoch 013/100 | loss 0.6656 | val_acc 73.73%


                                                 

Epoch 014/100 | loss 0.6394 | val_acc 74.10%


                                                 

Epoch 015/100 | loss 0.6115 | val_acc 73.58%


                                                 

Epoch 016/100 | loss 0.5854 | val_acc 74.82%


                                                 

Epoch 017/100 | loss 0.5617 | val_acc 75.10%


                                                 

Epoch 018/100 | loss 0.5402 | val_acc 74.81%


                                                 

Epoch 019/100 | loss 0.5131 | val_acc 75.04%


                                                 

Epoch 020/100 | loss 0.4850 | val_acc 76.81%


                                                 

Epoch 021/100 | loss 0.4697 | val_acc 77.47%


                                                 

Epoch 022/100 | loss 0.4383 | val_acc 76.37%


                                                 

Epoch 023/100 | loss 0.4173 | val_acc 75.49%


                                                 

Epoch 024/100 | loss 0.4057 | val_acc 77.03%


                                                 

Epoch 025/100 | loss 0.3754 | val_acc 77.06%


                                                 

Epoch 026/100 | loss 0.3592 | val_acc 76.89%


                                                 

Epoch 027/100 | loss 0.3436 | val_acc 76.57%


                                                 

Epoch 028/100 | loss 0.3199 | val_acc 77.56%


                                                 

Epoch 029/100 | loss 0.3030 | val_acc 77.24%


                                                 

Epoch 030/100 | loss 0.2848 | val_acc 76.68%


                                                 

Epoch 031/100 | loss 0.2721 | val_acc 77.07%


                                                 

Epoch 032/100 | loss 0.2549 | val_acc 77.72%


                                                 

Epoch 033/100 | loss 0.2393 | val_acc 77.20%


                                                 

Epoch 034/100 | loss 0.2168 | val_acc 77.28%


                                                 

Epoch 035/100 | loss 0.2159 | val_acc 78.05%


                                                 

Epoch 036/100 | loss 0.1943 | val_acc 77.25%


                                                 

Epoch 037/100 | loss 0.1854 | val_acc 78.18%


                                                 

Epoch 038/100 | loss 0.1801 | val_acc 77.64%


                                                 

Epoch 039/100 | loss 0.1612 | val_acc 77.78%


                                                 

Epoch 040/100 | loss 0.1597 | val_acc 76.92%


                                                 

Epoch 041/100 | loss 0.1401 | val_acc 78.06%


                                                 

Epoch 042/100 | loss 0.1353 | val_acc 78.13%


                                                 

Epoch 043/100 | loss 0.1259 | val_acc 77.65%


                                                 

Epoch 044/100 | loss 0.1198 | val_acc 78.11%


                                                 

Epoch 045/100 | loss 0.1104 | val_acc 77.63%


                                                 

Epoch 046/100 | loss 0.1148 | val_acc 78.03%


                                                 

Epoch 047/100 | loss 0.0987 | val_acc 78.16%


                                                 

Epoch 048/100 | loss 0.0951 | val_acc 77.75%


                                                 

Epoch 049/100 | loss 0.0889 | val_acc 77.69%


                                                 

Epoch 050/100 | loss 0.0841 | val_acc 77.84%


                                                 

Epoch 051/100 | loss 0.0746 | val_acc 78.23%


                                                 

Epoch 052/100 | loss 0.0732 | val_acc 78.20%


                                                 

Epoch 053/100 | loss 0.0748 | val_acc 77.76%


                                                 

Epoch 054/100 | loss 0.0661 | val_acc 77.74%


                                                 

Epoch 055/100 | loss 0.0638 | val_acc 78.08%


                                                 

Epoch 056/100 | loss 0.0570 | val_acc 78.32%


                                                 

Epoch 057/100 | loss 0.0519 | val_acc 77.64%


                                                 

Epoch 058/100 | loss 0.0507 | val_acc 78.08%


                                                 

Epoch 059/100 | loss 0.0471 | val_acc 78.44%


                                                 

Epoch 060/100 | loss 0.0429 | val_acc 78.23%


                                                 

Epoch 061/100 | loss 0.0413 | val_acc 78.66%


                                                 

Epoch 062/100 | loss 0.0411 | val_acc 78.33%


                                                 

Epoch 063/100 | loss 0.0375 | val_acc 78.34%


                                                 

Epoch 064/100 | loss 0.0322 | val_acc 78.32%


                                                 

Epoch 065/100 | loss 0.0269 | val_acc 78.69%


                                                 

Epoch 066/100 | loss 0.0309 | val_acc 78.67%


                                                 

Epoch 067/100 | loss 0.0261 | val_acc 78.66%


                                                 

Epoch 068/100 | loss 0.0256 | val_acc 78.29%


                                                 

Epoch 069/100 | loss 0.0210 | val_acc 79.07%


                                                 

Epoch 070/100 | loss 0.0197 | val_acc 78.85%


                                                 

Epoch 071/100 | loss 0.0193 | val_acc 78.38%


                                                 

Epoch 072/100 | loss 0.0179 | val_acc 78.53%


                                                 

Epoch 073/100 | loss 0.0161 | val_acc 78.79%


                                                 

Epoch 074/100 | loss 0.0131 | val_acc 78.61%


                                                 

Epoch 075/100 | loss 0.0116 | val_acc 78.62%


                                                 

Epoch 076/100 | loss 0.0122 | val_acc 78.91%


                                                 

Epoch 077/100 | loss 0.0120 | val_acc 79.09%


                                                 

Epoch 078/100 | loss 0.0099 | val_acc 79.14%


                                                 

Epoch 079/100 | loss 0.0072 | val_acc 78.93%


                                                 

Epoch 080/100 | loss 0.0078 | val_acc 79.02%


                                                 

Epoch 081/100 | loss 0.0073 | val_acc 79.21%


                                                 

Epoch 082/100 | loss 0.0068 | val_acc 79.26%


                                                 

Epoch 083/100 | loss 0.0062 | val_acc 79.04%


                                                 

Epoch 084/100 | loss 0.0051 | val_acc 79.17%


                                                 

Epoch 085/100 | loss 0.0044 | val_acc 79.37%


                                                 

Epoch 086/100 | loss 0.0038 | val_acc 79.44%


                                                 

Epoch 087/100 | loss 0.0034 | val_acc 79.40%


                                                 

Epoch 088/100 | loss 0.0032 | val_acc 79.36%


                                                 

Epoch 089/100 | loss 0.0040 | val_acc 79.43%


                                                 

Epoch 090/100 | loss 0.0032 | val_acc 79.29%


                                                 

Epoch 091/100 | loss 0.0027 | val_acc 79.30%


                                                 

Epoch 092/100 | loss 0.0026 | val_acc 79.17%


                                                 

Epoch 093/100 | loss 0.0026 | val_acc 79.37%


                                                 

Epoch 094/100 | loss 0.0025 | val_acc 79.38%


                                                 

Epoch 095/100 | loss 0.0020 | val_acc 79.27%


                                                 

Epoch 096/100 | loss 0.0019 | val_acc 79.41%


                                                 

Epoch 097/100 | loss 0.0019 | val_acc 79.48%


                                                 

Epoch 098/100 | loss 0.0019 | val_acc 79.45%


                                                 

Epoch 099/100 | loss 0.0016 | val_acc 79.39%


                                                 

Epoch 100/100 | loss 0.0017 | val_acc 79.39%
✅ Training done. Best val accuracy: 79.48%
