In [None]:
!pip install -q torch torchvision tqdm

In [1]:
# Imports
import math, time, os
from pathlib import Path
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, random_split

In [2]:
# Config - tweak for experiments
CFG = {
    'seed': 42,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'batch_size': 128,
    'epochs': 50,           # increase for better accuracy
    'lr': 3e-4,
    'weight_decay': 0.05,
    'patch_size': 4,        # try 4, 8 (4 gives more patches on 32x32)
    'embed_dim': 256,       # model width
    'depth': 12,             # number of transformer blocks
    'num_heads': 8,
    'mlp_ratio': 4.0,       # MLP hidden dim = embed_dim * mlp_ratio
    'dropout': 0.1,
    'num_workers': 2,
    'output_dir': '/content/vit_ckpt'
}
os.makedirs(CFG['output_dir'], exist_ok=True)
torch.manual_seed(CFG['seed'])

<torch._C.Generator at 0x79d1900534d0>

In [3]:
# Data transforms (augmentations)
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.1,0.1,0.1,0.01),
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465),(0.247,0.243,0.261))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(trainset, batch_size=CFG['batch_size'], shuffle=True, num_workers=CFG['num_workers'], pin_memory=True)
test_loader  = DataLoader(testset, batch_size=CFG['batch_size'], shuffle=False, num_workers=CFG['num_workers'], pin_memory=True)


100%|██████████| 170M/170M [00:03<00:00, 49.5MB/s]


In [4]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=256):
        super().__init__()
        assert img_size % patch_size == 0, "img_size must be divisible by patch_size"
        self.num_patches = (img_size // patch_size) ** 2
        self.patch_size = patch_size
        # We implement patchify + linear projection as a conv layer with stride=patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)            # (B, embed_dim, H/ps, W/ps)
        x = x.flatten(2).transpose(1,2)  # (B, num_patches, embed_dim)
        return x

In [5]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, num_heads=8, dropout=0.0):
        super().__init__()
        assert dim % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=True)
        self.proj = nn.Linear(dim, dim)
        self.attn_drop = nn.Dropout(dropout)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x):
        # x: (B, N, dim)
        B, N, C = x.shape
        qkv = self.qkv(x)  # (B, N, 3*dim)
        qkv = qkv.reshape(B, N, 3, self.num_heads, self.head_dim).permute(2,0,3,1,4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # each: (B, num_heads, N, head_dim)

        attn = (q @ k.transpose(-2,-1)) * self.scale  # (B, heads, N, N)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        out = (attn @ v)  # (B, heads, N, head_dim)
        out = out.transpose(1,2).reshape(B,N,C)
        out = self.proj(out)
        out = self.proj_drop(out)
        return out

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, dropout=0.0):
        super().__init__()
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, in_features)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


In [6]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.1, attn_dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, num_heads=num_heads, dropout=attn_dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, int(dim * mlp_ratio), dropout=dropout)
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


In [7]:
class ViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_chans=3, num_classes=10,
                 embed_dim=256, depth=8, num_heads=8, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches

        # cls token
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim))
        # positional embedding
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=dropout)

        # Transformer encoder stack
        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_ratio, dropout, attn_dropout=dropout)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)

        # Classifier head
        self.head = nn.Linear(embed_dim, num_classes)

        # init
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)

    def forward(self, x):
        # x: (B, C, H, W)
        B = x.shape[0]
        x = self.patch_embed(x)  # (B, N, dim)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B,1,dim)
        x = torch.cat((cls_tokens, x), dim=1)  # (B, N+1, dim)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls_out = x[:,0]  # (B, dim)
        logits = self.head(cls_out)
        return logits


In [8]:
device = torch.device(CFG['device'])
model = ViT(img_size=32, patch_size=CFG['patch_size'], embed_dim=CFG['embed_dim'],
            depth=CFG['depth'], num_heads=CFG['num_heads'], mlp_ratio=CFG['mlp_ratio'],
            dropout=CFG['dropout']).to(device)

# Optimizer & scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG['lr'], weight_decay=CFG['weight_decay'])

# cosine warmup scheduler
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * num_cycles * 2.0 * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

