<a href="https://colab.research.google.com/github/Doris-QZ/Transformers_from_Scratch--BERT_and_GPT2_in_PyTorch/blob/main/1_Data_For_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Introduction

This notebook prepares the **IMDB dataset** for BERT-style pre-training. The data is processed to align with **BERT’s training objectives**:

* **Masked Language Modeling (MLM)**: randomly masks tokens for prediction.
* **Next Sentence Prediction (NSP)**: create paired sentences with both positive and negative examples.  
>

At the end of this notebook, we'll have a pandas DataFrame with the following columns:
* **input_ids** - token IDs of paired sentences       
* **token_type_ids** - token type IDs (0 for the first sentence, 1 for the second)
* **attention_mask** - 1 for real tokens, 0 for padding
* **mlm_labels** - labels for masked language model prediction
* **nsp_labels** - labels for next sentence prediction.  
>
   
Although this is not the same corpus used in the original BERT paper (BooksCorpus + Wikipedia), the dataset is structured in a way that allows the model to be trained with the same objectives.  


In [None]:
from datasets import load_dataset
import transformers
from transformers import BertTokenizer

import torch
from torch.nn.utils.rnn import pad_sequence

import pandas as pd
import random

In [None]:
# Suppress the warning
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

# Load the data
imdb_data = load_dataset('imdb', split='train')

In [None]:
# Take a look at the data
print(f"{imdb_data}\n")
imdb_data[0]

BERT uses **WordPiece**, a subword tokenization algorithm, with a vocabulary size of 30,522 in the BERT-base model. We’ll load the BERT-base tokenizer from Hugging Face’s transformers package for text tokenization.
>
**Note:** Although the BERT tokenizer can directly return `input_ids`, `attention_mask`, and `token_type_ids`, in this project we only use its `encode` method to obtain token IDs. We will then **create `input_ids`, `attention_mask`, and `token_type_ids` manually**, since our goal is to **explore how data is prepared for masked language modeling (MLM) and next sentence prediction (NSP)**.

In [None]:
# Load BERT-base tokenizer
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Get vocabulary and vocabulary size
vocab = bert_tokenizer.get_vocab()
vocab_size = len(vocab)
vocab_size

### Function for Preparing Data for Masked Language Modeling (MLM)
According to the BERT paper, **15% of tokens are randomly selected for masking**.  

For example,<br>

> Original text: "my dog is very cute."  
> Selected token: "cute"

For the selected tokens:
* 80% of time ---> replaced with [MASK] ---> "my dog is very **[MASK]**."
* 10% of time ---> left unchanged ---> "my dog is very **cute**."
* 10% of time ---> replaced with a random token ---> "my dog is very **apple**."    
>

Regardless of how the selected tokens are masked, the label remains the same:
>[PAD][PAD][PAD][PAD] cute.  

During training, the model **ignores the [PAD] tokens** and predicts the orginal tokens at the selected positions.

In [None]:
# Get special token indices
PAD_IDX = bert_tokenizer.pad_token_id
UNK_IDX = bert_tokenizer.unk_token_id
CLS_IDX = bert_tokenizer.cls_token_id
SEP_IDX = bert_tokenizer.sep_token_id
MASK_IDX = bert_tokenizer.mask_token_id

PAD_IDX, UNK_IDX, CLS_IDX, SEP_IDX, MASK_IDX

In [None]:
special_ids = {PAD_IDX, UNK_IDX, CLS_IDX, SEP_IDX, MASK_IDX}

def random_token(vocab_size=vocab_size, special_ids=special_ids):
    """
    Returns a random token ID that is not a special token.
    """
    while True:
        idx = random.randint(0, vocab_size-1)
        if idx not in special_ids:
            return idx

def masking(token_id, PAD_IDX, MASK_IDX):
    """
    Mask a single token for Masked Language Model (MLM) training.

    Args:
        token_id (int): The token_id to be processed.
        valid_tokens (list): A list of tokens to get random tokens from.

    Returns:
        tuple of two str---
            1. The processed token id, which may be replaced with MASK_IDX, left unchanged,
            or replaced with a random token id.
            2. The label for the token-the original token id if masked, or PAD_IDX if not masked.

    """

    # The probability of a token being masked is 15%.
    mask = random.random() <= 0.15

    if not mask:
        token_ = token_id
        label_ = PAD_IDX
        return token_, label_

    # Generates a random float between 0 and 1
    random_float = random.random()

    # 80% of the selected tokens will be repalced by MASK_IDX
    if random_float < 0.8:
        token_ = MASK_IDX
        label_ = token_id
        return token_, label_

    # 10% of the selected tokens will remain unchanged
    if random_float > 0.9:
        token_ = token_id
        label_ = token_id
        return token_, label_

    # 10% of the selected tokens will be replaced by a random token id
    else:
        token_ = random_token()
        label_ = token_id
        return token_, label_

