# Imports

In [119]:
# !pip install transformers huggingface_hub
# !pip install tf-keras
import os
# Suppress specific TensorFlow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # 3 means to filter out all INFO and WARNING logs
import warnings
# Suppress specific warnings
warnings.filterwarnings('ignore', category=UserWarning, message='.*OUT_OF_RANGE.*')

import sys
sys.path.append('../src')
import random
import re
import pickle
import requests
import json
import pandas as pd
from lr_schedular import CustomSchedule
from transformer_encoder import TransformerEncoderV3  
from positional_encoding import encode_pos_sin_cosine
import seaborn as sns
import numpy as np
import nltk
from datasets import load_dataset
from transformers import BertTokenizer, BertTokenizerFast
import tensorflow as tf
from tensorflow.keras.layers import Embedding, Input, Dense, Dropout, Layer
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy, BinaryCrossentropy
from tensorflow.keras.preprocessing.sequence import pad_sequences


from transformers import TFPreTrainedModel, BertConfig
from transformers.utils import ModelOutput

# parameters 

In [127]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
vocab_size = tokenizer.vocab_size  # Smaller vocabulary size for simplicity
print(vocab_size)
num_layers = 4  # Fewer layers
d_model = 768  # Smaller dimensionality
num_heads = 4
dff = 3072
segment_size = 2
max_seq_length = 32
max_predictions_per_seq=5
max_position_embeddings=2048
hidden_dropout_prob=0.1
batch_size = 16
COLAB = False
if not COLAB:
      model_path= '/mnt/d/MyDev/attention/transformerlab/bert/models'
      model_weights = 'bert_pretrained_1.weights.h5'
model_full_name = os.path.join(model_path, model_weights)

30522


# Pretraining

## pretrianing data  

### load wiki data from hugging face

In [2]:
from datasets import load_dataset
# Load an example dataset, 'wikipedia' for English, 2020-03-01 version
dataset = load_dataset("wikipedia", "20220301.en", split="train")
print('how the dataset looks:', dataset[0].keys())
num_of_articles = 1000

how the dataset looks: dict_keys(['id', 'url', 'title', 'text'])


In [3]:
first_article = dataset[0]['text']
print('words in first_article:', len(first_article))
print('how the article look:\n',first_article[:1500])
print('..........................................\n........................')
print(first_article[-500:])

words in first_article: 43985
how the article look:
 Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy. Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful. As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.

Humans lived in societies without formal hierarchies long before the establishment of formal states, realms, or empires. With the rise of organised hierarchical bodies, scepticism toward authority also rose. Although traces of anarchist thought are found throughout history, modern anarchism emerged from the Enlightenment. During the latter half of the 19th and the first decades of th

Document and Sentence Segmentation:

Document Boundary: Each Wikipedia article can be treated as a single document. This aligns with the BERT requirement where each document is separated by blank lines.

Sentence Tokenization: Use a sentence tokenizer to convert each paragraph into distinct sentences. This is crucial because BERT's NSP task assumes that two consecutive sentences in the data might be used as training pairs.

### save the data with article separator

In [4]:
def save_articles_with_doc_boundary(dataset_name, config_name, split_name, output_file_path, num_of_articles=1000):
    nltk.download('punkt')
    dataset = load_dataset(dataset_name, config_name, split=split_name)
    articles_to_process = dataset.select(range(num_of_articles))
    with open(output_file_path, 'w', encoding='utf-8') as file:
        for i, article in enumerate(articles_to_process):
            title = article.get('title', f"No Title Available for Article {i}")
            art_id = article.get('id', "No ID")
            art_url = article.get('url', "No URL")
            sentences = nltk.sent_tokenize(article['text'])
            full_article_text = f"ARTICLE-{i}-{art_id}-{art_url}-{title}\n" + '\n'.join(sentences) + '\n\n'
            file.write(full_article_text)
# Call the function to process and save articles
output_file_path = 'wiki_articles_with_seperator.txt'
save_articles_with_doc_boundary('wikipedia', '20220301.en', 'train', output_file_path)

[nltk_data] Downloading package punkt to /home/bhujay/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [5]:
def display_article_lines(file_path, num_lines=5, num_articles=5):    
    with open(file_path, 'r', encoding='utf-8') as file:
        article_count = 0
        line_count = 0
        for line in file:
            if line.startswith('ARTICLE-'):  # New article detected
                if article_count >= num_articles:
                    break                          
                article_count += 1
                line_count = 0  # Reset line count for the new article
            if line_count < num_lines:
                print(line)
                line_count += 1
            else:
                continue  # Skip further lines until the next article starts
output_file_path = 'wiki_articles_with_seperator.txt'
display_article_lines(output_file_path, num_lines=5, num_articles=5)

ARTICLE-0-12-https://en.wikipedia.org/wiki/Anarchism-Anarchism

Anarchism is a political philosophy and movement that is sceptical of authority and rejects all involuntary, coercive forms of hierarchy.

Anarchism calls for the abolition of the state, which it holds to be unnecessary, undesirable, and harmful.

As a historically left-wing movement, placed on the farthest left of the political spectrum, it is usually described alongside communalism and libertarian Marxism as the libertarian wing (libertarian socialism) of the socialist movement, and has a strong historical association with anti-capitalism and socialism.

Humans lived in societies without formal hierarchies long before the establishment of formal states, realms, or empires.

ARTICLE-1-25-https://en.wikipedia.org/wiki/Autism-Autism

Autism is a neurodevelopmental disorder characterized by difficulties with social interaction and communication, and by restricted and repetitive behavior.

Parents often notice signs during th

### masking , truncating function and training instance class

In [6]:
class TrainingInstance:
    """A single training instance (sentence pair)."""
    def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels, is_random_next):
        self.tokens = tokens
        self.segment_ids = segment_ids
        self.masked_lm_positions = masked_lm_positions
        self.masked_lm_labels = masked_lm_labels
        self.is_random_next = is_random_next

    def __str__(self):
        tokens_str = " ".join([str(token) for token in self.tokens])
        segment_ids_str = " ".join(map(str, self.segment_ids))
        masked_lm_positions_str = " ".join(map(str, self.masked_lm_positions))
        masked_lm_labels_str = " ".join([str(label) for label in self.masked_lm_labels])
        return f"Tokens: {tokens_str}\nSegment IDs: {segment_ids_str}\n" \
               f"Is Random Next: {self.is_random_next}\n" \
               f"Masked LM Positions: {masked_lm_positions_str}\n" \
               f"Masked LM Labels: {masked_lm_labels_str}\n"

    def __repr__(self):
        return self.__str__()

def mask_tokens(tokens, tokenizer, max_predictions_per_seq, rng):
    """Masks tokens and returns masked tokens and corresponding labels."""
    output_tokens = tokens[:]
    output_labels = [-1] * len(tokens)  # Initialize labels with -1 (no change)

    # Determine which tokens can be masked
    candidate_indices = [
        i for i, token in enumerate(tokens) 
        if token not in [tokenizer.cls_token, tokenizer.sep_token]
    ]
    rng.shuffle(candidate_indices)
    num_masked = min(max_predictions_per_seq, len(candidate_indices) * 15 // 100)
    
    for index in candidate_indices[:num_masked]:
        random_choice = rng.random()
        # 80% replace with [MASK], 10% random token, 10% unchanged
        if random_choice < 0.8:
            output_tokens[index] = tokenizer.mask_token
        elif random_choice < 0.9:
            output_tokens[index] = random.choice(list(tokenizer.vocab.keys()))
        
        output_labels[index] = tokenizer.convert_tokens_to_ids(tokens[index])

    return output_tokens, output_labels

def truncate_and_process(tokens_a, tokens_b, max_seq_length, tokenizer, max_predictions_per_seq, instances, rng, is_random_next):
    # Truncate tokens_a and tokens_b if their combined length is too long
    while len(tokens_a) + len(tokens_b) + 3 > max_seq_length:
        if len(tokens_a) > len(tokens_b):
            tokens_a.pop()
        else:
            tokens_b.pop()

    tokens = ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]']
    segment_ids = [0] * (len(tokens_a) + 2) + [1] * (len(tokens_b) + 1)

    masked_tokens, masked_labels = mask_tokens(tokens, tokenizer, max_predictions_per_seq, rng)

    # Convert masked_tokens to IDs
    token_ids = tokenizer.convert_tokens_to_ids(masked_tokens)  # Ensure this returns integers
    instance = TrainingInstance(
        tokens=token_ids,
        segment_ids=segment_ids,
        masked_lm_positions=[i for i, label in enumerate(masked_labels) if label != -1],
        masked_lm_labels=[label for label in masked_labels if label != -1],
        is_random_next=int(is_random_next)
    )
    # instance = {
    #     'tokens': token_ids,
    #     'segment_ids': segment_ids,
    #     'masked_lm_positions': [i for i, label in enumerate(masked_labels) if label != -1],
    #     'masked_lm_labels': [label for label in masked_labels if label != -1],
    #     'is_random_next': int(is_random_next)
    # }
    instances.append(instance)



### create pre-training data with masking and nsp 

