In [None]:
# Import of all necessary libraries
!pip install datasets
!pip install transformers
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import LayerNormalization, Dropout, Dense, Embedding
from datasets import load_dataset, load_from_disk
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.callbacks import CSVLogger
from google.colab import files
import os
from keras.saving import register_keras_serializable
from transformers import AutoTokenizer
from tensorflow.keras.layers import GroupQueryAttention
import datasets

In [None]:
# Use TensorFlow mixed precision
from tensorflow.keras.mixed_precision import set_global_policy, Policy

policy = Policy('mixed_float16')
set_global_policy(policy)
compute_dtype = tf.keras.mixed_precision.global_policy().compute_dtype

In [None]:
# Parameters:
d_model = 960
num_layers = 16
num_heads = 16
d_ff = 3072
vocab_size = 30000
max_seq_len = 128
key_heads = 8
dropout_rate = 0.1
batch_size = 128
epochs = 2
train_dataset_path = "/content/drive/My Drive/MyModel/train_dataset"
val_dataset_path = "/content/drive/My Drive/MyModel/val_dataset"
drive_path = "/content/drive/My Drive/MyModel"

In [None]:
# Load a pretrained tokenizer from Drive
tokenizer = AutoTokenizer.from_pretrained(drive_path)
print(tokenizer.vocab_size)

In [None]:
# Load a subset of OpenWebText and save it in Drive
from datasets import Dataset, DatasetDict
import random

owt = load_dataset("openwebtext", split="train", streaming=True, trust_remote_code=True)
dataset = owt.shuffle(buffer_size=10_000, seed=42).take(2000000)

samples = []

for sample in dataset:
  samples.append(sample)
  if len(samples) >= 2000000:
    break

samples.sort(key=lambda x: len(x["text"]))

samples = samples[:1000000]

random.shuffle(samples)

tsamples = samples[:900000]
vsamples = samples[-100000:]

train = datasets.Dataset.from_dict({
    "text": [sample["text"] for sample in tsamples]
})

validation = datasets.Dataset.from_dict({
    "text": [sample["text"] for sample in vsamples]
})

dataset_dict = DatasetDict({
    "train": train,
    "validation": validation
})

dataset_dict.save_to_disk('/content/drive/MyDrive/my_dataset')


In [None]:
# Code used to train the Tokenizer
texts = [sample["text"] for sample in samples]

tokenizer = Tokenizer(BPE())

tokenizer.pre_tokenizer = ByteLevel()

special_tokens = ["<pad>", "<bos>", "<eos>", "<sep>", "<user>", "<assistant>", "<unk>", "<context>"]

trainer = BpeTrainer(
    vocab_size=30000,
    special_tokens=special_tokens
)

tokenizer.train_from_iterator(texts, trainer)

tokenizer.save(drive_path + '/tokenizer.json')

hf_tokenizer = PreTrainedTokenizerFast(
    tokenizer_object=tokenizer,
    bos_token="<bos>",
    eos_token="<eos>",
    pad_token="<pad>",
    unk_token="<unk>",
)

hf_tokenizer.save_pretrained(drive_path)

In [None]:
# Load the subset from Drive
dataset = load_from_disk('/content/drive/MyDrive/my_dataset')
train_dataset = dataset['train']
val_dataset = dataset['validation']
print(train_dataset)
print(val_dataset)

In [None]:
# Preprocessing logic
import re

def truncate_tokens(tokens, max_len):
    return tokens[:max_len]

def split_and_truncate(text, tokenizer, max_seq_len):
    chunks = []
    current_chunk = []

    pattern = r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s'
    sentences = re.split(pattern, text.replace('\n', ''))

    bos_token_id = tokenizer.bos_token_id
    eos_token_id = tokenizer.eos_token_id
    current_chunk.append(bos_token_id)

    for sentence in sentences:
        if not sentence.strip():
          continue
        sentence_tokens = tokenizer.encode(sentence, add_special_tokens=False)
        sentence_tokens = truncate_tokens(sentence_tokens, max_seq_len - 2)
        if len(current_chunk) + len(sentence_tokens) <= max_seq_len - 1:
            current_chunk.extend(sentence_tokens)
        elif len(current_chunk) > 1:
            current_chunk.append(eos_token_id)
            chunks.append(current_chunk)
            current_chunk = [bos_token_id] + sentence_tokens

    if len(current_chunk) > 1:
        current_chunk.append(eos_token_id)
        chunks.append(current_chunk)

    return chunks