total_steps = len(train_loader) * CFG['epochs']
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=500, num_training_steps=total_steps)
criterion = nn.CrossEntropyLoss()


In [9]:
def train_one_epoch(model, loader, optimizer, scheduler, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    pbar = tqdm(loader, leave=False)
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        logits = model(images)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        running_loss += loss.item() * images.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += images.size(0)
        pbar.set_description(f"loss: {running_loss/total:.4f} acc: {100*correct/total:.2f}%")
    return running_loss/total, 100*correct/total

def evaluate(model, loader, device):
    model.eval()
    total = 0
    correct = 0
    loss_sum = 0.0
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            labels = labels.to(device)
            logits = model(images)
            loss = criterion(logits, labels)
            loss_sum += loss.item() * images.size(0)
            preds = logits.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)
    return loss_sum/total, 100*correct/total


In [10]:
best_acc = 0.0
for epoch in range(1, CFG['epochs']+1):
    t0 = time.time()
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, scheduler, device)
    val_loss, val_acc = evaluate(model, test_loader, device)
    t1 = time.time()
    print(f"Epoch {epoch}/{CFG['epochs']}  train_loss={train_loss:.4f} train_acc={train_acc:.2f}%  val_loss={val_loss:.4f} val_acc={val_acc:.2f}%  time={t1-t0:.1f}s")
    # save best
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save({'model_state': model.state_dict(), 'cfg': CFG, 'epoch': epoch}, os.path.join(CFG['output_dir'], 'best_vit.pth'))
print("Best val acc:", best_acc)


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 1/50  train_loss=1.9752 train_acc=27.07%  val_loss=1.8824 val_acc=31.79%  time=96.3s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 2/50  train_loss=1.6492 train_acc=38.79%  val_loss=1.4300 val_acc=47.19%  time=96.0s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 3/50  train_loss=1.4373 train_acc=47.19%  val_loss=1.3971 val_acc=49.14%  time=96.4s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 4/50  train_loss=1.3126 train_acc=52.16%  val_loss=1.2217 val_acc=55.98%  time=96.3s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 5/50  train_loss=1.2511 train_acc=54.73%  val_loss=1.1582 val_acc=57.65%  time=96.4s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 6/50  train_loss=1.1886 train_acc=56.92%  val_loss=1.0909 val_acc=60.96%  time=96.6s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 7/50  train_loss=1.1368 train_acc=58.76%  val_loss=1.0689 val_acc=61.73%  time=96.3s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 8/50  train_loss=1.0960 train_acc=60.41%  val_loss=0.9904 val_acc=64.01%  time=96.9s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 9/50  train_loss=1.0600 train_acc=61.72%  val_loss=1.0665 val_acc=62.16%  time=96.9s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 10/50  train_loss=1.0268 train_acc=62.84%  val_loss=0.9075 val_acc=67.17%  time=96.6s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 11/50  train_loss=0.9900 train_acc=64.43%  val_loss=0.9415 val_acc=66.07%  time=96.4s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 12/50  train_loss=0.9588 train_acc=65.72%  val_loss=0.9178 val_acc=67.43%  time=96.5s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 13/50  train_loss=0.9241 train_acc=66.88%  val_loss=0.8421 val_acc=69.82%  time=97.0s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 14/50  train_loss=0.8991 train_acc=67.84%  val_loss=0.8833 val_acc=68.54%  time=96.7s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 15/50  train_loss=0.8686 train_acc=68.89%  val_loss=0.8221 val_acc=71.08%  time=97.0s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 16/50  train_loss=0.8426 train_acc=69.84%  val_loss=0.7880 val_acc=71.68%  time=96.7s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 17/50  train_loss=0.8165 train_acc=71.00%  val_loss=0.7853 val_acc=72.49%  time=96.5s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 18/50  train_loss=0.7909 train_acc=71.83%  val_loss=0.7567 val_acc=73.40%  time=96.3s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 19/50  train_loss=0.7703 train_acc=72.59%  val_loss=0.7341 val_acc=74.38%  time=96.4s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 20/50  train_loss=0.7424 train_acc=73.27%  val_loss=0.7728 val_acc=73.27%  time=96.9s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 21/50  train_loss=0.7213 train_acc=74.51%  val_loss=0.6950 val_acc=75.27%  time=96.8s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 22/50  train_loss=0.6980 train_acc=75.17%  val_loss=0.6791 val_acc=75.92%  time=96.7s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 23/50  train_loss=0.6766 train_acc=75.93%  val_loss=0.6765 val_acc=76.12%  time=96.7s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 24/50  train_loss=0.6537 train_acc=76.89%  val_loss=0.6640 val_acc=77.22%  time=96.9s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 25/50  train_loss=0.6308 train_acc=77.59%  val_loss=0.6727 val_acc=76.57%  time=96.7s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 26/50  train_loss=0.6149 train_acc=78.04%  val_loss=0.6909 val_acc=76.29%  time=96.6s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 27/50  train_loss=0.5905 train_acc=78.87%  val_loss=0.6431 val_acc=77.76%  time=96.6s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 28/50  train_loss=0.5752 train_acc=79.36%  val_loss=0.6230 val_acc=78.65%  time=96.9s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 29/50  train_loss=0.5574 train_acc=80.14%  val_loss=0.6478 val_acc=77.87%  time=97.2s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 30/50  train_loss=0.5357 train_acc=80.78%  val_loss=0.6396 val_acc=78.30%  time=96.7s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 31/50  train_loss=0.5140 train_acc=81.73%  val_loss=0.6222 val_acc=78.76%  time=97.2s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 32/50  train_loss=0.4980 train_acc=82.12%  val_loss=0.6178 val_acc=79.03%  time=97.1s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 33/50  train_loss=0.4866 train_acc=82.53%  val_loss=0.6205 val_acc=79.55%  time=97.1s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 34/50  train_loss=0.4657 train_acc=83.28%  val_loss=0.6123 val_acc=79.87%  time=96.6s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 35/50  train_loss=0.4504 train_acc=84.15%  val_loss=0.6046 val_acc=80.12%  time=96.5s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 36/50  train_loss=0.4380 train_acc=84.38%  val_loss=0.6128 val_acc=79.79%  time=96.5s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 37/50  train_loss=0.4251 train_acc=84.80%  val_loss=0.6076 val_acc=80.40%  time=96.6s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 38/50  train_loss=0.4045 train_acc=85.49%  val_loss=0.5983 val_acc=80.52%  time=97.1s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 39/50  train_loss=0.3971 train_acc=85.79%  val_loss=0.6035 val_acc=80.64%  time=96.5s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 40/50  train_loss=0.3808 train_acc=86.35%  val_loss=0.6095 val_acc=80.72%  time=96.9s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 41/50  train_loss=0.3703 train_acc=86.57%  val_loss=0.5997 val_acc=81.16%  time=97.1s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 42/50  train_loss=0.3607 train_acc=87.08%  val_loss=0.5987 val_acc=81.13%  time=96.9s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 43/50  train_loss=0.3524 train_acc=87.31%  val_loss=0.6035 val_acc=81.25%  time=96.7s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 44/50  train_loss=0.3462 train_acc=87.43%  val_loss=0.6035 val_acc=81.41%  time=96.5s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 45/50  train_loss=0.3431 train_acc=87.62%  val_loss=0.6011 val_acc=81.15%  time=96.9s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 46/50  train_loss=0.3358 train_acc=87.77%  val_loss=0.6073 val_acc=81.32%  time=96.5s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 47/50  train_loss=0.3341 train_acc=87.95%  val_loss=0.6048 val_acc=81.41%  time=96.4s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 48/50  train_loss=0.3340 train_acc=87.99%  val_loss=0.6044 val_acc=81.34%  time=96.5s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 49/50  train_loss=0.3270 train_acc=88.25%  val_loss=0.6050 val_acc=81.39%  time=97.0s


  0%|          | 0/391 [00:00<?, ?it/s]

Epoch 50/50  train_loss=0.3317 train_acc=87.93%  val_loss=0.6051 val_acc=81.37%  time=96.6s
Best val acc: 81.41
