In [3]:
import tensorflow as tf
from tensorflow.keras.layers import Embedding, LSTM, Dense, RNN, Flatten, Softmax
from tensorflow.keras import layers
from tensorflow.keras import initializers, regularizers, constraints

class Encoder(tf.keras.Model):
    '''
    Encoder- Takes a input sequence and returns output sequence
    '''
    def __init__(self,input_vocab_size,embedd_size,lstm_size,inp_len,perturbated_text_embed_matrix):
        super().__init__()

        self.input_vocab_size = input_vocab_size
        self.embedd_size = embedd_size
        self.lstm_size = lstm_size
        self.inp_len = inp_len
        self.perturbated_text_embed_matrix=perturbated_text_embed_matrix

        self.embedd = Embedding(input_dim = self.input_vocab_size, output_dim = self.embedd_size, input_length=self.inp_len,
                                weights = [self.perturbated_text_embed_matrix], mask_zero=True)
        self.encoder_lstm = LSTM(units = self.lstm_size, return_sequences=True, return_state=True, 
                               name="Encoder", kernel_regularizer= regularizers.l2(1e-5))

    def call(self,inp_seq, training=True):
        embedds = self.embedd(inp_seq)
        encoder_output, encoder_hidden_state, encod_cell_state = self.encoder_lstm(embedds)
        return encoder_output, encoder_hidden_state, encod_cell_state

    def initialize_states(self,batch_size):
        h_state = np.zeros((batch_size, self.lstm_units))
        c_state = np.zeros((batch_size, self.lstm_units))
        return h_state, c_state
    
############################################################################################
#https://github.com/UdiBhaskar/TfKeras-Custom-Layers/blob/master/Seq2Seq/clayers.py
class MonotonicBahadanauAttention(tf.keras.layers.Layer):
    def __init__(self, units,
                 return_aweights=False,
                 scaling_factor=None,
                 noise_std=0,
                 weights_initializer='he_normal',
                 bias_initializer='zeros',
                 **kwargs):
        
        if 'name' not in kwargs:
            kwargs['name'] = ""
            
        super(MonotonicBahadanauAttention, self).__init__(**kwargs)
        self.units = units
        self.scaling_factor = scaling_factor
        self.noise_std = noise_std
        self.weights_initializer = initializers.get(weights_initializer)
        self.bias_initializer = initializers.get(bias_initializer)
        self.weights_regularizer = regularizers.l2(1e-2)
    
    def build(self, input_shape):
        self._wa = layers.Dense(self.units, use_bias=False,\
            kernel_initializer=self.weights_initializer, bias_initializer=self.bias_initializer,\
                kernel_regularizer= self.weights_regularizer, name=self.name+"Wa")
        
        self._ua = layers.Dense(self.units,\
            kernel_initializer=self.weights_initializer, bias_initializer=self.bias_initializer,\
                kernel_regularizer= self.weights_regularizer, name=self.name+"Ua")
        
        self._va = layers.Dense(1, use_bias=False, kernel_initializer=self.weights_initializer,\
            kernel_regularizer= self.weights_regularizer,bias_initializer=self.bias_initializer, name=self.name+"Va")
        
        
    def call(self, decoder_hidden_state, encoder_outputs, prev_attention, training=True):

        encoder_outputs, decoder_hidden_state = tf.cast(encoder_outputs, tf.float32), \
            tf.cast(decoder_hidden_state, tf.float32)
        
        dec_hidden_with_time_axis = tf.expand_dims(decoder_hidden_state, 1)

        weightwa=self._wa
        weightua=self._ua
        weightva=self._va
        
        #bahdanau attention score
        score = weightva(tf.nn.tanh(weightwa(dec_hidden_with_time_axis) + weightua(encoder_outputs)))
        score = tf.squeeze(score, [2])
        
        if self.scaling_factor is not None:
            score = score/tf.sqrt(self.scaling_factor)

        if training:
            if self.noise_std > 0:
                random_noise = tf.random.normal(shape=tf.shape(input=score), mean=0,\
                    stddev=self.noise_std, dtype=score.dtype, seed=self.seed)
                score = score + random_noise

        probabilities = tf.sigmoid(score)

        #monotonic attention 'parallel' mode
        cumprod_1mp_probabilities = tf.exp(tf.cumsum(tf.math.log(tf.clip_by_value(1-probabilities,\
            1e-10, 1)), axis=1, exclusive=True))
        attention_weights = probabilities*cumprod_1mp_probabilities*tf.cumsum(prev_attention/\
            tf.clip_by_value(cumprod_1mp_probabilities, 1e-10, 1.), axis=1)
        attention_weights = tf.expand_dims(attention_weights, 1)

        context_vector = tf.matmul(attention_weights, encoder_outputs)
        context_vector = tf.squeeze(context_vector, 1, name="context_vector")

        return context_vector, tf.squeeze(attention_weights, 1, name='attention_weights')
    
############################################################################################

