<a href="https://colab.research.google.com/github/Ifeeding99/tensorflow-scripts/blob/main/Vision_transformer_with_keras.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U tensorflow-addons

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow-addons
  Downloading tensorflow_addons-0.19.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m10.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tensorflow-addons
Successfully installed tensorflow-addons-0.19.0


In [None]:
!pip install --upgrade tensorflow

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow
  Downloading tensorflow-2.11.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (588.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m588.3/588.3 MB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Collecting tensorflow-estimator<2.12,>=2.11.0
  Downloading tensorflow_estimator-2.11.0-py2.py3-none-any.whl (439 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m439.2/439.2 KB[0m [31m40.9 MB/s[0m eta [36m0:00:00[0m
Collecting keras<2.12,>=2.11.0
  Downloading keras-2.11.0-py2.py3-none-any.whl (1.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m78.2 MB/s[0m eta [36m0:00:00[0m
Collecting flatbuffers>=2.0
  Downloading flatbuffers-23.1.21-py2.py3-none-any.whl (26 kB)
Collecting tensorboard<2.12,>=2.11
  Downloading tensorboard-2.11.2-py3-none-any.whl (6.0 MB)
[2K     [90m━━━━━━━━━━━━━

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt

In [None]:
(X_train, y_train), (X_val, y_val) = tf.keras.datasets.cifar100.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz


In [None]:
learning_rate = 1e-3
lr_decay = 1e-4
batch_size = 64
epochs = 20
img_size = 72
num_heads_= 4
patch_size = 10
num_patches = (img_size // patch_size)**2
projection_dim = 64
transformers_layers = 8
mlp_head_units = [2048, 1024]
transformer_units = [
    projection_dim *2,
    projection_dim
]

In [None]:
def mlp(x, hidden_units, dropout_rate):
  for units in hidden_units:
    x = tf.keras.layers.Dense(units, activation = tf.nn.gelu)(x)
    x = tf.keras.layers.Dropout(dropout_rate)(x)
  return x

In [None]:
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.Normalization(),
    tf.keras.layers.Resizing(img_size, img_size),
    tf.keras.layers.RandomFlip(),
    tf.keras.layers.RandomRotation(factor=0.02),
    tf.keras.layers.RandomZoom(0.2)
],
name = 'data_augmentation')
data_augmentation.layers[0].adapt(X_train)

In [None]:
class Patches(tf.keras.layers.Layer):
  def __init__(self, patch_size):
    super().__init__()
    self.patch_size = patch_size

  def call(self, images):
    batch_size = tf.shape(images)[0]
    patches = tf.image.extract_patches(
        images = images,
        sizes = [1, self.patch_size, self.patch_size, 1],
        strides = [1, self.patch_size, self.patch_size, 1],
        rates = [1,1,1,1],
        padding = 'VALID'
    )
    patch_dims = patches.shape[-1]
    patches = tf.reshape(patches, [batch_size, -1, patch_dims])
    return patches

In [None]:
class PatchEncoder(tf.keras.layers.Layer):
  def __init__(self, num_patches, projection_dim):
    super().__init__()
    self.num_patches = num_patches
    self.projection_dim = tf.keras.layers.Dense(projection_dim)
    self.positional_embedding = tf.keras.layers.Embedding(input_dim = num_patches, output_dim = projection_dim)

  def call(self, patch):
    positions = tf.range(start=0, limit = self.num_patches, delta=1)
    encoded = self.projection_dim(patch) + self.positional_embedding(positions)
    return encoded

In [None]:
def create_ViT_classifier():
  input_shape = (32,32,3)
  num_classes = 100
  inputs = tf.keras.Input(shape = input_shape)
  augmented = data_augmentation(inputs)
  patches = Patches(patch_size)(augmented)
  encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

  for _ in range (transformers_layers):
    x1 = tf.keras.layers.LayerNormalization()(encoded_patches)
    attention_output = tf.keras.layers.MultiHeadAttention(num_heads = num_heads_, key_dim = projection_dim, dropout = 0.1)(x1,x1)
    x2 = tf.keras.layers.Add()([attention_output, encoded_patches])
    x3 = tf.keras.layers.LayerNormalization()(x2)
    x3 = mlp(x3, hidden_units = transformer_units, dropout_rate = 0.1)
    encoded_patches = tf.keras.layers.Add()([x3,x2])

  representation = tf.keras.layers.LayerNormalization()(encoded_patches)
  representation = tf.keras.layers.Flatten()(representation)
  features = mlp(representation, hidden_units = mlp_head_units, dropout_rate = 0.5)
  logits = tf.keras.layers.Dense(num_classes)(features)
  model = tf.keras.Model(inputs = inputs, outputs = logits)
  return model

In [None]:
def run_experiment(model):
  optimizer = tfa.optimizers.AdamW(learning_rate=learning_rate, weight_decay=lr_decay)
  model.compile(optimizer = optimizer,
                loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics = [
                    tf.keras.metrics.SparseCategoricalAccuracy(name = 'accuracy'),
                    tf.keras.metrics.SparseTopKCategoricalAccuracy(5, name = 'top-5-accuracy'),
                ])

  history = model.fit(X_train, y_train,
                      validation_data = (X_val, y_val),
                      batch_size = batch_size,
                      epochs = 20)

In [None]:
model = create_ViT_classifier()
run_experiment(model)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
