In [3]:
from tensorflow.keras.layers import Input, LSTM, Embedding, Dense
from tensorflow.keras.models import Model
import tensorflow as tf

In [5]:
def encoder(vocab_size, units, embedding_size, regularizer=None, name="encoder"):
    enc_inputs = tf.keras.layers.Input(shape=(None,), name="encoder inputs")
        
    embeddings = tf.keras.layers.Embedding(vocab_size, embedding_size,
                    embeddings_regularizer=regularizer, mask_zero=True,
                    name='{}-embedding'.format(name))(enc_inputs)
    embeddings *= tf.math.sqrt(tf.cast(embedding_size, tf.float32))


    encoder_outputs, state_h, state_c = tf.keras.layers.LSTM(units=units, return_state=True,
            kernel_regularizer=regularizer, recurrent_regularizer=regularizer, input_shape=(None, embedding_size),
            name='{}-lstm'.format(name))(embeddings)
    # inputs = Dropout(dropout,name='dropout_%d'%(i+1))(lstm)

    return tf.keras.Model(
      inputs=[enc_inputs], outputs=[state_h, state_c], name=name)

In [6]:
def decoder(vocab_size, units, embedding_size, regularizer=None, name="decoder"):
    dec_inputs = tf.keras.layers.Input(shape=(None,), name="decoder inputs")
    enc_state_h = tf.keras.layers.Input(shape=( units), name="encoder hidden state")
    enc_state_c = tf.keras.layers.Input(shape=( units), name="encoder context state")

    embeddings = tf.keras.layers.Embedding(vocab_size, embedding_size,
                    embeddings_regularizer=regularizer,mask_zero=True,
                    name='{}-embedding'.format(name))(dec_inputs)
    embeddings *= tf.math.sqrt(tf.cast(embedding_size, tf.float32))
    # model.add(embeddings)

    decoder_outputs, _, _ = tf.keras.layers.LSTM(units=units, return_sequences=True, return_state=True)(embeddings, initial_state=[enc_state_h, enc_state_c])
        
    return tf.keras.Model(
      inputs=[dec_inputs, enc_state_h, enc_state_c], outputs=decoder_outputs, name=name)

In [7]:
def seq2seq(src_vocab_size, tar_vocab_size, embedding_size, units=256, regularizer=None):

    enc_inputs = tf.keras.layers.Input(shape=(None, ), name="encoder inputs")
    dec_inputs = tf.keras.layers.Input(shape=(None, ), name="decoder inputs")

    enc_outputs = encoder(src_vocab_size, units, embedding_size= embedding_size, regularizer=regularizer, name="encoder")(inputs=[enc_inputs])

    enc_state_h = enc_outputs[0]
    enc_state_c = enc_outputs[1]

    dec_outputs = decoder(tar_vocab_size, units, embedding_size=embedding_size, regularizer=regularizer, name="decoder")(inputs=[dec_inputs, enc_state_h, enc_state_c])

    s2s_outputs = Dense(tar_vocab_size, activation='softmax',name="outputs")(dec_outputs)

    return tf.keras.Model(
      inputs=[enc_inputs, dec_inputs], outputs=s2s_outputs, name="sequence-to-sequence")
