# Module 3, Task 7: Vision Transformers in Keras

**Objective:** Implement, train, and evaluate a Vision Transformer (ViT) model for image classification using TensorFlow and Keras, and compare its performance to the CNN model.

In [None]:
# Install necessary libraries
!pip install tensorflow tensorflow-datasets matplotlib scikit-learn tensorflow-addons

### Introduction to Vision Transformers (ViT)

The Vision Transformer (ViT) is a model that applies the Transformer architecture, originally successful in Natural Language Processing (NLP), to computer vision tasks. It works as follows:

1.  **Image Patching:** The input image is split into a sequence of fixed-size, non-overlapping patches.
2.  **Linear Projection:** Each patch is flattened and linearly projected into an embedding vector.
3.  **Positional Embeddings:** Position embeddings are added to the patch embeddings to retain spatial information.
4.  **Transformer Encoder:** The resulting sequence of vectors is fed into a standard Transformer Encoder, which uses self-attention mechanisms to learn relationships between patches.
5.  **Classification Head:** The output corresponding to a special `[CLS]` token is passed through a classifier (MLP head) to produce the final prediction.

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report

# --- Data Loading and Preprocessing (Reused from M2_04) ---
(ds_train, ds_validation), ds_info = tfds.load(
    'eurosat/rgb', split=['train[:80%]', 'train[80%:]'],
    shuffle_files=True, as_supervised=True, with_info=True)

NUM_CLASSES = ds_info.features['label'].num_classes
IMG_SIZE = 72  # ViT models often work better with smaller image sizes like 72, 224, etc.
BATCH_SIZE = 128 # ViTs are memory intensive, might need a smaller batch size
AUTOTUNE = tf.data.AUTOTUNE

def process_image(image, label):
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    return tf.cast(image, tf.float32) / 255.0, label

data_augmentation = keras.Sequential([...]) # Same as before, can be defined if needed

def configure_dataset(ds, shuffle=False):
    ds = ds.map(process_image, num_parallel_calls=AUTOTUNE)
    ds = ds.cache().batch(BATCH_SIZE).prefetch(buffer_size=AUTOTUNE)
    return ds

train_ds = configure_dataset(ds_train, shuffle=True)
validation_ds = configure_dataset(ds_validation)
print("Data pipelines are ready for ViT.")

### Building the ViT Model
We will implement a simplified ViT model from scratch using Keras layers.

In [None]:
# --- Hyperparameters ---
PATCH_SIZE = 6
NUM_PATCHES = (IMG_SIZE // PATCH_SIZE) ** 2
PROJECTION_DIM = 64
NUM_HEADS = 4
TRANSFORMER_UNITS = [PROJECTION_DIM * 2, PROJECTION_DIM]
TRANSFORMER_LAYERS = 4
MLP_HEAD_UNITS = [2048, 1024]

# --- Keras Layers Implementation ---
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

def create_vit_classifier():
    inputs = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3))
    # Create patches.
    patches = Patches(PATCH_SIZE)(inputs)
    # Encode patches.
    encoded_patches = PatchEncoder(NUM_PATCHES, PROJECTION_DIM)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(TRANSFORMER_LAYERS):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=TRANSFORMER_UNITS, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=MLP_HEAD_UNITS, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(NUM_CLASSES, activation="softmax")(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model

vit_model = create_vit_classifier()
vit_model.summary()

### Training the ViT Model
ViT models typically require longer training times and more data to converge compared to CNNs. We'll use the AdamW optimizer, which often works well for Transformers.

In [None]:
EPOCHS = 20 # ViT may need more epochs
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0001

optimizer = tfa.optimizers.AdamW(
    learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

vit_model.compile(
    optimizer=optimizer,
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

history = vit_model.fit(
    train_ds,
    validation_data=validation_ds,
    epochs=EPOCHS,
    verbose=1
)

### Evaluation
Let's evaluate our trained ViT model.

In [None]:
# Plot training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.title('ViT Accuracy')

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.title('ViT Loss')
plt.show()

# Detailed Classification Report
y_true = np.concatenate([y for x, y in validation_ds], axis=0)
y_pred_probs = vit_model.predict(validation_ds)
y_pred = np.argmax(y_pred_probs, axis=1)
class_names = ds_info.features['label'].names

print("\nViT Classification Report:\n")
print(classification_report(y_true, y_pred, target_names=class_names))

### Conclusion: ViT vs. CNN

- **Performance:** On smaller datasets like EuroSAT, a well-tuned CNN (from M2_04) might outperform a ViT trained from scratch. ViTs are data-hungry and truly shine when pre-trained on massive datasets (like ImageNet-21k) and then fine-tuned.
- **Architecture:** ViTs lack the built-in inductive biases of CNNs (like translation equivariance and locality), which makes them more flexible but harder to train on smaller datasets.
- **Complexity:** Implementing a ViT from scratch is more complex than a standard CNN, but high-level libraries are making them more accessible.

This exercise demonstrates the implementation of a ViT, a powerful alternative to CNNs, showcasing the flexibility of modern deep learning architectures.