## Vision Transformer (ViT) Overview

According to the paper on [arxiv.org](https://arxiv.org/abs/2010.11929), ViT functions as follows:

1. **Image Patching**: The input image is divided into patches of a fixed size.
2. **Patch Embedding Calculation**: Each patch is converted into a corresponding embedding.
3. **Adding Position Embeddings and Class Token**: Position embeddings and a class token are appended to each patch embedding.
4. **Transformer Encoder**: The sequence of embeddings is then input into a Transformer encoder.
5. **MLP Head for Classification**: Finally, the resulting representations are passed through a Multi-Layer Perceptron (MLP) head to obtain the class predictions.


In [None]:
# MODULES

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import Model 
from tensorflow.keras.layers import Add, Dense, Dropout, Embedding, GlobalAveragePooling1D, Input, Layer, LayerNormalization, MultiHeadAttention

In [None]:
# PATCH EXTRACTION

class PatchExtractor(Layer):
    def __init__(self):
        super(PatchExtractor, self).__init__()

    def call(self,images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches( 
        images = images,
        sizes = [1, 16, 16, 1],
        strides = [1, 16, 16 ,1],
        rates = [1, 1, 1, 1],
        padding = "VALID")
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
        
image = plt.imread('cat.jpg')
image = tf.image.resize(tf.convert_to_tensor(image), size=(224, 224))
plt.imshow(image.numpy().astype("uint8"))
plt.axis("off")



In [None]:
# for (224 x 224) we get 196 patches of (16 x 16)

batch = tf.expand_dims(image, axis=0)
patches = PatchExtractor()(batch)
patches.shape

In [None]:
# VISUALISING THE PATCHES WE EXTRACTED

n = int(np.sqrt(patches.shape[1]))
for i, patch in enumerate(patches[0]):
    # print()
    ax = plt.subplot(n, n, i+1)
    patch_img = tf.reshape(patch, (16, 16, 3))
    ax.imshow(patch_img.numpy().astype("uint8"))
    ax.axis("off")

# PATCH ENCODING

# Patch Encoder Process

## Overview
The patch encoder is a crucial component in vision transformer models. It processes image patches to create embeddings that can be fed into a transformer architecture.

## Process Steps

1. **Input**
   - The encoder receives image patches as input.
   - These patches are uniform, non-overlapping sections of the original image.

2. **Patch Embeddings**
   - Each patch is flattened into a 1D vector.
   - This vector passes through a linear projection (typically a fully connected layer).
   - The output is the patch embedding, representing the content of the patch.

3. **Positional Embeddings**
   - For each patch, a positional embedding is created.
   - This embedding encodes the patch's position in the original image.
   - It's usually a learned vector with the same dimensions as the patch embedding.

4. **Combination**
   - The patch embedding and its corresponding positional embedding are added together.
   - This combination preserves both content and spatial information.

5. **Output**
   - The result is a sequence of embedded patches.
   - Each embedded patch contains both content and position information.
   - This sequence is ready for input into a Transformer model.

## Significance
This process allows the Transformer to treat the image as a sequence of tokens, similar to how it processes text. The inclusion of positional information helps the model understand spatial relationships between different parts of the image.

In [None]:
# Add a trainable that will learn the [class] token embeddings

class PatchEncoder(Layer):
    def __init__(self, num_patches=196, projection_dim=768):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection_dim = projection_dim

        w_init = tf.random_normal_initializer()
        class_token = w_init(shape=(1, projection_dim), dtype="float32")
        self.class_token = tf.Variable(initial_value=class_token, trainable=True)
        self.projection = Dense(units=projection_dim)
        self.position_embedding = Embedding(input_dim=num_patches+1, output_dim=projection_dim)

    def call(self, patch):
        batch = tf.shape(patch)[0]
        class_token = tf.tile(self.class_token, multiples=[batch,1])
        class_token = tf.reshape(class_token, (batch, 1, self.projection_dim))

        patches_embedd = self.projection(patch)
        patches_embedd = tf.concat([patches_embedd, class_token],1)

        positions = tf.range(start=0, limit=self.num_patches+1, delta=1)
        positions_embedd = self.position_embedding(positions)

        encoded = patches_embedd + positions_embedd
        return encoded

In [None]:
embeddings = PatchEncoder()(patches)
embeddings.shape

In [None]:
# MULTILAYER PERCEPTRON

#  It is used in the Transformer encoder as well as the final output layer of the ViT mode.

class MLP(Layer):
    def __init__(self, hidden_features, out_features, dropout_rate=0.1):
        super(MLP, self).__init__()
        self.dense1 = Dense(hidden_features, activation=tf.nn.gelu)
        self.dense2 = Dense(out_features)
        self.dropout = Dropout(dropout_rate)

    def call(self,x):
        x = self.dense1(x)
        x = self.dropout(x)
        x = self.dense2(x)
        y = self.dropout(x)
        return y

mlp = MLP(768* 2, 768)
y = mlp(tf.zeros((1, 197, 768)))
y.shape


In [None]:
# IMPLEMENTING THE TRANSFORMER ENCODER

class Block(Layer):
    def __init__(self, projection_dim, num_heads=4, dropout_rate=0.1):
        super(Block, self).__init__()
        self.norm1 = LayerNormalization(epsilon=1e-6)
        self.attn = MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=dropout_rate)
        self.norm2 = LayerNormalization(epsilon=1e-6)
        self.mlp = MLP(projection_dim * 2, projection_dim, dropout_rate)

    def call(self, x):
        x1 = self.norm1(x)
        attn_output = self.attn(x1,x1)
        x2 = Add()([attn_output, x])
        x3 = self.norm2(x2)
        x3 = self.mlp(x3)

        y = Add()([x3, x2])
        return y

block = Block(768)
y = block(tf.zeros((1, 197, 768)))
y.shape

In [None]:
class TransformerEncoder(Layer):
    def __init__(self, projection_dim, num_heads=4, num_blocks=12, dropout_rate=0.1):
        super(TransformerEncoder, self).__init__()
        self.blocks = [Block(projection_dim, num_heads, dropout_rate) for _ in range(num_blocks)]
        self.norm = LayerNormalization(epsilon=1e-6)
        self.dropout = Dropout(0.5)

    def call(self, x):
        # Create a [batch_size, projection_dim] tensor.
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        y = self.dropout(x)
        return y
transformer = TransformerEncoder(768)
y = transformer(embeddings)
y.shape

In [None]:
def create_VisionTransformer(num_classes, num_patches=196, projection_dim=768, input_shape=(224, 224, 3)):
    inputs = Input(shape=input_shape)
    # Patch extractor
    patches = PatchExtractor()(inputs)
    # Patch encoder
    patches_embed = PatchEncoder(num_patches, projection_dim)(patches)
    # Transformer encoder
    representation = TransformerEncoder(projection_dim)(patches_embed)
    representation = GlobalAveragePooling1D()(representation)
    # MLP to classify outputs
    logits = MLP(projection_dim, num_classes, 0.5)(representation)
    # Create model
    model = Model(inputs=inputs, outputs=logits)
    return model
model = create_VisionTransformer(2)
model.summary()