<a href="https://colab.research.google.com/github/Doris-QZ/Transformers-from-Scratch-BERT-and-GPT2-in-PyTorch/blob/main/1_Data_For_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Introduction

This notebook prepares the **IMDB dataset** for BERT-style pre-training. The data is processed to align with **BERT’s training objectives**:

* **Masked Language Modeling (MLM)**: randomly masks tokens for prediction.
* **Next Sentence Prediction (NSP)**: create paired sentences with both positive and negative examples.  
>

At the end of this notebook, we'll have a pandas DataFrame with the following columns:
* **input_ids** - token IDs of paired sentences       
* **token_type_ids** - token type IDs (0 for the first sentence, 1 for the second)
* **attention_mask** - 1 for real tokens, 0 for padding
* **mlm_labels** - labels for masked language model prediction
* **nsp_labels** - labels for next sentence prediction.  
>
   
Although this is not the same corpus used in the original BERT paper (BooksCorpus + Wikipedia), the dataset is structured in a way that allows the model to be trained with the same objectives.  


In [1]:
from datasets import load_dataset
import transformers
from transformers import BertTokenizer

import torch
from torch.nn.utils.rnn import pad_sequence

import pandas as pd
import random

In [13]:
# Suppress the warning
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

# Load the data
imdb_data = load_dataset('imdb', split='train')

In [3]:
# Take a look at the data
print(f"{imdb_data}\n")
imdb_data[0]

Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})



