In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, LSTM, Dense, Embedding, Attention, Concatenate
from tensorflow.keras.models import Model

class Seq2SeqAttention(Model):
    def __init__(self, vocab_size_source, vocab_size_target, embedding_dim, lstm_units):
        super(Seq2SeqAttention, self).__init__()
        #Define all the layers as attributes.

        # Encoder
        self.encoder_inputs = Input(shape=(None,))
        self.encoder_embeddings = Embedding(vocab_size_source, embedding_dim)
        self.encoder_lstm = LSTM(lstm_units, return_sequences=True, return_state=True)

        # Decoder
        self.decoder_inputs = Input(shape=(None,))
        self.decoder_embeddings = Embedding(vocab_size_target, embedding_dim)
        self.decoder_lstm = LSTM(lstm_units, return_sequences=True, return_state=True)

        # Attention mechanism
        self.attention_layer = Attention()

        # Concatenate layer
        self.decoder_concat = Concatenate(axis=-1)

        # Dense layer for word prediction
        self.word_prediction_layer = Dense(vocab_size_target, activation='softmax')

    def call(self, inputs):
        #Link the layers.
        encoder_input, decoder_input = inputs
        
        # Encoder
        encoder_embeddings = self.encoder_embeddings(encoder_input)
        encoder_outputs, state_h, state_c = self.encoder_lstm(encoder_embeddings)
        encoder_states = [state_h, state_c]

        # Decoder
        decoder_embeddings = self.decoder_embeddings(decoder_input)
        decoder_outputs, _, _ = self.decoder_lstm(decoder_embeddings, initial_state=encoder_states)

        # Attention mechanism
        attention_output = self.attention_layer([decoder_outputs, encoder_outputs])
        decoder_concat_input = self.decoder_concat([decoder_outputs, attention_output])

        # Dense layer for word prediction
        decoder_output = self.word_prediction_layer(decoder_concat_input)

        return decoder_output
