In [1]:
import tensorflow as tf
from tensorflow import keras
from keras import layers
from keras.losses import SparseCategoricalCrossentropy
from keras.optimizers import Adam

In [2]:
class Encoder(tf.keras.Model):
    def __init__(self, input_vocab_size, embedding_dim, hidden_dim):
        super(Encoder, self).__init__

        self.embedding = layers.Embedding(
            input_dim=input_vocab_size, output_dim=embedding_dim, mask_zero=True
        )

        self.lstm = layers.LSTM(units=hidden_dim, return_state=True)

    def call(self, inputs):

        x = self.embedding(inputs)
        encoder_out, state_h, state_c = self.lstm(x)
        return (state_h, state_c)

In [3]:
batch_size = 32
seq_len = 20
embedding_dim = 10
target_vocab_size = 50
input_vocab_size = 30
hidden_dim = 16

X = tf.random.uniform(shape=(batch_size, seq_len), minval=0, maxval=input_vocab_size)
y = tf.random.uniform(shape=(batch_size, seq_len), minval=0, maxval=target_vocab_size)
encoder = Encoder(
    input_vocab_size=input_vocab_size,
    embedding_dim=embedding_dim,
    hidden_dim=hidden_dim,
)
encoder(X)[1].shape

TensorShape([32, 16])

In [4]:
class Decoder(tf.keras.Model):
    def __init__(self, target_vocab_size, embedding_dim, hidden_dim):
        super(Decoder, self).__init__()

        self.embedding = layers.Embedding(
            target_vocab_size, embedding_dim, mask_zero=True
        )

        self.lstm = layers.LSTM(hidden_dim, return_state=True, return_sequences=True)

        self.dense = layers.Dense(target_vocab_size, "softmax")

    def call(self, inputs):
        decoder_input, encoder_output = inputs

        x = self.embedding(decoder_input)

        decoder_output, _, _ = self.lstm(x, initial_state=encoder_output)

        return self.dense(decoder_output)

In [5]:
decoder = Decoder(
    target_vocab_size=target_vocab_size,
    embedding_dim=embedding_dim,
    hidden_dim=hidden_dim,
)
decoder((X, encoder(X)))

<tf.Tensor: shape=(32, 20, 50), dtype=float32, numpy=
array([[[0.02004724, 0.02001585, 0.01997572, ..., 0.02002917,
         0.02002766, 0.01985104],
        [0.01997081, 0.0200408 , 0.01991662, ..., 0.020055  ,
         0.02002386, 0.01993472],
        [0.0198773 , 0.02004897, 0.0198836 , ..., 0.02004115,
         0.02006204, 0.02008183],
        ...,
        [0.02005896, 0.02004887, 0.01997051, ..., 0.02008729,
         0.01996567, 0.01996697],
        [0.02009996, 0.02006258, 0.01994334, ..., 0.01999923,
         0.01998866, 0.0200236 ],
        [0.0200354 , 0.02011444, 0.01984122, ..., 0.01993513,
         0.02001637, 0.02015435]],

       [[0.01986672, 0.02007607, 0.01995894, ..., 0.01995959,
         0.01996633, 0.0200538 ],
        [0.01996388, 0.0200776 , 0.01994407, ..., 0.01988657,
         0.0199767 , 0.0200689 ],
        [0.01994152, 0.02012059, 0.01984973, ..., 0.0198383 ,
         0.01999843, 0.02017463],
        ...,
        [0.0199283 , 0.02008826, 0.01985196, ..., 0.01

In [6]:
class Seq2Seq(tf.keras.Model):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()

        self.encoder = encoder
        self.decoder = decoder

    def call(self, inputs):
        encoder_input, decoder_input = inputs
        encoder_output = self.encoder(encoder_input)
        decoder_output = self.decoder((decoder_input, encoder_output))
        return decoder_output

In [7]:
model = Seq2Seq(encoder, decoder)
model.compile(optimizer=Adam(), loss=SparseCategoricalCrossentropy())

In [8]:
tf.argmax(model((X, X)), axis=1)

<tf.Tensor: shape=(32, 50), dtype=int64, numpy=
array([[18, 10, 16, ..., 15,  4, 19],
       [16,  2,  7, ..., 18,  8,  4],
       [ 3, 14,  9, ...,  8, 18, 14],
       ...,
       [16, 19,  0, ...,  7, 18, 19],
       [12,  5, 18, ...,  9,  8,  0],
       [ 3,  6, 11, ...,  8,  7,  9]], dtype=int64)>

In [10]:
input_data = tf.random.uniform(
    (batch_size, input_vocab_size), minval=0, maxval=input_vocab_size, dtype=tf.int32
)
target_data = tf.random.uniform(
    (batch_size, target_vocab_size), minval=0, maxval=target_vocab_size, dtype=tf.int32
)

# Train the model
model.fit(
    [input_data, target_data[:, :-1]], target_data[:, 1:], epochs=100, batch_size=32
)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/100
Epoch 76/100
Epoch 77/100
Epoch 78

<keras.callbacks.History at 0x1dc935cd410>