In [1]:
import tensorflow as tf

In [2]:
input_shape = (256, 256, 3)
patch_size = 16
num_patches = input_shape[0] * input_shape[1] // patch_size ** 2
num_heads = 8
projection_dim = 512
drop_rate = 0.3
num_classes = 100

In [3]:
class PatchEncoder(tf.keras.layers.Layer):
    def __init__(self, num_patchs, patch_size, projection_dim, drop_rate):
        super().__init__()
        self.num_patches = num_patches
        self.patch_size = patch_size
        self.projection_dim = projection_dim
        self.pos = self.add_weight(shape=[num_patches, projection_dim], initializer='normal', dtype=tf.float32, trainable=True)
        self.conv = tf.keras.layers.Conv2D(filters=patch_size**2 * 3, kernel_size=patch_size, strides=patch_size)
        self.linear_proj = tf.keras.layers.Dense(units=projection_dim)
        self.dropout = tf.keras.layers.Dropout(drop_rate)

    def call(self, x):
        batch_size = tf.shape(x)[0]
        x = self.conv(x)
        x = tf.reshape(x, shape=[batch_size, -1, self.patch_size**2 * 3])
        x = self.linear_proj(x)
        x = x + self.pos
        x = self.dropout(x)
        return x

In [4]:
class MultiHeadSelfAttention(tf.keras.Model):
    def __init__(self, num_heads, projection_dim, drop_rate):
        super().__init__()
        self.num_heads = num_heads
        self.projection_dim = projection_dim
        self.w = tf.keras.layers.Dense(units=projection_dim)
        self.softmax = tf.keras.layers.Softmax()
        self.dropout = tf.keras.layers.Dropout(drop_rate)
    
    def call(self, x):
        batch_size = tf.shape(x)[0]
        
        query = self.w(x)
        query = tf.reshape(query, shape=[batch_size, -1, self.num_heads, self.projection_dim // self.num_heads])
        query = tf.transpose(query, perm=[0, 2, 1, 3])
        
        key = self.w(x)
        key = tf.reshape(query, shape=[batch_size, -1, self.num_heads, self.projection_dim // self.num_heads])
        key = tf.transpose(key, perm=[0, 2, 1, 3])
        
        value = self.w(x)
        value = tf.reshape(query, shape=[batch_size, -1, self.num_heads, self.projection_dim // self.num_heads])
        value = tf.transpose(value, perm=[0, 2, 1, 3])
        
        attention = tf.matmul(query, key, transpose_b=True) / tf.sqrt(tf.cast(tf.shape(key)[-1], dtype=tf.float32))
        
        x = self.softmax(attention)
        x = self.dropout(x)
        x = tf.matmul(x, value)
        x = tf.transpose(x, perm=[0, 2, 1, 3])
        x = tf.reshape(x, shape=[batch_size, -1, self.projection_dim])
        return x

In [5]:
def model():
    inputs = tf.keras.layers.Input(shape=input_shape)
    x = PatchEncoder(num_patches, patch_size, projection_dim, drop_rate)(inputs)
    
    for _ in range(8):
        x1 = tf.keras.layers.LayerNormalization()(x)
        x1 = MultiHeadSelfAttention(num_heads, projection_dim, drop_rate)(x1)
        x2 = tf.keras.layers.Add()([x1, x])
        x3 = tf.keras.layers.LayerNormalization()(x2)
        x3 = tf.keras.layers.Dense(projection_dim * 2, activation=tf.nn.gelu)(x3)
        x3 = tf.keras.layers.Dropout(drop_rate)(x3)
        x3 = tf.keras.layers.Dense(projection_dim, activation=tf.nn.gelu)(x3)
        x3 = tf.keras.layers.Dropout(drop_rate)(x3)
        x = tf.keras.layers.Add()([x3, x2])
        
    x = tf.keras.layers.LayerNormalization()(x)
    x = tf.keras.layers.GlobalAveragePooling1D()(x)
    x = tf.keras.layers.Dropout(drop_rate)(x)
    x = tf.keras.layers.Dense(1024, activation=tf.nn.gelu)(x)
    x = tf.keras.layers.Dropout(drop_rate)(x)
    x = tf.keras.layers.Dense(num_classes)(x)
    return tf.keras.Model(inputs=inputs, outputs=x)

model = model()
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 patch_encoder (PatchEncoder)   (None, 256, 512)     1115392     ['input_1[0][0]']                
                                                                                                  
 layer_normalization (LayerNorm  (None, 256, 512)    1024        ['patch_encoder[0][0]']          
 alization)                                                                                       
                                                                                              

 multi_head_self_attention_3 (M  (None, None, 512)   262656      ['layer_normalization_6[0][0]']  
 ultiHeadSelfAttention)                                                                           
                                                                                                  
 add_6 (Add)                    (None, 256, 512)     0           ['multi_head_self_attention_3[0][
                                                                 0]',                             
                                                                  'add_5[0][0]']                  
                                                                                                  
 layer_normalization_7 (LayerNo  (None, 256, 512)    1024        ['add_6[0][0]']                  
 rmalization)                                                                                     
                                                                                                  
 dense_11 

                                                                                                  
 dropout_20 (Dropout)           (None, 256, 1024)    0           ['dense_20[0][0]']               
                                                                                                  
 dense_21 (Dense)               (None, 256, 512)     524800      ['dropout_20[0][0]']             
                                                                                                  
 dropout_21 (Dropout)           (None, 256, 512)     0           ['dense_21[0][0]']               
                                                                                                  
 add_13 (Add)                   (None, 256, 512)     0           ['dropout_21[0][0]',             
                                                                  'add_12[0][0]']                 
                                                                                                  
 layer_nor