<a href="https://colab.research.google.com/github/SumayT9/transformer_implementations/blob/main/ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import os
os.chdir("drive/MyDrive/projects")

# Vision Transformer on MNIST

In [None]:
import numpy as np

from tqdm import tqdm, trange

import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader

from torchvision.transforms import ToTensor
from torchvision.datasets.mnist import MNIST

In [None]:
mnist_trainset = MNIST(root='./data', train=True, download=True, transform=ToTensor())
mnist_testset = MNIST(root='./data', train=False, download=True, transform=ToTensor())

In [None]:
img, label = mnist_trainset[0]
img.shape

torch.Size([1, 28, 28])

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, n_heads):
        # Dims
        super().__init__()
        assert embed_dim % n_heads == 0
        self.embed_dim = embed_dim
        self.n_heads = n_heads
        self.head_dim = embed_dim // n_heads
        # Nets
        self.w_q = nn.Linear(self.embed_dim, self.embed_dim)
        self.w_k = nn.Linear(self.embed_dim, self.embed_dim)
        self.w_v = nn.Linear(self.embed_dim, self.embed_dim)


    def expand_mask(self, mask):
        assert mask.ndim >= 2, "Mask must be at least 2-dimensional with seq_length x seq_length"
        if mask.ndim == 3:
            mask = mask.unsqueeze(1)
        while mask.ndim < 4:
            mask = mask.unsqueeze(0)
        return mask

    def split(self, vec):
        # vec shape: b x seq_len x embed_dim
        b, seq_len, embed_dim = vec.shape
        vec = vec.reshape(b, seq_len, self.n_heads, self.head_dim)
        vec = vec.permute(0, 2, 1, 3) # b, n_heads, seq_len, head_dim
        return vec

    def forward(self, q=None, k=None, v=None, mask=None):
        # qkv shape: b x seq_len x embed_dim
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        q, k, v = self.split(q), self.split(k), self.split(v)
        b, n_heads, seq_len, head_dim = q.shape
        logits = q @ k.transpose(-2, -1)
        scaled = logits / np.sqrt(head_dim)
        if mask is not None:
            mask = self.expand_mask(mask)
            scaled = scaled.masked_fill(mask==0, -1e9)
        attn = torch.softmax(scaled, dim=-1)
        attn = attn @ v
        attn = attn.permute(0, 2, 1, 3) # b x seq_len x num_heads x head_dim
        attn = attn.reshape(b, seq_len, self.embed_dim)
        return attn


