<a href="https://colab.research.google.com/github/AmandaJMendes/ViT_from_scatch/blob/main/vit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Import libraries

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt

# Define hyperparameters

In [2]:
INPUT_SHAPE = (224, 224, 3) #(H, W, C)

B  = 64      # Batch size
P = 32       # Patch size  / Flattened patch length is P*P*C
D = 768      # Embedding dimension
K = 12       # Number of attention heads
D_mlp = 3072 # Hidden dimension of MLP in the transformer block
layers = 12  # Number of layers/blocks in the transformer encoder

# Tensorflow datasets

In [None]:
train = tfds.load('cifar10', split='train[:90%]', shuffle_files=True)
valid = tfds.load('cifar10', split='train[90%:]', shuffle_files=True)
test = tfds.load('cifar10', split='test', shuffle_files=False)

train = train.map(lambda sample: (sample['image']/255, tf.one_hot(sample['label'], depth=10)))
train = train.shuffle(1000).batch(B).prefetch(tf.data.AUTOTUNE)

valid = valid.map(lambda sample: (sample['image']/255, tf.one_hot(sample['label'], depth=10)))
valid = valid.batch(B).prefetch(tf.data.AUTOTUNE)

test = test.map(lambda sample: (sample['image']/255, tf.one_hot(sample['label'], depth=10)))
test = test.batch(B).prefetch(tf.data.AUTOTUNE)



Downloading and preparing dataset 162.17 MiB (download: 162.17 MiB, generated: 132.40 MiB, total: 294.58 MiB) to /root/tensorflow_datasets/cifar10/3.0.2...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/50000 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/cifar10/3.0.2.incompleteI3IPB1/cifar10-train.tfrecord*...:   0%|          …

Generating test examples...:   0%|          | 0/10000 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/cifar10/3.0.2.incompleteI3IPB1/cifar10-test.tfrecord*...:   0%|          |…

Dataset cifar10 downloaded and prepared to /root/tensorflow_datasets/cifar10/3.0.2. Subsequent calls will reuse this data.


In [None]:
for example in train.take(1):
  image, label = example

# Custom layers

In [None]:
class ExtractPatchesLayer(tf.keras.layers.Layer):
  def __init__(self, P = 32):
    super().__init__()
    self.P = P

  def call(self, inputs):
    patches = tf.image.extract_patches(images  = inputs,
                                       sizes   = [1, self.P, self.P, 1],
                                       strides = [1, self.P, self.P, 1],
                                       rates   = [1, 1, 1, 1],
                                       padding='VALID')

    #number of patches = N = v_patches*h_patches
    batch_size, v_patches, h_patches, size = patches.shape

    return tf.reshape(patches, [tf.shape(inputs)[0], v_patches*h_patches, size])

In [None]:
class PatchEmbeddingLayer(tf.keras.layers.Layer):
  def __init__(self, D=1024, patch_length=32*32):
    super().__init__()
    self.E = self.add_weight(shape=(patch_length, D),
                             initializer="random_normal",
                             trainable=True, name = 'E')

  def call(self, inputs):
    return tf.matmul(inputs, self.E)

In [None]:
class PrependCLSLayer(tf.keras.layers.Layer):
  def __init__(self, D=32):
    super().__init__()
    self.cls_token = self.add_weight(shape=(1, 1, D),
                                     initializer="random_normal",
                                     trainable=True, name = 'cls_E')

  def call(self, inputs):
    batch_size = tf.shape(inputs)[0]
    class_token = tf.repeat(self.cls_token, repeats = [batch_size], axis = 0)
    return tf.concat([class_token, inputs], axis = 1)

In [None]:
class AddPositionEmbeddingLayer(tf.keras.layers.Layer):
  def __init__(self, tokens, D=1024):
    super().__init__()
    self.E_pos = self.add_weight(shape=(1, tokens, D),
                                 initializer="random_normal",
                                 trainable=True, name = 'pos_E')

  def call(self, inputs):
    return tf.math.add(inputs, self.E_pos)

