# Building Vision Transformer

* **References**: 
1. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., & Houlsby, N. (2020). An image is worth 16x16 words: Transformers for image recognition at scale. In arXiv [cs.CV]. http://arxiv.org/abs/2010.11929

2.  Building a vision transformer from scratch in PyTorch. GeeksforGeeks. https://www.geeksforgeeks.org/building-a-vision-transformer-from-scratch-in-pytorch/

3. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, L., & Polosukhin, I. (2017). Attention is all you need. In arXiv [cs.CL]. http://arxiv.org/abs/1706.03762

In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torchinfo
from torchinfo import summary


import numpy as np
import math

# The 4 equations that defines the ViT architecture

* $\mathbf{z}_{0} = \left[\mathbf{x}_{\text{class}};\mathbf{x}_{p}^{(1)}\mathbf{E};\mathbf{x}_{p}^{(2)}\mathbf{E};\dots;\mathbf{x}_{p}^{(2)}\mathbf{E};\right] + \mathbf{E}_{\text{pos}}$, $\mathbf{E} \in \mathbb{R}^{({P^{2}\cdot C})\times D}, \mathbf{E}_{\text{pos} \in \mathbb{R}^{(N+1)\times D}}$

* $\mathbf{z}'_{l} = \text{MultiHeadAttention}(\text{LayerNorm}(\mathbf{z}_{l-1})) + \mathbf{z}_{l-1}$

* $\mathbf{z}_{l} = \text{MultilayerPerceptron}(\text{LayerNorm}(\mathbf{z}'_{l})) + \mathbf{z}'_{l}$

* $\mathbf{y} = \text{LayerNorm}(\mathbf{z}^{0}_{L})$

# ViT Transformer Design

* Model Class uses ``torch.nn.TransformerEncoderLayer`` for the encoder part and ``torch.nn.TransformerEncoder`` to stack multiple layers of transformer encoder

In [None]:
class PatchEmbed(torch.nn.Module):
    """
    Turns a 2D input image into a 1D sequence learnable embedding vector.
    """
    def __init__(self,
                 input_channels: int=3,
                 patch_dim: int=16,
                 embed_dim: int=768):
        super().__init__()

        self.project = nn.Conv2d(
            in_channels=input_channels,
            out_channels = embed_dim,
            kernel_size = patch_dim,
            stride=patch_dim,
            padding = 0
        )
    
        self.flatten = nn.Flatten(start_dim=2, end_dim=3)
    
    def forward(self, x) -> torch.Tensor:
        x = self.project(x)
        x = self.flatten(x)
        x = x.permute(0,2,1)
        return x

    ### Alternative below from the above    
    # def forward(self, x) -> torch.Tensor:
    #     x = self.patch(x)
    #     x = x.flatten(2)
    #     x = x.transpose(1,2)
    #     return x

class VisionTransformer(torch.nn.Module):
    def __init__(self,
                 img_dim=224,
                 number_of_channels=3,
                 patch_size=16,
                 embed_dim=768,
                 dropout=0.2,
                 mlp_size=3072,
                 Lx=12,
                 number_of_heads=8,
                 number_of_classes=10):
        super().__init__()
        self.img_dim = img_dim
        self.patch_size = patch_size

        assert (
            img_dim % patch_size == 0 
        ), f"Minumulto ka na ng damdamamin mo"

        #### Create Patch Embedding
        self.patch_embed = PatchEmbed(
            input_channels=number_of_channels,
            patch_dim=patch_size,
            embed_dim=embed_dim
        )

        ### Create class token
        self.class_token = torch.nn.Parameter(torch.randn(1,1, embed_dim),
                                              requires_grad=True)

        ### Positional Embedding
        num_patches = (img_dim * img_dim) // patch_size**2 ### --> Section 3.1 ViT paper
        self.position_embed = torch.nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))

        ### Create patch + position embedding dropout
        self.embedding_dropout = torch.nn.Dropout(p=dropout)

        ### Create transformer Encoder Layer (single)
        # self.transformer_encoder_layer = torch.nn.TransformerEncoderLayer(
        #     d_model=embed_dim,
        #     nhead=number_of_heads,
        #     dim_feedforward=mlp_size,
        #     dropout=dropout,
        #     activation='gelu',
        #     batch_first=True,
        #     norm_first=True)
        
        ### Stack transformer encoder layers
        self.transformer_encoder = torch.nn.TransformerEncoder(
            encoder_layer = torch.nn.TransformerEncoderLayer(
                d_model=embed_dim,
                nhead=number_of_heads,
                dim_feedforward=mlp_size,
                dropout=dropout,
                activation='gelu',
                batch_first=True,
                norm_first=True),

                num_layers=Lx
                )


        ### Create MLP Head
        self.MLP_head = torch.nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim,
                      out_features=number_of_classes)
        )
    
    def forward(self, x) -> torch.Tensor:
        
        batch_size = x.shape[0]

        num_patches = (self.img_dim * self.img_dim) // self.patch_size**2
        if self.position_embed.shape[1] != (num_patches + 1):
            self.position_embed = torch.nn.Parameter(
                torch.randn(1, num_patches + 1, x.size(-1)),
                requires_grad=True
            )
        x = self.patch_embed(x)
        print(x.shape)
        class_token = self.class_token.expand(batch_size,-1,-1)
        x = torch.cat((class_token, x), dim=1)
        print(f"{x.shape}")

        x = x + self.position_embed[:, :x.size(1), :]

        x = self.embedding_dropout(x)
        x = self.transformer_encoder(x)
        x = self.MLP_head(x[:, 0])
        return x

