## 1. Import Modules and Data

Accroding to original paper <a href="https://arxiv.org/abs/1810.04805">BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding</a>, Bert is pretrained on bookcorpus and wikipedia. More info about bookcorpus and wikipedia can be found at huggingface webpages of [bookcorpus](https://huggingface.co/datasets/bookcorpus/bookcorpus) and [wikipedia](https://huggingface.co/datasets/wikimedia/wikipedia).


BERT is pretrained on  Masked Language Model (Mask LM) and Next Sentence Prediction tasks during training (original paper, Section 3.1):

1. Masked LM: The training data generator randomly selects 15% of token positions for prediction. If the i-th token is selected, it is replaced with: (1) the [MASK] token 80% of the time, (2) a random token 10% of the time, or (3) the original i-th token 10% of the time. The model then predicts the original token using cross-entropy loss.

2. Next Sentence Prediction: This is a binary classification task. When selecting sentences A and B for each pre-training example, B is the actual subsequent sentence following A 50% of the time, and the other 50% of the time, B is a randomly chosen sentence from the corpus.

In [2]:
import itertools
import os

from data import load_data
import config

tokenizer, bookcorpus_dl = load_data("bookcorpus", loading_ratio=0.1)
_, wikipedia_dl = load_data("wikipedia", loading_ratio=1/41)


dataloader_size = len(bookcorpus_dl) + len(wikipedia_dl)
print("bookcorpus_dl size:", len(bookcorpus_dl))
print("wikipedia_dl size:", len(wikipedia_dl))
print("Total Dataloader Size:", dataloader_size)

Downloaded at  [/home/bks/.cache/huggingface/datasets/downloads/4ae3271c145b7cf080a16f6a2e0fa1d6410e16d8e50c32966328a2b721cf86a4.cfcd7c0ef3521898b77bd4829f8599ad4bc97b70065998fbda8f675af882a783 (origin=https://hf-mirror.com/datasets/bookcorpus/bookcorpus/resolve/refs%2Fconvert%2Fparquet/plain_text/train/0000.parquet?download=true)]
Downloaded at  [/home/bks/.cache/huggingface/datasets/downloads/09e1288f58ee4cdc9195bfc8ffcfa2cb28160e6201045fc17901c4b14ac61f08.d220069fc80398db604d42d5a6b1718d7a4699b18cfa1209a922e783f7cccb71 (origin=https://huggingface.co/datasets/wikimedia/wikipedia/resolve/refs%2Fconvert%2Fparquet/20231101.en/train/0000.parquet)]
bookcorpus_dl size: 3902000
wikipedia_dl size: 2438800
Total Dataloader Size: 6340800


## 2. Build Model
The key structural difference between Bert and GPT is that Bert does not use a causal mask in its layers. This allows Bert to leverage bidirectional attention, enabling it to capture global dependencies directly. Consequently, Bert's pre-training tasks are fundamentally different from GPT's next-token prediction task.

In [3]:
import torch
from modules.bert import BertForPreTraining

device = torch.device("cuda")
bert = BertForPreTraining(
    vocab_size=tokenizer.vocab_size,
    type_vocab_size=2,
    hidden_size=config.hidden_size,
    max_len=config.max_len,
    num_hidden_layers=config.num_layers,
    num_attention_heads=config.attention_heads,
    intermediate_size=config.intermediate_size,
    dropout=config.dropout,
    pad_token_idx=tokenizer.pad_token_id
).to(device)

## 3. Pretrain Model
Here I define a BERTTrainer class for pretraining settings.

In [4]:
import torch
import torch.nn as nn
import tqdm
from torch.optim import Adam, AdamW
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
from modules import BertForPreTraining

class BERTTrainer:
    """
    BERTTrainer make the pretrained BERT model with two LM training method.

        1. Masked Language Model : 3.3.1 Task #1: Masked LM
        2. Next Sentence prediction : 3.3.2 Task #2: Next Sentence Prediction

    please check the details on README.md with simple example.

    """
    def __init__(
            self, 
            bert: BertForPreTraining,
            lr: float = 1e-4, 
            weight_decay: float = 0.01, 
            warmup_steps=10000
    ):
        """
        :param bert: BERT model which you want to train
        :param lr: learning rate of optimizer
        :param weight_decay: Adam optimizer weight decay param
        :param warmup_steps: warmup steps of learning rate scheduler
        """

        # Setup cuda device for BERT training, argument -c, --cuda should be true
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.bert = bert
        self.model = bert.to(self.device)

        # Distributed GPU training if CUDA can detect more than 1 GPU
        # if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        #     cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "")
        #     device_ids = list(map(int, cuda_visible_devices.split(','))) if cuda_visible_devices else []
        #     print("Using %d GPUS for BERT" % torch.cuda.device_count())
        #     self.model = nn.DataParallel(self.model, device_ids=device_ids )

        self.optimizer = AdamW(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=dataloader_size * config.PretrainingConfig.n_epoch,
        )

        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = nn.NLLLoss(ignore_index=0)

        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch, train_dataloader):
        self.iteration(epoch, train_dataloader)

    def iteration(self, epoch, data_loader):
        """
        loop over the data_loader for training 
        if on train status, backward operation is activated
        and also auto save the model every epoch
        :param epoch: current epoch index
        :param data_loader: torch.utils.data.DataLoader for iteration
        """
        self.model.train()

        # Setting the tqdm progress bar
        data_iter = tqdm.tqdm(enumerate(data_loader),
                              desc="EP_train:%d" % (epoch),
                              total=dataloader_size,
                              bar_format="{l_bar}{r_bar}")

        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        for i, data in data_iter:
            # get batch data
            input_ids = data.input_ids.to(self.device)
            attention_mask = data.attention_mask.to(self.device)
            token_type_ids = data.token_type_ids.to(self.device)
            labels = data.labels.to(self.device)
            is_next = data.is_next.to(self.device)
            
            outputs = self.model.forward(
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                labels=labels,
                next_sentence_label=is_next
            )
            total_loss, prediction_scores, seq_relationship_score = outputs
            
            self.optimizer.zero_grad()
            total_loss.backward()
            self.optimizer.step()
            self.scheduler.step()

            # next sentence prediction accuracy
            correct = seq_relationship_score.argmax(dim=-1).eq(is_next.squeeze(-1)).sum().item()
            
            avg_loss += total_loss.item()
            total_correct += correct
            total_element += is_next.nelement()

        print("EP%d_train, avg_loss=" % (epoch), avg_loss / len(data_iter), "NSP_acc=",
              total_correct * 100.0 / total_element)

    def save(self, epoch, file_path):
        """
        Saving the current BERT model on file_path

        :param epoch: current epoch number
        :param file_path: model output path which gonna be file_path+"ep%d" % epoch
        :return: final_output_path
        """
        output_path = str(file_path) + "_ep%d.pth" % epoch
        torch.save(self.bert.cpu(), output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
    
trainer = BERTTrainer(
    bert,
    lr=config.PretrainingConfig.lr,
    weight_decay=config.PretrainingConfig.weight_decay,
    warmup_steps=config.PretrainingConfig.warmup_steps,
)

Total Parameters: 110106428


### 3.2 Train loop
Simple test of pretraining process.

In [None]:
print("Training Start")
for epoch in range(config.PretrainingConfig.n_epoch):
    train_dl = itertools.chain(bookcorpus_dl, wikipedia_dl)
    trainer.train(epoch, train_dl)
    if epoch % config.PretrainingConfig.checkpoint_freq == 0:
        trainer.save(epoch, config.trained_path)