In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model

# Encoder LSTM
class Encoder(tf.keras.Model):
    def __init__(self, input_dim, emb_dim, hidden_dim):
        super(Encoder, self).__init__()
        self.embedding = layers.Embedding(input_dim, emb_dim)
        self.lstm = layers.LSTM(hidden_dim, return_state=True)

    def call(self, src):
        # src: [batch, seq_len]
        embedded = self.embedding(src)  # [batch, seq_len, emb_dim]
        output, hidden, cell = self.lstm(embedded)
        return hidden, cell


# Decoder LSTM
class Decoder(tf.keras.Model):
    def __init__(self, output_dim, emb_dim, hidden_dim):
        super(Decoder, self).__init__()
        self.embedding = layers.Embedding(output_dim, emb_dim)
        self.lstm = layers.LSTM(hidden_dim, return_state=True)
        self.fc_out = layers.Dense(output_dim)

    def call(self, input, hidden, cell):
        # input: [batch] (this is a single token)
        embedded = self.embedding(input)  # [batch, emb_dim]
        embedded = tf.expand_dims(embedded, 1)  # Expand to [batch, 1, emb_dim] for LSTM input
        output, hidden, cell = self.lstm(embedded, initial_state=[hidden, cell])
        prediction = self.fc_out(output)  # [batch, output_dim]
        return prediction, hidden, cell


# Sequence-to-Sequence wrapper
class Seq2Seq(tf.keras.Model):
    def __init__(self, encoder, decoder):
        super(Seq2Seq, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def call(self, src, trg):
        # src: [batch, seq_len_src]
        # trg: [batch, seq_len_trg]
        batch_size = trg.shape[0]
        trg_len = trg.shape[1]
        trg_vocab_size = self.decoder.fc_out.units

        # Initialize an empty list to store outputs
        outputs = []

        # Get encoder hidden state and cell state
        hidden, cell = self.encoder(src)

        # Start decoding with <SOS> token (trg[:, 0] is <SOS>)
        input = trg[:, 0]  # <SOS> token

        for t in range(1, trg_len):
            print(f"Time Step {t}: input shape before decoding: {input.shape}")
            output, hidden, cell = self.decoder(input, hidden, cell)

            # Debug the output shape
            print(f"Time Step {t}: output shape from decoder: {output.shape}")

            # Append the output to the list
            outputs.append(output)

            # Teacher forcing: Use the true token as the next input
            input = trg[:, t]  # trg[:, t] should be the actual token at time step t

        # Stack the outputs list into a tensor of shape [batch_size, trg_len, vocab_size]
        outputs = tf.stack(outputs, axis=1)

        return outputs


# Example usage (assuming src and trg are TensorFlow tensors)
src = tf.random.uniform([32, 20], maxval=10000, dtype=tf.int32)  # batch of 32, sequence length 20
trg = tf.random.uniform([32, 15], maxval=10000, dtype=tf.int32)  # batch of 32, sequence length 15

# Define input dimensions
input_dim = 10000  # vocabulary size for source language
output_dim = 10000  # vocabulary size for target language
emb_dim = 256  # embedding dimension
hidden_dim = 512  # hidden dimension for LSTM

# Instantiate models
encoder = Encoder(input_dim, emb_dim, hidden_dim)
decoder = Decoder(output_dim, emb_dim, hidden_dim)
seq2seq_model = Seq2Seq(encoder, decoder)

output = seq2seq_model(src, trg)
print(f"Output shape: {output.shape}")

Time Step 1: input shape before decoding: (32,)
Time Step 1: output shape from decoder: (32, 10000)
Time Step 2: input shape before decoding: (32,)
Time Step 2: output shape from decoder: (32, 10000)
Time Step 3: input shape before decoding: (32,)
Time Step 3: output shape from decoder: (32, 10000)
Time Step 4: input shape before decoding: (32,)
Time Step 4: output shape from decoder: (32, 10000)
Time Step 5: input shape before decoding: (32,)
Time Step 5: output shape from decoder: (32, 10000)
Time Step 6: input shape before decoding: (32,)
Time Step 6: output shape from decoder: (32, 10000)
Time Step 7: input shape before decoding: (32,)
Time Step 7: output shape from decoder: (32, 10000)
Time Step 8: input shape before decoding: (32,)
Time Step 8: output shape from decoder: (32, 10000)
Time Step 9: input shape before decoding: (32,)
Time Step 9: output shape from decoder: (32, 10000)
Time Step 10: input shape before decoding: (32,)
Time Step 10: output shape from decoder: (32, 10000