In [1]:
import torch
from torch import nn
import torchvision
import numpy as np

In [2]:
class PositionalEmbedding(nn.Module):
  def __init__(self,batch_size,num_patches,emb_size):


    super().__init__()

    self.batch_size = batch_size
    self.num_patches = num_patches
    self.emb_size = emb_size


    positions = np.arange(num_patches)[:,np.newaxis]
    depth = np.arange(emb_size)[np.newaxis, :]
    depth = (2*depth//2)/emb_size

    angle_rates = 1 / (10000**depth)

    angle_rads  = positions * angle_rates
    angle_rads[:,0::2] = np.sin(angle_rads[:,0::2])
    angle_rads[:,1::2] = np.cos(angle_rads[:,1::2])


    self.positions = positions * angle_rads

  def forward(self):
    return torch.tensor(np.broadcast_to(self.positions,[self.batch_size,self.num_patches,self.emb_size]),dtype=torch.float32)
     

In [3]:
class PatchEmbedding(nn.Module):

  def __init__(self,
               img_size = 224,
               batch_size=32,
               patch_size=16,
               emb_size=768,
               ):
    

    super().__init__()

    self.img_size = img_size
    self.batch_size = batch_size
    self.patch_size = patch_size
    self.emb_size = emb_size

    self.num_patches = (img_size * img_size) // patch_size**2

    self.cnn_layer = nn.Conv2d(in_channels=3,out_channels=emb_size,kernel_size=self.patch_size,stride=self.patch_size)
    self.flatten = nn.Flatten(start_dim=2,end_dim=3)

    self.pos = PositionalEmbedding(batch_size,self.num_patches+1,emb_size)


  def forward(self,images):
    
    # patch embedding
    patches = self.cnn_layer(images)
    patches = self.flatten(patches)
    patches = patches.permute(0, 2, 1)

    # class learnable embedding
    class_token = nn.Parameter(torch.ones((self.batch_size,1,self.emb_size)),requires_grad=True)
    
    # concat class token with patch embedding
    embedding = torch.cat((class_token, patches), dim=1)
  
    # Positional Embedding
    pos_emb = nn.Parameter(self.pos(),requires_grad=True)
  
    
    # Add positional emb with embeddings

    embedding = embedding + pos_emb

    return embedding

# MultiHeaded Attention

In [4]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self,
                 emb_size=768,
                 batch_size=32,
                 heads=12):
        super(MultiHeadAttention,self).__init__()
        
        

        self.emb_size= emb_size
        self.heads=heads
        self.head_dim = emb_size//heads
        self.batch_size = batch_size
        
        # Queries, Keys and Values Matrices Layers
        self.queries = nn.Linear(self.emb_size,self.emb_size)
        self.keys = nn.Linear(self.emb_size,self.emb_size)
        self.values = nn.Linear(self.emb_size,self.emb_size)
        self.out_projection = nn.Linear(self.emb_size,self.emb_size) 
        
        self.softmax = nn.Softmax(dim=-1)
    
    
    def self_attention(self,queries,keys,values,masked=False):
        
        scores = torch.matmul(queries,keys.transpose(-2,-1))
        scores = scores/np.sqrt(self.head_dim)
 
        
        scores = self.softmax(scores)
        atten = torch.matmul(scores,values)
        
        return atten 
    
     
    def forward(self,x):
        
        
        # For Multiheaded Attention
        queries = self.queries(x)
        keys = self.keys(x)
        values = self.values(x)
        
            
        # Self Attention
        attention = self.self_attention(queries,keys,values)
            
        # Last Projection Matrix 
        
        out = self.out_projection(attention)
                          
        return out


# MLP Block

In [5]:
class MLPBlock(nn.Module):

  def __init__(self,
               batch_size=32,
               num_patches=196,
               emb_size=768,
               mlp_block=3072):
    
    super().__init__()

    self.layer_norm = nn.LayerNorm([batch_size,num_patches+1,emb_size])

    self.mlp = nn.Sequential(
        nn.Linear(emb_size,mlp_block),
        nn.GELU(),
        nn.Linear(mlp_block,emb_size)
    )

  def forward(self,x):

    x = self.layer_norm(x)
    x = self.mlp(x)

    return x

# Transfomer Block

In [6]:
class Transformer(nn.Module):

  def __init__(self,
               num_patches=196,
               emb_size=768,
               mlp_block=3072,
               heads=12,
               batch_size=32):
    
    super().__init__()


    self.ln1 = nn.LayerNorm([batch_size,num_patches+1,emb_size])
    self.ln2 = nn.LayerNorm([batch_size,num_patches+1,emb_size])

    self.mha = MultiHeadAttention(emb_size,batch_size,heads)
    self.mlp = MLPBlock(batch_size,num_patches,emb_size,mlp_block)

  def forward(self,x):

    norm = self.ln1(x)

    msa = self.mha(norm)

    # skip connection

    x = x + msa 

    norm = self.ln2(x)

    mlp_layer = self.mlp(norm)

    # skip connection

    x = mlp_layer + x


    return x



# Vision Transfomer

In [7]:
class ViT(nn.Module):

  
  def __init__(self,
               img_size = 224,
               batch_size=32,
               patch_size=16,
               emb_size=768,
               heads=12,
               mlp_block=3072,
               encoder_layers = 12,
               num_classes=10):
    
    super().__init__()

    self.batch_size = batch_size
    self.emb_size = emb_size

    self.num_patches = (img_size * img_size) // patch_size**2

    # Patch Embedding
    self.patch_emb = PatchEmbedding(img_size,batch_size,patch_size,emb_size)

    # Transformer Block
    self.encoder = [Transformer(self.num_patches,emb_size,mlp_block,heads,batch_size) for _ in range(encoder_layers)]

    # global_avg_pool

    self.global_avg_pool = nn.AdaptiveAvgPool2d((1,emb_size))

    # Classifier
    self.classifier = nn.Linear(emb_size,num_classes)
  
  
  def forward(self,images):

    x = self.patch_emb(images)

    for enc_layer in self.encoder:

      x = enc_layer(x)
    
    x = self.global_avg_pool(x)
    x = torch.reshape(x,(self.batch_size,self.emb_size))
    x = self.classifier(x)

    return x 



In [8]:
# Initialize model

vit = ViT()

x = torch.rand((32,3,224,224))
out = vit(x)
out.shape

torch.Size([32, 10])