In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [6]:
pool = tf.keras.layers.GlobalAveragePooling1D()


x= tf.random.normal((32,197,768))

out = pool(x)
out.shape

TensorShape([32, 768])

# Patch Embeddings

In [7]:
class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self,batch_size,num_patches,emb_size):
        
        positions = np.arange(num_patches)[:,np.newaxis]
        depth = np.arange(emb_size)[np.newaxis, :]
        depth = (2*depth//2)/emb_size

        angle_rates = 1 / (10000**depth)

        angle_rads  = positions * angle_rates
       
        angle_rads[:,0::2] = np.sin(angle_rads[:,0::2])
        angle_rads[:,1::2] = np.cos(angle_rads[:,1::2])


        positions = positions * angle_rads
     
        # positions = positions.T
        self.pos = tf.constant(np.broadcast_to(positions,[batch_size,num_patches,emb_size]))
       
        
   
    def __call__(self):
        return self.pos

In [8]:
class PatchEmbedding(tf.keras.layers.Layer):

  def __init__(self,
               img_size = 224,
               batch_size=32,
               patch_size=16,
               emb_size=768,
               ):
    

    super().__init__()

    self.img_size = img_size
    self.batch_size = batch_size
    self.patch_size = patch_size
    self.emb_size = emb_size

    self.num_patches = (img_size * img_size) // patch_size**2

    self.cnn_layer = tf.keras.layers.Conv2D(filters=768,kernel_size=patch_size,strides=patch_size)

    self.pos = PositionalEmbedding(batch_size,self.num_patches+1,emb_size)


  def __call__(self,images):
    
    # patch embedding
    patches = self.cnn_layer(images)
    patches = tf.reshape(patches,(self.batch_size,-1,self.emb_size))

    # class learnable embedding
    class_token = tf.ones((self.batch_size,self.emb_size))
    class_token = tf.reshape(class_token,(self.batch_size,1,self.emb_size))

    # concat class token with patch embedding
    embedding = tf.concat([patches,class_token],axis=1)
  
    # Positional Embedding
    pos_emb = self.pos()
  
    
    # Add positional emb with embeddings

    embedding = tf.keras.layers.Add()([embedding,pos_emb])

    return embedding

# MultiHeaded Attention

In [9]:
class MultiHeadAttention(tf.keras.layers.Layer):
    
    def __init__(self,
                 emb_size=768,
                 batch_size=32,
                 heads=12):
      
        super(MultiHeadAttention,self).__init__()
        
        
      
        self.emb_size= emb_size
        self.heads = heads
        self.head_dim = emb_size//heads
        self.batch_size = batch_size
        
        # Queries, Keys and Values Matrices Layers
        self.queries = tf.keras.layers.Dense(self.emb_size)
        self.keys = tf.keras.layers.Dense(self.emb_size)
        self.values = tf.keras.layers.Dense(self.emb_size)

       
    
    def self_attention(self,queries,keys,values,masked=False):
   
        out = tf.matmul(queries,tf.transpose(keys,perm=[0, 2, 1]))
        out = out/np.sqrt(self.head_dim)
        
        out = tf.math.softmax(out)
        out = tf.matmul(out,values)
        
        return out 
    
     
    def __call__(self,x):
        
        
        # As mention in the paper first we multiply each word embedding in our case 512 with 512x512 Matrcis
        # We pass our data through the dense layer
        
        # For Multiheaded Attention
        queries = self.queries(x)
        keys = self.keys(x)
        values = self.values(x)
        
        # Self Attention
        attention = self.self_attention(queries,keys,values)
        
            
        # Last matrix 
        
        out = tf.keras.layers.Dense(self.emb_size)(attention)
                   
        return out

# MLP Block

In [10]:
class MLPBlock(tf.keras.layers.Layer):

  def __init__(self,emb_size=768,mlp_block=3072):
    self.mlp_block = mlp_block
    self.emb_size = emb_size

    self.layer_norm = tf.keras.layers.LayerNormalization()
    self.mlp = tf.keras.Sequential([
        tf.keras.layers.Dense(mlp_block),
        tf.keras.layers.Activation('gelu'),
        tf.keras.layers.Dense(self.emb_size)
    ])

  def __call__(self,x):

    x = self.layer_norm(x)
    x = self.mlp(x)

    return x

# Transformer Block

In [11]:
class Transformer(tf.keras.layers.Layer):

  def __init__(self,
               emb_size=768,
               mlp_block=3072,
               heads=12,
               batch_size=32):
    
    self.emb_size=emb_size
    self.mlp_block = mlp_block
    self.heads=heads
    self.batch_size = batch_size

    self.ln1 = tf.keras.layers.LayerNormalization()
    self.ln2 = tf.keras.layers.LayerNormalization()

    self.mha = MultiHeadAttention(self.emb_size,self.batch_size,self.heads)
    self.mlp = MLPBlock(self.emb_size,self.mlp_block)

  def __call__(self,x):

    norm = self.ln1(x)

    msa = self.mha(norm)

    # skip connection

    x = tf.keras.layers.Add()([x,msa])

    norm = self.ln2(x)

    mlp_layer = self.mlp(norm)

    # skip connection

    x = tf.keras.layers.Add()([mlp_layer,x])


    return x



# Vision Transformer

In [12]:
class ViT(tf.keras.layers.Layer):

  
  def __init__(self,
               img_size = 224,
               batch_size=32,
               patch_size=16,
               emb_size=768,
               heads=12,
               mlp_block=3072,
               encoder_layers = 12,
               num_classes=10):
    
    # Patch Embedding
    self.patch_emb = PatchEmbedding(img_size,batch_size,patch_size,emb_size)
  
    # Transformer Block
    self.encoder = [Transformer(emb_size,mlp_block,heads,batch_size) for _ in range(encoder_layers)]

    # Global Pool
    self.pool = tf.keras.layers.GlobalAveragePooling1D()

    # Classifier
    self.classifier = tf.keras.layers.Dense(num_classes,activation='softmax')
  
  
  def __call__(self,images):

    x = self.patch_emb(images)

    for enc_layer in self.encoder:

      x = enc_layer(x)

    x = self.pool(x)
    x = self.classifier(x)

    return x 



In [13]:
inp = tf.keras.layers.Input(shape=(224,224,3))
vit = ViT()

out = vit(inp)

model = tf.keras.Model(inputs=inp,outputs=out)

In [14]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 224, 224, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 14, 14, 768)  590592      ['input_1[0][0]']                
                                                                                                  
 tf.reshape (TFOpLambda)        (32, None, 768)      0           ['conv2d[0][0]']                 
                                                                                                  
 tf.concat (TFOpLambda)         (32, None, 768)      0           ['tf.reshape[0][0]']         