def preprocess_function(examples, tokenizer, max_seq_len):
    pad_token_id = tokenizer.pad_token_id
    input_ids_list = []
    labels_list = []

    for text in examples["text"]:
        tokenized_chunks = split_and_truncate(text, tokenizer, max_seq_len)

        for seq in tokenized_chunks:
            input_ids = seq + [pad_token_id] * (max_seq_len - len(seq))

            labels = seq[1:] + [pad_token_id] * (max_seq_len - len(seq) + 1)

            input_ids_list.append(input_ids)
            labels_list.append(labels)

    return {
        "input_ids": input_ids_list,
        "labels": labels_list
    }


train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=["text"], fn_kwargs={"tokenizer": tokenizer, "max_seq_len": max_seq_len})
val_dataset = val_dataset.map(preprocess_function, batched=True, remove_columns=["text"], fn_kwargs={"tokenizer": tokenizer, "max_seq_len": max_seq_len})


In [None]:
# Saving the preprocessed dataset in Drive
from datasets import Dataset, DatasetDict

tokenized_dataset_dict = DatasetDict({
    "train": train_subset,
    "validation": val_subset
})

tokenized_dataset_dict.save_to_disk('/content/drive/MyDrive/my_tokenized_dataset')


In [None]:
# Loading preprocessed dataset
dataset = load_from_disk('/content/drive/MyDrive/my_tokenized_dataset')
train_dataset = dataset['train']
val_dataset = dataset['validation']
train_dataset.remove_columns('seq_len')
val_dataset.remove_columns('seq_len')

In [None]:
# Converting the HugginFace dataset to Tensorflow dataset
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator(return_tensors="tf")

train_tf_dataset = train_dataset.to_tf_dataset(
    columns="input_ids",
    label_cols="labels",
    batch_size=batch_size,
    shuffle=True,
    collate_fn=data_collator
)

val_tf_dataset = val_dataset.to_tf_dataset(
    columns="input_ids",
    label_cols="labels",
    batch_size=batch_size,
    shuffle=False,
    collate_fn=data_collator
)

In [None]:
# Functions to create look-ahead and padding masks
@register_keras_serializable(package='CustomTransformer', name='CreateLookAheadMask')
def create_look_ahead_mask(seq_len, batch_size=1):
    mask = tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
    mask = tf.expand_dims(mask, axis=0)
    mask = tf.tile(mask, [batch_size, 1, 1])
    return tf.cast(mask, tf.bool)

@register_keras_serializable(package='CustomTransformer', name='PaddingMask')
def create_padding_mask(x):
    padding_mask = tf.cast(tf.math.not_equal(x, 0), tf.bool)
    padding_mask = tf.expand_dims(padding_mask, axis=1)
    return padding_mask

In [None]:
# Creation of a custom Decoder Layer with keras
@register_keras_serializable(package='CustomTransformer', name='DecoderLayer')
class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, key_heads, d_ff, dropout_rate=0.1, **kwargs):
        super(DecoderLayer, self).__init__(**kwargs)
        self.d_model = d_model
        self.num_heads = num_heads
        self.key_heads = key_heads
        self.d_ff = d_ff
        self.dropout_rate = dropout_rate
        self.query_head_dim = d_model // num_heads
        self.key_value_head_dim = d_model // key_heads
        self.gqa = GroupQueryAttention(
            head_dim=self.query_head_dim,
            num_query_heads=num_heads,
            num_key_value_heads=key_heads,
            dropout=dropout_rate,
        )
        self.ffn = tf.keras.Sequential([
            tf.keras.layers.Dense(d_ff, activation="gelu", kernel_initializer="glorot_uniform"),
            tf.keras.layers.Dense(d_model, kernel_initializer="glorot_uniform"),
        ])
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)


    def call(self, x, training=False, look_ahead_mask=None):
        x = tf.cast(x, dtype=tf.float16)
        x_norm = self.layernorm1(x)
        attn_output = self.gqa(query=x_norm, key=x_norm, value=x_norm, attention_mask=look_ahead_mask, training=training)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = x + attn_output
        ffn_output = self.ffn(self.layernorm2(out1))
        ffn_output = self.dropout2(ffn_output, training=training)
        return out1 + ffn_output

    def get_config(self):
        config = super(DecoderLayer, self).get_config()
        config.update({
            "d_model": self.d_model,
            "num_heads": self.num_heads,
            "d_ff": self.d_ff,
            "key_heads": self.key_heads,
            "dropout_rate": self.dropout_rate,
            "query_head_dim": self.query_head_dim,
            "key_value_head_dim": self.key_value_head_dim
        })
        return config

