<a href="https://colab.research.google.com/github/Noors-lab/VIT_components/blob/main/VIT_encoder_block.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# the libraries


In [2]:
import torch
import torch.nn as nn



VitInput


In [3]:
class VitInput(nn.Module):
  def __init__(self,img_size=32,patch_size=4,in_channels=3,embed_dim=128):
    super().__init__()

    #patch_embedding

    self.patch_embed=nn.Conv2d(
        in_channels=in_channels,
        stride=patch_size,
        kernel_size=patch_size,
        out_channels=embed_dim

    )
  #number of patches
    num_patches = (img_size//patch_size)**2

  #CLS token
    self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))

  #positional encoding
    self.pos_embed = nn.Parameter(torch.randn(1,num_patches+1,embed_dim))

  def forward(self,x):
    B = x.shape[0]
    x = self.patch_embed(x)
    x = x.flatten(2)
    x = x.transpose(1,2)

    #expand cls token to batch
    cls_tokens = self.cls_token.expand(B,-1,-1)

    # concatenating cls token
    x = torch.cat((cls_tokens,x),dim=1)

    # adding positional embedding
    x = x+self.pos_embed
    return x





encoder_block

In [7]:
class encoder_block(nn.Module):
  def __init__(self,embed_dim=128,num_heads=4,mlp_ratio=4.0,dropout=0):
    super().__init__()
    self.norm1 = nn.LayerNorm(embed_dim)
    self.attn = nn.MultiheadAttention(embed_dim,
                                      num_heads,
                                      dropout=dropout,
                                      batch_first=True) #IMPORTANT

    self.norm2 = nn.LayerNorm(embed_dim)
    hidden_dim = int(embed_dim*mlp_ratio)
    self.MLP = nn.Sequential(
        nn.Linear(embed_dim, hidden_dim),
        nn.GELU(),
        nn.Linear(hidden_dim, embed_dim)
    )

  def forward(self,x):
    x = x + self.attn(self.norm1(x),
                      self.norm1(x),
                      self.norm1(x))[0]


    #feed forward net
    x = x + self.MLP(self.norm2(x))

    return x

testing

In [8]:
x = torch.randn(2, 3, 32, 32)

vit_input = VitInput()
encoder = encoder_block()

tokens = vit_input(x)
out = encoder(tokens)

print(out.shape)

torch.Size([2, 65, 128])
