# Vision Transformer from Scratch

In [1]:
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

## Patch Embedding

* Split the image into patches based **patch_size**.

  `Number of patches (n_patches) = (img_h // patch_size) * (img_w // patch_size)`

* Flatten patches into 1D vector. size of each 1D vector will be `patch_size * patch_size * in_channels`.

  Ex : If patch_size is 16 and the image is a color image(RGB)
      `16 * 16 * 3(in_channels (RGB))`

* Project the 1D vectors to an Embed vector `embed_dim` to create patch embedding
`Projected vector = W⋅x+b`

In [2]:
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size, in_channels, embed_dim):
        super(PatchEmbedding, self).__init__()

        self.patch_size = patch_size
        self.embed_dim = embed_dim

        self.projection = nn.Linear(in_channels * patch_size * patch_size, embed_dim)

    def forward(self, x):

        batch_size, in_channels, height, width = x.size()
        n_patches = (height // self.patch_size) * (width // self.patch_size)

        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.contiguous().view(batch_size, in_channels, -1, self.patch_size*self.patch_size)

        x = x.permute(0, 2, 1, 3)

        x = x.contiguous().view(batch_size, n_patches, -1)

        x = self.projection(x)

        return x

## Position Encoding

In [3]:
class Position_Encoding(nn.Module):
    def __init__(self, embed_dim, n_patches):
        super(Position_Encoding, self).__init__()

        self.position_encoding = nn.Parameter(torch.randn(1, n_patches + 1, embed_dim))

    def forward(self, x):
        return x + self.position_encoding


## Multi Head Self Attention

* Attention Scores are calculated as Scaled dot product between Querys and Keys generated from the same sequence

* Attention Score:
  `Q.inv(K)/sqrt(head dim)`

* The computed attention scores are then multiplied with the Value vector to generate the output

In [4]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()

        assert embed_dim % num_heads == 0 , "Embedding dimesion should be divisible by num heads"

        self.num_heads = num_heads
        self.embed_dim = embed_dim

        self.head_dim = embed_dim // num_heads

        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)

        self.out = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):

        batch_size, n_patches, embed_dim = x.size()

        Q = self.query(x)
        K = self.query(x)
        V = self.query(x)


        Q = Q.view(batch_size, n_patches, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        K = K.view(batch_size, n_patches, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        V = V.view(batch_size, n_patches, self.num_heads, self.head_dim).permute(0, 2, 1, 3)


        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))

        attention = torch.softmax(scores, dim=-1)


        out = torch.matmul(attention, V)

        # Reshape back to original dimensions: [batch_size, n_patches, embed_dim]
        out = out.permute(0, 2, 1, 3).contiguous().view(batch_size, n_patches, embed_dim)

        out = self.out(out)

        return out

## Transformer EncoderBlock

In [5]:
class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout=0.1):
        super(TransformerEncoder, self).__init__()

        self.self_attn = MultiHeadSelfAttention(embed_dim, num_heads)
        self.linear1 = nn.Linear(embed_dim, mlp_dim)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(mlp_dim, embed_dim)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)


    def forward(self, x):
        attn_out = self.self_attn(x)
        x = x + self.dropout1(attn_out)
        x = self.norm1(x)

        ff_out = self.linear2(self.dropout(F.gelu(self.linear1(x))))
        x = x + self.dropout2(ff_out)
        x = self.norm2(x)

        return x




## Vision Transformer Class

In [6]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size, in_channels, patch_size, embed_dim, mlp_dim, num_heads, num_layers, num_classes):
        super(VisionTransformer, self).__init__()

        self.patch_size = patch_size

        # Number of Patches
        n_patches = (img_size // patch_size) * (img_size // patch_size)

        # Generate Patch Embeddings
        self.patch_embedding = PatchEmbedding(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)

        # Positional Encoding
        self.position_encoding = Position_Encoding(embed_dim, n_patches)

        # classification token
        self.classification_token = nn.Parameter(torch.randn(1, 1, embed_dim))

        # transformer encoder layers
        self.transformer_encoders = nn.ModuleList([
            TransformerEncoder(embed_dim, num_heads, mlp_dim) for _ in range(num_layers)
        ])

        self.classifier = nn.Linear(embed_dim, num_classes)


    def forward(self, x):
        x = self.patch_embedding(x)
        batch_size = x.size(0)
        cls_tokens = self.classification_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = self.position_encoding(x)

        for layer in self.transformer_encoders:
            x = layer(x)

        cls_token_out = x[:, 0]

        out = self.classifier(cls_token_out)

        return out






## Model Training

We will be using the CIFAR-10K dataset to train the Vision Transformer Model.

In [7]:
from torchvision.datasets import CIFAR10
from torchvision.transforms import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Loading CIFAR-10 Dataset
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)



Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 43987480.21it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [None]:
# setup device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model Training

img_size = 224
in_channels = 3
patch_size = 16
embed_dim = 768
mlp_dim = 3072
num_heads = 12
num_layers = 12
num_classes = 10

model = VisionTransformer(img_size, in_channels, patch_size, embed_dim, mlp_dim, num_heads, num_layers, num_classes)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


for epoch in range(10):
    model.train()
    for images , labels in train_loader:
        images , labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/10], Loss: {loss.item():.4f}')

Epoch [1/10], Loss: 2.2908
