In [219]:
import tensorflow as tf
from tensorflow import convert_to_tensor, string
from tensorflow.keras.layers import TextVectorization, Embedding, Layer,LayerNormalization,Dense,ReLU,Dropout
import numpy as np
from tensorflow import matmul,cast,float32,math,reshape,shape,transpose
from tensorflow.keras.backend import softmax

In [300]:
class PositionEmbeddingFixedWeights(Layer):
    def __init__(self, sequence_length, vocab_size, output_dim, **kwargs):
        super(PositionEmbeddingFixedWeights, self).__init__(**kwargs)
        self.word_embedding_matrix = self.get_position_encoding(vocab_size, output_dim)   
        self.position_embedding_matrix = self.get_position_encoding(sequence_length, output_dim)                                          
        self.word_embedding_layer = Embedding(
            input_dim=vocab_size, output_dim=output_dim,
            trainable=False,
            name="word",
        )
        ##fixed weights
        self.word_embedding_layer.add_weight(shape=(vocab_size, output_dim),trainable=False)
        self.word_embedding_layer.set_weights([self.word_embedding_matrix])
        
        
        self.position_embedding_layer = Embedding(
            input_dim=sequence_length, output_dim=output_dim,
            trainable=False,
            name="pos",
        )
        ##fixed weights
        self.position_embedding_layer.add_weight(shape=(sequence_length, output_dim),trainable=False)
        self.position_embedding_layer.set_weights([self.position_embedding_matrix])
    
        
    def get_position_encoding(self, seq_len, d, n=10000):
        P = np.zeros((seq_len, d))
        for k in range(seq_len):
            for i in np.arange(int(d/2)):
                denominator = np.power(n, 2*i/d)
                P[k, 2*i] = np.sin(k/denominator)
                P[k, 2*i+1] = np.cos(k/denominator)
        return P
    
 
    def call(self, inputs):
        position_indices = tf.range(tf.shape(inputs)[-1])
        embedded_words = self.word_embedding_layer(inputs)
        embedded_indices = self.position_embedding_layer(position_indices)
        return embedded_words + embedded_indices

In [274]:
class AddNormalization(Layer):
    def __init__(self,**kwargs):
        super(AddNormalization,self).__init__(**kwargs)
        self.layer_norm=LayerNormalization()
    
    def call(self,x,sublayer_x):
        #sublayer i/p and o/p need to be of the same shape to be summed
        add=x+sublayer_x
        
        #apply layer normalization to the sum
        return self.layer_norm(add)

In [275]:
class FeedForward(Layer):
    def __init__(self,d_ff,d_model,**kwargs):
        super(FeedForward,self).__init__(**kwargs)
        self.fully_connected1=Dense(d_ff) #first fully connected layer
        self.fully_connected2=Dense(d_model) #second fully connected layer
        self.activation=ReLU() #relu activation layer
    
    def call(self,x):
        #the i/p passed into the 2 fully connected layers,with a relu in b/w
        x_fc1=self.fully_connected1(x)
        
        return self.fully_connected2(self.activation(x_fc1))

In [276]:
class DotProductAttention(Layer):
    def __init__(self,**kwargs):
        super(DotProductAttention,self).__init__(**kwargs)
    
    def call(self,queries,keys,values,d_k,mask=None):
        scores=matmul(queries,keys,transpose_b=True)/math.sqrt(cast(d_k,float32))
    
    #applying mask so as to not base the occurence of a word on the basis of the words ahead
        if mask is not None:
            scores+= -1e9*mask
        weights=softmax(scores)
        
        return matmul(weights,values)

In [277]:
class MultiHeadAttention(Layer):
    def __init__(self,h,d_k,d_v,d_model,**kwargs):
        super(MultiHeadAttention,self).__init__(**kwargs)
        self.attention=DotProductAttention() #scaled dot product attention
        self.heads=h #number of attention heads to use
        self.d_k=d_k #dimentionality of linearly projected queries and keys
        self.d_v=d_v #dimantionality of linearly projected values
        self.d_model=d_model #dimentionality of the model
        self.W_q=Dense(d_k) #learned projection matrix for the queries
        self.W_k=Dense(d_k) #learned projection matrix for the keys
        self.W_v=Dense(d_v) #learned projection matrix for the values
        self.W_o=Dense(d_model) #leanred projection matrix for the multi head o/p
        
    def reshape_tensor(self,x,heads,flag):
        if flag:
            #tensor shape after reshaping and transposing: (batch_size,heads,seq_length,-1)
            x=reshape(x,shape=(shape(x)[0],shape(x)[1],heads,-1))
            x=transpose(x,perm=(0,2,1,3))
        else:
            #reverting the reshaping and transposing opertaions:(batch_size,seq,length,d_k)
            x=transpose(x,perm=(0,2,1,3))
            x=reshape(x,shape=(shape(x)[0],shape(x)[1],self.d_k))
        return x
    
    def call(self,queries,keys,values,mask=None):
        #rearrange the queries to be able to compute all heads in parallel
        q_reshaped=self.reshape_tensor(self.W_q(queries),self.heads,True)
        #resulting tensor shape: (batch_size,heads,input_seq_len,-1)
        
        #rearrange the keys to be able to compute all heads in parallel
        k_reshaped=self.reshape_tensor(self.W_k(keys),self.heads,True)
        
        #rearrange the values to be able to compute all heads in parallel
        v_reshaped=self.reshape_tensor(self.W_v(values),self.heads,True)
        
        #compute the multi head attention o/p using the reshaped q,k,v
        o_reshaped=self.attention(queries=q_reshaped,keys=k_reshaped,values=v_reshaped,d_k=self.d_k,mask=mask)
        #resulting tensor shape: (batch_size,input_seq_len,-1)
        
        #rearrange back the o/p into concatenated form
        output=self.reshape_tensor(o_reshaped,self.heads,False)
        #resulting tensor shape: (batch_size,heads,input_seq_len,d_k)
        
        #apply one final layer linear projection to the o/p to generate the mutlihead attention
        #resulting tensor shaoe:(batch_size,input_seq_len,d_model)
        return self.W_o(output)

