In [1]:
import torch
import torch.nn as nn

## **1. Patch Embedding:**

Split images into patches and then embed them.

In [2]:
class PatchEmbed(nn.Module):
  def __init__(self, img_size, patch_size, in_chans = 3 , embed_dim = 768):     # in_chans is  no. of input channels (rgb image = 3) # embed_dim (constant) ...int                                                                                                       
    super().__init__()
    self.img_size = img_size # size of image (square) ....int
    self.patch_size = patch_size  # size of patch (square)...int
    self.n_patches = (img_size // patch_size) ** 2 # no. of patches inside image
    self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) # convolution layer that does both splitting into pathces and their embedding.

  def forward(self, x): # x is a param , input (torch.Tensor), Shape (n_samples, in_chans, img_size, img_size)
    x = self.proj(x) # (n_samples, embed_dim, n_patches ** 0.5, n_patches ** 0.5)
    x = x.flatten(2) # (n_samples, embed_dim, n_patches) # take last 2 dim and flatten them into single dim
    x = x.transpose(1,2) # (n_samples, n_patches, embed_dim) # output
    return x

## **2. Attention:**

In [3]:
class Attention(nn.Module):
  '''
  PARAMETER
  ----------------
   dim : (int) The input and out dim of per token feature.
   n_heads : (int) no. of attention heads.
   qkv_bias : (bool) if True, we include bias in linear proj to the query, key, value projections.
   attn_p : (float) dropout prob applied to query, key, value tensors.
   proj_p : (float) dropout prob applied to output tensors.
  
  ATTRIBUTES
  --------------------
  scale : (float) normalizing constant for the dot product.
  qkv : (nn.Linear) Linear projection for query, key, value.
  proj : (nn.Linear) Linear mapping that takes in the concatenated o/p of all heads and map it into a new space.
  attn_drop, proj_drop : (nn.Dropout) Dropout layers.
  
  '''

  def __init__(self, dim, n_heads = 12, qkv_bias = True , attn_p = 0., proj_p = 0.):
    super().__init__()
    self.n_heads = n_heads
    self.dim = dim
    self.head_dim = dim // n_heads
    self.scale = self.head_dim ** -0.5   # from "Attention is All you Need" paper # not to feed extremely large values to softmax , that could lead small gradients.
    self.qkv = nn.Linear(dim, dim * 3, bias = qkv_bias) # take token embedding, lineary project them and generate query, key, value.
    self.attn_drop = nn.Dropout(attn_p) 
    self.proj = nn.Linear(dim,dim)
    self.proj_drop = nn.Dropout(proj_p)

  def forward(self,x):
    '''
    PARAMETERS:
    ------------
    x : input, (torch.Tensor), Shape (n_samples, n_pathces + 1, dim)   

    RETURNS: (torch.Tensor), Shape (n_samples, n_pathces + 1, dim)
    --------
    **** input and output tensors have same shape . ***
    **** we take 2nd dim as n_patches +1 ; +1 because classtoken is taken as first token in input sequence. . ***
    '''

    n_samples, n_tokens, dim = x.shape
    if dim != self.dim:
      raise ValueError

    qkv = self.qkv(x) # (n_samples, n_pathces + 1, 3 * dim)   
    qkv = qkv.reshape(n_samples, n_tokens, 3, self.n_heads, self.head_dim)  # (n_samples, n_pathces + 1, 3, n_heads, head_dim) # apply LL to 3-d tensor
    qkv = qkv.permute(2,0,3,1,4)  # (3, n_samples, n_heads,n_pathces + 1,head_dim) # apply LL to 3-d tensor
    q, k, v = qkv[0], qkv[1], qkv[2]  # extract key, query and value
    k_t = k.transpose(-2,-1) # (n_samples, n_heads, head_dim, n_patches + 1)
    dp = (q @ k_t) * self.scale # (n_samples, n_heads, n_patches + 1, n_patches + 1)
    attn = dp.softmax(dim=-1) # (n_samples, n_heads, n_patches + 1, n_patches + 1) 
    attn = self.attn_drop(attn) 
    weighted_avg = attn @ v # (n_samples, n_heads,n_pathces + 1,head_dim) 
    weighted_avg = weighted_avg.transpose(1,2) # (n_samples,n_pathces + 1,n_heads,head_dim) 
    weighted_avg = weighted_avg.flatten(2) # (n_samples, n_patches + 1, dim)
    x = self.proj(weighted_avg) # (n_samples, n_patches + 1, dim)
    x = self.proj_drop(x) # (n_samples, n_patches + 1, dim)

    return x


## **3. Multi Layer Perceptron:**