In [None]:
PUNCT_IDX = [vocab['.'], vocab['?'], vocab['!']]

def data_for_MLM(dataset, pad_idx, mask_idx, punct_idx):
    """
    Prepare data for Masked Language model (MLM) training.

    Args:
        dataset (Dataset): The dataset to be processed.
        pad_idx (int): The index of the [PAD] token.
        mask_idx(int): The index of the [MASK] token.
        punct_idx(list): A list of indices of the ending punctuations

    Returns:
        tuple of two lists--
            1. List of tokenized sentences, where each token is either replaced with MASK_IDX,
            left unchanged, or replaced with a random token.
            2. List of labels for masked tokens corresponding to the tokenized sentences.
            Each label is either the original token at masked positions or PAD_IDX at unmasked positions.
    """

    masked_sentences = []
    mlm_labels = []
    cur_tokens = []
    cur_labels = []


    for data in dataset:
        tokens =  bert_tokenizer.encode(data['text'])

        for token_id in tokens:
            token_, label_ = masking(token_id, PAD_IDX, MASK_IDX)
            cur_tokens.append(token_)
            cur_labels.append(label_)

            # Found a token indicates the end of sentence, process the sentence and reset it.
            if token_id in punct_idx:
                if len(cur_tokens) > 2:
                    masked_sentences.append(cur_tokens)
                    mlm_labels.append(cur_labels)
                    cur_tokens = []
                    cur_labels = []
                else:
                    cur_tokens = []
                    cur_labels = []

        # Append the remaining tokens that do not have an ending punctuation to the list
        if cur_tokens:
            masked_sentences.append(cur_tokens)
            mlm_labels.append(cur_labels)

    return masked_sentences, mlm_labels


### Function for Preparing Data for Next Sentence Prediction (NSP)

BERT is trained to predict whether a pair of sentences are consecutive in the original text (**Next Sentence Prediction**).

* **Positive examples**: the second sentence follows the first sentence in the dataset.

* **Negative examples**: the second sentence is randomly selected from the dataset.

For each sentence pair, a label is created:

* 1 → Next sentence is correct (positive example)

* 0 → Next sentence is incorrect (negative example)

This allows the model to learn relationships between sentences in addition to word-level predictions.

In [None]:
def data_for_NSP(masked_sentences, mlm_labels, pad_idx, cls_idx, sep_idx):
    """
    Prepare data for Next Sentence Prediction (NSP).

    Args:
        masked_sentences (list): List of tokenized sentences
        mlm_labels (list): List of labels corresponding to input_tokens.
        pad_idx (int): The index of the [PAD] token.
        cls_idx(int): The index of the [CLS] token.
        sep_idx(int): The index of the [SEP] token.

    Returns:
        tuple of three lists---
            1. List of paired sentences with special token indices added.
            2. List of labels for masked token.
            3. List of boolean values for Next Sentence Prediction.
    """

    # Make sure the length of inputs are valid
    num_sentence = len(masked_sentences)
    if num_sentence < 2:
        raise ValueError("Must be more than two sentences in the input_tokens.")

    if num_sentence != len(mlm_labels):
        raise ValueError("The input_tokens and the labels must have the same length.")


    paired_sents = []
    paired_mlm_labels = []
    nsp_labels = []

    # Create the list of sentence indices
    sentence_idx = list(range(num_sentence))

    while len(sentence_idx) >= 2:
        if random.random() >= 0.5:
            # Randomly choose an index from the sentence_idx
            idx = random.choice(sentence_idx[:-1])

            # Pair two consecutive sentences (idx, idx+1) with special tokens as current inputs/lables
            cur_input = [[CLS_IDX] + masked_sentences[idx] + [SEP_IDX],
                          masked_sentences[idx + 1] + [SEP_IDX]]
            cur_label = [[PAD_IDX] + mlm_labels[idx] + [PAD_IDX],
                          mlm_labels[idx + 1] + [PAD_IDX]]

            # Add current inputs/labels to paired_inputs/paired_labels
            paired_sents.append(cur_input)
            paired_mlm_labels.append(cur_label)

            # Append 1 to nsp_label, indicating the current inputs are consecutive sentences.
            nsp_labels.append(1)

            # Remove idx and idx+1 from sentence_idx
            sentence_idx.remove(idx)
            if idx + 1 in sentence_idx:
                sentence_idx.remove(idx+1)

        else:
            # Randomly sample two indices from the sentence_idx
            idx_1, idx_2 = random.sample(sentence_idx, 2)

            # Add two randomly selected sentences (idx_1, idx_2) with special tokens as current inputs/lables
            cur_input = [[CLS_IDX] + masked_sentences[idx_1] + [SEP_IDX],
                         masked_sentences[idx_2] + [SEP_IDX]]
            cur_label = [[PAD_IDX] + mlm_labels[idx_1] + [PAD_IDX],
                         mlm_labels[idx_2] + [PAD_IDX]]

            # Add current inputs/labels to paired_inputs/paired_labels
            paired_sents.append(cur_input)
            paired_mlm_labels.append(cur_label)

            # Append 0 to nsp_label, indicating the current inputs are not consecutive sentences.
            nsp_labels.append(0)

            # Remove idx_1 and idx_2 from sentence_idx
            sentence_idx.remove(idx_1)
            sentence_idx.remove(idx_2)

    return paired_sents, paired_mlm_labels, nsp_labels