In [None]:
class EmbeddingLayer(tf.keras.layers.Layer):
  def __init__(self, patch_length, N, D=1024):
    super().__init__()
    self.patch_embedding   = PatchEmbeddingLayer(D, patch_length)
    self.prepend_cls_token = PrependCLSLayer(D)
    self.add_pos_embedding = AddPositionEmbeddingLayer(N+1, D)

  def call(self, inputs):
    patch_emb     = self.patch_embedding(inputs)          # (B, N, D)
    patch_cls_emb = self.prepend_cls_token(patch_emb)     # (B, N+1, D)
    final_emb     = self.add_pos_embedding(patch_cls_emb) # (B, N+1, D)

    return final_emb

In [None]:
class MSALayer(tf.keras.layers.Layer):
  def __init__(self, D = 1024, K = 16, N = 121):
    super().__init__()
    self.D   = D
    self.K   = K
    self.N   = N
    self.D_h = D//K

    self.Uqkv = self.add_weight(shape=(D, 3*D), initializer="random_normal",
                                trainable=True, name = "Uqkv")
    self.Umsa = self.add_weight(shape=(D, D),   initializer="random_normal",
                                trainable=True, name = "Umsa")

  def call(self, inputs):
                                                        #inputs => (B, N+1, D)
    batch_size = tf.shape(inputs)[0]

    qkv = tf.matmul(inputs, self.Uqkv)                  # (B, N+1, 3*D)
    q, k, v = tf.split(qkv, 3 , axis = -1)              # 3 x (B, N+1, D)

    q = tf.reshape(q, [-1, self.N+1, self.K, self.D_h]) # (B, N+1, K, D_h)
    q = tf.transpose(q, [0, 2, 1, 3])                   # (B, K, N+1, D_h)

    k = tf.reshape(k, [-1, self.N+1, self.K, self.D_h]) # (B, N+1, K, D_h)
    k = tf.transpose(k, [0, 2, 1, 3])                   # (B, K, N+1, D_h)

    v = tf.reshape(v, [-1, self.N+1, self.K, self.D_h]) # (B, N+1, K, D_h)
    v = tf.transpose(v, [0, 2, 1, 3])                   # (B, K, N+1, D_h)

    qk = tf.matmul(q, k, transpose_b = True)            # (B, K, N+1, N+1)
    A = tf.nn.softmax(qk/(self.D_h**0.5))               # (B, K, N+1, N+1)

    msa = tf.matmul(A, v)                               # (B, K, N+1, D_h)
    msa = tf.transpose(msa, [0, 2, 1, 3])               # (B, N+1, K, D_h)
    msa = tf.reshape(msa, [-1, self.N+1, self.D])       # (B, N+1, K*D_h) = (B, N+1, D)

    msa_out = tf.matmul(msa, self.Umsa)                 # (B, N+1, D)

    return msa_out

In [None]:
class MLP(tf.keras.layers.Layer):
    def __init__(self, D = 1024, D_mlp = 2048):
        super().__init__()
        self.layer1 = tf.keras.layers.Dense(D_mlp, activation = 'gelu')
        self.layer2 = tf.keras.layers.Dense(D,     activation = 'gelu')

    def call(self, inputs):
        return self.layer2(self.layer1(inputs))

In [None]:
class TransformerBlock(tf.keras.layers.Layer):
  def __init__(self, D = 1024, D_mlp = 2048, K = 16, N = 121):
    super().__init__()
    self.ln_1  = tf.keras.layers.LayerNormalization(axis = [1, 2])
    self.att   = MSALayer(D, K, N)
    self.ln_2  = tf.keras.layers.LayerNormalization(axis = [1, 2])
    self.mlp   = MLP(D, D_mlp)

  def call(self, inputs):
                                #inputs => (B, N+1, D)
    norm1   = self.ln_1(inputs) # (B, N+1, D)
    att_out = self.att(norm1)   # (B, N+1, D)
    resid1  = inputs + att_out  # (B, N+1, D)
    norm2   = self.ln_2(resid1) # (B, N+1, D)
    mlp_out = self.mlp(norm2)   # (B, N+1, D)
    resid2  = resid1 + mlp_out  # (B, N+1, D)
    return resid2