In [None]:
# Declaring a stack of Decoder Layers along with the embedding layer and positional encoding
@register_keras_serializable(package='CustomTransformer', name='Decoder')
class Decoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, key_heads, d_ff, vocab_size, max_seq_len, dropout_rate=0.1, **kwargs):
        super(Decoder, self).__init__(**kwargs)
        self.d_model = d_model
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.key_heads = key_heads
        self.dropout_rate = dropout_rate
        self.embedding = tf.keras.layers.Embedding(vocab_size, d_model)
        self.pos_encoding = self.positional_encoding(max_seq_len, d_model)
        self.dec_layers = [DecoderLayer(d_model, num_heads, key_heads, d_ff, dropout_rate) for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def positional_encoding(self, position, d_model):
        angle_rads = self.get_angles(
            tf.range(position, dtype=tf.float32)[:, tf.newaxis],
            tf.range(d_model, dtype=tf.float32)[tf.newaxis, :],
            d_model
        )
        sines = tf.sin(angle_rads[:, 0::2])
        cosines = tf.cos(angle_rads[:, 1::2])
        pos_encoding = tf.concat([sines, cosines], axis=-1)
        return tf.cast(pos_encoding[tf.newaxis, ...], dtype=tf.float16)

    @staticmethod
    def get_angles(pos, i, d_model):
        angle_rates = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(d_model, tf.float32))
        return tf.cast(pos, tf.float32) * angle_rates

    def call(self, x, training=None, look_ahead_mask=None):
        seq_len = tf.shape(x)[1]
        x = self.embedding(x) * tf.math.sqrt(tf.cast(self.d_model, tf.float16))
        x += self.pos_encoding[:, :seq_len, :]
        x = self.dropout(x, training=training)
        for i in range(self.num_layers):
            x = self.dec_layers[i](x, training=training, look_ahead_mask=look_ahead_mask)
        return x

    def get_config(self):
        config = super(Decoder, self).get_config()
        config.update({
            "num_layers": self.num_layers,
            "d_model": self.d_model,
            "num_heads": self.num_heads,
            "key_heads": self.key_heads,
            "d_ff": self.d_ff,
            "vocab_size": self.vocab_size,
            "max_seq_len": self.max_seq_len,
            "dropout_rate": self.dropout_rate,
        })
        return config


In [None]:
# Declaring the decoder along with the compute of masks and the last ffn
@register_keras_serializable(package='CustomTransformer', name='DecoderOnlyTransformer')
class DecoderOnlyTransformer(tf.keras.models.Model):
    def __init__(self, num_layers, d_model, num_heads, d_ff, vocab_size, max_seq_len, key_heads, dropout_rate=0.1, **kwargs):
        super(DecoderOnlyTransformer, self).__init__(**kwargs)
        self.decoder = Decoder(
            num_layers=num_layers,
            d_model=d_model,
            num_heads=num_heads,
            key_heads=key_heads,
            d_ff=d_ff,
            vocab_size=vocab_size,
            max_seq_len=max_seq_len,
            dropout_rate=dropout_rate)
        self.final_layer = tf.keras.layers.Dense(vocab_size)
        self.max_seq_len = max_seq_len
        self.num_heads = num_heads
        self.d_model = d_model
        self.num_layers = num_layers
        self.d_ff = d_ff
        self.vocab_size = vocab_size
        self.key_heads = key_heads
        self.dropout_rate = dropout_rate

    def call(self, x, training=None):
        if isinstance(x, tuple): #Check if inputs are touples
            x = x[0]
        elif isinstance(x, dict): #Check if inputs are dictionaries
            x = x["input_ids"]
        seq_len = tf.shape(x)[1]
        batch_size = tf.shape(x)[0]
        look_ahead_mask = create_look_ahead_mask(seq_len, batch_size)
        padding_mask = create_padding_mask(x)
        combined_mask = tf.logical_and(look_ahead_mask, padding_mask)
        dec_output = self.decoder(x, training=training, look_ahead_mask=combined_mask)
        logits = self.final_layer(dec_output)
        return tf.cast(logits, tf.float32)

    def get_config(self):
        config = super(DecoderOnlyTransformer, self).get_config()
        config.update({
            "num_layers": self.num_layers,
            "d_model": self.d_model,
            "num_heads": self.num_heads,
            "key_heads": self.key_heads,
            "d_ff": self.d_ff,
            "vocab_size": self.vocab_size,
            "max_seq_len": self.max_seq_len,
            "dropout_rate": self.dropout_rate,
        })
        return config