In [85]:
def create_bert_pretraining_instances_in_chunks(file_path, chunk_size=1048576, 
                          doc_boundary_pattern=r'ARTICLE-\d+-\d+-https:\/\/\S+',
                         test_print=10, max_seq_length=max_seq_length, 
                         max_predictions_per_seq=max_predictions_per_seq, 
                         dupe_factor=5, random_seed=12345, nsp_enabled=True,
                                               tokenizer=tokenizer):
    # tokenizer = BertTokenizer.from_pretrained(args.vocab_file, do_lower_case=args.do_lower_case)  ## is for custom vocab
    # tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') ### This loads the default vocab for bert pre-training
    rng = random.Random(random_seed)
    buffer = ''
    instances = []    
    with open(file_path, 'r', encoding='utf-8') as file:
        while True:
            chunk = file.read(chunk_size)
            if not chunk:
                break
            buffer += chunk
            documents = re.split(doc_boundary_pattern, buffer, flags=re.MULTILINE)            
            if documents and not re.match(doc_boundary_pattern, documents[-1]):
                buffer = documents.pop()
            else:
                buffer = ''            
            for doc in documents:
                if not doc.strip():
                    continue
                tokenized_doc = tokenizer.tokenize(doc)
                sequences = [tokenized_doc[i:i+max_seq_length] for i in range(0, len(tokenized_doc), max_seq_length)]                
                for j in range(len(sequences) - 1):
                    tokens_a = sequences[j]
                    if rng.random() > 0.5 or not nsp_enabled:
                        is_random_next = True
                        tokens_b = sequences[rng.randint(0, len(sequences) - 1)]
                    else:
                        is_random_next = False
                        tokens_b = sequences[j + 1]

                    truncate_and_process(tokens_a, tokens_b, max_seq_length, tokenizer, max_predictions_per_seq, instances, rng, is_random_next)

                    if test_print > 0:
                        print(f"Tokens A: {tokens_a}, len: {len(tokens_a)}")
                        print(f"Tokens B: {tokens_b[:10]}")
                        print(f"Is random next: {is_random_next}\n")
                        test_print -= 1
                        if test_print == 0:
                            return instances
    return instances  # Return all instances for further processing or training
file_path = 'wiki_articles_with_seperator.txt'
res = create_bert_pretraining_instances_in_chunks(file_path)

Tokens A: ['ana', '##rch', '##ism', 'is', 'a', 'political', 'philosophy', 'and', 'movement', 'that', 'is', 'sc', '##ept', '##ical', 'of'], len: 15
Tokens B: ['for', 'the', 'abolition', 'of', 'the', 'state', ',', 'which', 'it', 'holds']
Is random next: False

Tokens A: ['for', 'the', 'abolition', 'of', 'the', 'state', ',', 'which', 'it', 'holds', 'to', 'be', 'unnecessary', ','], len: 14
Tokens B: ['the', 'far', '##thest', 'left', 'of', 'the', 'political', 'spectrum', ',', 'it']
Is random next: False

Tokens A: ['the', 'far', '##thest', 'left', 'of', 'the', 'political', 'spectrum', ',', 'it', 'is', 'usually', 'described', 'alongside', 'communal'], len: 15
Tokens B: ['and', 'has', 'a', 'strong', 'historical', 'association', 'with', 'anti', '-', 'capitalism']
Is random next: False

