In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Dense, Flatten, Dropout, LayerNormalization
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
import numpy as np


In [None]:
class PatchEmbedding(Layer):
    def __init__(self, patch_size, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.projection = Dense(embed_dim)

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding='VALID'
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        embeddings = self.projection(patches)
        return embeddings


In [None]:
class PositionalEncoding(Layer):
    def __init__(self, num_patches, embed_dim):
        super(PositionalEncoding, self).__init__()
        self.pos_encoding = self.positional_encoding(num_patches, embed_dim)

    def positional_encoding(self, num_patches, embed_dim):
        # Create a range for positions and the division term
        positions = tf.range(num_patches, dtype=tf.float32)[:, tf.newaxis]  # Shape: [num_patches, 1]
        div_term = tf.exp(tf.range(0, embed_dim, 2, dtype=tf.float32) * -(tf.math.log(10000.0) / embed_dim))

        # Compute sine and cosine for even and odd indices
        even_indices = tf.sin(positions * div_term)
        odd_indices = tf.cos(positions * div_term)

        # Combine even and odd indices
        pos_encoding = tf.concat([even_indices, odd_indices], axis=1)

        return pos_encoding[:, :embed_dim]  # Ensure the shape matches [num_patches, embed_dim]

    def call(self, x):
        return x + self.pos_encoding


In [None]:
class TransformerEncoderBlock(Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, dropout_rate=0.1):
        super(TransformerEncoderBlock, self).__init__()
        self.att = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = tf.keras.Sequential([
            Dense(ff_dim, activation='relu'),
            Dense(embed_dim),
        ])
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(dropout_rate)
        self.dropout2 = Dropout(dropout_rate)

    def call(self, inputs, training=None):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)


In [None]:
def create_vit_model(input_shape, patch_size, embed_dim, num_heads, ff_dim, num_layers, num_classes):
    inputs = tf.keras.Input(shape=input_shape)
    patches = PatchEmbedding(patch_size, embed_dim)(inputs)
    num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
    positions = PositionalEncoding(num_patches, embed_dim)(patches)
    x = positions

    for _ in range(num_layers):
        x = TransformerEncoderBlock(embed_dim, num_heads, ff_dim)(x)

    x = LayerNormalization(epsilon=1e-6)(x)
    x = Flatten()(x)
    x = Dense(ff_dim, activation='relu')(x)
    x = Dropout(0.1)(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    return tf.keras.Model(inputs=inputs, outputs=outputs)


In [None]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype("int32") / 255.0
x_test = x_test.astype("int32") / 255.0

# One-hot encode labels
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)


In [None]:
vit_model = create_vit_model(
    input_shape=(32, 32, 3),
    patch_size=4,
    embed_dim=64,
    num_heads=4,
    ff_dim=128,
    num_layers=8,
    num_classes=10
)

In [None]:
vit_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

# Train Model
history = vit_model.fit(x_train, y_train, batch_size=64, epochs=100, validation_split=0.2,verbose=1)

Epoch 1/100
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m75s[0m 43ms/step - accuracy: 0.1885 - loss: 2.1956 - val_accuracy: 0.3933 - val_loss: 1.7012
Epoch 2/100
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 26ms/step - accuracy: 0.3994 - loss: 1.6703 - val_accuracy: 0.4722 - val_loss: 1.4432
Epoch 3/100
[1m289/625[0m [32m━━━━━━━━━[0m[37m━━━━━━━━━━━[0m [1m8s[0m 25ms/step - accuracy: 0.4653 - loss: 1.4883