In [4]:
class MLP(nn.Module):
  ''' Multilayer perceptron.

  PARAMETERS :
  ------------
  in_features : (int) no. of input features.
  hidden_features : (int) no. of hidden features.
  out_features : (int) no. of output features.
  p : (float) Dropout probability.

  ATTRIBUTES:
  -----------
  fc1 : (nn.Linear) first linear layer.
  act : (nn.GELU) GELU activation function. # Gaussian Error Linear Unit 
  fc2 : (nn.Linear) second linear layer.
  drop : (nn.Dropout) dropout layer.

  '''

  def __init__(self, in_features, hidden_features, out_features, p = 0.):
    super().__init__()
    self.fc1 = nn.Linear(in_features, hidden_features)
    self.act = nn.GELU()
    self.fc2 = nn.Linear(hidden_features,out_features)
    self.drop = nn.Dropout(p)


  def forward(self, x):
    '''
    PARAMETERS:
    ------------
    x : input, (torch.Tensor), Shape (n_samples, n_pathces + 1, in_features)   

    RETURNS: (torch.Tensor), Shape (n_samples, n_pathces + 1, out_features)
    --------'''

    x = self.fc1(x) # (n_samples, n_pathces + 1, hidden_features)
    x = self.act(x) # (n_samples, n_pathces + 1, hidden_features)
    x = self.drop(x) # (n_samples, n_pathces + 1, hidden_features)
    x = self.fc2(x) # (n_samples, n_pathces + 1, hidden_features)
    x = self.drop(x) # (n_samples, n_pathces + 1, hidden_features)

    return x

## **4. Transformer Block:**

In [5]:
class Block(nn.Module):
  """ Transformer block.
  
  PARAMETERS:
  -----------
  dim : (int) Embedding Dimension.
  n_heads : (int) No. of attention heads.
  mlp_ratio : (float) determines the hidden dm size of 'MLP' module wrt 'dim'.
  qkv_bias : (bool) if True, we include bias in linear proj to the query, key, value projections.
  attn_p : (float) dropout prob applied to query, key, value tensors.
   p : (float) Dropout probability.

  ATTRIBUTES:
  -----------
  norm1, norm2 : (LayerNorm) Layer Normalization.
  attn : (Attention)  Attention module.
  mlp : (MLP) MLP module.
  """ 
  def __init__(self, dim, n_heads, mlp_ratio = 4.0, qkv_bias = True, p = 0., attn_p = 0.):
    super().__init__()
    self.norm1 = nn.LayerNorm(dim, eps=1e-6)
    self.attn = Attention( dim, n_heads = n_heads, qkv_bias = qkv_bias , attn_p = attn_p, proj_p = p)
    self.norm2 = nn.LayerNorm(dim, eps=1e-6)
    hidden_features = int(dim * mlp_ratio)
    self.mlp = MLP(in_features=dim,hidden_features=hidden_features, out_features=dim )

  def forward(self, x):
    '''PARAMETERS:
    ------------
    x : input, (torch.Tensor), Shape (n_samples, n_pathces + 1, dim)   

    RETURNS: (torch.Tensor), Shape (n_samples, n_pathces + 1, dim)
    --------'''
    # both norm is having saperate parameters.
    x = x + self.attn(self.norm1(x))
    x = x + self.mlp(self.norm2(x)) 

    return x

## **5. Vision Transformer:**