class EncoderBlock(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.self_attn = MultiHeadAttention(embed_dim=embed_dim, n_heads=num_heads)
        self.ff = nn.Sequential(
            torch.nn.Linear(embed_dim, 4*embed_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(4*embed_dim, embed_dim)
        )
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.d1 = nn.Dropout(p=0.1)
        self.d2 = nn.Dropout(p=0.1)

    def forward(self, x, mask=None):
        x = x + self.ln1(self.d1(self.self_attn(q=x, k=x, v=x, mask=mask)))
        x = x + self.ln2(self.d2(self.ff(x)))
        return x


class Encoder(nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.layers = nn.ModuleList([EncoderBlock(embed_dim, num_heads) for _ in range(num_layers)])

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask=mask)
        return x

class VisionTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, num_layers, chunk_size):
        super().__init__()
        self.encoder = Encoder(embed_dim, num_heads, num_layers)
        self.conv_proj = nn.Conv2d(in_channels=1, out_channels=embed_dim, kernel_size=chunk_size, stride=chunk_size)
        self.pos_enc = nn.Parameter(torch.randn((17, embed_dim)))
        self.embed_dim = embed_dim
        self.chunk_size = chunk_size
        self.cls_token = nn.Parameter(torch.randn((1, 1, embed_dim)))
        self.mlp = nn.Sequential(
            torch.nn.Linear(embed_dim, 2*embed_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(2*embed_dim, 10)
        )

    def forward(self, x):
        n, c, h, w = x.shape
        x = self.conv_proj(x)
        x = x.reshape(n, self.embed_dim, h//self.chunk_size * w//self.chunk_size)
        x = x.permute(0, 2, 1)
        batch_cls_token = self.cls_token.expand(n, -1, -1)
        x = torch.cat([batch_cls_token, x], dim=1)
        x += self.pos_enc
        x = self.encoder(x)
        x = self.mlp(x[:, 0])
        return x

In [None]:
class CustomScheduler():
    def __init__(self, optimizer, d_model, n_warmup_steps):
        self.optimizer = optimizer
        self.d_model = d_model
        self.n_warmup_steps = n_warmup_steps
        self.n_steps = 0

    def get_lr(self):
        d_model = self.d_model
        n_warmup_steps = self.n_warmup_steps
        n_steps = self.n_steps
        return (d_model ** -0.5) * min(n_steps ** -0.5, n_steps * n_warmup_steps ** -1.5)

    def zero_grad(self):
        self.optimizer.zero_grad()

    def step(self):
        self.n_steps += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

        self.optimizer.step()

In [None]:
train_dataloader = DataLoader(mnist_trainset, batch_size=128, shuffle=True)
test_dataloader = DataLoader(mnist_testset, batch_size=128, shuffle=False)

model = VisionTransformer(embed_dim=128, num_heads=8, num_layers=4, chunk_size=7)
model.to(device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = Adam(model.parameters(), lr=0.001, betas=(0.9, 0.98), eps=1e-9)
scheduler = CustomScheduler(optimizer, 128, 4000)

criterion = nn.CrossEntropyLoss()

num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    val_loss = 0.0
    for batch_idx, (x, y) in enumerate(tqdm(train_dataloader)):
        x, y = x.to(device), y.to(device)
        scheduler.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        train_loss += loss.item()
        loss.backward()
        scheduler.step()
    if (epoch+1) % 5 == 0:
      with torch.no_grad():
          model.eval()
          for batch_idx, (x, y) in enumerate(tqdm(test_dataloader)):
              x, y = x.to(device), y.to(device)
              logits = model(x)
              loss = criterion(logits, y)
              val_loss = loss.item()
          print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")

100%|██████████| 469/469 [00:12<00:00, 38.99it/s]
100%|██████████| 79/79 [00:01<00:00, 61.05it/s]


Epoch 1, Train Loss: 568.2353913635015, Val Loss: 0.16777215898036957


100%|██████████| 469/469 [00:10<00:00, 44.73it/s]
100%|██████████| 469/469 [00:10<00:00, 44.62it/s]
100%|██████████| 469/469 [00:10<00:00, 44.70it/s]
100%|██████████| 469/469 [00:10<00:00, 45.01it/s]
100%|██████████| 469/469 [00:10<00:00, 44.51it/s]
100%|██████████| 79/79 [00:01<00:00, 65.04it/s]


Epoch 6, Train Loss: 57.41693733818829, Val Loss: 0.0038599923718720675


100%|██████████| 469/469 [00:10<00:00, 44.46it/s]
100%|██████████| 469/469 [00:10<00:00, 44.57it/s]
100%|██████████| 469/469 [00:10<00:00, 44.24it/s]
100%|██████████| 469/469 [00:10<00:00, 44.57it/s]
100%|██████████| 469/469 [00:10<00:00, 44.83it/s]
100%|██████████| 79/79 [00:01<00:00, 64.22it/s]


Epoch 11, Train Loss: 39.60297906771302, Val Loss: 0.0003705897834151983


100%|██████████| 469/469 [00:10<00:00, 44.32it/s]
100%|██████████| 469/469 [00:10<00:00, 44.61it/s]
100%|██████████| 469/469 [00:10<00:00, 44.68it/s]
100%|██████████| 469/469 [00:10<00:00, 44.62it/s]
100%|██████████| 469/469 [00:10<00:00, 44.65it/s]
100%|██████████| 79/79 [00:01<00:00, 64.07it/s]


Epoch 16, Train Loss: 25.35137113649398, Val Loss: 0.00011502775305416435


100%|██████████| 469/469 [00:10<00:00, 44.63it/s]
100%|██████████| 469/469 [00:10<00:00, 44.41it/s]
100%|██████████| 469/469 [00:10<00:00, 44.26it/s]
100%|██████████| 469/469 [00:10<00:00, 44.85it/s]


In [None]:
with torch.no_grad():
    correct, total = 0, 0
    test_loss = 0.0
    for batch in tqdm(test_dataloader, desc="Testing"):
        x, y = batch
        x, y = x.to(device), y.to(device)
        y_hat = model(x)
        loss = criterion(y_hat, y)
        test_loss += loss.detach().cpu().item() / len(test_dataloader)

        correct += torch.sum(torch.argmax(y_hat, dim=1) == y).detach().cpu().item()
        total += len(x)
    print(f"Test loss: {test_loss:.2f}")
    print(f"Test accuracy: {correct / total * 100:.2f}%")

Testing: 100%|██████████| 79/79 [00:01<00:00, 55.40it/s]

Test loss: 0.08
Test accuracy: 97.93%



