<a href="https://colab.research.google.com/github/Vaibhavrathore1999/ML-building-blocks/blob/main/ViT_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import math
import torch
import torch.nn as nn

In [50]:
class GELU(nn.Module):
  def __init__(self):
    super(GELU,self).__init__()
    pass
  def forward(self,x):
    return torch.tensor(0.5*x*(1+torch.tanh(torch.tensor(2/math.pi)**0.5)*(x+0.044715*torch.pow(x,3.0))))


## Convert Image into Patches

In [7]:
class PatchEmbeddings(nn.Module):
  def __init__(self,img_size,patch_size,in_channels,emb_size):
    assert (img_size%patch_size==0), "image size must be divisible by patch size"
    super(PatchEmbeddings,self).__init__()
    self.img_size=img_size
    self.patch_size=patch_size
    self.n_patches=(img_size//patch_size)**2
    # Convert each patch into embedding
    self.proj=nn.Conv2d(in_channels,emb_size,kernel_size=(patch_size,patch_size),stride=patch_size)
  def forward(self,x):
    # Shape of x ---> (batch_size,num_channels,image_size,image_size). get converted to (batch_size,num_patches,emb_dim)
    x=self.proj(x)      # ----> (batch_size,emb_dim,img_size//patch_size,img_size//patch_size)
    x=x.flatten(2)      # ----> (batch_size,emb_dim,num_patches)
    x=x.transpose(1,2)  # ----> (batch_size,num_patches,emb_dim)
    return x

In [9]:
class PatchEmbeddings_v2(nn.Module):
  def __init__(self,img_size,patch_size,in_channels,emb_size):
    assert (img_size%patch_size==0), "image size must be divisible by patch size"
    super(PatchEmbeddings,self).__init__()
    self.img_size=img_size
    self.patch_size=patch_size
    self.n_patches=(img_size//patch_size)**2
    self.linear=nn.Linear(in_channels*patch_size*patch_size,emb_size)
  def forward(self,x):
    # Shape of x ---> (batch_size,num_channels,image_size,image_size). get converted to (batch_size,num_patches,emb_dim)
    patches=x.unfold(2,self.patch_size.self.patch_size).unfold(3,self.patch_size,self.patch_size)
    patches=patches.contiguous().view(x.shape[0,self.n_patches,-1])
    return self.linear(patches)


## Add Learnable Positional Embeddings

In [None]:
class Embedding(nn.Module):
  def __init__(self,img_size,patch_size,in_channels,emb_size,dropout):
    super(Embedding,self).__init__()
    self.patch_embeddings=PatchEmbeddings(img_size,patch_size,in_channels,emb_size)
    # Create a learnable [CLS] Token
    self.cls_token=nn.Parameter(torch.randn(1,1,emb_size))
    self.dropout=nn.Dropout(dropout)
    self.pos_embbedings=nn.Parameter(torch.randn(1,self.patch_embeddings.n_patches+1,emb_size))
  def forward(self,x):
    x=self.patch_embeddings(x)
    cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
    x=torch.cat((cls_tokens,x),dim=1)
    x=x+self.pos_embeddings
    x=self.dropout(x)

## Add Fixed sine and cosine pos embeddings

In [10]:
class Embedding_v2(nn.Module):
  def __init__(self,img_size,patch_size,in_channels,emb_size,dropout):
    super(Embedding,self).__init__()
    self.patch_embeddings=PatchEmbeddings(img_size,patch_size,in_channels,emb_size)
    # Create a learnable [CLS] Token
    self.cls_token=nn.Parameter(torch.randn(1,1,emb_size))
    self.dropout=nn.Dropout(dropout)
    self.pos_embbedings=self.get_sinusoidal_positional_embeddings(self.patch_embeddings.n_patches+1,emb_size)

  def get_sinusoidal_positional_embeddings(self, num_positions, emb_size):
    """
    Generate sinusoidal positional embeddings using both sine and cosine functions.
    """
    position = torch.arange(0, num_positions, dtype=torch.float).unsqueeze(1)  # (num_positions, 1)
    div_term = torch.exp(torch.arange(0, emb_size, 2).float() * -(math.log(10000.0) / emb_size))  # (emb_size / 2)

    # Apply sine to even indices (2i)
    pos_emb = torch.zeros(num_positions, emb_size)
    pos_emb[:, 0::2] = torch.sin(position * div_term)
    pos_emb[:, 1::2] = torch.cos(position * div_term)
    return pos_emb.unsqueeze(0)  # Shape: (1, num_positions, emb_size)

  def forward(self,x):
    x=self.patch_embeddings(x)
    cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
    x=torch.cat((cls_tokens,x),dim=1)           # ----> (batch_size,num_patches+1,emb_dim)
    x=x+self.pos_embeddings
    x=self.dropout(x)

In [11]:
class AttentionHead(nn.Module):
  def __init__(self,d_in,d_out,dropout):
    self.W_k=nn.Linear(d_in,d_out)        # Shape ---> (d_in,d_out)
    self.W_q=nn.Linear(d_in,d_out)
    self.W_v=nn.Linear(d_in,d_out)
    self.dropout=nn.Dropout(dropout)
  def forward(self,x):
    # Shape --> (batch, num_tokens, d_in)
    key=x@self.W_k          # Shape --> (batch, num_tokens, d_out)
    query=x@self.W_q        # Shape --> (batch, num_tokens, d_out)
    value=x@self.W_v        # Shape --> (batch, num_tokens, d_out)
    # Shape of attention matrix ----> (batch, num_tokens , num_tokens)
    attn_matrix= query@key.transpose(-2,-1)
    attn_matrix=attn_matrix/torch.sqrt(torch.tensor(key.shape[-1]))
    attn_matrix=torch.softmax(attn_matrix,dim=-1)
    attn_matrix=self.dropout(attn_matrix)
    context_vec=attn_matrix@value
    return context_vec            # # Shape --> (batch, num_tokens, d_out)


In [12]:
class MultiHeadWraper(nn.Module):
  def __init__(self,d_in,d_out,num_heads,dropout):
    super(MultiHeadWraper,self).__init__()
    self.num_heads=num_heads
    self.d_out=d_out
    self.d_in=d_in
    self.heads=nn.ModuleList([AttentionHead(d_in,d_out,dropout) for _ in range(num_heads)])

  def forward(self,x):
    # Shape --> (batch, num_tokens, d_in)
    x_=torch.cat([self.heads[i](x) for i in range(self.num_heads)],dim=-1)        # Shape ---> (batch,num_tokens,d_out*num_heads)
    return x_

In [14]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_in,d_out,num_heads,dropout):
    super(MultiHeadAttention,self).__init__()
    self.num_heads=num_heads
    self.d_out=d_out
    self.d_in=d_in
    assert (d_out%num_heads==0), "d_out should be divisible by num_haeds"
    self.W_k=nn.Linear(d_in,d_out)        # Shape ---> (d_in,d_out)
    # self.W_k=nn.Parameter(torch.randn(d_in,d_out))
    self.W_q=nn.Linear(d_in,d_out)
    self.W_v=nn.Linear(d_in,d_out)
    self.dim_head=d_out//num_heads
    self.dropout=nn.Dropout(dropout)

  def forward(self,x):
    # Shape ---> (batch, num_tokens, d_in)
    keys=self.W_k(x)        # Shape ----> (batch,num_tokens,d_out)
    # keys=x@self.W_k
    queries=self.W_q(x)     # Shape ----> (batch,num_tokens,d_out)
    values=self.W_v(x)      # Shape ----> (batch,num_tokens,d_out)
    keys=keys.view(keys.shape[0],keys.shape[1],self.num_heads,self.dim_head)                   # Shape ----> (batch,num_tokens,num_heads,dim_head)
    queries=queries.view(queries.shape[0],queries.shape[1],self.num_heads,self.dim_head)       # Shape ----> (batch,num_tokens,num_heads,dim_head)
    values=values.view(values.shape[0],values.shape[1],self.num_heads,self.dim_head)           # Shape ----> (batch,num_tokens,num_heads,dim_head)
    keys=keys.transpose(1,2)                    # Shape ----> (batch,num_heads,num_tokens,dim_head)
    queries=queries.transpose(1,2)               # Shape ----> (batch,num_heads,num_tokens,dim_head)
    values=values.transpose(1,2)               # Shape ----> (batch,num_heads,num_tokens,dim_head)
    attn_matrix=queries@keys.transpose(-2,-1)     # Shape ---> (batch,num_heads,num_tokens,num_tokens)
    attn_matrix=attn_matrix/torch.sqrt(torch.tensor(keys.shape[-1]))
    attn_matrix=torch.softmax(attn_matrix,dim=-1)
    attn_matrix=self.dropout(attn_matrix)
    context_vec=attn_matrix@values        # Shape ---> (batch, num_heads, num_tokens,dim_head)
    context_vec=context_vec.transpose(1,2)      # Shape ---> (batch, num_tokens, num_heads,dim_head)
    context_vec=context_vec.contiguous().view(context_vec.shape[0],context_vec.shape[1],-1)         # Shape ---> (batch, num_tokens,d_out)
    return context_vec


In [15]:
class MLP(nn.Module):
  def __init__(self,d_in,hidden_size,dropout):
    super().__init__()
    self.fc1=nn.Linear(d_in,hidden_size)
    self.gelu=GELU()
    self.fc2=nn.Linear(hidden_size,d_in)
    self.dropout=nn.Dropout(dropout)
  def forward(self,x):
    x=self.fc1(x)
    x=self.gelu(x)
    x=self.fc2(x)
    x=self.dropout(x)
    return x

In [113]:
class LayerNorm(nn.Module):
  def __init__(self,d_out,eps=1e-6):
    super().__init__()
    self.eps=eps
    self.d_out=d_out
    self.gamma=nn.Parameter(torch.ones(d_out))
    self.beta=nn.Parameter(torch.zeros(d_out))

  def forward(self,x):
    # Shape ---> (batch, num_tokens,d_out)
    mean=x.mean(dim=-1,keepdim=True)
    std=x.std(dim=-1,keepdim=True)
    return ((x-mean)/(std+self.eps))*self.gamma + self.beta


In [127]:
class Block(nn.Module):
    def __init__(self, d_in, d_out, num_heads, mlp_size, dropout):
        super().__init__()
        self.mha = MultiHeadAttention(d_in, d_out, num_heads, dropout)
        self.norm1 = torch.nn.LayerNorm(d_in)
        self.norm2 = torch.nn.LayerNorm(d_out)
        self.mlp   = MLP(d_out, mlp_size, dropout)

    def forward(self, x):
        # 1) Pre‐attention norm
        x_norm = self.norm1(x)                # (batch, n_tokens, d_in)
        x_attn = self.mha(x_norm)             # (batch, n_tokens, d_out)

        # 2) Residual (project x up only if needed)
        res = x             # (batch, n_tokens, d_out)
        x = x_attn + res                      # (batch, n_tokens, d_out)

        # 3) Pre‐FFN norm
        x_norm = self.norm2(x)                # (batch, n_tokens, d_out)
        x_ffn  = self.mlp(x_norm)             # (batch, n_tokens, d_out)

        # 4) FFN residual
        return x + x_ffn


In [128]:
class Encoder(nn.Module):
  def __init__(self,d_in,d_out,num_heads,mlp_size,dropout,num_blocks):
    super().__init__()
    self.blocks=nn.ModuleList([Block(d_in,d_out,num_heads,mlp_size,dropout) for _ in range(num_blocks)])
  def forward(self,x):
    for block in self.blocks:
      x=block(x)
    return x          # Shape ---> (batch, num_tokens+1,d_out)

In [129]:
batch=4
num_tokens=5
emb_dim=768
example=torch.randn(batch,num_tokens,emb_dim)
print(example.shape)

torch.Size([4, 5, 768])


In [130]:
encoder=Encoder(d_in=emb_dim,d_out=768,num_heads=16,mlp_size=3072,dropout=0.1,num_blocks=12)
print(encoder)

Encoder(
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (mha): MultiHeadAttention(
        (W_k): Linear(in_features=768, out_features=768, bias=True)
        (W_q): Linear(in_features=768, out_features=768, bias=True)
        (W_v): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (fc1): Linear(in_features=768, out_features=3072, bias=True)
        (gelu): GELU()
        (fc2): Linear(in_features=3072, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
)


In [131]:
output=encoder(example)
print(output.shape)

torch.Size([4, 5, 768])


  return torch.tensor(0.5*x*(1+torch.tanh(torch.tensor(2/math.pi)**0.5)*(x+0.044715*torch.pow(x,3.0))))
