In [None]:
#Requirements
!pip install einops



# ViT implementation from scratch

In this exercise we will implement ViT transformer from scratch

## Patch embedding

In the previous exercise we have seen the basics of path embedding. To
finish the implementation we need to add:
1. Class token
2. Positional encodings

Complete the following code using the rearrange part of the previous exercise


In [None]:
import torch
from torch import nn
from torch import Tensor
from einops.layers.torch import Rearrange

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels: int = 3, patch_size: int = 16, emb_size: int = 768, img_size: int = 224):
        self.patch_size = patch_size
        super().__init__()
        self.projection = nn.Sequential(
            # break-down the image in s1 x s2 patches and flat them using rearrange an linear layer

            #Put code here!!

        )
        self.cls_token = nn.Parameter(torch.randn(1,1, emb_size))
        #remember that the class token has also an associated positional encoding
        self.positional_enc = nn.Parameter(torch.randn((img_size // patch_size) **2 + 1, emb_size))

    def forward(self, x: Tensor) -> Tensor:

        b, _, _, _ = x.shape
        x = self.projection(x)

        #repeat the class token b times 1 1 embed_size -> b 1 emb_size
        #use torch repeat function
        cls_tokens =

        #concatenate c and x tensors on the seq dimension (use torch.cat)
        #the shape of the result will be b (seq_len +1) emb_size

        x =

        #add posional encodings

        x =

        return x


In [None]:
#let's try:
img = torch.randn(4,3,224,224)

pe = PatchEmbedding(in_channels=3, patch_size=16,emb_size = 768, img_size = 224)

print(pe(img).shape) #shoud be [4, 197, 768]


torch.Size([4, 197, 768])


## Multihead attention

In [None]:
from einops import rearrange, einsum
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
  #copy the code of Multihead attention from previous exercise

    def __init__(self, emb_size: int = 512, num_heads: int = 8, dropout: float = 0):
        super().__init__()
        #code here


    def forward(self, x : Tensor, mask: Tensor = None) -> Tensor:
        # split keys, queries and values in num_heads



        return out

 ## Residual Add

 Residual connections are essential in transformers. The following code wraps redidual connection.

 We will use this block to wrap residual connections of multihead attention and MLP layers.

In [None]:
class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn # torch nn module that is wrapped

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

## MLP block

This block follows multihead attention, and is a sequential model with:

1. Linear layer (with expansion of dimensionality)
2. Non linearity (GELU)
3. Dropout
4. Linear layer that returns to input dimension

This block can be very easily implemented by subclassing the nn.Sequential class. Remember that nn.Sequential
can be initialized with the list of blocks.

This implementation avoids to reimplement the forward method!



In [None]:
# MLP Block

class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size: int, expansion: int = 4, dropout: float = 0.):
        #write the four blocks in the init function as sequential
        # 1 linear with expansion of dims
        # 2 gelu
        # 3 dropout
        # 3 linear that returns to emb_size


        super().__init__(
            #put layers here!
        )

## Transformer encoder block

Again we implement this block from the Sequential module. The modules are:

1. Block that includes layernorm, mha, wrapped with residual connection
2. Block that includes layernorn, FF, wrapped with residual connection




In [None]:
class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size: int = 768,
                 forward_expansion: int = 4,
                 dropout: float = 0.,
                 num_heads: int =  8,
                 ):

        block_msa = nn.Sequential(
                #layer norm
                # multihead attention
        )

        block_ff = nn.Sequential(
                #layer norm
                # feed forward block
            )

        #check how residual connecion are done!! :-)
        super().__init__(
            ResidualAdd(block_msa),
            ResidualAdd(block_ff)
            )

## Transformer encoder

This is just a Sequential block of TransformerEncoderBlocks
All TransformerEncoderBlock are initialized with the same parameters

In [None]:
class TransformerEncoder(nn.Sequential):
    def __init__(self, depth: int = 12,
                 emb_size: int = 768,
                 forward_expansion: int = 4,
                 dropout: float = 0.,
                 num_heads: int =  8):

        # generate list of transformer blocks
        transformer_blocks = [TransformerEncoderBlock(emb_size=emb_size,
                                                      forward_expansion=forward_expansion,
                                                      dropout = dropout,
                                                      num_heads = num_heads) for _ in range(depth)]


        super().__init__(*transformer_blocks)

## Classification Head

Although originally the class token was the only one used for classification, it is a common practice to average all tokens and then apply the linear transform to the number of classes

In [None]:
from einops.layers.torch import Reduce

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size: int = 768, n_classes: int = 1000):
        super().__init__(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes))



x = torch.rand(10,14*14+1, 768)

ch = ClassificationHead()

logits = ch(x)
print(logits.shape) # 10 x 1000

torch.Size([10, 1000])


## Putting all together: ViT transformer

In [None]:
class ViT(nn.Sequential):
    def __init__(self,
                in_channels: int = 3,
                patch_size: int = 16,
                emb_size: int = 768,
                img_size: int = 224,
                depth: int = 12,
                num_heads: int = 8,
                n_classes: int = 1000,
                ):
        super().__init__(
            PatchEmbedding(in_channels, patch_size, emb_size, img_size),
            TransformerEncoder(depth, emb_size=emb_size, num_heads=num_heads),
            ClassificationHead(emb_size, n_classes)
        )


In [None]:
vit = ViT()

x = torch.rand(10,3,224,224)

logits = vit(x)

print(logits.shape)

torch.Size([10, 1000])
