<a href="https://colab.research.google.com/github/amimulhasan/Deep-Learning/blob/main/vit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np


In [2]:
# Load CIFAR-100 dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

# Normalize pixel values to [0, 1]
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0

# Convert labels to one-hot encoding
num_classes = 100
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
[1m169001437/169001437[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 0us/step


In [3]:
# Parameters
input_shape = (32, 32, 3)
patch_size = 4  # 4x4 patches
num_patches = (input_shape[0] // patch_size) ** 2
projection_dim = 64

# Patch extraction layer
class PatchExtractor(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding='VALID',
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

# Patch encoding layer
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patches):
        positions = tf.range(start=0, limit=num_patches, delta=1)
        encoded = self.projection(patches) + self.position_embedding(positions)
        return encoded


In [4]:
def create_vit_classifier():
    inputs = layers.Input(shape=input_shape)
    # Extract patches
    patches = PatchExtractor(patch_size)(inputs)
    # Encode patches
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple Transformer blocks
    for _ in range(4):
        # Layer normalization 1
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Multi-head self-attention
        attention_output = layers.MultiHeadAttention(
            num_heads=4, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1
        x2 = layers.Add()([attention_output, encoded_patches])

        # Layer normalization 2
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP
        x3 = layers.Dense(units=projection_dim * 2, activation=tf.nn.gelu)(x3)
        x3 = layers.Dense(units=projection_dim)(x3)
        # Skip connection 2
        encoded_patches = layers.Add()([x3, x2])

    # Classification head
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    features = layers.Dense(units=128, activation=tf.nn.gelu)(representation)
    features = layers.Dropout(0.5)(features)
    logits = layers.Dense(units=num_classes)(features)

    # Create the Keras model
    model = keras.Model(inputs=inputs, outputs=logits)
    return model


In [None]:
# Create the model
model = create_vit_classifier()

# Compile the model
model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'],
)

# Train the model
model.fit(
    x=x_train,
    y=y_train,
    batch_size=64,
    epochs=10,
    validation_split=0.1,
)


Epoch 1/10
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m442s[0m 604ms/step - accuracy: 0.0102 - loss: 4.6476 - val_accuracy: 0.0078 - val_loss: 4.6063
Epoch 2/10
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m429s[0m 587ms/step - accuracy: 0.0082 - loss: 4.6054 - val_accuracy: 0.0078 - val_loss: 4.6069
Epoch 3/10
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m442s[0m 587ms/step - accuracy: 0.0094 - loss: 4.6056 - val_accuracy: 0.0078 - val_loss: 4.6072
Epoch 4/10
[1m704/704[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m417s[0m 592ms/step - accuracy: 0.0098 - loss: 4.6053 - val_accuracy: 0.0078 - val_loss: 4.6074
Epoch 5/10


In [None]:
# Evaluate on the test set
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {test_accuracy:.2f}")
