# Image classification with Vision Transformer

**Author:** [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)<br>
**Date created:** 2021/01/18<br>
**Last modified:** 2021/01/18<br>
**Description:** Implementing the Vision Transformer (ViT) model for image classification.

## Introduction

This example implements the [Vision Transformer (ViT)](https://arxiv.org/abs/2010.11929)
model by Alexey Dosovitskiy et al. for image classification,
and demonstrates it on the CIFAR-100 dataset.
The ViT model applies the Transformer architecture with self-attention to sequences of
image patches, without using convolution layers.

## Setup

In [7]:
import os

os.environ["KERAS_BACKEND"] = "jax"  # @param ["tensorflow", "jax", "torch"]

import keras
from keras import layers
from keras import ops

import numpy as np
import matplotlib.pyplot as plt
!pip install tensorflow
!pip install tensorflow-io

Collecting tensorflow-io
  Downloading tensorflow_io-0.37.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (14 kB)
Downloading tensorflow_io-0.37.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (49.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.6/49.6 MB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: tensorflow-io
Successfully installed tensorflow-io-0.37.1


## Prepare the data

In [10]:
import tensorflow as tf
import os
import random
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers  # Import layers module
import tensorflow_io as tfio

# ... (Other code remains the same) ...

def load_and_preprocess_video(path, label):
    """Loads a video from the given path, preprocesses it, and returns the video tensor and label."""

    # Use tf.strings.regex_full_match for TensorFlow-compatible string matching
    def _load_video(path):
        raw = tf.io.read_file(path)
        if tf.strings.regex_full_match(path, '.*\\.mp4'):
            video = tfio.experimental.av.decode_mp4(raw)
        elif tf.strings.regex_full_match(path, '.*\\.avi'):
            video = tfio.experimental.av.decode_avi(raw)
        else:
            # If the video format is not supported, raise an exception.
            raise ValueError(f'Unsupported video format: {path}')
        return video

    # Use tf.py_function to wrap the loading and augmentation within a TensorFlow operation
    video = tf.py_function(_load_video, [path], tf.uint8)
    video = augment_video(video)

    return video, label

# Apply load_and_preprocess_video to the train_dataset
train_dataset = train_dataset.map(
    load_and_preprocess_video,  # Apply this function
    num_parallel_calls=tf.data.AUTOTUNE,
)

# ... (Rest of the code - model training, etc.) ...

AttributeError: in user code:

    File "<ipython-input-10-5fdc94bb7655>", line 17, in load_and_preprocess_video  *
        if path.endswith('.mp4'):

    AttributeError: 'SymbolicTensor' object has no attribute 'endswith'


In [9]:
import tensorflow as tf
import os
import random
from sklearn.model_selection import train_test_split

# Define the path to your dataset.
dataset_path = "/content/drive/MyDrive/UCF50"  # Replace with the actual path

# Define the desired train/test split ratio.
train_ratio = 0.8

# Create a list of video file paths and their corresponding labels.
video_paths = []
labels = []
class_names = []  # To store class names

for class_index, class_name in enumerate(os.listdir(dataset_path)):
    class_path = os.path.join(dataset_path, class_name)
    class_names.append(class_name)  # Store the class name
    for video_file in os.listdir(class_path):
        video_path = os.path.join(class_path, video_file)
        video_paths.append(video_path)
        labels.append(class_index)  # Assign label based on class index

# Split the data into train and test sets.
train_paths, test_paths, train_labels, test_labels = train_test_split(
    video_paths, labels, test_size=1 - train_ratio, random_state=42
)

# Create TensorFlow datasets for train and test sets.
train_dataset = tf.data.Dataset.from_tensor_slices((train_paths, train_labels))
test_dataset = tf.data.Dataset.from_tensor_slices((test_paths, test_labels))

# ... (Rest of the code - preprocessing, model training, etc.) ...

## Configure the hyperparameters

In [4]:
learning_rate = 0.001  # You might need to adjust this during training
weight_decay = 0.0001  # Standard value for weight decay
batch_size = 32  # Adjust based on your GPU memory
num_epochs = 100  # Increase for better convergence
image_size = 224  # ViT Base 16 input size
patch_size = 16  # ViT Base 16 patch size
num_patches = (image_size // patch_size) ** 2
projection_dim = 768  # ViT Base 16 projection dimension
num_heads = 12  # ViT Base 16 number of heads
transformer_units = [
    projection_dim * 4,
    projection_dim,
]  # Standard transformer units
transformer_layers = 12  # ViT Base 16 number of transformer layers
mlp_head_units = [
    projection_dim * 4,
    projection_dim,
]  # Adjust based on your task

## Use data augmentation

In [5]:
def augment_video(video):
    """Applies data augmentation to a video tensor.

    Args:
        video: A tensor representing the video, with shape (num_frames, height, width, channels).

    Returns:
        The augmented video tensor.
    """

    # Create a list to store the augmented frames.
    augmented_frames = []

    # Apply augmentations to each frame.
    for frame in tf.unstack(video, axis=0):  # Iterate through frames
        frame = layers.Normalization()(frame)  # Normalize
        frame = layers.Resizing(image_size, image_size)(frame)  # Resize
        frame = layers.RandomFlip("horizontal")(frame)  # Random flip
        frame = layers.RandomRotation(factor=0.02)(frame)  # Random rotation
        frame = layers.RandomZoom(height_factor=0.2, width_factor=0.2)(frame)  # Random zoom
        augmented_frames.append(frame)

    # Stack the augmented frames back into a video tensor.
    augmented_video = tf.stack(augmented_frames, axis=0)

    return augmented_video

# Now, in your data loading pipeline, apply this function:
train_dataset = train_dataset.map(
    lambda video, label: (augment_video(video), label),
    num_parallel_calls=tf.data.AUTOTUNE,
)

ValueError: in user code:

    File "<ipython-input-5-4767a3257108>", line 30, in None  *
        lambda video, label: (augment_video(video), label)
    File "<ipython-input-5-4767a3257108>", line 15, in augment_video  *
        for frame in tf.unstack(video, axis=0):  # Iterate through frames

    ValueError: Argument `axis` = 0 not in range [0, 0)


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Implement multilayer perceptron (MLP)

In [None]:

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=keras.activations.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x


## Implement patch creation as a layer

In [None]:

class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        input_shape = ops.shape(images)
        batch_size = input_shape[0]
        height = input_shape[1]
        width = input_shape[2]
        channels = input_shape[3]
        num_patches_h = height // self.patch_size
        num_patches_w = width // self.patch_size
        patches = keras.ops.image.extract_patches(images, size=self.patch_size)
        patches = ops.reshape(
            patches,
            (
                batch_size,
                num_patches_h * num_patches_w,
                self.patch_size * self.patch_size * channels,
            ),
        )
        return patches

    def get_config(self):
        config = super().get_config()
        config.update({"patch_size": self.patch_size})
        return config


Let's display patches for a sample image

In [None]:
plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))]
plt.imshow(image.astype("uint8"))
plt.axis("off")