In [6]:
class VisionTransformer(nn.Module):
  '''
  PARAMETERS:
  ------------
  img_size : (int) height and width of image (square).
  patch_size : (int) height and width of patch (square).
  in_chans : (int) no. of input channels.
  n_classes: (int) no. of classes.
  embed_dim : (int) Dimensionality of the token/patch embeddings.
  depth : (int) no. of blocks.
  n_heads : (int) no. of attention heads.
  mlp_ratio : (float) determines the hidden dimention of mlp module.
  qkv_bias : (bool) if True, we include bias in linear proj to the query, key, value projections.
  attn_p : (float) dropout prob applied to query, key, value tensors.
  p : (float) Dropout probability.

  ATTRIBUTES:
  -----------
  patch_embed : (PatchEmbed) Instance of 'PatchEmbed' layer.

  cls_token : (nn.Parameter) Learnable parameter that will represnt the first token in sequence.
  It has 'embed_dim' elements.

  pos_emb : (nn.Parameter) Positional embedding of the cls token + all the patches.
  It has '(n_patches + 1) * embed_dim' elements.

  pos_drop : (nn.Dropout) Dropout layer.
  blocks : (nn.ModuleList) List of 'Block' modules.
  norm : (nn.LayerNorm) Layer normalization.
  '''
  def __init__(self, 
               img_size=384,
               patch_size=16,
               in_chans=3,
               n_classes=1000,
               embed_dim=768,
               depth=12,
               n_heads=12,
               mlp_ratio=4,
               qkv_bias=True,
               p=0.,
               attn_p=0.):
      super().__init__()
      self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
      self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim)) # initialize with zeroes
      self.pos_embed = nn.Parameter(torch.zeros(1,1 + self.patch_embed.n_patches,embed_dim))
      self.pos_drop = nn.Dropout(p=p)
      self.blocks = nn.ModuleList(
          [
          Block(
              dim=embed_dim,
              n_heads=n_heads,
              mlp_ratio=mlp_ratio,
              qkv_bias=qkv_bias,
              p=p,
              attn_p=attn_p,
          )
          for _ in range(depth)
          ]
      )  

      self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
      self.head = nn.Linear(embed_dim, n_classes)

  def forward(self, x):
    '''PARAMETERS:
    ------------
    x : input, (torch.Tensor), Shape (n_samples, in_chans, img_size, img_size)   
    RETURNS: logits: (torch.Tensor), logits over all the classes, Shape (n_samples, n_classes)
    --------
    '''
    n_samples = x.shape[0]
    x = self.patch_embed(x)

    cls_token = self.cls_token.expand(n_samples, -1, -1)  # (n_sample, 1, embed_dim)
    x = torch.cat((cls_token, x), dim = 1) # (n_samples, 1 + n_patches, embed_dim) # append cls_token at front in x.
    x = x + self.pos_embed # (n_samples, 1 + n_patches, embed_dim)
    x = self.pos_drop(x)

    for block in self.blocks: # apply all the blocks of encoder on x
      x = block(x) 

    x = self.norm(x) # apply layer normalization

    cls_token_final = x[:,0] # out of all embedding we only select class embedding (CLS token)
    x = self.head(cls_token_final) # classifier # This embedding encode meaning of entire image

    return x


# **Testing part:**

In [7]:
!pip install timm



In [8]:
import numpy as np
import timm
import torch
#from custom import VisionTransformer

In [9]:
# Helper function 
def get_n_param(module):
  return sum(p.numel() for p in module.parameters() if p.requires_grad) # to count no. of elements

def assert_tensors_equal(t1,t2): # check the 2 tensors are equal or not.
  a1,a2 = t1.detach().numpy(), t2.detach().numpy()
  np.testing.assert_allclose(a1,a2)

# load pretrained vision transformer models from timm
model_name = "vit_base_patch16_384" 
model_official = timm.create_model(model_name, pretrained = True)
model_official.eval()
print(type(model_official))

# declare hyperparameters corruspoding to pretrained model
custom_config = {
    "img_size" : 384,
    "patch_size": 16,
    "in_chans": 3,
    # n_classes=1000,
    "embed_dim": 768,
    "depth":12,
    "n_heads": 12,
    "mlp_ratio": 4,
    "qkv_bias": True
    }

# instantiate the implemented custom model and set it to evaluation mode
model_custom = VisionTransformer(**custom_config)
model_custom.eval() 

# iterate through all the parameters of official and custom network
for (n_o,p_o), (n_c,p_c) in zip(model_official.named_parameters(),model_custom.named_parameters()):
  assert p_o.numel() == p_c.numel() # first we check for each parameters no. of elements are equal.
  #print("{n_o} | {n_c}")

  p_c.data[:] = p_o.data

  assert_tensors_equal(p_c.data, p_o.data)

inp = torch.rand(1,3,384,384) # create a random tensor with exact size
# run forward pass
res_c = model_custom(inp)
res_o = model_official(inp)

# Asserts
assert get_n_param(model_custom) == get_n_param(model_official) # check no. of parameter for both models
assert_tensors_equal(res_c, res_o) # take 2 tensors, make sure they are identical

# save custom model
torch.save(model_custom,"TransformerModel.pth")

<class 'timm.models.vision_transformer.VisionTransformer'>


In [10]:
#!git clone https://github.com/lucidrains/vit-pytorch

# **get prediction:**

In [11]:
import numpy as np
from PIL import Image
import torch
import cv2 
k = 10

#imagenet_labels = dict(enumerate(open("classes.txt")))

model = torch.load("TransformerModel.pth")
model.eval()

def prediction(image):
  # new image loading
  img = cv2.resize(cv2.imread(image), (384, 384)) 
  img = (np.array(img) / 128) - 1 # in the range -1, 1
  inp = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).to(torch.float32)
  logits = model(inp)
  probs = torch.nn.functional.softmax(logits, dim=1)
  #print(probs.shape)
  op = (torch.argmax(probs)).item()
  print("the predicted class index is:", op)
  t = []
  with open('classes.txt', encoding='utf8') as f:
      for line in f:
          t.append(line.strip())
  print("the predicted class is:", t[op])

