# An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

<img src='https://drive.google.com/uc?id=1WtqzFKgO6qSpBGNbbjGMYgWjtprREZfx'>

<img src='https://drive.google.com/uc?id=1wQxQ7BgaLvrRmlFIGU0ymGJ_AmGqYVFM'>

In [1]:
from tensorflow.keras import layers
import tensorflow as tf
import numpy as np

## Layers

### Image Preprocessing

Normalize and resize the images.

In [2]:
class Preprocessing(layers.Layer):
  def __init__(self,image_size):
    super().__init__()
    self.image_size = image_size
    self.normalizing_layer = layers.Normalization()
    self.resizing_layer = layers.Resizing(image_size,image_size)

  def call(self,x):
    x = self.normalizing_layer(x)
    x = self.resizing_layer(x)
    return x

### Image Patching

Create the patches from the images.

In [3]:
class Patching(layers.Layer):
  def __init__(self,patch_size):
    super().__init__()
    self.patch_size = patch_size

  def call(self,x):
    batch_size = tf.shape(x)[0]
    patches = tf.image.extract_patches(
        images=x,
        sizes=[1,self.patch_size,self.patch_size,1],
        strides=[1,self.patch_size,self.patch_size,1],
        rates=[1,1,1,1],
        padding="VALID"
    )
    patches_dim = tf.shape(patches)[-1]
    num_patches = patches.shape[1]*patches.shape[2]
    patches = tf.reshape(patches,shape=(batch_size,num_patches,patches_dim))
    return patches

### Encoding Patches

Create the positional embeddings and encode them with the linear projections of the flattened patches.

In [4]:
class PatchEmbedding(layers.Layer):
  def __init__(self,embedding_dim,num_patches):
    super().__init__()
    self.embedding_dim = embedding_dim
    self.num_patches = num_patches
    self.projection_layer = layers.Dense(embedding_dim)
    self.embedding_layer = layers.Embedding(
        input_dim = self.num_patches, output_dim=self.embedding_dim
    )

  def call(self,x):
    positions = tf.range(start=0,limit=self.num_patches,delta=1)
    embedded_patches = self.projection_layer(x) + self.embedding_layer(positions)
    return embedded_patches

### *Transformer Encoder*

Create the transformer encoder. This should follow the following architecture:



<img src='https://drive.google.com/uc?id=1tWq5cjpQUPGHt_MLQRoysmlaQzoEKJpz'>

In [5]:
class TransformerEncoder(layers.Layer):
  def __init__(self,num_heads,key_dim,value_dim,hidden_neurons,embedding_dim,drop_rate=0.0,eps=1e-08):
    super().__init__()
    self.num_heads = num_heads
    self.key_dim = key_dim
    self.value_dim = value_dim
    self.hidden_neurons = hidden_neurons
    self.eps = eps
    self.drop_rate = drop_rate
    self.embedding_dim = embedding_dim
    self.norm1 = layers.LayerNormalization(epsilon=self.eps)
    self.norm2 = layers.LayerNormalization(epsilon=self.eps)
    self.mha = layers.MultiHeadAttention(self.num_heads,self.key_dim,self.value_dim,dropout=self.drop_rate)
    self.mlp = tf.keras.models.Sequential([
        layers.Dense(self.hidden_neurons, activation=tf.nn.gelu),
        layers.Dense(self.embedding_dim, activation=tf.nn.gelu)
    ])

  def call(self,x):
    x1 = self.norm1(x)
    x2 = self.mha(x1,x1)
    x = layers.Add()([x,x2])
    x1 = self.norm2(x)
    x2 = self.mlp(x1)
    x = layers.Add()([x,x2])
    return x

## Main

### Load Data

Load the CIFAR100 dataset.

In [6]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
input_shape = x_train.shape[1:]

### Build model

Create the ViT. Do not forget to adapt the normalize layer to the x_train features with the "adapt" method of the normalization layer.

In [7]:
image_size = 64
patch_size = 8
num_patches = (image_size//patch_size)**2  #This is true under some assumptions: The image size is divisible by the patch size, patches are (patch_size x patch_size), strides are (patch_size x patch_size), and padding="VALID".
embedding_dim = 100
hidden_neurons = 200
key_dim = 1000
value_dim = 100
num_heads = 8
drop_rate = 0.2
L = 10  # Num of stacked transformer encoders
num_classes = 100

input_layer = layers.Input(shape=input_shape)
prep_layer = Preprocessing(image_size)
prep_layer.normalizing_layer.adapt(x_train)
prep_images = prep_layer(input_layer)
patches = Patching(patch_size)(prep_images)
embeddings = PatchEmbedding(embedding_dim,num_patches)(patches)
for _ in range(L):
  embeddings = TransformerEncoder(num_heads,key_dim,value_dim,hidden_neurons,embedding_dim,drop_rate)(embeddings)
flattened_embeddings = layers.Flatten()(embeddings)
hidden_output = layers.Dense(hidden_neurons)(flattened_embeddings)
hidden_output = layers.Dropout(drop_rate)(hidden_output)
outputs = layers.Dense(num_classes)(hidden_output)
model = tf.keras.Model(inputs=input_layer,outputs=outputs)

In [8]:
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 32, 32, 3)]       0         
                                                                 
 preprocessing (Preprocessi  (None, 64, 64, 3)         7         
 ng)                                                             
                                                                 
 patching (Patching)         (None, 64, 192)           0         
                                                                 
 patch_embedding (PatchEmbe  (None, 64, 100)           25700     
 dding)                                                          
                                                                 
 transformer_encoder (Trans  (None, 64, 100)           1817600   
 formerEncoder)                                                  
                                                             

Compile the model.

In [9]:
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))

### Train model

Fit the ViT with the CIFAR100 data.

In [10]:
model.fit(x_train,y_train,
          validation_data=(x_test,y_test),
          epochs=10)

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x7cc2852100a0>