{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far be

BERT uses **WordPiece**, a subword tokenization algorithm, with a vocabulary size of 30,522 in the BERT-base model. We’ll load the BERT-base tokenizer from Hugging Face’s transformers package for text tokenization.
>
**Note:** Although the BERT tokenizer can directly return `input_ids`, `attention_mask`, and `token_type_ids`, in this project we only use its `encode` method to obtain token IDs. We will then **create `input_ids`, `attention_mask`, and `token_type_ids` manually**, since our goal is to **explore how data is prepared for masked language modeling (MLM) and next sentence prediction (NSP)**.

In [14]:
# Load BERT-base tokenizer
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# Get vocabulary and vocabulary size
vocab = bert_tokenizer.get_vocab()
vocab_size = len(vocab)

### Function for Preparing Data for Masked Language Modeling (MLM)
According to the BERT paper, **15% of tokens are randomly selected for masking**.  

For example,<br>

> Original text: "my dog is very cute."  
> Randomly selected token: "cute"

For the selected tokens:
* 80% of time ---> replaced with [MASK] ---> "my dog is very **[MASK]**."
* 10% of time ---> left unchanged ---> "my dog is very **cute**."
* 10% of time ---> replaced with a random token ---> "my dog is very **apple**."    
>

Regardless of how the selected tokens are masked, the label remains the same:
> Label: [PAD][PAD][PAD][PAD] cute.  

During training, the model **ignores the [PAD] tokens** and **predicts the original tokens** (in this case, "cute") at the selected positions.

In [5]:
# Get special token indices
PAD_IDX = bert_tokenizer.pad_token_id
UNK_IDX = bert_tokenizer.unk_token_id
CLS_IDX = bert_tokenizer.cls_token_id
SEP_IDX = bert_tokenizer.sep_token_id
MASK_IDX = bert_tokenizer.mask_token_id

PAD_IDX, UNK_IDX, CLS_IDX, SEP_IDX, MASK_IDX

(0, 100, 101, 102, 103)

In [6]:
special_ids = {PAD_IDX, UNK_IDX, CLS_IDX, SEP_IDX, MASK_IDX}

def random_token(vocab_size: int = vocab_size,
                 special_ids: set[int] | list[int] = special_ids) -> int:
    """
    Returns a random token ID that is not a special token.
    """
    while True:
        idx = random.randint(0, vocab_size-1)
        if idx not in special_ids:
            return idx

def masking(token_id: int, pad_idx: int, mask_idx: int) -> tuple[int, int]:
    """
    Mask a single token for Masked Language Model (MLM) training.

    Args:
        token_id (int): The token_id to be processed.
        pad_idx (int): Index of the [PAD] token.
        mask_idx (int): Index of the [MASK] token.

    Returns:
        tuple[int, int]: A tuple containing:
            1. The processed token id (masked, random, or original).
            2. The label for the token (original token id if masked, or pad_idx)

    """

    # The probability of a token being masked is 15%.
    mask = random.random() <= 0.15

    if not mask:
        token_ = token_id
        label_ = pad_idx
        return token_, label_

    # Generates a random float between 0 and 1
    random_float = random.random()

    # 80% of the selected tokens will be repalced by mask_idx
    if random_float < 0.8:
        token_ = mask_idx
        label_ = token_id
        return token_, label_

    # 10% of the selected tokens will remain unchanged
    if random_float > 0.9:
        token_ = token_id
        label_ = token_id
        return token_, label_

    # 10% of the selected tokens will be replaced by a random token id
    else:
        token_ = random_token()
        label_ = token_id
        return token_, label_

In [7]:
from datasets import Dataset

PUNCT_IDX = {vocab['.'], vocab['?'], vocab['!']}

def data_for_MLM(dataset: Dataset,
                 tokenizer,
                 pad_idx: int,
                 mask_idx: int,
                 punct_idx: set[int] | list[int]) -> tuple[list, list]:
    """
    Prepare data for Masked Language model (MLM) training.

    Args:
        dataset (Dataset): The dataset to be processed.
        tokenizer: The tokenizer to encode the text.
        pad_idx (int): Index of the [PAD] token.
        mask_idx (int): Index of the [MASK] token.
        punct_idx (set | list): A set or list of indices of ending punctuations

    Returns:
        tuple[list, list]: A tuple containing:
            1. List of tokenized and masked sentences
            2. List of labels corresponding to the masked sentences.

    """

    masked_sentences = []
    mlm_labels = []

    for data in dataset:
        cur_tokens = []
        cur_labels = []
        tokens =  tokenizer.encode(data['text'], add_special_tokens=False)

        for token_id in tokens:
            token_, label_ = masking(token_id, pad_idx, mask_idx)
            cur_tokens.append(token_)
            cur_labels.append(label_)

            # Found a token indicates the end of sentence, process the sentence and reset it.
            if token_id in punct_idx:
                if len(cur_tokens) > 2:
                    masked_sentences.append(cur_tokens)
                    mlm_labels.append(cur_labels)

                # Reset for the next sentence within the same document
                cur_tokens = []
                cur_labels = []

        # Append leftovers (if the document didn't end with ending punctuations)
        if cur_tokens:
            masked_sentences.append(cur_tokens)
            mlm_labels.append(cur_labels)

    return masked_sentences, mlm_labels


### Function for Preparing Data for Next Sentence Prediction (NSP)

BERT is trained to predict whether a pair of sentences are consecutive in the original text (**Next Sentence Prediction**).

* **Positive examples**: the second sentence follows the first sentence in the dataset.

* **Negative examples**: the second sentence is randomly selected from the dataset.

For each sentence pair, a label is created:

* 1 → Next sentence is correct (positive example)

* 0 → Next sentence is incorrect (negative example)

This allows the model to learn relationships between sentences in addition to word-level predictions.

In [8]:
def data_for_NSP(masked_sentences: list[list[int]],
                 mlm_labels: list[list[int]],
                 pad_idx: int,
                 cls_idx: int,
                 sep_idx: int) -> tuple[list, list, list]:
    """
    Prepare data for Next Sentence Prediction (NSP).

    Args:
        masked_sentences (list[list[int]]): List of masked sentences.
        mlm_labels (list[list[int]]): List of labels corresponding to masked sentences.
        pad_idx (int): Index of the [PAD] token.
        cls_idx (int): Index of the [CLS] token.
        sep_idx (int): Index of the [SEP] token.

    Returns:
        tuple[list, list, list]: A tuple containing:
            1. List of paired sentences with special token indices added.
            2. List of labels for masked token.
            3. List of NSP labels (1 for consecutive, 0 for random).
    """


    # Make sure the length of inputs are valid
    num_sentence = len(masked_sentences)
    if num_sentence < 2:
        raise ValueError("Must be more than two sentences in the input_tokens.")

    if num_sentence != len(mlm_labels):
        raise ValueError("The input_tokens and the labels must have the same length.")


    paired_sents = []
    paired_mlm_labels = []
    nsp_labels = []

    # Create the list of sentence indices
    sentence_idx = list(range(num_sentence))

    while len(sentence_idx) >= 2:

        # ---Positive Pair (consecutive)---
        if random.random() >= 0.5:

            # Pick a starter index
            idx = random.choice(sentence_idx[:-1])

            # Pair (idx) and (idx+1) with special tokens as current inputs/labels
            cur_input = [[cls_idx] + masked_sentences[idx] + [sep_idx],
                          masked_sentences[idx + 1] + [sep_idx]]
            cur_label = [[pad_idx] + mlm_labels[idx] + [pad_idx],
                          mlm_labels[idx + 1] + [pad_idx]]

            paired_sents.append(cur_input)
            paired_mlm_labels.append(cur_label)
            nsp_labels.append(1)      # Append 1 to nsp_label, indicating 'IsNext'

            # Remove idx. Also remove idx+1 if it is still available
            sentence_idx.remove(idx)
            if (idx+1) in sentence_idx:
                sentence_idx.remove(idx+1)

        # ---Negative Pair (random)---
        else:

            # Pick two random indices
            idx_1, idx_2 = random.sample(sentence_idx, 2)

            # Make sure the two indices are NOT neighbors
            attempts = 0
            while abs(idx_1 - idx_2) == 1 and attempts < 5:
                idx_1, idx_2 = random.sample(sentence_idx, 2)
                attempts += 1

            # If the two indices are still neighbors after 5 tries, skip the iteration to avoid bad data.
            if abs(idx_1 - idx_2) == 1:
                continue

            cur_input = [[cls_idx] + masked_sentences[idx_1] + [sep_idx],
                         masked_sentences[idx_2] + [sep_idx]]
            cur_label = [[pad_idx] + mlm_labels[idx_1] + [pad_idx],
                         mlm_labels[idx_2] + [pad_idx]]

            paired_sents.append(cur_input)
            paired_mlm_labels.append(cur_label)
            nsp_labels.append(0)    # Append 0 to nsp_label, indicating 'NotNext'

            # Remove both index
            sentence_idx.remove(idx_1)
            sentence_idx.remove(idx_2)

    return paired_sents, paired_mlm_labels, nsp_labels

### Function to Create Final BERT Training Data

The final function combines the **MLM** and **NSP** data preparation steps:

1. Randomly masks 15% of tokens and prepares the corresponding labels for **Masked Language Modeling (MLM)**.

2. Generates sentence pairs and labels for **Next Sentence Prediction (NSP)**.

3. Returns a **pandas DataFrame** containing all the processed data.
>

The output DataFrame can then be saved as a CSV file, which is later loaded in the `Reproducing BERT Model from Scratch using PyTorch.ipynb` notebook. In that notebook, the data is **converted to a PyTorch dataset** and loaded into a **PyTorch DataLoader** for training.

In [9]:
def data_for_BERT(dataset: Dataset,
                  tokenizer,
                  pad_idx: int,
                  mask_idx: int,
                  cls_idx: int,
                  sep_idx: int,
                  punct_idx: set[int] | list[int]) -> pd.DataFrame:
    """
    Prepare data for BERT training--Masked Language Modeling and Next Sentence Prediction.

    Args:
        dataset (Dataset): The dataset to be processed.
        tokenizer: The tokenizer to encode the text.
        pad_idx (int): Index of the [PAD] token.
        mask_idx (int): Index of the [MASK] token.
        cls_idx (int): Index of the [CLS] token.
        sep_idx (int): Index of the [SEP] token.
        punct_idx (set | list): A set or list of indices of ending punctuations.

    Returns:
        A Pandas dataframe with the following columns:
            input_ids - token IDs of paired sentences
            token_type_ids - token type IDs (0 for the first sentence, 1 for the second)
            attention_mask - 1 for real tokens, 0 for padding
            mlm_labels - labels for masked language model prediction
            nsp_labels - labels for next sentence prediction

    """

    # Get masked_sentences and the corresponding labels from the dataset
    masked_sentences, mask_labels = data_for_MLM(dataset, tokenizer, pad_idx, mask_idx, punct_idx)

    # Get paired_sentences, labels, and nsp_labels list
    paired_sentences, paired_mask_labels, nsp_labels = data_for_NSP(masked_sentences, mask_labels, pad_idx, cls_idx, sep_idx)

    input_ids, token_type_ids, attention_mask, mlm_labels = [], [], [], []

    for sentences, labels in zip(paired_sentences, paired_mask_labels):

        # flatten paired sentences and mask_labels
        padded_sent = sentences[0] + sentences[1]
        padded_label = labels[0] + labels[1]

        # Create token types (0 for the first sentence, 1 for the second)
        token_type = [0] * len(sentences[0]) + [1] * len(sentences[1])

        # Create attention mask (1 for real tokens)
        mask = [1] * len(sentences[0]) + [1] * len(sentences[1])

        # Append data to final lists
        input_ids.append(padded_sent)
        mlm_labels.append(padded_label)
        token_type_ids.append(token_type)
        attention_mask.append(mask)

    # Create dataframe
    bert_data = pd.DataFrame({
        'input_ids': input_ids,
        'token_type_ids': token_type_ids,
        'attention_mask': attention_mask,
        'mlm_labels': mlm_labels,
        'nsp_labels': nsp_labels
    })

    return bert_data

In [10]:
# Process imdb_data for BERT training
imdb_bert_data = data_for_BERT(imdb_data, bert_tokenizer, PAD_IDX, MASK_IDX, CLS_IDX, SEP_IDX, PUNCT_IDX)
imdb_bert_data.head()

Token indices sequence length is longer than the specified maximum sequence length for this model (718 > 512). Running this sequence through the model will result in indexing errors


Unnamed: 0,input_ids,token_type_ids,attention_mask,mlm_labels,nsp_labels
0,"[101, 11387, 2005, 1996, 3787, 1997, 3348, 101...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 2298, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",1
1,"[101, 2010, 3566, 2003, 1037, 16299, 103, 1996...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 1998, 0, 0, 0, 0, 0, 2028, ...",1
2,"[101, 3100, 1010, 1045, 2001, 11471, 1998, 103...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 11471, 0, 2787, 0, 0, 0, 0, 0,...",0
3,"[101, 1998, 1045, 1005, 103, 2469, 2008, 2065,...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 1049, 0, 0, 0, 0, 0, 2009, 0, 0, ...",0
4,"[101, 1045, 2074, 2293, 1996, 10256, 4038, 199...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",0


In [11]:
imdb_bert_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 185667 entries, 0 to 185666
Data columns (total 5 columns):
 #   Column          Non-Null Count   Dtype 
---  ------          --------------   ----- 
 0   input_ids       185667 non-null  object
 1   token_type_ids  185667 non-null  object
 2   attention_mask  185667 non-null  object
 3   mlm_labels      185667 non-null  object
 4   nsp_labels      185667 non-null  int64 
dtypes: int64(1), object(4)
memory usage: 7.1+ MB


In [12]:
# Save the dataframe as CSV
imdb_bert_data.to_csv('imdb_bert_data.csv', index=False)