In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf

"""
Everything in TensorFlow is based on Tensor operations.
Tensors are (kind of) like np.arrays.
All tensors are immutable: you can never update the contents of a
tensor, only create a new one.

 - nd-arrays (1d, 2d, or even 3d and higher)
 - GPU support
 - Computational graph / Track gradients / Backpropagation
 - Immutable!
"""

### 1 create tensor

In [None]:
# 1. create tensors
# scalar, rank-0 tensor
x = tf.constant(4)
print(x)

In [None]:
x = tf.constant(4, shape=(1,1), dtype=tf.float32)
print(x)

In [None]:
# vector, rank-1
x = tf.constant([1,2,3])
print(x)

In [None]:
# matrix, rank-2
x = tf.constant([[1,2,3], [4,5,6]])
print(x)

### 2 zeros, ones

In [None]:
x = tf.ones((3,3))
print(x)

x = tf.zeros((3,3))
print(x)

In [None]:
x = tf.eye(3)
print(x)

### 3 std deviation .. normal distribution

In [None]:
x = tf.random.normal((3,3), mean=0, stddev=1)
print(x)

In [None]:
x = tf.random.uniform((3,3), minval=0, maxval=1)        ## values are between 0 and 1
print(x)

In [None]:
x = tf.range(10)
print(x)

In [None]:
x = tf.range(start=1, limit=15, delta=3, dtype=tf.float32)
print(x)

In [None]:
x = tf.range(start=1, limit=15, delta=3, dtype=tf.float32)
x = tf.cast(x, dtype=tf.float16)
print(x)

### 4 mathematical operations

In [14]:
x = tf.constant([1,2,3])
y = tf.constant([4,5,6])

In [None]:
z = tf.add(x,y)
# z = x + y
print(z)

In [None]:
# z = tf.subtract(x,y)
z = x - y
print(z)

In [None]:
z = tf.divide(x,y)
# z = x / y
print(z)

In [None]:
z = tf.multiply(x,y)
# z = x * y
print(z)

### 5 tensordot

In [None]:
x = tf.constant([1,2,3])
y = tf.constant([4,5,6])

print(x)
print(y)
z = tf.tensordot(x,y, axes=1)       ## it will do element wise multiplication and then summation
print(z)                            ## [ (1*4) + (2*5) + (3*6) ]  => 32

In [None]:
x = tf.random.uniform(shape=(2,3), minval=1, maxval=100)
y = tf.random.uniform(shape=(2,3), minval=1, maxval=100)
print(x)
print()
print(y)
print()
z = tf.tensordot(x,y, axes=0)
print(z)
"""
x =>  [[ x11, x12, x13 ]                    y =>  [[ y11, y12, y13 ]
       [ x21, x22, x23 ]]                          [ y21, y22, y23 ]]

x and y both are having shape of (2 row 3 columns)

dot product of both matrices will be the shape of (2, 3, 2, 3)

Example => 
a1 = [[ (x11 * y11), (x11 * y12), (x11 * y13) ]
      [ (x11 * y21), (x11 * y22), (x11 * y23) ]]

a2 = [[ (x12 * y11), (x12 * y12), (x12 * y13) ]
      [ (x12 * y21), (x12 * y22), (x12 * y23) ]]

a2 = [[ (x13 * y11), (x13 * y12), (x13 * y13) ]
      [ (x13 * y21), (x13 * y22), (x13 * y23) ]]

the above 3 result will create one metrix (1, 2, 3) of this shape ...
like this it has 3 values in one row .. it maens (3, 2, 3) metrix
there are 2 rows .. so the metrix shape will be (2, 3, 2, 3)
"""

### 6 reduce_sum , reduce_max , reduce_mean

In [None]:
x = tf.constant([1,2,3])
y = tf.constant([4,5,6])

print(x)
print(y)
# z = tf.reduce_sum(x*y, axis=0)       ## it will do element wise multiplication and then summation
# z = tf.reduce_max(x+y, axis=0)
z = tf.reduce_mean(x+y, axis=0)
print(z)  

In [None]:
x = tf.constant([1,2,3])
# elementwise exponentiate
z = x ** 3
print(z)

### 7 matrix multiplication
matrix multiplication (shapes must match: number of columns A = number of rows B)

In [None]:
x = tf.random.normal((2,3))
y = tf.random.normal((3,4))

