# An implementation of Visual Transformer. 
## References - AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE, Alexey Dosovitskiy et. al

In [299]:
import torch 
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torchvision.datasets import CIFAR10
import torch.nn.functional as F
import torchvision.transforms as transforms

In [300]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [301]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

In [302]:
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset  = CIFAR10(root='./data', train=False, download=True, transform=transform)

In [303]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [304]:
num_classes = len(train_dataset.classes) 
print(num_classes)

10


In [305]:
image = train_dataset[0]
shape = image[0].shape
C = shape[0]
H = shape[1]
W = shape[2]
P = 4
block_size = H * W/(P**2)
print(block_size)
print(C,H,W)

64.0
3 32 32


In [306]:
"""
Total Model Parameters: 33.67M
"""
n_heads = 12
block_size = H * W//(P**2)
batch_size = 32
n_embd = 252
max_iters = 30
patch_size = 4
dropout = 0.1

In [307]:
for images, labels in train_loader:
    print(images.shape) 
    print(labels.shape) 
    break

torch.Size([32, 3, 32, 32])
torch.Size([32])


In [308]:
class AttentionHead(nn.Module):
    def __init__(self, head_size):
        super(AttentionHead, self).__init__()
        self.head_size = head_size
        self.w_q = nn.Linear(n_embd, head_size)
        self.w_k = nn.Linear(n_embd, head_size)
        self.w_v = nn.Linear(n_embd, head_size)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        q = self.w_q(x)
        k = self.w_k(x)
        v = self.w_v(x)

        scores = q @ k.transpose(-2,-1)
        scores = (scores/torch.sqrt(torch.tensor(k.size(-1))))
        attn_w = F.softmax(scores, dim=-1) 
        attn_w = self.dropout(attn_w)
        output = attn_w @ v
        return output

In [309]:
class MultiHeadAttention(nn.Module):
    """
    Apply Self-Attention using multiple heads over the input x.
    """
    def __init__(self, num_heads, head_size):
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList([AttentionHead(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)
    def forward(self,x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [310]:
class EncoderBlock(nn.Module):
    def __init__(self, n_embd, n_head, mlp_ratio=2):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        
        hidden_dim = n_embd * mlp_ratio
        self.mlp = nn.Sequential(
            nn.Linear(n_embd, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, n_embd),
            nn.Dropout(dropout)
        )
        
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))   
        x = x + self.mlp(self.ln2(x)) 
        return x

In [311]:
class ViT(nn.Module):
    def __init__(self, n_embd, n_head, block_size, n_class, channel_size, patch_size):
        super(ViT, self).__init__()
        self.P = patch_size
        self.proj = nn.Conv2d(channel_size, n_embd, kernel_size=self.P, stride=self.P)
        self.pos_embd = nn.Parameter(torch.zeros(1, block_size + 1, n_embd))

        self.cls_token = nn.Parameter(torch.zeros(1, 1, n_embd))

        self.enc_layers = nn.Sequential(
            EncoderBlock(n_embd, n_head),
            EncoderBlock(n_embd, n_head),
            EncoderBlock(n_embd, n_head),
            EncoderBlock(n_embd, n_head),
            EncoderBlock(n_embd, n_head),
            EncoderBlock(n_embd, n_head),
            nn.LayerNorm(n_embd)
        )
        self.mlp = nn.Linear(n_embd, n_class)
        self.norm = nn.LayerNorm(n_embd)
        nn.init.trunc_normal_(self.pos_embd, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        
    def forward(self, x):
        B = x.shape[0]  
        x = self.proj(x)
        x = x.flatten(2).transpose(1, 2)
        
        cls_tokens = self.cls_token.expand(B, -1, -1)

        x = torch.cat((cls_tokens, x), dim=1)  

        x = x + self.pos_embd

        x = self.enc_layers(x)
        x = self.norm(x)
        logits = self.mlp(x[:, 0, :])
        return logits

In [312]:
model = ViT(n_embd, n_heads, block_size, num_classes, channel_size=C, patch_size= patch_size)
m = model.to(device)

In [313]:
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Trainable Parameters: {total_params:,}")

Total Trainable Parameters: 3,097,342


In [315]:
criterion = nn.CrossEntropyLoss()

In [316]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

In [317]:
for epoch in range(max_iters):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        logits = model(images)  
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    train_acc = correct / total
    if (epoch + 1) % 1 == 0 or (epoch + 1) == max_iters:
        print(f"Epoch {epoch+1}/{max_iters} | Loss: {total_loss/len(train_loader):.4f} | Train Acc: {train_acc:.4f}")

Epoch 1/30 | Loss: 1.6620 | Train Acc: 0.3841
Epoch 2/30 | Loss: 1.2975 | Train Acc: 0.5269


KeyboardInterrupt: 

In [None]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        logits = model(images)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

test_acc = correct / total
print(f"Test Accuracy: {test_acc:.4f}")

In [None]:
torch.save(model.state_dict(), "vit_small_weights.pth")

In [None]:
"""

"""