In [120]:
import keras

In [121]:
import tensorflow as tf

In [122]:
import numpy as np

In [123]:
from keras.layers import Dense,LSTM,Input,Embedding

In [124]:
class LuongAttentionWithMonotonic(tf.keras.layers.Layer):
    def __init__(self, units, method="Dot", window_size=5):
        super(LuongAttentionWithMonotonic, self).__init__()
        self.method = method
        self.window_size = window_size
        self.w1 = Dense(units, use_bias=False)
        if method == "Concat":
            self.weight = tf.Variable(initial_value=tf.zeros((units, 1)), trainable=True, dtype=tf.float32)

    def call(self, inputs):
        encoder_op, decoder_op, decoder_step = inputs
        if self.method == "General":
            decoder_op = tf.transpose(decoder_op, perm=(0, 2, 1))
            ou1 = self.w1(encoder_op)
            score = tf.matmul(ou1, decoder_op)

        elif self.method == "Dot":
            decoder_op = tf.transpose(decoder_op, perm=(0, 2, 1))
            score = tf.matmul(encoder_op, decoder_op)

        elif self.method == "Concat":
            concat = tf.nn.tanh(self.w1(encoder_op) + self.w1(decoder_op))
            score = tf.matmul(concat, self.weight)

        else:
            try:
                raise ValueError("Try valid alignment")
            except ValueError as e:
                print("Error:", e)
                return

        pt = tf.cast(decoder_step, tf.int32)  
        pt_minus_D = tf.maximum(pt - self.window_size, 0)  
        pt_plus_D = tf.minimum(pt + self.window_size + 1, tf.shape(encoder_op)[1])  

        mask = tf.sequence_mask([pt_minus_D, pt_plus_D],
                                maxlen=tf.shape(encoder_op)[1], dtype=tf.float32)

        score = score * mask

        attention_weights = tf.nn.softmax(score, axis=1)
        context_vector = tf.matmul(attention_weights, encoder_op)
        return context_vector

In [125]:
class LuongAttentionWithLocalWindow(tf.keras.layers.Layer):
    def __init__(self, units, method="Dot", window_size=5):
        super(LuongAttentionWithLocalWindow, self).__init__()
        self.method = method
        self.window_size = window_size
        self.w1 = Dense(units, use_bias=False)
        self.v = Dense(units, use_bias=False)  

        if method == "Concat":
            self.weight = tf.Variable(initial_value=tf.zeros((units, 1)), trainable=True, dtype=tf.float32)

    def call(self, inputs):
        encoder_op, decoder_op, decoder_step = inputs

        pt = decoder_step * tf.sigmoid(self.v(tf.nn.tanh(self.w1(decoder_op))))

        sigma = self.window_size / 2.0
        alpha = tf.exp(-0.5 * tf.square(tf.expand_dims(tf.range(tf.shape(encoder_op)[1]), axis=0) - 
                                        tf.expand_dims(pt, axis=1)) / tf.square(sigma))
        mask = tf.sequence_mask([tf.cast(tf.round(pt - self.window_size / 2), tf.int32),
                                 tf.cast(tf.round(pt + self.window_size / 2) + 1, tf.int32)],
                                maxlen=tf.shape(encoder_op)[1], dtype=tf.float32)

        score = self.score(encoder_op, decoder_op)  
        score = score * alpha * mask

        attention_weights = tf.nn.softmax(score, axis=1)
        context_vector = tf.matmul(attention_weights, encoder_op)
        return context_vector

    def score(self, encoder_op, decoder_op):
        if self.method == "General":
            decoder_op = tf.transpose(decoder_op, perm=(0, 2, 1))
            ou1 = self.w1(encoder_op)
            score = tf.matmul(ou1, decoder_op)

        elif self.method == "Dot":
            decoder_op = tf.transpose(decoder_op, perm=(0, 2, 1))
            score = tf.matmul(encoder_op, decoder_op)

        elif self.method == "Concat":
            concat = tf.nn.tanh(self.w1(encoder_op) + self.w1(decoder_op))
            score = tf.matmul(concat, self.weight)

        else:
            try:
                raise ValueError("Try valid alignment")
            except ValueError as e:
                print("Error:", e)
                return

        return score

In [126]:
from keras import Model

In [127]:
from keras.layers import Bidirectional,Concatenate

