In [14]:
import tensorflow as tf
from tensorflow.keras.layers import Embedding, GRU, Dense, Concatenate

# Hyperparameters
embedding_dim = 256
units = 1024
vocab_size = 10000
batch_size = 64
max_seq_len = 16

# Initialize hidden state
def initialize_hidden_state(batch_sz, units):
    return tf.zeros((batch_sz, units))

# Encoder
def encoder(input_sequence, vocab_size, embedding_dim, units):
    embedding = Embedding(vocab_size, embedding_dim)(input_sequence)
    gru_output, gru_hidden = GRU(units, return_sequences=True, return_state=True)(embedding)
    return gru_output, gru_hidden

# Simplified Bahdanau Attention
def bahdanau_attention(query, values, units):
    W1 = Dense(units)
    W2 = Dense(units)
    V = Dense(1)

    # Apply attention layers
    query_with_time_axis = tf.expand_dims(query, 1)
    score = tf.nn.tanh(W1(values) + W2(query_with_time_axis))
    attention_weights = tf.nn.softmax(V(score), axis=1)

    # Calculate context vector
    context_vector = tf.reduce_sum(attention_weights * values, axis=1)
    return context_vector, attention_weights

# Decoder
def decoder(dec_input, dec_hidden, enc_output, vocab_size, embedding_dim, units):
    context_vector, attention_weights = bahdanau_attention(dec_hidden, enc_output, units)
    dec_embedding = Embedding(vocab_size, embedding_dim)(dec_input)
    dec_input_combined = tf.concat([tf.expand_dims(context_vector, 1), dec_embedding], axis=-1)
    gru_output, gru_state = GRU(units, return_sequences=True, return_state=True)(dec_input_combined)
    output = Dense(vocab_size)(tf.reshape(gru_output, (-1, gru_output.shape[2])))
    return output, gru_state, attention_weights

# Example usage
sample_input = tf.random.uniform((batch_size, max_seq_len), dtype=tf.int32, minval=0, maxval=vocab_size)
sample_target = tf.random.uniform((batch_size, 1), dtype=tf.int32, minval=0, maxval=vocab_size)
enc_hidden = initialize_hidden_state(batch_size, units)

# Run encoder
enc_output, enc_hidden = encoder(sample_input, vocab_size, embedding_dim, units)

# Decoder's initial input (usually <start>)
dec_input = tf.expand_dims([0] * batch_size, 1)

# Run decoder
dec_output, dec_hidden, attention_weights = decoder(dec_input, enc_hidden, enc_output, vocab_size, embedding_dim, units)

# Output shapes
print("Encoder output shape:", enc_output.shape)
print("Decoder output shape:", dec_output.shape)
print("Attention weights shape:", attention_weights.shape)

ValueError: too many values to unpack (expected 2)