In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

class PatchEmbedding(tf.keras.layers.Layer):
    def __init__(self, patch_size=16, embedding_dim=768):
        super().__init__()
        self.proj = layers.Conv2D(filters=embedding_dim, kernel_size=patch_size, strides=patch_size, padding='valid')
        self.flatten = layers.Reshape((-1, embedding_dim))  # (B, num_patches, dim)

    def call(self, x):
        x = self.proj(x)            # [B, H/P, W/P, D]
        x = self.flatten(x)         # [B, N, D]
        return x

class MultiHeadSelfAttentionBlock(tf.keras.layers.Layer):
    def __init__(self, embedding_dim=768, num_heads=12, attn_dropout=0.0):
        super().__init__()
        self.norm = layers.LayerNormalization(epsilon=1e-6)
        self.mha = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embedding_dim // num_heads, dropout=attn_dropout)

    def call(self, x):
        x_norm = self.norm(x)
        attn_output = self.mha(x_norm, x_norm)
        return attn_output

class MLPBlock(tf.keras.layers.Layer):
    def __init__(self, embedding_dim=768, mlp_size=3072, dropout=0.1):
        super().__init__()
        self.norm = layers.LayerNormalization(epsilon=1e-6)
        self.mlp = keras.Sequential([
            layers.Dense(mlp_size, activation='gelu'),
            layers.Dropout(dropout),
            layers.Dense(embedding_dim),
            layers.Dropout(dropout)
        ])

    def call(self, x):
        x = self.norm(x)
        return self.mlp(x)

class TransformerEncoderBlock(tf.keras.layers.Layer):
    def __init__(self, embedding_dim=768, num_heads=12, mlp_size=3072, mlp_dropout=0.1, attn_dropout=0.0):
        super().__init__()
        self.attn = MultiHeadSelfAttentionBlock(embedding_dim, num_heads, attn_dropout)
        self.mlp = MLPBlock(embedding_dim, mlp_size, mlp_dropout)

    def call(self, x):
        x = self.attn(x) + x
        x = self.mlp(x) + x
        return x

class VisionTransformer(tf.keras.Model):
    def __init__(self,
                 img_size=224,
                 patch_size=16,
                 in_channels=3,
                 num_transformer_layers=12,
                 embedding_dim=768,
                 mlp_size=3072,
                 num_heads=12,
                 attn_dropout=0.0,
                 mlp_dropout=0.1,
                 embedding_dropout=0.1,
                 num_classes=1000):
        super().__init__()
        assert img_size % patch_size == 0, "Image size must be divisible by patch size."

        self.num_patches = (img_size * img_size) // (patch_size ** 2)
        self.patch_embedding = PatchEmbedding(patch_size=patch_size, embedding_dim=embedding_dim)

        self.class_token = self.add_weight("cls", shape=[1, 1, embedding_dim], initializer="random_normal", trainable=True)
        self.position_embedding = self.add_weight("pos_embed", shape=[1, self.num_patches + 1, embedding_dim],
                                                  initializer="random_normal", trainable=True)
        self.embedding_dropout = layers.Dropout(embedding_dropout)

        self.encoder_blocks = [TransformerEncoderBlock(embedding_dim, num_heads, mlp_size, mlp_dropout, attn_dropout)
                               for _ in range(num_transformer_layers)]

        self.norm = layers.LayerNormalization(epsilon=1e-6)
        self.head = layers.Dense(num_classes)

    def call(self, x):
        batch_size = tf.shape(x)[0]
        x = self.patch_embedding(x)
        cls_tokens = tf.broadcast_to(self.class_token, [batch_size, 1, self.class_token.shape[-1]])
        x = tf.concat([cls_tokens, x], axis=1)
        x = x + self.position_embedding
        x = self.embedding_dropout(x)

        for blk in self.encoder_blocks:
            x = blk(x)

        x = self.norm(x)
        return self.head(x[:, 0])  # Use only [CLS] token