z = tf.matmul(x,y)
    ## or
# z = x @ y

print(z)

### 8 indexing, slicing

In [None]:
x = tf.constant([[1,2,3,4],[5,6,7,8]])
print(x[0])
print(x[:, 0]) # all rows, column 0
print(x[1, :]) # row 1, all columns
print(x[1,1]) # element at 1, 1

In [None]:
x = tf.constant([1,2,3,4,5,6,7,8,9])

# print(x[::-1])      ## reverse order of the rows
# print(x[1:3])

## specific indexes values
list_of_indexes = tf.constant([0, 3, 5])
print(tf.gather(x, indices= [0, 3, 5]))
# print(tf.gather(x, indices= list_of_indexes))

In [None]:
x = tf.random.uniform(shape=(4,4), minval=1, maxval=100)
print(x)
print()
# print(x[::-1])      ## reverse order of the rows
# print(x[1:3, :])

### 9 reshape

In [None]:
x = tf.random.normal((2,3))
print(x.shape)
x = tf.reshape(x, (3,2))
print(x)

In [None]:
x = tf.reshape(x, (-1,2))
print(x)

In [None]:
x = tf.reshape(x, (6))
print(x)

### 10 transpose

In [None]:
x = tf.random.uniform(shape=(4,4), minval=1, maxval=100)
print(x)
print()
print(tf.transpose(x, perm=[1,0]))          ## it will convert row to columns and column to row
"""
if the metrix shape is (3, 4)  .. and you will do transpose (0,1)
it means you wnat to swap the shape number ( 3,4 ) to ( 4,3 )
"""

In [None]:
x = tf.random.uniform(shape=(2,5,4), minval=1, maxval=100)
print(x)
print()
print(tf.transpose(x, perm=[2,0,1]))          ## it will convert row to columns and column to row

"""
if the metrix shape is (2,5,4)  .. and you will do transpose (2,0,1)
it means you wnat to swap the shape number ( 2,5,4 ) to ( 4,2,5 )

earlier it was the combination of two (5,4) matrixs ...
after transpose it has become combination of four (2,5) matrixs
"""

### 11 numpy to tensor

In [None]:
x = x.numpy()
print(type(x))

In [None]:
x = tf.convert_to_tensor(x)
print(type(x))
# -> eager tensor = evaluates operations immediately
# without building graphs

### 12 string tensor

In [None]:
## string tensor
x = tf.constant("alpha")
print(x)

x = tf.constant(["alpha", "beta", "gamma"])
print(x)

### 13 variable

In [None]:
# Variable
# A tf.Variable represents a tensor whose value can be
# changed by running ops on it
# Used to represent shared, persistent state your program manipulates
# Higher level libraries like tf.keras use tf.Variable to store model parameters.
b = tf.Variable([[1.0, 2.0, 3.0]])
print(b)
print(type(b))

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np

# -------------------- Transformer Model Definitions -------------------- #

# Scaled Dot-Product Attention
class ScaledDotProductAttention(layers.Layer):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def call(self, queries, keys, values, mask):
        # queries, keys, values: (batch_size, num_heads, seq_len, depth)
        matmul_qk = tf.matmul(queries, keys, transpose_b=True)  # (batch_size, num_heads, seq_len, seq_len)

        # Scale matmul_qk
        dk = tf.cast(tf.shape(keys)[-1], tf.float32)
        scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

        # Add the mask to the scaled tensor.
        if mask is not None:
            scaled_attention_logits += (mask * -1e9)  # Large negative value to mask

        # Softmax on the last axis (seq_len_k)
        attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (batch_size, num_heads, seq_len, seq_len)

        output = tf.matmul(attention_weights, values)  # (batch_size, num_heads, seq_len, depth)

        return output, attention_weights

