## cifar-vision-transformer

Implementation of a Vision Transformer from scratch in Keras, following the Keras code example "Image classification with Vision Transformer"

The original code can be found here:

https://keras.io/examples/vision/image_classification_with_vision_transformer/

## Download and Imports

In [None]:
pip install -U tensorflow-addons

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa

## Transformer Constants

In [3]:
PROJECTION_DIM = 108
NUM_HEADS = 6
TRANSFORMER_UNITS = [
    PROJECTION_DIM * 2,
    PROJECTION_DIM,
]
LEARNING_RATE = 0.001
WEIGHT_DECAY = 0.0001

## Data Constants

In [4]:
NUM_CLASSES = 100
INPUT_SHAPE = (32, 32, 3)

## Download Cifar 100 Dataset

In [5]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)


## Patch Creation Layer

In [6]:
class Patches(tf.keras.layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

## Patch Encoding Layer

Embeds the individual Patches and adds positional encoding

In [7]:
class PatchEncoder(tf.keras.layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = tf.keras.layers.Dense(units=projection_dim)
        self.position_embedding = tf.keras.layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

## Multilayer Perceptron Layer

In [8]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = tf.keras.layers.Dense(units, activation=tf.nn.gelu)(x)
        x = tf.keras.layers.Dropout(dropout_rate)(x)
    return x

## Building the Transformer

In [9]:
def create_vision_transformer():

  i = tf.keras.layers.Input(shape=INPUT_SHAPE)

  patches = Patches(6)(i)
  encoded = PatchEncoder(25, PROJECTION_DIM)(patches)

  for _ in range(2):

    norm_0 = tf.keras.layers.LayerNormalization()(encoded)
    attention = tf.keras.layers.MultiHeadAttention(num_heads=NUM_HEADS, key_dim = PROJECTION_DIM)(norm_0, norm_0)

    skip = tf.keras.layers.Add()([attention, encoded])
    norm_1 = tf.keras.layers.LayerNormalization()(skip)

    perceptron_layer = mlp(norm_1, hidden_units=TRANSFORMER_UNITS, dropout_rate=0.1)
    encoded = tf.keras.layers.Add()([norm_1, perceptron_layer])

  norm_2 = tf.keras.layers.LayerNormalization()(encoded)
  flat = tf.keras.layers.Flatten()(norm_2)
  drop = tf.keras.layers.Dropout(0.5)(flat)

  features = mlp(drop, hidden_units=TRANSFORMER_UNITS, dropout_rate=0.1)
  logits = tf.keras.layers.Dense(NUM_CLASSES)(features)

  return tf.keras.Model(inputs=i, outputs=logits)

In [10]:
transformer = create_vision_transformer()

In [11]:
transformer.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 32, 32, 3)]          0         []                            
                                                                                                  
 patches (Patches)           (None, None, 108)            0         ['input_1[0][0]']             
                                                                                                  
 patch_encoder (PatchEncode  (None, 25, 108)              14472     ['patches[0][0]']             
 r)                                                                                               
                                                                                                  
 layer_normalization (Layer  (None, 25, 108)              216       ['patch_encoder[0][0]']   

## Training the Transformer

In [13]:
def train(model, epochs=100):

  optimizer = tfa.optimizers.AdamW(learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
  model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            tf.keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )
  model.fit(x_train, y_train, epochs=epochs, validation_split=0.1)
  _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
  print("Model Accuracy: ", accuracy)
  print("Model top-5 Accuracy: ", top_5_accuracy)

In [14]:
train(transformer)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78