# Importing libraries

In [1]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np


# Positional Encoding

In [9]:
def positional_encoding(max_len, dm):
    pos = np.arange(max_len)[:, np.newaxis]
    i = np.arange(dm)[np.newaxis, :]
    angle_rates = 1 / np.power(10000, (2*(i//2)) / np.float32(dm))
    angles = pos * angle_rates

    angles[:, 0::2] = np.sin(angles[:, 0::2])  # even indices
    angles[:, 1::2] = np.cos(angles[:, 1::2])  # odd indices

    return tf.cast(angles[np.newaxis, ...], dtype=tf.float32)


# Scaled Dot-Product Attention

In [11]:
def scaled_dot_attention(q, k, v, mask=None):
    matmul_qk = tf.matmul(q, k, transpose_b=True)
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled = matmul_qk / tf.sqrt(dk)

    if mask is not None:
        scaled += (mask * -1e9)

    weights = tf.nn.softmax(scaled, axis=-1)
    output = tf.matmul(weights, v)
    return output, weights


# Multi-Head Attention Layer

In [12]:
class MultiHeadAttention(layers.Layer):
    def __init__(self, dm, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.dm = dm
        self.depth = dm // num_heads

        self.wq = layers.Dense(dm)
        self.wk = layers.Dense(dm)
        self.wv = layers.Dense(dm)
        self.linear = layers.Dense(dm)

    def split_heads(self, x):
        batch = tf.shape(x)[0]
        x = tf.reshape(x, (batch, -1, self.num_heads, self.depth))
        return tf.transpose(x, [0, 2, 1, 3])

    def call(self, v, k, q, mask=None):
        q = self.split_heads(self.wq(q))
        k = self.split_heads(self.wk(k))
        v = self.split_heads(self.wv(v))

        scaled, _ = scaled_dot_attention(q, k, v, mask)

        scaled = tf.transpose(scaled, [0, 2, 1, 3])
        concat = tf.reshape(scaled, (tf.shape(scaled)[0], -1, self.dm))

        return self.linear(concat)


# Feed Forward Network

In [13]:
def feed_forward(dm, dff):
    return tf.keras.Sequential([
        layers.Dense(dff, activation='relu'),
        layers.Dense(dm)
    ])


# Transformer Encoder Block

In [14]:
class EncoderBlock(layers.Layer):
    def __init__(self, dm, num_heads, dff, rate=0.1):
        super().__init__()
        self.mha = MultiHeadAttention(dm, num_heads)
        self.ffn = feed_forward(dm, dff)

        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)

        self.drop1 = layers.Dropout(rate)
        self.drop2 = layers.Dropout(rate)

    def call(self, x, training, mask):
        att = self.mha(x, x, x, mask)
        att = self.drop1(att, training=training)
        out1 = self.norm1(x + att)

        ffn_out = self.ffn(out1)
        ffn_out = self.drop2(ffn_out, training=training)

        return self.norm2(out1 + ffn_out)


# Decoder Block

In [15]:
class DecoderBlock(layers.Layer):
    def __init__(self, dm, num_heads, dff, rate=0.1):
        super().__init__()
        self.mha1 = MultiHeadAttention(dm, num_heads)
        self.mha2 = MultiHeadAttention(dm, num_heads)

        self.ffn = feed_forward(dm, dff)

        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.norm3 = layers.LayerNormalization(epsilon=1e-6)

        self.drop1 = layers.Dropout(rate)
        self.drop2 = layers.Dropout(rate)
        self.drop3 = layers.Dropout(rate)

    def call(self, x, enc_out, training, look_ahead_mask, padding_mask):
        # Masked self-attention (decoder)
        att1 = self.mha1(x, x, x, look_ahead_mask)
        att1 = self.drop1(att1, training=training)
        out1 = self.norm1(att1 + x)

        # Cross-attention (decoder queries, encoder keys/values)
        att2 = self.mha2(enc_out, enc_out, out1, padding_mask)
        att2 = self.drop2(att2, training=training)
        out2 = self.norm2(att2 + out1)

        # FFN
        ffn_out = self.ffn(out2)
        ffn_out = self.drop3(ffn_out, training=training)

        return self.norm3(out2 + ffn_out)


# Encoder

In [16]:
class Encoder(layers.Layer):
    def __init__(self, num_layers, dm, num_heads, dff, vocab, max_len, rate=0.1):
        super().__init__()

        self.num_layers = num_layers
        self.embedding = layers.Embedding(vocab, dm)
        self.pos = positional_encoding(max_len, dm)

        self.enc_layers = [
            EncoderBlock(dm, num_heads, dff, rate)
            for _ in range(num_layers)
        ]
        self.drop = layers.Dropout(rate)

    def call(self, x, training, mask):
        seq_len = tf.shape(x)[1]
        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))
        x += self.pos[:, :seq_len, :]

        x = self.drop(x, training=training)

        for layer in self.enc_layers:
            x = layer(x, training, mask)

        return x


# Decoder

In [17]:
class Decoder(layers.Layer):
    def __init__(self, num_layers, dm, num_heads, dff, vocab, max_len, rate=0.1):
        super().__init__()

        self.embedding = layers.Embedding(vocab, dm)
        self.pos = positional_encoding(max_len, dm)

        self.dec_layers = [
            DecoderBlock(dm, num_heads, dff, rate)
            for _ in range(num_layers)
        ]
        self.dropout = layers.Dropout(rate)

    def call(self, x, enc_out, training, look_ahead_mask, padding_mask):
        seq_len = tf.shape(x)[1]
        x = self.embedding(x)
        x *= tf.math.sqrt(tf.cast(tf.shape(x)[-1], tf.float32))
        x += self.pos[:, :seq_len, :]

        for layer in self.dec_layers:
            x = layer(x, enc_out, training, look_ahead_mask, padding_mask)

        return x


# FULL Transformer Model (Encoder and Decoder)

In [18]:
class Transformer(tf.keras.Model):
    def __init__(self, num_layers, dm, num_heads, dff, input_vocab,
                 target_vocab, max_len, rate=0.1):
        super().__init__()
        self.encoder = Encoder(num_layers, dm, num_heads, dff, input_vocab, max_len)
        self.decoder = Decoder(num_layers, dm, num_heads, dff, target_vocab, max_len)
        self.final_layer = layers.Dense(target_vocab)

    def call(self, inp, tar, training, enc_padding_mask,
             look_ahead_mask, dec_padding_mask):

        enc_out = self.encoder(inp, training, enc_padding_mask)
        dec_out = self.decoder(tar, enc_out, training,
                               look_ahead_mask, dec_padding_mask)

        return self.final_layer(dec_out)
