In [1]:
import tensorflow as tf
from tensorflow.keras import layers, Model

# Define the positional encoding function
def positional_encoding(position, d_model):
    angle_rads = tf.cast(tf.range(position)[:, tf.newaxis], tf.float32) / tf.pow(10000, (2 * (tf.range(d_model)[tf.newaxis, :] // 2)) / tf.cast(d_model, tf.float32))
    angle_rads[:, 0::2] = tf.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = tf.cos(angle_rads[:, 1::2])
    return angle_rads

# Define a scaled dot-product attention function
def scaled_dot_product_attention(q, k, v, mask):
    matmul_qk = tf.matmul(q, k, transpose_b=True)
    depth_k = tf.cast(tf.shape(k)[-1], tf.float32)
    logits = matmul_qk / tf.math.sqrt(depth_k)

    if mask is not None:
        logits += (mask * -1e9)

    attention_weights = tf.nn.softmax(logits, axis=-1)
    output = tf.matmul(attention_weights, v)
    return output, attention_weights

# Define the multi-head attention layer
class MultiHeadAttention(layers.Layer):
    def __init__(self, num_heads, d_model):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.depth = d_model // num_heads

        self.wq = layers.Dense(d_model)
        self.wk = layers.Dense(d_model)
        self.wv = layers.Dense(d_model)
        self.dense = layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)

        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)

        output, attention_weights = scaled_dot_product_attention(q, k, v, mask)

        output = tf.transpose(output, perm=[0, 2, 1, 3])
        output = tf.reshape(output, (batch_size, -1, self.d_model))

        return self.dense(output), attention_weights

# Define the Transformer block
class TransformerBlock(layers.Layer):
    def __init__(self, num_heads, d_model, dff, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.mha = MultiHeadAttention(num_heads, d_model)
        self.ffn = tf.keras.Sequential([
            layers.Dense(dff, activation='relu'),
            layers.Dense(d_model)
        ])
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, x, training, mask):
        attn_output, _ = self.mha(x, x, x, mask)
        out1 = self.layernorm1(x + self.dropout1(attn_output, training=training))
        ffn_output = self.ffn(out1)
        return self.layernorm2(out1 + self.dropout2(ffn_output, training=training))

# Define the full Transformer model
class Transformer(Model):
    def __init__(self, num_layers, num_heads, d_model, dff, input_vocab_size, target_vocab_size, rate=0.1):
        super(Transformer, self).__init__()
        self.encoder = [TransformerBlock(num_heads, d_model, dff, rate) for _ in range(num_layers)]
        self.decoder = [TransformerBlock(num_heads, d_model, dff, rate) for _ in range(num_layers)]
        self.final_layer = layers.Dense(target_vocab_size)

    def call(self, inputs):
        x = inputs[0]  # Encoder input
        for enc in self.encoder:
            x = enc(x, training=True, mask=None)

        for dec in self.decoder:
            x = dec(x, training=True, mask=None)

        return self.final_layer(x)

# Example parameters
num_layers = 4
num_heads = 8
d_model = 128
dff = 512
input_vocab_size = 10000  # Example vocab size
target_vocab_size = 10000  # Example vocab size

# Instantiate the model
transformer_model = Transformer(num_layers, num_heads, d_model, dff, input_vocab_size, target_vocab_size)

# Summary of the model
transformer_model.build((None, None))  # Batch size and sequence length
transformer_model.summary()


