In [7]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense, LayerNormalization, Dropout, Flatten

In [8]:
class TransformerBlock(Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.attn = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = tf.keras.Sequential([
            Dense(ff_dim, activation='relu'),
            Dense(embed_dim),
        ])
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(rate)
        self.dropout2 = Dropout(rate)

    def call(self, inputs, training):
        attn_output = self.attn(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

In [9]:
# Define the PatchEmbedding layer
class PatchEmbedding(Layer):
    def __init__(self, num_patches, embedding_dim):
        super(PatchEmbedding, self).__init__()
        self.num_patches = num_patches
        self.embedding_dim = embedding_dim
        self.projection = Dense(embedding_dim)
    def call(self, patches):
            return self.projection(patches)


In [10]:
# Define the VisionTransformer model
class VisionTransformer(tf.keras.Model):
    def __init__(self, num_patches, embedding_dim, num_heads,
    ff_dim,num_layers, num_classes):
        super(VisionTransformer, self).__init__()
        self.patch_embed = PatchEmbedding(num_patches, embedding_dim)
        self.transformer_layers = [TransformerBlock(embedding_dim, num_heads,
        ff_dim)
                for _ in range(num_layers)]
        self.flatten = Flatten()
        self.dense = Dense(num_classes, activation='softmax')
    def call(self, images, training=False):
        patches = self.extract_patches(images)
        x = self.patch_embed(patches)
        for transformer_layer in self.transformer_layers:
            x = transformer_layer(x, training=training)
        x = self.flatten(x)
        return self.dense(x)
    def extract_patches(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
        images=images,
        sizes=[1, 16, 16, 1], strides=[1, 16, 16, 1],
        rates=[1, 1, 1, 1],
        padding='VALID'
        )
        patches=tf.reshape(patches, [batch_size, -1, 16*16*3])
        return patches


In [11]:
#example usage
# Example usage

num_patches = 196 # Assuming 14x14 patches
embedding_dim = 128
num_heads = 4
ff_dim = 512
num_layers = 6
num_classes = 10 # For CIFAR-10 dataset
vit = VisionTransformer(num_patches, embedding_dim, num_heads, ff_dim,
num_layers, num_classes)
images = tf.random.uniform((32, 224, 224, 3)) # Batch of 32 images of size
#224x224
output = vit(images)
print(output.shape) # Should (32, 10)


(32, 10)
