In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

In [27]:
class PatchEmbeddings(nn.Module):
    def __init__(self, in_channels=3, patch_size=16, embedding_dimension=12):
        super(PatchEmbeddings, self).__init__()
        self.in_channels= in_channels
        self.patch_size= patch_size
        self.embedding_dimension = embedding_dimension

        self.patch = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.embedding_dimension,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding=0,
            bias=True
        )

    def forward(self, x):
        x = self.patch(x)
        x = x.flatten(2).transpose(1,2)
        return x    

In [6]:
embedding = PatchEmbeddings()
summary(embedding, (3,224,224), device='cpu')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 12, 14, 14]           9,228
Total params: 9,228
Trainable params: 9,228
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 0.02
Params size (MB): 0.04
Estimated Total Size (MB): 0.63
----------------------------------------------------------------


In [10]:
x = torch.randn(3,3,224,224)
embedding(x).shape
## this is of the form (Batch, num_patches, embdedding_dimension)
## can be calculated

torch.Size([3, 196, 12])

In [16]:
class TransformerBlock(nn.Module):
    def __init__(self, embedding_dimension=12, num_heads=6, mlp_dimension=256, dropout=0.1, **kwargs):
        super(TransformerBlock, self).__init__()
        
        self.msa_norm = nn.LayerNorm(embedding_dimension)
        self.msa = nn.MultiheadAttention(
            embed_dim=embedding_dimension,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )

        self.mlp_norm = nn.LayerNorm(embedding_dimension)
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dimension, mlp_dimension),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dimension, embedding_dimension)
        )

    def forward(self,x):
        x = self.msa_norm(x)
        msa_res = x
        x,_ = self.msa(x,x,x)
        x = msa_res + F.dropout(x, p=self.dropout, training=self.training)

        x = self.mlp_norm(x)
        mlp_res = x        
        x = self.mlp(x)
        x = mlp_res + F.dropout(x, p=self.dropout, training=self.training)

        return x

In [31]:
block = TransformerBlock()
summary(block, (197,12), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         LayerNorm-1              [-1, 197, 12]              24
MultiheadAttention-2  [[-1, 197, 12], [-1, 197, 197]]               0
         LayerNorm-3              [-1, 197, 12]              24
            Linear-4             [-1, 197, 256]           3,328
              GELU-5             [-1, 197, 256]               0
           Dropout-6             [-1, 197, 256]               0
            Linear-7              [-1, 197, 12]           3,084
Total params: 6,460
Trainable params: 6,460
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.01
Forward/backward pass size (MB): 698.75
Params size (MB): 0.02
Estimated Total Size (MB): 698.78
----------------------------------------------------------------


In [18]:
x = torch.randn(3,196,12)
block(x).shape

torch.Size([3, 196, 12])

In [22]:
class Encoder(nn.Module):
    def __init__(self, num_layers=8, embedding_dimension=12, **kwargs):
        super(Encoder, self).__init__()
        
        self.layers = nn.ModuleList([TransformerBlock(embedding_dimension, **kwargs) for _ in range(num_layers)])
        
        self.final_norm = nn.LayerNorm(embedding_dimension)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)  
        x = self.final_norm(x)  
        return x

In [30]:
encoder = Encoder()
summary(encoder, (197,12), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         LayerNorm-1              [-1, 197, 12]              24
MultiheadAttention-2  [[-1, 197, 12], [-1, 197, 197]]               0
         LayerNorm-3              [-1, 197, 12]              24
            Linear-4             [-1, 197, 256]           3,328
              GELU-5             [-1, 197, 256]               0
           Dropout-6             [-1, 197, 256]               0
            Linear-7              [-1, 197, 12]           3,084
  TransformerBlock-8              [-1, 197, 12]               0
         LayerNorm-9              [-1, 197, 12]              24
MultiheadAttention-10  [[-1, 197, 12], [-1, 197, 197]]               0
        LayerNorm-11              [-1, 197, 12]              24
           Linear-12             [-1, 197, 256]           3,328
             GELU-13             [-1, 197, 256]               0
          Dropout-14      

In [25]:
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, hidden_dim=1024, num_classes=1000, 
                 num_layers=8, embedding_dim=768,
                 in_channels=3, patch_size=16,
                 **kwargs):
        super(VisionTransformer, self).__init__()

        self.num_patches = (img_size // patch_size) ** 2

        self.patch_embed = PatchEmbeddings(in_channels, patch_size, embedding_dim)

        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
        
        self.pos_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, embedding_dim))  

        self.transformer = Encoder(num_layers,embedding_dim, **kwargs)

        ## From the paper, it is mentioned for pre-training, we need an MLP with a hidden layer
        ## For finetuning, we use a single linear layer
        
        self.mlp_head = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim), 
            nn.GELU(),  
            nn.Linear(hidden_dim, num_classes)  
        )

        # self.mlp_head = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        B = x.shape[0]

        x = self.patch_embed(x)  # (B, num_patches, D)

        ## -1 means no change along that dimension
        cls_token = self.cls_token.expand(B, -1, -1)  # (B, 1, D)
        x = torch.cat((cls_token, x), dim=1)  # (B, num_patches+1, D)

        x = x + self.pos_embedding[:, :x.shape[1], :]

        x = self.transformer(x)

        # Use only CLS token for final classification
        x = x[:, 0]  # (B, D)
        return self.mlp_head(x)  # (B, num_classes)

In [29]:
model = VisionTransformer()
summary(model, (3,224,224), device="cpu")

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 768, 14, 14]         590,592
   PatchEmbeddings-2             [-1, 196, 768]               0
         LayerNorm-3             [-1, 197, 768]           1,536
MultiheadAttention-4  [[-1, 197, 768], [-1, 197, 197]]               0
         LayerNorm-5             [-1, 197, 768]           1,536
            Linear-6             [-1, 197, 256]         196,864
              GELU-7             [-1, 197, 256]               0
           Dropout-8             [-1, 197, 256]               0
            Linear-9             [-1, 197, 768]         197,376
 TransformerBlock-10             [-1, 197, 768]               0
        LayerNorm-11             [-1, 197, 768]           1,536
MultiheadAttention-12  [[-1, 197, 768], [-1, 197, 197]]               0
        LayerNorm-13             [-1, 197, 768]           1,536
           Linear-14    

In [28]:

x = torch.randn(2, 3, 224, 224)  
output = model(x)
print(output.shape)  

torch.Size([2, 1000])