class OneStepDecoder(tf.keras.Model):

    def __init__(self,target_vocab_size, embedd_dim, inp_len, decoder_units, attention_units,text_embed_matrix):
        super().__init__()
        
        self.target_vocab_size = target_vocab_size
        self.embedd_dim = embedd_dim
        self.inp_len = inp_len
        self.decoder_units = decoder_units
        self.attention_units = attention_units
        self.text_embed_matrix=text_embed_matrix
        
        self.decoder_embedding_layer = Embedding(input_dim = self.target_vocab_size, output_dim = self.embedd_dim,
                                                 input_length = self.inp_len,
                                      weights = [self.text_embed_matrix] , name="onestepdecoder_embedding_layer", mask_zero=True)
        
        self.decoder_LSTM = LSTM(units = self.decoder_units, return_state=True, kernel_regularizer= regularizers.l2(1e-5))
        
        self.MonotonicBahadanauAttention=MonotonicBahadanauAttention(self.attention_units,
                 return_aweights=False,
                 scaling_factor=None,
                 noise_std=0,
                 weights_initializer='he_normal',
                 bias_initializer='zeros',)

        self.dense = Dense(units=self.target_vocab_size)

    def call(self,inp_to_dec, encoder_output, state_hidden, state_cell, att_weights):
        
        decoder_embedd = self.decoder_embedding_layer(inp_to_dec)
        context_vec, att_weights = self.MonotonicBahadanauAttention(state_hidden,encoder_output, att_weights)
        decoder_embedd = tf.concat([tf.expand_dims(context_vec,1), decoder_embedd], axis=-1)
        decoder_out, decoder_hidden_state, decoder_cell_state = self.decoder_LSTM(decoder_embedd, 
                                                                                initial_state=[state_hidden, state_cell])
        onestep_decoder_output = self.dense(decoder_out)

        return onestep_decoder_output, decoder_hidden_state, decoder_cell_state, att_weights, context_vec


############################################################################################

class Decoder(tf.keras.Model):
    def __init__(self,output_vocab_size, embedd_dim, output_length, decoder_units , att_units):
        super().__init__()

        self.output_vocab_size = output_vocab_size
        self.embedd_dim = embedd_dim
        self.output_length = output_length
        self.decoder_units = decoder_units
        self.att_units = att_units
        self.text_embed_matrix=text_embed_matrix

        self.onestep_decoder = OneStepDecoder(self.output_vocab_size, self.embedd_dim, self.output_length, self.decoder_units
                                              ,self.att_units,self.text_embed_matrix)

        
    def call(self, inp_to_dec,enc_out,decoder_h,decoder_c, att_weights ):

        total_out = tf.TensorArray(tf.float32, size=tf.shape(inp_to_dec)[1], name='out_arrays')
        i = tf.shape(inp_to_dec)[1]

        for t_step in range(i):    

            onestep_decoder_output, decoder_h, decoder_c, att_weights, context_vector = self.onestep_decoder(
                                            inp_to_dec[:, t_step:t_step+1], enc_out, decoder_h, decoder_c, att_weights)

            total_out = total_out.write(t_step, onestep_decoder_output)
        
        total_out = tf.transpose(total_out.stack(), [1,0,2])

        return total_out

############################################################################################

class bahadanau_attention_model(tf.keras.Model):
    def __init__(self, encoder_vocab_size, decoder_vocab_size, encoder_embedd_dim, decoder_embedd_dim, input_len, output_len, 
                 encoder_units, decoder_units, attention_units,perturbated_text_embed_matrix,text_embed_matrix):
        super().__init__()

        self.encoder_vocab_size = encoder_vocab_size
        self.decoder_vocab_size = decoder_vocab_size
        self.encoder_embedd_dim = encoder_embedd_dim
        self.decoder_embedd_dim = decoder_embedd_dim
        self.input_len = input_len
        self.output_len = output_len
        self.encoder_units = encoder_units
        self.decoder_units = decoder_units
        self.attention_units = attention_units
        self.perturbated_text_embed_matrix=perturbated_text_embed_matrix
        self.text_embed_matrix=text_embed_matrix

        self.encoder = Encoder(self.encoder_vocab_size,self.encoder_embedd_dim,self.encoder_units,self.input_len,
                              self.perturbated_text_embed_matrix)
        self.decoder = Decoder(self.decoder_vocab_size,self.decoder_embedd_dim,self.output_len,self.decoder_units,
                               self.attention_units,self.text_embed_matrix)

    def call(self, data, training=True):
        encoder_inp, decoder_inp = data[0], data[1]
        
        encoder_out, encoder_h, encoder_c = self.encoder(encoder_inp)
    
        decoder_h = encoder_h  #initial decoder state is equal to final encoder hidden state
        decoder_c = encoder_c

        att_weights = np.zeros((512, 20), dtype='float32')
        att_weights[:, 0] = 1
        
        final_output = self.decoder(decoder_inp,encoder_out,decoder_h,decoder_c, att_weights)
        return final_output