### Function to Create Final BERT Training Data

The final function combines the **MLM** and **NSP** data preparation steps:

1. Randomly masks 15% of tokens and prepares the corresponding labels for **Masked Language Modeling (MLM)**.

2. Generates sentence pairs and labels for **Next Sentence Prediction (NSP)**.

3. Returns a **pandas DataFrame** containing all the processed data.
>

The output DataFrame can then be saved as a CSV file, which is later loaded in the `Reproducing BERT Model from Scratch using PyTorch.ipynb` notebook. In that notebook, the data is **converted to a PyTorch dataset** and loaded into a **PyTorch DataLoader** for training.

In [None]:
def data_for_BERT(dataset, pad_idx, mask_idx, cls_idx, sep_idx, punct_idx):
    """
    Prepare data for BERT training--Masked Language Modeling and Next Sentence Prediction.

    Args:
        dataset(Dataset): The dataset to be processed.
        pad_idx (int): The index of the [PAD] token.
        mask_idx(int): The index of the [MASK] token.
        cls_idx(int): The index of the [CLS] token.
        sep_idx(int): The index of the [SEP] token.
        punct_idx(list): A list of indices of the ending punctuations

    Returns:
        A Pandas dataframe with the following columns:
            input_ids - token IDs of paired sentences
            token_type_ids - token type IDs (0 for the first sentence, 1 for the second)
            attention_mask - 1 for real tokens, 0 for padding
            mlm_labels - labels for masked language model prediction
            nsp_labels - labels for next sentence prediction

    """
    # Pad the paired inputs and flatten the nested list
    def pad_flatten(pairs, padding=PAD_IDX):
        max_len = max(len(pairs[0]), len([pairs[1]]))
        pairs[0].extend([padding] * (max_len - len(pairs[0])))
        pairs[1].extend([padding] * (max_len - len(pairs[1])))
        flatten_pairs = [item for sublist in pairs for item in sublist]
        return flatten_pairs

    # Get masked_sentences and the corresponding labels from the dataset
    masked_sentences, mask_labels = data_for_MLM(dataset, pad_idx, mask_idx, punct_idx)

    # Get paired_sentences, labels, and nsp_labels list
    paired_sentences, paired_mask_labels, nsp_labels = data_for_NSP(masked_sentences, mask_labels, pad_idx, cls_idx, sep_idx)

    input_ids, token_type_ids, attention_mask, mlm_labels = [], [], [], []

    for sentences, labels in zip(paired_sentences, paired_mask_labels):
        # Create token types (0: first sentence, 1: second sentence)
        token_type = [[0] * len(sentences[0]), [1] * len(sentences[1])]
        token_type = pad_flatten(token_type)

        # Create attention mask (1: real tokens, 0: padding])
        mask = [[1] * len(sentences[0]), [1] * len(sentences[1])]
        mask = pad_flatten(mask)

        # Pad and flatten paired sentences and mask_labels
        padded_sent = pad_flatten(sentences)
        padded_label = pad_flatten(labels)

        # Convert tokens to indices and add to final lists
        input_ids.append(padded_sent)
        mlm_labels.append(padded_label)
        token_type_ids.append(token_type)
        attention_mask.append(mask)

    # Create a dataframe of bert data
    bert_data = pd.DataFrame({
        'input_ids': input_ids,
        'token_type_ids': token_type_ids,
        'attention_mask': attention_mask,
        'mlm_labels': mlm_labels,
        'nsp_labels': nsp_labels
    })

    return bert_data

In [None]:
# Process imdb_data for BERT training
imdb_bert_data = data_for_BERT(imdb_data, PAD_IDX, MASK_IDX, CLS_IDX, SEP_IDX, PUNCT_IDX)
imdb_bert_data.head()

In [None]:
imdb_bert_data.info()

In [None]:
# Save the dataframe as CSV
imdb_bert_data.to_csv('imdb_bert_data.csv', index=False)