In [32]:
import tensorflow as tf
import numpy as np
import math
from tensorflow.keras.layers import Dense, Dropout

In [33]:
class InputEmbedding(tf.keras.layers.Layer):
    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.d_model = d_model
        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)

    def call(self, x):
        return self.embedding(x) * math.sqrt(self.d_model)

In [34]:
class FFD(tf.keras.layers.Layer):
    def __init__(self, d_ffn, input_size, dropout_rate=0.1):
        super().__init__()
        self.dense1 = Dense(d_ffn, activation='relu')
        self.dropout = Dropout(dropout_rate)
        self.dense2 = Dense(input_size)

    def call(self, x, training=False):
        x = self.dense1(x)
        x = self.dropout(x, training=training)
        x = self.dense2(x)
        return x

In [35]:
class PositionalEncoding(tf.keras.layers.Layer):
    def __init__(self, d_model, seq_len, dropout):
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = tf.keras.layers.Dropout(dropout)

        position = tf.cast(tf.range(seq_len)[:, tf.newaxis], tf.float32)
        div_term = tf.exp(tf.range(0, d_model, 2, dtype=tf.float32) * (-tf.math.log(10000.0) / d_model))

        angle_rads = position * div_term
        sin_encoding = tf.math.sin(angle_rads)
        cos_encoding = tf.math.cos(angle_rads)

        pe = tf.stack([sin_encoding, cos_encoding], axis=-1)
        pe = tf.reshape(pe, (seq_len, d_model))

        self.pos_encoding = tf.expand_dims(pe, axis=0)

    def call(self, x):
        x = x * tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x = x + self.pos_encoding[:, :tf.shape(x)[1], :]
        return self.dropout(x)

In [51]:
class MHA(tf.keras.layers.Layer):
    def __init__(self,d_model,heads):
        super().__init__()
        self.d_model = d_model
        self.heads = heads
        assert d_model % heads == 0, "d_model mus tbe divisible by nums_heads"
        self.depth = d_model // heads
        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)
        self.wo = tf.keras.layers.Dense(d_model)
        
    def attention(self,q,k,v,mask=None):
        qk = tf.matmul(q,k,transpose_b = True)
        dk = tf.cast(tf.shape(k)[-1],tf.float32)
        qk = qk/tf.math.sqrt(dk)
        if mask is not None:
            qk += (mask* -1e9)
        weights = tf.nn.softmax(qk,axis=-1)
        return tf.matmul(weights,v),weights

    def split_heads(self,batch,x):
        x = tf.reshape(x,(batch,-1,self.heads,self.depth))
        return tf.transpose(x,perm=[0,2,1,3])
    def call(self,v,k,q,mask=None):
        batch = tf.shape(q)[0]
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)
        q = self.split_heads(q, batch)
        k = self.split_heads(k, batch)
        v = self.split_heads(v, batch)

        attention_output, attn_weights = self.attention(q, k, v, mask)

        attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3])
        concat_attention = tf.reshape(attention_output, (batch_size, -1, self.d_model))

        output = self.wo(concat_attention)

        return output, attn_weights
        

In [37]:
class normalize(tf.keras.layers.Layer):
    def __init__(self,d_model, eps:float = 10**-6):
        super().__init__()
        self.eps = eps
        self.gamma = self.add_weight(
            name="gamma",
            shape=(d_model,),
            initializer="ones",
            trainable=True
        )
        self.beta = self.add_weight(
            name="beta",
            shape=(d_model,),
            initializer="zeros",
            trainable=True
        )

    def call(self, x):
        mean = tf.reduce_mean(x, axis=-1, keepdims=True)
        variance = tf.reduce_mean(tf.square(x - mean), axis=-1, keepdims=True)
        normalized = (x - mean) / tf.sqrt(variance + self.eps)
        return self.gamma * normalized + self.beta

In [38]:
class residualconnections(tf.keras.layers.Layer):
    def __init__(self,d_model,dropout,layer):
        super().__init__()
        self.dropout = tf.keras.layers.Dropout(dropout)
        self.norm = normalize(d_model)
    def call(self,x,sublayer):
        return x + self.dropout(sublayer(self.norm(x)))
    

In [39]:
class EncoderBlock(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, d_ffn, dropout):
        super().__init__()
        self.mha = MHA(d_model, num_heads)
        self.ffn = FFD(d_ffn, d_model, dropout)

        self.res1 = residualconnections(d_model, dropout, self.mha)
        self.res2 = residualconnections(d_model, dropout, self.ffn)

    def call(self, x, mask=None):
        x = self.res1(x, lambda x: self.mha(x, x, x, mask)[0])
        x = self.res2(x, self.ffn)
        return x


