<a href="https://colab.research.google.com/github/Doris-QZ/Reproducing-BERT-from-Scratch-with-PyTorch/blob/main/1_Data_For_BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Introduction

This notebook prepares the **IMDB dataset** for BERT-style pre-training. The data is processed to be **compatible with BERT’s training objectives**, including:

* **Masked Language Modeling (MLM)**: randomly masks tokens for prediction.

* **Next Sentence Prediction (NSP)**: creates paired sentences with both positive and negative examples.

Although this is not the same corpus used in the original BERT paper (BooksCorpus + Wikipedia), the dataset is structured in a way that allows the model to be trained with the same objectives.


In [None]:
from datasets import load_dataset

import nltk
from nltk.tokenize import word_tokenize
nltk.download('punkt_tab')

import torch
from torch.nn.utils.rnn import pad_sequence

import pandas as pd
from collections import Counter
import random

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


In [None]:
# Suppress the warning
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

# Load the data
imdb_data = load_dataset('imdb', split='train')

README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

plain_text/test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

plain_text/unsupervised-00000-of-00001.p(…):   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [None]:
# Take a look at the data
print(f"{imdb_data}\n")
imdb_data[0]

Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})



{'text': 'I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far be

### Build the dictionary

First, we'll build a vocabulary that maps tokens---including special tokens---to their corresponding indices. Note that the BERT paper uses **WordPiece (a sub-word tokenization algorithm)** with a vocabulary of 30,000 token. Here, however, we use **NLTK’s word-based tokenizer**, which results in a much larger vocabulary.

In [None]:
def build_vocab(dataset, special_tokens):
    """
    Build a vocabulary mapping tokens to indices.

    Args:
        dataset (Dataset):The dataset from which to build the vocabulary.
        special_tokens (List): A list of special tokens to include in the vocabylary.

    Returns:
        A dictionary that maps tokens to their corresponding indices.

    """
    # Add special tokens to the dictionary
    num_special_tokens = len(special_tokens)
    token_to_idx = {}

    for i, token in enumerate(special_tokens):
        token_to_idx[token] = i

    # Initialize Counter object to track word frequencies
    word_freq = Counter()

    # Iterate through all datasets' texts and calculate word frequencies
    for i in range(len(dataset)):
        word_freq.update(word_tokenize(dataset[i]))

    # Sort words by frequency and keep the top max_vocab_size words
    sorted_words = sorted(word_freq.items(), key=lambda items: items[1], reverse=True)
    vocab = [item[0] for item in sorted_words]

    # Add words to token_to_idx dictionary
    for idx, word in enumerate(vocab, start=len(special_tokens)):
        token_to_idx[word] = idx

    return token_to_idx

# Build the vocabulary from the dataset
vocab = build_vocab(dataset=imdb_data['text'],
                    special_tokens=['[PAD]','[CLS]', '[SEP]','[MASK]','[UNK]']
                    )

len(vocab)

132093

### Function for Preparing Data for Masked Language Modeling (MLM)
According to the BERT paper, **15% of tokens are randomly selected for masking**.  

For example,<br>

> Original text: "my dog is very cute."  
> Selected token: "cute"

For the selected tokens:
* 80% of time ---> replaced with [MASK] ---> "my dog is very **[MASK]**."
* 10% of time ---> left unchanged ---> "my dog is very **cute**."
* 10% of time ---> replaced with a random token ---> "my dog is very **apple**."
<br>

Regardless of how the selected tokens are masked, the label remains the same:
>[PAD][PAD][PAD][PAD] cute

During training, the model **ignores the [PAD] tokens** and predicts the orginal tokens at the selected positions.

In [None]:
def masking(token, vocab):
    """
    Helper function to mask a single token for Masked Language Model (MLM) training.

    Args:
        token (str): The token to be processed.
        vocab (dict): The vocabulary to choose random tokens from.

    Returns:
        tuple of two str---
            1. The processed token, which may be replaced with '[MASK]', left unchanged, or replaced with a random token.
            2. The label for the token--the original token if masked, or '[PAD]' if not masked.

    """
    # The probability of a token being masked is 15%.
    mask = random.random() <= 0.15

    if not mask:
        token_ = token
        label_ = '[PAD]'
        return token_, label_

    # Generates a random float between 0 and 1
    random_float = random.random()

    # 80% of the selected tokens will be repalced by [MASK] token
    if random_float < 0.8:
        token_ = '[MASK]'
        label_ = token
        return token_, label_

    # 10% of the selected tokens will remain unchanged
    if random_float > 0.9:
        token_ = token
        label_ = token
        return token_, label_

    # 10% of the selected tokens will be replaced by a random token
    else:
        random_idx = random.randint(0, len(vocab) - 1)
        token_ = list(vocab.keys())[random_idx]
        label_ = token
        return token_, label_

In [None]:
def data_for_MLM(dataset,  vocab):
    """
    Prepare data for Masked Language model (MLM) training.

    Args:
        dataset (Dataset): The dataset to be processed.
        vocab (dict): Vocabulary built from the dataset.

    Returns:
        tuple of two lists--
            1. List of tokenized sentences, where each token is either replaced with '[MASK]', left unchanged, or replaced with a random token.
            2. List of labels for masked tokens corresponding to the tokenized sentences. Each label is either the original token
        at masked positions or [PAD] at unmasked positions.
    """

    masked_sentences = []
    labels = []
    cur_tokens = []
    cur_labels = []


    for data in dataset:
        tokens =  word_tokenize(data['text'])

        for token in tokens:
            token_, label_ = masking(token, vocab)
            cur_tokens.append(token_)
            cur_labels.append(label_)

            # Found a token indicates the end of sentence, process the sentence and reset it.
            if token in ['.', '?', '!']:
                if len(cur_tokens) > 2:
                    masked_sentences.append(cur_tokens)
                    labels.append(cur_labels)
                    cur_tokens = []
                    cur_labels = []
                else:
                    cur_tokens = []
                    cur_labels = []

        # Note: the remaining tokens in the dataset that doesn't have an ending punctuation are ignored here. We can append it to the end of the list if we want.

    return masked_sentences, labels


### Function for Preparing Data for Next Sentence Prediction (NSP)

BERT is trained to predict whether a pair of sentences are consecutive in the original text (**Next Sentence Prediction**).

* **Positive examples**: the second sentence follows the first sentence in the dataset.

* **Negative examples**: the second sentence is randomly selected from the dataset.

For each sentence pair, a label is created:

* 1 → Next sentence is correct (positive example)

* 0 → Next sentence is incorrect (negative example)

This allows the model to learn relationships between sentences in addition to word-level predictions.

In [None]:
def data_for_NSP(masked_sentences, labels):
    """
    Prepare data for Next Sentence Prediction (NSP).

    Args:
        masked_sentences (List): List of tokenized sentences
        labels (List): List of labels corresponding to input_tokens.

    Returns:
        tuple of three lists---
            1. List of paired sentences with special tokens added.
            2. List of labels for masked token.
            3. List of boolean values for Next Sentence Prediction.
    """

    # Make sure the length of inputs are valid
    num_sentence = len(masked_sentences)
    if num_sentence < 2:
        raise ValueError("Must be more than two sentences in the input_tokens.")

    if num_sentence != len(labels):
        raise ValueError("The input_tokens and the labels must have the same length.")


    paired_inputs = []
    paired_labels = []
    is_next = []

    # Create the list of sentence indices
    sentence_idx = list(range(num_sentence))

    while len(sentence_idx) >= 2:
        if random.random() >= 0.5:
            # Randomly choose an index from the sentence_idx
            idx = random.choice(sentence_idx[:-1])

            # Pair two consecutive sentences (idx, idx+1) with special tokens as current inputs/lables
            cur_input = [['[CLS]'] + masked_sentences[idx] + ['[SEP]'],
                          masked_sentences[idx + 1] + ['[SEP]']]
            cur_label = [['[PAD]'] + labels[idx] + ['[PAD]'],
                          labels[idx + 1] + ['[PAD]']]

            # Add current inputs/labels to paired_inputs/paired_labels
            paired_inputs.append(cur_input)
            paired_labels.append(cur_label)

            # Append 1 to is_next, indicating the current input are consecutive sentences.
            is_next.append(1)

            # Remove idx and idx+1 from sentence_idx
            sentence_idx.remove(idx)
            if idx + 1 in sentence_idx:
                sentence_idx.remove(idx+1)

        else:
            # Randomly sample two indices from the sentence_idx
            idx_1, idx_2 = random.sample(sentence_idx, 2)

            # Add two randomly selected sentences (idx_1, idx_2) with special tokens as current inputs/lables
            cur_input = [['[CLS]'] + masked_sentences[idx_1] + ['[SEP]'],
                         masked_sentences[idx_2] + ['[SEP]']]
            cur_label = [['[PAD]'] + labels[idx_1] + ['[PAD]'],
                         labels[idx_2] + ['[PAD]']]

            # Add current inputs/labels to paired_inputs/paired_labels
            paired_inputs.append(cur_input)
            paired_labels.append(cur_label)

            # Append 0 to is_next, indicating the current input are not consecutive sentences.
            is_next.append(0)

            # Remove idx_1 and idx_2 from sentence_idx
            sentence_idx.remove(idx_1)
            sentence_idx.remove(idx_2)

    return paired_inputs, paired_labels, is_next

### Function to Create Final BERT Training Data

The final function combines the **MLM** and **NSP** data preparation steps:

1. Randomly masks 15% of tokens and prepares the corresponding labels for **Masked Language Modeling (MLM)**.

2. Generates sentence pairs and labels for **Next Sentence Prediction (NSP)**.

3. Returns a **pandas DataFrame** containing all the processed data.

The output DataFrame can then be saved as a CSV file, which is later loaded in the `Reproducing BERT Model from Scratch using PyTorch.ipynb` notebook. In that notebook, the data is **converted to a PyTorch dataset** and loaded into a **PyTorch DataLoader** for training.

In [None]:
def data_for_BERT(dataset,  vocab):
    """
    Prepare data for BERT training--Masked Language Modeling and Next Sentence Prediction.

    Args:
        dataset(Dataset): The dataset to be processed.
        vocab (dict): The vocabulary built from the dataset.

    Returns:
        A Pandas dataframe

    """
    # Pad the paired inputs and flatten the nested list
    def pad_flatten(pairs, padding='[PAD]'):
        max_len = max(len(pairs[0]), len([pairs[1]]))
        pairs[0].extend([padding] * (max_len - len(pairs[0])))
        pairs[1].extend([padding] * (max_len - len(pairs[1])))
        flatten_pairs = [item for sublist in pairs for item in sublist]
        return flatten_pairs

    # Convert tokens to indices
    token_to_idx = lambda tokens: [vocab[token] for token in tokens]

    # Get masked_sentences and the corresponding labels from the dataset
    masked_sentences, labels = data_for_MLM(dataset,  vocab)

    # Get paired_sentences, labels, and is_next list
    paired_sentences, labels, is_next = data_for_NSP(masked_sentences, labels)

    bert_inputs, bert_labels, segment_labels = [], [], []

    for sentences, labels in zip(paired_sentences, labels):
        # Create and pad segment labels
        seg_label = [[1] * len(sentences[0]), [2] * len(sentences[1])]
        seg_label = pad_flatten(seg_label, padding=0)

        # Pad and flatten paired sentences and labels
        padded_sent = pad_flatten(sentences)
        padded_label = pad_flatten(labels)

        # Convert tokens to indices and add to final lists
        bert_inputs.append(token_to_idx(padded_sent))
        bert_labels.append(token_to_idx(padded_label))
        segment_labels.append(seg_label)

    # Create a dataframe of bert data
    bert_data = pd.DataFrame({
        'bert_inputs': bert_inputs,
        'bert_labels': bert_labels,
        'segment_labels': segment_labels,
        'is_next': is_next
    })

    return bert_data

In [None]:
# Process imdb_data for BERT training
imdb_bert_data = data_for_BERT(imdb_data,  vocab)
imdb_bert_data.head()

Unnamed: 0,bert_inputs,bert_labels,segment_labels,is_next
0,"[1, 32, 3, 342, 333, 160, 3, 19356, 8, 141, 42...","[0, 32, 22, 0, 0, 0, 42, 0, 0, 0, 0, 0, 0, 0, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",1
1,"[1, 24, 200, 23, 1556, 74, 3, 3, 42, 882, 52, ...","[0, 0, 0, 0, 0, 0, 12, 11, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",0
2,"[1, 1055, 27, 1648, 6, 996, 19, 61, 7, 2, 1055...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 26, 5,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, ...",1
3,"[1, 221, 103, 3, 86, 28727, 12, 327, 3, 9, 136...","[0, 0, 0, 73, 0, 0, 0, 0, 1863, 0, 0, 0, 0, 0,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",1
4,"[1, 3, 3, 20, 41, 3782, 42, 9146, 3155, 41, 66...","[0, 54, 45796, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...",1


In [None]:
imdb_bert_data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 150288 entries, 0 to 150287
Data columns (total 4 columns):
 #   Column          Non-Null Count   Dtype 
---  ------          --------------   ----- 
 0   bert_inputs     150288 non-null  object
 1   bert_labels     150288 non-null  object
 2   segment_labels  150288 non-null  object
 3   is_next         150288 non-null  int64 
dtypes: int64(1), object(3)
memory usage: 4.6+ MB


In [None]:
# Save the dataframe as CSV
imdb_bert_data.to_csv('imdb_bert_data.csv', index=False)