# ViT Model

In [None]:
class ViT(tf.keras.Model):
  def __init__(self, input_shape = (224, 224, 3), P = 32, D = 1024, K = 16,
               D_mlp = 2048, layers = 8, classes = 10):
    super().__init__()
    self.H, self.W, C = input_shape
    N = (self.H//(P))*(self.W//(P)) #Number of patches
    patch_length = C*P**2

    self.input_layer         = tf.keras.layers.InputLayer(input_shape, batch_size = 32)
    self.extract_patches     = ExtractPatchesLayer(P)
    self.embedding_layer     = EmbeddingLayer(patch_length, N, D)
    self.transformer_blocks  = [TransformerBlock(D, D_mlp, K, N) for i in range(layers)]
    self.classification_head = tf.keras.layers.Dense(classes, activation = 'softmax')

  def call(self, inputs):
    inputs  = self.input_layer(inputs)                         # (B, H, W, C)
    resized_inputs = tf.image.resize(inputs, (self.H, self.W)) # (B, self.H, self.W, C)
    patches = self.extract_patches(resized_inputs)             # (B, N)
    embeddings = self.embedding_layer(patches)                 # (B, N+1, D)
    z = embeddings
    for block in self.transformer_blocks:
     z = block(z)                                              # (B, N+1, D)
    y = self.classification_head(tf.gather(z, 0, axis = 1))    # (B, N_CLASSES)
    return y


# Train ViT Model

In [None]:
vit = ViT(INPUT_SHAPE, P, D, K, D_mlp, layers)

In [None]:
initial_learning_rate = 1e-3
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=len(train), decay_rate=0.9, staircase=True
)

In [None]:
early_callback = tf.keras.callbacks.EarlyStopping(monitor='val_categorical_accuracy',
                                                  patience=5, verbose = 1)
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(filepath="vit_cifar10",
                                                         save_weights_only=False,
                                                         monitor='val_categorical_accuracy',
                                                         mode='max',
                                                         save_best_only=True,
                                                         verbose = 1)


In [None]:
vit.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule),
            loss      = 'categorical_crossentropy',
            metrics   = [tf.keras.metrics.CategoricalAccuracy()])

In [None]:
vit.fit(train, steps_per_epoch=len(train), epochs = 50,
        validation_data = valid, validation_steps = len(valid),
        callbacks = [early_callback, checkpoint_callback])

Epoch 1/50
Epoch 1: val_categorical_accuracy improved from -inf to 0.23960, saving model to vit_cifar10
Epoch 2/50
Epoch 2: val_categorical_accuracy improved from 0.23960 to 0.31420, saving model to vit_cifar10
Epoch 3/50
Epoch 3: val_categorical_accuracy improved from 0.31420 to 0.35780, saving model to vit_cifar10
Epoch 4/50
Epoch 4: val_categorical_accuracy improved from 0.35780 to 0.39140, saving model to vit_cifar10
Epoch 5/50
Epoch 5: val_categorical_accuracy did not improve from 0.39140
Epoch 6/50
Epoch 6: val_categorical_accuracy improved from 0.39140 to 0.44760, saving model to vit_cifar10
Epoch 7/50
Epoch 7: val_categorical_accuracy improved from 0.44760 to 0.52060, saving model to vit_cifar10
Epoch 8/50
Epoch 8: val_categorical_accuracy did not improve from 0.52060
Epoch 9/50
Epoch 9: val_categorical_accuracy improved from 0.52060 to 0.53560, saving model to vit_cifar10
Epoch 10/50
Epoch 10: val_categorical_accuracy improved from 0.53560 to 0.55540, saving model to vit_cifar

<keras.src.callbacks.History at 0x799ed4b36920>

# Evaluate ViT Model


In [None]:
best_vit = tf.keras.models.load_model("vit_cifar10", custom_objects = {"ViT": ViT})

In [None]:
best_vit.evaluate(test)



[1.4199962615966797, 0.5981000065803528]