<a href="https://colab.research.google.com/github/LUUTHIENXUAN/Vision-Transformer-ViT-in-Tensorflow/blob/main/Image_classification_with_Vision_Transformer_ver1_0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Image classification with Vision Transformer

**Author:** [LUU THIEN XUAN](https://www.linkedin.com/in/thienxuanluu/)<br>

**Credit:** [Phil Wang](https://github.com/lucidrains)<br>
**Credit:** [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)<br>
**Date created:** 2021/01/18<br>
**Last modified:** 2021/01/18<br>
**Description:** Implementing the Vision Transformer (ViT) model for image classification.

## Introduction

This example implements the [Vision Transformer (ViT)](https://arxiv.org/abs/2010.11929)
model by Alexey Dosovitskiy et al. for image classification,
and demonstrates it on the CIFAR-100 dataset.
The ViT model applies the Transformer architecture with self-attention to sequences of
image patches, without using convolution layers.

This example requires TensorFlow 2.4 or higher, as well as
[TensorFlow Addons](https://www.tensorflow.org/addons/overview),
which can be installed using the following command:

```python
pip install -U tensorflow-addons
```

## Setup

In [1]:
!pip install -U tensorflow-addons



In [2]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

## Prepare the data

In [3]:
num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")


x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)


## Configure the hyperparameters

In [4]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100

image_size = 72  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
projection_dim = 64
num_heads = 4
transformer_layers = 8
mlp_head_units = [2048, 1024]  # Size of the dense layers of the final classifier

## Use data augmentation

In [5]:
data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.Normalization(),
        layers.experimental.preprocessing.Resizing(image_size, image_size),
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(factor=0.02),
        layers.experimental.preprocessing.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

In [6]:
train_batches = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .cache()
    .shuffle(1000)
    .batch(batch_size,drop_remainder=True)
    .repeat()
    .prefetch(buffer_size=tf.data.AUTOTUNE))

val_batches = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .cache()
    .batch(batch_size,drop_remainder=True)
    .prefetch(buffer_size=tf.data.AUTOTUNE))

## Build the ViT model

In [7]:
from tensorflow.keras import backend as K

In [8]:
class PreNorm(layers.Layer):
  def __init__(self, fn):
    super(PreNorm, self).__init__()
    self.norm = layers.LayerNormalization(epsilon=1e-6)
    self.fn = fn

  #@tf.function(jit_compile=True)  
  def call(self, x, **kwargs):
    return self.fn(self.norm(x), **kwargs)

In [9]:
class FeedForward(layers.Layer):
  
  def __init__(self, dim, hidden_dim, dropout=0.1):
    super(FeedForward, self).__init__()
    self.net =  keras.Sequential([
                    layers.Dense(hidden_dim, activation=tf.nn.gelu),
                    #tfa.layers.GELU(),
                    layers.Dropout(dropout),
                    layers.Dense(dim),
                    layers.Dropout(dropout)
                    ])
  #@tf.function(jit_compile=True)
  def call(self, x):
    return self.net(x)

In [10]:
class Attention(layers.Layer):

  def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
     
    super(Attention, self).__init__()
    self.heads = heads
    self.dim_head = dim_head

    self.inner_dim = self.dim_head *  self.heads

    self.to_q = layers.Dense(self.inner_dim)
    self.to_k = layers.Dense(self.inner_dim)
    self.to_v = layers.Dense(self.inner_dim)

    self.scale = 1/K.sqrt(K.cast(dim_head, 'float32'))
    self.attend = layers.Activation('softmax')
    self.to_out = keras.Sequential([
            layers.Dense(dim),
            layers.Dropout(dropout)])
   
  #@tf.function(jit_compile=True)
  def call(self, inputs):
    batch_size = K.int_shape(inputs)[0] #tf.shape(inputs)[0]

    q = self.to_q(inputs)
    k = self.to_k(inputs)
    v = self.to_v(inputs)

    q = K.reshape(q, (batch_size, -1, self.heads, self.dim_head))
    k = K.reshape(k, (batch_size, -1, self.heads, self.dim_head))
    v = K.reshape(v, (batch_size, -1, self.heads, self.dim_head))

    q = K.permute_dimensions(q, (0, 2, 1, 3))
    k = K.permute_dimensions(k, (0, 2, 1, 3))
    v = K.permute_dimensions(v, (0, 2, 1, 3))

    dots = tf.matmul(q, k, transpose_b=True) * self.scale
    attn = self.attend(dots)

    out = tf.matmul(attn, v)
    out = K.permute_dimensions(out, (0, 2, 1, 3))
    out = K.reshape(out, (batch_size, -1, self.inner_dim))
    
    return self.to_out(out)

In [11]:
class Transformer(layers.Layer):
  
  def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
    super(Transformer, self).__init__()
    self.layers = []
    for _ in range(depth):
      self.layers.append(
          [PreNorm(Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
           PreNorm(FeedForward(dim, mlp_dim, dropout = dropout))])
  
  #@tf.function(jit_compile=True)        
  def call(self, x):
    for attn, ff in self.layers:
      x = attn(x) + x
      x = ff(x) + x
    return x

In [12]:
def pair(t):
  return t if isinstance(t, tuple) else (t, t) 

class ViT(layers.Layer):
  
  def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, 
               pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
    
    super(ViT, self).__init__()
    image_height, image_width = pair(image_size)
    patch_height, patch_width = pair(patch_size)

    assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

    self.num_patches = (image_height // patch_height) * (image_width // patch_width)
    self.patch_dim   = channels * patch_height * patch_width
    assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
    
    self.patch_size = patch_size
    self.dim = dim
    self.dense = layers.Dense(self.dim)

    self.pos_embedding = self.add_weight(shape=[1, self.num_patches+1, self.dim],dtype=tf.float32) 
    self.cls_token = self.add_weight(shape=[1, 1, self.dim],dtype=tf.float32) 
    self.dropout = layers.Dropout(emb_dropout)

    self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

    self.pool = pool
    self.to_latent1 = layers.Dropout(0.1)
    self.to_latent2 = layers.Dense(mlp_dim, activation=tfa.activations.gelu)

    self.mlp_head = keras.Sequential([
            layers.LayerNormalization(epsilon=1e-6),
            layers.Dense(num_classes)]) #, activation='softmax'
      
  def build(self, input_shape):
    self.b = input_shape[0]
    super(ViT, self).build(input_shape)
    
  #@tf.function(jit_compile=True)
  def call(self, inputs):
    
    x = tf.nn.space_to_depth(inputs, self.patch_size)
    x = K.reshape(x, (-1, self.num_patches, self.patch_dim))
    x = self.dense(x)
    b = tf.shape(x)[0] #b , _ , _ = x.shape
    
    """
    cls_tokens = K.repeat_elements(self.cls_token, self.b, axis=0)
    x = K.concatenate((cls_tokens, x), axis=1)
    
    pos_emb = K.repeat_elements(self.pos_embedding, self.b, axis=0)
    """
    
    cls_tokens = tf.repeat(self.cls_token, b, axis=0)
    x = tf.concat((cls_tokens, x), axis=1)
    
    pos_emb = tf.repeat(self.pos_embedding, b, axis=0)  
    
    x += pos_emb
    x = self.dropout(x)

    x = self.transformer(x)

    x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

    x = self.to_latent1(self.to_latent2(x))
    return self.mlp_head(x)

### Debug model

In [13]:
model = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 8,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)
model.build(input_shape=(1,256,256,3))


img = tf.random.uniform(shape=[1, 256, 256, 3])
preds = model(img)
print(preds.shape) # (1, 1000)

(1, 1000)


## Compile, train, and evaluate the mode

In [14]:
def run_experiment(model):
  
  optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

  model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

  checkpoint_filepath = "/tmp/checkpoint"
  checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

  history = model.fit(
        train_batches,
        batch_size=batch_size,
        steps_per_epoch = len(x_train)//batch_size,
        epochs=num_epochs,
        validation_data=val_batches,
        validation_steps= len(x_test)//batch_size, 
        #validation_split=0.1,
        #callbacks=[checkpoint_callback],
    )

  return history

In [15]:
class vit_classifier(keras.Model):
  
  def __init__(self):
    
    super(vit_classifier, self).__init__()
    self.aug = data_augmentation
    self.vit = ViT(image_size = image_size,
                   patch_size = patch_size,
                   num_classes = num_classes,
                   dim = 1024,
                   depth = transformer_layers,
                   heads = num_heads,
                   dim_head = 16,
                   mlp_dim = 2048,
                   dropout = 0.1,
                   emb_dropout = 0.1
                )
    
      
  def call(self, inputs):
    x = self.aug(inputs)
    return self.vit(x)

In [None]:
classifier = vit_classifier()
classifier.build(input_shape=(batch_size,32,32,3))
history = run_experiment(classifier)

After 100 epochs, the ViT model achieves around 55% accuracy and
82% top-5 accuracy on the test data. These are not competitive results on the CIFAR-100 dataset,
as a ResNet50V2 trained from scratch on the same data can achieve 67% accuracy.

Note that the state of the art results reported in the
[paper](https://arxiv.org/abs/2010.11929) are achieved by pre-training the ViT model using
the JFT-300M dataset, then fine-tuning it on the target dataset. To improve the model quality
without pre-training, you can try to train the model for more epochs, use a larger number of
Transformer layers, resize the input images, change the patch size, or increase the projection dimensions. 
Besides, as mentioned in the paper, the quality of the model is affected not only by architecture choices, 
but also by parameters such as the learning rate schedule, optimizer, weight decay, etc.
In practice, it's recommended to fine-tune a ViT model
that was pre-trained using a large, high-resolution dataset.