In [None]:
# Declaration of the CallBacks
from tensorflow.keras.callbacks import TensorBoard
import datetime

%load_ext tensorboard
%tensorboard --logdir logs/fit

log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

tensorboard_callback = TensorBoard(
    log_dir=log_dir,
    histogram_freq=1,
    write_graph=True,
    write_images=True,
    update_freq=200,
    profile_batch=0,
)

checkpoint_callback = ModelCheckpoint(
    filepath=drive_path + '/saved_model_checkpoints/modelcheckpoint.keras',
    monitor='val_loss',
    save_best_only=True,
    save_weights_only=False,
    verbose=1
)

class DebugNanCallback(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        if logs is not None and tf.math.is_nan(logs.get("loss")):
            print(f"NaN loss encountered at batch {batch}.")
            self.model.stop_training = True
debugnancallback = DebugNanCallback()

callbacks = [
    checkpoint_callback,
    debugnancallback,
]

In [None]:
# Declaration of a custom learning rate, masked loss function, masked accuracy, model and compilation of the model.
from tensorflow.keras import losses
from tensorflow.keras.optimizers.schedules import LearningRateSchedule

@register_keras_serializable(package='CustomTransformer', name='warmupdecay')
class WarmupCosineDecay(LearningRateSchedule):
    def __init__(self, warmup_steps, initial_lr, first_decay_steps, t_mul=1.5, m_mul=1.0, alpha=0.0):
        super().__init__()
        self.warmup_steps = warmup_steps
        self.initial_lr = initial_lr
        self.first_decay_steps = first_decay_steps
        self.t_mul = t_mul
        self.m_mul = m_mul
        self.alpha = alpha

        self.cosine_decay = tf.keras.optimizers.schedules.CosineDecayRestarts(
            initial_learning_rate=initial_lr,
            first_decay_steps=first_decay_steps,
            t_mul=t_mul,
            m_mul=m_mul,
            alpha=alpha
        )

    def __call__(self, step):
        warmup_lr = (self.initial_lr / tf.cast(self.warmup_steps, tf.float32)) * tf.cast(tf.convert_to_tensor(step), tf.float32)
        return tf.cond(
            step < self.warmup_steps,
            lambda: warmup_lr,
            lambda: self.cosine_decay(step - self.warmup_steps)
        )

    def get_config(self):
        config = {
            "warmup_steps": self.warmup_steps,
            "initial_lr": self.initial_lr,
            "first_decay_steps": self.first_decay_steps,
            "t_mul": self.t_mul,
            "m_mul": self.m_mul,
            "alpha": self.alpha
        }

        return config

@register_keras_serializable(package='CustomTransformer', name='masked_crossentropy')
def masked_crossentropy(y_true, y_pred):
    mask = tf.math.logical_not(tf.math.equal(y_true, 0))
    loss = losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)
    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask
    return tf.reduce_sum(loss) / tf.reduce_sum(mask)

@register_keras_serializable(package='CustomTransformer', name='maskedaccuracy')
def masked_accuracy(y_true, y_pred):
    mask = tf.math.logical_not(tf.math.equal(y_true, 0))
    acc = tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred)
    mask = tf.cast(mask, dtype=acc.dtype)
    acc *= mask
    return tf.reduce_sum(acc) / tf.reduce_sum(mask)

model = DecoderOnlyTransformer(
    num_layers=num_layers,
    d_model=d_model,
    num_heads=num_heads,
    d_ff=d_ff,
    key_heads=key_heads,
    vocab_size=vocab_size,
    max_seq_len=max_seq_len,
    dropout_rate=dropout_rate)

lr_schedule = WarmupCosineDecay(
    warmup_steps=1000,
    initial_lr=5e-5,
    first_decay_steps=8000,
    m_mul=1.0,
    alpha=0.0
)

model.compile(
    optimizer = tf.keras.optimizers.AdamW(
        learning_rate=lr_schedule,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-7,
        weight_decay=0.01,
        clipnorm=1.0
    ),
    loss=masked_crossentropy,
    metrics=[
        masked_accuracy
        ]
)

In [None]:
# Training, testing and saving logic
import os

tf.config.optimizer.set_jit(True)

history = model.fit(
    train_tf_dataset,
    validation_data=val_tf_dataset,
    epochs=epochs,
    callbacks=[tensorboard_callback]
)
os.makedirs(drive_path, exist_ok=True)
model_path = f'{drive_path}/astera_01.keras'
model.save(model_path)
val_loss, val_acc = model.evaluate(val_tf_dataset)
print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")