In [241]:
import torch
from torch import nn

In [242]:
class PatchEmbedding(nn.Module):

  def __init__(self,
               in_channels:int=3,
               stride:int=16,
               kernel:int=16,
               embedding_dim:int=768):
    super().__init__()


    self.conv = nn.Conv2d(in_channels=in_channels,
                          out_channels=embedding_dim,
                          stride=stride,
                          kernel_size=kernel,
                          padding=0,
                          )


    self.flatten = nn.Flatten(start_dim=2,end_dim=3)



  def forward(self,x):
    x = self.conv(x)
    x = self.flatten(x)

    return x.permute(0,2,1)

In [243]:
patch_size=16
c=3
height = 224
width = 224
print(f"patch size :{patch_size}, C :{c} , height :{height} , widht :{width}")
embedding_dim = (patch_size**2)*c
print(f"embedding dimension :{embedding_dim}")

patch = PatchEmbedding(in_channels=3,
                       stride=16,
                       kernel=16,
                       embedding_dim=768)

patch size :16, C :3 , height :224 , widht :224
embedding dimension :768


In [244]:
class MultiheadSelfAttentionBlock(nn.Module):


  def __init__(self,
               embedding_dim:int=768,
               num_heads:int=12,
               attn_dropout:float=0.0):

    super().__init__()


    self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)


    self.multi_head_attention = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                      num_heads=num_heads,
                                                      dropout=attn_dropout,
                                                      batch_first=True)

  def forward(self,x):

    x = self.layer_norm(x)
    x,_ = self.multi_head_attention(query=x,
                                  key=x,
                                  value=x,
                                  need_weights=False)
    return x
    print(type(x))


    return x

In [245]:
class MLPBlock(nn.Module):


  def __init__(self,
               embedding_dim:int=768,
               mlp_dim:int=3072,
               dropout:float=0.1
               ):
    super().__init__()

    self.layer_norm = nn.LayerNorm(normalized_shape = embedding_dim)

    self.mlp = nn.Sequential(
        nn.Linear(in_features=embedding_dim,
                  out_features=mlp_dim),
        nn.GELU(),
        nn.Dropout(p=dropout),
        nn.Linear(in_features=mlp_dim,
                  out_features=embedding_dim),
        nn.Dropout(p=dropout)
    )


  def forward(self,x):

    x = self.layer_norm(x)


    x = self.mlp(x)



    return x

In [246]:
!pip install -q torchinfo
from torchinfo import summary

In [247]:
class TransformerEncoderBlock(nn.Module):


  def __init__(self,
               embedding_dim:int=768,
               num_heads:int=12,
               mlp_heads:int=3072,
               attn_dropout:float=0.0,
               dropout:float=0.1
               ):
    super().__init__()

    self.msa_block = MultiheadSelfAttentionBlock(embedding_dim=embedding_dim,
                            num_heads=num_heads,
                            attn_dropout=attn_dropout)
    self.mlp_block = MLPBlock(embedding_dim=embedding_dim,
                            mlp_dim=mlp_heads,
                            dropout=dropout)



  def forward(self,x):
    x = self.msa_block(x)
    x = self.mlp_block(x)
    return x

