# ``ViT`` : Vision Transformer

* Reference:  

In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torchinfo
from torchinfo import summary


import numpy as np
import math

# The 4 equations that defines the ViT architecture

* $\mathbf{z}_{0} = \left[\mathbf{x}_{\text{class}};\mathbf{x}_{p}^{(1)}\mathbf{E};\mathbf{x}_{p}^{(2)}\mathbf{E};\dots;\mathbf{x}_{p}^{(2)}\mathbf{E};\right] + \mathbf{E}_{\text{pos}}$, $\mathbf{E} \in \mathbb{R}^{({P^{2}\cdot C})\times D}, \mathbf{E}_{\text{pos} \in \mathbb{R}^{(N+1)\times D}}$

* $\mathbf{z}'_{l} = \text{MultiHeadAttention}(\text{LayerNorm}(\mathbf{z}_{l-1})) + \mathbf{z}_{l-1}$

* $\mathbf{z}_{l} = \text{MultilayerPerceptron}(\text{LayerNorm}(\mathbf{z}'_{l})) + \mathbf{z}'_{l}$

* $\mathbf{y} = \text{LayerNorm}(\mathbf{z}^{0}_{L})$

# ViT Transformer Design

* Model Class uses ``torch.nn.TransformerEncoderLayer`` for the encoder part and ``torch.nn.TransformerEncoder`` to stack multiple layers of transformer encoder

In [19]:
class PatchEmbed(torch.nn.Module):
    """
    Turns a 2D input image into a 1D sequence learnable embedding vector.
    """
    def __init__(self,
                 input_channels: int=3,
                 patch_dim: int=16,
                 embed_dim: int=768):
        super().__init__()

        self.project = nn.Conv2d(
            in_channels=input_channels,
            out_channels = embed_dim,
            kernel_size = patch_dim,
            stride=patch_dim,
            padding = 0
        )
    
        self.flatten = nn.Flatten(start_dim=2, end_dim=3)
    
    def forward(self, x) -> torch.Tensor:
        x = self.project(x)
        x = self.flatten(x)
        x = x.permute(0,2,1)
        return x

    ### Alternative below from the above    
    # def forward(self, x) -> torch.Tensor:
    #     x = self.patch(x)
    #     x = x.flatten(2)
    #     x = x.transpose(1,2)
    #     return x

class VisionTransformer(torch.nn.Module):
    def __init__(self,
                 img_dim=224,
                 number_of_channels=3,
                 patch_size=16,
                 embed_dim=512,
                 dropout=0.1,
                 mlp_size=3072,
                 number_of_transformerlayers=12,
                 number_of_heads=8,
                 number_of_classes=10):
        super().__init__()
        self.img_dim = img_dim
        self.patch_size = patch_size

        assert img_dim % patch_size == 0

        #### Create Patch Embedding
        self.patch_embed = PatchEmbed(
            input_channels=number_of_channels,
            patch_dim=patch_size,
            embed_dim=embed_dim
        )

        ### Create class token
        self.class_token = torch.nn.Parameter(torch.randn(1,1, embed_dim),
                                              requires_grad=True)

        ### Positional Embedding
        num_patches = (img_dim * img_dim) // patch_size**2 ### --> Section 3.1 ViT paper
        self.position_embed = torch.nn.Parameter(torch.randn(1, num_patches + 1, embed_dim))

        ### Create patch + position embedding dropout
        self.embedding_dropout = torch.nn.Dropout(p=dropout)

        ### Create transformer Encoder Layer (single)
        self.transformer_encoder_layer = torch.nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=number_of_heads,
            dim_feedforward=mlp_size,
            dropout=dropout,
            activation='gelu',
            batch_first=True,
            norm_first=True)
        
        ### Stack transformer encoder layers
        self.transformer_encoder = torch.nn.TransformerEncoder(
            encoder_layer = self.transformer_encoder_layer,
            num_layers=number_of_transformerlayers
        )


        ### Create MLP Head
        self.MLP_head = torch.nn.Sequential(
            nn.LayerNorm(normalized_shape=embed_dim),
            nn.Linear(in_features=embed_dim,
                      out_features=number_of_classes)
        )
    
    def forward(self, x) -> torch.Tensor:
        
        batch_size = x.shape[0]

        num_patches = (self.img_dim * self.img_dim) // self.patch_size**2
        if self.position_embed.shape[1] != (num_patches + 1):
            self.position_embed = torch.nn.Parameter(
                torch.randn(1, num_patches + 1, x.size(-1)),
                requires_grad=True
            )
        x = self.patch_embed(x)
        print(x.shape)
        class_token = self.class_token.expand(batch_size,-1,-1)
        x = torch.cat((class_token, x), dim=1)
        print(f"{x.shape}")

        x = x + self.position_embed[:, :x.size(1), :]

        x = self.embedding_dropout(x)
        x = self.transformer_encoder(x)
        x = self.MLP_head(x[:, 0])
        return x

## Dummy Test for the Patch Embedding



In [34]:
img = torch.randn(size=(1,128,128)).unsqueeze(0)
img3 = torch.randn(size=(2,3,128,128))
img4 = torch.randn(size=(3,3,224,224))
patchify = PatchEmbed(input_channels=3,
                      patch_dim=16,
                      embed_dim=768
                      )
embed = patchify(img3)
print(f"{img3.shape}")
print(f"{embed.shape}")
# print(f"{embed}")


torch.Size([2, 3, 128, 128])
torch.Size([2, 64, 768])


## Testing on Dummy Data

In [31]:
print(img3.shape)
ViT = VisionTransformer(number_of_classes=3, embed_dim=256)
ViT(torch.randn(size=(4,3,50,50)))

torch.Size([2, 3, 128, 128])
torch.Size([4, 9, 256])
torch.Size([4, 10, 256])


tensor([[-0.2980,  0.4789, -0.2882],
        [ 0.0851, -0.1581, -0.6242],
        [ 0.0176, -0.2377, -0.4942],
        [-0.5732,  0.1085, -0.3220]], grad_fn=<AddmmBackward0>)