In [141]:
class LuongGlobalAttention(tf.keras.layers.Layer):
    def __init__(self,units,method="Dot"):
        super(LuongGlobalAttention,self).__init__()
        self.method=method
        self.w1=Dense(units,use_bias=False)
        if method=="Concat":
            self.weight = tf.Variable(initial_value=tf.zeros((units,1)), trainable=True, dtype=tf.float32)
            
    def call(self,inputs):
        encoder_op,decoder_op=inputs
        if self.method=="General":
            decoder_op=tf.transpose(decoder_op,perm=(0,2,1))
            ou1=self.w1(encoder_op)
            score=tf.matmul(ou1,decoder_op)
            
        elif self.method=="Dot":
            decoder_op=tf.transpose(decoder_op,perm=(0,2,1))
            score=tf.matmul(encoder_op,decoder_op)
            
        elif self.method=="Concat":
            concat=tf.nn.tanh(self.w1(encoder_op)+self.w1(decoder_op))
            score=tf.matmul(concat,self.weight)
            
        else:
            try:
                raise ValueError("Try valid alignment")
            except ValueError as e:
                print("Error:", e)
                return
            
        attention_weights=tf.nn.softmax(score,axis=1)
        attention_weights=tf.transpose(attention_weights,perm=(0,2,1))
        context_vector=tf.matmul(attention_weights,encoder_op)
        return context_vector

In [142]:
src_len=10
ip_vocab_size=2
tg_vocab_size=5
lstm_units=15
embed_dim=20

In [220]:
encoder_input=Input(shape=(src_len,))
decoder_input=Input(shape=(None,))

encoder_embedding=Embedding(ip_vocab_size,embed_dim)
decoder_embedding=Embedding(tg_vocab_size,embed_dim)

encoder_embed=encoder_embedding(encoder_input)
decoder_embed=decoder_embedding(decoder_input)

encoder_lstm=Bidirectional(LSTM(lstm_units,return_sequences=True,return_state=True))
encoder_op,forward_h,forward_c,backward_h,backward_c=encoder_lstm(encoder_embed)
encoder_dense=Dense(lstm_units)
h=tf.concat([forward_h,backward_h],axis=-1)
c=tf.concat([forward_c,backward_c],axis=-1)
encoder_op=encoder_dense(encoder_op)
h=encoder_dense(h)
c=encoder_dense(c)

decoder_lstm=LSTM(lstm_units,return_sequences=True,return_state=True)
decoder_op,h1,c1=decoder_lstm(decoder_embed,initial_state=[h,c])
attention=LuongGlobalAttention(lstm_units,method="General")
context_vector=attention([encoder_op,decoder_op])

decoder_op=tf.concat([context_vector, decoder_op],axis=-1)
decoder_op=tf.nn.tanh(decoder_op)
decoder_dense=Dense(tg_vocab_size,activation='softmax')
decoder_op=decoder_dense(decoder_op)

model=Model([encoder_input,decoder_input],[decoder_op])

In [221]:
inputs=np.random.random((30,10))
outputs=np.random.random((30,7))

In [222]:
from sklearn.model_selection import train_test_split

In [223]:
X_train,X_test,y_train,y_test=train_test_split(inputs,outputs,test_size=0.2)

In [224]:
y_train_onehot=tf.one_hot(y_train,tg_vocab_size)
y_test_onehot=tf.one_hot(y_test,tg_vocab_size)

In [225]:
y_train.shape,y_train_onehot.shape

((24, 7), TensorShape([24, 7, 5]))

In [226]:
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])

In [227]:
model.fit([X_train,y_train],y_train_onehot,epochs=3,batch_size=5,validation_data=([X_test,y_test],y_test_onehot))

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.src.callbacks.History at 0x272924eb430>

In [228]:
model.summary()

Model: "model_21"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_77 (InputLayer)       [(None, 10)]                 0         []                            
                                                                                                  
 embedding_76 (Embedding)    (None, 10, 20)               40        ['input_77[0][0]']            
                                                                                                  
 bidirectional_12 (Bidirect  [(None, 10, 30),             4320      ['embedding_76[0][0]']        
 ional)                       (None, 15),                                                         
                              (None, 15),                                                         
                              (None, 15),                                                  

In [230]:
encoder_model=Model(encoder_input,[encoder_op,h,c])

In [231]:
encoder_model.summary()

Model: "model_22"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_77 (InputLayer)       [(None, 10)]                 0         []                            
                                                                                                  
 embedding_76 (Embedding)    (None, 10, 20)               40        ['input_77[0][0]']            
                                                                                                  
 bidirectional_12 (Bidirect  [(None, 10, 30),             4320      ['embedding_76[0][0]']        
 ional)                       (None, 15),                                                         
                              (None, 15),                                                         
                              (None, 15),                                                  

In [233]:
decoder_model=Model([decoder_input,encoder_op,h,c],[decoder_op,h1,c1])

In [234]:
decoder_model.summary()

Model: "model_23"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_78 (InputLayer)       [(None, None)]               0         []                            
                                                                                                  
 embedding_77 (Embedding)    (None, None, 20)             100       ['input_78[0][0]']            
                                                                                                  
 input_80 (InputLayer)       [(None, 15)]                 0         []                            
                                                                                                  
 input_81 (InputLayer)       [(None, 15)]                 0         []                            
                                                                                           