resized_image = ops.image.resize(
    ops.convert_to_tensor([image]), size=(image_size, image_size)
)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = ops.reshape(patch, (patch_size, patch_size, 3))
    plt.imshow(ops.convert_to_numpy(patch_img).astype("uint8"))
    plt.axis("off")

## Implement the patch encoding layer

The `PatchEncoder` layer will linearly transform a patch by projecting it into a
vector of size `projection_dim`. In addition, it adds a learnable position
embedding to the projected vector.

In [None]:

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = ops.expand_dims(
            ops.arange(start=0, stop=self.num_patches, step=1), axis=0
        )
        projected_patches = self.projection(patch)
        encoded = projected_patches + self.position_embedding(positions)
        return encoded

    def get_config(self):
        config = super().get_config()
        config.update({"num_patches": self.num_patches})
        return config


## Build the ViT model

The ViT model consists of multiple Transformer blocks,
which use the `layers.MultiHeadAttention` layer as a self-attention mechanism
applied to the sequence of patches. The Transformer blocks produce a
`[batch_size, num_patches, projection_dim]` tensor, which is processed via an
classifier head with softmax to produce the final class probabilities output.

Unlike the technique described in the [paper](https://arxiv.org/abs/2010.11929),
which prepends a learnable embedding to the sequence of encoded patches to serve
as the image representation, all the outputs of the final Transformer block are
reshaped with `layers.Flatten()` and used as the image
representation input to the classifier head.
Note that the `layers.GlobalAveragePooling1D` layer
could also be used instead to aggregate the outputs of the Transformer block,
especially when the number of patches and the projection dimensions are large.

In [None]:

def create_vit_classifier():
    inputs = keras.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches.
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    # Create multiple layers of the Transformer block.
    for _ in range(transformer_layers):
        # Layer normalization 1.
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        # Create a multi-head attention layer.
        attention_output = layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        # Skip connection 1.
        x2 = layers.Add()([attention_output, encoded_patches])
        # Layer normalization 2.
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = layers.Add()([x3, x2])

    # Create a [batch_size, projection_dim] tensor.
    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = layers.Flatten()(representation)
    representation = layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = layers.Dense(num_classes)(features)
    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=logits)
    return model


## Compile, train, and evaluate the mode

In [None]:

def run_experiment(model):
    optimizer = keras.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    checkpoint_filepath = "/tmp/checkpoint.weights.h5"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    return history


vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)


def plot_history(item):
    plt.plot(history.history[item], label=item)
    plt.plot(history.history["val_" + item], label="val_" + item)
    plt.xlabel("Epochs")
    plt.ylabel(item)
    plt.title("Train and Validation {} Over Epochs".format(item), fontsize=14)
    plt.legend()
    plt.grid()
    plt.show()


plot_history("loss")
plot_history("top-5-accuracy")


After 100 epochs, the ViT model achieves around 55% accuracy and
82% top-5 accuracy on the test data. These are not competitive results on the CIFAR-100 dataset,
as a ResNet50V2 trained from scratch on the same data can achieve 67% accuracy.

Note that the state of the art results reported in the
[paper](https://arxiv.org/abs/2010.11929) are achieved by pre-training the ViT model using
the JFT-300M dataset, then fine-tuning it on the target dataset. To improve the model quality
without pre-training, you can try to train the model for more epochs, use a larger number of
Transformer layers, resize the input images, change the patch size, or increase the projection dimensions.
Besides, as mentioned in the paper, the quality of the model is affected not only by architecture choices,
but also by parameters such as the learning rate schedule, optimizer, weight decay, etc.
In practice, it's recommended to fine-tune a ViT model
that was pre-trained using a large, high-resolution dataset.