# Building the ViT

The goal of this tutorial is to set set up a basic Vision Transformer model from scratch using Pytorch in order to understand the basic principles and theoretical background. The tutorial combines concepts from https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c and https://medium.com/the-dl/transformers-from-scratch-in-pytorch-8777e346ca51. This second of three notebooks deals with setting up a ViT architecture.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
# Nice trick to import function and class definitions from other notebooks
from ipynb.fs.defs._1_Vit_Preprocessing import VitPreprocessor

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

from torch import Tensor
from torch import nn
from torch.nn.functional import softmax

In [None]:
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

# INFO: The PyTorch uses a different tensor shape than Tensorflpow
# Tensorflow: [batch_size, height, width, channels]
# Pytorch: [batch_size, channels, height, width]
train_loader = DataLoader(training_data, batch_size=32)
x_sample, y_sample = next(iter(train_loader))
image_shape = x_sample.shape
print (f"Image shape is: {image_shape} (batch size, channels, height, width)")
plt.matshow(x_sample[0][0], cmap="coolwarm")
plt.show()

## Layer Normalization
Unlike Batch Normalization, Layer Normalization does not calculate mean and standard deviation across a batch but, as the name indicates, across all units across a layer. Let's consider for example a layer for our 64 dimensional embedded space.

In [None]:
# define a layer normalization in Pytorch
ln = nn.LayerNorm(64)
# generate a random number tensor of shape (batch size, number of patches, embedded dimension) as we expect it after the pre processing
batch_input = torch.rand(32, 16, 64) * 42 + 4
# Let's see what happens to a patch embedding after LayerNorm
print(f"0th input batch before LayerNorm: {batch_input[0, 0, :].mean()} +- {batch_input[0, 0, :].std()}")

normalized_input = ln(batch_input)
print(f"0th input batch after LayerNorm: {normalized_input[0, 0, :].mean()} +- {normalized_input[0, 0, :].std()}")

## Scaled dot-Product 

In [None]:
def scaled_dot_product_attention(query, key, value):
    # using the bmm method for the matrix multiplications since we're dealing with batches
    # bmm -> batch matrix multiplication
    temp = query.bmm(key.transpose(1, 2))
    scale = key.size(-1) ** 0.5
    activated = softmax(temp / scale, dim=-1)
    return activated.bmm(value)

## Attention Head Layer

In [None]:
class AttentionHead(nn.Module):
    def __init__(self, dimension):
        super().__init__()
        self.q = nn.Linear(dimension, dimension)
        self.k = nn.Linear(dimension, dimension)
        self.v = nn.Linear(dimension, dimension)

    def forward(self, tokens):
        return scaled_dot_product_attention(self.q(tokens), self.k(tokens), self.v(tokens))

## Expanding to Multi-Head Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, hidden_dimension):
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionHead(hidden_dimension) for _ in range(num_heads)]
        )
        self.linear = nn.Linear(num_heads * hidden_dimension, hidden_dimension)

    def forward(self, tokens):
        return self.linear(
            torch.cat([head(tokens) for head in self.heads], dim=-1)
        )

## Residual Connections

In [None]:
class VitBlock(nn.Module):
    def __init__(self, hidden_dimension, n_heads, mlp_ratio=4):
        super().__init__()
        self.hidden_dimension = hidden_dimension
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(self.hidden_dimension)
        self.mha = MultiHeadAttention(self.n_heads, self.hidden_dimension)
        self.norm2 = nn.LayerNorm(self.hidden_dimension)

        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_dimension, mlp_ratio * self.hidden_dimension),
            nn.ReLU(),
            nn.Linear(mlp_ratio * self.hidden_dimension, self.hidden_dimension)
        )

    def forward(self, x):
        out = x + self.mha(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out

In [None]:
model = VitBlock(hidden_dimension=8, n_heads=2)
x = torch.randn(32, 50, 8)
print(model(x).shape)

In [None]:
class MyVit(nn.Module):
    def __init__(self, image_shape, classes, p_size, embedded_dimension, n_heads=2, n_blocks=2):
        super().__init__()
        self.image_shape = image_shape
        self.classes = classes
        self.p_size = p_size
        self.embedded_dimension = embedded_dimension
        self.n_heads = n_heads
        self.n_blocks = n_blocks
        
        self.preprocessor = VitPreprocessor(
            self.image_shape,
            self.p_size,
            self.embedded_dimension
        )
        
        self.vit_blocks = nn.ModuleList(
            [VitBlock(self.embedded_dimension, self.n_heads) for _ in range(self.n_blocks)]
        )
        
        self.mlp = nn.Sequential(
            nn.Linear(self.embedded_dimension, self.classes),
            nn.Softmax(dim=-1)
        )
        
    def forward(self, images):
        out = self.preprocessor(images)

        for vit_block in self.vit_blocks:
            out = vit_block(out)

        class_tokens = out[:, 0]

        result = self.mlp(class_tokens)

        return result

In [None]:
model = MyVit(x_sample.shape,
              classes=10,
              p_size=7,
              embedded_dimension=32,
              n_heads=4,
              n_blocks=3)              
print(model(x_sample).shape)

In [None]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(f"Number of trainable parameters {params}")
print(f"List of modules: \n {model.modules}")