In [None]:
# Idea of ViT:
'''


1. Patch Embedding - this includes
    1.1. Conv2d (passing each patch as 16x16 pixels) -> reshaping the patches (flatten)
    1.2. Adding cls token [This is added to the no. of patches + cls token]
            (Learnable parameter for the class i.e., like trying to put the image in a multi-dimentional map so it would be easier to classify)
    1.2. adding positional embedding (This patch belongs to which/what position?)
            (Learnable parameter where the positional embedding dimension = same as patch embedding dimension D)
    Note: Each patch has its own positional embedding and 1 cls token is added to the front of the entire embedding along with a positional embedding
2. Transformer Encoder
    2.1. Multi-Head Self Attention (MHA)
        2.1.1. converting to Q, K, V (Query, Key, Value)
        2.1.2. Scaled Dot production Attention
                2.1.2.1. Q,K -> MatMul -> Scale -> Mask -> SoftMax -> Out
                2.1.2.2. Out, V -> MatMul ->FinOut
        2.1.3. Concatenate
        2.1.4. Linear Layer            
    2.2. Transformer Block
        2.2.1. Layer Norm
        2.2.2 MLP
            2.2.2.1. FC
            2.2.2.2. GeLu
            2.2.2.3. FC
        2.2.3. Layer Norm
3. MLP head
    3.1. Take the cls token from the last output -> Linear layer (-> GeLu -> Dropout -> Linear )

'''

In [12]:
import math
from torch import nn, optim
import torch
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torch.nn import functional as f
import matplotlib.pyplot as plt

import time

In [13]:
# Global Variables
# I am intially setting these for MNIST Dataset
IMG_SIZE = 28
PATCH_SIZE = 4
IN_CHANNELS = 1
NUM_CLASSES = 10
EMBED_DIM = 64
TRANSFORMER_DEPTH = 6
NUM_HEADS = 8
MLP_RATIO = 4.0
DROPOUT = 0.1
EPOCHS = 10
LR = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [14]:
# 1. Patch Embedding
class PatchEmbedding(nn.Module):
    def __init__(self, image_size=28, patch_size=4, in_channels=1, embed_dim=64 ):
        super().__init__()
        assert image_size % patch_size==0, "Image size must be divisible by patch size"
        self.num_patches = (image_size//patch_size)**2

        # using conv2d network to convert the images to patches with kernel_size=patch_size, stirde=patch_size for non-overlapping patches
        self.proj = nn.Conv2d(in_channels=in_channels, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        x = self.proj(x) # 1.1.
        x = x.flatten(2) # 1.1.

        # To transform (batch, embed_dim, num_patches) ->  (batch, num_patches, embed_dim)
        x = x.transpose(1,2) 
        return x

In [15]:
# 2. Transformer Encoder
# 2.1. MultiHead self-ATTENTION
class MHSA(nn.Module):
    def __init__(self, embed_dim, num_heads, qkv_bias=True, attention_drop=0.0, proj_drop=0.0):
        super().__init__()
        assert embed_dim % num_heads == 0
        self.num_heads = num_heads
        head_dim = embed_dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim, embed_dim*3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attention_drop)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_drop)
    
    def forward(self, x, return_attention = False):
        B, N, C = x.shape # Batch_size, N_tokens, Embedding_dim
        qkv = self.qkv(x)
        qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
        qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, heads, N, head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # 2.1.2.1
        attn = (q @ k.transpose(-2,-1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        # 2.1.2.2
        out = (attn @ v)
        out = out.transpose(1,2).reshape(B, N, C)
        out = self.proj(out)
        out = self.proj_drop(out)
        if return_attention:
            return out, attn
        return out

In [16]:
# 2.2. Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, qkv_bias=True, p=0.0, attn_p=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim, eps=1e-6)
        self.attn = MHSA(embed_dim, num_heads, qkv_bias, attention_drop=attn_p, proj_drop=p)
        self.norm2 = nn.LayerNorm(embed_dim, eps=1e-6)

        hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(p),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(p)
        )
    
    def forward(self, x, return_attention=False):
        if return_attention:
            attn_out, attn = self.attn(self.norm1(x), return_attention)
            x = x + attn_out
            x = x + self.mlp(self.norm2(x))
            return x, attn
        else:
            x = x + self.attn(self.norm1(x))
            x = x + self.mlp(self.norm2(x))
            return x


In [36]:
# ViT
class ViT(nn.Module):
    def __init__(self, img_size=28, patch_size=4, in_channels=1, num_classes=10, embed_dim=64, depth=6, num_heads=8, mlp_ratio=4.0, p=0.0):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        # CLS token + Positional Embeddings
        self.cls_token = nn.Parameter(torch.zeros(1,1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches+1, embed_dim))
        self.pos_drop = nn.Dropout(p)

        # Transformer block
        self.blocks = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, p=p, attn_p=p) for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim, eps=1e-6)

        # Classification head
        self.head = nn.Linear(embed_dim, num_classes)

        self._init_weights()

    def _init_weights(self):
        nn.init.normal_(self.pos_embed, std=0.02)
        nn.init.normal_(self.cls_token, std=0.02)

        # Linear weights
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    
    def forward(self, x, return_all_attention=False):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        attentions = []
        if return_all_attention:
            for blk in self.blocks:
                x, attn = blk(x, return_attention=True)
                attentions.append(attn)
        else:
            for blk in self.blocks:
                x = blk(x)
        x = self.norm(x)
        cls_final = x[:, 0]
        logits = self.head(cls_final)
        if return_all_attention:
            return logits, attentions
        return logits

In [37]:
def get_dataLoaders(batch_size=64):
    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    # Downloading datasets
    train_dataset = datasets.MNIST(root="..", train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(root="..", train=False, transform=transform, download=True)

    # Dataloader for NN
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    val_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

    return train_loader, val_loader

In [38]:
# Train & eval
def train_one_epoch(model, loader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        # print(outputs.shape, labels.shape)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * imgs.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
    return running_loss / total, correct / total

def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * imgs.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
    return running_loss / total, correct / total

In [None]:
def main():
    train_loader, test_loader = get_dataLoaders()
    model = ViT(img_size=IMG_SIZE, patch_size=PATCH_SIZE, in_channels=IN_CHANNELS,
                num_classes=NUM_CLASSES, embed_dim=EMBED_DIM, depth=TRANSFORMER_DEPTH, num_heads=NUM_HEADS, mlp_ratio=MLP_RATIO, p=DROPOUT).to(DEVICE)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=LR, weight_decay=0.05)

    best_acc = 0.0
    for epoch in range(1, EPOCHS + 1):
        t0 = time.time()
        train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE)
        val_loss, val_acc = evaluate(model, test_loader, criterion, DEVICE)
        t1 = time.time()
        print(f"Epoch {epoch:02d}  time {(t1-t0):.1f}s  train_loss {train_loss:.4f} train_acc {train_acc:.4f}  val_loss {val_loss:.4f} val_acc {val_acc:.4f}")
        if val_acc > best_acc:
            best_acc = val_acc
            # torch.save(model.state_dict(), "vit_mnist_best.pth")
        break
    print("Best val acc:", best_acc)
    
if __name__ == '__main__':
    main()

Epoch 01  time 18.9s  train_loss 0.7376 train_acc 0.7454  val_loss 0.3529 val_acc 0.8857
Best val acc: 0.8857