In [12]:
prediction("cat.jpg")

the predicted class index is: 281
the predicted class is: 281: 'tabby, tabby cat',


In [13]:
prediction("dog.png")

the predicted class index is: 151
the predicted class is: 151: 'Chihuahua',


In [14]:
prediction("dog2.png") # adv dog

the predicted class index is: 850
the predicted class is: 850: 'teddy, teddy bear',


## **How linear layer behave, when we have 3-dim or more dim tensors!!**

In [15]:
import torch
module = torch.nn.Linear(10,20)
module

Linear(in_features=10, out_features=20, bias=True)

In [16]:
n_samples = 40 
in_2d = torch.rand(n_samples,10) # last dim must be same as declared in module's linear layer 'in_feature'
module(in_2d).shape

torch.Size([40, 20])

In [17]:
in_3d = torch.rand(n_samples,33,10) # last dim must be same as declared in module's linear layer 'in_feature'
module(in_3d).shape

torch.Size([40, 33, 20])

In [18]:
in_5d = torch.rand(n_samples,2,3,4,10) # last dim must be same as declared in module's linear layer 'in_feature'
module(in_5d).shape

torch.Size([40, 2, 3, 4, 20])

In [19]:
in_7d = torch.rand(n_samples,2,3,4,5,6,10) # last dim must be same as declared in module's linear layer 'in_feature'
module(in_7d).shape

torch.Size([40, 2, 3, 4, 5, 6, 20])

## **Basic Property of Layer Normalization:**
Layernorm normalize data for each sample.

In [20]:
import torch

In [21]:
inp = torch.tensor([[0,4.], [-1,7],[3,5]])

In [22]:
n_samples, n_features = inp.shape

In [23]:
module = torch.nn.LayerNorm(n_features,elementwise_affine=False)

In [24]:
sum(p.numel() for p in module.parameters() if p.requires_grad) # 0 lernable parameter

0

In [25]:
inp.mean(-1),inp.std(-1,unbiased= False)

(tensor([2., 3., 4.]), tensor([2., 4., 1.]))

In [26]:
module(inp).mean(-1),module(inp).std(-1,unbiased= False)

(tensor([0., 0., 0.]), tensor([1.0000, 1.0000, 1.0000]))

In [27]:
module = torch.nn.LayerNorm(n_features,elementwise_affine=True)

In [28]:
sum(p.numel() for p in module.parameters() if p.requires_grad) # 4 lernable parameter

4

4 learnable parameter , contains bias and weight parameter of module.

In [29]:
module.bias, module.weight

(Parameter containing:
 tensor([0., 0.], requires_grad=True), Parameter containing:
 tensor([1., 1.], requires_grad=True))

In [30]:
module(inp).mean(-1),module(inp).std(-1,unbiased= False)

(tensor([0., 0., 0.], grad_fn=<MeanBackward1>),
 tensor([1.0000, 1.0000, 1.0000], grad_fn=<StdBackward0>))

In [31]:
module(torch.rand(n_samples,2,3,4,5,6,n_features)).shape

torch.Size([3, 2, 3, 4, 5, 6, 2])

In [32]:
module(torch.rand(n_samples,2,3,4,5,6,n_features)).mean(-1) # -1 becoz always last dim being norm.

tensor([[[[[[ 0.0000e+00,  6.2585e-07,  5.9605e-08,  2.9802e-08,  1.4901e-07,
              0.0000e+00],
            [ 5.9605e-08, -2.9802e-08, -5.9605e-08, -2.9802e-08,  2.9802e-08,
             -5.9605e-08],
            [ 1.4901e-07, -2.9802e-08, -2.9802e-08, -1.1921e-07, -2.9802e-08,
             -8.9407e-08],
            [-1.4901e-07,  8.9407e-08, -2.0862e-07, -1.7881e-07,  5.9605e-08,
             -2.9802e-08],
            [-1.2517e-06, -2.0862e-07, -5.9605e-08,  1.1921e-07,  2.9802e-07,
              2.9802e-08]],

           [[-2.9802e-07, -6.0201e-06,  5.9605e-08, -2.3842e-07,  0.0000e+00,
             -8.9407e-08],
            [-5.9605e-08,  5.9605e-08, -2.3842e-07,  8.9407e-08, -1.1921e-07,
             -2.9802e-08],
            [ 2.9802e-07,  1.1921e-07, -2.9802e-07,  0.0000e+00, -1.1921e-07,
              1.1921e-07],
            [-2.9802e-08,  2.3842e-07, -2.9802e-08,  0.0000e+00,  4.4703e-07,
              2.9802e-08],
            [-2.0862e-07, -1.7881e-07,  5.9605e-08,  