#### Define Transformer block for adding MultiHeadAttention Layer

In [48]:
import tensorflow as tf
from tensorflow.keras.layers import Layer,Dense,LayerNormalization, Dropout, Flatten

# Define Transformer block class
class TransformerBlock(Layer):
    def __init__(self,embed_dim, num_heads, ff_dim, rate = 0.1):
        super(TransformerBlock, self) .__init__()

        # Define MultiHeadAttension layer
        self.att = tf.keras.layers.MultiHeadAttention(num_heads = num_heads, key_dim = embed_dim)
        
        # Define feed forward network
        self.ffn = tf.keras.Sequential([
            Dense(ff_dim , activation="relu"),
            Dense(embed_dim),
        ])

        # Define normalization and dropout layers
        self.layernorm1 = LayerNormalization(epsilon = 1e-6)
        self.layernorm2 = LayerNormalization(epsilon = 1e-6)
        self.dropout1 = Dropout(rate) # prevent overfitting
        self.dropout2 = Dropout(rate)

    # call method for computing the attension weights and applies them to the value vectors to get output
    def call(self, inputs, training, mask=None):
        
        # MultiHeadAttention output
        attn_output = self.att(inputs, inputs, inputs, attention_mask=mask)
        # Apply to dropout layer to attnoutput for prevent overfitting
        attn_output = self.dropout1(attn_output, training=training)
        # Apply layernormalization to attn_output
        out1 = self.layernorm1(inputs + attn_output)
        
        # feed forward network output(Apply out1)
        ffn_output = self.ffn(out1)
        # Apply again ffn_output to dropout layer)
        ffn_output = self.dropout2(ffn_output, training=training)

        # Finally apply layernormalization
        return self.layernorm2(out1 + ffn_output)

#### Define Patch Embedding Layer(embed image patches in to the disired dimension)

In [49]:
class PatchEmbedding(Layer):
    def __init__(self, num_patches, embedding_dim):
        super(PatchEmbedding, self).__init__()
        self.num_patches = num_patches
        self.embedding_dim = embedding_dim
        self.projection = Dense(embedding_dim)

    # Apply patches to disired layers
    def call(self, patches):
        return self.projection(patches)

#### Define the visionTransformer model with patch

In [52]:
class VisionTransformer(tf.keras.Model):
    def __init__(self, num_patches, embedding_dim, num_heads, ff_dim, num_layers,num_classes):
        super(VisionTransformer, self).__init__()
        # Define patch embedding
        self.patch_embed = PatchEmbedding(num_patches, embedding_dim)
        # Define transformer layer using TransformerBlock
        self.transformer_layers = [TransformerBlock(embedding_dim, num_heads, ff_dim) for _ in range(num_layers)]

        self.flatten = Flatten()
        self.dense = Dense(num_classes, activation="softmax")

    def call(self, images, training):
        patches = self.extract_patches(images)
        x = self.patch_embed(patches)

        for transformer_layer in self.transformer_layers:
            x = transformer_layer(x, training= training) # apply embed images to transformer layer
            x = self.flatten(x)
            return self.dense(x)
    # extract patches from images for the tranformer model
    def extract_patches(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes = [1,16,16,1], strides=[1,16,16,1],
            rates = [1,1,1,1],
            padding="VALID"
        )
        patches = tf.reshape(patches, [batch_size, -1, 16*16*3])
        return patches
        

In [55]:
# Example Usage
num_patches = 196 # Assumin 14*14 patches
embedding_dim = 128
num_heads = 4
ff_dim = 512
num_layers = 6
num_classes = 10 # For CIFAR-10 dataset

vit = VisionTransformer(num_patches, embedding_dim, num_heads, ff_dim, num_layers, num_classes)
images = tf.random.uniform((32, 224, 224,3)) # Batch of 32 images of size 224 *224

output = vit(images,training=True)
print(output.shape)

(32, 10)
