<a href="https://colab.research.google.com/github/Jainam051/Vision-Transformer-from-scratch/blob/main/ViT_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:


import einops
from tqdm.notebook import tqdm

from torchsummary import summary

import torch
from torch import nn
import torchvision
import torch.optim as optim
from torchvision.transforms import Compose, Resize, ToTensor, Normalize, RandomHorizontalFlip, RandomCrop



In [None]:


!jupyter nbextension enable --py widgetsnbextension



In [None]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

patch_size = 16
latent_size = 768
n_channels = 3
num_heads = 12
num_encoders = 12
dropout = 0.1
num_classes = 10
size = 224

epochs = 10
base_lr = 10e-3
weight_decay = 0.03
batch_size = 8

cpu


In [None]:
class InputEmbedding(nn.Module):
    def __init__(self, patch_size=patch_size, n_channels=n_channels, device=device, latent_size=latent_size, batch_size=batch_size):
        super(InputEmbedding, self).__init__()
        self.latent_size = latent_size
        self.patch_size = patch_size
        self.n_channels = n_channels
        self.device = device
        self.batch_size = batch_size
        self.input_size = self.patch_size*self.patch_size*self.n_channels

        # Linear projection
        self.linearProjection = nn.Linear(self.input_size, self.latent_size)

        # Class token
        self.class_token = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)

        # Positional embedding
        self.pos_embedding = nn.Parameter(torch.randn(self.batch_size, 1, self.latent_size)).to(self.device)

    def forward(self, input_data):
        input_data = input_data.to(self.device)

        # Patchify input image
        patches = einops.rearrange(
            input_data, 'b c (h h1) (w w1) -> b (h w) (h1 w1 c)', h1=self.patch_size, w1=self.patch_size)

        #print(input_data.size())
        #print(patches.size())

        linear_projection = self.linearProjection(patches).to(self.device)
        b, n, _ = linear_projection.shape

        linear_projection = torch.cat((self.class_token, linear_projection), dim=1)
        pos_embed = einops.repeat(self.pos_embedding, 'b 1 d -> b m d', m=n+1)

        linear_projection += pos_embed

        return linear_projection


In [None]:


test_input = torch.randn((8, 3, 224, 224))
test_class = InputEmbedding().to(device)
embed_test = test_class(test_input)



In [None]:


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, latent_size, num_heads, dropout=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        assert latent_size % num_heads == 0, "latent_size must be divisible by num_heads"

        self.latent_size = latent_size
        self.num_heads = num_heads
        self.head_dim = latent_size // num_heads

        # Learnable projections
        self.q_proj = nn.Linear(latent_size, latent_size)
        self.k_proj = nn.Linear(latent_size, latent_size)
        self.v_proj = nn.Linear(latent_size, latent_size)
        self.out_proj = nn.Linear(latent_size, latent_size)

        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.head_dim)

    def forward(self, x):
        # x: (batch_size, seq_len, latent_size)
        B, N, _ = x.shape

        # Linear projections: (B, N, latent_size) -> (B, N, latent_size)
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # Reshape into heads: (B, N, latent_size) -> (B, num_heads, N, head_dim)
        Q = einops.rearrange(Q, 'b n (h d) -> b h n d', h=self.num_heads)
        K = einops.rearrange(K, 'b n (h d) -> b h n d', h=self.num_heads)
        V = einops.rearrange(V, 'b n (h d) -> b h n d', h=self.num_heads)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale  # (B, H, N, N)
        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        context = torch.matmul(attn, V)  # (B, H, N, D)
        context = einops.rearrange(context, 'b h n d -> b n (h d)')  # Concatenate heads

        return self.out_proj(context)  # (B, N, latent_size)

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, latent_size=latent_size, num_heads=num_heads, dropout=dropout):
        super(EncoderBlock, self).__init__()

        self.norm1 = nn.LayerNorm(latent_size)
        self.attn = MultiHeadSelfAttention(latent_size, num_heads, dropout)
        self.norm2 = nn.LayerNorm(latent_size)

        self.feed_forward = nn.Sequential(
            nn.Linear(latent_size, latent_size * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(latent_size * 4, latent_size),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # Self-attention + residual
        x = x + self.attn(self.norm1(x))
        # Feedforward + residual
        x = x + self.feed_forward(self.norm2(x))
        return x

In [None]:
test_encoder = EncoderBlock().to(device)
test_encoder(embed_test)

In [None]:


class Vit(nn.Module):
    def __init__(self, num_encoders=num_encoders, latent_size=latent_size, device=device, num_classes=num_classes, dropout=dropout):
        super(Vit, self).__init__()
        self.num_encoder = num_encoders
        self.latent_size = latent_size
        self.device = device
        self.num_classes = num_classes
        self.dropout = dropout

        self.embedding = InputEmbedding()

        # Create the stack of encoders
        self.encStack = nn.ModuleList([EncoderBlock() for i in range(self.num_encoder)])

        self.MLP_head = nn.Sequential(
            nn.LayerNorm(self.latent_size),
            nn.Linear(self.latent_size, self.latent_size),
            nn.Linear(self.latent_size, self.num_classes)
        )

    def forward(self, test_input):
        enc_output = self.embedding(test_input)

        for enc_layer in self.encStack:
            enc_output = enc_layer(enc_output)

        cls_token_embed = enc_output[:, 0]

        return self.MLP_head(cls_token_embed)



In [None]:


model = Vit().to(device)
vit_output = model(test_input)
print(vit_output)
print(vit_output.size())

