In [None]:
from tensorflow.keras.layers import MultiHeadAttention,Dense,LayerNormalization,Embedding,Layer,Input
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy
import numpy as np
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [None]:
def positional_encoding(seq_len , model_size):
  output = []
  for pos in range(seq_len):
    PE = np.zeros(model_size)
    for i in range(model_size):
      if i % 2 == 0:
        PE[i] = np.sin(pos / (10000 ** (i/model_size)))
      else:
        PE[i] = np.cos(pos / (10000 ** ((i-1)/model_size)))

    output.append(PE)
    out = np.expand_dims(output , axis = 0)

  return out

class Embeddings(Layer):
  def __init__(self , vocab_size , seq_len , model_size):
    super(Embeddings, self).__init__()
    self.emb = Embedding(input_dim = vocab_size , output_dim = model_size)
    self.pos_encoding = positional_encoding(seq_len,model_size)

  def call(self,input):
    embs = self.emb(input)
    return (self.pos_encoding + embs)

  def compute_masks(self,input):
    mask = tf.math.not_equal(input , 0)
    mask = tf.cast(mask[:,tf.newaxis,:],tf.int32)
    T = tf.shape(mask)[2]
    mask = tf.repeat(mask , T , axis = 1)

    return mask

In [None]:
class Encoderlayer(Layer):
  def __init__(self, num_heads , emb_dim , dense_dim):
    super(Encoderlayer,self).__init__()
    self.layernorm_1 = LayerNormalization()
    self.layernorm_2 = LayerNormalization()
    self.dense = tf.keras.Sequential([
        Dense(dense_dim,activation = 'relu'),
        Dense(emb_dim)
    ])
    self.attn = MultiHeadAttention(num_heads=num_heads,key_dim=emb_dim)

  def call(self,inputs , mask):
    attn_out = self.attn(query = inputs , key = inputs , value = inputs , attention_mask = mask)
    out = self.layernorm_1(attn_out + inputs)

    dense_out = self.dense(out)

    return self.layernorm_2(dense_out + out)


class Decoderlayer(Layer):
  def __init__(self, num_heads , emb_dim , dense_dim):
    super(Decoderlayer,self).__init__()
    self.layernorm_1 = LayerNormalization()
    self.layernorm_2 = LayerNormalization()
    self.layernorm_3 = LayerNormalization()

    self.dense = tf.keras.Sequential([
        Dense(dense_dim,activation = 'relu'),
        Dense(emb_dim)
    ])
    self.attn_1 = MultiHeadAttention(num_heads=num_heads,key_dim=emb_dim)
    self.attn_2 = MultiHeadAttention(num_heads=num_heads,key_dim=emb_dim)


  def call(self,inputs ,encoder_outputs, mask):
    causal_mask = tf.linalg.band_part(tf.ones([tf.shape(inputs)[0],
                                                  tf.shape(inputs)[1],
                                                  tf.shape(inputs)[1]],dtype = tf.int32),-1,0)

    attn_mask = tf.minimum(mask , causal_mask)


    attn_out = self.attn_1(query = inputs , key = inputs , value = inputs , attention_mask = attn_mask)
    out_1 = self.layernorm_1(attn_out + inputs)

    attn_out_2 = self.attn_2(query = out_1 ,key = encoder_outputs , value = encoder_outputs , attention_mask = None )
    out_2 = self.layernorm_2(out_1 + attn_out_2)

    dense_out = self.dense(out_2)

    return self.layernorm_3(dense_out + out_2)

In [None]:
num_heads = 8
emd_dim = 512
vocab_size = 20000
seq_len = 2
dense_dim = 2048
num_layers = 1

enc_inputs = Input(shape = (None,))
enc_emb = Embeddings(vocab_size,seq_len,emd_dim)
x = enc_emb(enc_inputs)
enc_mask = enc_emb.compute_masks(enc_inputs)

for _ in range(num_layers):
  x = Encoderlayer(num_heads,emd_dim,dense_dim)(x,enc_mask)

enc_output = x

dec_inputs = Input(shape = (None,))
dec_emb = Embeddings(vocab_size,2,emd_dim)
x = dec_emb(dec_inputs)
dec_mask = dec_emb.compute_masks(dec_inputs)


for _ in range(num_layers):
  x = Decoderlayer(num_heads,emd_dim,dense_dim)(x,enc_output,dec_mask)

output = Dense(vocab_size , activation = 'softmax')(x)

model = tf.keras.Model([enc_inputs,dec_inputs],output)

model.summary()
model.compile(optimizer = Adam(),
              loss = SparseCategoricalCrossentropy())

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, None)]               0         []                            
                                                                                                  
 input_2 (InputLayer)        [(None, None)]               0         []                            
                                                                                                  
 tf.math.not_equal (TFOpLam  (None, None)                 0         ['input_1[0][0]']             
 bda)                                                                                             
                                                                                                  
 tf.__operators__.getitem (  (None, 1, None)              0         ['tf.math.not_equal[0][0]'

In [None]:
x1 = np.array([[1,2]])
x2 = np.array([[3,4]])
y = np.array([[4,5]])



history=model.fit(
    [x1,x2],y,
    epochs=100)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

In [None]:
x1 = np.array([[1,2]])
x2 = np.array([[3,4]])

print(np.argmax(model.predict([x1,x2])))

4