# Multi-Head Attention
class MultiHeadAttention(layers.Layer):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert embed_size % num_heads == 0, "Embedding size must be divisible by num_heads"

        self.num_heads = num_heads
        self.embed_size = embed_size
        self.depth = embed_size // num_heads

        self.wq = layers.Dense(embed_size)  # Query weight
        self.wk = layers.Dense(embed_size)  # Key weight
        self.wv = layers.Dense(embed_size)  # Value weight

        self.dense = layers.Dense(embed_size)  # Final dense layer

        self.attention = ScaledDotProductAttention()

    def split_heads(self, x, batch_size):
        """
        Split the last dimension into (num_heads, depth).
        Transpose the result to shape (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])  # (batch_size, num_heads, seq_len, depth)

    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]

        # Linear layers
        q = self.wq(q)  # (batch_size, seq_len_q, embed_size)
        k = self.wk(k)  # (batch_size, seq_len_k, embed_size)
        v = self.wv(v)  # (batch_size, seq_len_v, embed_size)

        # Split heads
        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

        # Scaled Dot-Product Attention
        scaled_attention, attention_weights = self.attention(q, k, v, mask)
        # scaled_attention: (batch_size, num_heads, seq_len_q, depth)

        # Transpose and reshape
        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)
        concat_attention = tf.reshape(scaled_attention, 
                                      (batch_size, -1, self.num_heads * self.depth))  # (batch_size, seq_len_q, embed_size)

        # Final linear layer
        output = self.dense(concat_attention)  # (batch_size, seq_len_q, embed_size)

        return output, attention_weights

# Position-wise Feed-Forward Network
class PositionWiseFeedForward(layers.Layer):
    def __init__(self, embed_size, forward_expansion):
        super(PositionWiseFeedForward, self).__init__()
        self.fc1 = layers.Dense(forward_expansion * embed_size, activation='relu')
        self.fc2 = layers.Dense(embed_size)

    def call(self, x):
        return self.fc2(self.fc1(x))

# Encoder Layer
class EncoderLayer(layers.Layer):
    def __init__(self, embed_size, num_heads, forward_expansion, dropout):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(embed_size, num_heads)
        self.ffn = PositionWiseFeedForward(embed_size, forward_expansion)

        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = layers.Dropout(dropout)
        self.dropout2 = layers.Dropout(dropout)

    def call(self, x, mask, training):
        # Multi-Head Attention
        attn_output, _ = self.mha(x, x, x, mask)  # Self-attention
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)  # Residual connection

        # Feed-Forward Network
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)  # Residual connection

        return out2

# Decoder Layer
class DecoderLayer(layers.Layer):
    def __init__(self, embed_size, num_heads, forward_expansion, dropout):
        super(DecoderLayer, self).__init__()
        self.mha1 = MultiHeadAttention(embed_size, num_heads)  # Self-attention
        self.mha2 = MultiHeadAttention(embed_size, num_heads)  # Encoder-Decoder attention
        self.ffn = PositionWiseFeedForward(embed_size, forward_expansion)

        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = layers.Dropout(dropout)
        self.dropout2 = layers.Dropout(dropout)
        self.dropout3 = layers.Dropout(dropout)

    def call(self, x, enc_output, src_mask, trg_mask, training):
        # Self-Attention (masked)
        attn1, attn_weights_block1 = self.mha1(x, x, x, trg_mask)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(x + attn1)

        # Encoder-Decoder Attention
        attn2, attn_weights_block2 = self.mha2(enc_output, enc_output, out1, src_mask)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(out1 + attn2)

        # Feed-Forward Network
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(out2 + ffn_output)

        return out3, attn_weights_block1, attn_weights_block2

# Encoder
class Encoder(layers.Layer):
    def __init__(self, src_vocab_size, embed_size, num_layers, num_heads, forward_expansion, dropout, max_length):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.embedding = layers.Embedding(src_vocab_size, embed_size)
        self.pos_encoding = self.positional_encoding(max_length, embed_size)
        self.layers = [EncoderLayer(embed_size, num_heads, forward_expansion, dropout) for _ in range(num_layers)]
        self.dropout = layers.Dropout(dropout)

    def get_angles(self, pos, i, embed_size):
        # Ensure 10000 is a float
        angle_rates = 1.0 / tf.pow(10000.0, (2 * (i // 2)) / tf.cast(embed_size, tf.float32))
        return pos * angle_rates

    def positional_encoding(self, max_length, embed_size):
        angle_rads = self.get_angles(
            pos=tf.range(max_length)[:, tf.newaxis],
            i=tf.range(embed_size)[tf.newaxis, :],
            embed_size=embed_size
        )
        # Apply sin to even indices, cos to odd indices
        angle_rads = tf.where(tf.math.mod(tf.range(embed_size), 2) == 0, 
                              tf.sin(angle_rads), 
                              tf.cos(angle_rads))
        pos_encoding = angle_rads[tf.newaxis, ...]  # Shape: (1, max_length, embed_size)
        return tf.cast(pos_encoding, dtype=tf.float32)

    def call(self, x, mask, training):
        seq_length = tf.shape(x)[1]
        x = self.embedding(x)  # (batch_size, seq_length, embed_size)
        x *= tf.math.sqrt(tf.cast(self.embed_size, tf.float32))
        x += self.pos_encoding[:, :seq_length, :]
        x = self.dropout(x, training=training)

        for layer in self.layers:
            x = layer(x, mask, training)

        return x  # (batch_size, seq_length, embed_size)

# Decoder
class Decoder(layers.Layer):
    def __init__(self, trg_vocab_size, embed_size, num_layers, num_heads, forward_expansion, dropout, max_length):
        super(Decoder, self).__init__()
        self.embed_size = embed_size
        self.embedding = layers.Embedding(trg_vocab_size, embed_size)
        self.pos_encoding = self.positional_encoding(max_length, embed_size)
        self.layers = [DecoderLayer(embed_size, num_heads, forward_expansion, dropout) for _ in range(num_layers)]
        self.dropout = layers.Dropout(dropout)
        self.fc_out = layers.Dense(trg_vocab_size)

    def get_angles(self, pos, i, embed_size):
        # Ensure 10000 is a float
        angle_rates = 1.0 / tf.pow(10000.0, (2 * (i // 2)) / tf.cast(embed_size, tf.float32))
        return pos * angle_rates

    def positional_encoding(self, max_length, embed_size):
        angle_rads = self.get_angles(
            pos=tf.range(max_length)[:, tf.newaxis],
            i=tf.range(embed_size)[tf.newaxis, :],
            embed_size=embed_size
        )
        # Apply sin to even indices, cos to odd indices
        angle_rads = tf.where(tf.math.mod(tf.range(embed_size), 2) == 0, 
                              tf.sin(angle_rads), 
                              tf.cos(angle_rads))
        pos_encoding = angle_rads[tf.newaxis, ...]  # Shape: (1, max_length, embed_size)
        return tf.cast(pos_encoding, dtype=tf.float32)

    def call(self, x, enc_output, src_mask, trg_mask, training):
        seq_length = tf.shape(x)[1]
        attention_weights = {}

        x = self.embedding(x)  # (batch_size, target_seq_len, embed_size)
        x *= tf.math.sqrt(tf.cast(self.embed_size, tf.float32))
        x += self.pos_encoding[:, :seq_length, :]
        x = self.dropout(x, training=training)

        for i, layer in enumerate(self.layers):
            x, block1, block2 = layer(x, enc_output, src_mask, trg_mask, training)
            attention_weights[f'decoder_layer{i+1}_block1'] = block1
            attention_weights[f'decoder_layer{i+1}_block2'] = block2

        out = self.fc_out(x)  # (batch_size, target_seq_len, trg_vocab_size)

        return out, attention_weights

# Transformer Model
class Transformer(Model):
    def __init__(self, src_vocab_size, trg_vocab_size, src_pad_idx, trg_pad_idx, embed_size=512, num_layers=6, 
                 num_heads=8, forward_expansion=4, dropout=0.1, max_length=100):
        super(Transformer, self).__init__()
        self.encoder = Encoder(src_vocab_size, embed_size, num_layers, num_heads, forward_expansion, dropout, max_length)
        self.decoder = Decoder(trg_vocab_size, embed_size, num_layers, num_heads, forward_expansion, dropout, max_length)
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx

    def make_src_mask(self, src):
        # src: (batch_size, src_seq_len)
        mask = tf.cast(tf.math.not_equal(src, self.src_pad_idx), dtype=tf.float32)  # (batch_size, src_seq_len)
        mask = mask[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, src_seq_len)
        return mask  # Broadcastable to (batch_size, num_heads, trg_seq_len, src_seq_len)

    def make_trg_mask(self, trg):
        # trg: (batch_size, trg_seq_len)
        seq_len = tf.shape(trg)[1]
        # Look-ahead mask
        look_ahead_mask = 1.0 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)  # (trg_seq_len, trg_seq_len)
        # Padding mask
        padding_mask = tf.cast(tf.math.not_equal(trg, self.trg_pad_idx), dtype=tf.float32)  # (batch_size, trg_seq_len)
        padding_mask = padding_mask[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, trg_seq_len)
        # Combine masks
        combined_mask = tf.maximum(look_ahead_mask, 1.0 - padding_mask)  # (batch_size, trg_seq_len, trg_seq_len)
        combined_mask = combined_mask[tf.newaxis, ...]  # (1, trg_seq_len, trg_seq_len)
        return combined_mask  # Broadcastable to (batch_size, num_heads, trg_seq_len, trg_seq_len)

    def call(self, src, trg, training):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)

        enc_output = self.encoder(src, src_mask, training)  # (batch_size, src_seq_len, embed_size)
        dec_output, attention_weights = self.decoder(trg, enc_output, src_mask, trg_mask, training)  # (batch_size, trg_seq_len, trg_vocab_size)

        return dec_output  # logits



# Constants
src_pad_idx = 0
trg_pad_idx = 0
src_vocab_size = 10  # Example vocabulary size
trg_vocab_size = 10  # Example vocabulary size
embed_size = 128
num_layers = 2
num_heads = 2
forward_expansion = 4
dropout = 0.1
max_length = 9  # Length of your sequences
batch_size = 2
epochs = 20

# Sample training data
# Source sequences (e.g., tokenized sentences)
x = tf.constant([
    [1, 5, 6, 4, 3, 9, 5, 2, 0],
    [1, 8, 7, 3, 4, 5, 6, 7, 2]
], dtype=tf.int32)

# Target sequences (e.g., tokenized sentences)
trg = tf.constant([
    [1, 7, 4, 3, 5, 9, 2, 0, 0],
    [1, 5, 6, 2, 4, 7, 6, 2, 0]
], dtype=tf.int32)

# Create trg_input and trg_real
trg_input = trg[:, :-1]  # (batch_size, trg_seq_len - 1)
trg_real = trg[:, 1:]    # (batch_size, trg_seq_len - 1)

# Create training dataset
train_dataset = tf.data.Dataset.from_tensor_slices(((x, trg_input), trg_real))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# -------------------- Loss Function -------------------- #

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

def loss_function(real, pred):
    """
    real: (batch_size, trg_seq_len - 1)
    pred: (batch_size, trg_seq_len - 1, trg_vocab_size)
    """
    mask = tf.math.logical_not(tf.math.equal(real, trg_pad_idx))  # (batch_size, trg_seq_len - 1)
    loss_ = loss_object(real, pred)  # (batch_size, trg_seq_len - 1)

    mask = tf.cast(mask, dtype=loss_.dtype)  # (batch_size, trg_seq_len - 1)
    loss_ *= mask  # Apply the mask

    return tf.reduce_sum(loss_) / tf.reduce_sum(mask)  # Scalar

# -------------------- Optimizer with Learning Rate Scheduler -------------------- #

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, embed_size, warmup_steps=4000):
        super(CustomSchedule, self).__init__()
        self.embed_size = embed_size
        self.embed_size = tf.cast(self.embed_size, tf.float32)
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        # Implement the learning rate schedule
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.embed_size) * tf.math.minimum(arg1, arg2)

learning_rate = CustomSchedule(embed_size)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

# -------------------- Metrics -------------------- #

train_loss = tf.keras.metrics.Mean(name='train_loss')

# -------------------- Initialize the Transformer Model -------------------- #

transformer = Transformer(
    src_vocab_size=src_vocab_size, 
    trg_vocab_size=trg_vocab_size, 
    src_pad_idx=src_pad_idx, 
    trg_pad_idx=trg_pad_idx,
    embed_size=embed_size, 
    num_layers=num_layers, 
    num_heads=num_heads, 
    forward_expansion=forward_expansion, 
    dropout=dropout, 
    max_length=max_length
)

# -------------------- Training Step -------------------- #

@tf.function
def train_step(src, trg_input, trg_real):
    with tf.GradientTape() as tape:
        # Forward pass
        predictions = transformer(src, trg_input, training=True)  # (batch_size, trg_seq_len - 1, trg_vocab_size)
        loss = loss_function(trg_real, predictions)

    # Compute gradients
    gradients = tape.gradient(loss, transformer.trainable_variables)
    # Apply gradients
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    # Update metrics
    train_loss(loss)

# -------------------- Training Loop -------------------- #

for epoch in range(epochs):
    train_loss.reset_states()  # Reset the metrics at the start of each epoch

    for batch, ((src_batch, trg_input_batch), trg_real_batch) in enumerate(train_dataset):
        train_step(src_batch, trg_input_batch, trg_real_batch)

    print(f'Epoch {epoch + 1}, Loss: {train_loss.result():.4f}')

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Layer, Dense, Embedding, LayerNormalization, Dropout
import numpy as np

class MultiHeadAttention(Layer):
    def __init__(self, embed_size, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.embed_size = embed_size
        self.head_dim = embed_size // num_heads

        assert embed_size % num_heads == 0, "Embedding size must be divisible by number of heads"

        self.values = Dense(embed_size)
        self.keys = Dense(embed_size)
        self.queries = Dense(embed_size)
        self.fc_out = Dense(embed_size)

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.head_dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, queries, keys, values, mask):
        batch_size = tf.shape(queries)[0]

        q = self.split_heads(self.queries(queries), batch_size)
        k = self.split_heads(self.keys(keys), batch_size)
        v = self.split_heads(self.values(values), batch_size)

        attention_scores = tf.matmul(q, k, transpose_b=True)
        if mask is not None:
            attention_scores += (mask * -1e9)

        attention_weights = tf.nn.softmax(attention_scores / tf.math.sqrt(tf.cast(self.head_dim, tf.float32)), axis=-1)
        out = tf.matmul(attention_weights, v)
        out = tf.transpose(out, perm=[0, 2, 1, 3])
        out = tf.reshape(out, (batch_size, -1, self.embed_size))
        return self.fc_out(out)

class PositionwiseFeedForward(Layer):
    def __init__(self, embed_size, ff_dim, dropout_rate=0.1):
        super(PositionwiseFeedForward, self).__init__()
        self.fc1 = Dense(ff_dim, activation='relu')
        self.fc2 = Dense(embed_size)
        self.dropout = Dropout(dropout_rate)

    def call(self, x):
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

class TransformerBlock(Layer):
    def __init__(self, embed_size, num_heads, ff_dim, dropout_rate=0.1):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadAttention(embed_size, num_heads)
        self.ffn = PositionwiseFeedForward(embed_size, ff_dim, dropout_rate)
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(dropout_rate)
        self.dropout2 = Dropout(dropout_rate)

    def call(self, x, mask):
        attn_output = self.attention(x, x, x, mask)
        attn_output = self.dropout1(attn_output)
        x = self.layernorm1(x + attn_output)
        ffn_output = self.ffn(x)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(x + ffn_output)

class Transformer(keras.Model):
    def __init__(self, src_vocab_size, trg_vocab_size, embed_size, num_layers, num_heads, ff_dim, dropout_rate=0.1, max_length=500):
        super(Transformer, self).__init__()
        self.embed_size = embed_size
        self.src_vocab_size = src_vocab_size
        self.trg_vocab_size = trg_vocab_size
        self.max_length = max_length

        self.src_embedding = Embedding(src_vocab_size, embed_size)
        self.trg_embedding = Embedding(trg_vocab_size, embed_size)
        self.positional_encoding = self.create_positional_encoding()

        self.encoder_layers = [TransformerBlock(embed_size, num_heads, ff_dim, dropout_rate) for _ in range(num_layers)]
        self.decoder_layers = [TransformerBlock(embed_size, num_heads, ff_dim, dropout_rate) for _ in range(num_layers)]

        self.fc_out = Dense(trg_vocab_size)
        self.dropout = Dropout(dropout_rate)

    def create_positional_encoding(self):
        pos = np.arange(self.max_length)[:, np.newaxis]
        i = np.arange(self.embed_size)[np.newaxis, :]
        angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(self.embed_size))
        angle_rads = pos * angle_rates
        angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
        angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
        return tf.constant(angle_rads[np.newaxis, :], dtype=tf.float32)

    def call(self, src, trg, src_mask=None, trg_mask=None):
        src_embedded = self.src_embedding(src) + self.positional_encoding[:, :tf.shape(src)[1], :]
        trg_embedded = self.trg_embedding(trg) + self.positional_encoding[:, :tf.shape(trg)[1], :]

        enc_output = src_embedded
        for layer in self.encoder_layers:
            enc_output = layer(enc_output, src_mask)

        dec_output = trg_embedded
        for layer in self.decoder_layers:
            dec_output = layer(dec_output, trg_mask)

        final_output = self.fc_out(dec_output)
        return final_output