## Dummy Test for the Patch Embedding



In [47]:
img = torch.randn(size=(1,128,128)).unsqueeze(0)
img3 = torch.randn(size=(2,3,128,128))
img4 = torch.randn(size=(3,3,224,224))
patchify = PatchEmbed(input_channels=3,
                      patch_dim=16,
                      embed_dim=768
                      )
embed = patchify(img3)
print(f"{img3.shape}")
print(f"{embed.shape}")
# print(f"{embed}")


torch.Size([2, 3, 128, 128])
torch.Size([2, 64, 768])


## Testing on Dummy Data

In [67]:
print(img3.shape)
ViT = VisionTransformer(number_of_classes=1000)
ViT(torch.randn(size=(4,3,32,32)))

torch.Size([2, 3, 128, 128])
torch.Size([4, 4, 768])
torch.Size([4, 5, 768])


tensor([[ 0.9096, -0.1652, -1.2042,  ...,  0.1095,  0.0373, -0.6444],
        [ 0.6283,  0.4312, -1.2815,  ..., -0.1587, -0.3894, -0.4380],
        [ 0.3004, -0.2053, -0.8396,  ..., -0.3644, -0.6926, -0.9440],
        [ 0.7554, -0.2503, -1.1799,  ..., -0.2957,  0.3130, -1.1866]],
       grad_fn=<AddmmBackward0>)

## Summary of the Model

In [68]:
summary(model=ViT,
        input=torch.randn(size=(1,3,224,224)))

Layer (type:depth-idx)                                            Param #
VisionTransformer                                                 152,064
├─PatchEmbed: 1-1                                                 --
│    └─Conv2d: 2-1                                                590,592
│    └─Flatten: 2-2                                               --
├─Dropout: 1-2                                                    --
├─TransformerEncoder: 1-3                                         --
│    └─ModuleList: 2-3                                            --
│    │    └─TransformerEncoderLayer: 3-1                          3,903,208
│    │    └─TransformerEncoderLayer: 3-2                          3,903,208
│    │    └─TransformerEncoderLayer: 3-3                          3,903,208
│    │    └─TransformerEncoderLayer: 3-4                          3,903,208
│    │    └─TransformerEncoderLayer: 3-5                          3,903,208
│    │    └─TransformerEncoderLayer: 3-6             

# ViT Transformer using ```nn.MultiheadAttention``` 
* Build a custom transformer encoder layer using ```nn.MultiheadAttention``` 

In [6]:
### Use this Code Snippet or modify if gusto natin gumamit ng Custom Transformer Encoder layer
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout: float=0.1):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, nhead)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x):
        return self.dropout(self.attention(x, x, x)[0])

class CustomTransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, dim_feedforward, dropout: float=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, nhead)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.Softplus(),
            nn.Dropout(p=dropout),
            nn.Linear(dim_feedforward, d_model),
            nn.Dropout(p=dropout)
        )
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class PatchEmbed(torch.nn.Module):
    """
    Turns a 2D input image into a 1D sequence learnable embedding vector.
    """
    def __init__(self,
                 input_channels: int=3,
                 patch_dim: int=16,
                 embed_dim: int=768):
        super().__init__()

        self.project = nn.Conv2d(
            in_channels=input_channels,
            out_channels = embed_dim,
            kernel_size = patch_dim,
            stride=patch_dim,
            padding = 0
        )
    
        self.flatten = nn.Flatten(start_dim=2, end_dim=3)
    
    def forward(self, x) -> torch.Tensor:
        x = self.project(x)
        x = self.flatten(x)
        x = x.permute(0,2,1)
        return x

class VisionTransformer(torch.nn.Module):
    def __init__(self,
                 img_dim=224,
                 number_of_channels=3,
                 patch_size=16,
                 embed_dim=768,
                 dropout=0.2,
                 mlp_size=3072,
                 Lx=12,
                 number_of_heads=8,
                 number_of_classes=10):
        super().__init__()
        self.img_dim = img_dim
        self.patch_size = patch_size

        assert (
            img_dim % patch_size == 0 
        ), f"Minumulto ka na ng damdamamin mo"


        self.patch_embed = PatchEmbed(
            input_channels=number_of_channels,
            patch_dim=patch_size,
            embed_dim=embed_dim
        )

        ### Class Token
        self.class_token = torch.nn.Parameter(torch.randn(1,1, embed_dim),
                                              requires_grad=True)

        ### Positional Embedding
        num_patches = (img_dim * img_dim) // patch_size**2 ### --> Section 3.1 ViT paper
        self.position_embed = torch.nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))

        ### Create patch + position embedding dropout
        self.embedding_dropout = torch.nn.Dropout(p=dropout)
        
        ### Using Custom Transformer Encoder Layer
        self.transformer_encoder = torch.nn.ModuleList([
                CustomTransformerEncoderLayer(
                d_model=embed_dim,
                nhead=number_of_heads,
                dim_feedforward=mlp_size,
                dropout=dropout)
                
                for _ in range(Lx)
        ])


        ### Create MLP Head
        self.MLP_head = torch.nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim, out_features=100),
            nn.Softplus(beta=1.0, threshold=20.0),
            nn.Linear(in_features=100, out_features=number_of_classes)
        )
    
    def forward(self, x) -> torch.Tensor:
        
        num_patches = (self.img_dim * self.img_dim) // self.patch_size**2
        if self.position_embed.shape[1] != (num_patches + 1):
            self.position_embed = torch.nn.Parameter(
                torch.randn(1, num_patches + 1, x.size(-1)),
                requires_grad=True
            )
        x = self.patch_embed(x)
        class_token = self.class_token.expand(x.shape[0],-1,-1)
        x = torch.cat((class_token, x), dim=1)

        x = x + self.position_embed[:, :x.size(1), :]

        x = self.embedding_dropout(x)
        # x = self.transformer_encoder(x)
        for layer in self.transformer_encoder:
            x = layer(x)
        x = self.MLP_head(x[:, 0])
        return x

ViT = VisionTransformer()

summary(model=ViT,
        input=(torch.randn(size=(1,3,224,224))))

Layer (type:depth-idx)                                            Param #
VisionTransformer                                                 152,064
├─PatchEmbed: 1-1                                                 --
│    └─Conv2d: 2-1                                                590,592
│    └─Flatten: 2-2                                               --
├─Dropout: 1-2                                                    --
├─ModuleList: 1-3                                                 --
│    └─CustomTransformerEncoderLayer: 2-3                         --
│    │    └─MultiHeadAttention: 3-1                               2,362,368
│    │    └─Sequential: 3-2                                       4,722,432
│    │    └─LayerNorm: 3-3                                        1,536
│    │    └─LayerNorm: 3-4                                        1,536
│    └─CustomTransformerEncoderLayer: 2-4                         --
│    │    └─MultiHeadAttention: 3-5                               2,

## Dummy Testing

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tensor1 = torch.randn(size=(3,224,224))
# tensor1 = torch.randn(size=(1,28,28))
print(tensor1.shape)
ViT = VisionTransformer(number_of_classes=10).to(device)
print(f"{ViT(tensor1.unsqueeze(dim=0).to(device))}")

torch.Size([3, 224, 224])
tensor([[-0.2285, -0.0016, -1.0394, -0.5673, -0.5573,  0.9051,  0.1334, -0.3697,
         -0.8526,  0.0453]], device='cuda:0', grad_fn=<AddmmBackward0>)
