In [29]:
import tensorflow as tf
from tensorflow.keras.layers import Layer,Embedding,Dense,LayerNormalization,Dropout
from tensorflow.keras import Model

In [30]:
class MultiHeadSelfAttention(Layer):
    def __init__(self,d_model,num_heads):
        super().__init__()
        self.d_model=d_model
        self.num_heads=num_heads
        self.alt_head=d_model//num_heads
        self.wq=Dense(d_model)
        self.wk=Dense(d_model)
        self.wv=Dense(d_model)

        self.dense=Dense(d_model)
    def split_heads(self,x,batch_size):
        x=tf.reshape(x,(batch_size,-1,self.num_heads,self.alt_head))
        return tf.transpose(x,perm=[0,2,1,3])

    def call(self,q,k,v,mask):
        batch_size=tf.shape(q)[0]
        q=self.split_heads(self.wq(q),batch_size)
        k=self.split_heads(self.wk(k),batch_size)
        v=self.split_heads(self.wv(v),batch_size)
        mat_mul=tf.matmul(q,k,transpose_b=True)
        scaled_attention_logits=mat_mul/tf.math.sqrt(tf.cast(self.alt_head,tf.float32))

        if mask is not None:
           scaled_attention_logits+=(mask*-1e9)
        attention_weights=tf.nn.softmax(scaled_attention_logits,axis=-1)
        scaled_attention=tf.matmul(attention_weights,v)
        scaled_attention=tf.transpose(scaled_attention,perm=[0,2,1,3])
        concat_attention=tf.reshape(scaled_attention,(batch_size,-1,self.d_model))
        output=self.dense(concat_attention)
        return output

In [31]:
class FeedForward(Layer):
    def __init__(self,d_model,dff,dropout_rate=0.1):
        super().__init__()
        self.dense1=Dense(dff,activation='gelu')
        self.dense2=Dense(d_model)
        self.dropout=Dropout(dropout_rate)
    def call(self,x):
      return self.dense2(self.dense1(x))

In [32]:
class TransformBlock(Layer):
  def __init__(self,d_model,num_heads,dff,dropout_rate=0.1):
    super().__init__()
    self.att=MultiHeadSelfAttention(d_model,num_heads)
    self.ffn=FeedForward(d_model,dff)
    self.norm1=LayerNormalization(epsilon=1e-6)
    self.norm2=LayerNormalization(epsilon=1e-6)
    self.dropout1=Dropout(dropout_rate)
    self.dropout2=Dropout(dropout_rate)

  def call(self,x,mask=None):
    attn_output=self.att(x,x,x,mask)
    attn_output=self.dropout1(attn_output)
    out1=self.norm1(x+attn_output)
    ffn_output=self.ffn(out1)
    ffn_output=self.dropout2(ffn_output)
    return self.norm2(out1+ffn_output)

In [33]:
class GPT2(Model):
  def __init__(self,vocab_size,num_layers,d_model,num_heads,dff,dropout_rate=0.1):
    super().__init__()
    self.token_emb=Embedding(vocab_size,d_model)
    self.pos_emb=Embedding(vocab_size,d_model)

    self.TransformBlock=[TransformBlock(d_model,num_heads,dff,dropout_rate) for __ in range(num_layers)]
    self.norm=LayerNormalization(epsilon=1e-6)
    self.dropout=Dropout(dropout_rate)
    self.out = Dense(vocab_size)



  def create_casual_mask(self, seq_len):
    mask=1-tf.linalg.band_part(tf.ones((seq_len,seq_len)),-1,0)
    return mask


  def call(self,x):
    seq_len=tf.shape(x)[1]
    mask=self.create_casual_mask(seq_len)

    token_embeddings=self.token_emb(x)
    pos_embeddings=self.pos_emb(tf.range(seq_len)[:,tf.newaxis])
    x=token_embeddings+pos_embeddings
    x=self.dropout(x)

    for transformer in self.TransformBlock:
      x=transformer(x,mask)

    x=self.norm(x)
    return self.out(x)

In [34]:
vocab_size=50257
max_length=1024

d_model = 768
num_heads = 12
dff = 3072
num_layers = 12

inputs=tf.keras.layers.Input(shape=(max_length,), dtype=tf.int32)
outputs=GPT2(vocab_size=vocab_size, num_layers=num_layers, d_model=d_model, num_heads=num_heads, dff=dff)(inputs)
gpt2=Model(inputs,outputs)

gpt2.build(input_shape=(1,max_length))
gpt2.summary()