In [4]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import (
    Input, Embedding, Dense, Dropout,
    LayerNormalization, Add, MultiHeadAttention, Lambda , Layer
)
from tensorflow.keras import Model # Moved Model import here


def positional_encoding(seq_len , model_size):
  output = []
  for pos in range(seq_len):
    PE = np.zeros(model_size)
    for i in range(model_size):
      if i % 2 == 0:
        PE[i] = np.sin(pos / (10000 ** (i/model_size)))
      else:
        PE[i] = np.cos(pos / (10000 ** ((i-1)/model_size)))

    output.append(PE)
  out = np.expand_dims(output , axis = 0)

  return out

# Masks
class PaddingMaskLayer(Layer):
  def call(self, input):
      mask = tf.math.not_equal(input, 0)
      expanded_mask = tf.expand_dims(mask, axis=1)
      return expanded_mask


class CausalMaskLayer(Layer):
  def call(self , inputs):
    mask = tf.linalg.band_part(tf.ones((tf.shape(inputs)[1],tf.shape(inputs)[1]),dtype=tf.bool), -1, 0)
    mask = tf.expand_dims(mask, axis=0)
    pad_mask = tf.math.not_equal(inputs, 0)
    pad_mask = tf.expand_dims(pad_mask, axis=1)
    return tf.logical_and(mask, pad_mask)

# --- Encoder Block (Modified to accept encoder_padding_mask) ---
def encoder_block(x, num_heads, d_model, d_ff, encoder_padding_mask, dropout=0.1):
    # Self-attention with padding mask
    attn = MultiHeadAttention(
        num_heads=num_heads,
        key_dim=d_model
    )(x, x, attention_mask=encoder_padding_mask) # Pass the correct padding mask

    attn = Dropout(dropout)(attn)
    x = LayerNormalization(epsilon=1e-6)(x + attn)

    ff = Dense(d_ff, activation="relu")(x)
    ff = Dense(d_model)(ff)
    ff = Dropout(dropout)(ff)
    x = LayerNormalization(epsilon=1e-6)(x + ff)

    return x

# --- Decoder Block (Modified to accept decoder_self_attn_mask and encoder_padding_mask) ---
def decoder_block(x, enc_out, num_heads, d_model, d_ff,
                  decoder_self_attn_mask, dropout=0.1):
    # Masked self-attention (using combined padding and look-ahead mask)
    self_attn = MultiHeadAttention(
        num_heads=num_heads,
        key_dim=d_model
    )(x, x, attention_mask=decoder_self_attn_mask)

    self_attn = Dropout(dropout)(self_attn)
    x = LayerNormalization(epsilon=1e-6)(x + self_attn)

    # Cross-attention (using encoder padding mask for keys/values from encoder_output)
    cross_attn = MultiHeadAttention(
        num_heads=num_heads,
        key_dim=d_model
    )(x, enc_out) # Pass encoder padding mask here

    cross_attn = Dropout(dropout)(cross_attn)
    x = LayerNormalization(epsilon=1e-6)(x + cross_attn)

    # Feed forward
    ff = Dense(d_ff, activation="relu")(x)
    ff = Dense(d_model)(ff)
    ff = Dropout(dropout)(ff)
    x = LayerNormalization(epsilon=1e-6)(x + ff)

    return x

# Hyperparameters
vocab_size = 10000
max_len = 100
d_model = 256
num_heads = 8
d_ff = 512
num_layers = 4

# --- Model Construction ---

# Encoder Inputs
encoder_inputs = Input(shape=(None,), name="encoder_input", dtype=tf.int32)
# Create encoder padding mask from encoder_inputs using a Lambda layer
enc_padding_mask = PaddingMaskLayer()(encoder_inputs)

enc_embed = Embedding(
    vocab_size, d_model
)(encoder_inputs)

enc_embed += positional_encoding(max_len, d_model)

x = enc_embed
for _ in range(num_layers):
    x = encoder_block(x, num_heads, d_model, d_ff, enc_padding_mask) # Pass enc_padding_mask

encoder_output = x

# Decoder Inputs
decoder_inputs = Input(shape=(None,), name="decoder_input", dtype=tf.int32)
# Create decoder self-attention mask from decoder_inputs using a Lambda layer
dec_self_attn_mask = CausalMaskLayer()(decoder_inputs)


dec_embed = Embedding(
    vocab_size, d_model
)(decoder_inputs)
dec_embed += positional_encoding(max_len, d_model)

y = dec_embed
for _ in range(num_layers):
    y = decoder_block(
        y, encoder_output,
        num_heads, d_model, d_ff,
        dec_self_attn_mask # Pass both masks
    )

outputs = Dense(vocab_size, activation="softmax")(y)

model = Model(
    inputs=[encoder_inputs, decoder_inputs],
    outputs=outputs
)

model.compile(
    optimizer="adam",
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)

model.summary()