In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Embedding, Dense, LayerNormalization, Dropout
from tensorflow.keras import Model

## Multihead Self Attention Layer

In [None]:
class MultiHeadSelfAttention(Layer):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        # projections
        self.wq = Dense(embed_dim, use_bias=False)
        self.wk = Dense(embed_dim, use_bias=False)
        self.wv = Dense(embed_dim, use_bias=False)
        self.out = Dense(embed_dim, use_bias=False)

    def split_heads(self, x):
        # [B, T, E] -> [B, H, T, D]
        B = tf.shape(x)[0]
        T = tf.shape(x)[1]
        x = tf.reshape(x, [B, T, self.num_heads, self.head_dim])
        return tf.transpose(x, [0, 2, 1, 3])

    def combine_heads(self, x):
        # [B, H, T, D] -> [B, T, E]
        B = tf.shape(x)[0]
        T = tf.shape(x)[2]
        x = tf.transpose(x, [0, 2, 1, 3])
        return tf.reshape(x, [B, T, self.embed_dim])

    def call(self, x, mask=None, training=False):
        q = self.split_heads(self.wq(x))
        k = self.split_heads(self.wk(x))
        v = self.split_heads(self.wv(x))

        # scaled dot-product attention: [B,H,T,T]
        attn_logits = tf.matmul(q, k, transpose_b=True)
        attn_logits = attn_logits * (self.head_dim ** -0.5)

        if mask is not None:
            # mask: 1 for masked positions, 0 otherwise; broadcast to [B,H,T,T]
            attn_logits = attn_logits - 1e9 * tf.cast(mask, attn_logits.dtype)

        attn_weights = tf.nn.softmax(attn_logits, axis=-1)
        context = tf.matmul(attn_weights, v)
        context = self.combine_heads(context)
        return self.out(context)

## Feed Forward Network

In [None]:
class FeedForwardNetwork(Layer):
    def __init__(self, embed_dim, dff):
        super().__init__()
        self.dense1 = Dense(dff, activation=tf.nn.gelu)
        self.dense2 = Dense(embed_dim)

    def call(self, x, training=False):
        return self.dense2(self.dense1(x))

## Transformer Block

In [None]:
class TransformerBlock(Layer):
    def __init__(self, embed_dim, num_heads, dff, dropout_rate=0.1):
        super().__init__()
        self.mha = MultiHeadSelfAttention(embed_dim, num_heads)
        self.ffn = FeedForwardNetwork(embed_dim, dff)
        # pre-norm (stabler)
        self.norm1 = LayerNormalization(epsilon=1e-5)
        self.norm2 = LayerNormalization(epsilon=1e-5)
        self.drop1 = Dropout(dropout_rate)
        self.drop2 = Dropout(dropout_rate)

    def call(self, x, mask=None, training=False):
        h = x + self.drop1(self.mha(self.norm1(x), mask=mask, training=training), training=training)
        h = h + self.drop2(self.ffn(self.norm2(h), training=training), training=training)
        return h

## GPT Architecture


In [None]:
class RZY_GPT(Model):
    def __init__(self, vocab_size, max_length, embed_dim=768, num_heads=12, dff=3072, num_layers=12, dropout_rate=0.1):
        super().__init__()
        self.vocab_size = vocab_size
        self.max_length = max_length

        self.token_emb = Embedding(vocab_size, embed_dim)
        self.pos_emb = Embedding(max_length, embed_dim)

        self.blocks = [TransformerBlock(embed_dim, num_heads, dff, dropout_rate) for _ in range(num_layers)]
        self.norm = LayerNormalization(epsilon=1e-5)
        self.lm_head = Dense(vocab_size, use_bias=False)

    def causal_mask(self, seq_len):
        # [T,T] lower-tri allowed -> invert to 1 where masked
        m = 1 - tf.linalg.band_part(tf.ones([seq_len, seq_len], dtype=tf.int32), -1, 0)
        # [1,1,T,T] for broadcasting to [B,H,T,T]
        return m[tf.newaxis, tf.newaxis, :, :]

    def call(self, x, training=False):
        # x: [B, T] int32
        B = tf.shape(x)[0]
        T = tf.shape(x)[1]

        tok = self.token_emb(x)                    # [B, T, E]
        pos = self.pos_emb(tf.range(T))            # [T, E]
        pos = tf.broadcast_to(pos, [B, T, tf.shape(tok)[-1]])
        h = tok + pos

        mask = self.causal_mask(T)                 # [1,1,T,T]
        for blk in self.blocks:
            h = blk(h, mask=mask, training=training)

        h = self.norm(h)
        logits = self.lm_head(h)                   # [B, T, V]
        return logits

In [None]:
VOCAB_SIZE = 50257
MAX_LENGTH = 1024

inputs = tf.keras.Input(shape=(MAX_LENGTH,), dtype=tf.int32)
logits = RZY_GPT(vocab_size=VOCAB_SIZE, max_length=MAX_LENGTH)(inputs)
model = tf.keras.Model(inputs, logits)
model.summary()