In [None]:
### This notebook checks how many parameters in the transformer block
### Notes:
### 1. Attention learnable parameters calculations
### key_dim as the dimension where each key, query & value is projected to for each head
### so if model input has a vector of 64 elements & layer parameter  key_dim = 64 ==> 
### we need to project each sequence vector to 3 * 64 vecor for each head
### ==> (64*64 weights + 64 bias for each projection) * 3 (for mini Q,K,V vectors) projections * 4 (number of heads)
### leading to 49,920 learnable parameters
### This outputs 4 (attention heads) * (mini-attention vectors each of 64 elements)
### The aattention vectors from each head (64 each matching the projection size of the value vectors) 
### are concatenated leading to one attention vector of 256 elements
### This vector is projected to 64 element leading to (256 * 64 + 64) = 16,448 learnable parameters
### this leaves us with 49,920+16448 = 66,368 learnable parameters fort the multihead attention layer
###
### 2. Number input or outut sequence has no contribution to learnable parameters

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras import layers, models, Sequential, utils
from tensorflow.keras import activations,applications, optimizers
num_heads=4
key_dim=64
layerNorm1 = layers.LayerNormalization(epsilon=1e-6)
layerNorm2 = layers.LayerNormalization(epsilon=1e-6)
multiHeadAttenLayer = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=key_dim,dropout=0.1)
dense1 = layers.Dense(units=key_dim*2, activation=tf.nn.gelu)
dense2 = layers.Dense(units=key_dim, activation=tf.nn.gelu)
dropout1 = layers.Dropout(0.1)
dropout2 = layers.Dropout(0.1)
# Layer normalization 1
input = tf.keras.Input(shape=(145,64))
x=input
x1 = layerNorm1(x)
# Multi-head attention layer
x1 = multiHeadAttenLayer(x1, x1)
# Skip connection 1
x = layers.Add()([x, x1])

# MLP layer
x1 = layerNorm2(x)
x1 = dense1(x1)
x1 = dropout1(x1)
x1 = dense2(x1)
x1 = dropout2(x1)
# Skip connection 2
output = layers.Add()([x, x1])
model = tf.keras.Model(input, output)
model.summary(expand_nested=True)




Model: "model_5"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_10 (InputLayer)          [(None, 145, 64)]    0           []                               
                                                                                                  
 layer_normalization_8 (LayerNo  (None, 145, 64)     128         ['input_10[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 multi_head_attention_6 (MultiH  (None, 145, 64)     66368       ['layer_normalization_8[0][0]',  
 eadAttention)                                                    'layer_normalization_8[0][0]']  
                                                                                            