In [9]:
import tensorflow as tf
from tensorflow.keras.layers import Embedding, GRU, Dense, Concatenate

# Hyperparameters
embedding_dim = 256
units = 1024
vocab_size = 10000  # Example vocab size, adjust based on your data
batch_size = 64
max_seq_len = 16

# Initialize hidden state
def initialize_hidden_state(batch_sz, units):
    return tf.zeros((batch_sz, units))

# Encoder
def encoder(input_sequence, vocab_size, embedding_dim, units):
    embedding = Embedding(vocab_size, embedding_dim)(input_sequence)
    
    # The GRU layer returns multiple values when return_state=True
    gru_output = GRU(units, return_sequences=True, return_state=True,
                     recurrent_initializer='glorot_uniform')(embedding)

    # Extract only the first two values: sequence output and final hidden state
    gru_output, *gru_hidden = gru_output

    return gru_output, gru_hidden[0]  # Return the sequence and the first hidden state

# Fixed Bahdanau Attention Mechanism
def bahdanau_attention(query, values, units):
    # query: decoder hidden state, shape: (batch_size, hidden_size)
    # values: encoder output, shape: (batch_size, max_length, hidden_size)

    # Dense layers for attention
    W1 = Dense(units)
    W2 = Dense(units)
    V = Dense(1)

    # Add time axis to query for broadcasting
    query_with_time_axis = tf.expand_dims(query, 1)  # shape (batch_size, 1, hidden_size)

    # Compute W1_values (encoder output) and W2_query (decoder hidden state)
    W1_values = W1(values)  # shape (batch_size, max_length, units)
    W2_query = W2(query_with_time_axis)  # shape (batch_size, 1, units)

    # Ensure W2_query has shape (batch_size, max_length, units) for broadcasting
    W2_query = tf.broadcast_to(W2_query, W1_values.shape)  # shape (batch_size, max_length, units)

    # Compute the score by adding W1_values and W2_query, then applying tanh
    score = tf.nn.tanh(W1_values + W2_query)  # shape (batch_size, max_length, units)

    # Compute attention weights
    attention_weights = tf.nn.softmax(V(score), axis=1)  # shape (batch_size, max_length, 1)

    # Compute context vector as the weighted sum of the encoder's output
    context_vector = attention_weights * values  # shape (batch_size, max_length, hidden_size)
    context_vector = tf.reduce_sum(context_vector, axis=1)  # shape (batch_size, hidden_size)

    return context_vector, attention_weights

# Decoder
def decoder(dec_input, dec_hidden, enc_output, vocab_size, embedding_dim, units):
    # Perform attention mechanism
    context_vector, attention_weights = bahdanau_attention(dec_hidden, enc_output, units)

    # Embedding for decoder input
    dec_embedding = Embedding(vocab_size, embedding_dim)(dec_input)

    # Concatenate context vector with decoder input embedding
    dec_input_concat = Concatenate(axis=-1)([tf.expand_dims(context_vector, 1), dec_embedding])

    # Pass through GRU
    gru_output, gru_state = GRU(units, return_sequences=True, return_state=True,
                                recurrent_initializer='glorot_uniform')(dec_input_concat)

    # Dense layer to generate predictions
    output = Dense(vocab_size)(tf.reshape(gru_output, (-1, gru_output.shape[2])))

    return output, gru_state, attention_weights

# Example usage
# Input sample for encoder and decoder
sample_input = tf.random.uniform((batch_size, max_seq_len), dtype=tf.int32, minval=0, maxval=vocab_size)
sample_target = tf.random.uniform((batch_size, 1), dtype=tf.int32, minval=0, maxval=vocab_size)

# Initialize hidden state for encoder
enc_hidden = initialize_hidden_state(batch_size, units)

# Run encoder
enc_output, enc_hidden = encoder(sample_input, vocab_size, embedding_dim, units)

# Decoder's initial input (usually start token <start>)
dec_input = tf.expand_dims([0] * batch_size, 1)

# Run decoder
dec_output, dec_hidden, attention_weights = decoder(dec_input, enc_hidden, enc_output, vocab_size, embedding_dim, units)

# Output shapes
print("Encoder output shape:", enc_output.shape)  # (batch_size, seq_len, units)
print("Encoder hidden state shape:", enc_hidden.shape)  # (batch_size, units)
print("Decoder output shape:", dec_output.shape)  # (batch_size * seq_len, vocab_size)
print("Attention weights shape:", attention_weights.shape)  # (batch_size, seq_len, 1)

InvalidArgumentError: {{function_node __wrapped__BroadcastTo_device_/job:localhost/replica:0/task:0/device:GPU:0}} Incompatible shapes: [1024,1024] vs. [64,16,1024] [Op:BroadcastTo]