# Vision Transformer

Transformer is a family of neural network architectures that came to computer vision from NLP. Since transformers don't assume that their input has any specific structure, they can learn more general dependencies in data than convolutional neural network architectures. That's why we all like Vision transformers. At the same time, vision transformes are known to be "data hungry" and their training is quite tricky.

In this homework we will go through main components of vision transformers and their training procedure.

In [7]:
import numpy as np
import torch
from torch import nn

## How to code your transformer

As it was said, vision transformer came from NLP area where typical neural network input is ordered sequence of tokens which are words or word parts. So vision transformer main blocks are:
1. Tokenizer - module that takes images and returns a set of tokens
2. Transformer encoder - the main block of neural network that contains multihead attention, normalization and MLP on tokens.
3. Positional embeddings - a way how to provide information about token orders
4. Classification token - special token whose features is expected to be used for the final class prediction
5. Classification head - MLP that predicts the final class from classificaiton token features

### Tokenizer (2p)

Tokenizer should take an image, split it on non-overlapping patches, flatten the patches and apply Linear layer to these vectors. There are many ways how one can implement this, we will do it using Conv2D with stride being equal to kernel_size.

In [49]:
class Tokenizer(nn.Module):
    def __init__(self, input_height, input_width, output_height, output_width,
                 n_input_channels,
                 embedding_dim):
        super(Tokenizer, self).__init__()

        assert input_height % output_height == 0, f"{input_height} should be devided by {output_height}"
        assert input_width % output_width == 0, f"{input_width} should be devided by {output_width}"
        
        kernel_h = input_height // output_height
        kernel_w = input_width // output_width
        assert kernel_h == kernel_w, "Only square kernels are supported"
        kernel_size = kernel_h
        
        self.patch_size = kernel_size
        
        self.conv = nn.Conv2d(n_input_channels, embedding_dim,  kernel_size, stride=kernel_size)

        self.flattener = nn.Flatten(2, 3)       # From [B, C, H, W] to I guess [B, C, HW]

    def forward(self, x):
        x = self.conv(x)                        # [B, emb_dim, H', W']
        x = self.flattener(x).transpose(-2,-1) # [B, H'*W', emb_dim]
        return x 
    
    
    def patchify(self, img:np.ndarray, kernel:np.ndarray):
        
        H,W = img.shape
        kh, kw = kernel.shape
        
        out_h = H - kh
        out_w = W - kh
        
        
        
        
        
        


In [50]:
tokenizer = Tokenizer(input_height=32, input_width=32, output_height=16, output_width=16, n_input_channels=1,
                      embedding_dim=64)
dummy_batch = torch.zeros((1, 1, 32, 32))
tokenizer_result = tokenizer.forward(dummy_batch)
assert tokenizer_result.shape == (1, 256, 64), tokenizer_result.shape

### Transformer encoder

Transformer encoder consists of 2 blocks: Multi-Head Attention and MLP, each of each is prepended by layer norm. Let's walk through the separate modules for beggining

#### Multi-head attention

<img src="https://data-science-blog.com/wp-content/uploads/2022/01/mha_img_original.png" style="width:50%">

Attention implements a simple formula: $\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$.

Multi-head attention is about splitting Q, K, V on several subvectors, appling Attention on each subvector independendly and concating the result.

You can find Multi-Head Attention being implemented in pytorch as `torch.nn.MultiheadAttention`. Check the documentation and pay attention on `dropout` and `batch_first` parameters.