Tokens A: ['and', 'has', 'a', 'strong', 'historical', 'association', 'with', 'anti', '-', 'capitalism', 'and', 'socialism', '.', 'humans'], len: 14
Tokens B: ['fiction', 'due', 'to', 'the', 'fact', 'that', 'the'

In [86]:
# print(res[0])
# # d = res[0]
# print(d.__dict__)

### save the mlm and nsp data on disk

In [87]:
def save_instances_as_parquet(instances, file_path, num_instances=1000, small=False):
    if small:
        instances = instances[:num_instances]
    data = {
        'input_ids': [],
        'segment_ids': [],
        'masked_lm_positions': [],
        'mlm_labels': [],
        'nsp_labels': []
    }
    for instance in instances:
        data['input_ids'].append(instance.tokens)
        data['segment_ids'].append(instance.segment_ids)
        data['masked_lm_positions'].append(instance.masked_lm_positions)
        data['mlm_labels'].append(instance.masked_lm_labels)
        data['nsp_labels'].append(instance.is_random_next)
    df = pd.DataFrame(data)
    df.to_parquet(file_path, engine='pyarrow')
instances = create_bert_pretraining_instances_in_chunks(file_path, test_print=0)
save_instances_as_parquet(instances, 'pretraining_bert_data.parquet')

### load pretraining as tenslorflow dataset 

In [5]:
def load_dataset(file_path, batch_size=32):
    df = pd.read_parquet(file_path)
    tensor_dict = {}
    for col in df.columns:
        if isinstance(df[col].values[0], list) or isinstance(df[col].values[0], np.ndarray):
            # Handle list or array: pad sequences and convert to tensor
            padded_array = tf.keras.preprocessing.sequence.pad_sequences(df[col].tolist(), padding='post', dtype='int32')
            tensor_dict[col] = tf.convert_to_tensor(padded_array, dtype=tf.int32)
        else:
            # Convert scalar values directly to tensor
            tensor_dict[col] = tf.convert_to_tensor(df[col].to_numpy(dtype=np.int32), dtype=tf.int32)

    # Split the tensors into inputs and labels
    inputs = {k: tensor_dict[k] for k in ['input_ids', 'segment_ids']}
    labels = {k: tensor_dict[k] for k in ['masked_lm_positions', 'mlm_labels', 'nsp_labels']}
    # Combine into a single dataset
    dataset = tf.data.Dataset.from_tensor_slices((inputs, labels))
    dataset = dataset.batch(batch_size)
    return dataset


train_dataset = load_dataset('pretraining_bert_data.parquet', batch_size=batch_size)

for inputs, labels in train_dataset.take(1):
    print("Inputs:", inputs)
    print("Labels:", labels)
    break


Inputs: {'input_ids': <tf.Tensor: shape=(16, 32), dtype=int32, numpy=
array([[  101,  9617,   103,  2964,  2003,  1037,  2576,  4695,  1998,
         2929,  2008,  2003,  8040,   103,  7476,   103,   102,  2005,
          103, 15766,  1997,  1996,  2110,  1010,  2029,  2009,  4324,
         2000,  2022, 14203,  1010,   102],
       [  101,  2005,  1996, 15766,  1997,   103,  2110,  1010,   103,
         2009,  4324,  2000,   103,   103,  1010,   102,  1996,  2521,
        20515,  2187,  1997,  1996,  2576,  8674,  1010,  2009,  2003,
         2788,  2649,  4077, 15029,   102],
       [  101,  1996,  2521, 20515,   103,  1997,  1996,  2576,  8674,
         1010,  2009,  2003,  2788,  2649,  4077, 15029,   102,  1998,
          103,  1037,  2844,  3439,   103,  2007,  3424,   103, 16498,
         1998, 14649,  1012,  4286,   102],
       [  101,  1998,  2038,  1037,  2844,  3439,  2523,  2007,  3424,
         1011, 16498,  1998, 14649,  1012,  4286,   102,  4349,  2349,
          103,  1

In [89]:
### Convert these articles to pretraining data 
# !python create_pretraining_data.py --vocab_file vocab.txt --input_text input_text.txt --output_tfrecord output.tfrecord --do_lower_case --nsp


In [8]:
single_test_instance = iter(train_dataset.take(1)).next()
single_input_tuple = single_test_instance[0]['input_ids'], single_test_instance[0]['segment_ids']
# print(single_test_instance)
print(single_test_instance[0]['input_ids'][0])
print(single_test_instance[1]['masked_lm_positions'][0])
print(single_test_instance[1]['mlm_labels'][0])

tf.Tensor(
[  101  9617   103  2964  2003  1037  2576  4695  1998  2929  2008  2003
  8040   103  7476   103   102  2005   103 15766  1997  1996  2110  1010
  2029  2009  4324  2000  2022 14203  1010   102], shape=(32,), dtype=int32)
tf.Tensor([ 2 13 15 18], shape=(4,), dtype=int32)
tf.Tensor([11140 23606  1997  1996], shape=(4,), dtype=int32)


## model architecture

### positional encoding with segment embedding 

In [6]:
class PositionalAndSegmentEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size, segment_size, d_model, max_pos=2048, pos_dropout=0.1, **kwargs):
        super().__init__(**kwargs)  # Initialize the superclass (Layer)
        self.d_model = d_model  # Store the dimensionality of the model embeddings
        self.token_embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
        self.segment_embedding = tf.keras.layers.Embedding(segment_size, d_model)
        self.pos_encoding = encode_pos_sin_cosine(max_pos, d_model, debug=False)
        self.dropout = tf.keras.layers.Dropout(pos_dropout)

    def compute_mask(self, inputs, *args, **kwargs):
        # Assuming the input structure is a tuple of (tokens, segments)
        token_inputs, _ = inputs
        return self.token_embedding.compute_mask(token_inputs, *args, **kwargs)

    def call(self, inputs, training=False):
        # Expect inputs to be a tuple (token_inputs, segment_inputs)
        token_inputs, segment_inputs = inputs
        tokens = self.token_embedding(token_inputs)  # Token embeddings
        segments = self.segment_embedding(segment_inputs)  # Segment embeddings       
        x = tokens + segments
        # Scale the embeddings by the square root of the embedding dimension size
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        # Add positional encoding to the combined embeddings, sliced to match the input length
        length = tf.shape(x)[1]
        # pos_encodings = tf.reshape(self.pos_encoding, (1, -1, self.d_model))[:, :length, :]
        pos_encodings = tf.cast(tf.reshape(self.pos_encoding, (1, -1, self.d_model))[:, :tf.shape(x)[1], :], tf.float32)
        x += pos_encodings
        x = self.dropout(x, training=training)
        return x

embedding_layer = PositionalAndSegmentEmbedding(vocab_size=vocab_size, segment_size=2, d_model=256)

# Extract a single batch from the dataset
for inputs, labels in train_dataset.take(1):
    # The inputs dictionary contains 'input_ids' and 'segment_ids'
    input_ids = inputs['input_ids']
    segment_ids = inputs['segment_ids']

    # Call the embedding layer
    embeddings = embedding_layer((input_ids, segment_ids))

    # Print the output shape
    print("Output shape:", embeddings.shape)
    print("Output shape:", embeddings._keras_mask)
    break

Output shape: (16, 32, 256)
Output shape: tf.Tensor(
[[ True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True  True  True  True  True
   True  True  True  True  True  True  True  True  True  True  True  True
   

### Transformer Encoder with PositionalEncoding with Segment ids

In [145]:
class TransformerEncoderV4(TransformerEncoderV3):
    def __init__(self, num_layers, d_model, num_heads, dff, vocab_size, segment_size, max_pos=2048, pos_dropout=0.1, **kwargs):
        super(TransformerEncoderV4, self).__init__(num_layers=num_layers, d_model=d_model, num_heads=num_heads, dff=dff, vocab_size=vocab_size, max_pos=max_pos, **kwargs)
        # Use the custom embedding layer that handles tokens, segments, and positional encodings
        self.embedding_layer = PositionalAndSegmentEmbedding(vocab_size, segment_size, d_model, max_pos, pos_dropout)

    def call(self, inputs, training=False):
        input_ids, segment_ids = inputs
        # The embedding layer now handles everything including token, segment, and positional embeddings
        x = self.embedding_layer((input_ids, segment_ids), training=training)
        x = self.enc_layers_0(x, training=training)
        for i in range(self.remaining_layers):
            x = self.enc_layers[i](x, training=training)
        return x
tren = TransformerEncoderV4(num_layers, d_model, num_heads, dff, vocab_size, segment_size, max_pos=max_seq_length)

encoder_out = tren(single_input_tuple)
print(encoder_out.shape)

(16, 32, 768)


### BERT model class 

In [13]:
class BERT(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, vocab_size, segment_size, max_seq_length=128, rate=0.1):
        super(BERT, self).__init__()
        self.encoder = TransformerEncoderV4(num_layers=num_layers, d_model=d_model, num_heads=num_heads,
                                            dff=dff, vocab_size=vocab_size, segment_size=segment_size,
                                            max_pos=max_seq_length, pos_dropout=rate)
        self.mlm_dense_transform = tf.keras.layers.Dense(d_model, activation='gelu')  # Transform layer for MLM
        self.mlm_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12)  # Layer normalization for MLM
        self.mlm_dense = tf.keras.layers.Dense(vocab_size)  # Ensures output shape is [batch, seq_length, vocab_size]
        self.nsp_dense = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, inputs, training=False):
        x = self.encoder((inputs['input_ids'], inputs['segment_ids']), training=training)
        
        # Apply dense transformation and layer normalization for MLM
        mlm_output = self.mlm_dense_transform(x)
        mlm_output = self.mlm_layer_norm(mlm_output)
        mlm_output = self.mlm_dense(mlm_output)
        
        nsp_output = self.nsp_dense(x[:, 0, :])
        return {'mlm_output': mlm_output, 'nsp_output': nsp_output}

# Instantiate and compile the BERT model
bert_model = BERT(num_layers, d_model, num_heads, 
                  dff, vocab_size, segment_size)
bert_out = bert_model(single_test_instance[0])
print(bert_out['mlm_output'].shape  )
print(bert_out['nsp_output'].shape  )

(16, 32, 30522)
(16, 1)


### Loss functions and learning rate 

In [11]:
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98,
                                     epsilon=1e-9)
# optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5)
loss_object_nsp = tf.keras.losses.BinaryCrossentropy(from_logits=True)


def compute_mlm_loss(masked_positions, masked_labels, logits):
    # Gather the logits at the masked positions
    masked_logits = tf.gather(logits, masked_positions, batch_dims=1, axis=1) 
    # tf.print("masked_positions:", masked_positions[0], summarize=-1)
    # tf.print("masked_labels:", masked_labels[0], summarize=-1)
    # tf.print("masked_logits:", masked_logits[0].shape, summarize=-1)
    # Ensure that the masked_labels used here are the correct length and match the number of masked_positions
    mlm_loss = tf.keras.losses.sparse_categorical_crossentropy(masked_labels, masked_logits, from_logits=True)
    # tf.print("mlm_loss:", mlm_loss[0], summarize=-1)
    # Reduce mean across batches if needed or sum as appropriate
    return tf.reduce_mean(mlm_loss)
    
def compute_nsp_loss(labels, logits):
    return tf.keras.losses.binary_crossentropy(labels, logits, from_logits=True)

###  pre training step 

In [12]:
@tf.function
def train_step(inputs, labels, task='both'):
    with tf.GradientTape() as tape:        
        predictions = bert_model(inputs, training=True)  # Predictions will have 'mlm_output' and 'nsp_output'
        # Compute the MLM loss using the positions and labels
        loss_mlm = compute_mlm_loss(labels['masked_lm_positions'], 
                                    labels['mlm_labels'], predictions['mlm_output'])
        # NSP loss remains the same
        loss_nsp = loss_object_nsp(labels['nsp_labels'], predictions['nsp_output'])
        if task == 'nsp':
            total_loss = loss_nsp
        elif task == 'mlm':
            total_loss = loss_mlm
        else:
            total_loss = loss_mlm + loss_nsp
    gradients = tape.gradient(total_loss, bert_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, bert_model.trainable_variables))
    # tf.print("Total Loss:", total_loss, "MLM Loss:", loss_mlm, "NSP Loss:", loss_nsp)
    return total_loss, loss_mlm, loss_nsp

## Pre training 

###  pretraining only mlm 

In [147]:
train_dataset = train_dataset.repeat()
epochs = 100
for epoch in range(epochs):
    step = 0
    for inputs, labels in train_dataset.take(1):
        # tf.print("input_ids:", inputs['input_ids'][0], summarize=-1)
        loss_values = train_step(inputs, labels, task='mlm')
        if step % 10 == 0:
            print(f"Epoch {epoch + 1}, Step {step}, Total Loss: {loss_values[0].numpy():.4f}, MLM Loss: {loss_values[1].numpy():.4f}, NSP Loss: {loss_values[2].numpy():.4f}")
        step += 1
# for epoch in range(epochs):
#     step = 0
#     # Create a new iterator for each epoch
#     dataset_iter = iter(train_dataset.take(10))
#     while True:
#         try:
#             inputs, labels = next(dataset_iter)
#             loss_values = train_step(inputs, labels, task='mlm')
#             if step % 10 == 0:
#                 print(f"Epoch {epoch + 1}, Step {step}, Total Loss: {loss_values[0].numpy():.4f}, MLM Loss: {loss_values[1].numpy():.4f}, NSP Loss: {loss_values[2].numpy():.4f}")
#             step += 1
#             # Break after one batch to simulate take(1)
#             # break
#         except StopIteration:
#             break

Epoch 1, Step 0, Total Loss: 10.3523, MLM Loss: 10.3523, NSP Loss: 0.7588


2024-05-16 11:05:27.820836: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 2, Step 0, Total Loss: 10.3740, MLM Loss: 10.3740, NSP Loss: 0.8065


2024-05-16 11:05:29.159169: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 3, Step 0, Total Loss: 10.3871, MLM Loss: 10.3871, NSP Loss: 0.8478


2024-05-16 11:05:30.473423: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 4, Step 0, Total Loss: 10.3338, MLM Loss: 10.3338, NSP Loss: 0.8719


2024-05-16 11:05:31.768384: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 5, Step 0, Total Loss: 10.3394, MLM Loss: 10.3394, NSP Loss: 0.8040


2024-05-16 11:05:33.041288: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 6, Step 0, Total Loss: 10.3431, MLM Loss: 10.3431, NSP Loss: 0.7065


2024-05-16 11:05:34.386915: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 7, Step 0, Total Loss: 10.3434, MLM Loss: 10.3434, NSP Loss: 0.7179


2024-05-16 11:05:35.781340: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 8, Step 0, Total Loss: 10.3589, MLM Loss: 10.3589, NSP Loss: 0.8624


2024-05-16 11:05:37.093158: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 9, Step 0, Total Loss: 10.3311, MLM Loss: 10.3311, NSP Loss: 0.7160


2024-05-16 11:05:38.403152: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 10, Step 0, Total Loss: 10.3145, MLM Loss: 10.3145, NSP Loss: 0.7148


2024-05-16 11:05:39.710706: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 11, Step 0, Total Loss: 10.3115, MLM Loss: 10.3115, NSP Loss: 0.6065


2024-05-16 11:05:41.017094: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 12, Step 0, Total Loss: 10.2941, MLM Loss: 10.2941, NSP Loss: 0.8228


2024-05-16 11:05:42.326403: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 13, Step 0, Total Loss: 10.2844, MLM Loss: 10.2844, NSP Loss: 0.7140


2024-05-16 11:05:43.662081: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 14, Step 0, Total Loss: 10.2663, MLM Loss: 10.2663, NSP Loss: 0.7763


2024-05-16 11:05:45.000540: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 15, Step 0, Total Loss: 10.2994, MLM Loss: 10.2994, NSP Loss: 0.8444


2024-05-16 11:05:46.370642: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 16, Step 0, Total Loss: 10.2889, MLM Loss: 10.2889, NSP Loss: 0.7690


2024-05-16 11:05:47.786073: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 17, Step 0, Total Loss: 10.2409, MLM Loss: 10.2409, NSP Loss: 0.7088


2024-05-16 11:05:49.127251: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 18, Step 0, Total Loss: 10.2258, MLM Loss: 10.2258, NSP Loss: 0.8141


2024-05-16 11:05:50.400261: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 19, Step 0, Total Loss: 10.2558, MLM Loss: 10.2558, NSP Loss: 0.7559


2024-05-16 11:05:51.758108: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 20, Step 0, Total Loss: 10.1724, MLM Loss: 10.1724, NSP Loss: 0.7898


2024-05-16 11:05:53.203768: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 21, Step 0, Total Loss: 10.1910, MLM Loss: 10.1910, NSP Loss: 0.8164


2024-05-16 11:05:54.603510: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 22, Step 0, Total Loss: 10.1616, MLM Loss: 10.1616, NSP Loss: 0.7280


2024-05-16 11:05:55.943943: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 23, Step 0, Total Loss: 10.1623, MLM Loss: 10.1623, NSP Loss: 0.9647


2024-05-16 11:05:57.250937: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 24, Step 0, Total Loss: 10.1477, MLM Loss: 10.1477, NSP Loss: 0.8506


2024-05-16 11:05:58.543883: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 25, Step 0, Total Loss: 10.1425, MLM Loss: 10.1425, NSP Loss: 0.8144


2024-05-16 11:05:59.837754: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 26, Step 0, Total Loss: 10.1049, MLM Loss: 10.1049, NSP Loss: 0.4734


2024-05-16 11:06:01.123319: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 27, Step 0, Total Loss: 10.0800, MLM Loss: 10.0800, NSP Loss: 0.8051


2024-05-16 11:06:02.462482: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 28, Step 0, Total Loss: 10.0795, MLM Loss: 10.0795, NSP Loss: 0.8810


2024-05-16 11:06:03.781568: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 29, Step 0, Total Loss: 10.0402, MLM Loss: 10.0402, NSP Loss: 0.7338


2024-05-16 11:06:05.084095: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 30, Step 0, Total Loss: 10.0395, MLM Loss: 10.0395, NSP Loss: 0.7840


2024-05-16 11:06:06.370703: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 31, Step 0, Total Loss: 10.0148, MLM Loss: 10.0148, NSP Loss: 0.8483


2024-05-16 11:06:07.730416: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 32, Step 0, Total Loss: 9.9909, MLM Loss: 9.9909, NSP Loss: 0.7342


2024-05-16 11:06:09.144301: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 33, Step 0, Total Loss: 9.9647, MLM Loss: 9.9647, NSP Loss: 0.7196


2024-05-16 11:06:10.588495: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 34, Step 0, Total Loss: 9.9498, MLM Loss: 9.9498, NSP Loss: 0.7199


2024-05-16 11:06:11.978126: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 35, Step 0, Total Loss: 9.9367, MLM Loss: 9.9367, NSP Loss: 0.6957


2024-05-16 11:06:13.326427: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 36, Step 0, Total Loss: 9.9059, MLM Loss: 9.9059, NSP Loss: 0.8324


2024-05-16 11:06:14.662159: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 37, Step 0, Total Loss: 9.8886, MLM Loss: 9.8886, NSP Loss: 0.7913


2024-05-16 11:06:15.966932: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 38, Step 0, Total Loss: 9.9162, MLM Loss: 9.9162, NSP Loss: 0.8017


2024-05-16 11:06:17.294502: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 39, Step 0, Total Loss: 9.8396, MLM Loss: 9.8396, NSP Loss: 0.7087


2024-05-16 11:06:18.577463: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 40, Step 0, Total Loss: 9.8344, MLM Loss: 9.8344, NSP Loss: 0.8255


2024-05-16 11:06:19.867774: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 41, Step 0, Total Loss: 9.8138, MLM Loss: 9.8138, NSP Loss: 0.8037


2024-05-16 11:06:21.167325: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 42, Step 0, Total Loss: 9.8025, MLM Loss: 9.8025, NSP Loss: 0.6386


2024-05-16 11:06:22.520093: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 43, Step 0, Total Loss: 9.7586, MLM Loss: 9.7586, NSP Loss: 0.7449


2024-05-16 11:06:23.929167: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 44, Step 0, Total Loss: 9.7457, MLM Loss: 9.7457, NSP Loss: 0.5881


2024-05-16 11:06:25.287374: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 45, Step 0, Total Loss: 9.7618, MLM Loss: 9.7618, NSP Loss: 0.7520


2024-05-16 11:06:26.666982: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 46, Step 0, Total Loss: 9.7157, MLM Loss: 9.7157, NSP Loss: 0.7484


2024-05-16 11:06:28.173975: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 47, Step 0, Total Loss: 9.7121, MLM Loss: 9.7121, NSP Loss: 0.7630


2024-05-16 11:06:29.887557: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 48, Step 0, Total Loss: 9.6765, MLM Loss: 9.6765, NSP Loss: 0.6966


2024-05-16 11:06:31.580456: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 49, Step 0, Total Loss: 9.6340, MLM Loss: 9.6340, NSP Loss: 0.7328


2024-05-16 11:06:33.208940: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 50, Step 0, Total Loss: 9.6577, MLM Loss: 9.6577, NSP Loss: 0.7028


2024-05-16 11:06:34.781596: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 51, Step 0, Total Loss: 9.6225, MLM Loss: 9.6225, NSP Loss: 0.7240


2024-05-16 11:06:36.317621: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 52, Step 0, Total Loss: 9.5863, MLM Loss: 9.5863, NSP Loss: 0.7805


2024-05-16 11:06:37.851868: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 53, Step 0, Total Loss: 9.5837, MLM Loss: 9.5837, NSP Loss: 0.9090


2024-05-16 11:06:39.358865: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 54, Step 0, Total Loss: 9.5552, MLM Loss: 9.5552, NSP Loss: 0.9720


2024-05-16 11:06:40.806480: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 55, Step 0, Total Loss: 9.5236, MLM Loss: 9.5236, NSP Loss: 0.7683


2024-05-16 11:06:42.168718: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 56, Step 0, Total Loss: 9.4820, MLM Loss: 9.4820, NSP Loss: 0.9051


2024-05-16 11:06:43.553318: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 57, Step 0, Total Loss: 9.5084, MLM Loss: 9.5084, NSP Loss: 0.7577


2024-05-16 11:06:44.935489: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 58, Step 0, Total Loss: 9.4902, MLM Loss: 9.4902, NSP Loss: 0.7844


2024-05-16 11:06:46.355484: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 59, Step 0, Total Loss: 9.4766, MLM Loss: 9.4766, NSP Loss: 0.7764


2024-05-16 11:06:47.715680: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 60, Step 0, Total Loss: 9.4409, MLM Loss: 9.4409, NSP Loss: 0.7934


2024-05-16 11:06:49.051299: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 61, Step 0, Total Loss: 9.4356, MLM Loss: 9.4356, NSP Loss: 0.8621


2024-05-16 11:06:50.368268: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 62, Step 0, Total Loss: 9.3979, MLM Loss: 9.3979, NSP Loss: 0.6562


2024-05-16 11:06:51.747969: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 63, Step 0, Total Loss: 9.3804, MLM Loss: 9.3804, NSP Loss: 0.7982


2024-05-16 11:06:53.056440: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 64, Step 0, Total Loss: 9.3805, MLM Loss: 9.3805, NSP Loss: 0.7430


2024-05-16 11:06:54.423165: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 65, Step 0, Total Loss: 9.3652, MLM Loss: 9.3652, NSP Loss: 0.6402


2024-05-16 11:06:55.748900: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 66, Step 0, Total Loss: 9.3036, MLM Loss: 9.3036, NSP Loss: 0.8558


2024-05-16 11:06:57.055296: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 67, Step 0, Total Loss: 9.3270, MLM Loss: 9.3270, NSP Loss: 0.8807


2024-05-16 11:06:58.372897: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 68, Step 0, Total Loss: 9.2914, MLM Loss: 9.2914, NSP Loss: 0.8570


2024-05-16 11:06:59.695285: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 69, Step 0, Total Loss: 9.2588, MLM Loss: 9.2588, NSP Loss: 0.6241


2024-05-16 11:07:01.019280: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 70, Step 0, Total Loss: 9.2118, MLM Loss: 9.2118, NSP Loss: 0.7096


2024-05-16 11:07:02.338508: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 71, Step 0, Total Loss: 9.2204, MLM Loss: 9.2204, NSP Loss: 0.7138


2024-05-16 11:07:03.646075: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 72, Step 0, Total Loss: 9.2043, MLM Loss: 9.2043, NSP Loss: 0.6461


2024-05-16 11:07:04.953253: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 73, Step 0, Total Loss: 9.1764, MLM Loss: 9.1764, NSP Loss: 0.8302


2024-05-16 11:07:06.284394: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 74, Step 0, Total Loss: 9.1492, MLM Loss: 9.1492, NSP Loss: 0.7267


2024-05-16 11:07:07.596203: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 75, Step 0, Total Loss: 9.1446, MLM Loss: 9.1446, NSP Loss: 0.8544


2024-05-16 11:07:08.900148: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 76, Step 0, Total Loss: 9.1102, MLM Loss: 9.1102, NSP Loss: 0.7594


2024-05-16 11:07:10.218968: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 77, Step 0, Total Loss: 9.0945, MLM Loss: 9.0945, NSP Loss: 0.8876


2024-05-16 11:07:11.527941: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 78, Step 0, Total Loss: 9.0711, MLM Loss: 9.0711, NSP Loss: 0.7874


2024-05-16 11:07:12.854529: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 79, Step 0, Total Loss: 9.0417, MLM Loss: 9.0417, NSP Loss: 0.7179


2024-05-16 11:07:14.210091: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 80, Step 0, Total Loss: 8.9940, MLM Loss: 8.9940, NSP Loss: 1.0353


2024-05-16 11:07:15.682850: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 81, Step 0, Total Loss: 9.0018, MLM Loss: 9.0018, NSP Loss: 0.9501


2024-05-16 11:07:17.267686: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 82, Step 0, Total Loss: 9.0022, MLM Loss: 9.0022, NSP Loss: 0.7825


2024-05-16 11:07:18.753736: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 83, Step 0, Total Loss: 8.9456, MLM Loss: 8.9456, NSP Loss: 0.6976


2024-05-16 11:07:20.315559: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 84, Step 0, Total Loss: 8.9386, MLM Loss: 8.9386, NSP Loss: 0.6444


2024-05-16 11:07:21.829121: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 85, Step 0, Total Loss: 8.9409, MLM Loss: 8.9409, NSP Loss: 0.7415


2024-05-16 11:07:23.351961: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 86, Step 0, Total Loss: 8.8871, MLM Loss: 8.8871, NSP Loss: 0.8380


2024-05-16 11:07:24.804257: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 87, Step 0, Total Loss: 8.8399, MLM Loss: 8.8399, NSP Loss: 0.9388


2024-05-16 11:07:26.272296: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 88, Step 0, Total Loss: 8.8390, MLM Loss: 8.8390, NSP Loss: 0.7943


2024-05-16 11:07:27.688584: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 89, Step 0, Total Loss: 8.8104, MLM Loss: 8.8104, NSP Loss: 0.8526


2024-05-16 11:07:29.099037: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 90, Step 0, Total Loss: 8.7807, MLM Loss: 8.7807, NSP Loss: 0.8083


2024-05-16 11:07:30.482832: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 91, Step 0, Total Loss: 8.7379, MLM Loss: 8.7379, NSP Loss: 0.9789


2024-05-16 11:07:31.877849: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 92, Step 0, Total Loss: 8.7275, MLM Loss: 8.7275, NSP Loss: 0.8325


2024-05-16 11:07:33.358974: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 93, Step 0, Total Loss: 8.7004, MLM Loss: 8.7004, NSP Loss: 0.7062


2024-05-16 11:07:34.835102: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 94, Step 0, Total Loss: 8.6888, MLM Loss: 8.6888, NSP Loss: 0.8646


2024-05-16 11:07:36.354653: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 95, Step 0, Total Loss: 8.6691, MLM Loss: 8.6691, NSP Loss: 0.7720


2024-05-16 11:07:37.897035: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 96, Step 0, Total Loss: 8.6194, MLM Loss: 8.6194, NSP Loss: 0.7994


2024-05-16 11:07:39.474798: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 97, Step 0, Total Loss: 8.5966, MLM Loss: 8.5966, NSP Loss: 0.8384


2024-05-16 11:07:40.912835: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 98, Step 0, Total Loss: 8.5884, MLM Loss: 8.5884, NSP Loss: 0.7938


2024-05-16 11:07:42.547898: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 99, Step 0, Total Loss: 8.5649, MLM Loss: 8.5649, NSP Loss: 0.7803


2024-05-16 11:07:44.111065: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 100, Step 0, Total Loss: 8.5479, MLM Loss: 8.5479, NSP Loss: 0.7651


2024-05-16 11:07:45.505561: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


###  pre-training with mlm and nsp

In [148]:
epochs = 100
for epoch in range(epochs):
    step = 0
    for inputs, labels in train_dataset.take(1):
        # tf.print("input_ids:", inputs['input_ids'][0], summarize=-1)
        loss_values = train_step(inputs, labels, task='both')
        if step % 10 == 0:
            print(f"Epoch {epoch + 1}, Step {step}, Total Loss: {loss_values[0].numpy():.4f}, MLM Loss: {loss_values[1].numpy():.4f}, NSP Loss: {loss_values[2].numpy():.4f}")
        step += 1

Epoch 1, Step 0, Total Loss: 9.2853, MLM Loss: 8.5035, NSP Loss: 0.7818


2024-05-16 16:11:12.980846: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 2, Step 0, Total Loss: 9.2618, MLM Loss: 8.4925, NSP Loss: 0.7693


2024-05-16 16:11:14.137424: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 3, Step 0, Total Loss: 9.2957, MLM Loss: 8.4526, NSP Loss: 0.8431


2024-05-16 16:11:15.302822: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 4, Step 0, Total Loss: 9.0262, MLM Loss: 8.4133, NSP Loss: 0.6129


2024-05-16 16:11:16.454072: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 5, Step 0, Total Loss: 8.9619, MLM Loss: 8.4075, NSP Loss: 0.5544


2024-05-16 16:11:17.593652: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 6, Step 0, Total Loss: 8.9275, MLM Loss: 8.3956, NSP Loss: 0.5319


2024-05-16 16:11:18.753803: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 7, Step 0, Total Loss: 8.8233, MLM Loss: 8.3807, NSP Loss: 0.4426


2024-05-16 16:11:19.910653: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 8, Step 0, Total Loss: 8.7205, MLM Loss: 8.3708, NSP Loss: 0.3497


2024-05-16 16:11:21.074901: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 9, Step 0, Total Loss: 8.5759, MLM Loss: 8.3451, NSP Loss: 0.2309


2024-05-16 16:11:22.229045: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 10, Step 0, Total Loss: 8.4935, MLM Loss: 8.3028, NSP Loss: 0.1906


2024-05-16 16:11:23.402598: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 11, Step 0, Total Loss: 8.5185, MLM Loss: 8.3012, NSP Loss: 0.2173


2024-05-16 16:11:24.606374: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 12, Step 0, Total Loss: 8.4484, MLM Loss: 8.2722, NSP Loss: 0.1762


2024-05-16 16:11:25.780957: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 13, Step 0, Total Loss: 8.5268, MLM Loss: 8.2448, NSP Loss: 0.2820


2024-05-16 16:11:26.944361: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 14, Step 0, Total Loss: 8.3182, MLM Loss: 8.2263, NSP Loss: 0.0919


2024-05-16 16:11:28.116963: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 15, Step 0, Total Loss: 8.3541, MLM Loss: 8.1970, NSP Loss: 0.1571


2024-05-16 16:11:29.308536: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 16, Step 0, Total Loss: 8.3399, MLM Loss: 8.1733, NSP Loss: 0.1666


2024-05-16 16:11:30.501190: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 17, Step 0, Total Loss: 8.2542, MLM Loss: 8.1484, NSP Loss: 0.1059


2024-05-16 16:11:31.653330: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 18, Step 0, Total Loss: 8.1511, MLM Loss: 8.1025, NSP Loss: 0.0486


2024-05-16 16:11:32.852965: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 19, Step 0, Total Loss: 8.1972, MLM Loss: 8.1144, NSP Loss: 0.0828


2024-05-16 16:11:34.035532: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 20, Step 0, Total Loss: 8.0876, MLM Loss: 8.0674, NSP Loss: 0.0202


2024-05-16 16:11:35.192186: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 21, Step 0, Total Loss: 8.0851, MLM Loss: 8.0628, NSP Loss: 0.0223


2024-05-16 16:11:36.372153: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 22, Step 0, Total Loss: 8.0932, MLM Loss: 8.0649, NSP Loss: 0.0283


2024-05-16 16:11:37.515898: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 23, Step 0, Total Loss: 8.0189, MLM Loss: 8.0124, NSP Loss: 0.0065


2024-05-16 16:11:38.701365: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 24, Step 0, Total Loss: 7.9813, MLM Loss: 7.9757, NSP Loss: 0.0056


2024-05-16 16:11:39.888117: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 25, Step 0, Total Loss: 7.9736, MLM Loss: 7.9694, NSP Loss: 0.0043


2024-05-16 16:11:41.057321: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 26, Step 0, Total Loss: 7.9236, MLM Loss: 7.9174, NSP Loss: 0.0062


2024-05-16 16:11:42.226981: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 27, Step 0, Total Loss: 7.8993, MLM Loss: 7.8959, NSP Loss: 0.0034


2024-05-16 16:11:43.432859: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 28, Step 0, Total Loss: 7.8621, MLM Loss: 7.8600, NSP Loss: 0.0021


2024-05-16 16:11:44.588475: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 29, Step 0, Total Loss: 7.8340, MLM Loss: 7.8325, NSP Loss: 0.0015


2024-05-16 16:11:45.794805: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 30, Step 0, Total Loss: 7.8311, MLM Loss: 7.8291, NSP Loss: 0.0020


2024-05-16 16:11:46.971806: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 31, Step 0, Total Loss: 7.8021, MLM Loss: 7.7991, NSP Loss: 0.0030


2024-05-16 16:11:48.136634: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 32, Step 0, Total Loss: 7.7394, MLM Loss: 7.7336, NSP Loss: 0.0058


2024-05-16 16:11:49.287765: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 33, Step 0, Total Loss: 7.7271, MLM Loss: 7.7262, NSP Loss: 0.0009


2024-05-16 16:11:50.441692: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 34, Step 0, Total Loss: 7.6942, MLM Loss: 7.6927, NSP Loss: 0.0015


2024-05-16 16:11:51.642606: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 35, Step 0, Total Loss: 7.7131, MLM Loss: 7.6791, NSP Loss: 0.0340


2024-05-16 16:11:52.852646: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 36, Step 0, Total Loss: 7.6413, MLM Loss: 7.6401, NSP Loss: 0.0013


2024-05-16 16:11:54.050947: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 37, Step 0, Total Loss: 7.6061, MLM Loss: 7.6052, NSP Loss: 0.0009


2024-05-16 16:11:55.253534: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 38, Step 0, Total Loss: 7.5889, MLM Loss: 7.5881, NSP Loss: 0.0009


2024-05-16 16:11:56.402128: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 39, Step 0, Total Loss: 7.5438, MLM Loss: 7.5419, NSP Loss: 0.0019


2024-05-16 16:11:57.591754: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 40, Step 0, Total Loss: 7.5117, MLM Loss: 7.5102, NSP Loss: 0.0015


2024-05-16 16:11:58.759201: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 41, Step 0, Total Loss: 7.4845, MLM Loss: 7.4833, NSP Loss: 0.0011


2024-05-16 16:11:59.951458: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 42, Step 0, Total Loss: 7.4568, MLM Loss: 7.4559, NSP Loss: 0.0009


2024-05-16 16:12:01.132236: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 43, Step 0, Total Loss: 7.4298, MLM Loss: 7.4292, NSP Loss: 0.0006


2024-05-16 16:12:02.282991: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 44, Step 0, Total Loss: 7.3951, MLM Loss: 7.3931, NSP Loss: 0.0020


2024-05-16 16:12:03.449611: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 45, Step 0, Total Loss: 7.3449, MLM Loss: 7.3413, NSP Loss: 0.0036


2024-05-16 16:12:04.602249: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 46, Step 0, Total Loss: 7.3325, MLM Loss: 7.3296, NSP Loss: 0.0029


2024-05-16 16:12:05.751165: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 47, Step 0, Total Loss: 7.2830, MLM Loss: 7.2817, NSP Loss: 0.0012


2024-05-16 16:12:06.905779: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 48, Step 0, Total Loss: 7.2392, MLM Loss: 7.2386, NSP Loss: 0.0006


2024-05-16 16:12:08.121879: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 49, Step 0, Total Loss: 7.2220, MLM Loss: 7.2213, NSP Loss: 0.0007


2024-05-16 16:12:09.370543: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 50, Step 0, Total Loss: 7.1954, MLM Loss: 7.1942, NSP Loss: 0.0012


2024-05-16 16:12:10.633698: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 51, Step 0, Total Loss: 7.1724, MLM Loss: 7.1719, NSP Loss: 0.0006


2024-05-16 16:12:11.805321: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 52, Step 0, Total Loss: 7.1277, MLM Loss: 7.1271, NSP Loss: 0.0006


2024-05-16 16:12:12.965454: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 53, Step 0, Total Loss: 7.0559, MLM Loss: 7.0547, NSP Loss: 0.0012


2024-05-16 16:12:14.140692: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 54, Step 0, Total Loss: 7.0470, MLM Loss: 7.0464, NSP Loss: 0.0006


2024-05-16 16:12:15.312988: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 55, Step 0, Total Loss: 7.0078, MLM Loss: 7.0072, NSP Loss: 0.0005


2024-05-16 16:12:16.586701: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 56, Step 0, Total Loss: 6.9870, MLM Loss: 6.9864, NSP Loss: 0.0006


2024-05-16 16:12:17.838235: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 57, Step 0, Total Loss: 6.9405, MLM Loss: 6.9400, NSP Loss: 0.0005


2024-05-16 16:12:19.199454: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 58, Step 0, Total Loss: 6.8986, MLM Loss: 6.8983, NSP Loss: 0.0003


2024-05-16 16:12:20.441185: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 59, Step 0, Total Loss: 6.8702, MLM Loss: 6.8696, NSP Loss: 0.0006


2024-05-16 16:12:21.844552: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 60, Step 0, Total Loss: 6.8532, MLM Loss: 6.8527, NSP Loss: 0.0005


2024-05-16 16:12:23.204564: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 61, Step 0, Total Loss: 6.8043, MLM Loss: 6.8037, NSP Loss: 0.0006


2024-05-16 16:12:24.635559: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 62, Step 0, Total Loss: 6.8035, MLM Loss: 6.8026, NSP Loss: 0.0009


2024-05-16 16:12:25.973147: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 63, Step 0, Total Loss: 6.7236, MLM Loss: 6.7226, NSP Loss: 0.0010


2024-05-16 16:12:27.141438: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 64, Step 0, Total Loss: 6.7030, MLM Loss: 6.7020, NSP Loss: 0.0010


2024-05-16 16:12:28.299723: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 65, Step 0, Total Loss: 6.6809, MLM Loss: 6.6798, NSP Loss: 0.0011


2024-05-16 16:12:29.456955: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 66, Step 0, Total Loss: 6.6281, MLM Loss: 6.6273, NSP Loss: 0.0008


2024-05-16 16:12:30.601637: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 67, Step 0, Total Loss: 6.6066, MLM Loss: 6.6060, NSP Loss: 0.0005


2024-05-16 16:12:31.766418: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 68, Step 0, Total Loss: 6.5588, MLM Loss: 6.5580, NSP Loss: 0.0008


2024-05-16 16:12:32.912486: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 69, Step 0, Total Loss: 6.5292, MLM Loss: 6.5287, NSP Loss: 0.0005


2024-05-16 16:12:34.075935: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 70, Step 0, Total Loss: 6.4808, MLM Loss: 6.4791, NSP Loss: 0.0017


2024-05-16 16:12:35.253682: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 71, Step 0, Total Loss: 6.4537, MLM Loss: 6.4521, NSP Loss: 0.0015


2024-05-16 16:12:36.431789: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 72, Step 0, Total Loss: 6.3983, MLM Loss: 6.3974, NSP Loss: 0.0009


2024-05-16 16:12:37.608505: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 73, Step 0, Total Loss: 6.3520, MLM Loss: 6.3507, NSP Loss: 0.0013


2024-05-16 16:12:38.766885: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 74, Step 0, Total Loss: 6.3406, MLM Loss: 6.3396, NSP Loss: 0.0009


2024-05-16 16:12:39.922458: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 75, Step 0, Total Loss: 6.2964, MLM Loss: 6.2957, NSP Loss: 0.0007


2024-05-16 16:12:41.085710: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 76, Step 0, Total Loss: 6.2630, MLM Loss: 6.2622, NSP Loss: 0.0008


2024-05-16 16:12:42.315230: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 77, Step 0, Total Loss: 6.2088, MLM Loss: 6.2077, NSP Loss: 0.0011


2024-05-16 16:12:43.578006: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 78, Step 0, Total Loss: 6.1974, MLM Loss: 6.1965, NSP Loss: 0.0010


2024-05-16 16:12:44.841033: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 79, Step 0, Total Loss: 6.1371, MLM Loss: 6.1361, NSP Loss: 0.0009


2024-05-16 16:12:46.091703: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 80, Step 0, Total Loss: 6.1032, MLM Loss: 6.1024, NSP Loss: 0.0007


2024-05-16 16:12:47.251995: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 81, Step 0, Total Loss: 6.0621, MLM Loss: 6.0612, NSP Loss: 0.0009


2024-05-16 16:12:48.406732: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 82, Step 0, Total Loss: 6.0339, MLM Loss: 6.0335, NSP Loss: 0.0005


2024-05-16 16:12:49.610271: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 83, Step 0, Total Loss: 5.9980, MLM Loss: 5.9974, NSP Loss: 0.0006


2024-05-16 16:12:50.865437: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 84, Step 0, Total Loss: 5.9413, MLM Loss: 5.9408, NSP Loss: 0.0005


2024-05-16 16:12:52.027334: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 85, Step 0, Total Loss: 5.9035, MLM Loss: 5.9030, NSP Loss: 0.0006


2024-05-16 16:12:53.189949: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 86, Step 0, Total Loss: 5.8612, MLM Loss: 5.8606, NSP Loss: 0.0006


2024-05-16 16:12:54.371416: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 87, Step 0, Total Loss: 5.8312, MLM Loss: 5.8301, NSP Loss: 0.0011


2024-05-16 16:12:55.649641: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 88, Step 0, Total Loss: 5.7807, MLM Loss: 5.7798, NSP Loss: 0.0010


2024-05-16 16:12:56.921863: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 89, Step 0, Total Loss: 5.7541, MLM Loss: 5.7536, NSP Loss: 0.0005


2024-05-16 16:12:58.198763: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 90, Step 0, Total Loss: 5.7127, MLM Loss: 5.7117, NSP Loss: 0.0009


2024-05-16 16:12:59.446313: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 91, Step 0, Total Loss: 5.6657, MLM Loss: 5.6650, NSP Loss: 0.0006


2024-05-16 16:13:00.631946: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 92, Step 0, Total Loss: 5.6121, MLM Loss: 5.6112, NSP Loss: 0.0009


2024-05-16 16:13:01.874857: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 93, Step 0, Total Loss: 5.5821, MLM Loss: 5.5811, NSP Loss: 0.0011


2024-05-16 16:13:03.084627: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 94, Step 0, Total Loss: 5.5579, MLM Loss: 5.5563, NSP Loss: 0.0016


2024-05-16 16:13:04.330866: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 95, Step 0, Total Loss: 5.5086, MLM Loss: 5.5074, NSP Loss: 0.0011


2024-05-16 16:13:05.505525: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 96, Step 0, Total Loss: 5.4798, MLM Loss: 5.4789, NSP Loss: 0.0009


2024-05-16 16:13:06.639287: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 97, Step 0, Total Loss: 5.4147, MLM Loss: 5.4140, NSP Loss: 0.0008


2024-05-16 16:13:07.795475: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 98, Step 0, Total Loss: 5.3827, MLM Loss: 5.3818, NSP Loss: 0.0008


2024-05-16 16:13:08.996257: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 99, Step 0, Total Loss: 5.3376, MLM Loss: 5.3368, NSP Loss: 0.0009


2024-05-16 16:13:10.164782: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 100, Step 0, Total Loss: 5.2939, MLM Loss: 5.2930, NSP Loss: 0.0009


2024-05-16 16:13:11.312709: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


### saving the  pre trained model

In [None]:
fit_model = bert_model
fit_model.save_weights(model_full_name)
with open(model_path + '/bert_pt_1_config.pkl', 'wb') as config_file:
        pickle.dump(fit_model.get_config(), config_file)
with open(model_path + '/bert_pt_1_config.pkl', 'rb') as config_file:
        fit_model_config = pickle.load(config_file)
print(fit_model.summary())

### Loadning the saved pretrained model 

In [16]:
loaded_fit_model = BERT(num_layers, d_model, num_heads, 
                  dff, vocab_size, segment_size)

loaded_fit_model_out = loaded_fit_model(single_test_instance[0])
loaded_fit_model.load_weights(model_full_name)
print(loaded_fit_model.summary())

None


# Fine Tuning

###  SQuAD question - answer data  for Fine Tuning -  save in json

In [45]:
squad_train_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json"
squad_dev_url = "https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json"

# Download the training dataset
response = requests.get(squad_train_url)
squad_train_data = response.json()

# Save the training dataset to a file
with open('train-v2.0.json', 'w') as f:
    json.dump(squad_train_data, f)

# Download the validation dataset
response = requests.get(squad_dev_url)
squad_dev_data = response.json()

# Save the validation dataset to a file
with open('dev-v2.0.json', 'w') as f:
    json.dump(squad_dev_data, f)

In [54]:
print((squad_train_data.keys()))
print('data:', type(squad_train_data['data']))
print('1st data:', squad_train_data['data'][0].keys())
print('para:', type(squad_train_data['data'][0]['paragraphs']))
print('1st para:', type(squad_train_data['data'][0]['paragraphs'][0]))
print('keys in para:', squad_train_data['data'][0]['paragraphs'][0].keys())
print('qas:', type(squad_train_data['data'][0]['paragraphs'][0]['qas']))
print('qas len:', len(squad_train_data['data'][0]['paragraphs'][0]['qas']))
print('qa in qas', squad_train_data['data'][0]['paragraphs'][0]['qas'][0])
print('context in para:', squad_train_data['data'][0]['paragraphs'][0]['context'])

dict_keys(['version', 'data'])
data: <class 'list'>
1st data: dict_keys(['title', 'paragraphs'])
para: <class 'list'>
1st para: <class 'dict'>
keys in para: dict_keys(['qas', 'context'])
qas: <class 'list'>
qas len: 15
qa in qas {'question': 'When did Beyonce start becoming popular?', 'id': '56be85543aeaaa14008c9063', 'answers': [{'text': 'in the late 1990s', 'answer_start': 269}], 'is_impossible': False}
context in para: Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny's Child. Managed by her father, Mathew Knowles, the group became one of the world's best-selling girl groups of all time. Their hiatus saw the release of Beyoncé's debut album, Dangerously in Love (2003), which established her as a solo artist wo

###  format the data for fine tuning - preprocess

The main goal of a QA model is to find the answer to a given question within a provided context. For example, given a context (a passage of text) and a question, the model needs to identify the span of text within the context that answers the question.

Context and Question
Let's consider an example to illustrate this:

Context:


Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American singer, songwriter, record producer and actress. Born and raised in Houston, Texas, she performed in various singing and dancing competitions as a child, and rose to fame in the late 1990s as lead singer of R&B girl-group Destiny's Child.
Question:


When did Beyoncé start becoming popular?
Answer:

in the late 1990s
Tokenization
In tokenization, both the context and the question are tokenized together. The reason for including both is to provide the model with all the information it needs to understand the context of the question and where the answer might lie within the context.

Here's how it works step-by-step:

Tokenize the Question and Context Together:

tokenizer(question, context, truncation="only_second", max_length=max_seq_length, stride=doc_stride, return_overflowing_tokens=True, padding='max_length', return_offsets_mapping=True)
This step generates tokenized input that includes special tokens to separate the question and the context, and provides the offset mapping, which helps in locating the position of tokens within the original text.

Offset Mapping:

The offset mapping gives the start and end character positions of each token in the original context text. This is crucial for determining where the answer is located in the context.
Calculating Start and End Positions
The start and end positions are indices of the tokens that correspond to the answer within the tokenized context.

Here's the process:

Identify the Character Positions of the Answer:

In the example, the answer "in the late 1990s" starts at character position 269 and ends at 286 in the context.
Match Character Positions to Token Indices:

Use the offset mapping to find which tokens correspond to the start and end character positions of the answer. This involves checking each token's start and end character positions to see if they overlap with the answer's start and end character positions.
Assign Start and End Positions:

The start position is the index of the token that contains the start of the answer.
The end position is the index of the token that contains the end of the answer.
Example
Given the context and question above:

Answer Text: "in the late 1990s"
Answer Start Character Position: 269
Answer End Character Position: 286
The tokenizer will produce something like:

yaml
Copy code
{'input_ids': [101, 2043, 2106, 20773, 2707, ...], 
 'token_type_ids': [0, 0, 0, 0, 0, ..., 1, 1, 1, 1, 1, ...], 
 'attention_mask': [1, 1, 1, 1, 1, ...], 
 'offset_mapping': [(0, 0), (0, 4), (5, 8), (9, 16), ...]}
The offset_mapping might look like [(0, 0), (0, 4), (5, 8), (9, 16), (17, 22), ...].
In the offset mapping, you look for token positions where the answer text starts and ends:

If token at index 75 starts at character 269, it is the start token.
If token at index 78 ends at character 286, it is the end token.
So, for this example:

Start Position: 75
End Position: 78

In [77]:
with open('train-v2.0.json') as f:
    squad_train_data = json.load(f)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
def preprocess_squad(data, tokenizer, max_samples=1000, 
                     max_seq_length=128, 
                     doc_stride=16, test_one_example=False):
    input_ids_list = []
    attention_masks_list = []
    token_type_ids_list = []
    start_positions_list = []
    end_positions_list = []

    samples = 0
    for entry in data['data']:
        if samples >= max_samples:
            break
        for paragraph in entry['paragraphs']:
            context = paragraph['context']
            if test_one_example:
                print(f"Context: {context[:100]}...")  # Print a part of the context for debugging
            for qa in paragraph['qas']:
                question = qa['question']
                if len(qa['answers']) == 0:
                    continue
                answer = qa['answers'][0]
                start_char = answer['answer_start']
                end_char = start_char + len(answer['text'])
                if test_one_example:
                    print('Question:', question)
                    print('Answer:', answer)
                    print('Answer span:', (start_char, end_char))
                # Tokenize context and question
                tokenized_context = tokenizer(
                    question,
                    context,
                    truncation="only_second",
                    max_length=max_seq_length,
                    stride=doc_stride,
                    return_overflowing_tokens=False,
                    padding='max_length',
                    return_offsets_mapping=True
                )

                if test_one_example:
                    print("Tokenized context:", tokenized_context)  # Print tokenized output for debugging

                # Get start and end token positions manually
                offsets = tokenized_context["offset_mapping"]
                if test_one_example:
                    print("Offset mapping:", offsets)
                start_position = None
                end_position = None

                for idx, offset in enumerate(offsets):
                    if offset is None or len(offset) != 2:
                        continue
                    start, end = offset
                    if start <= start_char < end:
                        start_position = idx
                    if start < end_char <= end:
                        end_position = idx

                if test_one_example:
                    print("Start position:", start_position)  # Print start position for debugging
                    print("End position:", end_position)  # Print end position for debugging
                
                if start_position is None or end_position is None:
                    continue

                # Add data to lists
                input_ids_list.append(tokenized_context['input_ids'])
                attention_masks_list.append(tokenized_context['attention_mask'])
                token_type_ids_list.append(tokenized_context['token_type_ids'])
                start_positions_list.append(start_position)
                end_positions_list.append(end_position)
                
                samples += 1
                if samples >= max_samples or test_one_example:
                    break
            if samples >= max_samples or test_one_example:
                break
        if samples >= max_samples or test_one_example:
            break

    # Convert lists to tensors
    if input_ids_list:
        input_ids = tf.convert_to_tensor(input_ids_list, dtype=tf.int32)
        attention_masks = tf.convert_to_tensor(attention_masks_list, dtype=tf.int32)
        token_type_ids = tf.convert_to_tensor(token_type_ids_list, dtype=tf.int32)
        start_positions = tf.convert_to_tensor(start_positions_list, dtype=tf.int32)
        end_positions = tf.convert_to_tensor(end_positions_list, dtype=tf.int32)
    else:
        input_ids = tf.constant([])
        attention_masks = tf.constant([])
        token_type_ids = tf.constant([])
        start_positions = tf.constant([])
        end_positions = tf.constant([])

    return input_ids, attention_masks, token_type_ids, start_positions, end_positions

# Preprocess one example from the SQuAD data for testing
input_ids, attention_masks, token_type_ids, start_positions, end_positions = preprocess_squad(
    squad_train_data, tokenizer, max_samples=1, test_one_example=True)

print("input_ids:", input_ids)
print("attention_masks:", attention_masks)
print("token_type_ids:", token_type_ids)
print("start_positions:", start_positions)
print("end_positions:", end_positions)

Context: Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ bee-YON-say) (born September 4, 1981) is an American si...
Question: When did Beyonce start becoming popular?
Answer: {'text': 'in the late 1990s', 'answer_start': 269}
Answer span: (269, 286)
Tokenized context: {'input_ids': [101, 2043, 2106, 20773, 2707, 3352, 2759, 1029, 102, 20773, 21025, 19358, 22815, 1011, 5708, 1006, 1013, 12170, 23432, 29715, 3501, 29678, 12325, 29685, 1013, 10506, 1011, 10930, 2078, 1011, 2360, 1007, 1006, 2141, 2244, 1018, 1010, 3261, 1007, 2003, 2019, 2137, 3220, 1010, 6009, 1010, 2501, 3135, 1998, 3883, 1012, 2141, 1998, 2992, 1999, 5395, 1010, 3146, 1010, 2016, 2864, 1999, 2536, 4823, 1998, 5613, 6479, 2004, 1037, 2775, 1010, 1998, 3123, 2000, 4476, 1999, 1996, 2397, 4134, 2004, 2599, 3220, 1997, 1054, 1004, 1038, 2611, 1011, 2177, 10461, 1005, 1055, 2775, 1012, 3266, 2011, 2014, 2269, 1010, 25436, 22815, 1010, 1996, 2177, 2150, 2028, 1997, 1996, 2088, 1005, 1055, 2190, 1011, 4855, 2611, 2967, 1997, 203

In [81]:
input_ids, attention_masks, token_type_ids, start_positions, end_positions = preprocess_squad(
    squad_train_data, tokenizer, max_samples=1000, test_one_example=False)
print(input_ids)

tf.Tensor(
[[  101  2043  2106 ...  1997 20773   102]
 [  101  2054  2752 ...  2037 14221   102]
 [  101  1999  2054 ...  2387  1996   102]
 ...
 [  101  2040  2165 ...  2044  8792   102]
 [  101  2040  3092 ...  2000 25479   102]
 [  101  2129  2116 ...     0     0     0]], shape=(1000, 128), dtype=int32)


### Load the SQuAD dataset from the saved files

In [83]:
# Ensure buffer_size is greater than zero
buffer_size = max(len(input_ids), 1)

# Create the dataset
dataset = tf.data.Dataset.from_tensor_slices(({
    'input_ids': input_ids,
    # 'attention_mask': attention_masks,
    'token_type_ids': token_type_ids
}, {
    'start_positions': start_positions,
    'end_positions': end_positions
}))

# Batch the dataset
batch_size = 8
dataset = dataset.shuffle(buffer_size).batch(batch_size)

### Fine Tune model architecure for QA data

In [87]:
class BERTQA(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff, vocab_size, segment_size, max_seq_length=128, rate=0.1):
        super(BERTQA, self).__init__()
        self.encoder = TransformerEncoderV4(num_layers=num_layers, d_model=d_model, num_heads=num_heads,
                                            dff=dff, vocab_size=vocab_size, segment_size=segment_size,
                                            max_pos=max_seq_length, pos_dropout=rate)
        self.qa_outputs = tf.keras.layers.Dense(2)  # Start and end positions

    def call(self, inputs, training=False):
        input_ids = inputs['input_ids']
        token_type_ids = inputs['token_type_ids'] 
        x = self.encoder((input_ids, token_type_ids), training=training)
        logits = self.qa_outputs(x)
        ### the model learns to predict the start and end positions given the input tokens
        start_logits, end_logits = tf.split(logits, 2, axis=-1)
        start_logits = tf.squeeze(start_logits, axis=-1)
        end_logits = tf.squeeze(end_logits, axis=-1)
        return start_logits, end_logits
bert_qa_model = BERTQA(num_layers, d_model, num_heads, 
                  dff, vocab_size, segment_size)

### To ensure only the encoder weights are loaded, explicitly exclude the additional layers during the weight loading process.

The pre-training model includes mlm_dense_transform, mlm_layer_norm, mlm_dense, and nsp_dense layers which are specific to Masked Language Model (MLM) and Next Sentence Prediction (NSP) tasks.
The fine-tuning model includes qa_outputs layer specific to the QA task.

In [None]:
# Load pre-trained weights (ensure you have the path to your pre-trained weights)
bert_qa_model.load_weights(model_full_name, skip_mismatch=True)

### accuracy and loss functions for fine tuning

During training, the model learns to predict the start and end positions given the input tokens. 
The loss function calculates the difference between the predicted start/end positions and the true start/end positions. The model adjusts its weights to minimize this loss, improving its ability to locate the correct answer span in future contexts.

In [None]:
def start_accuracy(y_true, y_pred):
    return tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred)

def end_accuracy(y_true, y_pred):
    return tf.keras.metrics.sparse_categorical_accuracy(y_true, y_pred)

# Compile the model
optimizer = tf.keras.optimizers.Adam(learning_rate=3e-5)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

###  Fine Tune Traiining with skip mismatch of eights if any

In [93]:
bert_qa_model.compile(optimizer=optimizer, 
                      loss=loss_fn, 
                      metrics=[start_accuracy, end_accuracy])
# Train the model
bert_qa_model.fit(dataset, epochs=30)

Epoch 1/30
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m99s[0m 749ms/step - end_accuracy: 0.2933 - loss: 4.9262 - start_accuracy: 0.3005
Epoch 2/30
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m100s[0m 804ms/step - end_accuracy: 0.3263 - loss: 4.3945 - start_accuracy: 0.3620
Epoch 3/30
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m97s[0m 772ms/step - end_accuracy: 0.3466 - loss: 4.1039 - start_accuracy: 0.3588
Epoch 4/30
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m95s[0m 763ms/step - end_accuracy: 0.4128 - loss: 3.6042 - start_accuracy: 0.4044
Epoch 5/30
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m96s[0m 769ms/step - end_accuracy: 0.4858 - loss: 3.2439 - start_accuracy: 0.4814
Epoch 6/30
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m93s[0m 742ms/step - end_accuracy: 0.4773 - loss: 3.0950 - start_accuracy: 0.4637
Epoch 7/30
[1m125/125[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m93s[0m 746ms/ste

<keras.src.callbacks.history.History at 0x7f2a88595dc0>

# Adopt for Hugging Face

In [96]:
# Save the fine-tuned model
# model_path = '/mnt/d/MyDev/attention/transformerlab/bert/models/bert_fine_tuned_qa'
qa_model_full_name = os.path.join(model_path, 'finetuned_qa_bert.weights.h5')
bert_qa_model.save_weights(qa_model_full_name)

In [158]:
class TFCustomBertModel(TFPreTrainedModel):
    config_class = BertConfig
    base_model_prefix = "bert"

    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.encoder = TransformerEncoderV4(num_layers=config.num_hidden_layers,
                                            d_model=config.hidden_size,
                                            num_heads=config.num_attention_heads,
                                            dff=config.intermediate_size,
                                            vocab_size=config.vocab_size,
                                            segment_size=config.segment_size,
                                            max_pos=config.max_position_embeddings,
                                            pos_dropout=config.hidden_dropout_prob)

        self.mlm_dense_transform = tf.keras.layers.Dense(config.hidden_size, activation='gelu')
        self.mlm_layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps)
        self.mlm_dense = tf.keras.layers.Dense(config.vocab_size)
        self.nsp_dense = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, inputs, training=False):
        x = self.encoder((inputs['input_ids'], inputs['segment_ids']), training=training)
        mlm_output = self.mlm_dense_transform(x)
        mlm_output = self.mlm_layer_norm(mlm_output)
        mlm_output = self.mlm_dense(mlm_output)
        nsp_output = self.nsp_dense(x[:, 0, :])
        return {'mlm_output': mlm_output, 'nsp_output': nsp_output}

    def save_pretrained(self, save_directory, model_name="tf_model.h5", *model_args, **kwargs):
        os.makedirs(save_directory, exist_ok=True)
        model_weights_path = os.path.join(save_directory, model_name)
        self.save_weights(model_weights_path)
        self.config.save_pretrained(save_directory)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, model_name="tf_model.h5", *model_args, **kwargs):
        config = kwargs.pop("config", None)
        if config is None:
            config = BertConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
        model = cls(config, *model_args, **kwargs)

        # Create the model variables by calling the model with dummy inputs
        dummy_input_ids = tf.constant([[0] * config.max_position_embeddings])
        dummy_segment_ids = tf.constant([[0] * config.max_position_embeddings])
        model({'input_ids': dummy_input_ids, 'segment_ids': dummy_segment_ids})

        model_weights_path = os.path.join(pretrained_model_name_or_path, model_name)
        model.load_weights(model_weights_path)
        return model

# Define the config
config = BertConfig(
    vocab_size=vocab_size,  # Adjust as needed
    hidden_size=d_model,
    num_hidden_layers=num_layers,
    num_attention_heads=num_heads,
    intermediate_size=dff,
    hidden_act="gelu",
    initializer_range=0.02,
    layer_norm_eps=1e-12,
    segment_size=segment_size,
    max_position_embeddings=max_seq_length,
    hidden_dropout_prob=0.1
)

# Instantiate and save the model
model = TFCustomBertModel(config)
model.load_weights(model_full_name)
model_name = "huggingface_wiki_pretrained_tf_model.h5"
model.save_pretrained(model_path, model_name)

# Save the tokenizer
tokenizer.save_pretrained(model_path)

# Load the model
loaded_hf_model = TFCustomBertModel.from_pretrained(model_path, model_name)

### execute this step from CLI

!huggingface-cli login

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)

    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): Traceback (most recent call last):

In [159]:
from huggingface_hub import HfApi, HfFolder, Repository
# Save your Hugging Face token
# hf_token = "your_huggingface_token"
# # Save the token in the HfFolder
# HfFolder.save_token(hf_token)
model_repo_name = f"Bhujay/{model_name}"
model.push_to_hub(model_repo_name)
# # Create the repository if it doesn't exist
# api = HfApi()
# api.create_repo(model_repo_name, exist_ok=True)
# # Initialize the repository
# repo = Repository(local_dir=model_path, clone_from=model_repo_name)
# # Commit and push the files
# repo.git_add()
# repo.git_commit("Initial commit")
# repo.git_push()

tf_model.h5: 100%|█████████████████████████████████████████████████| 6.47k/6.47k [00:01<00:00, 4.91kB/s]


# Pre-Train and Fine tune a Hugging face mode; 

In [None]:
from datasets import load_dataset
dataset = load_dataset("wikipedia", "20220301.en", split='train')

In [None]:
from transformers import BertTokenizer

tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length")

tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [None]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15)


In [None]:
from transformers import BertForPreTraining, Trainer, TrainingArguments

model = BertForPreTraining.from_pretrained('bert-base-uncased')

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=8,
    save_steps=10_000,
    save_total_limit=2,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets
)
trainer.train()