# **Vision Transformer**

What's the idea behind this?
A image can be split into patches, and then we can feed the patches into a multi-head self attension to treat the sequence of patches like sequence of words.

### Figure 1

![](https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/08-vit-paper-figure-1-architecture-overview.png)

* **Embedding** = learnable representation(start with random numbers and improve over time)


### Four equations
![](https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/08-vit-paper-four-equations.png)



### Table 1
![](https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/images/08-vit-paper-table-1.png)

Here are all piplines and formulas that we need. We will implement the ViT-Base model

As usual, we need to know what's our input, and what's our output after each block. The input is a image, so in order to represent it using matrix, the matrix will have shape (batch, 3, 224, 224), where 3 are RGB color channels and (224,224) are the default reshape resolution for the ViT-Base

Let's first split a input image into patches, when doing that, you find out the plot is kind of misleading. We do not actually split one image into 9 different images as shown in the pipline, or to say, it's just part of the process.

We can simply use a conv2d layer to split one image into patches, since the patch size is 16 for ViT base, and the image resolution is 224, so we will have 224 * 224/(16 * 16) = 14 * 14 patches.

So after the input image (Batch,3,224,224) pass throught the conv2d, the output will be (Batch, 768, 14, 14) where 768 can be derived from 16*16*3. How can we understand this? It's like we store each patch's data into 768 dimension vector (remember each patch should have shape (3,16,16), we store them into a 768 dim vector), then we have (14*14) such vectors.

So the next thing is flatten the output's last two dimension (14,14), so that the output can be (Batch, 768, 196)

Finnaly, we will permute the output to be (Batch, 196, 768), so that it means we have 196 patches, each has 768 dim vector to represent them.

In [4]:
import torch
from torch import nn

In [29]:
class Pachify(nn.Module):
  def __init__(self, patch_size = 16):
    super().__init__()
    self.patch_size = patch_size
    self.conv2d = nn.Conv2d(in_channels=3,
                            out_channels=768, #D size from table 1 #number of filters
                            kernel_size=patch_size,
                            stride=patch_size,
                            padding=0)
  def forward(self, x):
    x = self.conv2d(x)
    x = x.flatten(start_dim=2)
    x = x.transpose(1,2)
    return x

In [30]:
image = torch.randn(1,3,224,224)
pachify = Pachify()
y = pachify(image)
y.shape

torch.Size([1, 196, 768])

The next thing is to add Positonal information to the 196 patches and prepend the class token to the input. unlike the transformer, which uses a function to add fix positional embedding to the input, this paper states that the positonal embedding are just learnable parameters.

In [34]:
def Add_Positional_Embedding_and_Class_Token(pachified_input):
  batch_size = pachified_input.shape[0]
  dimension = pachified_input.shape[-1] #768
  num_patches = pachified_input.shape[-2] #196
  class_token = nn.Parameter(torch.randn(batch_size,1,dimension),requires_grad=True) # learnable
  pachified_input_with_class_token = torch.cat((class_token,pachified_input),dim=1)
  position_embedding = nn.Parameter(torch.randn(1,num_patches+1,dimension),requires_grad=True) # learnable
  return pachified_input_with_class_token + position_embedding

Note that in the plot, before adding positional embedding and class token, there's a linear projection layer.

In [44]:
class Embedded_Patches(nn.Module):
  def __init__(self, patch_size = 16):
    super().__init__()
    self.pachify = Pachify(patch_size)
    self.linear_projection = nn.Linear(in_features = 768,out_features = 768)
  def forward(self, x):
    x = self.pachify(x)
    x = self.linear_projection(x)
    x = Add_Positional_Embedding_and_Class_Token(x)
    return x

In [45]:
image = torch.randn(1,3,224,224)
prepared_input = Embedded_Patches()
y = prepared_input(image)
y.shape

torch.Size([1, 197, 768])

### Now we prepared our input from a image tensor into pachified tensor with class token and positional embedding. Shape (batch,197,768)

we are ready to build the transformer block.
Note that in this jupyter notebook, I will not implement the transformer block from scratch, if you want to know more about transformer, go to the "Tranformer_Implementation.ipynb" notebook also in this repo.

In [47]:
class Transformer_Encoder_Block(nn.Module):
  def __init__(self, embed_size, heads = 4, dropout_rate = 0.1, hidden_units = 3072):
    super().__init__()
    self.norm1 = nn.LayerNorm(embed_size)
    self.multi_head_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=heads)
    self.norm2 = nn.LayerNorm(embed_size)

    # MLP block are just two linear layers, with a ReLU activation function
    self.MLP = nn.Sequential(
        nn.Linear(embed_size, hidden_units),
        nn.ReLU(),
        nn.Linear(hidden_units, embed_size)
    )
    self.dropout = nn.Dropout(dropout_rate)

  def forward(self, x):
    x = self.norm1(x)
    attension_output, _ = self.multi_head_attention(x,x,x)
    x = x + self.dropout(attension_output)
    x = self.norm2(x)
    mlp_output = self.MLP(x)
    x = x + self.dropout(mlp_output)
    return x

### Now, it's time to combine everything together!

Note that, in the paper, we should stack the same transformer encoder block for 12 times for ViT-Base.

In [78]:
class ViT(nn.Module):
  def __init__(self, image_size = 224,
               patch_size = 16,
               num_classes = 1000,
               embed_size = 768,
               num_heads = 12,
               dropout_rate = 0.1,
               hidden_units = 3072,
               num_transformer_blocks = 12):  ## All hyperparameters are provided by the paper about ViT-Base
    super().__init__()
    self.image_size = image_size
    self.patch_size = patch_size
    self.num_classes = num_classes
    self.embed_size = embed_size
    self.num_heads = num_heads
    self.dropout_rate = dropout_rate
    self.hidden_units = hidden_units
    self.num_transformer_blocks = num_transformer_blocks
    self.embedding_patches = Embedded_Patches(patch_size= self.patch_size)

    # Stack the transformer_encoder for 12 times
    self.transformer_encoder_blocks = nn.ModuleList([
        Transformer_Encoder_Block(embed_size=self.embed_size,
                                  heads=self.num_heads,
                                  dropout_rate=self.dropout_rate,
                                  hidden_units=self.hidden_units)
        for _ in range(self.num_transformer_blocks)
    ])

    # The final MLP layer to get which class it is in 1000 classes
    self.classifier = nn.Sequential(
        nn.LayerNorm(self.embed_size),
        nn.Linear(self.embed_size, self.num_classes),
    )

  def forward(self,x):
    # Prepare the input
    x = self.embedding_patches(x)

    # pass through the transformer encoder for 12 times
    for block in self.transformer_encoder_blocks:
      x = block(x)

    # get the class probs
    y = self.classifier(x[:,0,:]) #0th in 197 patches is the class token

    probs = torch.softmax(y, dim=1)

    return probs

In [79]:
vit = ViT()
x = torch.randn(1,3,224,224)
y = vit(x)
y.shape

torch.Size([1, 1000])

In [81]:
# the prob of 1000 classes sum to 1
y.sum()

tensor(1., grad_fn=<SumBackward0>)