[[paper]](https://arxiv.org/pdf/1706.03762.pdf)

In [51]:
torch.nn.MultiheadAttention??

#### MLP for Transformer Encoder (1p)

MLP for transformer encoder is just a simple two-layer perceptron with RELU as non-linearity. It also uses Dropout after each Linear layer in order to reduce overfitting. Important thing is that size of hidden state on MLP is usually several times bigger than size of MLP input.

In [52]:
def create_mlp(embedding_dim, mlp_size, dropout_rate):
    return nn.Sequential(
        # YOUR CODE: Linear + RELU + Dropout + Linear + Dropout
        nn.Linear(embedding_dim, mlp_size),
        nn.ReLU(mlp_size),
        nn.Dropout(dropout_rate),
        nn.Linear(mlp_size, embedding_dim),
        nn.Dropout(dropout_rate)
    )

In [53]:
mlp = create_mlp(128, 128 * 2, 0.1)

#### Layer norm

While Batch Normalization is a default normalization layer for convolutional neural networks, in transformers Layer Normalization is used instead.

Layer norm is implemented in pytorch as `torch.nn.LayerNorm`

In [54]:
nn.LayerNorm?

#### TransformerEncoder: putting it all together (2p)

Now we are ready to define Transformer Encoder.
<img src="./transformer_encoder.png" style="width:20%">

In [55]:
class TransformerEncoder(nn.Module):
    def __init__(self, embedding_dim, num_heads, mlp_size, dropout=0.1, attention_dropout=0.1,
                 drop_path_rate=0.1):
        super().__init__()
        # YOUR CODE
        self.attention_pre_norm = nn.LayerNorm(embedding_dim)
        self.attention = torch.nn.MultiheadAttention(embed_dim=embedding_dim,
                                                     num_heads=num_heads,
                                                     dropout=attention_dropout,
                                                     batch_first=True)
        
        self.attention_output_dropout = nn.Dropout(dropout)

        self.mlp_pre_norm = nn.LayerNorm(embedding_dim)
        self.mlp = create_mlp(embedding_dim, mlp_size, dropout)

    def forward(self, x):
        # first block
        y = self.attention_pre_norm(x)
        attention = self.attention(y, y, y)[0]
        attention = self.attention_output_dropout(attention)
        x = x + attention   # Residual connection
            
        # second block
        y = self.mlp_pre_norm(x)
        y = self.mlp(y)
        x = x + y # Residual connection
        
        return x

Let's check that it actually works

In [56]:
tokenizer_result.shape

torch.Size([1, 256, 64])

In [57]:
e = TransformerEncoder(embedding_dim=64, num_heads=2, mlp_size=128)
encoder_result = e(tokenizer_result)
print (encoder_result.shape)
assert encoder_result.shape == tokenizer_result.shape

torch.Size([1, 256, 64])


### Positional embeddings (2p)

Positional embeddings is a way to give transformer information about token orders. You can either learn good embeddings by SGD or use some scheme for embeddings generation. The most popular scheme is sinusoidal embeddings:

$$\text{emb}(p, 2i) = \sin(\frac{p}{10000^{2i/d}})$$
$$\text{emb}(p, 2i + 1) = \cos(\frac{p}{10000^{2i/d}})$$
where p, 2i, 2i+1 - indices of embedding element, d - embedding dimension

Tranditional way of using embeddings in pytorch is by `torch.nn.Embedding`. But in our case its simplier to use more low-level thing `torch.nn.Parameter`. Here is how one can define learnable embeddings.

In [58]:
n_tokens = 256
embedding_dim = 64

# YOUR CODE
emb =  torch.nn.Parameter(torch.empty(n_tokens, embedding_dim)) #using empty to create an uninitialised tensor to use truncated normal to fill the tensor 

_ = torch.nn.init.trunc_normal_(emb, std=0.2)

In [59]:
print(emb.std(), emb.shape)

tensor(0.2001, grad_fn=<StdBackward0>) torch.Size([256, 64])


### Class token and classification head

Vanilla Vision Transformer uses a rather unusual way how to get the embedding of the whole image for the final prediction. It adds one more token, named as class-token, with its own positional embedding and takes its features as the final embedding of image. Alternative approach that comes from CNN is to use global average pooling for image embeddings obtaining. While being more simple to implement, global average pooling introduces a shortcut how different patches can communicates between each other (in vanilla ViT all the inter-patch relations can be learned only through attention blocks).

However in modern papers you can meet the both approaces equally likely.

Adding class token in pytorch is simple thing. You can either add one more embedding to `nn.Parameter` for positional encoders or create one more `nn.Parameter` module for class token only.

In [60]:
embedding_dim = 64
class_emb = torch.nn.Parameter(torch.empty((1, embedding_dim)), requires_grad=True)
torch.nn.init.trunc_normal_(class_emb, std=0.2)

print(class_emb[0].shape, class_emb.shape)

torch.Size([64]) torch.Size([1, 64])


### Vision Transformer: putting it all together (3p)

In [67]:
class VisionTransformer(nn.Module):
    def __init__(self,
                 input_height, input_width,
                 n_input_channels,
                 embedding_dim,
                 num_layers,
                 num_heads,
                 num_classes=1000,
                 mlp_ratio=4.0,
                 dropout=0.1,
                 attention_dropout=0.1,
                 stochastic_depth=0.1):
        super().__init__()
        
        # YOUR CODE
        # 1. Tokenizer
        self.tokenizer = Tokenizer(
            input_height=input_height,
            input_width=input_width,
            output_height=input_height,   # or input_height // patch_size if patching
            output_width=input_width,     # same for width
            n_input_channels=n_input_channels,
            embedding_dim=embedding_dim
        )

        
        num_patches = (input_height // self.tokenizer.patch_size) * (input_width // self.tokenizer.patch_size)

        
        # 2. Positional embeddings
        self.positional_embeddings = nn.Parameter(torch.empty(1, num_patches + 1, embedding_dim)) #Calculates how many patch tokens the image is split into.
        torch.nn.init.trunc_normal_(self.positional_embeddings, std=0.2)
        
        # 3. Class token
        self.class_embedding = torch.nn.Parameter(torch.empty((1, embedding_dim)), requires_grad=True)
        torch.nn.init.trunc_normal_(self.class_embedding, std=0.2)

        # 4. TransformerEncoder 
        mlp_size = int(embedding_dim * mlp_ratio)
        self.blocks = nn.Sequential(*[
            TransformerEncoder(
                embedding_dim=embedding_dim,
                num_heads=num_heads,
                mlp_size=mlp_size,
                dropout=dropout,
                attention_dropout=attention_dropout
                )for i in range(num_layers)
        ])
        
        # 5. we will need more dropout and normalization!
        self.dropout = nn.Dropout(p=dropout)
        self.norm = nn.LayerNorm(embedding_dim)

        # 6. layer for the final prediction
        self.fc = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        # 1. tokenizer
        patch_embeddings = self.tokenizer(x)  # [B, 256, D]
        
        # 2. add class token
        B = patch_embeddings.size(0)
        cls_token = self.class_embedding.expand(B, 1, -1)  # [B, 1, D]
        x = torch.cat([cls_token, patch_embeddings], dim=1)  # [B, 257, D]
    
        # 3. position embeddings
        x = x + self.positional_embeddings  # now both are [B, 257, D]
    
        # 4. dropout
        x = self.dropout(x)
    
        # 5. transformer blocks
        for block in self.blocks:
            x = block(x)
    
        # 6. norm and final prediction
        x = self.norm(x)
        cls_out = x[:, 0]  # [B, D]
        return self.fc(cls_out)


In [68]:
input_height = 16
input_width = 16
n_input_channels = 1
vit = VisionTransformer(input_height, input_width,
                 #n_tokens=4,
                 n_input_channels=n_input_channels,
                 embedding_dim=32,
                 num_layers=2,
                 num_heads=2,
                 num_classes=10,
                 mlp_ratio=2.0,
                 dropout=0.1,
                 attention_dropout=0.1,
                 stochastic_depth=0.1)

In [69]:
fake_batch = torch.rand((1, n_input_channels, input_height, input_width))
print(fake_batch.shape)

result = vit(fake_batch)
print(result.shape)
print(result)

torch.Size([1, 1, 16, 16])
torch.Size([1, 10])
tensor([[ 0.1446, -0.0488, -0.1795,  1.1439, -0.3029, -0.0131,  0.0605, -0.0669,
         -0.3321, -0.0419]], grad_fn=<AddmmBackward0>)
