In [None]:
import torch
from torch import nn

In [None]:
class Patch_Embedding(nn.Module):

  def __init__(self,
               in_channels:int=3,
               out_channels:int=768,
               stride:int=16,
               patch_size:int=16):
    super().__init__()
    self.patcher = nn.Conv2d(in_channels=in_channels,
                             out_channels=out_channels,
                             kernel_size=patch_size,
                             stride=stride)
    self.flatten = nn.Flatten(start_dim=2,
                              end_dim=3)


  def forward(self,x):
    x_patched = self.patcher(x.unsqueeze(0))
    x_flattened = self.flatten(x_patched)


    return x_flattened.permute(0,2,1)



In [None]:
class MultiHeadedSelfAttention(nn.Module):

  def __init__(self,
               embedding_dim:int=768,
               msa_heads:int=12,
               attn_dropout:float=0.2
               ):

    super().__init__()


    self.lyaer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

    self.msa = nn.MultiheadAttention(
                                    embed_dim=embedding_dim,
                                    num_heads=msa_heads,
                                    dropout=attn_dropout
                                    )




  def forward(self,x):
    x = self.lyaer_norm(x)
    a,_ = self.msa(query=x,key=x,value=x)

    return a

In [None]:
class MLP(nn.Module):


  def __init__(self,
               embedding_dim:int=768,
               mlp_heads:int=3072):

    super().__init__()
    self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)
    self.MlP = nn.Sequential(
        nn.Dropout(p=0.2),
        nn.Linear(in_features=embedding_dim,out_features=mlp_heads),
        nn.GELU(),
        nn.Linear(in_features=mlp_heads,out_features=embedding_dim),
        nn.Dropout(p=0.2)
    )

  def forward(self,x):
    x = self.layer_norm(x)
    x = self.MlP(x)

    return x

In [None]:
class TransformerEncoder(nn.Module):
  def __init__(self,
               mlp_heads:int=3072,
               msa_heads:int=12,
               embedding_dim:int=768,
               attn_dropout:float=0.2,
               dropout:float=0.1):
    super().__init__()

    self.msa = MultiHeadedSelfAttention(embedding_dim,msa_heads,attn_dropout)
    self.mlp = MLP(embedding_dim,mlp_heads)


  def forward(self,x):
    x = self.msa(x) + x
    x = self.mlp(x) + x

    return x

In [None]:
class ViT(nn.Module):
  def __init__(self,
               img_size:int=224,
               in_channels:int=3,
               embedding_dim:int=768,
               patch_size:int=16,
               num_transformer_layers:int=12,
               mlp_heads:int=3072,
               msa_heads:int=12,
               attn_dropout:float=0.2,
               dropout:float=0.1):

    super().__init__()
    self.num_patches = (img_size * img_size) // patch_size**2
    self.patch_embedding = Patch_Embedding(in_channels,
                                   embedding_dim,
                                   patch_size,
                                   patch_size)
    self.cls_token = nn.Parameter(torch.rand(1,1,embedding_dim),requires_grad=True)
    self.pos_embedding = nn.Parameter(torch.rand(1,self.num_patches+1,embedding_dim),requires_grad=True)

    self.encoder = nn.Sequential(*[TransformerEncoder(
                                                      mlp_heads,
                                                      msa_heads,
                                                      embedding_dim,
                                                      attn_dropout,
                                                      dropout) for _ in range(num_transformer_layers)])

    self.classifier = nn.Sequential(
                                    nn.LayerNorm(normalized_shape=embedding_dim),
                                    nn.Linear(
                                              in_features=embedding_dim,
                                              out_features=1000)
                                    )


  def forward(self,x):
    x = self.patch_embedding(x)
    batch_size = x.shape[0]
    cls_token = self.cls_token.expand(batch_size,-1,-1)
    x = torch.cat((cls_token,x),dim=1)

    pos_embedding = self.pos_embedding
    x = x + pos_embedding

    x = self.encoder(x)

    x = self.classifier(x[:,0])

    return x




In [None]:
a = ViT(img_size=224,
              in_channels=3,
               embedding_dim=768,
               patch_size=16,
               num_transformer_layers=12,
               mlp_heads=3072,
               msa_heads=12,
               attn_dropout=0.2,
               dropout=0.1)

In [None]:
!pip install -q torchinfo

In [None]:
from torchinfo import summary
summary(model=a,row_settings=["var_names"])

Layer (type (var_name))                                                Param #
ViT (ViT)                                                              152,064
├─Patch_Embedding (patch_embedding)                                    --
│    └─Conv2d (patcher)                                                590,592
│    └─Flatten (flatten)                                               --
├─Sequential (encoder)                                                 --
│    └─TransformerEncoder (0)                                          --
│    │    └─MultiHeadedSelfAttention (msa)                             2,363,904
│    │    └─MLP (mlp)                                                  4,723,968
│    └─TransformerEncoder (1)                                          --
│    │    └─MultiHeadedSelfAttention (msa)                             2,363,904
│    │    └─MLP (mlp)                                                  4,723,968
│    └─TransformerEncoder (2)                                        