In [278]:
class EncoderLayer(Layer):
    def __init__(self,h,d_k,d_v,d_model,d_ff,rate,**kwargs):
        super(EncoderLayer,self).__init__(**kwargs)
        self.multihead_attention=MultiHeadAttention(h,d_k,d_v,d_model)
        self.dropout1=Dropout(rate)
        self.add_norm1=AddNormalization()
        self.feed_forward=FeedForward(d_ff,d_model)
        self.dropout2=Dropout(rate)
        self.add_norm2=AddNormalization()
        
    def call(self,x,padding_mask,training):
        #multihead attention layer
        multihead_output=self.multihead_attention(queries=x,keys=x,values=x,mask=padding_mask)
        #expected o/p shape =(batch_size,seq_len,d_model)
        
        #dropout
        multihead_output=self.dropout1(multihead_output,training=training)
        
        #Add and Norm Layer
        addnorm_output=self.add_norm1(x,multihead_output)
        #expected o/p shape=(batch_size,seq_len,d_model)
        
        #fully connected layer
        feedforward_output=self.feed_forward(addnorm_output)
        #expected shape=(batch_size,seq_len,d_model)
        
        #dropout layer
        feedforward_output=self.dropout2(feedforward_output,training=training)
        
        #Add and Norm layer
        return self.add_norm2(addnorm_output,feedforward_output)

In [279]:
class Encoder(Layer):
    def __init__(self,vocab_size,seq_len,h,d_k,d_v,d_model,d_ff,n,rate,**kwargs):
        super(Encoder,self).__init__(**kwargs)
        self.pos_encoding=PositionEmbeddingFixedWeights(seq_len,vocab_size,d_model)
        self.dropout=Dropout(rate)
        self.encoder_layer=[EncoderLayer(h,d_k,d_v,d_model,d_ff,rate) for _ in range(n)]
        
    
    def call(self,input_seq,padding_mask,training):
        #generate positional encoding
        pos_encoding_output=self.pos_encoding(input_seq)
        
        #expected output shape=(batch_size,seq_len,d_model)
        
        #droupout layer
        x=self.dropout(pos_encoding_output,training=training)
        
        #passing positional encoded values to each encoder layer
        for i ,layer in enumerate(self.encoder_layer):
            x=layer(x,padding_mask=padding_mask,training=training)
            
        return x

### testing using dummy data

In [301]:
from numpy import random

enc_vocab_size=20 #vocabulary size for the encoder
input_seq_len=5 #max size of the i/p seq
h=8 #number of self-attention heads
d_k=64 #dimentionality of the linearly projected queris and keys
d_v=64 #dimentionality of the linearly projected values
d_ff=2048 #dimentionality of the inner fully connected layer
d_model=512 #dimentionality of the model sub-layers' o/p
n=6 #numer of layers in the encoder stack

batch_size=64 #batch size form the the training process
dropout_rate=0.1 #frequency of dropping the i/p units in the dropout layer
input_seq=random.random((batch_size,input_seq_len))

encoder=Encoder(enc_vocab_size,input_seq_len,h,d_k,d_v,d_model,d_ff,n,dropout_rate)
print(encoder(input_seq,padding_mask=None,training=True))

tf.Tensor(
[[[ 1.1147017   0.16441566  0.55740565 ... -0.37166512 -0.15431736
   -1.8471124 ]
  [ 0.7618469   0.8236067  -0.60659933 ... -1.2163182   0.3851881
   -2.982336  ]
  [ 1.3416651   0.04513863 -0.15659723 ... -0.5181147   0.32735035
   -1.9324467 ]
  [ 0.8564203   0.73370904 -0.33491606 ... -0.39158142  0.0924149
   -1.8935617 ]
  [ 0.80698574  0.12150733  0.06685011 ... -0.21014674 -0.41041982
   -0.76314354]]

 [[ 0.32850823  0.3916108  -0.11228144 ...  0.24602424  0.05485585
   -1.5224779 ]
  [ 1.565122    0.06600136 -0.19128309 ... -0.41524026  0.95913863
   -2.328098  ]
  [ 0.12286402 -0.95155865 -0.8306134  ... -0.47583726 -0.47932354
   -0.9740388 ]
  [ 0.4104212   0.43118575 -0.3968854  ... -0.34344634 -0.2612015
   -1.6174657 ]
  [ 0.9022037   0.0748307  -0.4029294  ... -0.01195032 -0.3776412
   -1.3287237 ]]

 [[ 1.195409    0.5425064   0.0361741  ... -0.6074726   0.91755617
   -0.74038535]
  [ 1.305232    0.8578337  -0.4308623  ... -0.66527903  0.90632313
   -2.743