In [248]:
class ViT(nn.Module):

  def __init__(self,
               N:int=196,
               embedding_dim:int=768,
               num_heads:int=12,
               mlp_heads:int=3072,
               attn_dropout:float=0.0,
               dropout:float=0.1,
               patch_size=16):

    super().__init__()


    # self.embedding_dropout = nn.Dropout(p=dropout)


    self.patch_embedding = PatchEmbedding(in_channels=3,
                                          embedding_dim=embedding_dim,
                                          stride = patch_size,
                                          kernel=patch_size)


    self.class_token = nn.Parameter(torch.randn((batch_size,
                                                 1,
                                                 embedding_dim),requires_grad=True))


    self.position_embedding = nn.Parameter(torch.randn(1,
                                                       N+1,
                                                       embedding_dim),requires_grad=True)


    self.embedding_dropout = nn.Dropout(p=dropout)


    self.transformer_encoder = nn.Sequential(
                                TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                        num_heads=num_heads,
                                                        mlp_heads=mlp_heads,
                                                        attn_dropout=attn_dropout,
                                                        dropout=dropout),
                                TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                        num_heads=num_heads,
                                                        mlp_heads=mlp_heads,
                                                        attn_dropout=attn_dropout,
                                                        dropout=dropout),
                                TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                        num_heads=num_heads,
                                                        mlp_heads=mlp_heads,
                                                        attn_dropout=attn_dropout,
                                                        dropout=dropout),
                                TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                        num_heads=num_heads,
                                                        mlp_heads=mlp_heads,
                                                        attn_dropout=attn_dropout,
                                                        dropout=dropout),
                                TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                        num_heads=num_heads,
                                                        mlp_heads=mlp_heads,
                                                        attn_dropout=attn_dropout,
                                                        dropout=dropout),
                                TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                        num_heads=num_heads,
                                                        mlp_heads=mlp_heads,
                                                        attn_dropout=attn_dropout,
                                                        dropout=dropout),
                                TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                        num_heads=num_heads,
                                                        mlp_heads=mlp_heads,
                                                        attn_dropout=attn_dropout,
                                                        dropout=dropout),
                                TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                        num_heads=num_heads,
                                                        mlp_heads=mlp_heads,
                                                        attn_dropout=attn_dropout,
                                                        dropout=dropout),
                                TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                        num_heads=num_heads,
                                                        mlp_heads=mlp_heads,
                                                        attn_dropout=attn_dropout,
                                                        dropout=dropout),
                                TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                        num_heads=num_heads,
                                                        mlp_heads=mlp_heads,
                                                        attn_dropout=attn_dropout,
                                                        dropout=dropout),
                                TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                        num_heads=num_heads,
                                                        mlp_heads=mlp_heads,
                                                        attn_dropout=attn_dropout,
                                                        dropout=dropout),
                                TransformerEncoderBlock(embedding_dim=embedding_dim,
                                                        num_heads=num_heads,
                                                        mlp_heads=mlp_heads,
                                                        attn_dropout=attn_dropout,
                                                        dropout=dropout),
                                )


    self.classifier = nn.Sequential(
                                    nn.LayerNorm(normalized_shape = embedding_dim),
                                    nn.Linear(in_features=embedding_dim,
                                    out_features=1000),
                                    )

  def forward(self,x):
    # x = self.dropout(x)
    x = self.patch_embedding(x)

    cls = self.class_token.expand(x.shape[0],-1,-1)

    x = torch.cat((cls,x),dim=1)


    x = x + self.position_embedding

    # x = self.layer_norm(x)
    x = self.embedding_dropout(x)


    msa = self.multi_head_attention(x)

    msa = msa + x
    msa = self.layer_norm(msa)
    mlp = self.mlp(msa)

    mlp = mlp+msa
    mlp = self.classifier(mlp)
    return mlp

In [249]:
t = ViT(N=196,
              embedding_dim=embedding_dim,
              num_heads=12,
              mlp_heads=3072,
              attn_dropout=0.0,
              dropout=0.1,
              patch_size=16)

In [250]:
summary(model=t,
        # input_size=(1, 3, 224, 224), # (batch_size, color_channels, height, width)
        # col_names=["input_size"], # uncomment for smaller output
        col_names=[ "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                                                Param #              Trainable
ViT (ViT)                                                              152,064              True
├─PatchEmbedding (patch_embedding)                                     --                   True
│    └─Conv2d (conv)                                                   590,592              True
│    └─Flatten (flatten)                                               --                   --
├─Dropout (embedding_dropout)                                          --                   --
├─Sequential (transformer_encoder)                                     --                   True
│    └─TransformerEncoderBlock (0)                                     --                   True
│    │    └─MultiheadSelfAttentionBlock (msa_block)                    2,363,904            True
│    │    └─MLPBlock (mlp_block)                                       4,723,968            True
│    └─TransformerEncoderBloc

In [251]:
optimiser = torch.optim.Adam(params = t.parameters(),
                             lr=3e-3,
                             betas=(0.9,0.999),
                             weight_decay=0.3)

loss_fn = nn.CrossEntropyLoss()