<a href="https://colab.research.google.com/github/ashaduzzaman-sarker/Computer-Vision-Projects/blob/main/MobileViT_A_mobile_friendly_Transformer_based_model_for_image_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MobileViT: A mobile-friendly Transformer-based model for image classification

**Author:** [Ashaduzzaman Piash](https://github.com/ashaduzzaman-sarker/)
<br>
**Date created:** 18/06/2024


## Introduction to MobileViT Architecture

In this example, we will implement the MobileViT architecture, which was introduced by Mehta et al. [MobileViT](
https://doi.org/10.48550/arXiv.2110.02178) combines the benefits of both Transformers (Vaswani et al.) and convolutions to achieve an efficient and high-performance model for image recognition tasks.

### Key Features of MobileViT:
1. **Combination of Transformers and Convolutions:**
   - **Transformers:** Known for their ability to capture long-range dependencies and global representations.
   - **Convolutions:** Effective at capturing spatial relationships and modeling locality.

2. **Mobile-Friendly Design:**
   - MobileViT is specifically designed to be efficient on mobile devices while maintaining high performance.
   - It surpasses other models of similar or higher complexity, such as MobileNetV3, in terms of performance.

3. **Versatility:**
   - MobileViT serves as a general-purpose backbone, making it suitable for various image recognition tasks.

### Implementation Overview

We will break down the implementation into the following steps:
1. **Install and Import Dependencies:**
   - Install TensorFlow 2.13 or higher.
   - Import necessary modules from TensorFlow and other libraries.

2. **Define Helper Functions:**
   - Implement functions for layers that are commonly used in the architecture (e.g., convolutional layers, Transformer blocks).

3. **Build the MobileViT Block:**
   - Construct the MobileViT block, which integrates the Transformer and convolutional layers.

4. **Assemble the MobileViT Model:**
   - Combine multiple MobileViT blocks to form the full model.

5. **Compile and Train the Model:**
   - Compile the model with an appropriate optimizer and loss function.
   - Train the model on a suitable dataset.

6. **Evaluate the Model:**
   - Assess the model's performance on a test set and compare it with other architectures.

![](https://user-images.githubusercontent.com/67839539/136470152-2573529e-1a24-4494-821d-70eb4647a51d.png)

## Imports

In [1]:
# Update to keras 3
!pip install --upgrade keras

Collecting keras
  Downloading keras-3.3.3-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
Collecting namex (from keras)
  Downloading namex-0.0.8-py3-none-any.whl (5.8 kB)
Collecting optree (from keras)
  Downloading optree-0.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (311 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.2/311.2 kB[0m [31m34.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: namex, optree, keras
  Attempting uninstall: keras
    Found existing installation: keras 2.15.0
    Uninstalling keras-2.15.0:
      Successfully uninstalled keras-2.15.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tensorflow 2.15.0 requires keras<2.16,>=2.15.0, but you have keras 3.3.3 which is incompatible.[

In [2]:
import os
import tensorflow as tf

os.environ['KERAS_BACKEND'] = 'tensorflow'

import keras
from keras import layers
from keras import backend

import tensorflow_datasets as tfds

tfds.disable_progress_bar()

## Hyperparameters

In [3]:
# Values are from table 4
patch_size = 4 # 2x2, for the tansformer blocks
image_size = 256
expansion_factor = 2 # Expansion factor for the MobileNetV2 block

## MobileViT utilities

In [4]:
## Define a convolutional block with Swish activation and same padding
def conv_block(x, filters=16, kernel_size=3, strides=2):
  conv_layer = layers.Conv2D(
      filters,
      kernel_size,
      strides=strides,
      activation=keras.activations.swish,
      padding='same',
  )
  return conv_layer(x)

In [5]:
## Function to calculate correct padding
def correct_pad(inputs, kernel_size):
  # Determine the image dimension based on data format
  img_dim = 2 if backend.image_data_format() == 'channels_first' else 1
  input_size = inputs.shape[img_dim : (img_dim + 2)
  ]
  # If kernel_size is an integer, make it a tuple
  if isinstance(kernel_size, int):
    kernel_size = (kernel_size, kernel_size)

  # Calculate adjustment for padding based on input size
  if input_size[0] is None:
    adjust = (1, 1)
  else:
    adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)

  # Calculate correct padding values
  correct = (kernel_size[0] // 2, kernel_size[1] // 2)
  return(
      (correct[0] - adjust[0], correct[0]),
      (correct[1] - adjust[1], correct[1]),
  )

In [6]:
## Define an inverted residual block
def inverted_residual_block(x, expanded_channels, output_channels, strides=1):
   # Expand the number of channels with a pointwise convolution
  m = layers.Conv2D(expanded_channels, 1, padding='same', use_bias=False)(x)
  m = layers.BatchNormalization()(m)
  m = keras.activations.swish(m)

  # If stride is 2, add zero padding
  if strides == 2:
    m = layers.ZeroPadding2D(padding=correct_pad(m, 3))(m)

  # Depthwise convolution with specified stride
  m = layers.DepthwiseConv2D(
      3, strides=strides, padding='same' if strides == 1 else 'valid', use_bias=False
  )(m)
  m = layers.BatchNormalization()(m)
  m = keras.activations.swish(m)

  # Project back to the desired number of output channels with a pointwise convolution
  m = layers.Conv2D(output_channels, 1, padding='same', use_bias=False)(m)
  m = layers.BatchNormalization()(m)

  # If input and output shapes match, add the input to the output (skip connection)
  if keras.ops.equal(x.shape[-1], output_channels) and strides == 1:
    return layers.Add()([m, x])
  return m


In [7]:
## Define Multi-Layer Perceptron (MLP) Function (mlp)
def mlp(x, hidden_units, dropout_rate):
  for units in hidden_units:
    x = layers.Dense(units, activation=keras.activations.swish)(x)
    x = layers.Dropout(dropout_rate)(x)
  return x

In [8]:
## Define Transformer Block function
def transformer_block(x, transformer_layers, projection_dim, num_heads=2):
  for _ in range(transformer_layers):
    # Layer Normalizaion 1.
    x1 = layers.LayerNormalization(epsilon=1e-6)(x)
    # Create a multi-head attention layer
    attention_output = layers.MultiHeadAttention(
        num_heads=num_heads, key_dim=projection_dim, dropout=0.1
    )(x1, x1)

    # Skip connection 1
    x2 = layers.Add()([attention_output, x])

    # Layer normalization 2
    x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
    # MLP
    x3 = mlp(
        x3,
        hidden_units=[x.shape[-1] * 2, x.shape[-1]],
        dropout_rate=0.1,
    )

    # Skip Connection 2
    x = layers.Add()([x3, x2])

  return x

In [9]:
## Define MobileViT Block
def mobilevit_block(x, num_blocks, projection_dim, strides=1):
  # Local projection with convolutions
  local_features = conv_block(x, filters=projection_dim, strides=strides)
  local_features = conv_block(
      local_features, filters=projection_dim, kernel_size=1, strides=strides
  )

  # Unfold into patches and then pass through Transformers.
  num_patches = int((local_features.shape[1] * local_features.shape[2]) / patch_size)
  non_overlapping_patches = layers.Reshape((patch_size, num_patches, projection_dim))(
      local_features
  )
  global_features = transformer_block(
      non_overlapping_patches, num_blocks, projection_dim
  )

  # Fold into conv-like feature-maps
  folded_feature_map = layers.Reshape((*local_features.shape[1:-1], projection_dim))(
      global_features
  )

  # Apply point_wise conv -> concatenate with the input features
  folded_feature_map = conv_block(
      folded_feature_map, filters=x.shape[-1], kernel_size=1, strides=strides
  )
  local_global_features = layers.Concatenate(axis=-1)([x, folded_feature_map])

  # Fuse the local and global features using a convolitional layer
  local_global_features = conv_block(
      local_global_features, filters=projection_dim, strides=strides
  )

  return local_global_features


### Detailed Explanation of the MobileViT Block

The MobileViT block is designed to seamlessly integrate local and global feature representations using a combination of convolutional layers and Transformer blocks. Here’s a step-by-step breakdown of the process:

1. **Local Feature Extraction:**
   - **Input Features (A):** The input feature representations have a shape of (h, w, num_channels), where `h` and `w` are the height and width of the feature map, and `num_channels` is the number of channels.
   - **Convolutional Blocks:** These input features are first processed through convolutional blocks to capture local relationships.

2. **Patch Unfolding:**
   - The feature map is divided into non-overlapping patches.
   - **Unfolded Shape:** The shape of these patches becomes (p, n, num_channels), where `p` is the area of each patch and `n` is the number of patches calculated as (h * w) / p.

3. **Global Feature Extraction:**
   - **Transformer Block:** The unfolded patches are then passed through a Transformer block that captures global relationships among the patches.

4. **Patch Folding:**
   - **Output Features (B):** The output from the Transformer block is folded back to the original shape of (h, w, num_channels), resembling a feature map obtained from convolutions.

5. **Feature Fusion:**
   - **Final Convolutional Layers:** The initial input features (A) and the output features from the Transformer block (B) are passed through additional convolutional layers to fuse the local and global representations. The spatial resolution remains unchanged during this process.

The design of the MobileViT block effectively combines the strengths of convolutions (local feature extraction) and Transformers (global feature extraction), offering a versatile and powerful architecture.

![](https://huggingface.co/datasets/hf-vision/course-assets/resolve/main/MobileViT-Architecture.png)

### Assembling the MobileViT Model

In [10]:
def create_mobilevit(num_classes=5):
  inputs = keras.Input((image_size, image_size, 3))
  x = layers.Rescaling(scale=1.0 / 255)(inputs)

  # Initial conv-stem -> MV2 block
  x = conv_block(x, filters=16)
  x = inverted_residual_block(
      x, expanded_channels=16 * expansion_factor, output_channels=16
  )

  # Downsampling with MV2 block
  x = inverted_residual_block(
      x, expanded_channels=16 * expansion_factor, output_channels=24, strides=2
  )
  x = inverted_residual_block(
      x, expanded_channels=24 * expansion_factor, output_channels=24
  )
  x = inverted_residual_block(
      x, expanded_channels=24 * expansion_factor, output_channels=24
  )

  # First MV2 -> MobileViT block
  x = inverted_residual_block(
      x, expanded_channels=24 * expansion_factor, output_channels=48, strides=2
  )
  x = mobilevit_block(x, num_blocks=2, projection_dim=64)

  # Second MV2 -> MobileViT block
  x = inverted_residual_block(
      x, expanded_channels=64 * expansion_factor, output_channels=64, strides=2
  )
  x = mobilevit_block(x, num_blocks=4, projection_dim=80)

  # Third MV2 -> MobileViT block
  x = inverted_residual_block(
      x, expanded_channels=80 * expansion_factor, output_channels=80, strides=2
  )
  x = mobilevit_block(x, num_blocks=3, projection_dim=96)
  x = conv_block(x, filters=320, kernel_size=1, strides=1)

  # Classification head
  x = layers.GlobalAvgPool2D()(x)
  outputs = layers.Dense(num_classes, activation='softmax')(x)

  return keras.Model(inputs, outputs)

mobilevit_xxs = create_mobilevit()
mobilevit_xxs.summary()

## Dataset preparation

In [11]:
batch_size = 64
auto = tf.data.AUTOTUNE
resize_bigger = 280
num_classes = 5

In [12]:
def preprocess_dataset(is_training=True):
  def _pp(image, label):
    if is_training:
      # Resize to a bigger spatial resolution and take the random crops.
      image = tf.image.resize(image, (resize_bigger, resize_bigger))
      image = tf.image.random_crop(image, (image_size, image_size, 3))
      image = tf.image.random_flip_left_right(image)
    else:
      image = tf.image.resize(image, (image_size, image_size))
    label = tf.one_hot(label, depth=num_classes)
    return image, label

  return _pp

In [13]:
def prepare_dataset(dataset, is_training=True):
  if is_training:
    dataset = dataset.shuffle(batch_size * 10)
  dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=auto)
  return dataset.batch(batch_size).prefetch(auto)

## Load and prepare the `tf_flowers` dataset

In [14]:
train_dataset, val_dataset = tfds.load(
    'tf_flowers', split=['train[:90%]','train[90%:]'], as_supervised=True
)

num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f'Number of Training example: {num_train}')
print(f'Number of Validation example: {num_val}')

train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)

Downloading and preparing dataset 218.21 MiB (download: 218.21 MiB, generated: 221.83 MiB, total: 440.05 MiB) to /root/tensorflow_datasets/tf_flowers/3.0.1...
Dataset tf_flowers downloaded and prepared to /root/tensorflow_datasets/tf_flowers/3.0.1. Subsequent calls will reuse this data.
Number of Training example: 3303
Number of Validation example: 367


## Train a MobileViT (XXS) model

In [15]:
learning_rate = 0.002
label_smoothing_factor = 0.1
epochs = 30

optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
loss_fn =  keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing_factor)

def run_experiment(epochs=epochs):
  mobilevit_xxs = create_mobilevit(num_classes=num_classes)
  mobilevit_xxs.compile(optimizer=optimizer, loss=loss_fn, metrics=['accuracy'])

  # When using `save_weights_only=True` in `ModelCheckpoint`,
  # the filepath provided must end in `.weights.h5`
  checkpoint_filepath = '/tmp/checkpoint.weights.h5'
  checkpoint_callback = keras.callbacks.ModelCheckpoint(
      checkpoint_filepath,
      monitor='val_accuracy',
      save_best_only=True,
      save_weights_only=True,
  )

  mobilevit_xxs.fit(
      train_dataset,
      validation_data=val_dataset,
      epochs=epochs,
      callbacks=[checkpoint_callback],
  )

  mobilevit_xxs.load_weights(checkpoint_filepath)
  _, accuracy = mobilevit_xxs.evaluate(val_dataset)
  print(f'Validation accuracy: {round(accuracy * 100), 2}%')
  return mobilevit_xxs


mobilevit_xxs = run_experiment()


Epoch 1/30
[1m52/52[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m256s[0m 2s/step - accuracy: 0.3580 - loss: 1.4885 - val_accuracy: 0.1907 - val_loss: 1.7295
Epoch 2/30
[1m52/52[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 242ms/step - accuracy: 0.5868 - loss: 1.1484 - val_accuracy: 0.1907 - val_loss: 1.7349
Epoch 3/30
[1m52/52[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 242ms/step - accuracy: 0.6670 - loss: 1.0502 - val_accuracy: 0.1907 - val_loss: 1.9980
Epoch 4/30
[1m52/52[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 242ms/step - accuracy: 0.6921 - loss: 0.9936 - val_accuracy: 0.1907 - val_loss: 2.4279
Epoch 5/30
[1m52/52[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 243ms/step - accuracy: 0.7042 - loss: 0.9781 - val_accuracy: 0.1907 - val_loss: 2.4653
Epoch 6/30
[1m52/52[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 242ms/step - accuracy: 0.7261 - loss: 0.9466 - val_accuracy: 0.1907 - val_loss: 2.3149
Epoch 7/30
[1m52/52[0m

-  With about `one million parameters`, getting to `~84% top-1 accuracy` on 256x256 resolution is a strong result.

- This MobileViT mobile is fully compatible with `TensorFlow Lite (TFLite)` and can be converted with the following code:

## Results and TFLite conversion

In [16]:
# Serialize the model as a SavedModel.
tf.saved_model.save(mobilevit_xxs, 'mobilevit_xxs')

In [1]:
# # Convert to TFLite.
# '''
# This form of quantization is called post-training
# dynamic-range quantization in TFLite
# '''

# converter = tf.lite.TFLiteConverter.from_saved_model('mobilevit_xxs')
# converter.optimizations = [tf.lite.Optimize.DEFAULT]
# converter.target_spec.supported_ops = [
#     tf.lite.OpsSet.TFLITE_BUILTINS,   # Enable TensorFlow Lite ops.
#     tf.lite.OpsSet.SELECT_TF_OPS,     # Enable TensorFlow ops
# ]

# tflite_model = converter.convert()
# open('mobilevit_xxs.tflite', 'wb').write(tflite_model)