# An Image is worth 16 x 16 Words
This research paper talks about the application of Tranformers architecture(which is already popular in NLP) in CV tasks.

The paper shows that CNNs is not necessary and pure transformer architecture can perform well.

The model architecture developed ViT(Vision Transformer) attains State of the Art results.


The paper talks about how when the architecture is applied on datasets like ImageNet then the accuracy achieved is lesser than ResNet of similiar size.This is because Transformers lack some of the inherent inductive bias of CNN like translation equivariance.

However when pre-trained on huge dataset like ImageNet 21k or JFT-300M then the Transformers achieve state of the art results.

The writers have tried to keep the architecture as close to the one mentioned in the famous "Attention is all you need" paper which was the first to describe the tranformer architecture.

We will implement this model in Pytorch.


In [1]:
import torch
from torch import nn 
import torch.nn.functional as F
from einops import rearrange,reduce,repeat
#Einops have been used for Rearrangement of size of tensors as well as for reduction and repeating
#I found Einops to be a very intuitive way to carry out these operations

## How to input an Image as a sequence
The thing about Tranformers which i will be describing in a bit is that the input is a sequence (words in a sentence in case of NLP) for which we divide the image into patches for which the patch has a size 2x2.So for a image with height H and width W we have N=H*W/P^2(N:Number of Patches).Each patch is of Size [P,P,C] where P is patch size and C is number of channel.

These patches are then flattened into 1D Tensor and then Position embedding is added to them(These patches as it is dont have knowledge of the position of the patch in the picture therefore we add Position embedding to them)

In [2]:
class Embedding(nn.Module):
    def __init__(self,height,width,in_channel,embed_size):
        super(Embedding,self).__init__()
        self.height=height
        self.width=width
        self.in_channel=in_channel
        self.embed_size=embed_size
        self.projection =nn.Conv2d(in_channel, embed_size, kernel_size=2, stride=2)
        #authors are using a Conv2d layer instead of a Linear one for performance gain. 
        self.pos_embedding = nn.Parameter(torch.randn(1, (height*width)//4 + 1, embed_size))
        #spatial information
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_size))
        #cls_token:number placed in front of each sequence (of projected patches)

    def forward(self,x):
        b,c,h,w=x.shape
        x=self.projection(x)
        x=rearrange(x,'b e (h) (w) -> b (h w) e')
        b,n,_=x.shape
        cls_token=repeat(self.cls_token,'1 1 d -> b 1 d',b=b)
        x=torch.cat((cls_token,x),dim=1)
        x+=self.pos_embedding
        return x

## Self Attention
As described in "Attention is all you need" we define Queries,Key and Value which we get by passing the input through 3 linear layer.Attention is defined= Softmax((Queries * Key.T)/sqrt(embedding size))*Value.

This layer allows each pixel to define how much attention should be given to each other pixel with respect to it.

## MultiHead Attention
What was realized was that results were better when you averaged over different attentions.In order to not increase computation time what was thought was that the dimension(in channel) of the input should be divided among the different heads and this is called MultiHead Attention.

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self,emb_size,num_heads=8,dropout=0):
        super(MultiHeadAttention,self).__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self,x):
        queries=rearrange(self.queries(x),"b n (h d) -> b h n d", h=self.num_heads)
        keys=rearrange(self.keys(x),"b n (h d) -> b h n d", h=self.num_heads)
        values=rearrange(self.values(x),"b n (h d) -> b h n d", h=self.num_heads)
        a=torch.einsum('bhqd, bhkd -> bhqk', queries, keys)

        att = F.softmax(a, dim=-1) /(self.emb_size**0.5)
        att = self.att_drop(att)
        # sum up over the third axis
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out+=x #Shortcut connection used
        out = self.projection(out)
        return out

The transformer block also consists of a feed forward block.

In [4]:
class FeedForwardBlock(nn.Module):
    def __init__(self, emb_size, expansion = 4, drop_p= 0):
        super(FeedForwardBlock,self).__init__()
        self.block=nn.Sequential(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size))

    def forward(self,x):
        out=self.block(x)
        out+=x #Shortcut connection used
        return out

## Transformer Block
The transformer Block consists of The self attention layer followed by the feed forward layer with layer normalization happening before every layer.

In [5]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self,emb_size=768,drop_p=0,forward_expansion=4,forward_drop_p=0):
        super(TransformerEncoderBlock,self).__init__()
        self.block=nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size),
                nn.Dropout(drop_p),
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
    def forward(self,x):
        out=self.block(x)
        return out

After the transformer blocks,we have a classification block in the end which consists of a linear layer which give us an output equal to number of classes

In [6]:
class Classification(nn.Module):
    def __init__(self, emb_size= 768, n_classes= 1000):
        super(Classification,self).__init__()
        self.block=nn.Sequential(
            nn.LayerNorm(emb_size), 
            nn.Linear(emb_size, n_classes))
    def forward(self,x):
        out=reduce(x,'b n e -> b e', reduction='mean')
        out=self.block(x)
        return out

## Model Architecture
Finally we build our model by combining all these layers

In [7]:
class ViT(nn.Module):
    def __init__(self,     
                in_channels= 3,
                patch_size= 16,
                emb_size= 768,
                img_size= 224,
                depth= 12,
                n_classes= 1000):
        super(ViT,self).__init__()
        self.embedding=Embedding(224, 224, in_channels, emb_size)
        self.block=self.make_layer(depth)
        self.classification=Classification(emb_size, n_classes)

    def make_layer(self,depth):
        layers=[]
        for i in range(depth):
            layers.append(TransformerEncoderBlock())
        return nn.Sequential(*layers)

    def forward(self,x):
        out=self.embedding(x)
        out=self.block(out)
        out=self.classification(out)
        return out