In [1]:
# !pip install transformers huggingface_hub
# !pip install tf-keras
import os
# Suppress specific TensorFlow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # 3 means to filter out all INFO and WARNING logs
import warnings
# Suppress specific warnings
warnings.filterwarnings('ignore', category=UserWarning, message='.*OUT_OF_RANGE.*')

import sys
sys.path.append('../src')
import random
import re
import pickle
import requests
import json
import pandas as pd
from lr_schedular import CustomSchedule
from transformer_encoder import TransformerEncoderV3  
from positional_encoding import encode_pos_sin_cosine
import seaborn as sns
import numpy as np
import nltk
from datasets import load_dataset
from transformers import BertTokenizer, BertTokenizerFast
import tensorflow as tf
from tensorflow.keras.layers import Embedding, Input, Dense, Dropout, Layer
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy, BinaryCrossentropy
from tensorflow.keras.preprocessing.sequence import pad_sequences


from transformers import TFPreTrainedModel, BertConfig
from transformers.utils import ModelOutput

In [2]:
class FlexibleGPTDecoderLayer(tf.keras.layers.Layer):
    def __init__(self, num_heads, d_model, dff, rate=0.1, use_bbmha=True, use_masked_softmax=False, row_mask_to_zero=False, **kwargs):
        super(FlexibleGPTDecoderLayer, self).__init__(**kwargs)
        self.use_bbmha = use_bbmha
        if use_bbmha:
            self.mha = MultiHeadAttentionV3(num_heads=num_heads, d_model=d_model, use_masked_softmax=use_masked_softmax, row_mask_to_sero=row_mask_to_zero)
        else:
            self.mha = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=d_model)
        
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(dff, activation='relu'),
            tf.keras.layers.Dense(d_model)
        ])

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)

    def call(self, x, training, mask=None):
        # Apply multi-head attention
        if self.use_bbmha:
            attn_output = self.mha(x, x, x, mask=mask, use_causal_mask=True)
        else:
            attn_output = self.mha(x, x, use_mask=mask)

        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)  # Apply residual connection after dropout

        # Feed-forward network
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        out2 = self.layernorm2(out1 + ffn_output)  # Second residual connection

        return out2

class FlexibleGPTDecoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, vocab_size, max_position_encoding, rate=0.1, use_bbmha=True, use_masked_softmax=False, row_mask_to_zero=False):
        super(FlexibleGPTDecoder, self).__init__()
        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
        self.pos_encoding = positional_encoding(max_position_encoding, d_model)

        self.dec_layers = [FlexibleGPTDecoderLayer(num_heads, d_model, dff, rate, use_bbmha, use_masked_softmax, row_mask_to_zero) for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(rate)

    def call(self, x, training):
        seq_len = tf.shape(x)[1]

        # Adding embedding and position encoding.
        x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)

        for i, dec_layer in enumerate(self.dec_layers):
            x = dec_layer(x, training, mask=look_ahead_mask(seq_len))

        return x  # (batch_size, target_seq_len, d_model)