In [40]:
class DecoderBlock(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, d_ffn, dropout):
        super().__init__()
        self.mha1 = MHA(d_model, num_heads)
        self.mha2 = MHA(d_model, num_heads)
        self.ffn = FFD(d_ffn, d_model, dropout)

        self.res1 = residualconnections(d_model, dropout, self.mha1)
        self.res2 = residualconnections(d_model, dropout, self.mha2)
        self.res3 = residualconnections(d_model, dropout, self.ffn)

    def call(self, x, enc_output, look_ahead_mask=None, padding_mask=None):
        x = self.res1(x, lambda x: self.mha1(x, x, x, look_ahead_mask)[0])
        x = self.res2(x, lambda x: self.mha2(x, enc_output, enc_output, padding_mask)[0])
        x = self.res3(x, self.ffn)
        return x


In [45]:
class Encoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, d_ffn, vocab_size, seq_len, dropout):
        super().__init__()
        self.embedding = InputEmbedding(d_model, vocab_size)
        self.pos_encoding = PositionalEncoding(d_model, seq_len, dropout)

        self.enc_blocks = [
            EncoderBlock(d_model, num_heads, d_ffn, dropout)
            for _ in range(num_layers)
        ]

    def call(self, x, mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)

        for block in self.enc_blocks:
            x = block(x, mask)
        return x


In [46]:
class Decoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, d_ffn, vocab_size, seq_len, dropout):
        super().__init__()
        self.embedding = InputEmbedding(d_model, vocab_size)
        self.pos_encoding = PositionalEncoding(d_model, seq_len, dropout)

        self.dec_blocks = [
            DecoderBlock(d_model, num_heads, d_ffn, dropout)
            for _ in range(num_layers)
        ]

    def call(self, x, enc_output, look_ahead_mask=None, padding_mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)

        for block in self.dec_blocks:
            x = block(x, enc_output, look_ahead_mask, padding_mask)
        return x


In [47]:
class ProjectionLayer(tf.keras.layers.Layer):
    def __init__(self, vocab_size):
        super().__init__()
        self.vocab_size = vocab_size
        self.dense = tf.keras.layers.Dense(vocab_size)

    def call(self, x):
        # x shape: (batch_size, seq_len, d_model)
        return self.dense(x)  # (batch_size, seq_len, vocab_size)


In [44]:
class Transformer(tf.keras.layers.Layer):
    def __init__(self, 
                 num_layers, d_model, num_heads, d_ffn, 
                 input_vocab_size, target_vocab_size, 
                 max_seq_len_input, max_seq_len_target, 
                 dropout=0.1):
        super().__init__()

        self.encoder = Encoder(num_layers, d_model, num_heads, d_ffn, input_vocab_size, max_seq_len_input, dropout)
        self.decoder = Decoder(num_layers, d_model, num_heads, d_ffn, target_vocab_size, max_seq_len_target, dropout)
        
        # No final_layer here — handled separately by ProjectionLayer
        

    def call(self, inputs, training=False):
        inp, tar, enc_padding_mask, look_ahead_mask, dec_padding_mask = inputs

        enc_output = self.encoder(inp, mask=enc_padding_mask)  # (batch_size, inp_seq_len, d_model)
        dec_output = self.decoder(tar, enc_output, look_ahead_mask, dec_padding_mask)  # (batch_size, tar_seq_len, d_model)
        
        return dec_output


In [49]:
def build_transformer(N, d_model, num_heads, d_ffn, 
                      input_vocab_size, target_vocab_size, 
                      max_seq_len_input, max_seq_len_target, 
                      dropout=0.1):
    transformer = Transformer(
        num_layers=N,               # N stacks of encoder & decoder
        d_model=d_model,
        num_heads=num_heads,
        d_ffn=d_ffn,
        input_vocab_size=input_vocab_size,
        target_vocab_size=target_vocab_size,
        max_seq_len_input=max_seq_len_input,
        max_seq_len_target=max_seq_len_target,
        dropout=dropout
    )
    
    projection_layer = ProjectionLayer(target_vocab_size)
    
    return transformer, projection_layer


In [53]:
N = 6  # Number of encoder & decoder layers

transformer, projection = build_transformer(
    N=N,
    d_model=512,
    num_heads=8,
    d_ffn=2048,
    input_vocab_size=8500,
    target_vocab_size=8000,
    max_seq_len_input=100,
    max_seq_len_target=100,
    dropout=0.1
)
